6-3 pytorch使用GPU训练模型

news2025/1/9 15:13:48

深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天。
训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代
数据准备过程还是模型训练时间的主要瓶颈时,我们可以使用更多进程来准备数据
参数迭代过程成为训练时间的主要瓶颈时,我们通常的方法是应用GPU来进行加速

import torch 
import torchkeras 
import torchmetrics

print("torch.__version__ = ",torch.__version__)
print("torchkeras.__version__ = ",torchkeras.__version__)
print("torchmetrics.__version__ = ",torchmetrics.__version__)

image.png
Pytorch中使用GPU加速模型非常简单,只要将模型和数据移动到GPU上。核心代码只有以下几行。

# 定义模型
... 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 先定义device
model.to(device) # 移动模型到cuda

# 训练模型
...

features = features.to(device) # 移动数据到cuda
labels = labels.to(device) # 或者  labels = labels.cuda() if torch.cuda.is_available() else labels

如果要使用多个GPU训练模型,也非常简单。只需要在将模型设置为数据并行风格模型。 则模型移动到GPU上之后,会在每一个GPU上拷贝一个副本,并把数据平分到各个GPU上进行训练。核心代码如下:

# 定义模型
... 

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model) # 包装为并行风格模型

# 训练模型
...
features = features.to(device) # 移动数据到cuda
labels = labels.to(device) # 或者 labels = labels.cuda() if torch.cuda.is_available() else labels

GPU相关操作汇总

查看GPU信息

import torch 
from torch import nn 

# 1,查看gpu信息
if_cuda = torch.cuda.is_available()
print("if_cuda=",if_cuda)

gpu_count = torch.cuda.device_count()
print("gpu_count=",gpu_count)

image.png

将张量在gpu和cpu间移动

# 2,将张量在gpu和cpu间移动
tensor = torch.rand((100,100))
tensor_gpu = tensor.to("cuda:0") # 或者 tensor_gpu = tensor.cuda()
print(tensor_gpu.device)
print(tensor_gpu.is_cuda)

tensor_cpu = tensor_gpu.to("cpu") # 或者 tensor_cpu = tensor_gpu.cpu() 
print(tensor_cpu.device)

image.png

将模型中的全部张量移动到gpu上

# 3,将模型中的全部张量移动到gpu上
net = nn.Linear(2,1)
print(next(net.parameters()).is_cuda)
net.to("cuda:0") # 将模型中的全部参数张量依次到GPU上,注意,无需重新赋值为 net = net.to("cuda:0")
print(next(net.parameters()).is_cuda)
print(next(net.parameters()).device)

image.png

创建支持多个gpu数据并行的模型

# 4,创建支持多个gpu数据并行的模型
linear = nn.Linear(2,1)
print(next(linear.parameters()).device)

model = nn.DataParallel(linear) # 并行
print(model.device_ids)
print(next(model.module.parameters()).device) 

#注意保存参数时要指定保存model.module的参数
torch.save(model.module.state_dict(), "model_parameter.pt") 

linear = nn.Linear(2,1)
linear.load_state_dict(torch.load("model_parameter.pt")) 

一、矩阵乘法范例

下面分别使用CPU和GPU作一个矩阵乘法,并比较其计算效率。

# 使用cpu
a = torch.rand((10000,200))
b = torch.rand((200,10000))
tic = time.time()
c = torch.matmul(a,b)
toc = time.time()

print(toc-tic)
print(a.device)
print(b.device)

image.png

# 使用gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
a = torch.rand((10000,200),device = device) #可以指定在GPU上创建张量
b = torch.rand((200,10000)) #也可以在CPU上创建张量后移动到GPU上
b = b.to(device) #或者 b = b.cuda() if torch.cuda.is_available() else b 
tic = time.time()
c = torch.matmul(a,b)
toc = time.time()
print(toc-tic)
print(a.device)
print(b.device)

image.png

二、线性回归范例

使用CPU

# 准备数据
n = 1000000 #样本数量

X = 10*torch.rand([n,2])-5.0  #torch.rand是均匀分布 
w0 = torch.tensor([[2.0,-3.0]])
b0 = torch.tensor([[10.0]])
Y = X@w0.t() + b0 + torch.normal( 0.0,2.0,size = [n,1])  # @表示矩阵乘法,增加正态扰动
# 定义模型
class LinearRegression(nn.Module): 
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn_like(w0))
        self.b = nn.Parameter(torch.zeros_like(b0))
    #正向传播
    def forward(self,x): 
        return x@self.w.t() + self.b
        
linear = LinearRegression() 
# 训练模型
optimizer = torch.optim.Adam(linear.parameters(),lr = 0.1)
loss_fn = nn.MSELoss()

def train(epoches):
    tic = time.time()
    for epoch in range(epoches):
        optimizer.zero_grad()
        Y_pred = linear(X) 
        loss = loss_fn(Y_pred,Y)
        loss.backward() 
        optimizer.step()
        if epoch%50==0:
            print({"epoch":epoch,"loss":loss.item()})
    toc = time.time()
    print("time used:",toc-tic)

train(500)

image.png

使用GPU

# 准备数据
n = 1000000 #样本数量

X = 10*torch.rand([n,2])-5.0  #torch.rand是均匀分布 
w0 = torch.tensor([[2.0,-3.0]])
b0 = torch.tensor([[10.0]])
Y = X@w0.t() + b0 + torch.normal( 0.0,2.0,size = [n,1])  # @表示矩阵乘法,增加正态扰动

# 数据移动到GPU上
print("torch.cuda.is_available() = ",torch.cuda.is_available())
X = X.cuda()
Y = Y.cuda()
print("X.device:",X.device)
print("Y.device:",Y.device)

image.png

# 定义模型
class LinearRegression(nn.Module): 
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn_like(w0))
        self.b = nn.Parameter(torch.zeros_like(b0))
    #正向传播
    def forward(self,x): 
        return x@self.w.t() + self.b
        
linear = LinearRegression() 

# 移动模型到GPU上
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
linear.to(device)

#查看模型是否已经移动到GPU上 通过参数
print("if on cuda:",next(linear.parameters()).is_cuda)

image.png
image.png

三、图片分类范例

注意需要使用GPU的地方:


loss_fn = nn.CrossEntropyLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   
metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)}
# =========================移动模型到GPU上==============================
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
loss_fn.to(device)
for name,fn in metrics_dict.items():
    fn.to(device)
# ====================================================================
features,labels = batch
        
# =========================移动数据到GPU上==============================
features = features.to(device)
labels = labels.to(device)
# ====================================================================
features,labels = batch
            
# =========================移动数据到GPU上==============================
features = features.to(device)
labels = labels.to(device)
# ====================================================================

四、torchkeras.KerasModel中使用GPU

从上面的例子可以看到,在pytorch中使用GPU并不复杂,但对于经常炼丹的同学来说,模型和数据老是移来移去还是蛮麻烦的。
一不小心就会忘了移动某些数据或者某些module,导致报错。
torchkeras.KerasModel 在设计的时候考虑到了这一点,如果环境当中存在可用的GPU,会自动使用GPU,反之则使用CPU。
通过引入accelerate的一些基础功能,torchkeras.KerasModel以非常优雅的方式在GPU和CPU之间切换。
详细实现可以参考torchkeras.KerasModel的源码。

import  accelerate 
accelerator = accelerate.Accelerator()
print(accelerator.device)  

image.png
image.png
image.png
image.png
参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分享一下微信公众号怎么添加砸金蛋链接

一、砸金蛋活动的优势 砸金蛋活动是一种非常有趣且吸引人的互动方式,在微信公众号中添加砸金蛋链接有以下优势: 提高用户参与度:砸金蛋活动能够激发用户的参与度和好奇心,让用户感到有乐趣和刺激。通过砸金蛋的方式,…

现在全国融资融券两融利率最低是多少?哪家证券公司券商费率低?

融资融券是指投资者通过向券商借入资金(融资)或借入证券(融券),以达到获得更高收益、降低交易风险、提高资金利用效率的目的。通过融资,投资者可以用借入的资金买入更多的证券;通过融券&#xf…

乐器商城小程序开发全攻略

随着互联网的普及和电子商务的快速发展,越来越多的人开始通过在线购物来满足自己的需求。而乐器作为一种特殊的商品,其在线销售市场也在不断扩大。为了满足这一需求,许多乐器商家开始开发自己的小程序商城,以提供更加便捷、高效的…

python使用SMTP发送邮件

SMTP是发送邮件的协议,Python内置对SMTP的支持,可以发送纯文本邮件、HTML邮件以及带附件的邮件。 Python对SMTP支持有smtplib和email两个模块,email负责构造邮件,smtplib负责发送邮件。 首先,我们来构造一个最简单的…

[BJDCTF2020]Mark loves cat foreach导致变量覆盖

这里我们着重了解一下变量覆盖 首先我们要知道函数是什么 foreach foreach (iterable_expression as $value)statement foreach (iterable_expression as $key > $value)statement第一种格式遍历给定的 iterable_expression 迭代器。每次循环中,当前单元的值被…

184_Python 在 Excel 和 Power BI 绘制堆积瀑布图

184_Python 在 Excel 和 Power BI 绘制堆积瀑布图 一、背景 在 2023 年 8 月 22 日 微软 Excel 官方宣布:在 Excel 原生内置的支持了 Python。博客原文 笔者第一时间就更新到了 Excel 的预览版,通过了漫长等待分发,现在可以体验了&#xf…

微信生态全场景方案

微信生态全场景方案 微信生态场景复杂,如何实现快速接入? 企业拥有跨平台数据,平台间数据割裂,如何实现各业务线数据整合? 借助身份云平台可快速接入微信生态全场景,轻松打通微信生态、电商平台、第三方平台…

prometheus 告警

prometheus 告警 1, prometheus 告警简介 告警能力在Prometheus的架构中被划分成两个独立的部分。如下所示,通过在Prometheus中定义AlertRule(告警规则),Prometheus会周期性的对告警规则进行计算,如果满足告警触发条件就会向Alertmanager发送告警信息。 在Prometheus中一…

基于Java建筑装修图纸管理平台设计实现(源码+lw+部署文档+讲解等)

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

(高阶)Redis 7 第13讲 数据双写一致性 canal篇

面试题 问题答案如何保证mysql改动后,立即同步到Rediscanal 简介 https://github.com/alibaba/canal/wikihttps://github.com/alibaba/canal/wiki 基于 MySQL 数据库增量日志解析,提供增量数据订阅和消费 业务 数据库镜像数据库实时备份多级索引 (卖家和买家各自分库索引…

【springMvc】自定义注解的使用方式

🎬 艳艳耶✌️:个人主页 🔥 个人专栏 :《Spring与Mybatis集成整合》 ⛺️ 生活的理想,为了不断更新自己 ! 1.前言 1.1.什么是注解 Annontation是Java5开始引入的新特征,中文名称叫注解。 它提供了一种安全…

【Java并发】聊聊死锁

什么是死锁 死锁出现的条件主要是资源互斥、占有并等待、非抢占、循环等待。 当出现两个线程对不同的资源进行获取的时候,A持有资源1,去获取资源2,B持有资源2,去获取资源1,就回出现死锁。 如何排查死锁 public cla…

计算机视觉与深度学习-经典网络解析-ResNet-[北邮鲁鹏]

这里写目录标题 ResNet参考产生背景贡献残差模块残差结构 批归一化ReLU激活函数的初始化方法 网络结构为什么残差网络性能好? ResNet ResNet(Residual Neural Network)是一种深度卷积神经网络模型,由Kaiming He等人在2015年提出。…

【1++的C++进阶】之智能指针

👍作者主页:进击的1 🤩 专栏链接:【1的C进阶】 文章目录 一,什么是智能指针二,为什么需要智能指针三,智能指针的发展 一,什么是智能指针 要了解智能指针,我们先要了解RA…

Linux上运行Redis服务出现报错及解决方法

近期,有用户反馈在Linux上运行Redis服务时遇到了一个报错:“Sorry, target machine refused connection”。下面我们来分析这个报错的解决方法。 一、报错分析 该报错通常是由于Redis服务无法与目标机器建立连接导致的。可能的原因包括以下几个方面&…

IP模块组装网络包及转发网络包链路

引言 之前协议栈系列的文章讲解了 连接,收发网络包,断开连接这些操作协议栈模块的处理,但是协议栈是上层 接下来会 委托ip模块进行真正的处理。 网络包 网络包的组成 网络包由头部的控制信息和头部后面的传输数据组成。 控制信息代表了包要…

TikTok矩阵玩法:如何最大程度地利用平台资源

在数字时代,TikTok已经成为全球范围内数亿用户的创意天堂,不仅仅是一个娱乐平台,还是一个创收的宝地。 TikTok矩阵玩法的崛起正在引领创作者们探索全新的变现方案,他们通过巧妙地利用平台资源,实现了前所未有的创收机…

为何网站一定要使用SSL证书

当您在浏览器中输入网址并按下回车键时,您是否曾想过您的个人信息和隐私是否会被窃取?在当今数字化的时代,网络安全问题越来越受到人们的关注。而SSL证书正是保护您的网站和用户信息安全的重要工具。 SSL证书是一种数字证书,它使用…

Unity之NetCode多人网络游戏联机对战教程(1)

文章目录 1.什么是NetCode2.安装NGO 1.什么是NetCode 官网链接:https://docs-multiplayer.unity3d.com/netcode/current/about/ Netcode for GameObjects(NGO)是专为Unity构建的高级网络库。它能够在网络会话中将GameObject和世界数据同时发…

unity打包后无法读取Excel解决方法

一、前言 最近几乎遇到了所有能遇到的unity读取Excel 的问题。 因为使用的是unity5.4,而且还是32位。所以出现各种问题在所难免。 废话不多说,现有的现象是:在unity的编辑器里可以完美运行,读取Excel不成问题,但是打包…