PyTorch多GPU训练模型——使用单GPU或CPU进行推理的方法

news2024/11/18 19:27:51

文章目录

  • 1 问题描述
  • 2 模型保存方式
  • 3 单块GPU上加载模型
  • 4 CPU上加载模型
  • 5 总结

1 问题描述

PyTorch提供了非常便捷的多GPU网络训练方法:DataParallelDistributedDataParallel。在涉及到一些复杂模型时,基本都是采用多个GPU并行训练并保存模型。但在推理阶段往往只采用单个GPU或者CPU运行。这时怎么将多GPU环境下保存的模型权重加载到单GPU/CPU运行环境下的模型上成了一个关键的问题。

如果忽视环境问题直接加载往往会出现两类问题:

1 出现错误:IndexError: list index out of range

出现这个错误的原因是:现有模型参数是在多GPU上获得并保存的,因此在读入时默认会保存至对应的GPU上,但是目前推理环境中只有一块GPU,所以导致那些本来在其它GPU上的参数找不到自己应该去的GPU编号,出现了一个溢出错误,本质是GPU编号溢出。

2 出现错误:Missing key(s) in state_dict:

出现这个错误的原因是:由于模型训练和推理的环境不同,导致一些参数丢失,因此报错。目前在网上的一些解决策略是忽视这些丢失的参数,例如使用命令:model.load_state_dict(torch.load('model.pth'), strict=False)
来成功导入模型。这条命令可以让程序不报错并看似成功的导入模型参数。但实际上这条命令的含义是在导入模型参数时通过设置 strict=False 来忽略丢失的参数,也就是说那些丢失参数地方的模型权重初仍为初始化随机状态,等同于没有进行训练和学习,何谈推理与验证!!!

2 模型保存方式

不论是用哪种方式进行推理,在训练的时候要保证程序保存模型的方式是这样的:

torch.save(model.state_dict(), "model.pth")

3 单块GPU上加载模型

将多GPU训练的权重文件加载到单GPU上:

# 1 加载模型
model = Model()
# 2 指定运行设备,这里为单块GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 3 将模型用DataParallel方法封装一次
model = torch.nn.DataParallel(model)
# 4 将模型读入到GPU设备上
model = model_E2E.to(device)
# 5 加载权重文件
model.load_state_dict(torch.load(weight_path, map_location=device))

通过上面的程序就可以实现将多块GPU上训练得到的权重文件加载到单块GPU环境下的模型中。这里有两点需要注意:

  • 在多GPU训练时,模型使用了 DataParallel
    DistributedDataParallel 方法,这两种并行化工具会修改模型的结构,将模型封装在一个新的模块中,通常名为:module因此在权重文件中保存的模型是经过 DataParallel 封装后的结构。为了能够载入全部参数,需要通过步骤3使推理模型与原始多GPU训练模型在结构上保持一致。

  • 在步骤5加载模型参数时使用了map_location 参数。这个参数会告诉 PyTorch在加载模型时应该将张量放置在哪个设备上。设置map_location=device,那么无论模型原来是在哪个设备上训练的,现在都将放置在指定的设备device='cuda:0'上。

4 CPU上加载模型

在CPU上加载模型:

from collections import OrderedDict

# 1 加载模型
model = Model()
# 2 指定设备CPU
device = "cpu"
# 3 读取权重文件
state_dict = torch.load(weight_path, map_location=device)
# 4 剥除权重文件中的module层
if next(iter(state_dict)).startswith("module."):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    state_dict = new_state_dict
# 5 加载权重文件
model.load_state_dict(state_dict)
# 6 将模型载入到CPU
model = model.to(device)

在CPU上加载模型的逻辑和GPU的差不多,核心都是因为原权重文件中的模型被封装成了module.Model,所以需要将这层外壳去掉,最后再进行读取并将模型加载到CPU上。

5 总结

在深度学习任务中训练与推理环境存在差异的情况十分常见 ,有差异的环境下实现网络权重文件的正确读取十分重要。实际操作中一定要确保正确的权重文件被读入,这是进行推理最基本的前提!最好在推理前做一些对比实验(例如:选取一部分数据,分别套用已有的程序进行训练和推理,对比二者的效果)来确保已经读入到正确的权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/983865.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Liunx环境安装字体(simsun为例)

一:下载simsun字体文件包 链接:https://pan.baidu.com/s/1jelox8MalDJDWTyx4Z9ghw 提取码:tttt二:把解压后的simsun.ttf、simsun.ttc放到 /usr/share/fonts目录 三:安装 // 刷新字体缓存 [rootxxxxxx fonts]# fc-ca…

为什么大家会觉得考PMP没用?

一是在于PMP这套知识体系,是一套底层的项目管理逻辑框架,整体是比较抽象的。大家在学习工作之后,会有人告诉你很多职场的一些做事的规则,比如说对于沟通,有人就会告诉如何跟客户沟通跟同事相处等等,这其实就…

ebay运营思路|学会这些技巧,新店铺销量翻倍

Ebay是一个老牌的跨境电商,目前仍然是稳坐全球前列的平台,也是强手如云的地方,虽然相对于亚马逊他显得没有那么“卷”。 要在这片市场中抢占一番天地,首先一定要学会一些高效的运营技巧,今天就来分享一些Ebay运营技巧…

SVPWM的原理及法则推导和控制算法详解

空间电压矢量调制 SVPWM 技术 SVPWM是近年发展的一种比较新颖的控制方法,是由三相功率逆变器的六个功率开关元件组成的特定开关模式产生的脉宽调制波,能够使输出电流波形尽 可能接近于理想的正弦波形。空间电压矢量PWM与传统的正弦PWM不同,它…

chrome 谷歌浏览器 导出插件拓展和导入插件拓展

给同事部署 微软 RPA时,需要用到对应的chrome浏览器插件;谷歌浏览器没有外网是不能直接下载拓展弄了半小时后才弄好,竟发现没有现成的教程,遂补充; 如何打包导出 谷歌浏览器 地址栏敲 chrome://extensions/在对应的地…

如何配置远程访问以在外部网络中使用公司内部的OA办公系统——“cpolar内网穿透”

文章目录 前言1. 确认在内网下能够使用IP端口号登录OA办公系统2. 安装cpolar内网穿透3. 创建隧道映射内网OA系统服务端口4. 实现外网访问公司内网OA系统总结 前言 现在大部分公司都会在公司内网搭建使用自己的办公管理系统,如OA、ERP、金蝶等,员工只需要…

Instagram Shop如何开通?如何销售?最全面攻略

借助 Instagram 商店,品牌可以策划一系列可购物的商品,这些商品可通过其 Instagram 个人资料直接访问。这使得在应用程序上销售更容易,也被潜在客户发现。 一、什么是Instagram Shop? Instagram 商店为商家提供了一种在 Instagra…

3种等待方式,让你学会Selenium设置自动化等待测试脚本!

一、Selenium脚本为什么要设置等待方式?——即他的应用背景到底是什么 应用Selenium时,浏览器加载过程中无法立即显示对应的页面元素从而无法进行元素操作,需设置一定的等待时间去等待元素的出现。(简单来说,就是设置…

组件以及组件间的通讯

组件 & 组件通讯 :::warning 注意 阅读本文章之前,你应该先要了解 ESM 模块化的 import export,如需要请查看 ESM 模块化。 ::: 上一篇有介绍到什么是组件化,就是把一个页面拆分成若干个小模块,然后重新组成一个页面。其中的…

iPhone 15有始终显示功能吗?它会出现在更多的苹果手机上吗?

和我们一样,你可能也在犹豫,iPhone 15将增加一个“始终显示”的功能,与一年前苹果Pro版手机的功能相匹配。然而,随着苹果9月活动的临近,没有太多传言可以让我们相信我们会如愿以偿。 你可能还记得,去年iPh…

导出Excel的技术分享-综合篇

导出Excel的技术分享-综合篇 简单的EasyExcel使用 /*** 最简单的写*/public void simpleWrite() {// 注意 simpleWrite在数据量不大的情况下可以使用(5000以内,具体也要看实际情况),数据量大参照 重复多次写入// 写法1 JDK8// s…

GMSL技术让汽车数据传输更为高效(转)

目前,大部分车企都在其旗舰车型上配备了达到Level 2水平的自动驾驶技术,也就是高级自动驾驶辅助 ADAS系统。ADAS系统硬件主要由以下几部分组成,包括传感器、串行器、解串器、ADAS处理器等。 除了ADAS系统,包括传感器融合、音视频影…

Python实现SSA智能麻雀搜索算法优化LightGBM回归模型(LGBMRegressor算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 麻雀搜索算法(Sparrow Search Algorithm, SSA)是一种新型的群智能优化算法,在2020年提出&a…

山西电力市场日前价格预测【2023-09-08】

日前价格预测 预测明日(2023-09-08)山西电力市场全天平均日前电价为356.28元/MWh。其中,最高日前电价为409.23元/MWh,预计出现在19: 30。最低日前电价为323.46元/MWh,预计出现在24: 00。 价差方向预测 1: 实…

封装flexible.js,页面替换px为rem,实现不同分辨率适配

做的vue项目需要做个大屏,需要适配不同电脑的分配率,想到了rem,但是直接通过npm install flexible 安装的flexible.js默认设置的分辨率范围不符合现有的需求,所以就把安装包里面的flexible.js单独拿出来,然后改下分辨率…

idea的git入门

(1)安装好git之后,在idea的设置里面,按照下面三步,配置git (2)创建本地git仓库 选择本地仓库的根目录,点击ok (3)创建成功之后,会发现文件名称都变…

C/C++输出第二个整数 2019年9月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 一、题目要求 1、编程实现 2、输入输出 二、解题思路 1、案例分析 三、程序代码 四、程序说明 五、运行结果 六、考点分析 2019年9月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 输入三个整数,把第二个输入的整数输出。 2、输入输出 输…

腾讯混元助手使用指南

一、腾讯混元助手简介 腾讯混元助手是什么? 腾讯混元助手是由腾讯研发的大语言模型的平台产品,具备跨领域知识和自然语言理解能力,实现基于人机自然语言对话的方式,理解用户指令并执行任务,帮助用户实现人获取信息&am…

利用观测云实现业务数据驱动的弹性扩缩容

背景 在使用观测云对业务系统进行观测的过程中,除了可以实现业务系统的全面感知,我们还可以基于观测云数据处理开发平台 DataFlux Func ,结合故障模型对被测系统进行主动管理,例如弹性扩容或系统故障自愈,从而实现系统…

VirtualBox的菜单栏被隐藏

一、virtualbox虚拟机里面最顶部没有控制,设置和帮助选项 的解决办法: 右边的CtrlC 二、linux终端上下滚动 向上滚动:Shift Page Up 向下滚动:Shift Page Down