59. 微调(fine-tuning)代码实现

news2025/1/5 9:16:45

1. 热狗识别

让我们通过具体案例演示微调:热狗识别。 我们将在一个小型数据集上微调ResNet模型。该模型已在ImageNet数据集上进行了预训练。 这个小型数据集包含数千张包含热狗和不包含热狗的图像,我们将使用微调模型来识别图像中是否包含热狗。

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 获取数据集

我们使用的热狗数据集来源于网络。 该数据集包含1400张热狗的“正类”图像,以及包含尽可能多的其他食物的“负类”图像。 含着两个类别的1000张图片用于训练,其余的则用于测试。

解压下载的数据集,我们获得了两个文件夹hotdog/train和hotdog/test。 这两个文件夹都有hotdog(有热狗)和not-hotdog(无热狗)两个子文件夹, 子文件夹内都包含相应类的图像。

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')

我们创建两个实例来分别读取训练和测试数据集中的所有图像文件。

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

下面显示了前8个正类样本图片和最后8张负类样本图片。正如所看到的,图像的大小和纵横比各有不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在这里插入图片描述

3. 数据增广

在训练期间,我们首先从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为 224×224 输入图像。 在测试过程中,我们将图像的高度和宽度都缩放到256像素,然后裁剪中央 224×224 区域作为输入。 此外,对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。

# 使用RGB通道的均值和标准差,以标准化每个通道
# 为什么要这样做,是因为在ImageNet上做了这个事情,所以也要把这个事情搬过来
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# 因为我们要用ImageNet上的模型做fine-tuning,所以剪裁大小是224 x 224
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

# 测试集使用的方法也是ImageNet常用的,先resize到256 x 256,再剪裁中央224 x 224
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize])

4. 定义和初始化模型

我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。 在这里,我们指定pretrained=True自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。

# 不仅把模型的定义弄下来了,还把在ImageNet上训练好的parameter拿过来了
pretrained_net = torchvision.models.resnet18(pretrained=True)

预训练的源模型实例包含许多特征层和一个输出层fc

此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调

下面给出了源模型的成员变量fc。

pretrained_net.fc

在这里插入图片描述

在ResNet的全局平均汇聚层后,全连接层转换为ImageNet数据集的1000个类输出。 之后,我们构建一个新的神经网络作为目标模型。 它的定义方式与预训练源模型的定义方式相同,只是最终层中的输出数量被设置为目标数据集中的类数(而不是1000个)。

在下面的代码中,目标模型finetune_net中成员变量features的参数被初始化为源模型相应层的模型参数。 由于模型参数是在ImageNet数据集上预训练的,并且足够好,因此通常只需要较小的学习率即可微调这些参数。

成员变量output的参数是随机初始化的,通常需要更高的学习率才能从头开始训练。 假设Trainer实例中的学习率为 𝜂 ,我们将成员变量output中参数的学习率设置为 10𝜂 。

finetune_net = torchvision.models.resnet18(pretrained=True) # 把pretrained的模型下载下来之后
# 把最后一个输出层fully- connected随机初始化一个线性层,in_features是512,输出的类别树是2
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) 
# 再对最后这个全连接线性层做xavier的随机初始化
# ps:前面层的参数都是用pretrained的模型
nn.init.xavier_uniform_(finetune_net.fc.weight);

5. 微调模型

首先,我们定义了一个训练函数train_fine_tuning,该函数使用微调,因此可以多次调用

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,
                      param_group=True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)
    devices = d2l.try_all_gpus()
    loss = nn.CrossEntropyLoss(reduction="none")
    if param_group:
        # 把不是最后一层的所有层的参数都拿出来
        params_1x = [param for name, param in net.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
        trainer = torch.optim.SGD([{'params': params_1x}, # 这些层使用的学习率是默认的学习率
                                   {'params': net.fc.parameters(), # 最后一层用的学习率是前面的10倍
                                    'lr': learning_rate * 10}],
                                lr=learning_rate, weight_decay=0.001)
    else: # 如果没有param_group这个选项,就正常来,和之前一样
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
                   devices)

我们使用较小的学习率,通过微调预训练获得的模型参数。

train_fine_tuning(finetune_net, 5e-5)

运行结果如下:

在这里插入图片描述

为了进行比较,我们定义了一个相同的模型,但是将其(所有模型参数初始化为随机值)。 由于整个模型需要从头开始训练,因此我们需要使用更大的学习率。

在这里插入图片描述
意料之中,微调模型往往表现更好,因为它的初始参数值更有效。

老师建议:从fine-tuning开始,而不是从零开始对数据进行训练,这也是一般的计算机视觉的做法,而且,几乎可以认为,未来所有用于深度学习的应用都会是主要是基于fine-tuning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/137921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

专访中银金科:数字驱动成为新的增长引擎,未来业务转化是关键

大数据和信息科技正在逐步颠覆银行业过往的业务模式。建立以数据驱动为核心,以优化客户体验为目标的可持续营销理念,逐渐成为行业的共识。但是,伴随着银行业数字化转型进程加速发展,海量客户数据和低效营销之间的矛盾日益凸显。在…

Linux apt 命令

apt(Advanced Packaging Tool)是一个在 Debian 和 Ubuntu 中的 Shell 前端软件包管理器。 apt 命令提供了查找、安装、升级、删除某一个、一组甚至全部软件包的命令,而且命令简洁而又好记。 apt 命令执行需要超级管理员权限(root)。 apt 语…

23.2、Junit单元测试反射注解

Java代码执行的三个阶段 Junit单元测试: * 测试分类: 1. 黑盒测试:不需要写代码,给输入值,看程序是否能够输出期望的值。 2. 白盒测试:需要写代码的。关注程序具体的执行流程。 * Junit使用&#…

洛谷千题详解 | P1030 [NOIP2001 普及组] 求先序排列【C/C++、pascal语言】

博主主页:Yu仙笙 专栏地址:洛谷千题详解 目录 题目描述 输入格式 输出格式 输入输出样例 解析: C源码: C源码2: pascal源码: C源码: --------------------------------------------------------…

P4Pi AP转wifi模式

调试时间:2022.11.07 树莓派在安装P4Pi后,会自动设置为AP热点模式。本文档通过配置将树莓派系统从ap模式转变为wifi模式。 1 调试环境 Raspberry 4B 4GB-SDcard 32GB Raspberry Pi Imager v1.7.3 Raspberry Pi OS – Raspberry PiFrom industries lar…

值得信赖的数据同步备份软件 -Allway Sync 安全又可靠,简单又易用!

Allway Sync 是一款可靠的数据同步备份工具,最初的版本发布于 2004 年 4 月 19 号,距离今日大约有 19 年的更新历史了,足以说明软件绝对稳定,时间验证了软件的可靠性!而对于我们用户来说,数据同步备份最重要…

基于线性表的查找

目录 一、查找的基本概念 二、顺序查找 关键代码 完整代码 运行结果 增加哨兵 三、二分查找(折半查找) 关键代码 完整代码 运行结果 四、分块查找 图示 关键代码 完整代码 一、查找的基本概念 对查找表进行的操作 1.查找某个特定的数据元素是否存在 …

攻防世界-fakebook

题目 访问题目场景 我自己尝试了很久&#xff0c;发现怎么都找不到这道题的入手点&#xff0c;然后就去看了大佬们的文章&#xff0c;然后我发现这道题更趋近于真实的场景 解题过程 先使用目录扫描器扫一下发现存在robots.txt访问一下 这里发现存在一个备份文件 <?php…

html、css、js的小米商城

首页的展示 首页的功能 1、搜索栏模糊查询 在我在输入框输入关键字的时候&#xff0c;会匹配关键字&#xff0c;如果我的存放的数据里面包含这些关机键字就会显示出来。做到模糊查询的效果。 2、实现搜索功能 在首页的搜索框点击搜索的时候&#xff0c;就会对你输入的关键字进…

Redis 未授权访问的原理、危害及复现

原理介绍 Redis 未授权访问 准确的来说&#xff0c;其实并不是一个漏洞。而是由于开发人员配置不当&#xff0c;而产生的预料之外的危害。 具体原理&#xff1a; 可能由于部分业务要求&#xff0c;或者开发人员的配置不当&#xff0c;将 redis 服务器的 ip 和 port 暴露在公网…

基础数学(7)——常微分方程数值解法

文章目录期末考核方式基础知识解析解&#xff08;公式法&#xff09;解析解例题&#xff08;使用公式法&#xff0c;必考&#xff09;解析解的局限性数值解数值解的基本流程显示Euler法显示欧拉&#xff08;差值理解&#xff09;显示欧拉&#xff08;Taylor展开理解&#xff09…

ClickHouse表引擎详解看这篇就够了-基本讲解、处理逻辑、测试实例

表引擎是ClickHouse设计实现中的一大特色。表引擎在 ClickHouse 中的作用十分关键&#xff0c;直接决定了数据如何存储和读取、是否支持并发读写、是否支持 index、支持的 query 种类、是否支持主备复制等。1、表引擎概述1.1 介绍ClickHouse 提供了大约 28 种表引擎&#xff0c…

ArcGIS基础实验操作100例--实验43填充面要素空洞

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台&#xff1a;ArcGIS 10.6 实验数据&#xff1a;请访问实验1&#xff08;传送门&#xff09; 高级编辑篇--实验43 填充面要素空洞 目录 一、实验背景 二、实验数据 三、实验步骤 &#xff08;1&a…

JavaScript 条件语句

文章目录JavaScript If...Else 语句条件语句If 语句If...else 语句If...else if...else 语句JavaScript If…Else 语句 条件语句用于基于不同的条件来执行不同的动作。 条件语句 通常在写代码时&#xff0c;您总是需要为不同的决定来执行不同的动作。您可以在代码中使用条件语…

【学习笔记】Shell入门

Shell入门 https://www.bilibili.com/video/BV1WY4y1H7d3 资料&#xff1a;评论区取的 公众号的资料链接 https://pan.baidu.com/s/1_nBKUjE57MB2c96wmfSD5A 提取码&#xff1a;yyds 文章目录一、**Shell** 概述二、**Shell** 脚本入门三、变量1.系统预定义变量2.自定义变量**3…

自学软件测试该如何入门?

互联网行业发展很快技术更新也很快&#xff0c;软件测试技能要求在逐渐提高&#xff0c;自学软件测试要尽快而且入行后需要持续学习。保持好心态&#xff0c;找准教程&#xff0c;按照学习路线和自己的规划一步步学习下去~ 软件测试对代码的要求不像其他编程学科那么高&#x…

30个精品Python练手项目

随着 Python 语言的流行&#xff0c;越来越多的人加入到了 Python 的大家庭中。到底为什么这么多人学 Python &#xff1f;我要喊出那句话了&#xff1a;“人生苦短&#xff0c;我用 Python&#xff01;”&#xff0c;正是因为语法简单、容易学习&#xff0c;所以 Python 深受大…

Java微服务连接云服务器上的ZooKeeper

前言 这次要讲的连接ZooKeeper是在外网的云服务器上&#xff0c;不同于以往的本机上的虚拟机上的ZooKeeper&#xff0c;将会有一些不同于本机的连接方式。连接外网服务器进行操作可以更好的适应企业化的开发&#xff0c;脱离了本机的限制&#xff0c;具有很强的实战意义。 前…

小程序容器产品有何特点?

小程序容器顾名思义&#xff0c;是一个承载小程序的运行环境&#xff0c;可主动干预并进行功能扩展&#xff0c;达到丰富能力、优化性能、提升体验的目的。目前市面已知的技术产品包括&#xff1a;mPaas、FinClip、uniSDK 以及上周微信团队才推出的 Donut。今天&#xff0c;我们…

2022 年,这 20+22 位共建者闪耀 StarRocks 社区

2022 年即将过去&#xff0c;多变波动的大环境之中&#xff0c;一岁多的 StarRocks 社区依然保持了高速成长。这一年里&#xff0c;StarRocks 共发布 47 个大小版本&#xff0c;超过 200 人投入社区建设&#xff0c;每月 PR 数突破 1100。 在项目快速迭代的同时&#xff0c;社…