第五章 模型篇: 模型保存与加载

news2025/1/13 7:39:06

参考教程
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

文章目录

  • pytorch中的保存与加载
    • torch.save()
    • torch.load()
    • 代码示例
  • 模型的保存与加载
    • 保存 state_dict()
    • nn.Module().load_state_dict()
    • 加载模型参数
    • 保存模型本身
    • 加载模型本身
  • checkpoint
    • 保存与读取
    • 多个模型的保存与读取

训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。
为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。

这个模型里面包含一个名为self.p1的Parameter和一个名为conv1的卷积层。我们没有给模型定义forward()函数,是因为暂时不需要用到该方法。假如你想使用这个模型对数据进行前向传播,会返回 “NotImplementedError: Module [Model] is missing the required “forward” function”

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.t1 = torch.randn((3,2))
        self.p1 = nn.Parameter(self.t1)
        self.conv1 = nn.Conv2d(1, 1, 5)
net = Model()

pytorch中的保存与加载

首先我们来看一下pytorch中的保存和加载的方法是怎么实现的。

torch.save()

参考文档:https://pytorch.org/docs/stable/generated/torch.save.html
首先来看一下torch.save()函数。

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

torch.save()函数传入的第一个参数,就是我们要保存的对象,它的类别要求是object,而没有限定在nn.Module()或者nn.Parameters()等等之间。说明它可以保存的类型是多种多样的,很灵活。
传入的第二个参数是f,f是一个file-like object或者文件路径,也就是我们想要保存的位置。
后面的几个参数可以不用管它,一般也不会用到。从参数名称可以看到,我们想要保存的object是以pickle的形式保存的。因为pickle支持多种数据类型。
在源码中给了两个使用torch.save的例子。

  >>> # xdoctest: +SKIP("makes cwd dirty")
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)

第一个例子把一个tensor保存在了‘tensor.pt’中,第二个则是将tensor保存在一个buffer中。这都是允许的。

torch.load()

参考文档:https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
再来看一下torch.load()函数。

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

torch.load()传入的第一个参数f对应着torch.save()中的f,它可以是一个路径,也可以是一个file-like object。
因为我们的模型训练支持cpu也支持gpu等设备,所以我们保存的object也可能处于多种设备环境中,在torch.load()时,这个object会现在CPU上进行反序列化,然后移动到其保存时所处的设备上。假如当前的系统不支持这个设备,就会出现问题,这个时候就需要使用map_location参数,这个参数可以指定你想要放置object的设备,假如没有特别指定,在设备不能实现时就会报错。
weights_only参数可以限定你先要unpickle的object的种类,在使用weights_only参数的同时,你必须明确定义pickle_moduel这个参数(默认为pickle,这也是对的),否则就会报错RuntimeError(“Can not safely load weights when explicit pickle_module is specified”。一般情况下我们也不需要管这个参数。

代码示例

给出一个简单的例子,我们将一个tensor保存在’tensor.pt’中,又使用torch.load()加载进来。

因为保存支持的输入是object,所以我们即使只保存一个字符串也是可以的。(可以,但没必要)
在这里插入图片描述

模型的保存与加载

保存 state_dict()

在之前的章节中有说过,调用model.state_dict()方法时,得到的返回结果是一个orderdict,这个字典的key是模型中参数的名字,value是模型的参数值。
我们通常说的保存模型,保存的就是模型的state_dict(),也就是只保存了模型的参数名和参数值,因此我们是不知道模型的正确结构和forward()中的运算顺序的,你也没有办法直接使用这个state_dict()进行预测。
现在我们保存最开始定义的笨蛋小模型的state_dict()
在这里插入图片描述
我们只保存了模型的参数名和参数值,这个’test.pth’的大小只有1.39 KB (1,428 字节)。

nn.Module().load_state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any],
                        strict: bool = True):

load_state_dict()传入的参数是一个key和value的mapping。这里的keys对应的当前模型自己的state_dict的key,或者说参数名。
在使用load_state_dict()时,该方法会对传入的mapping中的key和模型本身的key进行对比。如果key可以匹配上,就会进行一些操作后,更改模型的key对应的参数值。假如没有匹配上,这个key就会被放进missing_keys或者unexpected_keys中去。
strict这个参数默认是True,所以当有不匹配的key时,就会返回报错。

加载模型参数

我们只保存的模型的参数,所以想要使用这个参数,就需要把它放置在一个现有的模型中去。比如说我们现在有一个新模型model2,它和model1有着一样的结构,但是因为初始化的随机性,它们的参数值可能是不一样的。
在这里插入图片描述
可以看到我们的model2中的参数名和model1一样,但是对应的值不一样。
我们可以使用load_state_dict()方法将model1的参数值根据参数名放到model2中去。
在这里插入图片描述
现在model1和model2中的参数值也都变得一样了。
假如我们手动修改一下我们使用torch.load()加载的state_dict,给它增加一个新的值。加载时就会报错,出现了unexpected_keys。相应地,假如给它删除一个值,就会出现Missing key(s) 的错误,在这里不举例子。

在这里插入图片描述

保存模型本身

torch.save()支持保存的对象是object,而我们的模型本身,作为nn.Module(),自然也是符合object的要求的。因此你也可以直接保存整个模型。
在这里插入图片描述
我们保存的是整个模型,包括了模型的结构和模型的参数名+参数值。这个’test2.pth’的大小是2.39 KB (2,457 字节)。

加载模型本身

我们在上面将整个模型都保存在了’test2.pth’中,因此我们使用torch.load('test2.pth)时,获得的结果就是模型本身,它的类型是nn.Module()。
在这里插入图片描述

checkpoint

保存与读取

假如我们现在有一个保存好的模型’model.pth’,我们想要继续当前模型的状态继续训练。这个时候我们就会发现,'model.pth’中拥有我们模型的参数名和参数值,但是随着我们之前的训练的进行,我们使用的optimizer或者lr_scheluder的状态我们是无法获取的,它们中也有一些参数可能在训练时发生了变化。
因此为了帮助我们重启训练状态,我们需要保存更多的信息,而不是只保存一个模型的state_dict。这些被保存的信息,统称为checkpoint。
在保存checkpoint时,我们同样使用torch.save()方法,在加载时,也是用torch.load()方法。因为torch.save支持保存各种格式,我们可以将想要保存的信息按照key和value组成一个dict,并将这个dict保存下来。
在下面这个例子中,被保存下来的信息包括当前的epoch数,模型的state_dict, 优化器的state_dict还有louss。

# Additional information
torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

在加载时,我们只要按照key取其中的value就可以。

# Additional information
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

多个模型的保存与读取

我们已经知道可以将key和value对应的dict保存成checkpoint的形式,帮助我们重启训练状态。当我们有多个模型时,只不过是增加了要保存到信息而已,方法是一样的。

# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

在这个checkpoint中,我们分别保存了modelA和modelB的state_dict,和它们对应的优化器optimizerA和optimizerB的state_dict。
因此在使用时,只要分别放置到对应的object中就可以。

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/651045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言编程语法—排序算法

一、冒泡排序 冒泡排序(英语:Bubble Sort)是一种简单的排序算法。它重复地走访过要排序的数列,一次比较两个元素,如果他们的顺序(如从大到小、首字母从A到Z)错误就把他们交换过来。 过程演示&…

ansible playbook脚本,安装LAMP套件

yum 集中式安装lamp --- - name: LAMP installhosts: dbserverstasks:- name: disable firewalldservice:name: firewalldenabled: nostate: stopped- name: disabled selinuxshell: "sudo sed -i s/SELINUXenforcing/SELINUXdisabled/g /etc/selinux/config"shell:…

Random random = SecureRandom.getInstanceStrong();堵塞线程问题解决

sonar扫描到使用Random随机函数不安全, 推荐使用SecureRandom替换,就是他–》【SecureRandom.getInstanceStrong()】,分别在本地,测试环境测过没问题上生产,但是运行了一段时间突然报错!!! 然后…

简要介绍 | 基于双风机振动的燕麦清选与筛选

注1:本文系“简要介绍”系列之一,仅从概念上对基于双风机振动的燕麦清选和筛选装置设计与仿真进行非常简要的介绍,不适合用于深入和详细的了解。 注2:"简要介绍"系列的所有创作均使用了AIGC工具辅助 基于双风机振动的燕…

37 # commonjs 规范流程梳理

require 源码大致过程 mod.require 会默认调用 require 语法Module.prototype.require 模块的原型上有 require 方法Module._load 调用模块的加载方法,最终返回的是 module.exportsModule._resolveFilename 解析文件名,将文件名变成绝对路径&#xff0c…

EndNote下载安装与引用

哎!写论文这个事真是没有头绪啊,今天研究一下参考文献怎么搞,发现了EndNote,但是这玩意感觉写中文的论文用还可以,英文的不太会用。这里记录一下安装使用过程,方便以后查阅。 EndNote下载安装与引用 EndNot…

React学习[一]

React学习[一] React概述React特点声明式基于组件学习一次,随处可用 React基本使用React使用方法说明 React脚手架意义脚手架初始化项目npx命令介绍 在脚手架中使用react JSXJSX的基本使用JSX使用步骤 JSX中使用JavaScript表达式嵌入式JS表达式 JSX的条件渲染JSX的列…

Model Checking(模型检测)

1. Definition 给定一个系统和一个我们期待拥有的属性P, Model checking 会探索这个系统的每个状态,验证系统是否满足定义的性质。如果满足直接返回True,否则会给出一个反例(counter example)。如果系统被证明是正确的,说明该系统的所有的行…

H3C-HCL模拟器-STP生成树协议实验

一、实验拓扑图 二、实验步骤 1)CRT连接并重命名 若遇到连接失败,先在HCL中启动命令行配置 2)启动所有设备 3)4台交换机重新命令 4)查看信息 ① SW1的MAC地址:SW1是根桥 为什么SW1是根桥? HC…

图像坐标转换:一个点绕着另一个点逆时针旋转角度平移后的坐标

图像坐标系:x向右增大, y向下增大。 点A在图像中的坐标(x1, y1) 点B在图像中的坐标(x2, y2) 点B绕着点A逆时针旋转a弧度,旋转后的点B坐标为{x (x2 - x1)*cos(a) (y2 - y1)*sin(a) x1&#xf…

java基础——有多少是你不知道的?

java基础——有多少是你不知道的&#xff1f; 一、&&和||二、Integer和int三、String、StringBuffer、StringBuilder的区别四、i1<i居然是成立的&#xff1f;五、一脸懵逼的null问题六、整数除法向上取整你知道多少种&#xff1f;七、这也能运行&#xff1f; 一、&a…

QML 与 Python 交互

在 Qt 中&#xff0c;C 和 QML 交互一般有如下三种方法 上下文属性&#xff1a;setContextProperty( )向引擎注册类型&#xff1a;调用 qmlRegisterType( )QML 扩展插件&#xff1a;虽然有很大的灵活性&#xff0c;但是用 Python 创建 QML 插件比较麻烦&#xff0c;所以这种方法…

【补充:CAN卡通信的下位机-STM32cubeIDE-hal库+STMF1xx+数据发送和接收+中断接收方式+基础样例3】

【CAN卡通信的下位机-STM32cubeIDE-hal库STMF4xx数据发送和接收中断接收方式基础样例3】 1、概述2、实验环境3、问题描述4、大佬指点与解决问题5、实验效果截图6、代码连接7、细节部分8、总结 ) 1、概述 从第一篇F1和F4上采用轮询的方式调试can&#xff0c; 【CAN卡通信的下位…

如何用Jmeter进行接口测试 ,这应该是全网最详细的教程了

一、Jmeter 的使用步骤 打开Jmeter 安装包&#xff0c;进入\bin 中&#xff0c;找到"jmeter.bat", 点击打开即可。 在下图打开的Jmeter 页面中&#xff0c;右键“测试计划” -> “添加” -> "Threads(Users)" -> “线程组”&#xff0c; 建立线…

Allure安装、使用、Jenkins集成

目录 一、allure介绍 二、安装allure服务 三、安装pytest、allure-pytest 插件 四、生成报告 五、allure其他使用 5.1 给测试报告添加各种附件 5.2 添加用例标题和描述信息 5.3 添加链接 5.4 标记测试用例 5.5 优先级 六、allure和jenkins集成 一、allure介绍 all…

2023年5月青少年软件编程(图形化) 等级考试试卷(三级)

青少年软件编程&#xff08;图形化&#xff09; 等级考试试卷&#xff08;三级&#xff09; 一、 单选题(共 25 题&#xff0c; 共 50 分) 1.关于变量&#xff0c; 下列描述错误的是&#xff1f; &#xff08; &#xff09; A.只能建一个变量 B.变量可以隐藏 C.变量可以删除 D.…

【抽样调查】实验

文章目录 1、数组矩阵简单抽样&#xff08;1&#xff09;构造数组&#xff08;2&#xff09;构造矩阵&#xff08;3&#xff09;产生来自正态分布的随机数&#xff08;4&#xff09;从正态总体中抽取若干个样本&#xff08;5&#xff09;对矩阵的行或列进行统计计算 2、R软件作…

输入信号、冲激响应与卷积

输入信号与冲激响应的离散卷积 系统冲激响应&#xff1a; h ( t ) ∑ τ 0 ∞ x ( t ) δ ( t − τ ) h(t)\sum_{\tau0}^{\infty}x(t)\delta(t-\tau ) h(t)τ0∑∞​x(t)δ(t−τ) 上式中 h ( t ) h(t) h(t)是冲激信号输入到系统后系统的输出&#xff0c;也是系统对外在激…

stl容器vector笔记

Vector 一、初始化二、常用方法1. 访问元素at()、下标、data()、front()、back()2. push_back()、pop_back()尾部增删元素3. insert()在pos前插入元素&#xff0c;返回插入位置4. erase()擦除元素&#xff0c;返回擦除元素后的元素位置5. clear()清空容器6. resize()改变容器元…

C语言中函数返回数组(一维和二维)

文章目录 函数返回一维数组函数返回二维数组 C语言中函数返回数组是很重要的一种应用&#xff0c;有时候在程序中调用函数返回数组可以更容易的实现我们想要的某些操作&#xff0c;比如一次返回多个值&#xff0c;这篇文章带来的是C语言中函数返回一维数组和二维数组的例子。 函…