【Pytorch】pytorch中保存模型的三种方式

news2024/10/6 4:03:11

【Pytorch】pytorch中保存模型的三种方式

文章目录

  • 【Pytorch】pytorch中保存模型的三种方式
    • 1. torch保存模型相关的api
      • 1.1 torch.save()
      • 1.2 torch.load()
      • 1.3 torch.nn.Module.load_state_dict()
      • 1.4 什么是state_dict()
        • 1.4. 1 举个例子
    • 2. pytorch模型文件后缀
    • 3. 存储整个模型
      • 3.1 直接保存整个模型
      • 3.2 直接加载整个模型
    • 4. 只保存模型的权重
      • 4.1 保存模型权重
      • 4.2 读取模型权重
    • 5. 使用Checkpoint保存中间结果
      • 5.1 保存Checkpoint
      • 5.2 加载Checkpoint
    • Reference

1. torch保存模型相关的api

1.1 torch.save()

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

参考自https://pytorch.org/docs/stable/generated/torch.save.html#torch-save

Image

torch.save()的功能是保存一个序列化的目标到磁盘当中,该函数使用了Python中的pickle库用于序列化,具体参数的解释如下

参数功能
obj需要保存的对象
f指定保存的路径
pickle_module用于 pickling 元数据和对象的模块
pickle_protocol指定 pickle protocal 可以覆盖默认参数

常见用法

# dirctly save entiry model
torch.save('model.pth')
# save model'weights only
torch.save(model.state_dict(), 'model_weights.pth')
# save checkpoint
checkpint = {
	'model_state_dict': model.state_dict(),
	'optimizer_state_dict': optimizer.state_dict(),
	'loss': loss,
	'epoch': epoch
}
torch.save(checkpoint, 'checkpoint_path.pth')

1.2 torch.load()

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)

参考自https://pytorch.org/docs/stable/generated/torch.load.html#torch-load

Image

torch.load()的功能是加载模型,使用python中的unpickle工具来反序列化对象,并且加载到对应的设备上,具体的参数解释如下

参数功能
f对象的存放路径
map_location需要映射到的设备
pickle_module用于 unpickling 元数据和对象的模块

常见用法

# specify the device to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load entiry model to cuda if available
model = torch.load('whole_model.pth', map_location=device)
# load model's weight to cuda if available
model.load_state_dict(torch.load('model_weights.pth'), map_location=device)
# load checkpoint
checkpoint = torch.load('checkpoint_path.pth', map_location=device)
# checkpoint加载出来就像个字典,预先保存的是否放置了什么内容,加载之后就可以这样来获取
loss = checkpoint['loss']
epoch = chekpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict']
optimizer.load_state_dict(checkpoint['optimizer_state_dict']


1.3 torch.nn.Module.load_state_dict()

torch.nn.Module.load_state_dict(state_dict, strict=True, assign=False)

参考自https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict

Image

torch.nn.Module.load_state_dict()将参数和缓冲区从 state_dict 复制到此模块及其后代中。 如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。具体的参数描述如下

参数功能
state_dict保存parameters和persistent buffers的字典
strict是否强制要求state_dict中的key和model.state_dict返回的key严格一致

1.4 什么是state_dict()

torch.nn.Module.state_dict()

参考自https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict

Image

其实state_dict可以理解为一种简单的Python Dictionary,其功能是将每层之间的参数进行一一映射并且存储在python的数据类型字典中。因此state_dict可以轻松地进行修改、保存等操作。

除了torch.nn.Module拥有state_dict()方法之外,torch.optim.Optimizer也具有state_dict()方法。如下所示

torch.optim.Optimizer.state_dict()

参考自https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.state_dict.html

1.4. 1 举个例子
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 100)
        self.fc2 = nn.Linear(100, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


if __name__ == "__main__":
    model = SimpleModel(10, 2)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("Check Model's State Dict:")
    for key, value in model.state_dict().items():
        print(key, "\t", value.size())
    
    print("Check Optimizer's State Dict:")
    for key, value in optimizer.state_dict().items():
        print(key, "\t", value)

输出的结果如下

Check Model's State Dict:
fc1.weight       torch.Size([100, 10])
fc1.bias         torch.Size([100])
fc2.weight       torch.Size([2, 100])
fc2.bias         torch.Size([2])
Check Optimizer's State Dict:
state    {}
param_groups     [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3]}]

2. pytorch模型文件后缀

常用的torch模型文件后缀有.pt.pth,这是最常见的PyTorch模型文件后缀,表示模型的权重、结构和状态字典(state_dict)都被保存在其中。

torch.save(model.state_dict(), 'model_weights.pth')
torch.save(model, 'full_model.pt')

还有检查点后缀如.ckpt.checkpoint,这些后缀常被用于保存模型的检查点,包括权重和训练状态等。它们也可以表示模型的中间状态,以便在训练期间从中断的地方继续训练。

checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    # 其他信息
}
torch.save(checkpoint, 'model_checkpoint.ckpt')

还有其他跨框架的数据结构例如.h5,PyTorch的模型也可以保存为HDF5文件格式用于跨框架的数据交换,可以使用h5py库来进行读写

import h5py

with h5py.File('model.h5', 'w') as f:
    # 将模型参数逐一保存到HDF5文件
    for name, param in model.named_parameters():
        f.create_dataset(name, data=param.numpy())

3. 存储整个模型

可以直接使用torch.save()torch.load()来加载和保存整个模型到文件中,这种方式保存了模型的所有权重、架构及其其他相关信息,即使不知道模型的结构也能够直接通过权重文件来加载模型

3.1 直接保存整个模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


if __name__ == "__main__":
    model = SimpleModel(10, 2)

    # specify the save path
    url = os.path.dirname(os.path.realpath(__file__)) + '/models/'
    # 如果路径不存在则创建
    if not os.path.exists(url):
        os.makedirs(url)
    # specify the model save name
    model_name = 'simple_model.pth'
    # save the model to file
    torch.save(model, url + model_name)

我们直接将模型保存到了当前文件夹下的./models文件夹中,

3.2 直接加载整个模型

由于我们已经保存了模型的所有相关信息,所以我们可以不知道模型的结构也能加载该模型,如下所示

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


if __name__ == "__main__":

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # model = SimpleModel(10, 2)

    # specify the save path
    url = os.path.dirname(os.path.realpath(__file__)) + '/models/'

    # 如果路径不存在则创建
    if not os.path.exists(url):
        os.makedirs(url)

    # specify the model save name
    model_name = 'simple_model.pth'
	
	# load the model
    if os.path.exists(url + model_name):
        model = torch.load(url + model_name, map_location=device)
        print("Success Load Model From:\n\t%s"%(url+model_name))

成功加载了模型


4. 只保存模型的权重

4.1 保存模型权重

利用前面提到的state_dict()方法来完成这一操作

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


if __name__ == "__main__":
	# specify device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleModel(10, 2)

    # specify the save path
    url = os.path.dirname(os.path.realpath(__file__)) + '/models/'

    # 如果路径不存在则创建
    if not os.path.exists(url):
        os.makedirs(url)

    # specify the model save name
    model_name = 'simple_model_weights.pth'

    torch.save(model.state_dict(), url + model_name)

我们直接将模型权重保存到了当前文件夹下的./models文件夹中,

4.2 读取模型权重

由于我们只保存了模型的权重信息,不知道模型的结构,所以必须要先实例化模型才行。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


if __name__ == "__main__":
    # specify device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # get model
    model = SimpleModel(10, 2)

    # specify the save path
    url = os.path.dirname(os.path.realpath(__file__)) + '/models/'

    # 如果路径不存在则创建
    if not os.path.exists(url):
        os.makedirs(url)
    # specify the model save name
    model_name = 'simple_model_weights.pth'
    if os.path.exists(url + model_name):
        model.load_state_dict(torch.load(url + model_name, map_location=device))
        print("Success Load Model'weights From:\n\t%s"%(url+model_name))

5. 使用Checkpoint保存中间结果

5.1 保存Checkpoint

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os

# 数据准备
x = torch.tensor(np.random.rand(100, 1), dtype=torch.float32)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)

# 定义模型
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

if __name__=="__main__":
    # specify device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 实例化模型
    model = SimpleLinearModel()


    # 定义损失函数和优化器
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # 训练循环
    num_epochs = 1000
    checkpoint_interval = 100  # 保存检查点的间隔
    url = os.path.dirname(os.path.realpath(__file__))+'/models/'
    if not os.path.exists(url):
        os.makedirs(url)
    checkpoint_file = 'checkpoint.pth'  # 检查点文件路径

    for epoch in range(num_epochs):
        # 前向传播
        outputs = model(x)
        loss = criterion(outputs, y)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 打印训练信息
        if (epoch + 1) % checkpoint_interval == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
            
            # 保存检查点
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
            }
            torch.save(checkpoint, url+checkpoint_file)

5.2 加载Checkpoint

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os

# 数据准备
x = torch.tensor(np.random.rand(100, 1), dtype=torch.float32)
y = 3 * x + 2 + 0.1 * torch.randn(100, 1)

# 定义模型
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

if __name__=="__main__":
    # specify device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 实例化模型
    model = SimpleLinearModel()


    # 定义损失函数和优化器
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # 训练循环
    num_epochs = 1000
    checkpoint_interval = 100  # 保存检查点的间隔
    url = os.path.dirname(os.path.realpath(__file__))+'/models/'
    if not os.path.exists(url):
        os.makedirs(url)
    checkpoint_file = 'checkpoint.pth'  # 检查点文件路径

    # load from checkpoint
    checkpoint = torch.load(url+checkpoint_file)
    for key, value in checkpoint.items():
        print(key, '-->', value)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print('Loaded checkpoint from epoch %d. Loss %f' % (epoch, loss))

输出如下

loss --> 0.01629752665758133
(test_ros_python) sjh@sjhR9000X:~/Documents/python_draft$  cd /home/sjh/Documents/python_draft ; /usr/bin/env /home/sjh/anaconda3/envs/metaRL/bin/python /home/sjh/.vscode/extensions/ms-python.python-2023.18.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 40897 -- /home/sjh/Documents/python_draft/check_checkpoint.py 
epoch --> 1000
model_state_dict --> OrderedDict([('linear.weight', tensor([[2.6938]])), ('linear.bias', tensor([2.1635]))])
optimizer_state_dict --> {'state': {0: {'momentum_buffer': None}, 1: {'momentum_buffer': None}}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}
loss --> 0.01629752665758133
Loaded checkpoint from epoch 1000. Loss 0.016298

我们成功从断点处加载checkpoint, 可以再从这个断点处继续训练

Reference

参考一

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1092423.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【2024波哥讲言语视频全套】

2024波哥讲言语视频全套 有需要的同学可以通过百度网盘下载 通过百度网盘分享的文件:【38】2024… 链接:https://pan.baidu.com/s/10GMG9tu8RGrXuU2uJWaysw?pwdfpl6 提取码:fpl6 复制这段内容打开「百度网盘APP 即可获取」

PHP家教系统平台源码/请家教兼职家教网源码/自适应手机端/实测

源码简介: PHP家教系统平台源码/请家教兼职家教网源码/自适应实测,它支持兼职家教和请家教两种模式。该程序提供了完整的功能,包括家教信息发布、家教需求发布、信息匹配、在线支付等。此外,该程序还可以自适应手机端&#xff0c…

CSS之排列系列--顶部导航栏ul、li居中展示的方法

原文网址:CSS之排列系列--顶部导航栏ul、li居中展示的方法_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍CSS顶部导航栏ul、li居中展示的方法。 核心方法 ul的父层使用:text-align: center ul元素使用:display: inline-block; 示例 …

C++ 反向迭代器

反向迭代器的即正向迭代器的--,反向迭代器的--即正向迭代器的,反向迭代器和正向迭代器的很多功能都是相似的,因此我们可以复用正向迭代器作为反向迭代器的底层容器来封装,从而实现出反向迭代器,即:反向迭代…

Linux 文件系统逻辑结构图的解释

task_struct进程结构体,表示一个运行的进程。 task_struct中的fs指向fs_struct结构体。fs_struct表示这个进程支持的文件系统。 root指向根目录dentry,dentry中的d_inode指向改进程根目录在存储设备中的inode节点。 pwd指向当前进程所在的目录结构体den…

部署k8s dashboard(这里使用Kubepi)

9. 部署k8s dashboard(这里使用Kubepi) Kubepi是一个简单高效的k8s集群图形化管理工具,方便日常管理K8S集群,高效快速的查询日志定位问题的工具 部署KubePI(随便在哪个节点部署,我这里在主节点部署&#…

记录单片机编码的坑

问题描述 在使用clion调试过程中,发现使用 mbstowcs函数转换后的数组仍为原数组 因而单片机中不能直接将ascii码(此处为编写代码的格式,例如GBK格式)转换为Unicode格式,这个Bug先记录下来,后续解决了再贴上方法!

【ppt密码】ppt的密码忘了,怎么破解

PPT文件设置了保护密码,但是密码忘记了,无法打开PPT文件、无法编辑PPT文件了该怎么办?PPT文件的两种保护密码该如何解密? 首先是打开密码 网上有一种解决方法: 1、重新命名PPT文件,将其后缀改为zip格式&…

【LeetCode刷题(数据结构)】:检查两颗树是否相同

给你两棵二叉树的根节点 p 和 q ,编写一个函数来检验这两棵树是否相同 如果两个树在结构上相同,并且节点具有相同的值,则认为它们是相同的 输入:p [1,2,3], q [1,2,3] 输出:true 输入:p [1,2], q [1,…

一篇文章带你用动态规划解决股票购买时机问题

动态规划的解题步骤可以分为以下五步,大家先好好记住 1.创建dp数组以及明确dp数组下标的含义 2.制定递推公式 3.初始化 4.遍历顺序 5.验证结果 股票购买时机问题的解题核心思路 当天的收益是根据前一天持有股票还是不持有股票的状态决定的 那么很自然的我们就想…

Dijkstra求最短路(图解)

你好,我是Hasity。 今天分享的内容:Dijkstra求最短路这个题目 Dijkstra求最短路I 题目描述 给定一个 n个点 m 条边的有向图,图中可能存在重边和自环,所有边权均为正值。 请你求出 1 号点到 n号点的最短距离,如果无…

关于导入Maven工程项目,更新pom.xml文件仍然爆红的原因

问题描述: 在学习maven工程的时候,把从网上学习的工程导入到IDEA,发现,无论怎么更新,pom.xml文件一直报错,查看settings设置和project Structure仍然没找出问题来。 settings设置如下: 解决问…

短视频如何批量添加水印

在当今的数字时代,短视频已经成为一种非常流行的内容形式。无论是社交媒体还是视频分享网站,短视频都已经成为了一种非常有吸引力的内容。然而,对于一些拥有大量视频内容的创作者来说,添加水印可能是一项繁琐的任务。本文将介绍如…

【windows下docker安装rocketMQ】

namesrv和broker安装就不说了,见如下博客 https://blog.csdn.net/Wonderful1025/article/details/107244434/ 安装rocketMQ-console docker run -d -e "JAVA_OPTS-Drocketmq.config.namesrvAddr192.168.65.2:9876 -Drocketmq.config.isVIPChannelfalse"…

__builtin_return_address()函数的使用方法

__builtin_return_address(0) 是GCC编译器提供的内置函数,用于获取当前函数调用栈中的指定帧(frame)的返回地址。这个函数通常用于调试和性能分析,以了解程序中的函数调用关系。 下面是关于 __builtin_return_address(0) 函数的一…

SystemC入门学习-第5章 同步逻辑建模

本章重点学习同步逻辑中的触发器,锁存器的一些建模规范: 触发器建模带异步置位/复位带同步置位/复位锁存器建模 5.1 触发器建模 触发器建模的关键是敏感列表的规范。SC_MODULE的规范写法中出现过sensitive 参数列表是事件敏感, 对触发器建模…

操作系统学习笔记4-死锁问题

文章目录 1、死锁逻辑图2、死锁三胞胎3、死锁的原因及必要条件4、死锁处理策略之死锁预防5、死锁处理策略之死锁避免(银行家算法)6、死锁处理策略之死锁检测与解除 1、死锁逻辑图 2、死锁三胞胎 3、死锁的原因及必要条件 4、死锁处理策略之死锁预防 5、死…

查找组成一个偶数最接近的两个素数

一、题目 二、代码 #include <iostream> using namespace std; bool isPrime(int num)//判断素数 {if (num < 1)return false;if (num 2)return true;if (num % 2 0)return false;for (int i 3; i < num; i){if (num % i 0){return false;}}return true; } in…

win11下的VS2022+QT6+VTK9.2+PCL1.13.1联合开发环境配置及踩坑记录

准备工作&#xff1a; 安装VS2022&#xff1a;这个比较简单&#xff0c;网上随便找个教程就行 安装QT并为VS2022添加QT Creater插件&#xff1a;VS2022配置Qt6_vs2022 qt6-CSDN博客 安装PCL&#xff1a;vs2022配置pcl1.13.1_pcl配置-CSDN博客 安装PCL过程中本身也会安装VTK&…

小程序入门及案例展示

目录 一、小程序简介 1.1 为什么要使用小程序 1.2 小程序可以干什么 二、前期准备 2.1 申请账号 2.2 开发工具下载与安装 三、电商案例演示 四、入门案例 4.1 项目结构解析 4.2 基础操作及语法 4.3 模拟器 4.4 案例演示 4.4.1 新建页面 4.4.2 头部样式设置 4.4.…