Pytorch单、多GPU和CPU训练模型保存和加载

news2025/1/7 17:35:52

Pytorch多GPU训练模型保存和加载

在多GPU训练中,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。(本质上原来的model还是存在的,参数也会同步更新)

  1. 多GPU训练模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。

  2. 单GPU或CPU加载模型
    当在单GPU或CPU上加载模型时,如果直接使用model.state_dict()保存的模型,由于缺少module前缀,会导致参数名不匹配,从而无法正确加载模型。因此,在保存多GPU训练的模型时,应该使用model.module.state_dict()来保存模型的状态字典,这样在单GPU或CPU上加载模型时,可以直接加载,不会出现参数名不匹配的问题。

  3. 示例代码
    以下是一个示例代码,展示了如何在多GPU训练时保存模型,并在单GPU或CPU上加载模型:

import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = YourModel()

# 将模型移动到多GPU上
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model = model.to(device)
else:
	model = model.to(device)
······
# 假设这是你的训练代码,训练完成后保存模型
if torch.cuda.device_count() > 1:
    torch.save(model.module.state_dict(), 'model.pth')
else:
    torch.save(model.state_dict(), 'model.pth')

# 在单、多GPU或CPU上加载模型
model = YourModel()
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)

2 在多GPU训练得到的模型加载时,通常需要考虑以下几个步骤:

  1. 模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中。因此,在保存模型时,需要确保保存的是模型的state_dict而不是整个模型对象。例如:
if torch.cuda.device_count() > 1:
    torch.save(model.module.state_dict(), 'model.pth')
else:
    torch.save(model.state_dict(), 'model.pth')
  1. 模型加载
    在加载模型时,首先需要创建模型的实例,然后使用load_state_dict方法来加载保存的权重。如果模型是在多GPU环境下训练的,那么在加载时也应该使用torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel来包装模型。例如:
model = YourModel()
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to('cuda')
  1. 注意事项
    在加载模型时,需要注意以下几点:

如果模型是在多GPU环境下训练的,那么在加载时也应该使用相同数量的GPU,或者使用torch.nn.DataParallel来包装模型,即使只有一个GPU可用。
如果模型是在分布式训练环境下训练的,那么在加载时也应该使用torch.nn.parallel.DistributedDataParallel来包装模型。
如果模型是在混合精度训练(如使用了torch.cuda.amp)下训练的,那么在加载模型后,应该恢复之前的精度设置。

3 为了避免模型保存和加载出错

在多GPU训练的模型使用了torch.nn.DataParallel来包装模型,但本质上原来的model是依然存在的,且参数会同步更新:

  1. torch.nn.DataParallel 的工作原理
    torch.nn.DataParallel 是 PyTorch 提供的一个类,用于在多个 GPU 上并行训练模型。它的工作原理如下:
    模型复制:DataParallel 会在每个 GPU 上创建模型的副本。
    数据分发:输入数据会被分发到各个 GPU 上。
    前向传播:每个 GPU 上的模型副本会独立进行前向传播计算。
    梯度收集:所有 GPU 上的梯度会被收集并汇总到主 GPU 上。
    参数更新:主 GPU 上的优化器会根据汇总后的梯度更新模型参数,然后将更新后的参数同步回其他 GPU。
  2. 模型参数更新
    当你使用 model_train = torch.nn.DataParallel(model) 后,model_train 实际上是一个包装了原始模型 model 的对象。虽然 model_train 是多GPU并行的版本,但它的参数更新是通过主 GPU 上的优化器完成的,并且这些更新会同步回原始模型 model
    因此,model 的参数确实会被更新。具体来说:
    前向传播和反向传播:在 train_model 函数中,model_train 用于前向传播和反向传播。
    参数更新:优化器 optimizer 使用的是 model.parameters(),即原始模型的参数。在每次迭代中,优化器会根据汇总后的梯度更新这些参数。
    参数同步:更新后的参数会自动同步到 model_train 中的各个 GPU 副本。
    因此可以使用如下代码,加载模型和保存模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = YourModel()

# 将模型移动到多GPU上,单GPU依然适用
if torch.cuda.device_count() > 1:
	model_train = nn.DataParallel(model)
	model_train = model_train.to(device)
else:
	model_train = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#注意这是model的参数
······
output = model_train(input)	# 多卡时训练的输入和输出,注意这是model_train

# 假设这是你的训练代码,训练完成后保存模型
torch.save(model.state_dict(), 'model.pth')	#注意这是model

  • 再在单/多GPU或CPU上加载模型,都不会报错,因为这里的model不是包装体,不带module
model = YourModel()
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth',map_location = device))
model = model.to(device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272028.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

让 Agent 具备语音交互能力:技术突破与应用前景(16/30)

让 Agent 具备语音交互能力:技术突破与应用前景 一、引言 在当今数字化时代,人机交互方式正经历着深刻的变革。从早期的命令行界面到图形用户界面,再到如今日益普及的语音交互,人们对于与机器沟通的便捷性和自然性有了更高的追求…

学生作业完成情况管理程序

网上看到的一个课程设计,正好练练手。 首先设计数据库 数据库有三张表,分别是班级表,学生表,作业成绩表。 学生表中外键关联班级表,作业成绩表中外键关联学生表。具体如下图所示 班级表 学生表学生表外键关联 …

基于vue的商城小程序的毕业设计与实现(源码及报告)

环境搭建 ☞☞☞ ​​​Vue入手篇(一),防踩雷(全网最详细教程)_vue force-CSDN博客 目录 一、功能介绍 二、登录注册功能 三、首页 四、项目截图 五、源码获取 一、功能介绍 用户信息展示:页面顶部设有用户头像和昵称展示区,方便用户识别…

DeepSeek V3“报错家门”:我是ChatGPT

搜 :海讯无双Ai 要说这两天大模型圈的顶流话题,那绝对是非DeepSeek V3莫属了。 不过在网友们纷纷测试之际,有个bug也成了热议的焦点—— 只是少了一个问号,DeepSeek V3竟然称自己是ChatGPT。 甚至让它讲个笑话,生成…

利用webworker解决性能瓶颈案例

目录 js单线程的问题webworker的基本使用webworker的常见应用可视化优化导出Excel js单线程的问题 众所周知,js不擅长计算,计算是同步的,大规模的计算会让js主线程阻塞,导致界面完成卡死。比如有一个600多亿次的计算,…

深入理解卷积神经网络(CNN):图像识别的强大工具

1、引言 卷积神经网络(CNN)是一种深度学习模型,特别适合分析视觉数据。它们在处理图像和视频任务时表现尤为出色。由于CNN在物体识别方面的高效性,这种网络架构广泛应用于计算机视觉领域,例如图像分类、物体检测、面部…

R语言安装教程与常见问题

生物信息基础入门笔记 R语言安装教程与常见问题 今天和大家聊一个非常基础但是很重要的技术问题——如何在不同操作系统上安装R语言?作为生物信息学数据分析的神兵利器,R语言的安装可谓是入门第一步,学术打工人的必备技能。今天分享在Windows…

VOC数据集格式转YOLO格式

将VOC格式的数据集转换为YOLO格式通常涉及以下几个步骤。YOLO格式的标注文件是每个图像对应一个.txt文件&#xff0c;文件中每一行表示一个目标&#xff0c;格式为&#xff1a; <class_id> <x_center> <y_center> <width> <height>其中&#xf…

win10搭建zephyr开发环境

搭建环境基于 zephyr官方文档 基于官方文档一步一步走很快就可以搞定 一、安装chocolatey 打开官网 https://community.chocolatey.org/courses/installation/installing?methodinstall-from-powershell-v3 1、用管理员身份打开PowerShell &#xff08;1&#xff09;执行 …

物体切割效果

1、物体切割效果是什么 在游戏开发中&#xff0c;物体切割效果就是物体看似被切割、分割或隐藏一部分的视觉效果。 这种效果常用与游戏和动画中&#xff0c;比如角色攻击时的切割效果&#xff0c;场景中的墙壁切割效果等等。 2、物体切割效果的基本原理 在片元着色器中判断片…

k8s集群监控系统部署方案

1.方案介绍 本文介绍一种k8s集群监控系统,该系统可以监控k8s集群中的pod和node的性能指标,以及K8s资源对象的使用情况。 监控流程: 集群资源数据采集(cadvisor、node-exporter、kube-state-metrics)-- 数据收集、存储、处理等(prometheus)-- 数据可视化查询和展示(gra…

RP2K:一个面向细粒度图像的大规模零售商品数据集

这是一种用于细粒度图像分类的新的大规模零售产品数据集。与以往专注于相对较少产品的数据集不同&#xff0c;我们收集了2000多种不同零售产品的35万张图像&#xff0c;这些图像直接在真实的零售商店的货架上拍摄。我们的数据集旨在推进零售对象识别的研究&#xff0c;该研究具…

Linux(Centos 7.6)命令详解:ls

1.命令作用 列出目录内容(list directory contents) 2.命令语法 Usage: ls [OPTION]... [FILE]... 3.参数详解 OPTION: -l&#xff0c;long list 使用长列表格式-a&#xff0c;all 不忽略.开头的条目&#xff08;打印所有条目&#xff0c;包括.开头的隐藏条目&#xff09…

比QT更高效的一款开源嵌入式图形工具EGT-Ensemble Graphics Toolkit

文章目录 EGT-Ensemble Graphics Toolkit介绍EGT具备非常高的图形渲染效率EGT采用了非常优秀的开源2D图形处理引擎-Cairo开源2D图形处理引擎Cairo的优势Cairo 2D图像引擎的性能Cairo 2D图像引擎的实际应用案例彩蛋 - 开源EDA软件KiCAD也在使用Cairo EGT高效的秘诀还有哪些Cairo…

密码学精简版

密码学是数学上的一个分支&#xff0c;同时也是计算机安全方向上很重要的基础原理&#xff0c;设置密码的目的是保证信息的机密性、完整性和不可抵赖性&#xff0c;安全方向上另外的功能——可用性则无法保证&#xff0c;可用性有两种方案保证&#xff0c;冗余和备份&#xff0…

WPF通过反射机制动态加载控件

Activator.CreateInstance 是 .NET 提供的一个静态方法&#xff0c;它属于 System 命名空间。此方法通过反射机制根据提供的类型信息。 写一个小demo演示一下 要求&#xff1a;在用户反馈界面点击建议或者评分按钮 弹出相应界面 编写MainWindow.xmal 主窗体 <Window x:C…

C语言 递归编程练习

1.将参数字符串中的字符反向排列&#xff0c;不是逆序打印。 要求&#xff1a;不能使用C函数库中的字符串操作函数。 比如&#xff1a; char arr[] "abcdef"; 逆序之后数组的内容变成&#xff1a;fedcba 1.非函数实现&#xff08;循环&#xff09; 2.用递归方法…

数据插入操作的深度分析:INSERT 语句使用及实践

title: 数据插入操作的深度分析:INSERT 语句使用及实践 date: 2025/1/5 updated: 2025/1/5 author: cmdragon excerpt: 在数据库管理系统中,数据插入(INSERT)操作是数据持久化的基础,也是应用程序与用户交互的核心功能之一。它不仅影响数据的完整性与一致性,还在数据建…

【Linux系列】使用 `nohup` 命令运行 Python 脚本并保存输出日志的详细解析

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【USRP】教程:在Macos M1(Apple芯片)上安装UHD驱动(最正确的安装方法)

Apple芯片 前言安装Homebrew安装uhd安装gnuradio使用b200mini安装好的路径下载固件后续启动频谱仪功能启动 gnu radio关于博主 前言 请参考本文进行安装&#xff0c;好多人买了Apple芯片的电脑&#xff0c;这种情况下&#xff0c;可以使用UHD吗&#xff1f;答案是肯定的&#…