|迁移学习| 迁移学习详解及基于pytorch的相关代码实现

news2024/9/24 5:23:05

🐑 |迁移学习| 迁移学习详解及基于pytorch的相关代码实现 🐑

文章目录

  • 🐑 |迁移学习| 迁移学习详解及基于pytorch的相关代码实现 🐑
    • 🐑 前言🐑
    • 🐑 迁移学习详解🐑
    • 🐑 迁移学习方法🐑
    • 🐑 迁移学习实现🐑
      • 🐑 ResNet50复现🐑
      • 🐑 迁移学习训练🐑
    • 🐑 总结🐑

🐑 前言🐑

前段时间一直在疯狂做实验,各种租服务器跑代码,现在感觉整个人跑的有点神志不清了,但好在最后的结果还可以接受,于是把最近用到的迁移学习相关理论以及整体实现记录一下。在代码实现方面本篇博客使用pytorch作为框架,从搭建迁移学习模型开始到最后使用数据集做迁移学习的训练模板结束。

🐑 迁移学习详解🐑

迁移学习是一种机器学习的方法,它允许模型在新任务上获取从相关任务中获得的知识。这种技术特别有用,因为在某些情况下,我们可能无法为特定问题收集足够的训练数据,或者收集和标注数据的成本非常高昂。通过使用已经在类似或相关任务上训练好的模型,我们可以有效地“转移”这些知识到新的任务上。

这个是迁移学习比较官方的一个注释,解释一下就是将一个表现很好的大模型的性能转移到自己任务训练的模型上。举个例子就是比如学过钢琴的人可以利用他对于五线谱的优势以及音准的优势去学习小提琴;再比如跳舞的人可以利用自身的柔韧性以及肢体协调能力等更容易地学会体操。然而在图片分类(image classification)的问题中,也有很多迁移学习的实例。对于一个已有的可以正确识别出图片中猫和狗的分类器,该分类器通过学习大量带有🐱和🐶标签的图片获得。如果将该分类器应用于一个新的任务,去识别🐘和🐴,而传统机器学习的方法由于缺少大量可用的带有🐘和🐴标签的图片作为训练数据而遇到瓶颈,这时候利用迁移学习的思想,从两者数据集之间图片的相关性可以看出,此时利用🐱 🐶数据集的相关参数去优化🐘 🐴数据集的参数,则可以使模型快速收敛的同时获得一个良好的泛化能力。
在这里插入图片描述
也就是说当自己任务的数据集规模较小的情况下,正常使用神经网络去训练很难获得一个良好的性能,而且质量较低的数据集也直接决定着训练模型的性能的上限不会很高,而此时如果使用迁移学习直接使用在大模型上训练好的参数进行微调的话,一方面会发现收敛速度加快,另一方面效果相较于不用迁移学习相比肯定不会变差。

🐑 迁移学习方法🐑

简单介绍几个关于迁移学习的相关方法。我们根据不同的分类准则,可以使用不同的方式将现有的迁移学习方法进行分类总结。
在这里插入图片描述
其中基于样本的迁移学习很好理解,就像上面例子中的🐱 🐶分类迁移到🐘 🐴分类这种 人为提升目标任务中样本学习的权重
基于特征的迁移学习是目前使用次数最多的一种迁移学习方法,对于两个毫不相干的任务以及数据集,依然可以使用迁移学习的方法。例如我将在ImageNet训练的ResNet50的参数迁移到我在全是心电信号上训练的ResNet50上面,按理说ImageNet的数据集和心电信号的数据集毫无关联,但是在预训练的模型中,模型学习到了大量的图片纹理、色彩等细节的特征,所以一方面可以使我迁移学习后的模型准确率提高,另一方面泛化能力也会加强。

🐑 迁移学习实现🐑

这里以ResNet50为例,复现迁移学习的全过程。

🐑 ResNet50复现🐑

关于网络详细的内容部分这里就不多赘述了,详情可以看学长 @浩浩的科研笔记 ResNet残差网络一维、二维复现pytorch-含残差块复现思路分析。
复现代码如下:

import torch
import os
import torchvision.models as models

class Bottleneck(torch.nn.Module):
    def __init__(self,in_channels,mid_channels,output_channels,down_sample = False,use_1x1conv = False):
        super().__init__()
        if down_sample:
            self.stride = 2
        else:
            self.stride = 1

        self.use_1x1conv = use_1x1conv
        self.conv = torch.nn.Conv2d(in_channels,output_channels,1,self.stride)

        self.res = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,mid_channels,1,self.stride),
            torch.nn.BatchNorm2d(mid_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(mid_channels,mid_channels,3,padding=1),
            torch.nn.BatchNorm2d(mid_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(mid_channels,output_channels,1),
            torch.nn.BatchNorm2d(output_channels),
            torch.nn.ReLU()
        )

    def forward(self,x):
        residual = x
        out = self.res(x)
        if self.use_1x1conv:
            residual = self.conv(residual)
        out = out+residual
        out = torch.nn.functional.relu(out)
        return out

class Resnet50(torch.nn.Module):
    def __init__(self,input_channels,classes):
        super().__init__()
        self.input_channels = input_channels
        # self.output_channels = output_channels
        self.classes = classes
        self.conv_1 = torch.nn.Sequential(
            torch.nn.Conv2d(input_channels,64,7,2,3),
            torch.nn.MaxPool2d(3,2,1),

            Bottleneck(64,64,256,True,True),
            Bottleneck(256,64,256),
            Bottleneck(256,64,256),

            Bottleneck(256,128,512,True,True),
            Bottleneck(512,128,512),
            Bottleneck(512, 128, 512),
            Bottleneck(512, 128, 512),

            Bottleneck(512,256,1024,True,True),
            Bottleneck(1024,128,1024),
            Bottleneck(1024, 128, 1024),
            Bottleneck(1024, 128, 1024),
            Bottleneck(1024, 128, 1024),
            Bottleneck(1024, 128, 1024),
        )
        self.conv_2 = torch.nn.Sequential(
            Bottleneck(1024, 512, 2048, True, True),
            Bottleneck(2048, 512, 2048),
            Bottleneck(2048, 512, 2048),

            torch.nn.AdaptiveAvgPool2d(1)
        )

        self.classfier = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(2048,self.classes)
        )

    def forward(self,x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.classfier(x)
        return x


def download_pretrained_model(save_path):
    # 检查保存路径是否存在
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, 'resnet50_pretrained.pth')

    if not os.path.exists(model_path):
        # 下载预训练模型权重
        pretrained_resnet50 = models.resnet50(pretrained=True)
        torch.save(pretrained_resnet50.state_dict(), model_path)
        print(f"预训练模型权重已下载并保存到: {model_path}")
    else:
        print(f"预训练模型权重已存在于: {model_path}")

    return model_path

if __name__ == '__main__':
    # 创建自定义的 Resnet50 模型
    model = Resnet50(3, 10)

    # 加载预训练的权重
    pretrained_dict = torch.load('pretrained_weights/resnet50_pretrained.pth')
    model_dict = model.state_dict()

    # 过滤掉不匹配的权重
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # 测试前向传播
    x = torch.randn(200, 3, 224, 224)
    y = model(x)
    print(y.shape)



这里需要注意的一点是因为我们预训练的模型是使用ImageNet训练出来的,那时候输入数据格式为三通道大小为224×224。所以我们自己的数据也需要处理为(3,224,224)

🐑 迁移学习训练🐑

在使用预训练的权重训练时候,为了保证能够有一个更好地泛化能力,我们一般对特征提取的网络层 参数做尽可能少的变化,主要以调整分类层网络参数为主,这方面我们通常以设置不同的学习率来实现,例如这次的ResNet50网络。请添加图片描述
其中红色框和蓝色框标注的为特征提取层,绿色的是分类层。在设置优化器学习率时候为了保证有一个良好的泛化能力,将红色框中的网络设置为1e-5,蓝色为1e-4,绿色为1e-3
下面是训练优化器部分代码:


#optimizer
learning_rate_feature_extractor_1 = 1e-5
learning_rate_feature_extractor_2 = 1e-4
learning_rate_classfier = 1e-3
optimizer = torch.optim.SGD([
    {'params':model.conv_1.parameters(),'lr':learning_rate_feature_extractor_1,'momentum':0.1},
    {'params':model.conv_2.parameters(),'lr':learning_rate_feature_extractor_2,'momentum':0.9},
    {'params':model.classfier.parameters(),'lr':learning_rate_classfier,'momentum':0.9}
])

#criterion
criterion = torch.nn.CrossEntropyLoss()

后续加入到自己的训练中即可。

🐑 总结🐑

由于使用迁移学习的话对于输入数据格式限制比较大,所以在数据预处理的方面一定要先处理好,另外使用迁移学习也完全不用担心模型训练结果变差,只要训练足够多的轮次,模型性能和不使用迁移学习相比可能不会变好,但一定不会变差,但我们一般为了一个较好的额泛化能力一般都会只训练20轮左右。
如果有写的不对的地方欢迎指出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1966438.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第34篇 子程序FINDSUM求和<一>

Q:如何设计汇编语言程序求数组[1:n]的和? A:基本原理:可编写一段实现子程序FINDSUM,子程序中使用一个loop来实现数组的求和运算。子程序FINDSUM的参数N存储在内存中,主程序从该内存中将其读取到一个寄存器…

MES系统如何实现生产任务的自动或辅助调度

MES系统(Manufacturing Execution System,制造执行系统)通过一系列集成化的功能模块和智能算法,实现生产任务的自动或辅助调度。以下是MES系统实现生产任务自动或辅助调度的具体方式: 1. 生产计划与排程 计划制定&am…

【C++从小白到大牛】类和对象

目录 一、面向过程和面向对象初步认识 二、类的引入 三、类的定义 类的成员函数两种定义方式: 1. 声明和定义全部放在类体中 2. 类声明放在.h文件中,成员函数定义放在.cpp文件中 成员变量命名规则的建议: 四、类的访问限定符 【访问限…

4.2.2、存储管理-段式存储和段页式存储

段式存储 段式存储是指将进程空间分为一个个段,每段也有段号和段内地址,与页式存储不同的是,每段物理大小不同,分段是根据逻辑整体分段的. 地址表示:(段号,段内偏移):其中段内偏移不能超过该段号对应的段长,否则越界错误,而此地址对应的真正内存地址应该是:段号对应的基地址段…

lambdafunctionbind

lambda匿名函数 定义: 捕捉:传值/传引用/mutable 混合捕捉,=表全普通捕捉 即使全部捕捉, 编译器实现时也不一定全部传入, 编译器只会传入要用到的变量 lambda内可使用的变量的范围 lambda内只能用捕捉对…

Linux gcc day 9

cpu是一个只可以执行指令,不是cpu要打印而是我们要打印,然后编译成指令再给cpu,再通过操作系统进行操手 进程状态: 为什么会有这些状态? 进程的多状态,本质都是为了满足未来不同的运行场景 有那些状态&am…

linux系统的检测脚本,用于检查linux的网络配置,包括网络接口状态、IP地址、子网掩码、默认网关、DNS服务器、连通性测试等等

目录 一、要求 二、脚本介绍 1、脚本内容 2、脚本解释 (1) 检查是否以 root 用户身份运行 (2)显示脚本标题 (3)打印主机名 (4)获取网络接口信息 (5&#xff09…

React学习之props(父传子,子传父),Context组件之间的传参。

目录 前言 一、什么时候需要使用props? 二、使用 1.父传子 2.子传父 二、什么时候需要使用Context? 第一步: 第二步使用: 第一种: 第二种: 演示: 总结 前言 React学习笔记记录,pr…

python | TypeError: list indices must be integers or slices, not tuple

python | TypeError: list indices must be integers or slices, not tuple 在Python编程中,TypeError: list indices must be integers or slices, not tuple 是一个常见的错误。此错误通常发生在尝试使用非整数(如元组)作为列表索引时。本…

WSL和Windows建立TCP通信协议

1.windows配置 首先是windows端,启动TCP服务端,用来监听指定的端口号,其中IP地址可以设置为任意,否则服务器可能无法正常打开。 addrSer.sin_addr.S_un.S_addr INADDR_ANY; recv函数用来接收客户端传输的数据,其中…

游戏加速器哪个好用

对于游戏加速器,确实有很多不同的选择,每个加速器都有其独特的特点和优势。不过,我可以给你推荐一个最新上线的较受欢迎且评价较高的游戏加速器,供你参考: 深度加速器: 广泛支持:支持国内外众多…

RocketMQ批量消息

RocketMQ消息发送基本示例(推送消费者)-CSDN博客 RocketMQ消费者主动拉取消息示例-CSDN博客 RocketMQ顺序消息-CSDN博客 RocketMQ广播消息-CSDN博客 RocketMQ延时消息-CSDN博客 批量消息 批量消息是指将多条消息合并成一个批量消息,一次发送出去,原先的都是一次发一条.批量…

springboot四川旅游攻略分享互动平台-计算机毕业设计源码70222

摘 要 本研究基于Spring Boot框架开发了一款高效、可靠的四川旅游攻略分享互动平台。该系统主要面向管理员、普通用户和商家用户,涵盖了多个功能模块,包括旅游景点、旅游攻略、景点订单、酒店订单、酒店信息等。通过对系统需求的分析和设计,…

从数据规划到产品运营,拆解数据资产产品化的6大路径

数据资源入表对于企业数据资产的估值影响并不大,要想提升数据资产的整体价值,将数据资产进行产品化是更有效的途径之一。 那么,数据资产产品化的具体路径是怎样的? 在由WakeData惟客数据联合星光数智推出的直播栏目《星光对话》…

打破自闭症束缚:儿童康复案例揭秘

在自闭症的阴霾下,孩子们仿佛被困在一个无形的牢笼中,与外界的世界隔绝。然而,通过不懈的努力和科学的康复方法,许多孩子正在逐渐打破这一束缚,走向充满希望的未来。让我们一同走进几个令人鼓舞的儿童康复案例&#xf…

如何通过阿里云服务器部署hexo博客(超详细)

👏大家好!我是和风coding,希望我的文章能给你带来帮助! 🔥如果感觉博主的文章还不错的话,请👍三连支持👍一下博主哦 📝点击 我的主页 还可以看到和风的其他内容噢&#x…

零基础入门转录组数据分析——机器学习算法之boruta(筛选特征基因)

零基础入门转录组数据分析——机器学习算法之boruta(筛选特征基因) 目录 零基础入门转录组数据分析——机器学习算法之boruta(筛选特征基因)1. boruta基础知识2. boruta(Rstudio)——代码实操2. 1 数据处理…

[Docker][Docker Volume]详细讲解

目录 1.什么是存储卷?2.为什么需要存储卷?1.数据丢失问题2.性能问题3.宿主机和容器互访不方便4.容器和容器共享不方便 3.存储卷分类1.volume docker 管理卷2.bind mount 绑定数据卷3.tmpfs mount 临时数据卷 5.管理卷 Volume1.创建卷1.-v 参数2.--mount …

《Milvus Cloud向量数据库指南》——向量数据库性价比大比拼:谁才是性能之王?

在分析这份向量数据库(Vector Databases)的性价比排名表格时,我们需要从多个维度深入探讨,包括但不限于硬件配置、价格/性能比(QP$,即每百万次查询所花费的价格)、数据集大小、查询类型(无标量过滤、低标量过滤、高标量过滤)以及不同服务提供商之间的比较。以下是一个…

微波治疗仪,美容仪,爆脂仪电源板

分享一下爆脂仪,美容仪,微波治疗仪电源板,高压输出为-2000v,驱动电流最大100mA,匹配磁控管功率输出100w