计算机视觉之迁移学习中的微调(fine tuning)

news2025/1/11 5:54:54

        现在的数据集越来越大,都是大模型的训练,参数都早已超过亿级,面对如此大的训练集,绝大部分用户的硬件配置达不到,那有没有一种方法让这些训练好的大型数据集的参数,迁移到自己的一个目标训练数据集当中呢?比如使用最广泛的图像数据集ImageNet,超过1000万张的图像和1000个分类,这些耗费大量时间人力物力而训练出来的参数,为我所用?

答案是肯定的,就是接下来说的微调(fine tuning),顾名思义就是细微的调节将这个已训练好的模型参数迁移过来,或者说复制(不是完全拷贝,故有点区别,所以叫微调)过来,最后再对自己的模型进行训练。
比如说我们的一个数据集是想找出图片中的热狗,但ImageNet数据集的图像大多于此无关,那迁移过来的参数有用吗?有用,因为在训练ImageNet中抽取的特征,如:边缘、纹理、形状等,对于识别物体都有同样的效果。

如何从源数据集迁移到目标数据集呢?方法就是将除了输出层的其余层的参数复制(做了微调)到目标数据集的除了输出层的其余层。而对于目标数据集的输出层,我们随机初始化该层的模型参数,然后将整个目标数据集重新训练一遍即可。

我们下载热狗数据集,来识别图像中的热狗的一个示例:

import d2lzh as d2l
from mxnet import gluon,init,nd
from mxnet.gluon import data as gdata,loss as gloss,model_zoo
from mxnet.gluon import utils as gutils
import os
import zipfile
#下载热狗数据集(如果超时等下载不了,就直接手动下载再解压)
#解压之后是hotdog目录,里面是train和test目录,分别都有hotdog和not-hotdog目录,存放热狗和非热狗(和热狗长的像的,比如香蕉之类)的图像
data_dir='../data'
fname=gutils.download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/hotdog.zip')
with zipfile.ZipFile(fname,'r') as z:
     z.extractall(data_dir)

当然如果遇到权限错误,比如我的是C盘存放,使用管理员权限的命令行执行即可。
PermissionError: [WinError 5] 拒绝访问。: '..\\data\\hotdog'

 数据集下载下来了,我们先来显示正类图像(热狗)和负类图像(非热狗),熟悉下这个数据集,代码如下:

train_imgs=gdata.vision.ImageFolderDataset(os.path.join(data_dir,'hotdog/train'))
test_imgs=gdata.vision.ImageFolderDataset(os.path.join(data_dir,'hotdog/test'))
#print(len(train_imgs),len(test_imgs))#2000,800

#print(train_imgs[0])#...<NDArray 144x122x3 @cpu(0)>, 0)(高,宽,通道)和标签值,0是热狗,1是非热狗
#print(train_imgs.items[0])#('../data\\hotdog/train\\hotdog\\0.png', 0)
hotdogs=[train_imgs[i][0] for i in range(8)]
not_hotdogs=[test_imgs[-i][0] for i in range(8)]
d2l.show_images(hotdogs+not_hotdogs,2,8,scale=1.5)
d2l.plt.show()

从画布中显示的图像可以看出,都是些大小和宽高比都不一样的热狗与非热狗图。

训练模型前的图像增广

        在训练时,先从图像中裁剪出随机大小和随机高宽比的一块区域,然后将该区域缩放到高宽为224像素的输入。测试时,我们将图像的高宽缩放到256像素,然后从中裁剪出高宽为224的中心区域作为输入。此外,我们对颜色通道做标准化,就是每个数值减去所有数值的均值,再除以标准差。

# 对颜色通道做标准化处理
normalize = gdata.vision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 均值,方差
train_augs = gdata.vision.transforms.Compose([gdata.vision.transforms.RandomResizedCrop(224), gdata.vision.transforms.RandomFlipLeftRight(),
                                              gdata.vision.transforms.ToTensor(), normalize])
test_augs = gdata.vision.transforms.Compose([gdata.vision.transforms.Resize(256), gdata.vision.transforms.CenterCrop(224),
                                             gdata.vision.transforms.ToTensor(), normalize])

定义和初始化模型

使用ImageNet数据集上预训练的ResNet-18作为源模型。预训练模型的参数,将下载到:C:\Users\Tony\.mxnet\models里面的resnet18_v2-8aacf80f.params参数文件

pretrained_net = model_zoo.vision.resnet18_v2(pretrained=True)
#包含两个成员变量:features和output
#打印output输出层看下,输出1000类的全连接层
print(pretrained_net.features)#Dense(512 -> 1000, linear)

接下来新建目标模型,其定义和预训练的源模型一样,只不过将最后的输出数修改为目标数据集的类别数,因为features中的模型参数是已经在ImageNet数据集上预训练得到的,已经足够好,所以只需要使用较小的学习率来微调这些参数。对于output中的模型参数我们采用随机初始化,一般需要更大的学习率(10倍学习率)从头训练。

finetune_net=model_zoo.vision.resnet18_v2(classes=2)
finetune_net.features=pretrained_net.features
finetune_net.output.initialize(init.Xavier())
finetune_net.output.collect_params().setattr('lr_mult',10)

微调模型

定义一个使用微调的训练函数,便于多次调用:

def train_fine_tuning(net,learning_rate,batch_size=128,num_epochs=5):
    train_iter=gdata.DataLoader(train_imgs.transform_first(train_augs),batch_size,shuffle=True)
    test_iter=gdata.DataLoader(test_imgs.transform_first(test_augs),batch_size)
    ctx=d2l.try_all_gpus()
    net.collect_params().reset_ctx(ctx)
    net.hybridize()
    loss=gloss.SoftmaxCrossEntropyLoss()
    trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':learning_rate,'wd':0.001})
    d2l.train(train_iter,test_iter,net,loss,trainer,ctx,num_epochs)

我们以小的学习率0.01微调获得预训练模型参数,然后以10倍学习率从头训练目标模型的输出层参数,当然本人配置不是很好,批处理大小就弄小点,不然内存溢出。

train_fine_tuning(finetune_net,0.01,32)
'''
(pygpu) C:\Users\Tony>python p.py
training on [gpu(0)]
epoch 1, loss 1.5428, train acc 0.815, test acc 0.594, time 28.9 sec
epoch 2, loss 0.7645, train acc 0.864, test acc 0.921, time 24.9 sec
epoch 3, loss 0.5352, train acc 0.882, test acc 0.921, time 24.8 sec
epoch 4, loss 0.4242, train acc 0.882, test acc 0.774, time 24.8 sec
epoch 5, loss 0.3898, train acc 0.893, test acc 0.917, time 24.9 sec
'''

作为对比,定义一个相同模型,但是将它所有模型参数都是初始化为随机值,由于需要从头训练,学习率大点0.1。

scratch_net=model_zoo.vision.resnet18_v2(classes=2)
scratch_net.initialize(init=init.Xavier())
train_fine_tuning(scratch_net,0.1,32)
'''
(pygpu) C:\Users\Tony>python p.py
training on [gpu(0)]
epoch 1, loss 0.5667, train acc 0.750, test acc 0.818, time 27.9 sec
epoch 2, loss 0.4731, train acc 0.795, test acc 0.824, time 25.0 sec
epoch 3, loss 0.4109, train acc 0.818, test acc 0.854, time 24.9 sec
epoch 4, loss 0.3790, train acc 0.828, test acc 0.812, time 24.9 sec
epoch 5, loss 0.4080, train acc 0.826, test acc 0.853, time 24.9 sec
'''

如果不微调,直接使用源模型参数,将会怎么样呢?

finetune_net = model_zoo.vision.resnet18_v2(classes=2)
finetune_net.features=pretrained_net.features
finetune_net.features.collect_params().setattr('grad_req', 'null')
finetune_net.output.initialize(init.Xavier())
finetune_net.output.collect_params().setattr('lr_mult', 10)

测试几次的效果,loss效果比较差,精度还好,不稳定性要大点,可能是小数据集容易过拟合的原因,不知道伙伴们测试的效果如何?

模型前缀不一样

我们也可以将微调得到的最终参数保存起来

finetune_net.collect_params().save('hotdog.params')

然后我们加载保存的参数看下:

mynet=model_zoo.vision.resnet18_v2(classes=2)
mynet.collect_params().load('hotdog.params')
train_fine_tuning(mynet, 0.01, 32)

出现如下错误:

AssertionError: Parameter 'resnetv22_batchnorm0_gamma' is missing in file 'hotdog.params', which contains parameters: 'resnetv20_batchnorm0_gamma', 'resnetv20_batchnorm0_beta', 'resnetv20_batchnorm0_running_mean', ..., 'resnetv20_batchnorm2_running_mean', 'resnetv20_batchnorm2_running_var', 'resnetv21_dense0_weight', 'resnetv21_dense0_bias'. Please make sure source and target networks have the same prefix.

很明显里面的前缀名不一致,这个就是没有固定前缀,在训练模型的时候都是动态递增的前缀,所以我们需要固定前缀,这样在加载的时候就指定前缀就好了
训练的时候,指定前缀:

pretrained_net = model_zoo.vision.resnet18_v2(pretrained=True,prefix='res_')
finetune_net = model_zoo.vision.resnet18_v2(classes=2,prefix='res_')
finetune_net.features=pretrained_net.features
finetune_net.output.initialize(init.Xavier())
finetune_net.output.collect_params().setattr('lr_mult', 10)

加载的时候,同样指定前缀即可:

mynet=model_zoo.vision.resnet18_v2(classes=2,prefix='res_')
mynet.collect_params().load('hotdog.params')
#print(mynet)

然后看下效果:

train_fine_tuning(mynet, 0.01, 32)
'''
(pygpu) C:\Users\Tony>python p.py
training on [gpu(0)]
epoch 1, loss 0.1788, train acc 0.933, test acc 0.939, time 28.0 sec
epoch 2, loss 0.1303, train acc 0.947, test acc 0.926, time 25.7 sec
epoch 3, loss 0.1398, train acc 0.948, test acc 0.943, time 25.7 sec
epoch 4, loss 0.1236, train acc 0.953, test acc 0.939, time 25.1 sec
epoch 5, loss 0.1157, train acc 0.955, test acc 0.941, time 25.1 sec
'''

ImageNet中热狗

目前的输出层我们使用的是随机初始化,现在我们将ImageNet数据集中的热狗这个类的权重参数应用到我们的初始化里面来,看下有什么效果:

weight=pretrained_net.output.weight
hotdog_w=nd.split(weight.data(),1000,axis=0)[713]
output_init=nd.concat(hotdog_w,-hotdog_w,dim=0)#将热狗的权重分一个负类的代表非热狗
print(output_init)
'''
[[-0.07650785  0.02459255  0.00455526 ... -0.06427797 -0.01825024
  -0.02214353]
 [ 0.07650785 -0.02459255 -0.00455526 ...  0.06427797  0.01825024
   0.02214353]]
<NDArray 2x512 @cpu(0)>
'''
finetune_net.output.initialize(init.Constant(output_init))#加载ImageNet里的热狗权重参数(增加一个负类)
finetune_net.output.collect_params().setattr('lr_mult', 10)
train_fine_tuning(finetune_net, 0.01, 32)
'''
training on [gpu(0)]
[12:41:37] c:\jenkins\workspace\mxnet-tag\mxnet\src\operator\nn\cudnn\./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
epoch 1, loss 1.3655, train acc 0.804, test acc 0.818, time 27.9 sec
epoch 2, loss 0.6536, train acc 0.875, test acc 0.902, time 25.0 sec
epoch 3, loss 0.5011, train acc 0.885, test acc 0.934, time 25.1 sec
epoch 4, loss 0.3031, train acc 0.913, test acc 0.922, time 25.0 sec
epoch 5, loss 0.2167, train acc 0.917, test acc 0.924, time 24.9 sec
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/20466.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

javaweb Ajax AXios异步框架 JSON 案例

AJAX概念&#xff1a;AJAX(Asynchronous JavaScript And XML)&#xff1a;异步的 JavaScript 和 XML AJAX作用&#xff1a; 与服务器进行数据交换&#xff1a;通过AJAX可以给服务器发送请求&#xff0c;并获取服务器响应的数据 使用了AJAX和服务器进行通信&#xff0c;就可以…

ArcGIS_将多个点数据整合成一个点数据

问题描述:如何将多个点合并成一个点,并保留原始点数据的字段信息 方法一:整合 打开arcgis整合工具 :“数据管理工具——要素类——整合” 容差半径可以通过arcgis测量工具获取,根据自己的目标任务,选择合适的容差半径 该方法优点在于整合后的点可以正好位于原始点数据…

牛客网之SQL非技术快速入门(7)-字符串截取、切割、删除、替换

知识点&#xff1a; &#xff08;1&#xff09;substring_indexsubstring_index(str,delim,count) str:要处理的字符串 delim:分隔符 count:计数 &#xff08;2&#xff09;切割、截取、删除、替换 1 2 3 4 5 6 7 8 9 10 11 12 13 14 select -- 替换法 replace(string, 被…

俺把所有粉丝显示在地图上啦~【详细教程+完整源码】

文章目录&#x1f332;小逼叨&#x1f332;爬取所有粉丝的IP所属地&#x1f334;爬者基本素养&#xff1a;网页分析&#x1f334;源代码&#x1f332;数据清洗和保存&#x1f334;源代码&#x1f332;绘制地图&#x1f334;源代码&#x1f332;结束语&#x1f332;小逼叨 其实昨…

windows中使用curl

curl这个工具在linux和macOS都经常使用&#xff0c;感觉挺实用的。在windows中默认也带了一个但是用起来不太一样&#xff0c;于是就想自己手动安装一个原汁原味的curl。 下载安装 https://curl.se/windows/ 下载适合自己平台的版本&#xff0c;解压就可以直接运行了。 比如…

剑指 Offer II 026. 重排链表【链表】

难度等级&#xff1a;中等 上一篇算法&#xff1a; 剑指 Offer II 021. 删除链表的倒数第 n 个结点【链表】 力扣此题地址&#xff1a; 剑指 Offer II 026. 重排链表 - 力扣&#xff08;LeetCode&#xff09; 1.题目&#xff1a;重排链表 给定一个单链表 L 的头节点 head &…

Linux用户和权限学习笔记

认识root用户 什么是root用户 无论是Windows、MacOS、Linux均采用多用户的管理模式进行权限管理。 在Linux系统中&#xff0c;拥有最大权限的账户名为&#xff1a;root&#xff08;超级管理员&#xff09;而在前期&#xff0c;我们一直使用的账户是普通账户&#xff1a;itheim…

《Android Studio开发实战 从零基础到App上线(第3版)》出版后记

2018年11月&#xff0c;经过熬夜写作的《Android Studio开发实战 从零基础到App上线(第2版)》正式出版面世。承蒙众多读者的厚爱&#xff0c;第2版的图书在此后的三年多时间&#xff0c;一直保持在移动开发图书的销量排行榜前列&#xff0c;迄今为止京东对该书的评价已达8000多…

设计模式基础-概括

目录 一、设计原则 二、设计模式分类 1、创建型模式&#xff1a;创建对象 2、结构型模式&#xff1a;更大的结构 3、行为型模式&#xff1a;交互以及职责分配 4、对象模式与类模式区别 三、各类型模式简介 1、创建型模式 2、结构型模式 3、行为型模式 一、设计原则 …

JAVA中Function的使用

JAVA中Function的使用一、方法介绍参数类型方法介绍源码二、demo参考&#xff1a; https://blog.csdn.net/boyan_HFUT/article/details/99618833 一、方法介绍 表示接受一个参数并产生结果的函数。 参数类型 T - 函数输入的类型R - 函数的结果类型 方法介绍 R apply(T t) …

【毕业设计】45-基于单片机的智能温度/超温报警计的系统设计(原理图工程+仿真工程+源代码+答辩论文+答辩PPT)

【毕业设计】45-基于单片机的智能温度/超温报警计的系统设计&#xff08;原理图工程仿真工程源代码答辩论文答辩PPT&#xff09; 文章目录【毕业设计】45-基于单片机的智能温度/超温报警计的系统设计&#xff08;原理图工程仿真工程源代码答辩论文答辩PPT&#xff09;资料下载链…

Vue 路由

参考文献&#xff1a;Vue中的路由 目录:一、路由理解&#xff1a;二、路由管理器理解&#xff1a;三、路由的使用&#xff1a;四、嵌套路由&#xff1a;五、路由传参&#xff1a;1.query传参&#xff1a;2. params传参&#xff1a;六、编程式路由导航&#xff1a;七、响应路由参…

数字孪生技术有没有真正的实用价值?

作为一个数字孪生领域的技术公司负责人&#xff0c;我尽可能用比较直白的话来描述一下我对数字孪生行业以及数字孪生价值的理解。 纵观数字孪生相关的公司&#xff0c;主要有两个流派&#xff0c;一派是具有互联网基因的数字孪生创业公司&#xff0c;一派是在工业软件领域实力…

ConfigurableListableBeanFactory和BeanDefinitionRegistry关系

前言 &#xff1a;在查看springBoot源码的过程中&#xff0c;遇到了这个问题&#xff0c;上网查了一些资料&#xff0c;理解了一些&#xff0c;这里顺便把这个问题给记录一下。 在springBoot调用Refresh方法里面 &#xff0c;有一个叫invokeBeanFactoryPostProcessors的方法【…

HIve数仓新零售项目ODS层的构建

HIve数仓新零售项目 注&#xff1a;大家觉得博客好的话&#xff0c;别忘了点赞收藏呀&#xff0c;本人每周都会更新关于人工智能和大数据相关的内容&#xff0c;内容多为原创&#xff0c;Python Java Scala SQL 代码&#xff0c;CV NLP 推荐系统等&#xff0c;Spark Flink Kaf…

WindowsPE(二)空白区添加代码新增,扩大,合并节

空白区添加代码 在 PE 中插入一段调用 MessageBox 的代码。 获取MessageBox地址&#xff0c;构造ShellCode代码 利 OD 定位出 MessageBoxA 函数的地址为 0x77D507EA 。 构造 shellcode &#xff1a; unsigned char shellcode[] {0x6A, 0x00, // pus…

ORB-SLAM2 ---- Initializer::ReconstructF函数

目录 1.函数作用 2.函数解析 2.1 调用函数解析 2.2 Initializer::ReconstructF函数总体思路 2.2.1 代码 2.2.2 总体思路解析 2.2.3 根据基础矩阵和相机的内参数矩阵计算本质矩阵 2.2.4 从本质矩阵求解两个R解和两个t解&#xff0c;共四组解 2.2.5 分别验证求解的4种…

准备面试题【面试】

前言 写作于 2022-11-13 19:27:08 发布于 2022-11-20 16:34:44 准备 程序员囧辉 我要进大厂 面试阿里&#xff0c;HashMap 这一篇就够了 Java 基础高频面试题&#xff08;2022年最新版&#xff09; 问遍了身边的面试官朋友&#xff0c;我整理出这份 Java 集合高频面试题…

【mysql】mysql 数据备份与恢复使用详解

一、前言 对一个运行中的线上系统来说&#xff0c;定期对数据库进行备份是非常重要的&#xff0c;备份不仅可以确保数据的局部完整性&#xff0c;一定程度上也为数据安全性提供了保障&#xff0c;设想如果某种极端的场景下&#xff0c;比如磁盘损坏导致某个时间段数据丢失&…

什么是Spring,Spring的核心和设计思想你了解吗?

目录 1.初识Spring 1.1 什么是容器 1.2 什么是IoC 2.什么是IoC容器. 2.1 什么是DI 哈喽呀,你好呀,欢迎呀,快来看一下这篇宝藏博客吧~~~ 1.初识Spring Srping指的是Spring Framework(Spring 框架).我们经常会听见框架二字,其中java中最最主流的框架当属Spring.Spring是一…