PyTorch 模型保存与加载的三种常用方式

news2024/9/29 9:16:32

在深度学习的训练过程中,我们不可避免地要保存模型,这是一个非常好的习惯。接下来,文章将通过一个简单的神经网络模型,带你了解 PyTorch 中主要的模型保存与加载方式。

文章目录

  • 为什么保存和加载模型很重要?
  • 代码示例
    • 模型准备
    • 方法一:保存和加载整个模型
    • 方法二:只保存模型的状态字典(state_dict)
      • 使用 `strict=False` 加载模型
    • 方法三:保存完整的训练状态(checkpoint)
    • 定义 checkpont 保存和加载的函数

为什么保存和加载模型很重要?

训练一个神经网络可能需要数小时甚至数天的时间,你需要认知到一点:时间是非常宝贵的,目前3090云服务器租赁一天的价格为 37.92 元。如果你的代码没有保存模型的模块,那就先不要开始,因为不保存基本等于没跑,你的效果再好也没有办法直接呈现给别人。如果你保存了模型,你就可以做到以下的事情:

  • 继续训练:通过保存检查点(checkpoint),你可以在意外中断后继续训练你的模型,这一点可能会节省你大量的时间。
  • 模型部署:训练好的模型可以被部署到生产环境中进行推理,比如 LLM,LoRA 等。
  • 分享模型:将训练好的模型分享给实验室其他成员或开源社区,以便进一步研究或复现结果。

代码示例

模型准备

为了演示,我们先定义一个简单的神经网络模型:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 64)   # 隐藏层到隐藏层
        self.fc3 = nn.Linear(64, 10)    # 隐藏层到输出层

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 实例化模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

方法一:保存和加载整个模型

保存模型

torch.save(model, 'model.pth')

加载模型

model = torch.load('model.pth')
print(model)

输出

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

这种方法非常简单直观,因为它保存了模型的整个结构和参数。

方法二:只保存模型的状态字典(state_dict)

保存模型状态字典

torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型状态字典
需要注意的是,加载state_dict时你需要手动重新实例化模型。

model = Net()  # 你需要先定义好模型架构
model.load_state_dict(torch.load('model_state_dict.pth'))
print(model)

输出

Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

与保存整个模型相比,保存 state_dict 更加灵活,它只包含模型的参数,而不依赖于完整的模型定义,这意味着你可以在不同的项目中加载模型参数,甚至只加载部分模型的权重。举个例子,对于分类模型,即便你保存的是完整的网络参数,也可以仅导入特征提取层部分,当然,直接导入完整模型再拆分实际上是一样的。对于不完全匹配的模型,加载时可以通过设置 strict=False 来忽略某些不匹配的键:

model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

这样,你可以灵活地只加载模型的某些部分。

使用 strict=False 加载模型

假设我们在原来的 Net 模型中新增了一个全连接层(fc4),此时如果我们直接加载之前保存的 state_dict,会因为 state_dict 中没有 fc4 的权重信息而导致报错。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 修改后的模型,新增了一层 fc4
class ModifiedNet(nn.Module):
    def __init__(self):
        super(ModifiedNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.fc4 = nn.Linear(10, 5)  # 新增的全连接层

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

# 实例化模型
modified_model = ModifiedNet()

# 尝试加载之前保存的 state_dict,但忽略不匹配的层
modified_model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

# 输出模型结构
print(modified_model)

输出

ModifiedNet(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (fc4): Linear(in_features=10, out_features=5, bias=True)
)

如果不设置 strict=False,将会报错,提示缺少 fc4 的权重:

RuntimeError: Error(s) in loading state_dict for ModifiedNet: Missing key(s) in state_dict: "fc4.weight", "fc4.bias". 

注意,即便减少层也可以使用 strict=False 也可以使用。例如,如果修改后的网络只保留前两层,仍然可以成功加载原始的 state_dict,并跳过缺失的部分。

方法三:保存完整的训练状态(checkpoint)

有时候,你可能不仅仅需要保存模型参数,还需要保存训练进度,比如当前的轮数、优化器状态等。此时可以使用检查点保存更多信息。

保存检查点

torch.save({
    'epoch': 100,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.01,
}, 'checkpoint.pth')

加载检查点

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Epoch: {epoch}, Loss: {loss}")

输出:

Epoch: 100, Loss: 0.01

这种方式适合长时间训练时,可以从中断的地方继续训练。但文件体积相比前面会更大,具体原因见《7. 探究模型参数与显存的关系以及不同精度造成的影响》,加载过程也稍微复杂一些,我们可以写一个函数来打包这个过程。

定义 checkpont 保存和加载的函数

def save_checkpoint(model, optimizer, epoch, loss, filepath='checkpoint.pth'):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, filepath)

def load_checkpoint(filepath, model, optimizer):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

# 保存
save_checkpoint(model, optimizer, 100, 0.01)

# 加载
epoch, loss = load_checkpoint('checkpoint.pth', model, optimizer)
print(f"Loaded checkpoint at epoch {epoch} with loss {loss}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2176298.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第445题两数相加II

题目&#xff1a; 题解&#xff1a; class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {Deque<Integer> stack1 new ArrayDeque<Integer>();Deque<Integer> stack2 new ArrayDeque<Integer>();while (l1 ! null) {stack1.…

AI Agent应用出路到底在哪?

1 Agent/Function Call 的定义 Overview of a LLM-powered autonomous agent system&#xff1a; Agent学会调用外部应用程序接口&#xff0c;以获取模型权重中缺失的额外信息&#xff08;预训练后通常难以更改&#xff09;&#xff0c;包括当前信息、代码执行能力、专有信息源…

《深度学习》OpenCV 角点检测、特征提取SIFT 原理及案例解析

目录 一、角点检测 1、什么是角点检测 2、检测流程 1&#xff09;输入图像 2&#xff09;图像预处理 3&#xff09;特征提取 4&#xff09;角点检测 5&#xff09;角点定位和标记 6&#xff09;角点筛选或后处理&#xff08;可选&#xff09; 7&#xff09;输出结果 3、邻域…

深度学习反向传播-过程举例

深度学习中&#xff0c;一般的参数更新方式都是梯度下降法&#xff0c;在使用梯度下降法时&#xff0c;涉及到梯度反向传播的过程&#xff0c;那么在反向传播过程中梯度到底是怎么传递的&#xff1f;结合自己最近的一点理解&#xff0c;下面举个例子简单说明&#xff01; 一、…

Qt开发技巧(九)去掉切换按钮,直接传样式文件,字体设置,QImage超强,巧用Qt的全局对象,信号槽断连,低量数据就用sqlite

继续讲一些Qt开发中的技巧操作&#xff1a; 1.去掉切换按钮 QTabWidget选项卡有个自动生成按钮切换选项卡的机制&#xff0c;有时候不想看到这个烦人的切换按钮&#xff0c;可以设置usesScrollButtons为假&#xff0c;其实QTabWidget的usesScrollButtons属性最终是应用到QTabWi…

衡石分析平台系统管理手册-功能配置之AI 助手集成嵌入指南

AI 助手集成嵌入指南​ 本文档将引导您通过几个简单的步骤&#xff0c;将 AI 助手集成或嵌入到您的系统中。HENGSHI SENSE AI 助手提供了多种集成方式&#xff0c;您可以通过 iframe、JS SDK 或 API 调用等方式将 AI 助手嵌入集成到您的系统中。 1. 通过 iframe 集成​ ifra…

老板最想要的20套模板!基于 VUE 国产开源 IoT 物联网 Web 可视化大屏设计器

如有需求&#xff0c;文末联系小编 Cola-Designer 是一个基于VUE开发&#xff0c;实现拖拽和配置方式生成数据大屏&#xff0c;提供丰富的可视化模板&#xff0c;满足客户业务监控、数据统计、风险预警、地理信息分析等多种业务的展示需求。Cola-Designer 帮助工程师通过图形化…

MySQL - 单表增删改

1. MySQL 概述 MySQL 是一种流行的开源关系型数据库管理系统 (DBMS)&#xff0c;广泛应用于互联网公司和企业开发中。它支持 SQL 语句操作数据&#xff0c;并提供多种版本供选择。 1.1 MySQL 安装和连接 社区版&#xff1a;免费版本&#xff0c;适合开发者使用。商业版&…

sizeof 和 strlen

一 . sizeof 关键字 这个是我们的老朋友了昂&#xff0c;经常都在使用&#xff0c;它是专门用来计算变量所占内存空间大小的&#xff0c;单位是字节&#xff0c;当然&#xff0c;如果我们的操作对象是类型的话&#xff0c;计算的就是类型所创建的变量所占内存的大小&#xff0…

【笔记】神领物流day1.1.13前后端部署【未完】

使用jenkins 前端部署 需要将前端开发的vue进行编译&#xff0c;发布成html&#xff0c;然后通过nginx进行访问&#xff0c;这个过程已经在Jenkins中配置&#xff0c;执行点击发布即可 网址栏输入神领TMS管理系统 (sl-express.com)即可看见启动成功 后端部署看linux 回到Jenki…

25维谛技术面试最常见问题面试经验分享总结(包含一二三面题目+答案)

开头附上工作招聘面试必备问题噢~~包括综合面试题、无领导小组面试题资源文件免费&#xff01;全文干货。 【免费】25维谛技术面试最常见问题面试经验分享总结&#xff08;包含一二三面题目答案&#xff09;资源-CSDN文库https://download.csdn.net/download/m0_72216164/8979…

单调递增/递减栈

单调栈 单调栈分为单调递增栈和单调递减栈 单调递增栈&#xff1a;栈中元素从栈底到栈顶是递增的 单调递减栈&#xff1a;栈中元素从栈底到栈顶是递减的 应用&#xff1a;求解下一个大于x元素或者是小于x的元素的位置 给一个数组&#xff0c;返回一个大小相同的数组&#x…

一文了解:最新版本 Llama 3.2

Meta AI最近发布了 Llama 3.2。这是他们第一次推出可以同时处理文字和图片的多模态模型。这个版本主要关注两个方面&#xff1a; 视觉功能&#xff1a;他们现在有了能处理图片的模型&#xff0c;参数量从11亿到90亿不等。 轻量级模型&#xff1a;这些模型参数量在1亿到3亿之间…

llamafactory0.9.0微调qwen2vl

LLaMA-Factory/data/README_zh.md at main hiyouga/LLaMA-Factory GitHubEfficiently Fine-Tune 100+ LLMs in WebUI (ACL 2024) - LLaMA-Factory/data/README_zh.md at main hiyouga/LLaMA-Factoryhttps://github.com/hiyouga/LLaMA-Factory/blob/main

【Java SE】初遇Java,数据类型,运算符

&#x1f525;博客主页&#x1f525;&#xff1a;【 坊钰_CSDN博客 】 欢迎各位点赞&#x1f44d;评论✍收藏⭐ 1. Java 概述 1.1 Java 是什么 Java 是一种高级计算机语言&#xff0c;是一种可以编写跨平台应用软件&#xff0c;完全面向对象的程序设计语言。Java 语言简单易学…

Android平台如何获取CPU占用率和电池电量信息

技术背景 我们在做Android平台GB28181设备接入模块、轻量级RTSP服务模块和RTMP推流模块的时候&#xff0c;遇到这样的技术诉求&#xff0c;开发者希望把实时CPU占用、电池信息等叠加在视频界面。 获取CPU占用率 Android平台获取CPU占用情况&#xff0c;可以读取/proc/stat文…

第十三届蓝桥杯真题Java c组D.求和(持续更新)

博客主页&#xff1a;音符犹如代码系列专栏&#xff1a;蓝桥杯关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 【问题描述】 给定 n 个整数 a1, a2, , an &#xff0c;求它们两两相乘再相…

生信初学者教程(十六):GO富集分析

文章目录 介绍加载R包导入数据所需函数运行输出结果总结介绍 GO(Gene Ontology)是一个在生物信息学中广泛使用的概念,用于描述基因和基因产物的功能、它们所处的细胞位置以及它们参与的生物过程。GO项目是一个协作性的国际努力,旨在建立和维护一个适用于各种物种的、结构化…

用Python实现运筹学——Day 6: 单纯形法求解过程

一、学习内容 1. 单纯形法的详细步骤 单纯形法是通过迭代过程来优化线性规划问题的解决方案。该算法从可行解空间的一个顶点出发&#xff0c;逐步沿着可行解空间的边界移动到另一个顶点&#xff0c;直到找到最优解。单纯形法的求解过程分为以下几个步骤&#xff1a; 初始化&a…

EE trade:黄金T+D是什么意思

黄金TD&#xff0c;全称“黄金延期交割”&#xff0c;是由上海黄金交易所推出的标准化合约&#xff0c;允许投资者以保证金的形式进行黄金交易&#xff0c;并可以选择当日交割或延期交割。它为国内投资者提供了一个全新的黄金投资渠道&#xff0c;但也存在一些风险&#xff0c;…