pytorch小记(七):pytorch中的保存/加载模型操作

news2025/1/14 3:33:34

pytorch小记(七):pytorch中的保存/加载模型操作

  • 1. 加载模型参数 (`state_dict`)
    • 1.1 保存模型参数
    • 1.2 加载模型参数
    • 1.3 常见变种
      • 1.3.1 指定加载设备
      • 1.3.2 非严格加载(跳过部分层)
      • 1.3.3 打印加载的参数
  • 2. 加载整个模型
    • 2.1 保存整个模型
    • 2.2 加载整个模型
    • 2.3 注意事项
  • 3. 总结
  • 4. 加载模型的完整代码示例
    • 4.1 保存和加载参数
    • 4.2 保存和加载整个模型
    • 4.3 加载到不同设备
    • 4.4 忽略部分参数(非严格加载)
    • 5. 检查模型是否加载成功


在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict)加载整个模型。以下是加载模型的所有相关操作及其详细步骤:


1. 加载模型参数 (state_dict)

当仅保存了模型的参数时(使用 model.state_dict() 保存),加载模型的步骤如下:

1.1 保存模型参数

torch.save(model.state_dict(), 'model.pth')
  • 文件内容:只保存模型的参数(权重和偏置)。
  • 优点
    • 节省存储空间。
    • 灵活性更高,可以与不同的模型架构配合使用。
  • 缺点
    • 需要手动重新定义模型结构。

1.2 加载模型参数

  1. 重新定义模型架构:

    model = MyModel()  # 替换为你的模型类
    
  2. 加载参数:

    state_dict = torch.load('model.pth')  # 加载参数字典
    model.load_state_dict(state_dict)    # 加载参数到模型
    
  3. 选择运行设备:

    model.to('cuda')  # 如果需要运行在 GPU 上
    

1.3 常见变种

1.3.1 指定加载设备

  • 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用 map_location
    state_dict = torch.load('model.pth', map_location='cpu')
    

1.3.2 非严格加载(跳过部分层)

  • 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用 strict=False
    model.load_state_dict(state_dict, strict=False)
    

1.3.3 打印加载的参数

  • 可以检查参数字典的内容:
    print(state_dict.keys())
    

2. 加载整个模型

当模型是通过 torch.save(model) 保存时,文件包含了模型的结构和参数,加载更为简单。

2.1 保存整个模型

torch.save(model, 'model_full.pth')
  • 文件内容:包含模型的架构和参数。
  • 优点
    • 无需重新定义模型结构。
    • 直接加载并使用。
  • 缺点
    • 文件依赖于保存时的代码版本(如模型定义)。
    • 文件体积较大。

2.2 加载整个模型

model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上运行

2.3 注意事项

  • 动态定义的模型
    • 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
    • 确保在加载时导入了与保存时相同的模型类。

3. 总结

操作使用场景优点缺点
保存参数 (state_dict)推荐大多数情况文件小、灵活性高需要手动定义模型架构
保存整个模型模型复杂且固定时不需要重新定义模型,直接加载文件大、依赖保存时的代码版本

4. 加载模型的完整代码示例

4.1 保存和加载参数

import torch
import torch.nn as nn

# 定义模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')

# 加载参数
model = MyModel()  # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 运行在 GPU

4.2 保存和加载整个模型

# 保存整个模型
torch.save(model, 'model_full.pth')

# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda')  # 运行在 GPU

4.3 加载到不同设备

# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)

# 加载到 GPU
model.to('cuda')

4.4 忽略部分参数(非严格加载)

# 保存参数
torch.save(model.state_dict(), 'model.pth')

# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 检查模型是否加载成功

  1. 验证权重是否加载

    for name, param in model.named_parameters():
        print(f"{name}: {param.data}")
    
  2. 进行推理验证

    x = torch.randn(1, 10).to('cuda')  # 假设输入维度为 10
    output = model(x)
    print(output)
    

通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

初学stm32 --- DAC输出三角波和正弦波

输出三角波实验简要: 1,功能描述 通过DAC1通道1(PA4)输出三角波,然后通过DS100示波器查看波形 2,关闭通道1触发(即自动) TEN1位置0 3,关闭输出缓冲 BOFF1位置1 4,使用12位右对齐模式 将数字量写入DAC_…

专题 - STM32

基础 基础知识 STM所有产品线(列举型号): STM产品的3内核架构(列举ARM芯片架构): STM32的3开发方式: STM32的5开发工具和套件: 若要在电脑上直接硬件级调试STM32设备,则…

25年无人机行业资讯 | 1.1 - 1.5

25年无人机行业资讯 | 1.1 - 1.5 中央党报《经济日报》刊文:低空经济蓄势待发,高质量发展需的平衡三大关系 据新华网消息,2025年1月3日,中央党报《经济日报》发表文章指出,随着国家发展改革委低空经济发展司的成立&a…

时序数据库InfluxDB—介绍与性能测试

目录 一、简述 二、主要特点 三、基本概念 1、主要概念 2、保留策略 3、连续查询 4、存储引擎—TSM Tree 5、存储目录 四、基本操作 1、Java-API操作 五、项目中的应用 六、单节点的硬件配置 七、性能测试 1、测试环境 2、测试程序 3、写入测试 4、查询测试 一…

计算机网络 (35)TCP报文段的首部格式

前言 计算机网络中的TCP(传输控制协议)报文段的首部格式是TCP协议的核心组成部分,它包含了控制TCP连接的各种信息和参数。 一、TCP报文段的结构 TCP报文段由首部和数据两部分组成。其中,首部包含了控制TCP连接的各种字段&#xff…

GelSight Mini视触觉传感器凝胶触头升级:增加40%耐用性,拓展机器人与触觉AI 应用边界

马萨诸塞州沃尔瑟姆-2025年1月6日-触觉智能技术领军企业Gelsight宣布,旗下Gelsight Mini视触觉传感器迎来凝胶触头的更新。经内部测试,新Gel凝胶触头耐用性提升40%,外观与触感与原凝胶触头保持一致。此次升级有效满足了客户在机器人应用中对设…

burpsiute的基础使用(2)

爆破模块(intruder): csrf请求伪造访问(模拟攻击): 方法一: 通过burp将修改,删除等行为的数据包压缩成一个可访问链接,通过本地浏览器访问(该浏览器用户处于登陆状态&a…

【ASP.NET学习】ASP.NET MVC基本编程

文章目录 ASP.NET MVCMVC 编程模式ASP.NET MVC - Internet 应用程序创建MVC web应用程序应用程序信息应用程序文件配置文件 用新建的ASP.NET MVC程序做一个简单计算器1. **修改视图文件**2. **修改控制器文件** 用新建的ASP.NET MVC程序做一个复杂计算器1.创建模型(…

Git 命令代码管理详解

一、Git 初相识:版本控制的神器 在当今的软件开发领域,版本控制如同基石般重要,而 Git 无疑是其中最耀眼的明珠。它由 Linus Torvalds 在 2005 年创造,最初是为了更好地管理 Linux 内核源代码。随着时间的推移,Git 凭借…

OpenCV实现基于交叉双边滤波的红外可见光融合算法

1 算法原理 CBF是*Cross Bilateral Filter(交叉双边滤波)*的缩写,论文《IMAGE FUSION BASED ON PIXEL SIGNIFICANCE USING CROSS BILATERAL FILTER》。 论文中,作者使用交叉双边滤波算法对原始图像 A A A, B B B 进行处理得到细节&#xff0…

项目实战--网页五子棋(用户模块)(1)

接下来我将使用Java语言,和Spring框架,实现一个简单的网页五子棋。 主要功能包括用户登录注册,人机对战,在线匹配对局,房间邀请对局,积分排行版等。 这篇文件讲解用户模块的后端代码 1. 用户表与实体类 …

机器学习之随机森林算法实现和特征重要性排名可视化

随机森林算法实现和特征重要性排名可视化 目录 随机森林算法实现和特征重要性排名可视化1 随机森林算法1.1 概念1.2 主要特点1.3 优缺点1.4 步骤1.5 函数及参数1.5.1 函数导入1.5.2 参数 1.6 特征重要性排名 2 实际代码测试 1 随机森林算法 1.1 概念 是一种基于树模型的集成学…

MySQL存储引擎、索引、索引失效

MySQL Docker 安装 MySQL8.0,安装见docker-compose.yaml 操作类型 SQL 程序语言有四种类型,对数据库的基本操作都属于这四种类,分为 DDL、DML、DQL、DCL DDL(Dara Definition Language 数据定义语言),是负责数据结构定义与数据…

WPF基础(1.1):ComboBox的使用

本篇文章介绍ComboBox的基本使用。 本篇文章的例子实现的功能:后端获取前端复选框中的选项之后,点击“确定”按钮,弹出一个MessageBox,显示用户选择的选项。 文章目录 1. 效果展示2. 代码逻辑2.1 前端代码2.2 后端代码 1. 效果展…

前端炫酷动画--文字(二)

目录 一、弧形边框选项卡 二、零宽字符 三、目录滚动时自动高亮 四、高亮关键字 五、文字描边 六、按钮边框的旋转动画 七、视频文字特效 八、立体文字特效让文字立起来 九、文字连续光影特效 十、重复渐变的边框 十一、磨砂玻璃效果 十二、FLIP动画 一、弧形边框…

android 官网刷机和线刷

nexus、pixel可使用google官网线上刷机的方法。网址:https://flash.android.com/ 本文使用google线上刷机,将Android14 刷为Android12 以下是失败的线刷经历。 准备工作 下载升级包。https://developers.google.com/android/images?hlzh-cn 注意&…

25/1/12 嵌入式笔记 学习esp32

了解了一下位选线和段选线的知识: 位选线: 作用:用于选择数码管的某一位,例如4位数码管的第1位,第2位) 通过控制位选线的电平(高低电平),决定当前哪一位数码管处于激活状…

探秘block原理

01 概述 在iOS开发中,block大家用的都很熟悉了,是iOS开发中闭包的一种实现方式,可以对一段代码逻辑进行封装,使其可以像数据一样被传递、存储、调用,并且可以保存相关的上下文状态。 很多block原理性的文章都比较老&am…

【Docker】入门教程

目录 一、Docker的安装 二、Docker的命令 Docker命令实验 1.下载镜像 2.启动容器 3.修改页面 4.保存镜像 5.分享社区 三、Docker存储 1.目录挂载 2.卷映射 四、Docker网络 1.容器间相互访问 2.Redis主从同步集群 3.启动MySQL 五、Docker Compose 1.命令式安装 …

Bootstrap 前端 UI 框架

Bootstrap官网:Bootstrap中文网 铂特优选 Bootstrap 下载 点击进入中文文档 点击下载 生产文件是开发响应式网页应用,源码是底层逻辑代码,因为是要制作响应式网页,所以下载开发文件 引入 css 文件, bootstrap.css 和 …