PyTorch如何修改模型(魔改)

news2024/11/24 23:49:17

文章目录

  • PyTorch如何修改模型(魔改)
    • 1.修改模型层(模型框架⭐)
      • 1.1通过继承修改模型
      • 1.2通过组合修改模型(重点学👀)
      • 1.3通过猴子补丁修改模型
    • 2.添加外部输入
    • 3.添加额外输出
    • 参考

PyTorch如何修改模型(魔改)

对模型缝缝补补、修修改改,是我们必须要掌握的技能,本文详细介绍了如何修改PyTorch模型?也就是我们经常说的如何魔改。👍

PyTorch 的模型是一个 torch.nn.Module 的某个子类的对象,修改模型实际就等价于修改某个类,对面向对象熟悉的同学应该知道,对类做修改有两个经典的方法:组合继承

1.修改模型层(模型框架⭐)

1.1通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,最后在子类中实现魔改的逻辑。其大致的框架代码如下所示:

from torchvision.models import ResNet

class CustomizedResNet(ResNet):

    def __init__(self):
        super().__init__()
        ...
        
    def forward(self, x):
        ...

下面这个例子,将对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
import torch

# 定义一个自定义的ResNet类,继承自torchvision的ResNet类
class CustomizedResNet(ResNet):
    def __init__(self, block, layers, num_classes=2):
        """
        初始化函数
        block: ResNet中的基本块类型,可以是BasicBlock或Bottleneck
        layers: 每个层级的基本块数量,是一个列表
        num_classes: 输出的类别数量,默认为2
        """
        # 调用父类的初始化方法
        super().__init__(block, layers, num_classes)
        # 重新定义全连接层,改变输出的特征数量
        self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)

    def forward(self, x):
        # 以下是ResNet的前向传播过程
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # 通过四个残差层
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        # 将四个残差层的输出进行拼接
        x = torch.cat(
            [self.avgpool(x1),
             self.avgpool(x2),
             self.avgpool(x3),
             self.avgpool(x4),], dim=1)

        # 将拼接后的张量展平
        x = torch.flatten(x, 1)
        # 通过全连接层,得到最终的输出
        x = self.fc(x)

        return x

# 创建不同版本的ResNet模型
new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

1.2通过组合修改模型(重点学👀)

在面向对象编程中,可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中,下面是大致的框架代码:

from torchvision.models import resnet18
import torch.nn as nn

class CustomizedResNet(nn.Module):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        ...

    def forward(self, x):
        ...

my_resnet18 = CustomizedResNet(resnet18)

同样,实现对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models import resnet50

class CustomizedResNet(torch.nn.Module):
    def __init__(self, backbone, num_classes=2):
        super().__init__()
        self.backbone = backbone
        self.fc = torch.nn.Linear(3840, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x1 = self.backbone.layer1(x)
        x2 = self.backbone.layer2(x1)
        x3 = self.backbone.layer3(x2)
        x4 = self.backbone.layer4(x3)

        x = torch.cat(
            [
                self.backbone.avgpool(x1),
                self.backbone.avgpool(x2),
                self.backbone.avgpool(x3),
                self.backbone.avgpool(x4),
            ],
            dim=1,
        )

        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

new_resnet50 = CustomizedResNet(resnet50())

1.3通过猴子补丁修改模型

最简单粗暴的方法:猴子补丁(Monkey Patch)。之所以叫猴子补丁,是因为这种方法从程序设计的角度上来说,是具有破坏性的。而且这种方法仅能实现一些简单的修改需求,所以还是推荐使用继承或组合去修改我们的模型。😉

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改。下面是把 ResNet34 的输出从 1000 改为 1 的简单例子:

from torchvision.models import resnet50
import torch.nn as nn

model = resnet50()
model.fc = nn.Linear(2048, 1)

还有一个例子,以 PyTorch 官方视觉库 torchvision 预定义好的模型 ResNet50 为例,修改模型的某一层或者某几层。先观察一下它的网络结构:

import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision.models as models
net = models.resnet50()
print(net)

假设要用这个模型去做一个10分类的问题,就应该修改模型的 fc 层,将其输出节点数替换为10。另外,想再加一层全连接层。可以做如下修改:

classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
                          ('relu1', nn.ReLU()), 
                          ('dropout1',nn.Dropout(0.5)),
                          ('fc2', nn.Linear(128, 10)),
                          ('output', nn.Softmax(dim=1))
                          ]))

net.fc = classifier

这里的操作相当于将模型(net)最后名称为“fc”的层替换成了名称为“classifier”的结构。

2.添加外部输入

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加输入和后续层之间的连接关系,从而完成模型的修改。

以 torchvision 的 resnet50 模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量 add_variable 来辅助预测。具体实现如下:

class Model(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc_add = nn.Linear(1001, 10, bias=True)
        self.output = nn.Softmax(dim=1)
        
    def forward(self, x, add_variable):
        x = self.net(x)
        x = torch.cat((self.dropout(self.relu(x)),
                       add_variable.unsqueeze(1)),1)
        x = self.fc_add(x)
        x = self.output(x)
        return x

这里的实现要点是通过torch.cat实现了tensor的拼接。torchvision 中的 resnet50 输出是一个1000维的 tensor,通过修改 forward 函数,先将 1000 维的 tensor 通过激活函数层和dropout层,再和外部输入变量"add_variable"拼接,最后通过全连接层映射到指定的输出维度 10。

另外这里对外部输入变量"add_variable"进行 unsqueeze 操作是为了和 net 输出的 tensor 保持维度一致,常用于 add_variable 是单一数值 (scalar) 的情况,此时 add_variable 的维度是 (batch_size, ),需要在第二维补充维数1,从而可以和 tensor 进行torch.cat操作。
unsqueeze与sequeeze语法说明

最后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

另外别忘了,训练中在输入数据的时候要给两个inputs:

outputs = model(inputs, add_var)

3.添加额外输出

有时候在模型训练中,除了模型最后的输出外,我们需要输出模型某一中间层的结果,以施加额外的监督,获得更好的中间层结果。基本的思路是修改模型定义中 forward 函数的 return 变量。

依然以 resnet50 做 10 分类任务为例,在已经定义好的模型结构上,同时输出 1000 维的倒数第二层和 10 维的最后一层结果。具体实现如下:

class Model(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(1000, 10, bias=True)
        self.output = nn.Softmax(dim=1)
        
    def forward(self, x, add_variable):
        x1000 = self.net(x)
        x10 = self.dropout(self.relu(x1000))
        x10 = self.fc1(x10)
        x10 = self.output(x10)
        return x10, x1000

之后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

out10, out1000 = model(inputs, add_var)

参考

  • Chenglu’s Log

  • Pytorch修改预训练模型的方法汇总

😃😃😃

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1634723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue知识

一、初始vue Vue核心 Vue简介 初识 (yuque.com) 1.想让Vue工作,就必须创建一个Vue实例,且要传入一个配置对象 2.root容器里的代码依然符合html规范,只不过混入了一些特殊的Vue语法 3.root容器里的代码被称为【Vue模板】 4.Vue实例和容器…

TreeSet 和 TreeMap 和 HashSet 和 HashMap

一、二叉搜索树 1、概念 (1)二叉搜索树 要么是一棵空树,要么就得满足左子树上所有结点的值都小于根结点的值,右子树上所有结点的值都大于根结点的值,即左边比我小,右边比我大。二叉树的左右子树也分别都是…

ssm092基于Tomcat技术的车库智能管理平台+jsp

车库智能管理平台设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本车库智能管理平台就是在这样的大环境下诞生,其可以帮助管理者在短…

【稳定检索|投稿优惠】2024年应用数学、建模与计算机工程国际会议(IASAMCE 2024)

2024 International Conference on Applied Mathematics, Modeling, and Computer Engineering 一、大会信息 会议名称:2024年应用数学、建模与计算机工程国际会议 会议简称:IASAMCE 2024 收录检索:提交Ei Compendex,CPCI,CNKI,Google Schola…

C#编程模式之装饰模式

创作背景:朋友们,我们继续C#编程模式的学习,本文我们将一起探讨装饰模式。装饰模式也是一种结构型设计模式,它允许你通过在运行时向对象添加额外的功能,从而动态的修改对象的行为。装饰模式本质上还是继承的一种替换方…

设计模式之监听器模式ListenerPattern(三)

一、介绍 监听器模式是一种软件设计模式,在对象的状态发生改变时,允许依赖它的其他对象获得通知。在Java中,可以使用接口和回调机制来实现监听器模式。 二、代码实例 1、事件Event类 package com.xu.demo.listener;// 事件类 public class…

QT-QTCreator环境配置

准备工作: 下载QT: 链接:https://pan.baidu.com/s/1prJcsC4DGqhKiXvLuPQFVA?pwd60b3 提取码:60b3下载WindowsKits: 链接:https://pan.baidu.com/s/1QNiS3HpbH5M5kXx5AhkqnQ?pwde2h8 提取码:e2h8安装的…

SpringBoot配置HTTPS及开发调试

前言 在实际开发过程中,如果后端需要启用https访问,通常项目启动后配置nginx代理再配置https,前端调用时高版本的chrome还会因为证书未信任导致调用失败,通过摸索整理一套开发调试下的https方案,特此分享 后端配置 …

影响外汇交易盈利的因素有哪些?

外汇交易就是通过汇率的差价来赚取相应的利润。在外汇交易中,投资者是否可以盈利,主要取决于是否正确的判断了市场趋势和行情。投资者在交易过程中受到主观和客观的因素影响,具体包含这些内容。 影响外汇交易盈利的因素有哪些? 1、…

5月软考中级软件设计师100条知识点速记!

最近有一些小伙伴问我:现在开始备考软考还来得及吗?其实只是备考中级的话时间还是比较充足的,5月底考试,每年都有不少人五一假期才开始备考并通过的,大家抓紧时间学起来吧! 今天为大家分享“24上半年软考软…

GIT入门到实战

文章目录 版本控制常见的版本控制工具版本控制分类Git与SVN的主要区别 Git基本理论(重要)三个区域工作流程 GIT文件操作文件的四种状态查看文件状态忽略文件 GIT 常见问题 版本控制 版本控制(Revision control)是一种在开发的过程…

java连锁美业收银系统源码-美业SaaS系统【微信小程序端】功能及应用场景介绍

博弈美业管理系统源码 连锁多门店美业收银系统源码 多门店管理 / 会员管理 / 预约管理 / 排班管理 / 商品管理 / 促销活动 PC管理后台、手机APP、iPad APP、微信小程序 ( 需要系统演示视频可联系观看 ) ▶ 顾客微信小程序端: 场景名称 场…

prime1--vulnhub靶场通关教程

一. 信息收集 1. 探测目标主机IP地址 arp-scan -l //查看网段 vm 编辑--查看虚拟网络编辑器,看到靶机的网段 网段是: 192.168.83.0 是c段网络 2. 全面检测目标IP nmap -sP 192.168.83.1/24 靶机ip是: 192.168.83.145 攻击机的ip是&…

邦注科技 模具清洗机 干冰清洗机 干冰清洗设备原理介绍

干冰清洗机,这款神奇的清洁设备,以干冰颗粒——固态的二氧化碳,作为其独特的清洁介质。它的工作原理可谓独具匠心,利用高压空气将干冰颗粒推送至超音速的速度,犹如一颗颗银色的流星,疾速喷射至待清洗的物体…

【大模型系列】指令微调

概述 指令微调(Instruction Tuning)是指使用自然语言形式的数据对预训练后的大语言模型进行参数微调,22年谷歌ICLR论文中提出这个概念。在其它文献中,指令微调也被称为有监督微调(Supervised Fine-tuning)…

Python-VBA函数之旅-object基类(非函数)

目录 一、object基类的常见应用场景 二、object基类使用注意事项 三、如何用好object基类? 1、object基类: 1-1、Python: 1-2、VBA: 2、推荐阅读: 个人主页:神奇夜光杯-CSDN博客 一、object基类的…

YOLOV8 pycharm

1 下载pycharm 社区版 https://www.jetbrains.com/zh-cn/pycharm/download/?sectionwindows 2 安装 3 新建 4 选择 文件-> setting 配置环境变量 5 添加conda 环境

MyBatis-plus笔记——常用注解

TableName 在开发的过程中,我们经常遇到以上的问题,即实体类所对应的表有固定的前缀,例如 t_ 或 tbl_ 此时,可以使用 TableName 指定表前缀 Data TableName("t_user") public class User {private Long id;private Stri…

起薪4万的AI产品经理,必须掌握的技术模型与3大知识体系

这是求职产品经理系列的第170篇文章 一、AI行业的招聘趋势以及人才紧缺度 根据脉脉《2023年人才报告》显示:人工智能成为2022最缺人行业,⼈⼯智能⾏业的⼈才紧缺指数(⼈才需求量/⼈才投递量)为0.83,也就是说这个领域人…

react props传参

props是父子传参的常用方法。 一、主要功能 1.传参 定义:父级组件向子级组件传递参数。 2.验证数据类型格式 定义:可以指定父组件传递过来数据为指定类型。 3.设置默认值 定义:在参数未使用时,直接默认为指定值。 二、实例代…