基于 PyTorch 的迁移学习介绍 (图像分类实战演示)

news2024/9/9 1:13:23

1. 介绍

迁移学习(Transfer Learning)允许我们采用另一个模型从另一个问题中学到的模式(也称为权重)并将它们用于我们自己的问题。

例如,我们可以采用计算机视觉模型从 ImageNet(包含数百万张不同对象的图像)等数据集学习到的模式,并使用它们来为我们的 FoodVision Mini 模型提供支持。

前提仍然是:找到一个性能良好的现有模型并将其应用于您自己的问题。

 

2. 迁移学习的优点

使用迁移学习有两个主要好处:

  • 可以利用已证明可以解决与我们类似的问题的现有模型(通常是神经网络架构)。

  • 可以利用一个工作模型,该模型已经学习了与我们自己类似的数据的模式。这通常会用更少的自定义数据获得很好的结果。

针对 FoodVision Mini 问题对这些进行测试,将采用在 ImageNet 上预训练的计算机视觉模型,并尝试利用其底层学习表示对披萨、牛排和寿司的图像进行分类。

研究和实践也支持迁移学习的使用。最近的一篇机器学习研究论文《How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers 》的发现建议从业者尽可能使用迁移学习。

Jeremy Howard(fastai 创始人)是迁移学习的大力支持者。“真正产生影响的事情(迁移学习),如果我们能在迁移学习方面做得更好,那就是改变世界的事情。突然之间,更多的人可以用更少的资源和数据完成世界级的工作。”

3. 如何获得预训练模型 ?

如下:

PyTorch 为每个图像、文本、音频和推荐系统这些领域都提供了预训练模型。

图像 torchvision.models

https://pytorch.org/vision/stable/models.html

文本 torchtext.models

https://pytorch.org/text/main/models.html

音频  torchaudio.models

https://pytorch.org/audio/stable/models.html

推荐系统 torchrec.models

https://pytorch.org/torchrec/torchrec.models.html

Hugging Face Hub

涵盖了来自世界各地的组织针对许多不同领域(视觉、文本、音频等)的一系列预训练模型。

https://huggingface.co/models

还有很多不同的数据集。

https://huggingface.co/datasets

timm

涵盖了几乎所有最新、最好的计算机视觉模型以及许多其他有用的计算机视觉功能的 PyTorch 实现。

https://github.com/rwightman/pytorch-image-models

Paperswithcode

最新最先进的机器学习论文的集合,并附有代码实现。您还可以在此处找到不同任务的模型性能基准。

https://paperswithcode.com/

4. 数据集的处理

我们会使用在 将notebook中的PyTorch代码模块化 介绍的 data_setup.py 脚本来获取 DataLoaders。

使用预训练模型时,最重要的是要以与预训练模型的原始训练数据相同转换的方式转换我们的自定义数据。

一般情况下,所有预训练模型要求输入数据的形状为 [N, 3, H, W],其中:

  • N 一般是32

  • 3表示 RGB 颜色通道

  • H 和 W 分别表示图像高度和宽度,一般是224

而且要求输入数据为取值在 [0, 1] 的浮点型张量,其中每个颜色通道的均值和标准差分别为:

  • [0.485, 0.456, 0.406]

  • [0.229, 0.224, 0.225]

均值和标准差是根据 ImageNet 数据集的图像子集计算出来的。不强制设置均值和标准差,神经网络通常能够很好地计算出适当的数据分布(它们将自行计算均值和标准差),但在开始时设置它们可以帮助我们的网络更快地实现更好的性能。

手工按照要求创建 transforms:

手工方式的好处是非常可定制化,我们可以在 transforms 中增加数据增强模块。

但是在 torchvision 0.13 版本之后,我们可以通过下面的方式获取 torchvision.models 中的预训练模型的权重:

weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT

其中:

  • EfficientNet_B0_Weights 是我们想要使用的模型架构权重(torchvision.models 中有许多不同的模型架构选项)。

  • DEFAULT 表示最佳可用权重(ImageNet 中的最佳性能)。

    注意:根据您选择的模型架构,您可能还会看到其他选项,例如 IMAGENET_V1 和 IMAGENET_V2,其中通常版本号越高越好。不过,如果您想要最好的可用选项,DEFAULT 是最简单的选择。有关更多信息,请参阅 torchvision.models 文档。

    https://pytorch.org/vision/main/models.html

我们就可以获得在 ImageNet 上训练 EfficientNet_B0_Weights 时的 transforms 了:

这种自动获取 transforms 的好处是可以确保这个正是当初训练模型时用的 transforms,但是坏处是不太方便定制化了。

5. 获取预训练模型

迁移学习的核心是在与您的问题空间类似的问题空间上选一个已经表现良好的模型,然后根据您的用例对其进行定制。

由于我们正在研究计算机视觉问题(使用 FoodVision Mini 进行图像分类),因此我们可以在 torchvision.models 中找到预训练的分类模型。

常见的计算机视觉架构:

架构代码
ResNettorchvision.models.resnet18(),
torchvision.models.resnet50(),
...
VGGtorchvision.models.vgg16()
EfficientNettorchvision.models.efficientnet_b0(),
torchvision.models.efficientnet_b1(),
...
VisionTransformer(ViT)torchvision.models.vit_b_16(),
torchvision.models.vit_b_32(),
...
ConvNeXttorchvision.models.convnext_tiny(),
torchvision.models.convnext_small(),
...

怎么选择预训练模型 ?

torchvision.models 中还有很多的预训练模型可用,可谓是琳琅满目,那么我们应该怎么选?

这取决于您当前正在处理的问题可用的设备

一般来说,模型名称中的数字越大(例如 effectivenet_b0() -> effectivenet_b1() -> effectivenet_b7())意味着性能更好,但模型更大。

但是更大的模型要求更强劲的设备。例如,假设您想在移动设备上运行模型,则必须考虑设备上有限的计算资源,因此您需要寻找更小的模型。

但如果您拥有无限的计算能力,您可能会选择最大、最需要计算的模型。了解这种性能、速度与尺寸的权衡需要时间和实践的帮助。

对我来说,我在 effectivenet_bX 模型中找到了一个很好的平衡。

注意:即使我们使用 effectivenet_bX,重要的是不要过于依赖任何一种架构,因为随着新研究的发布,它们总是在变化。最好不断地实验、实验、实验,看看什么对您的问题有效。

6. 设置预训练模型

我们将使用的预训练模型是 torchvision.models.efficientnet_b0()。该架构来自论文《EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks》。

https://arxiv.org/abs/1905.11946

我们将要创建的示例是来自 torchvision.models 的预训练 EfficientNet_B0 模型,其输出层针对我们对披萨、牛排和寿司图像进行分类的用例进行了调整。

我们可以使用与创建 transforms 相同的代码来设置 EfficientNet_B0 在 ImageNet 上预训练得到的权重。

weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT # .DEFAULT = ImageNet

这意味着该模型已经过数百万张图像的训练,并且具有良好的图像数据基础表示。该预训练模型的 PyTorch 版本能够在 ImageNet 的 1000 个类别中实现约 77.7% 的准确率。

下载预训练模型,其实是下载权重:

如果我们打印出模型,我们会得到一大串输出:

这是迁移学习的好处之一,它采用现有的模型,该模型是由世界上一些最好的工程师精心设计的,并且可以应用于您自己的问题。

我们的 effectivenet_b0 分为三个主要部分:

  • 特征 - 卷积层和其他各种激活层的集合,用于学习视觉数据的基本表示(此基本表示/层的集合通常称为特征或特征提取器,“模型的基本层学习图片数据的不同特征”)。

  • avgpool - 取特征层输出的平均值并将其转换为特征向量。

  • 分类器 - 将特征向量转换为与所需输出类数量具有相同维度的向量(因为 effectivenet_b0 是在 ImageNet 上预训练的,并且 ImageNet 有 1000 个类,所以 out_features=1000 是默认值)。

7. 获取模型摘要信息

要了解有关我们模型的更多信息,让我们使用 torchinfo 的summary() 方法。

为此,我们将传入:

  • model - 我们想要获得摘要的模型。

  • input_size - 我们想要传递给模型的数据的形状,对于 effectivenet_b0 的情况,输入大小为 (batch_size, 3, 224, 224),尽管 effectivenet_bX 的其他变体具有不同的输入大小。

  • 注意:由于 torch.nn.AdaptiveAvgPool2d(),许多现代模型可以处理不同大小的输入图像,该层根据需要自适应调整给定输入的 output_size。您可以通过将不同大小的输入图像传递给 summary() 或您的模型来尝试这一点。

  • col_names - 我们希望看到的有关模型的各种信息列。

  • col_width - 摘要的列应有多宽。

  • row_settings - 连续显示哪些功能。

从摘要的输出中,我们可以看到当图像数据通过模型时所有各种输入和输出形状的变化。还有一大堆参数(预训练权重)来识别数据中的不同模式。

作为参考,我们之前章节中的模型 TinyVGG 有 8083 个参数,而 effectivenet_b0 有 5288548 个参数,增加了约 654 倍!您认为这是否意味着更好的性能?

8. 冻结基础模型并根据我们的需求改变输出层

迁移学习的过程通常是这样的:冻结预训练模型的一些基础层(通常是特征部分),然后调整输出层(也称为头/分类器层)以满足您的需求。

您可以通过更改输出层来自定义预训练模型的输出以适应您的问题。

原始的 torchvision.models.efficientnet_b0() 带有 out_features=1000 ,

因为 ImageNet 中有 1000 个类。 

然而,对于我们的问题,对披萨、牛排和寿司的图像进行分类,

我们只需要 out_features=3。

让我们冻结 effectivenet_b0 模型的特征部分中的所有层/参数。

注意:冻结层意味着保持它们在训练期间的原样。例如,如果您的模型具有预训练层,那么冻结它们就是说,“在训练期间不要更改这些层中的任何模式,保持它们的原样。” 本质上,我们希望保留我们的模型从 ImageNet 学到的预训练权重/模式作为主干,然后只更改输出层。

我们可以通过设置属性 requires_grad=False 来冻结特征部分中的所有层/参数。对于 require_grad=False 的参数,PyTorch 不会跟踪梯度更新,反过来,我们的优化器在训练期间也不会更改这些参数。

本质上,requires_grad=False 的参数是“无法训练”或“冻结”的。

现在让我们根据需要调整预训练模型的输出层或分类器部分。现在我们的预训练模型有 out_features=1000,因为 ImageNet 中有 1000 个类。然而,我们没有1000个类别,我们只有三个——披萨,牛排和寿司。

我们可以通过创建一系列新的层来更改模型的分类器部分。

当前的分类器包括:

(classifier): Sequential(
    (0): Dropout(p=0.2, inplace=True)
    (1): Linear(in_features=1280, out_features=1000, bias=True)

我们将使用 torch.nn.Dropout(p=0.2, inplace=True) 保持 Dropout 层相同。

注意:Dropout 层以 p 的概率随机删除两个神经网络层之间的连接。例如,如果 p=0.2,则每次传递都会随机删除神经网络层之间 20% 的连接。这种做法的目的是通过确保保留的连接学习特征来补偿其他连接的删除(希望这些剩余的特征更通用),从而防止模型过拟合。

我们将为线性输出层保留 in_features=1280,但将 out_features 值更改为 class_names 的长度 (len(['pizza', 'steak', 'sushi']) = 3)。我们的新分类器层应该与我们的模型位于同一设备上。

更新完输出层后再看一下模型摘要信息:

我们可以看到有一点变化:

  • 可训练列 - 您将看到许多基础层(特征部分中的层)的可训练值为 False。这是因为我们设置了它们的属性 requires_grad=False。除非我们改变这一点,否则这些层在未来的训练期间不会更新。

  • 分类器的输出形状 - 模型的分类器部分现在的输出形状值为 [32, 3],而不是 [32, 1000]。它的 Trainable 值也是 True。这意味着它的参数将在训练期间更新。本质上,我们使用特征部分为我们的分类器部分提供图像的基本表示,然后我们的分类器层将学习如何使基本表示与我们的问题保持一致。

  • 可训练参数较少 - 之前有 5288548 个可训练参数。但由于我们冻结了模型的许多层,只留下分类器可训练,因此现在只有 3843 个可训练参数(甚至比我们的 TinyVGG 模型还要少)。尽管还有 4007548 个不可训练的参数,但这些参数将创建输入图像的基本表示,以输入到分类器层中。

注意:模型的可训练参数越多,计算能力就越强/训练时间就越长。冻结模型的基础层并保留较少的可训练参数意味着我们的模型应该训练得相当快。这是迁移学习的一大好处,它采用针对与您的问题类似的问题训练的模型的已学习参数,并且仅稍微调整输出以适应您的问题。

9. 训练模型

现在我们已经有了一个半冻结的预训练模型,并且有一个定制的分类器,我们来看看迁移学习的实际应用怎么样?

为了开始训练,我们创建一个损失函数和一个优化器。因为我们仍在处理多类分类,所以我们将使用 nn.CrossEntropyLoss() 作为损失函数。我们将坚持使用 torch.optim.Adam() 优化器且设置学习率 lr=0.001。

我们位于 going_modular 目录中的 engine.py 脚本中的 train() 函数训练模型。让我们看看训练我们的模型 5 个 epoch 需要多长时间。

注意:我们只会在这里训练参数分类器,因为模型中的所有其他参数都已被冻结。

借助 efficientnet_b0 主干,我们的模型在测试数据集上实现了近 89% 以上的准确率,几乎是我们使用 TinyVGG 实现的准确率的两倍。这对于我们用几行代码下载的模型来说还不错。

10. 使用损失曲线评估模型

损失曲线看起来很棒!看起来两个数据集(训练和测试)的损失都朝着正确的方向发展。准确度值也是如此,呈上升趋势。

这证明了迁移学习的力量。使用预训练模型通常可以在更短的时间内用少量数据产生相当好的结果。

11. 用测试集中的图像预测

看起来我们的模型在定量上表现良好,但在定性上又如何呢?让我们通过使用我们的模型对测试集中的图像(这些在训练期间看不到)进行一些预测并绘制它们来找出答案。

我们必须记住的一件事是,为了让我们的模型对图像进行预测,该图像必须与我们的模型所训练的图像具有相同的格式。

这意味着我们需要确保我们的图像具有:

  • 相同的形状 - 如果我们的图像与模型训练时的形状不同,我们就会得到形状错误。

  • 相同的数据类型 - 如果我们的图像是不同的数据类型(例如 torch.int8 与 torch.float32),我们将收到数据类型错误。

  • 相同的设备 - 如果我们的图像与我们的模型位于不同的设备上,我们将收到设备错误。

  • 相同的 transforms - 如果我们的模型是在以某种方式转换的图像上进行训练的(例如,用特定的平均值和标准差进行标准化),并且我们尝试对以不同方式转换的图像进行预测,那么这些预测可能会失败。

注意:如果您尝试使用经过训练的模型进行预测,这些要求适用于所有类型的数据。您想要预测的数据应采用与训练模型时相同的格式。

我们创建了一个函数 pred_and_plot_image():

  1. 接受经过训练的模型、类名列表、目标图像的文件路径、图像大小、transforms 和目标设备。

  2. 使用 PIL.Image.open() 打开图像。

  3. 为图像创建一个 transforms(这将默认为我们上面创建的 manual_transforms,或者它可以使用从 weights.transforms() 生成的 transforms)。

  4. 确保该模型位于目标设备上。

  5. 使用 model.eval() 打开模型评估模式(这会关闭 nn.Dropout() 等层,因此它们不用于推理)和推理模式上下文管理器。

  6. 使用步骤 3 中的 transforms 转换目标图像,并使用 torch.unsqueeze(dim=0) 添加额外的批量尺寸,以便我们的输入图像具有形状 [batch_size, color_channels, height, width]。

  7. 通过将图像传递给模型来对图像进行预测,确保它位于目标设备上。

  8. 使用 torch.softmax() 将模型的输出 logits 转换为预测概率。

  9. 使用 torch.argmax() 将模型的预测概率转换为预测标签。

  10. 使用 matplotlib 绘制图像,并将标题设置为步骤 9 中的预测标签和步骤 8 中的预测概率。

我们通过对测试集中的一些随机图像进行预测来测试模型。我们可以使用 list(Path(test_dir).glob("*/*.jpg")) 获取所有测试图像路径的列表。

然后我们可以使用 Python 的 random.sample(population, k) 随机采样其中的一些,其中 Population 是要采样的序列,k 是要检索的样本数。

结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1917243.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

51单片机嵌入式开发:8、 STC89C52RC 操作LCD1602原理

STC89C52RC 操作LCD1602原理 1 LCD1602概述1.1 LCD1602介绍1.2 LCD1602引脚说明1.3 LCD1602指令介绍 2 LCD1602外围电路2.1 LCD1602接线方法2.2 LCD1602电路原理 3 LCD1602软件操作3.1 LCD1602显示3.2 LCD1602 protues仿真 4 总结 1 LCD1602概述 1.1 LCD1602介绍 LCD1602是一种…

java如何判断某个数在区间是否存在?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

数字园区新视界:智慧管理的全景体验

通过图扑可视化技术,智慧园区管理实现实时监控与数据分析,优化各类资源的配置与使用,提高整体运营效率与智能决策能力。

Java中关于File类的详解

File类 File类是文件和目录路径名称的抽象表示,主要用于文件和目录的创建、查找和删除等操作。在创建File对象的时候,需要传递一个路径,这个路径定位到哪个文件或者文件夹上,File就代表哪个对象。 File file new File("D:…

记一次Ueditor上传Bypss

前言 前一段时间和小伙伴在某内网进行渗透测试,目标不给加白,只能进行硬刚了,队友fscan一把梭发现某资产疑似存在Ueditor组件,但初步测试是存在waf和杀软的,无法进行getshell,经过一番折腾最终getshell&am…

揭秘”大模型加速器”如何助力大模型应用

文章目录 一、大模型发展面临的问题二、“大模型加速器”助力突破困难2.1 现场效果展示2.1.1 大模型加速器——文档解析引擎2.2.2 图表数据提取 三、TextIn智能文档处理平台3.1 在线免费体验3.1.1 数学公式提取3.1.2 表格数据提取 四、acge文本向量化模型4.1 介绍4.2 技术创新4…

Python基础语法:运算符详解(算术运算符、比较运算符、逻辑运算符、赋值运算符)②

文章目录 Python中的运算符详解一、算术运算符二、比较运算符三、逻辑运算符四、赋值运算符五、综合示例结论 Python中的运算符详解 在Python编程中,运算符用于执行各种操作,例如算术计算、比较、逻辑判断和赋值。了解并掌握这些运算符的使用方法是编写…

CTF php RCE (四)

0x08 取反以及异或、或 这两个东西呢相当的好玩&#xff0c;也能够达到一下小极限的操作 <?php error_reporting(0); if(isset($_GET[code])){$code$_GET[code];if(strlen($code)>40){die("This is too Long.");}if(preg_match("/[A-Za-z0-9]/",$…

Firealpaca 解锁版下载及安装教程 (火焰羊驼绘画软件)

前言 FireAlpaca是一款简单易用的电脑绘画软件&#xff0c;采用了类似于Photoshop的图层绘画方式。对于喜欢手绘和创作漫画的朋友来说&#xff0c;FireAlpaca的多图层功能使得绘画过程更加便捷和简单。作为一个小型图像编辑软件&#xff0c;它能够轻松处理多个图层或手绘图&am…

拥抱UniHttp,规范Http接口对接之旅

前言 如果你项目里还在用传统的编程式Http客户端比如HttpClient、Okhttp去直接对接第三方Http接口&#xff0c; 那么你项目一定充斥着大量的对接逻辑和代码&#xff0c; 并且针对不同的对接渠道方需要每次封装一次调用的简化&#xff0c; 一旦封装不好系统将会变得难以维护&am…

策略模式(大话设计模式)C/C++版本

策略模式 商场收银软件 根据客户所购买商品的单价和数量来收费 需求分析&#xff1a; 1. 输入单价数量 > 界面逻辑 2. 计算&#xff08;可能打折或者促销&#xff09; > 业务逻辑 3. 输出结果 > 界面逻辑感觉和计算器的逻辑流程差不多&#xff0c;可以用简单工厂模式…

浪潮天启防火墙TQ2000远程配置方法SSL-xxx、L2xx 配置方法

前言 本次设置只针对配置VXX&#xff0c;其他防火墙配置不涉及。建议把防火墙内外网都调通后再进行Vxx配置。 其他配置可参考&#xff1a;浪潮天启防火墙配置手册 配置SSLVxx 在外网端口开启SSLVxx信息 开启SSLVxx功能 1、勾选 “启用SSL-Vxx” 2、设置登录端口号&#xff0…

Unity3D 太空大战射击游戏

一、前言 本案例是初级案例&#xff0c;意在帮助想使用unity的初级开发者能较快的入门&#xff0c;体验unity开发的方便性和简易性能。 本次我们将使用团结引擎进行开发&#xff0c;帮助想体验团结引擎的入门开发者进行较快的环境熟悉。 本游戏案例以太空作战为背景&#xff0c…

如何分析软件测试中发现的Bug!

假如你是一名软件测试工程师&#xff0c;每天面对的就是那些“刁钻”的Bug&#xff0c;它们像是隐藏在黑暗中的敌人&#xff0c;时不时跳出来给你一个“惊喜”。那么&#xff0c;如何才能有效地分析和处理这些Bug&#xff0c;让你的测试工作变得高效且有趣呢&#xff1f;今天我…

Threadlocal使用获取最后更新人信息

Threadlocal 的作用范围是一个线程&#xff0c;tomcat启动默认开启一个线程 首先点击登录&#xff0c;登录方法会返回token 拿到token后放在请求头中发送商品的插入请求&#xff0c;在插入是设置拿到token中的nickName&#xff08;花名&#xff09;放入&#xff08;lastUpdate…

C 语言中如何实现字符串的拼接?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 【https://dwz.mosong.cc/cyyjc】通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&…

轻松搭建RAG:澳鹏RAG开发工具

我们很高兴地宣布推出RAG开发工具&#xff0c;这是澳鹏大模型智能开发平台的一项新功能。此功能可帮助团队轻松创建高质量的检索增强生成 (RAG) 模型。 什么是 RAG&#xff1f; 检索增强生成 (RAG) 通过利用大量外部数据源&#xff08;例如企业的知识库&#xff09;显著增强了…

git查看版本,查看安装路径、更新版本

git version 查看版本 git update-git-for-windows 更新版本 git version 查看版本

美客多卖家必备:自养号测评补单技术的实战策略

构建美客多&#xff08;MercadoLibre&#xff09;自养号测评体系的稳健策略 一、确立目标与前期筹备 深入理解平台规范&#xff1a;首要任务是深入研究美客多平台的规则与指导方针&#xff0c;确保所有行动均符合平台要求&#xff0c;避免任何违规行为导致账号受限。 明确测评…

光电门验证动量守恒实验

本实验所需器件与第二个实验相同。但是连线方式有所区别&#xff0c;先将Arduino的电源输出接到两个光电门&#xff0c;然后再将光电门1的信号输出线接到Arduino的第10个端口&#xff0c;光电门2的信号输出线接到Arduino的第11个端口。对Arduino写入下列程序&#xff08;只有主…