【学习记录】pytorch载入模型的部分参数

news2025/4/27 22:32:27

需要从PointNet网络框架中提取encoder部分的参数,然后赋予自己的模型。因此,需要从一个已有的.pth文件读取部分参数,加载到自定义模型上面。做了一些尝试,记录如下。

关于模型保存与载入

torch.save(): 使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,可以保存各种对象,包括模型、张量和字典等。
torch.load(): 使用pickle unpickle工具将pickle的对象文件反序列化为内存。
可以看出,pth文件本质上是一个序列化的dict。

我们在save时,代码如下:

state = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}

然后以下代码load进来:

checkpoint = torch.load(args.model_file,  map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

查看checkpoint,可以看到包含的就是自己保存时的3个dict,分别是epoch,model_state_dict,和optimizer信息。
在这里插入图片描述

这里我们重点关注 model_state_dict,数据类型是一个 OrderedDict,有序字典。展开如下:
在这里插入图片描述
可以看到里面包含了自己定义的encoder,bn1-3,mlp 1-4层,以及每个层对应的参数(权重、bias,对于bn层还有mean, var等)。
这个Dict的顺序就是在Model中我们定义的顺序,这个和模型是一致的。
因此,如果载入时的模型和保存模型完全一致,直接用load_state_dict()就可以按顺序把数据载入进来。但如,如果定义不同怎么办?这就需要手动载入。

方法1:手动载入指定层的参数

从debug的断点可以看到,每个参数就是存在dict中的一个tensor。因此,我们只要读取对应的dict即可。
例如,encoder的conv1的权重,就是 checkpoint['model_state_dict']['encoder.conv1.weight'],那么我们在自己的模型对应的位置读取这个dict即可。
具体载入方式如下:

# 定义模型
model = MyPointNetSegmentation(channel=3, get_feature=True, batch_size=1)
model.to('cpu')

# 载入其他模型的参数
checkpoint = torch.load(model_file, map_location='cpu')
model_dict = checkpoint['model_state_dict']

# 将其他模型的参数,赋值给自己模型对应参数
model.encoder.conv1.weight.data.copy_(model_dict['encoder.conv1.weight'])
model.encoder.conv1.bias.data.copy_(model_dict['encoder.conv1.bias'])

把所有有用的参数都赋值过来就好,但要注意参数对应的tensor维度是一样的。
在这里插入图片描述

方法2:一次性载入key值相同的参数

如果说两个model的某些key值相同,可以用python的字典推导方式,将名称相关的参数提取出来。例如:

def load_dict_from_pointnet(model : Point2VoxelNet, checkpoint):
    my_model_dict = model.state_dict()
    pretrained_dict =  checkpoint['model_state_dict']
    # 只将pretraind_dict中那些在model_dict中的参数,提取出来
    state_dict = {k:v for k,v in pretrained_dict.items() if k in my_model_dict .keys()}

    my_model_dict.update(state_dict)		# 注意要更新state的变量,如果直接赋值,会出现某些key没有定义,导致运行失败
    model.load_state_dict(my_model_dict)
	
	# 对比参数是否一致
    print(f"{checkpoint['model_state_dict']['feat.stn.conv1.weight'][1]}")
    print(f"{model.feat.stn.conv1.weight[1]}")
    return model

看到这里,可以知道如果自己的模型改了名称,例如.pth的参数是:feat.stn.conv1,我这边叫做了 encoder.stn.conv1,那么是无法直接赋值的。可以用方法1,一个个载入,但是太慢了。另一种方式,是做一个键值映射,如果读到的是 feat.xxx,则赋予自定义模型中的 encoder.xxx ,简单处理即可。

注意事项

  • conv层需要载入的参数有:weight 和 bias
  • BN层涉及的参数有:
    1. weight,bias
    2. running_mean,running_var:这两个参数用于归一化的均值和方差, 因此也需要载入
    3. num_batches_tracked:在训练时需要载入,在test时不需要载入
  • 载入参数后,如果用于测试,需要调用 eval()。注意不能在载入参数前调用 eval。eval 会将 bn 层的training参数设置为 false ,这样在测试时 batch_size 时如果是 1 也能够正常运行。

测试

用默认方式载入参数,以及手动方式载入后的两个模型,预测结果一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2327481.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

写Prompt的技巧和基本原则

一.基本原则 1.一定要描述清晰你需要大模型做的事情,不要模棱两可 2.告诉大模型需要它做什么,不需要做什么 改写前: 请帮我推荐一些电影 改写后: 请帮我推荐2025年新出的10部评分比较高的喜剧电影,不要问我个人喜好等其他问题&#xff…

水下成像机理分析

一般情况下, 水下环境泛指浸入到人工水体 (如水库、人工湖等)或自然水体(如海洋、河流、湖 泊、含水层等)中的区域。在水下环境中所拍摄 的图像由于普遍受到光照、波长、水中悬浮颗粒物 等因素的影响,导致生成的水下图像出现模糊、退 化、偏色等现象,图像…

JVM类加载器详解

文章目录 1.类与类加载器2.类加载器加载规则3.JVM 中内置的三个重要类加载器为什么 获取到 ClassLoader 为null就是 BootstrapClassLoader 加载的呢? 4.自定义类加载器什么时候需要自定义类加载器代码示例 5.双亲委派模式类与类加载器双亲委派模型双亲委派模型的执行…

从一到无穷大 #44:AWS Glue: Data integration + Catalog

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 本作品 (李兆龙 博文, 由 李兆龙 创作),由 李兆龙 确认,转载请注明版权。 文章目录 引言Glue的历史,设计原则与挑战Serverless ETL 功能设计Glue StudioGlue …

实战打靶集锦-35-GitRoot

文章目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查5. 系统提权6. 写在最后 靶机地址:https://download.vulnhub.com/gitroot/GitRoot.ova 1. 主机发现 目前只知道目标靶机在192.168.56.xx网段,通过如下的命令,看看这个网段上在线的主机…

英语口语 -- 常用 1368 词汇

英语口语 -- 常用 1368 词汇 介绍常用单词List1 (96 个)时间类气候类自然类植物类动物类昆虫类其他生物地点类 List2 (95 个)机构类声音类食品类餐饮类蔬菜类水果类食材类饮料类营养类疾病类房屋类家具类服装类首饰类化妆品类 Lis…

SpringBoot+Vue 中 WebSocket 的使用

WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,它使得客户端和服务器之间可以进行实时数据传输,打破了传统 HTTP 协议请求 - 响应模式的限制。 下面我会展示在 SpringBoot Vue 中,使用WebSocket进行前后端通信。 后端 1、引入 j…

关于依赖注入框架VContainer DIIOC 的学习记录

文章目录 前言一、VContainer核心概念1.DI(Dependency Injection(依赖注入))2.scope(域,作用域) 二、练习例子1.Hello,World!步骤一,编写一个底类。HelloWorldService步骤二,编写使用低类的类。GamePresenter步骤三&am…

Qt常用控件第一部分

1.控件概述 Widget 是 Qt 中的核⼼概念. 英⽂原义是 "⼩部件", 我们此处也把它翻译为 "控件" . 控件是构成⼀个图形化界⾯的基本要素. 像上述⽰例中的, 按钮, 列表视图, 树形视图, 单⾏输⼊框, 多⾏输⼊框, 滚动条, 下拉框等, 都可以称为 "控件"…

docker存储卷及dockers容器源码部署httpd

1. COW机制 Docker镜像由多个只读层叠加而成,启动容器时,Docker会加载只读镜像层并在镜像栈顶部添加一个读写层。 如果运行中的容器修改了现有的一个已经存在的文件,那么该文件将会从读写层下面的只读层复制到读写层,该文件的只读版本依然存在,只是已经被读写层中该文件…

JMeter接口自动化发包与示例

前言 JMeter接口自动化发包与示例 近期需要完成对于接口的测试,于是了解并简单做了个测试示例,看了看这款江湖上声名远播的强大的软件-Jmeter靠不靠谱。 官网:Apache JMeter - Apache JMeter™ 1简介 Apache-Jmeter是一个使用java语言编写且开源&…

INFINI Console 极限控制台密码忘记了,如何重置?

在使用 INFINI Console(极限控制台)时,可能会遇到忘记密码的情况,这对于管理员来说是一个常见但棘手的问题。 本文将详细介绍如何处理 INFINI Console 密码忘记的情况,并提供两种可能的解决方案,帮助您快速…

汇编学习之《jcc指令》

JCC(Jump on Condition Code)指的是条件跳转指令,c中的就是if-else, while, for 等分支循环条件判断的逻辑。它包括很多指令集,各自都不太一样,接下来我尽量将每一个指令的c 源码和汇编代码结合起来看,加深…

从零构建大语言模型全栈开发指南:第四部分:工程实践与部署-4.3.3低代码开发:快速构建行业应用(电商推荐与金融风控案例)

👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 从零构建大语言模型全栈开发指南-第四部分:工程实践与部署4.3.3 低代码开发:快速构建行业应用(电商推荐与金融风控案例)1. 低代码与AI结合的核心价值2. 电商推荐系统案例2.1 技术架构与实现2.2 性能…

基于vue框架的智能服务旅游管理系统54kd3(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表 项目功能:用户,景点信息,门票预订,酒店客房,客房预订,旅游意向,推荐景点,景点分类 开题报告内容 基于Vue框架的智能服务旅游管理系统开题报告 一、研究背景与意义 1.1 行业现状与挑战 传统系统局限性:当前旅游管理系统普遍存在信息…

用Python实现TCP代理

依旧是Python黑帽子这本书 先附上代码,我在原书代码上加了注释,更好理解 import sys import socket import threading#生成可打印字符映射 HEX_FILTER.join([(len(repr(chr(i)))3) and chr(i) or . for i in range(256)])#接收bytes或string类型的输入…

MySQL的进阶语法7(索引-B+Tree 、Hash、聚集索引 、二级索引(回表查询)、索引的使用及设计原则

目录 一、索引概述 1.1 基本介绍 1.2 基本演示 1.3 特点及优势 二、索引结构 2.1 概述 2.2 二叉树 2.3 B-Tree 2.4 BTree 2.5 Hash 2.5.1 结构 2.5.2 特点 2.5.3 存储引擎支持 三、索引的分类 3.1 索引分类 3.2 聚集索引和二级索引 3.2.1 聚集索引和二级…

【CSS3】04-标准流 + 浮动 + flex布局

本文介绍浮动与flex布局。 目录 1. 标准流 2. 浮动 2.1 基本使用 特点 脱标 2.2 清除浮动 2.2.1 额外标签法 2.2.2 单伪元素法 2.2.3 双伪元素法(推荐) 2.2.4 overflow(最简单) 3. flex布局 3.1 组成 3.2 主轴与侧轴对齐方式 3.2.1 主轴 3.2.2 侧轴 3.3 修改主…

论坛系统的测试

项目背景 论坛系统采用前后端分离的方式来实现,同时使用数据库 来处理相关的数据,同时将其部署到服务器上。前端主要有7个页面组成:登录页,列表页,论坛详情页,编辑页,个人信息页,我…

宠物店小程序怎么做?助力实体店实现营销突破

宠物店小程序怎么做?助力实体店实现营销突破 ——一个宠物店老板的“真香”实战分享 ​一、行业现状:线下宠物店的“流量焦虑”​ 作为开了3年宠物店的“铲屎官供应商”,这两年明显感觉生意难做了:某宝9.9包邮的狗粮、某团“满…