pytorch保存、加载和解析模型权重

news2024/12/27 13:05:14

1、模型保存和加载

         主要有两种情况:一是仅保存参数,二是保存参数及模型结构。

保存参数:

         torch.save(net.state_dict())

加载参数(加载参数前需要先实例化模型):

         param = torch.load('param.pth')

         net.load_state_dict(param)

保存模型结构和参数:

         torch.save(net)

加载模型:

         net = torch.load('model.pt')

2、解析模型权重文件

         当加载某个模型文件后,如果需要查看模型中的算子和参数,可以将模型解析为字典,然后逐一打印。

以lent5为例,将lenet5模型保存为权重文件,然后重新加载权重文件并解析其中每一层的参数。

参考代码:

def pytorch_params(pth_file):
    par_dict = torch.load(pth_file, map_location='cpu')
    for name in par_dict:
        parameter = par_dict[name]
        print(name, parameter.numpy().shape)

        以上代码是加载的权重文件,文件只有参数,没有模型结构,如果加载的是包含模型结构的权重文件,可以做如下修改:

def pytorch_params(pt_file):
    net = torch.load(pt_file, map_location='cpu')
	par_dict = net.state_dict()
    for name in par_dict:
        parameter = par_dict[name]
        print(name, parameter.numpy().shape)

解析结果:

3、加载自定义参数

        某些情况下可能需要对某个算子进行单独调试,如加载特定参数进行推理计算,用来确定输出结果符合预期。以Conv2d算子为例进行测试,首先设定卷积层输入为3,输出为3,卷积核为3*3,偏置bias为False。通过numpy随机一个3*3*3*3的矩阵作为自定义参数,将参数转换为Tensor以后,添加到dict中,然后通过load_state_dict将参数加载进网络。

参考脚本:

 

import torch
import torch.nn as nn
import numpy as np
net = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, bias=False)
param = np.random.random((3, 3, 3, 3))
param = param.astype(np.float32)
torch_param = {'weight': torch.Tensor(param)}
net.load_state_dict(torch_param)
net.eval()
data = np.random.random((1, 3, 16, 16))
data = data.astype(np.float32)
result = net(torch.Tensor(data))
print(result)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/745706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AWS 中文入门开发教学 47- S3 - 基本的使用

知识点 S3 - 基本的使用方法实战演习 创建存储桶 阻止所有公网访问: 打开版本控制、添加标签: KMS是收费的: 创建成功: 上传文件 选择存储类:

这是中国人工智能AI激情澎湃的一周

融资 贝联珠贯完成 5000 万元天使轮融资,业务涵盖 AI 型算力市场据投中网报道,近日,云资源管理服务提供商浙江贝联珠贯宣布完成 5000 万元天使轮融资,由元璟资本、红杉中国种子基金和舟轩股权投资。 盛大网络 CEO 陈天桥再投 1…

springboot就业信息管理系统

本次设计任务是要设计一个就业信息管理系统,通过这个系统能够满足就业信息管理功能。系统的主要功能包括:首页,个人中心,学生管理,导师管理,企业管理,招聘信息管理,应聘信息管理&…

DMA是一个超级简化版的cpu吗?

来自群友的讨论 我的理解是DMA某种程度相当于一个CPU是因为DMA拥有访问其他地址空间的权利。 从系统角度考虑,对整个系统的观测者一般CPU DSP GPU DMA是一个级别,其他都是slave。cache一致性POC是要保证所有观测者,包括DMA观测到相同数据。 …

【学习bubbliiiing代码-2】从txt中获取类别名称以及类别数量

本系列主要用于自我学习,参考的为bubbliiiing的代码 写一个优雅的:从txt文件中获得类别名与类别数的函数,如下: #---------------------------------------------------# # 获得类别名与类别数 #-----------------------------…

Python爬虫:利用JS逆向抓取携程网景点评论区图片的下载链接

Python爬虫:利用JS逆向抓取携程网景点评论区图片的下载链接 1. 前言2. 实现过程3. 运行结果 1. 前言 文章内容可能存在版权问题,为此,小编不提供相关实现代码,只是从js逆向说一说到底怎样实现这个的过程,希望能够帮助到那些正在做…

主动配电网故障恢复的重构与孤岛划分统一模型(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

SIP协议学习(2)

文章目录 注册(REGISTER)1、AOR和Contact区别2、注册概述3、注册与定位服务4、注册超时处理5、注册消息6、多Contact地址处理7、下期预告 注册(REGISTER) 1、AOR和Contact区别 在学习注册之前,首先区分一下AOR和Cont…

Threads上线5天用户增至1亿,Threads软件常见问题百问百答

7月10日,脸书(Facebook)母公司Meta旗下新应用程序Threads上线的第5天,其用户数量已经超过1亿。这一增长速度打破聊天机器人ChatGPT的纪录——推出两个月内活跃用户量才破亿。 Threads或成为史上用户数增长速度最快的消费者应用。 …

Mysql数据库基础和增删改查操作

目录 一、数据库基本概念 二、数据库类型和常用数据库 1.关系型数据库 2.非关系型数据库 三、数据库的数据类型 四、SQL语句 1.简介 2.分类 五、SQL语句的使用 1.数据库操作 (1)创建数据库 ​编辑 (2)查看数据库 &am…

mac MySQL修改密码

简介: MySQL是一种常用的关系型数据库管理系统。在某些情况下,您可能需要关闭MySQL服务或修改root密码。本文将向您展示如何执行这些操作的步骤。 步骤1:关闭MySQL服务 打开MySQL软件并关闭它。 或者使用以下命令关闭MySQL服务&#xff1a…

conda的使用

一、conda 1、为什么使用conda 在安装Python包的过程中,可能遇到依赖包的问题。例如,要安装numpy,需要先安装BLAS和LAPACK等库。在使用pip等包管理工具时,这些依赖包需要手动安装,操作起来可能比较繁琐。而conda是一个…

pdf怎么添加水印图片?分享3个超实用解决方法

在使用PDF文件时,我们经常会看到一些设置的水印,这提醒观看者文件的所有权。给PDF文件添加水印是一种常见且实用的功能。为了解决如何给PDF添加水印的问题,我将介绍几种常用的方法。 方法一:使用WPS添加水印 WPS是我们常用的办公…

【分布式应用】zookeeper集群

目录 一、zookeeper概述1.1zookeeper工作机制1.2Zookeeper 数据结构1.3Zookeeper 应用场景1.4Zookeeper 选举机制第一次启动选举机制**非第一次启动选举机制 二、部署 Zookeeper 集群2.1环境配置2.2安装 Zookeeper 一、zookeeper概述 Zookeeper是一个开源的分布式的&#xff0c…

1.内核驱动中,驱动注册,阻塞IO,gpio子系统,中断处理的整体结合示例

一,功能实现要求 /*功能实现 在stm32开发板上实现功能 1.使用阻塞IO读取number变量的值,当number的值改变时打印number的值 2.注册KEY1按键的驱动和LED1的驱动以及对应的设备文件, 3.按键和指示灯设备信息放在同一个设备树的节点中 4.当KEY1…

TypeScript 类型体操:合并映射类型的处理结果为联合类型(记录)

一般索引索引 type boy {name : string,age : number } 对索引类型映射 type onlyBoy<obj> {readonly [key in keyof obj] : obj[key] } 使用 type res onlyBoy<boy>; 输出 这些都是对索引类型整体做的变换&#xff0c;变换的结果依然是一个索引类型。有的…

提示“无法向会话状态服务器发出会话状态请求。请确保 ASP.NET State Service (ASP.NET 状态服务)已启动”,如何解决?

在aspx网站部署过程中&#xff0c;出现“无法向会话状态服务器发出会话状态请求。请确保 ASP.NET State Service (ASP.NET 状态服务)已启动”的提示&#xff0c;如下图&#xff0c;如何解决&#xff1f; 解决方案1&#xff1a; Web.Config里面 把sessionState 的mode改为&quo…

ChatGPT对高校人才培养模式的挑战与应对策略思考

酷吗&#xff1f;输入指令后直接就能生成一大串代码&#xff0c;即使不懂相关技术也能玩转编程&#xff0c;这就是ChatGPT赋予你的“新能力”&#xff0c;除了写代码&#xff0c;ChatGPT还能帮你执行各种五花八门的任务。 AI工具如ChatGPT在行业中的广泛应用对于行业的人才结…

全网最牛,Python自动化测试-日志Log处理(超细)一篇打通...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 日志就是用于记录…

如何安装本地Go Tour教程(或者叫A Tour of Go离线版),以及中文版安装不了该怎么办

Go 官方是有一个在线教程 A Tour of Go&#xff0c;可以在线学习 Go 的编程&#xff0c;并且有中文版。英文原版页面如下&#xff1a; 出人意料的是&#xff0c;Go 提供了离线版&#xff08;各个语言都有&#xff09;&#xff0c;下载安装之后就可以在本地编译运行查看结果&a…