常用组件详解(十):保存与加载模型、检查点机制的使用

news2024/10/6 12:48:53

文章目录

  • 1.保存、加载模型
  • 2.torch.nn.Module.state_dict()
    • 2.1基本使用
    • 2.2保存和加载状态字典
  • 3.创建Checkpoint
    • 3.1基本使用
    • 3.2完整案例


1.保存、加载模型

  torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量和字典等(内部使用pickle模块实现对象的序列化)。数据会被保存为.pt.pth格式,可通过torch.load()从磁盘加载被保存的序列化对象,加载时会重新构造出原来的对象。
  torch.save()有两种保存模型的方式:

  • 1.保存整个模型(继承了torch.nn.Module的类),不推荐使用。
    • torch.load():利用pickle将保存的序列化对象反序列化,得到原始数据。可用于加载完整模型或状态字典。
#保存整个模型
torch.save(model, PATH)
#加载模型
model = torch.load(PATH)
  • 2.仅保存模型的参数(状态字典state_dict),推荐使用。
    • torch.nn.Module.load_state_dict():通过反序列化得到模型的state_dict()(状态字典)来加载模型,传入的参数是状态字典,而非.pt.pth文件。
#只保存模型参数
torch.save(model.state_dict(), PATH)
#加载模型
model=Model()
model.load_state_dict(torch.load(PATH))

  在实际使用中推荐第二种方式,第一种方式往往容易产生各种错误:

  • 设备错误。若在cuda:0上训练好一个模型并保存,则读取出来的模型也是默认在cuda:0上,如果训练过程的其他数据被放到了cuda:1上,那么就会发生错误:
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

此时需要将其他其他数据都保存在cuda:0上,或加载模型时指定使用cuda:1

device = torch.device("cuda:1")
model = torch.load(PATH, map_location=device)
  • 版本错误:比如使用pytorch1.0训练并保存CNN模型,再用pytorch1.1读取模型,则会出现错误:
AttributeError: 'Conv2d' object has no attribute 'padding_mode'

此时只能通过获取该模型的参数来加载新的模型:

#加载模型参数
model_state = torch.load(model_path).state_dict()
#初始化新模型并加载参数
model = Model()
model.load_state_dict(model_state)

2.torch.nn.Module.state_dict()

2.1基本使用

  torch.nn.Module.state_dict()用于返回模型的状态字典,其中保存了模型的可学习参数。其中,只有可学习参数的层(卷积层、全连接层等)和注册缓冲区(batchnorm’s running_mean)才会作为模型参数保存(优化器也有状态字典,也可进行保存)。
【例子】

import torch
import torch.nn as nn
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
 
# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

  查看模型与优化器的状态字典:
在这里插入图片描述

2.2保存和加载状态字典

  通过torch.save()来保存模型的状态字典(state_dict),即只保存学习到的模型参数,并通过torch.nn.Module.load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为.pt.pth

#保存模型状态字典
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
#根据状态字典加载模型
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
#打印新模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

  注意,模型推理之前,需要调用model.eval()函数将dropoutbatch normalization层设置为评估模式,否则会导致模型推理结果不一致。
在这里插入图片描述

3.创建Checkpoint

3.1基本使用

  模型检查点(checkpoint)是指模型训练过程中保存的模型状态,包括模型参数(权重与偏置)、优化器状态等其他相关的训练信息。通过保存检查点,可以实现在训练过程中定期保存模型的当前状态,以便在需要时恢复训练或用于模型评估和推理。模型检查点常见的保存信息如下:

  • 1.模型权重:模型的状态字典。
  • 2.优化器状态:优化器的状态字典。
  • 3.训练状态:当前的训练轮数(epoch)、批次(batch)等。
  • 4.其他数据:如学习率调度器的状态、自定义指标等。

例如:
【保存检查点】

#将模型参数和优化器状态的状态字典保存到检查点中
checkpoint = {'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss.item(),
			  'epoch':epoch
}

#保存检查点
torch.save(checkpoint, 'checkpoint.pth')

【加载检查点】

# 加载检查点
checkpoint = torch.load('checkpoint.pth')

# 恢复模型和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 恢复训练状态
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 如果是恢复训练,可以从保存的epoch继续
for epoch in range(epoch, num_epochs):
    # 继续训练

3.2完整案例

import torch
import torch.nn as nn
import torch.optim as optim

# 假设有一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    # 假设有输入x和目标y
    x = torch.randn(64, 10)
    y = torch.randn(64, 1)
    
    optimizer.zero_grad()
    output = model(x)
    loss = loss_fn(output, y)
    loss.backward()
    optimizer.step()
    
    # 每10个epoch保存一次检查点
    if epoch % 10 == 0:
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'loss': loss.item()
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')

# 加载检查点并继续训练
checkpoint = torch.load('checkpoint_epoch_10.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 从第11个epoch开始继续训练
for epoch in range(start_epoch + 1, num_epochs):
    # 继续训练
    pass

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2191915.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++基础(10)——初识vector

目录 1.vector 2.vector的使用 2.1vector的定义 2.2vector迭代器的使用 2.2.1begin和end 2.2.2rbegin和rend 2.3增删查改 2.3.1pop_back和push_back 2.3.2inset和erase 2.3.3find函数 2.3.4swap函数 2.3.5元素访问 2.4空间函数 2.4.1size和capacity 2.4.2reserv…

用HTML5+CSS+JavaScript庆祝国庆

用HTML5CSSJavaScript庆祝国庆 中华人民共和国的国庆日是每年的10月1日。 1949年10月1日,中华人民共和国中央人民政府成立,在首都北京天安门广场举行了开国大典,中央人民政府主席毛泽东庄严宣告中华人民共和国成立,并亲手升起了…

茴香豆 + Qwen-7B-Chat-Int8

今天 打开config.ini 发现 茴香豆 支持 qwen/qwen-7b-chat-int8 1.0 拉取qwen/qwen-7b-chat-int8 cd /root/modelsgit clone https://gitee.com/hf-models/Qwen-7B-Chat-Int8.git 1.1 更改配置文件 茴香豆的所有功能开启和模型切换都可以通过 config.ini 文件进行修改 /roo…

【JAVA开源】基于Vue和SpringBoot的洗衣店订单管理系统

本文项目编号 T 068 ,文末自助获取源码 \color{red}{T068,文末自助获取源码} T068,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 顾…

Leetcode—200. 岛屿数量【中等】

2024每日刷题&#xff08;176&#xff09; Leetcode—200. 岛屿数量 C实现代码 class Solution { public:int numIslands(vector<vector<char>>& grid) {int m grid.size();int n grid[0].size();int ans 0;function<void(int, int)> dfs [&](…

企业架构TOGAF的理论指南:数字化转型中的企业架构实践

在当今全球市场的快速变革中&#xff0c;企业的数字化转型已经成为不可避免的趋势。无论是为了提高效率、增强竞争力&#xff0c;还是为了应对技术变革的挑战&#xff0c;企业都需要一个强有力的架构框架来指导其转型。TOGAF&#xff08;The Open Group Architecture Framework…

pytorch版本和cuda版本不匹配问题

文章目录 &#x1f315;问题&#xff1a;Python11.8安装pytorch11.3失败&#x1f315;CUDA版本和pytorch版本的关系&#x1f315;安装Pytorch2.0.0&#x1f319;pip方法&#x1f319;cuda方法 &#x1f315;问题&#xff1a;Python11.8安装pytorch11.3失败 &#x1f315;CUDA版…

【CSS Tricks】试试新思路去处理文本超出情况

目录 引言一、常规套路1. 单行文本省略2. 多行文本省略 二、新思路美化一下1. 单行/多行文本隐藏2. 看下效果 三、总结 引言 本篇为css的一个小技巧 文本溢出问题是一个较为常见的场景。UI设计稿为了整体的美观度会将文本内容限制到一定范围内&#xff0c;然而UI设计阶段并不能…

智慧学生宿舍管理平台|学生宿舍管理平台系统|基于Springboot+VUE的智慧学生宿舍管理平台系统设计与实现(源码+数据库+文档)

智慧学生宿舍管理平台 目录 基于SpringbootVUE的智慧学生宿舍管理平台系统设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕…

余承东直播论道智能驾驶:激光雷达不可或缺,华为ADS 3.0引领安全创新

华为余承东:激光雷达,智能驾驶安全性的关键 9月29日,华为消费者业务集团CEO余承东在一场引人注目的直播中,与知名主持人马东就智能驾驶技术的最新进展进行了深入交流。在这场直播中,余承东针对激光雷达在智能驾驶中的必要性问题,发表了明确且深刻的观点,引发了业界和公众…

STM32F407 HAL库定时器触发ADC采集与DMA数据传输(定时器TIM+ADC+DMA)

在STM32F407系列微控制器的开发中&#xff0c;结合定时器、ADC&#xff08;模数转换器&#xff09;与DMA&#xff08;直接存储器访问&#xff09;控制器&#xff0c;能够显著提升数据采集与传输的效率。本文将指导你如何使用STM32 HAL库&#xff0c;通过定时器触发ADC1的单通道…

认知战认知作战:欧盟向中国纯电动车加关税为背景的认知作战方式与策略

认知战认知作战&#xff1a;欧盟向中国纯电动车加关税为背景的认知作战方式与策略 关键词&#xff1a;欧盟, 中国, 纯电动车, 关税, 认知战, 舆论战, 政治动员, 外交反击, 市场份额, 保护主义, 技术升级, 中立第三方, 友军, 国际贸易, 合作与竞争,认知作战,新质生产力,人类命运…

信号用wire类型还是reg类型定义

wire类型就是一根线&#xff0c;线有两端&#xff0c;一端发生改变&#xff0c;经过线传递的信号当然也会发生改变&#xff0c;reg类型则不同&#xff0c;可以把reg类型理解为存储数据的寄存器&#xff0c;当满足一定条件时&#xff0c;数值才被激活发生改变。 那么&#xff0…

英国本科毕业论文写作如何确立论点

英国本科毕业论文关系到留学生是否能顺利毕业。因此&#xff0c;写好英国本科毕业论文也便成了留学生在毕业季的头等大事。那么应当怎么做才能更好地完成毕业论文呢&#xff1f;在本文中&#xff0c;英国翰思教育将从论点这个内容展开说说&#xff0c;如果高质量地完成毕业论文…

2024 uniapp入门教程 01:含有vue3基础 我的第一个uniapp页面

uni-app官网uni-app,uniCloud,serverless,快速体验,看视频&#xff0c;10分钟了解uni-app,为什么要选择uni-app&#xff1f;,功能框架图,一套代码&#xff0c;运行到多个平台https://uniapp.dcloud.net.cn/ 准备工作&#xff1a;HBuilder X 软件 HBuilder X 官网下载&#xf…

AI产品经理的崛起

“It will be unthinkable not to have artificial intelligence integrated into a product. Because everyone will expect it.” _Sam Altman, CEO & Co-founder (OpenAI)_正如Sam Altman所说的&#xff0c;2024年人工智能技术继续快速发展。我们看到了各种AI模型&#…

[Python] 《人生重开模拟器》游戏实现

文章目录 优化点一&#xff1a;多元化的天赋系统示例天赋&#xff1a;天赋选择代码&#xff1a; 优化点二&#xff1a;更加多样化的随机事件年龄阶段划分&#xff1a;随机事件代码&#xff1a; 优化点三&#xff1a;设定人生目标人生目标示例&#xff1a;人生目标代码&#xff…

ubunut声卡配置 播放视频没有声音的解决方法 alsamixer和pavucontrol的使用方法

文章目录 &#x1f319;ubuntu22.04网页没有声音&#xff0c;声卡提示Dummy Output方法一&#xff1a;切换内核&#x1f319;方法二&#xff1a;使用知乎的方法 &#x1f319;ubuntu22.04 连接蓝牙耳机&#xff0c;1秒后断连解决方法ubuntu声音操作alsamixerpavucontrol通过are…

高校校园交友系统小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;用户管理&#xff0c;基础数据管理&#xff0c;论坛管理&#xff0c;公告信息管理&#xff0c;轮播图信息管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;用户&#…

15分钟学 Python 第40天:Python 爬虫入门(六)第一篇

Day40 &#xff1a;Python 爬取豆瓣网前一百的电影信息 1. 项目背景 在这个项目中&#xff0c;我们将学习如何利用 Python 爬虫技术从豆瓣网抓取前一百部电影的信息。通过这一练习&#xff0c;您将掌握网页抓取的基本流程&#xff0c;包括发送请求、解析HTML、存储数据等核心…