PyTorch中加载模型权重 A匹配B|A不匹配B

news2024/9/22 13:37:23

在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
在这里插入图片描述

未修改网络,A与B一致

很简单,直接.load_state_dict()

net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))

修改了网络,A与B不一致

[pytorch官方文档](Search — PyTorch master documentation):

load_state_dict(state_dict, strict=True)

将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的

load_state_dict()函数有两个返回值:

missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表

方法一:

将strict改为false,加载键值相同的部分。

model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:

ANet = torch.load('ANet.pt')  # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()  
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典

# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}

# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)

# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

方法二:

1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/840081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vector - CAPL - 诊断模块函数(发送及流控制帧)

目录 CanTpSendData - 诊断数据的发送 代码示例 CanTpGetHWSTmin & CanTpSetHWSTmin - 获取和设置硬件STMin的值 代码示例 CanTpSetSTminReduction - 将STmin设置需要的值 代码示例 CanTpGetBlockSize & CanTpSetBlockSize 代码示例 CanTpGetSTmin & Can…

一文学透设计模式

设计模式是什么? 设计模式是软件开发人员在软件开发过程中面临的一般问题的解决方案,代表了解决一些问题的最佳实践。这些解决方案是众多软件开发人员经过相当长的一段时间的试验和错误总结出来的。 说白了,设计模式对于软件开发人员来说就…

一百四十四、Kettle——Linux上安装的kettle8.2连接MySQL数据库

一、目的 在Linux上安装好kettle,然后用kettle连接MySQL数据库 注意:kettle版本是8.2 二、实施步骤 (一)到kettle安装目录下启动Linux的kettle服务 # cd /opt/install/data-integration/ # ./spoon.sh (二&#x…

【前端】搭建Vue3框架

目录 一、搭建准备二、node.js安装1、下载并安装2、配置默认安装目录和缓存日志目录①、创建默认安装目录和缓存日志目录(我的node.js目录在D盘,所以直接在node.js文件夹下创建)②、执行命令,配置默认安装目录和缓存日志目录到刚才…

OpenMesh 网格简化算法(基于边长度)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 网格简化的算法有很多种,基于边结构进行简化的方法便是其中一种方式。此类算法主要关注于它们如何选择要收缩的边,并且似乎都是为流形表面设计的,尽管边缘收缩也可以用于非流形表面,但往往会存在变形较大的问题…

2023-08-05 LeetCode每日一题(合并两个有序链表)

2023-08-05每日一题 一、题目编号 21. 合并两个有序链表二、题目链接 点击跳转到题目位置 三、题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例1: 示例2: 示例3: …

AcWing 372. 棋盘覆盖(二分图匈牙利算法)

输入样例&#xff1a; 8 0输出样例&#xff1a; 32 解析&#xff1a; n为100&#xff0c;状压肯定爆。 将每个骨牌看成二分图的一个匹配&#xff0c;即查找二分图的一个最大匹配&#xff0c;匈牙利算法。 #include<bits/stdc.h> using namespace std; const int N105…

漫画 | TCP/IP之大明邮差

后记&#xff1a; 1973年&#xff0c;卡恩与瑟夫开发出了网络中最核心的两个协议&#xff1a;TCP协议和IP协议&#xff0c;随后为了验证两个协议的可用性&#xff0c;他们做了一个实验&#xff0c;在多个异构网络中进行数据传输&#xff0c;数据包在经过近10万公里的旅程后到达…

合并果子C++详解

题目描述 在一个果园里&#xff0c;多多已经将所有的果子打了下来&#xff0c;而且按果子的不同种类分成了不同的堆。多多决定把所有的果子合成一堆。 每一次合并&#xff0c;多多可以把两堆果子合并到一起&#xff0c;消耗的体力等于两堆果子的重量之和。可以看出&#xff0c;…

Golang之路---04 并发编程——信道/通道

信道/通道 如果说 goroutine 是 Go语言程序的并发体的话&#xff0c;那么 channel&#xff08;信道&#xff09; 就是 它们之间的通信机制。channel&#xff0c;是一个可以让一个 goroutine 与另一个 goroutine 传输信息的通道&#xff0c;我把他叫做信道&#xff0c;也有人将…

MIT 6.824 -- MapReduce -- 01

MIT 6.824 -- MapReduce -- 01 引言抽象和实现可扩展性可用性(容错性)一致性MapReduceMap函数和Reduce函数疑问 课程b站视频地址: MIT 6.824 Distributed Systems Spring 2020 分布式系统 推荐伴读读物: 极客时间 – 大数据经典论文解读DDIA – 数据密集型应用大数据相关论文…

kagNet:对常识推理的知识感知图网络 8.4+8.5

这里写目录标题 摘要介绍概述问题陈述推理流程 模式图基础概念识别模式图构造概念网通过寻找路径来匹配子图基于KG嵌入的路径修剪 知识感知图网络图卷积网络&#xff08;GCN&#xff09;关系路径编码分层注意机制 实验数据集和使用步骤比较方法KAGNET是实施细节性能比较和分析I…

Redis实战(5)——Redis实现消息队列

消息队列&#xff0c;顾名思义&#xff0c;就是一个存放消息的队列。最简单的消息队列包含3个角色 生产者&#xff1a;将消息存入队列中队列&#xff1a;存放和管理消息消费者&#xff1a; 将消息从队列中取出来并做业务处理 R e d i s 提供了三种实现消息队列的方式&#x…

【力扣每日一题】2023.8.5 合并两个有序链表

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们两个有序的链表&#xff0c;要我们保持升序的状态合并它们。 我们可以马上想要把两个链表都遍历一遍&#xff0c;把所有节点的…

AI抠图使用指南:Stable Diffusion WebUI Rembg实用技巧

抠图是图像处理工具的一项必备能力&#xff0c;可以用在重绘、重组、更换背景等场景。最近我一直在探索 Stable Diffusion WebUI 的各项能力&#xff0c;那么 SD WebUI 的抠图能力表现如何呢&#xff1f;这篇文章就给大家分享一下。 安装插件 作为一个生成式AI&#xff0c;SD…

aardio:用 WebView 模仿 mdict 界面

aardio&#xff1a;用 WebView 模仿 mdict 界面 import win.ui; /*DSG{{*/ mainForm win.form(text"aardio2";right889;bottom467) mainForm.add( button{cls"button";text"go";left335;top22;right399;bottom41;z2}; button2{cls"button…

基于以太网的煤矿电力监控系统的设计与应用 安科瑞 许敏

摘 要&#xff1a;针对传统煤矿电力监控系统通讯网络性能较差、无法实现准确故障定位及报警、不具备数据交互功能等问题&#xff0c;结合分布式网络及GPS授时技术设计了一套基于工业以太网及RS485总线架构的煤矿电力监控系统&#xff0c;可实现对井下供电网络及设备的远程监控…

C++ 字符串

在C语言中是使用字符型数组来存放字符串&#xff0c;C程序也仍然可以沿用这种方法。不仅如此&#xff0c;C库中还预定义了string类。 1.用字符数组存储和处理字符串 字符串常量是用一对双引号括起来的字符序列。例如"abcd"&#xff0c;"China"都是字符串…

跨域+四种解决方法

文章目录 一、跨域二、JSONP实现跨域请求三、前端代理实现跨域请求四、后端设置请求头实现跨域请求五、Nginx代理实现跨域请求5.1 安装Nginx软件5.2 使用Ubuntu安装nginx 本文是在学习课程满神yyds后记录的笔记&#xff0c;强烈推荐读者去看此课程。 一、跨域 出于浏览器的同…

【力扣】23. 合并 K 个升序链表 <链表指针、堆排序、分治>

目录 【力扣】23. 合并 K 个升序链表题解方法一&#xff1a;暴力&#xff0c;先遍历取出来值到数组中排序&#xff0c;再生成新链表方法二&#xff1a;基础堆排序&#xff08;使用优先队列 PriorityQueue&#xff09;方法三&#xff1a;基础堆排序&#xff08;使用优先队列 Pri…