【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】

news2025/1/19 11:39:39

通常为了使模型的预测精度达到较高的标准,需要收集十分庞大的数据集来进行模型训练。一种比较巧妙解决该问题的办法是应用迁移学习(transfer learning),将从某个已有的数据集学到的知识迁移到目标数据集上。例如,假如我们想训练一个识别椅子的模型,我们使用已有的ImageNet数据集【它有超过1,000万的图像和1,000类的物体】,虽然的ImageNet数据集图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

目录

  • 1. 实战案例:热狗识别
    • 1.1 获取数据集
    • 1.2 定义和初始化模型
    • 1.3 使用微调技术训练模型
  • 总结

本文我们将介绍迁移学习中的一种常用技术:微调(fine tuning)。如下图所示,微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

在这里插入图片描述

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

1. 实战案例:热狗识别

接下来我们来实践一个具体的例子:热狗识别。我们将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张包含热狗和不包含热狗的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗

关注GZH:阿旭算法与机器学习,回复:“微调实战”即可获取本文数据集与项目文档

首先,导入实验所需的包或模块。torchvision的models包提供了常用的预训练模型。如果希望获取更多的预训练模型,可以使用使用pretrained-models.pytorch仓库。

%matplotlib inline
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os

import sys
import d2lzh_pytorch as d2l

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1.1 获取数据集

我们使用的热狗数据集含有1400张包含热狗的正类图像,和同样多包含其他食品的负类图像。各类的1000张图像被用于训练,其余则用于测试。

我们首先将压缩后的数据集下载到路径data_dir之下,然后在该路径将下载好的数据集解压,得到两个文件夹hotdog/trainhotdog/test。这两个文件夹下面均有hotdognot-hotdog两个类别文件夹,每个类别文件夹里面是图像文件。

data_dir = './data'
os.listdir(os.path.join(data_dir, "hotdog")) # ['train', 'test']

我们创建两个ImageFolder实例来分别读取训练数据集和测试数据集中的所有图像文件。

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

下面画出前8张正类图像和最后8张负类图像。可以看到,它们的大小和高宽比各不相同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

在这里插入图片描述

在训练时,我们先从图像中裁剪出随机大小和随机高宽比的一块随机区域,然后将该区域缩放为高和宽均为224像素的输入。测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高和宽均为224像素的中心区域作为输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化:每个数值减去该通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。

注: 在使用训练模型时,一定要将预测数据,作训练时同样的预处理。
如果你使用的是torchvisionmodels,那就要求:
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
如果你使用的是pretrained-models.pytorch仓库,请务必阅读其README,其中说明了如何预处理。

# 指定RGB三个通道的均值和方差来将图像通道归一化
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([
        transforms.RandomResizedCrop(size=224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

test_augs = transforms.Compose([
        transforms.Resize(size=256),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        normalize
    ])

1.2 定义和初始化模型

我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。这里指定pretrained=True来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

pretrained_net = models.resnet18(pretrained=True)

不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home目录下.torch文件夹。

下面打印源模型的成员变量fc。作为一个全连接层,它将ResNet最终的全局平均池化层输出变换成ImageNet数据集上1000类的输出。

print(pretrained_net.fc)

输出:

Linear(in_features=512, out_features=1000, bias=True)

注: 如果你使用的是其他模型,那可能没有成员变量fc(比如models中的VGG预训练模型),所以正确做法是查看对应模型源码中其定义部分,这样既不会出错也能加深我们对模型的理解。pretrained-models.pytorch仓库貌似统一了接口,但是我还是建议使用时查看一下对应模型的源码。

可见此时pretrained_net最后的输出个数等于目标数据集的类别数1000。所以我们应该将最后的fc成修改我们需要的输出类别数:

# 随机初始化输出层参数
pretrained_net.fc = nn.Linear(512, 2)
print(pretrained_net.fc)

输出:

Linear(in_features=512, out_features=2, bias=True)

此时,pretrained_netfc层就被随机初始化了,但是其他层依然保存着预训练得到的参数。由于是在很大的ImageNet数据集上预训练的,所以参数已经足够好,因此一般只需使用较小的学习率来微调这些参数,而fc中的随机初始化参数一般需要更大的学习率从头训练。PyTorch可以方便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设为已经预训练过的部分的10倍。

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())

lr = 0.01
optimizer = optim.SGD([{'params': feature_params},
                       {'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],
                       lr=lr, weight_decay=0.001)

1.3 使用微调技术训练模型

我们先定义一个使用微调的训练函数train_fine_tuning以便多次调用。

def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=5):
    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),
                            batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),
                           batch_size)
    loss = torch.nn.CrossEntropyLoss()
    d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)

根据前面的设置,我们将以10倍的学习率从头训练目标模型的输出层参数。

train_fine_tuning(pretrained_net, optimizer)

输出:

training on  cuda
epoch 1, loss 3.1183, train acc 0.731, test acc 0.932, time 41.4 sec
epoch 2, loss 0.6471, train acc 0.829, test acc 0.869, time 25.6 sec
epoch 3, loss 0.0964, train acc 0.920, test acc 0.910, time 24.9 sec
epoch 4, loss 0.0659, train acc 0.922, test acc 0.936, time 25.2 sec
epoch 5, loss 0.0668, train acc 0.913, test acc 0.929, time 25.0 sec

作为对比,我们定义一个相同的模型,但将它的所有模型参数都初始化为随机值。由于整个模型都需要从头训练,我们可以使用较大的学习率。

scratch_net = models.resnet18(pretrained=False, num_classes=2)
lr = 0.1
optimizer = optim.SGD(scratch_net.parameters(), lr=lr, weight_decay=0.001)
train_fine_tuning(scratch_net, optimizer)

输出:

training on  cuda
epoch 1, loss 2.6686, train acc 0.582, test acc 0.556, time 25.3 sec
epoch 2, loss 0.2434, train acc 0.797, test acc 0.776, time 25.3 sec
epoch 3, loss 0.1251, train acc 0.845, test acc 0.802, time 24.9 sec
epoch 4, loss 0.0958, train acc 0.833, test acc 0.810, time 25.0 sec
epoch 5, loss 0.0757, train acc 0.836, test acc 0.780, time 24.9 sec

可以看到,微调的模型因为参数初始值更好,往往在相同迭代周期下取得更高的精度。

总结

  • 迁移学习将从源数据集学到的知识迁移到目标数据集上。微调是迁移学习的一种常用技术。
  • 目标模型复制了源模型上除了输出层外的所有模型设计及其参数,并基于目标数据集微调这些参数。而目标模型的输出层需要从头训练。
  • 一般来说,微调参数会使用较小的学习率,而从头训练输出层可以使用较大的学习率。

如果文章内容对你有帮助,感谢点赞+关注!

欢迎关注下方GZH:阿旭算法与机器学习,共同学习交流~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/155362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序安装 Vant 组件库与API Promise组件库并实现简单的增删改查

在项目内右键空白处选择在外部终端打开2、在终端窗口输入 npm init -y,创建package-lock.jsonnpm init -y3、在终端输入npm i vant/weapp1.3.3 -S --production,创建node_modules文件夹npm i vant/weapp1.3.3 -S --production4、详情-本地设置&#xff0…

Vue2.0开发之——Vue组件-组件的实例对象(36)

一 概述 浏览器无法直接解析Vue文件package.json中的’vue-template-compiler’将vue结尾的文件解析为js文件交给浏览器处理Count组件实例对象 二 浏览器无法直接解析Vue文件 将Vue文件拖放到浏览器中无法直接显示 三 package.json中的’vue-template-compiler’将vue结尾的文…

软件著作权登记指南

一、什么是计算机软件《计算机软件保护条例》第二条、第三条规定,本条例所称计算机软件(以下简称软件),是指计算机程序及其有关文档;(一)计算机程序,是指为了得到某种结果而可以由计…

第13章 Token的Postman、Swagger和Vue调试

1 准备工作 1.1 WebApi.Controllers.JwtSettingModel namespace WebApi.Test { /// <summary> /// 【Jwt设置模型--纪录】 /// <remarks> /// 摘要&#xff1a; /// 通过该纪录中的属性成员实例存储“AppSettings.json”文件中的Jwt相关设置数据&#xff0…

java应用程序多级缓存架构

多级缓存架构 一级缓存&#xff1a;OpenResty—Lua—Redis 二级缓存&#xff1a;Nginx proxy-cache 三级缓存&#xff1a;Redis 使用OpenResty lua脚本访问redis proxy-cache 缓存注解 <!--依赖--> <dependency><groupId>org.springframework.boot</gr…

最新研究发现:天然海绵含有抑制Omicron变体感染的天然化合物

本文原文首发于2023年1月9日E-LIFESTYLE &#xff08;阅读时间4分钟&#xff09; 附标题&#xff1a;通过研究370多种来自植物、真菌和海绵等天然来源的化合物&#xff0c;寻找可用于治疗新冠肺炎的新抗病毒药物&#xff0c;用这些天然化合物制成的溶液中沐浴人类被SARS-CoV-2感…

SolidWorks装配体保存成零件,能有效压缩文件体积,方便二次装配

SolidWorks装配体保存成零件&#xff0c;能有效压缩文件体积&#xff0c;方便二次装配1. 先使用solidworks打开我们要转换成零件的装配体2. 然后点击上方保存下面的小三角&#xff0c;选择另存为3.之后选择要保存的位置&#xff0c;点击文件格式&#xff0c;然后在文件格式里找…

Zabbix监控服务详解+实战

目录 一、监控体系概述 1. 为什么需要监控 2. 监控目标与流程 &#xff08;1&#xff09;监控的目标 &#xff08;2&#xff09; 监控的流程 3. 监控的对象 &#xff08;1&#xff09;CPU监控 &#xff08;2&#xff09;磁盘监控 &#xff08;3&#xff09;内存监控 …

win7电脑怎么录屏?免费的录屏软件分享

现在大家的电脑一般是win10、11系统&#xff0c;但是还是有一些小伙伴喜欢使用win7系统的电脑。那你知道win7电脑怎么录屏吗&#xff1f;有没有好用且简单的win7电脑录屏软件推荐&#xff1f;当然有&#xff01;今天小编给使用win7电脑的小伙伴推荐两款简单且好用的电脑录屏软件…

各类字符串函数和内存函数的使用以及模拟(万字解析)

函数一.字符串函数(使用都需要包含string.h)1.求字符串长度—strlen2.长度不受限制的字符串函数1.strcpy-字符串拷贝2.strcat-追加字符串3.strcmp-字符串比较4.为什么长度不受限制3.长度受限制的字符串函数—strncopy,strncat,strncmp4.字符串查找1.strstr-判断是否为子字符串2…

Linux 文件句柄导致系统压力测试时出现错误率

最近&#xff0c;在对一个golang写的获取商品详情信息的接口做压力测试时&#xff0c;tps 单机可以达到1400多&#xff0c;但是发现每当压力测试开始2分钟多时就会出现502或504 错误&#xff0c;整体的错误率在0.5%左右。一开始是怀疑代码写的效率不高&#xff0c;是不是协程开…

【SAP Hana】SAP HANA SQL 进阶教程

SAP HANA SQL 进阶教程5、HANA SQL 进阶教程&#xff08;1&#xff09;Databases&#xff08;2&#xff09;User & Role&#xff08;3&#xff09;Schemas&#xff08;4&#xff09;Tables&#xff08;5&#xff09;Table Index&#xff08;6&#xff09;Table Partitions&…

于仕琪C/C++ 学习笔记

C函数指针有哪几类&#xff1f;函数指针、lambda、仿函数对象分别是什么&#xff1f;如何利用谓词对给定容器进行自定义排序&#xff1f;传递引用和传递值的区别&#xff1f;传递常引用和传递引用之间的区别&#xff1f;传递右值引用和传递引用之 间的区别&#xff1f;函数对象…

【PWA学习】6. 使用 Service Worker 进行后台同步

引言 你一定遇到过类似这样的场景&#xff1a; 当用手机填写完一张信息表单点击"提交"时&#xff0c;恰好手机网络很差或没有网络&#xff0c;这时候只能盯着手机看着旋转的小圆圈。经过长时间等待后依然没有结果&#xff0c;这时候关闭浏览器&#xff0c;请求也被终…

红外传感器使用

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、红外传感器&#xff1f;二、使用步骤1.确保驱动已经安装2.安装GPIO工具3.安装GPIO的Python支持4.Python3代码5.测试结果总结前言 最近在做一个项目需要用到…

Linux命令学习

1、linux目录结构 linux目录结构是一个树状结构 当我们直接打开ubuntu的控制台&#xff0c;进入的是 home 目录下的创建的用户&#xff0c;这里是真正的 家 目录 或者在安装 ssh 服务器之后可以直接通过 windows 命令行 访问 ubuntu 的ssh服务器&#xff0c;进入的是 home 目录…

【规范】我们是怎么做MySQL数据库安全管理的?

一、背景说明 MySQL作为数据库管理系统&#xff0c;里面保存企业的重要业务数据&#xff0c;因此保证数据库的安全性非常重要&#xff0c;如何保证数据库的安全性呢&#xff1f;用户和用户权限管理是一个很重要的方面。 MySQL数据库具有非常高的安全性&#xff0c;为我们提供…

Vue 2 即将成为过去

自从 2020 年 9 月 18 日 Vue 3 正式发布以来&#xff0c;已经有两年多时间了&#xff0c;终于在 2022 年 2 月 7 日 Vue 作者发布了一则消息&#xff1a;Vue 3 将成为新的默认版本。与此同时&#xff0c;Vue 相关官方周边的核心库 latest 发布标签将指向其 Vue 3 的兼容版本。…

从0到1完成一个Vue后台管理项目(二十一、网上地图资源、树形控件及路由权限分析、路由守卫)

往期 从0到1完成一个Vue后台管理项目&#xff08;一、创建项目&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;二、使用element-ui&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;三、使用SCSS/LESS&#xff0c;安装图标库&#xff09; 从0到1完成一个Vu…

JAVA SE 详解类和对象

类和对象 面向对象的初步认知 什么是面向对象 Java是一门纯面向对象的语言(Object Oriented Program&#xff0c;简称OOP)&#xff0c;在面向对象的世界里&#xff0c;一切皆为对象。 面 向对象是解决问题的一种思想&#xff0c;主要依靠对象之间的交互完成一件事情。用面向…