【Pytorch】自定义模型、自定义损失函数及模型删除修改层的常用操作

news2025/2/26 22:01:37

在这里插入图片描述

目录

  • 问题一:更改模型最后一层,删除最后一层
  • 问题二:自定义模型及参数冻结
  • 问题三:自定义损失函数及广播机制

问题1:更改模型最后一层,删除最后一层,添加层。

改变模型最后一层

# Load the model
model = models.resnet18(pretrained = False)

# Get number of parameters going in to the last layer. we need this to change the final layer. 
num_final_in = model.fc.in_features

# The final layer of the model is model.fc so we can basically just overwrite it 
#to have the output = number of classes we need. Say, 300 classes.
NUM_CLASSES = 300
model.fc = nn.Linear(num_final_in, NUM_CLASSES)

若有些网络的最后一层不是FC层,那么我们可以先去获取最后一层的层名,再根据层名进行替换

# Load the model
model = models.resnet18(pretrained = False)

# 打印所有层的层名
for name, module in model.named_modules():
    print(name)

删除最后一层

我们可以像以前一样使用 model.children() 来获取层。然后,我们可以通过在其上使用 list() 命令将其转换为列表。然后,我们可以通过索引列表来删除最后一层。最后,我们可以使用 PyTorch 函数 nn.Sequential() 将这个修改后的列表一起堆叠到一个新模型中。可以以任何你想要的方式编辑列表。也就是说,如果你想要倒数第 3 层图像的特征,你可以删除最后 2 层!

甚至可以从模型中间删除层。但很明显,这会导致进入其后层的特征数量不正确,因为大多数层都会改变图像的大小。在这种情况下,你可以索引模型的特定层并覆盖它!

# Load the model
model = models.resnet18(pretrained = False)

new_model = nn.Sequential(*list(model.children())[:-1])

# 获取倒数第3层
new_model_2_removed = nn.Sequential(*list(model.children())[:-2])

添加图层

比如说,想向我们现在拥有的模型添加一个全连接的层。一种明显的方法是编辑我上面讨论的列表并向其附加另一层。然而,通常我们训练了这样一个模型,并想看看我们是否可以加载该模型,并在其之上添加一个新层。如上所述,加载的模型应该与保存的模型具有相同的体系结构,因此我们不能使用列表方法。

我们需要在上面添加层。在 PyTorch 中执行此操作的方法很简单——我们只需要创建一个自定义模型!这将我们带到下一节 - 创建自定义模型!

自定义模型

让我们制作一个自定义模型。如上所述,我们将从预训练网络加载一半模型。这看起来很复杂,对吧?模型的一半是经过训练的,一半是新的。此外,我们希望其中一些被冻结。有些是可更新的。一旦你完成了这个,你就可以在 PyTorch 中对模型架构做任何事情。

# Some imports first
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
from torch.autograd.variable import Variable
from torchvision import datasets, models, transforms

# New models are defined as classes. Then, when we want to create a model we create an object instantiating this class.
class Resnet_Added_Layers_Half_Frozen(nn.Module):
    def __init__(self, LOAD_VIS_URL=None):
        super(ResnetCombinedFull2, self).__init__()
    
         # Start with half the resnet model, swap out the final layer because that's the model we had defined above. 
        model = models.resnet18(pretrained = False)
        num_final_in = model.fc.in_features
        model.fc = nn.Linear(num_final_in, 300)
        
        # Now that the architecture is defined same as above, let's load the model we would have trained above. 
        checkpoint = torch.load(MODEL_PATH)
        model.load_state_dict(checkpoint)
        
        
        # Let's freeze the same as above. Same code as above without the print statements
        child_counter = 0
        for child in model.children():
            if child_counter < 6:
                for param in child.parameters():
                    param.requires_grad = False
            elif child_counter == 6:
                children_of_child_counter = 0
                for children_of_child in child.children():
                    if children_of_child_counter < 1:
                        for param in children_of_child.parameters():
                            param.requires_grad = False
                    else:
                        children_of_child_counter += 1

            else:
                print("child ",child_counter," was not frozen")

            child_counter += 1
        
        # Now, let's define new layers that we want to add on top. 
        # Basically, these are just objects we define here. The "adding on top" is defined by the forward()
        # function which decides the flow of the input data into the model.
        
        # NOTE - Even the above model needs to be passed to self.
        self.vismodel = nn.Sequential(*list(model.children()))
        self.projective = nn.Linear(512,400)
        self.nonlinearity = nn.ReLU(inplace=True)
        self.projective2 = nn.Linear(400,300)
        
    
    # The forward function defines the flow of the input data and thus decides which layer/chunk goes on top of what.
    def forward(self,x):
        x = self.vismodel(x)
        x = torch.squeeze(x)
        x = self.projective(x)
        x = self.nonlinearity(x)
        x = self.projective2(x)
        return x

自定义损失函数

现在我们已经有了我们的模型,我们可以加载任何东西并创建我们想要的任何架构。这给我们留下了任何管道中的 2 个重要组件 - 加载数据和训练部分。我们来看看训练部分。这一步最重要的两个组成部分是优化器和损失函数。损失函数量化了我们现有模型与我们想要达到的目标之间的距离,优化器决定如何更新参数,以便我们可以最大限度地减少损失。

有时,我们需要定义自己的损失函数。这里有一些事情要知道

  • 自定义损失函数也是使用自定义类定义的。它们像自定义模型一样继承自 torch.nn.Module。
  • 通常,我们需要更改其中一项输入的维度。这可以使用 view() 函数来完成。
  • 如果我们想为张量添加维度,请使用 unsqueeze() 函数。
  • 损失函数最终返回的值必须是标量值。不是矢量/张量。
  • 返回的值必须是一个变量。这样它就可以用于更新参数。最好的方法是确保传入的 x 和 y 都是变量。这样,两者的任何函数也将是一个变量。
  • Pytorch 变量只是一个 Pytorch 张量,但 Pytorch 正在跟踪对其进行的操作,以便它可以反向传播以获得梯度。

这里我展示了一个名为 Regress_Loss 的自定义损失,它将 2 种输入 x 和 y 作为输入。然后将 x 重塑为与 y 相似,最后通过计算重塑后的 x 和 y 之间的 L2 差来返回损失。这是你在训练网络中经常遇到的标准事情。

将 x 视为形状 (5,10),将 y 视为形状 (5,5,10)。所以,我们需要给 x 添加一个维度,然后沿着添加的维度重复它以匹配 y 的维度。然后,(xy) 将是形状 (5,5,10)。我们必须将所有三个维度相加,即三个 torch.sum() 以获得标量。

该操作经常遇到,和numpy中的广播机制一致,需要掌握

# 
class Regress_Loss(torch.nn.Module):
    
    def __init__(self):
        super(Regress_Loss,self).__init__()
        
    def forward(self,x,y):
        y_shape = y.size()[1]
        x_added_dim = x.unsqueeze(1)
        x_stacked_along_dimension1 = x_added_dim.repeat(1, NUM_WORDS, 1)
        diff = torch.sum((y - x_stacked_along_dimension1)**2, 2)
        totloss = torch.sum(torch.sum(torch.sum(diff)))
        return totloss

请关注博主,一起玩转人工智能及深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/618720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jmter自动化

一、接口测试流程 1、拿到api接口文档&#xff0c;熟悉接口业务。 2、编写测试用例。 正例&#xff1a;正常参数&#xff0c;是否接口正常 反例&#xff1a;鉴权异常情况、参数异常、兼容性、黑名单、调用次数异常 3、使用接口测试用具&#xff08;Jmeter&#xff09; 4、…

chatgpt赋能python:Python安装gym:入门指南

Python安装gym: 入门指南 如果您是一位正在学习强化学习的学生&#xff0c;或者是一位研究者、开发人员&#xff0c;那么您一定会对OpenAI出品的gym库感兴趣。该库为编写和比较强化学习算法提供了一组标准环境。但是&#xff0c;在使用gym之前&#xff0c;您需要将其安装到您的…

ThinkPad无法进系统的解决方案(实测)

ThinkPad无法进系统如何解决&#xff1f; 不一样的笔记本进到BIOS的方法是不太一样的&#xff0c;下面就和大伙儿具体解读电脑上进到thinkpad的bios设置启动项的方法吧。 在开机或重启的Lenovo画面自检处&#xff0c;快速、连续多次按键盘的“F1”按键&#xff0c;即可进入BI…

基于html+css的图展示112

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

PHP实现文件上传

上传文件的必备三个条件&#xff1a; 1、上传到后台的文件 2、method "post";&#xff08;不可以为get方法&#xff09; 3、enctype "multipart/form-data";&#xff08;注意哦&#xff0c;是data&#xff0c;不是date&#xff09; 三者缺一不可 后台…

抛弃传统网络?SDN协议、标准、接口对比分析!

概要&#xff1a; 随着网络规模的不断扩大和复杂性的增加&#xff0c;传统的网络架构已经难以满足日益增长的网络需求。SDN&#xff08;Software Defined Networking&#xff09;技术的出现&#xff0c;为网络的管理和控制带来了革命性的变化。SDN的核心思想是将网络的控制和管…

chatgpt赋能python:Python如何访问文件

Python如何访问文件 Python是一种优秀的编程语言&#xff0c;被广泛应用于各种领域&#xff0c;包括文件处理。在Python中&#xff0c;我们可以使用内置的文件处理功能访问文件。 什么是文件&#xff1f; 文件是计算机系统中的一种数据存储形式。它们可以包含任何类型的信息…

u盘视频丢失怎么找回?居然还得靠它

u盘视频丢失怎么找回&#xff1f;U盘作为一款常用的存储数据的工具&#xff0c;因为其自身的小巧便携&#xff0c;方便我们随身携带&#xff0c;深受广大用户的喜爱。在使用U盘的过程中&#xff0c;我们也会遇到一些文件丢失的麻烦&#xff0c;比如误删除里面的视频文件。当遇到…

浅谈 Tarjan 算法

在了解 Tarjan 算法之前&#xff0c;我们先来了解 dfs 搜索树。 1 dfs 生成树 定义&#xff1a; dfs 遍历整张图&#xff0c;按照 dfs 序构成一棵树。 1.1 有向图的 dfs 生成树 有向图的 dfs 生成树包括四种边&#xff1a; 树边&#xff08;tree edge&#xff09;&#xff…

CDC是什么?有没有合适的技术方案?

CDC 是 Change Data Capture(变更数据获取)的简称。核心思想是&#xff0c;监测并捕获数据库的 变动&#xff08;包括数据或数据表的插入、更新以及删除等&#xff09;&#xff0c;将这些变更按发生的顺序完整记录下 来&#xff0c;写入到消息中间件中以供其他服务进行订阅及…

阿里、百度、值得买齐发声,电商的“AIGC式”进化

配图来自Canva可画 一年一度618要来了&#xff0c;和往年一样折扣力度、明星直播等话题被炒得火热&#xff0c;不同的是今年618的科技属性更强。 究其原因&#xff0c;过去半年AIGC技术被电商平台应用到实际运营中&#xff0c;“AIGC选品”、“虚拟货场”、“智能客服”成为电商…

《MySQL(六):基础篇- 事务》

文章目录 6. 事务6.1 事务简介6.2 事务操作6.2.1 未控制事务6.2.2 控制事务一6.2.3 控制事务二 6.3 事务四大特性6.4 并发事务问题6.5 事务隔离级别 6. 事务 6.1 事务简介 事务 是一组操作的集合&#xff0c;它是一个不可分割的工作单位&#xff0c;事务会把所有的操作作为一…

【机器学习】神经网络入门

神经网络 非线性假设 如果对于下图使用Logistics回归算法&#xff0c;如果只有x1和x2两个特征的时候&#xff0c;Logistics回归还是可以较好地处理的。它可以将x1和x2包含到多项式中 但是有很多问题所具有的特征远不止两个&#xff0c;甚至是上万个&#xff0c;如果我们想要…

MySQL数据库给表添加索引

说明&#xff1a;当数据库中的记录数过多时&#xff0c;查询速度会显著变慢。此时可以给表创建索引&#xff0c;提高查询速度。 一、创建索引前 我现在有一张表&#xff0c;有1000万条记录&#xff0c;根据username值&#xff0c;查询一条记录&#xff0c;测试下查询时间&…

赛宁网安助力智能网联汽车发展 | “饶派杯”XCTF车联网安全挑战赛圆满收官

​​ 2023年5月31日&#xff0c;“饶派杯”XCTF车联网安全挑战赛在江西省上饶市圆满落幕。本次大赛特邀国内21支精英战队参与比拼&#xff0c;参赛选手覆盖全国知名高校、自动驾驶汽车和科研院所等车联网安全人才。最终&#xff0c;经过9个小时激烈角逐&#xff0c;来自南京邮电…

chatgpt赋能python:Python自动更新技术的应用

Python自动更新技术的应用 Python是一款高效的编程语言&#xff0c;广泛应用于各种软件开发、数据分析及人工智能等领域。随着大数据和人工智能的快速发展&#xff0c;Python语言的应用也日益普及&#xff0c;更多的企业和个人开始使用Python编写自己的程序。而随着程序的使用…

上榜“网络安全企业科技能力百强”啦!

最新公布的《2023网络安全企业科技能力报告》显示&#xff0c;顶象在“2023网络安全企业科技能力百强”和“2023网络安全企业有效专利数量百强”等两个榜单中均处于前列。 《2023网络安全企业科技能力报告》由中关村网络安全与信息化产业联盟发布&#xff0c;旨在探究网络安全…

微信开发者工具公众号网页项目实现本地项目调试

背景 最近业务场景中有需要微信H5进行实现,需要网页授权,需要用户进行点击授权的操作,跳转一个微信公众号后台设置的授权域名下的网页后才能获取到code,其他网页授权步骤这里不进行展开,不想频繁的打包上传的服务器看实现效果,所以考虑从微信开发者工具中实现本地调试,搜索过相…

如何开发原生的 JavaScript 插件(知识点+写法)

一、前言 通过 "WWW" 原则我们来了解 JavaScript 插件这个东西 第一个 W "What" -- 是什么?什么是插件,我就不照搬书本上的抽象概念了,我个人简单理解就是,能方便实现某个功能的扩展工具.(下面我会通过简单的例子来帮助读者理解) 第二个 W "Why&q…

(9)基于发射器的调优

文章目录 前言 1 概述 2 调优值 3 用任务规划器设置 前言 你可以在飞行中使用你的遥控发射器进行广泛的参数调优。这是为那些无法使用自动调优功能的高级用户准备的&#xff0c;或者希望通过对每个参数的完全手动调优控制来进行微调。 1 概述 基于发射机的调优允许你在飞行…