【深度学习 pytorch】迁移学习 (迁移ResNet18)

news2024/12/26 11:40:19

李宏毅深度学习笔记
《深度学习原理Pytorch实战》
https://blog.csdn.net/peter6768/article/details/135712687

迁移学习

实际应用中很多任务的数据的标注成本很高,无法获得充足的训练数据,这种情况可以使用迁移学习(transfer learning)。假设A、B是两个相关的任务,A任务有很多训练数据,就可以把从A任务中学习到的某些可以泛化知识迁移到B任务。

有了迁移学习技术,我们便可以将神经网络像软件模块一样进行拼装和重复利用。例如,我们可以将在大数据集上训练好的大型网络迁移到小数据集上,从而只需经过少量的训练就能达到良好的效果。我们也可以将两个神经网络同时迁移过来,组合成一个新的网络,这两个神经网络就像软件模块一样被组合了起来。

监督学习要求训练集和测试集上的数据具有相同的分布特性。两个数据集具有相同的分布就意味着,两者中猫和狗的比例(甚至胖猫、瘦猫和胖狗、瘦狗的分布比例)大体相同,这样才能保证模型在训练集中学习到稳定的特征,从而应用到测试集中。
而迁移学习则不同,它允许训练集和测试集的数据有不同的分布、目标,甚至领域。

迁移学习方式:

  • 预训练模式:将迁移过来的权重视作新网络的初始权重,但是它在训练的过程中会被梯度下降算法改变数值。
  • 固定值模式:迁移过来的部分网络在结构和权重上都保持固定的数值,训练过程仅针对迁移模块后面的全连接网络。

贫困预测场景

一个地区的遥感图像大体能够反映该地区的贫困状况,因为贫困地区的街道布置往往更加混乱。但是,要想训练一个深度卷积神经网络来预测贫困地区,除了需要大量输入图像,还需要对每一张图像进行贫困程度的标注。
由于非洲贫困地区可获得的贫困数据非常少,仅有大概600多个数据点,这对于训练一个大型的卷积神经网络来说远远不够(对于一个8层的卷积神经网络来说,训练数据的量级至少要达到数百万)。所以,我们需要对遥感图像进行手工标注。然而,这种标注工作量巨大,简直就是一个不可能完成的任务。

应用迁移学习技术,解决标注数据缺失问题的方法。首先,训练一个卷积神经网络,用遥感图像来预测夜光亮度;然后,将训练好的网络迁移到运用遥感图像预测贫困地区的任务中,这样即使训练数据仅有几百个,我们照样可以进行更准确的预测。
在这里插入图片描述

1、首先,他们使用预训练的方法,将用于图像分类的大型卷积神经网络VGGF(包含8个卷积层)迁移过来,作为初始分类网络。物体分类网络VGGF是经过干万张图像训练而成的,它已经学会了如何对图像进行特征提取,例如提取物体的边缘等。
2、其次,在预训练好的VGGF网络上,应用卫星遥感影像数据和夜光影像数据对其进行训练。该模型的输入为某地区的卫星遥感图像,输出为该地区夜间明暗程度的预测。由于夜光数据很容易获得,因此将它作为标签,可以轻松获得数十万个成对的训练数据。另外,当卷积神经网络尝试预测夜光时,它需要学会有效地从卫星遥感图像中提取特定的特征,例如街道、房屋屋顶、混凝土建筑等。这样学到的网络就是一个能从卫星遥感图像中有效提炼特征的特征提取器。
3、然后,我们将用于预测夜光的神经网络的卷积层迁移过来,拼接一个新的全连接网络,用于预测一个地区的贫困程度。在这一部分,我们将采取固定值迁移方法,仅训练全连接网部分。这样便可以应用仅有的数百个贫困数据来训练这个预测器。在这一步,我们相当于在原始图像中提取有关特征,据此预测贫困程度。

分类问题(蚂蚁还是蜜蜂)

一是这些图像极其复杂,人类肉眼都不太容易一下子区分出画面中是蚂蚁还是蜜蜂,简单的卷积神经网络无法应付这个分类任务
二是整个训练集仅有244个样本,这么小的数据量无法训川练大的卷积神经网络。

我们把ResNet18中的卷积模块作为特征提取层迁移过来,用于提取局部特征。同时,构建一个包含512个隐含节点的全连接层,后接两个节点的输出层,用于最后的分类输出,最终构建一个包含20层的深度网络。

数据加载

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import os
 
data_path = 'data'
image_size = 224


class TranNet():
    def __init__(self):
        super(TranNet, self).__init__()
        
#加载的过程将会对图像进行如下增强操作:
#1.随机从原始图像中切下来一块224×224大小的区域
#2.随机水平翻转图像
#3.将图像的色彩数值标准化 
        self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([
            transforms.RandomSizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
#加载校验数据集,对每个加载的数据进行如下处理:
#1.放大到256×256
#2.从中心区域切割下224×224大小的区域
#3.将图像的色彩数值标准化
        self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))`在这里插入代码片`
#创建相应的数据加载器
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.num_classes = len(self.train_dataset.classes)
net = models.resnet18(pretrained=True)

可以使用预训练的方式将这个网络迁移过来:

        net = models.resnet18(pretrained=True)
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

其中,num_features存储了ResNet18最后的全连接层的输入神经元个数。事实上,以上代码所做的就是将原来的ResNet18最后两层全连接层替换成一个输出单元为2的全连接层,这就是net.fc。之后,我们按照普通的方法定义损失函数和优化器。因此,这个模型首先会利用ResNet预训练好的权重,提取输入图像中的重要特征,然后利用net.fc这个线性层,根据输入特征进行分类。

当使用固定值的方式进行迁移的时候,可以使用下列代码:

net = models.resnet18(pretrained=True)
for param in net.parameters():
	param.requires_grad=False
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, 2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

gpu加速

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
net = models.resnet18(pretrained=True)
#如果存在GPU,就将网络加载到GPU上
net = net.cuda() if use_cuda else net

将训练数据加载到gpu上

#将数据复制出来,然后加载到GPU上
data,target=data.clone().detach().requires_grad(True),target.clone().detach()
if use_cuda:
  data,target=data.cuda(),target.cuda()

最后,我们使用.cpu()将GPU上的计算结果再次转回内存中:

#待计算完成后,需将数据放回CPU
loss=loss.cpu() if use_cuda else loss

训练

    def model_prepare(self):
        net = models.resnet18(pretrained=True)
        
        # jusge whether GPU
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
            itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
            net = net.cuda() if use_cuda else net
        
        # float net values
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
        
        # fixed net values
        '''
        for param in net.parameters():
            param.requires_grad = False
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)
        '''
        record = []
        num_epochs = 20
        net.train(True) # open dropout  给网络模型做标记,说明模型在训练集上训练
        for epoch in range(num_epochs):
            train_rights = []
            train_losses = []
            for batch_index, (data, target) in enumerate(self.train_loader):
                data, target = data.clone().detach().requires_grad_(True), target.clone().detach()
                output = net(data) #完成一次预测
                loss = criterion(output, target)  #计算误差
                optimizer.zero_grad() #清空梯度
                loss.backward()  # 反向传播
                optimizer.step() #随机梯度下降
                right = rightness(output, target)  #计算准确率所需数值,返回数值为(正确样例数,总样本数)
                train_rights.append(right)
                train_losses.append(loss.data.numpy())
                if batch_index % 400 == 0:
                    verify_rights = []
                    for index, (data_v, target_v) in enumerate(self.verify_loader):
                        data_v, target_v = data_v.clone().detach(), target_v.clone().detach()
                        output_v = net(data_v)
                        right = rightness(output_v, target_v)
                        verify_rights.append(right)
                    verify_accu = sum([row[0] for row in verify_rights]) / sum([row[1] for row in verify_rights])
                    record.append((verify_accu))
                    print(f'verify data accu:{verify_accu}')
        # plot
        pyplot.figure(figsize=(8, 6))
        pyplot.plot(record)
        pyplot.xlabel('step')
        pyplot.ylabel('verify loss')
        pyplot.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第三期闯关基础岛

1、 Linux 基础知识 任务描述完成所需时间闯关任务完成SSH连接与端口映射并运行hello_world.py10min可选任务 1将Linux基础命令在开发机上完成一遍10min可选任务 2使用 VSCODE 远程连接开发机并创建一个conda环境10min可选任务 3创建并运行test.sh文件10min 1.1、SSH连接 使用…

Android Spinner

1. Spinner Spinner是下拉列表,如图3-14所示,通常用于为用户提供选择输入。Spinner有一个重要的属性:spinnerMode,它有2种情况: 属性值为dropdown时,表示Spinner的数据下拉展示,如图1&#xf…

自己动手写一个滑动验证码组件(后端为Spring Boot项目)

近期参加的项目,主管丢给我一个任务,说要支持滑动验证码。我身为50岁的软件攻城狮,当时正背着双手,好像一个受训的保安似的,中规中矩地参加每日站会,心想滑动验证码在今时今日已经是标配了,司空…

jenkins系列-06.harbor

https://github.com/goharbor/harbor/releases?page2 https://github.com/goharbor/harbor/releases/download/v2.3.4/harbor-offline-installer-v2.3.4.tgz harbor官网:https://goharbor.io/ 点击 Download now 链接,会自动跳转到上述github页面&am…

采用自动微分进行模型的训练

自动微分训练模型 简单代码实现: import torch import torch.nn as nn import torch.optim as optim# 定义一个简单的线性回归模型 class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear nn.Linear(1, 1) …

python:使用matplotlib库绘制图像(四)

作者是跟着http://t.csdnimg.cn/4fVW0学习的,matplotlib系列文章是http://t.csdnimg.cn/4fVW0的自己学习过程中整理的详细说明版本,对小白更友好哦! 四、条形图 1. 一个数据样本的条形图 条形图:常用于比较不同类别的数量或值&…

DockerCompose介绍,安装,使用

DockerCompose 1、Compose介绍 将单机服务-通过Dockerfile 构建为镜像 -docker run 成为一个服务 user 8080 net 7000 pay 8181 admin 5000 监控 .... docker run 单机版、一个个容器启动和停止问题: 前面我们使用Docker的时候,定义 Dockerfil…

深入理解Java泛型:概念、用法与案例分析

个人名片** 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4…

Transformer模型:Encoder的self-attention mask实现

前言 这是对Transformer模型的Word Embedding、Postion Embedding内容的续篇。 视频链接:19、Transformer模型Encoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili 文章链接:Transformer模型:WordEmbedding实现-CSDN博客 Transformer模型…

docker-compose安装PolarDB-PG数据库

文章目录 一. Mac1.1 docker-compose.yaml1.2 部署1.3 卸载4. 连接 二. Win102.1 docker-compose.yaml2.2 部署2.3 卸载 参考官方文档 基于单机文件系统部署 一. Mac 1.1 docker-compose.yaml mkdir -p /Users/wanfei/docker-compose/polardb-pg && cd /Users/wanfei…

Linux - 综合使用shell脚本,输出网站有效数据

综合示例: shell脚本实现查看网站分数 使用编辑器编辑文件jw.sh为如下内容: #!/bin/bash save_file"score" # 临时文件 semester20102 # 查分的学期, 20102代表2010年第二学期 jw_home"http://jwas3.nju.edu.cn:8080/jiaowu" # 测试网站首页地址 jw_logi…

zigbee开发工具:2、zigbee工程建立与配置

本文演示基于IAR for 8051(版本10.10.1)如何建立一个开发芯片cc2530的zigbee的工程,并配置这个工程,使其能够将编译的代码进行烧录,生成.hex文件。IAR for 8051(版本10.10.1)支持工程使用C语言&…

STM32智能交通灯系统教程

目录 引言环境准备智能交通灯系统基础代码实现:实现智能交通灯系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景:交通管理与优化问题解决方案与优化收尾与总结 1. 引言 智能交通灯系统通过STM…

Python游戏开发:四连珠(内附完整代码)

四连珠(Connect Four)是一款经典的棋类游戏,由两名玩家在7列6行的网格上轮流下棋。玩家的目标是将自己的棋子在垂直、水平或对角线上连成一条线,通常是四个棋子。如果一方成功做到这一点,那么他就赢得了游戏。如果所有…

视频监控汇聚平台:通过SDK接入大华DSS视频监控平台的源代码解释和分享

目录 一、视频监控汇聚平台 1、概述 2、视频接入能力 3、视频汇聚能力 二、大华DSS平台 1、DSS平台概述 2、大华DSS平台的主要特点 (1)高可用性 (2)高可靠性 (3)易维护性 (4&#xf…

《昇思25天学习打卡营第2天|02快速入门》

课程目标 这节课准备再学习下训练模型的基本流程,因此还是选择快速入门课程。 整体流程 整体介绍下流程: 数据处理构建网络模型训练模型保存模型加载模型 思路是比较清晰的,看来文档写的是比较连贯合理的。 数据处理 看数据也是手写体数…

【算法】平衡二叉树

难度:简单 题目 给定一个二叉树,判断它是否是 平衡二叉树 示例: 示例1: 输入:root [3,9,20,null,null,15,7] 输出:true 示例2: 输入:root [1,2,2,3,3,null,null,4,4] 输出&…

炒鸡清晰的防御综合实验(内含区域划分,安全策略,用户认证,NAT认证,智能选路,域名访问)

实验拓扑图如下: 前面六个条件在之间的实验中做过了,详细步骤可以去之前的文章看 这里简写一下大致步骤 第一步: 先将防火墙之外的配置给配置好,比如,PC的IP,交换上的Vlan划分。 第二步: 在浏览器上登…

用SurfaceView实现落花动画效果

上篇文章 Android子线程真的不能刷新UI吗?(一)复现异常 中可以看出子线程更新main线程创建的View,会抛出异常。SurfaceView不依赖main线程,可以直接使用自己的线程控制绘制逻辑。具体代码怎么实现了? 这篇文章用Surfa…

【算法专题】快速排序

1. 颜色分类 75. 颜色分类 - 力扣(LeetCode) 依据题意,我们需要把只包含0、1、2的数组划分为三个部分,事实上,在我们前面学习过的【算法专题】双指针算法-CSDN博客中,有一道题叫做移动零,题目要…