pytorch多GPU训练简明教程

news2024/12/24 9:07:00

1. Torch 的两种并行化模型封装

图片

1.1 DataParallel

DataParallel 是 PyTorch 提供的一种数据并行方法,用于在单台机器上的多个 GPU 上进行模型训练。它通过将输入数据划分成多个子部分(mini-batches),并将这些子部分分配给不同的 GPU,以实现并行计算。


在前向传播过程中,输入数据会被划分成多个副本并发送到不同的设备(device)上进行计算。模型(module)会被复制到每个设备上,这意味着输入的批次(batch)会被平均分配到每个设备,但模型会在每个设备上有一个副本。每个模型副本只需要处理对应的子部分。需要注意的是,批次大小应大于GPU数量。在反向传播过程中,每个副本的梯度会被累加到原始模型中。总结来说,DataParallel会自动将数据切分并加载到相应的GPU上,将模型复制到每个GPU上,进行正向传播以计算梯度并汇总。


注意:DataParallel是单进程多线程的,仅仅能工作在单机中。


封装示例:

import torchimport torch.nn as nnimport torch.optim as optim
# 定义模型class SimpleModel(nn.Module):    def __init__(self):        super(SimpleModel, self).__init__()        self.fc = nn.Linear(10, 1)
    def forward(self, x):        return self.fc(x)
# 初始化模型model = SimpleModel()
# 使用 DataParallel 将模型分布到多个 GPU 上model = nn.DataParallel(model)

1.2 DistributedDataParallel

DistributedDataParallel (DDP) 是 PyTorch 提供的一个用于分布式数据并行训练的模块,适用于单机多卡和多机多卡的场景。相比于 DataParallel,DDP 更加高效和灵活,能够在多个 GPU 和多个节点上进行并行训练。


DistributedDataParallel是多进程的,可以工作在单机或多机器中。DataParallel通常会慢于DistributedDataParallel。所以目前主流的方法是DistributedDataParallel。


封装示例:​​​​​​​

import torchimport torch.nn as nnimport torch.optim as optimimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDP
def main(rank, world_size):    # 初始化进程组    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # 创建模型并移动到GPU    model = SimpleModel().to(rank)
    # 包装模型为DDP模型    ddp_model = DDP(model, device_ids=[rank])

if __name__ == "__main__":    import os    import torch.multiprocessing as mp
    # 世界大小:总共的进程数    world_size = 4
    # 使用mp.spawn启动多个进程    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

2. 多GPU训练的三种架构组织方式

图片

由于上一节的讨论,本节所有源码均由DDP封装实现。

2.1 数据拆分,模型不拆分(Data Parallelism)

数据并行(Data Parallelism)将输入数据拆分成多个子部分(mini-batches),并分配给不同的 GPU 进行计算。每个 GPU 上都有一份完整的模型副本。这种方式适用于模型相对较小,但需要处理大量数据的场景。


由于下面的代码涉及了rank、world_size等概念,这里先做一下简要普及。


Rank

rank 是一个整数,用于标识当前进程在整个分布式训练中的身份。每个进程都有一个唯一的 rank。rank 的范围是 0 到 world_size - 1。

用于区分不同的进程。

可以根据 rank 来分配不同的数据和模型部分。

World Size
world_size 是一个整数,表示参与分布式训练的所有进程的总数。

确定分布式训练中所有进程的数量。

用于初始化通信组,确保所有进程能够正确地进行通信和同步。

Backend
backend 指定了用于进程间通信的后端库。常用的后端有 nccl(适用于 GPU)、gloo(适用于 CPU 和 GPU)和 mpi(适用于多种设备)。

决定了进程间通信的具体实现方式。

影响训练的效率和性能。

Init Method
init_method 指定了初始化分布式环境的方法。常用的初始化方法有 TCP、共享文件系统和环境变量。

用于设置进程间通信的初始化方式,确保所有进程能够正确加入到分布式训练中。

Local Rank
local_rank 是每个进程在其所在节点(机器)上的本地标识。不同节点上的进程可能会有相同的 local_rank。

用于将每个进程绑定到特定的 GPU 或 CPU。

图片

​​​​​​​

import torchimport torch.nn as nnimport torch.optim as optimimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPimport torch.multiprocessing as mp
class SimpleModel(nn.Module):    def __init__(self):        super(SimpleModel, self).__init__()        self.fc = nn.Linear(10, 1)
    def forward(self, x):        return self.fc(x)
def train(rank, world_size):    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:29500', rank=rank, world_size=world_size)
    model = SimpleModel().to(rank)    ddp_model = DDP(model, device_ids=[rank])
    criterion = nn.MSELoss().to(rank)    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
    inputs = torch.randn(64, 10).to(rank)    targets = torch.randn(64, 1).to(rank)
    outputs = ddp_model(inputs)    loss = criterion(outputs, targets)
    optimizer.zero_grad()    loss.backward()    optimizer.step()
    dist.destroy_process_group()
if __name__ == "__main__":    world_size = 4    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

2.2 数据不拆分,模型拆分(Model Parallelism)

模型并行(Model Parallelism)将模型拆分成多个部分,并分配给不同的 GPU。输入数据不拆分,但需要通过不同的 GPU 处理模型的不同部分。这种方式适用于模型非常大,单个 GPU 无法容纳完整模型的场景。

点击pytorch多GPU训练简明教程可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1995770.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试面试200问(含答案+文档)

1、你的测试职业发展是什么? 测试经验越多,测试能力越高。所以我的职业发展是需要时间积累的,一步步向着高级测试工程师奔去。而且我也有初步的职业规划,前3年积累测试经验,按如何做好测试工程师的要点去要求自己&…

旋转图像

旋转图像 思路: 第一意识是找一个数学规律,一个公式可以找到对应的位置。 唉 想不出 没啥思路 看题解了。 一看就懂了 规律就是。。。。。。:原来第 i 行第 j 列的元素 在旋转后 会在第 j 行倒数第i列。 这种题目做少了,多做…

正点原子安装buildroot构建根文件系统

1:何为 buildroot? 1.1:buildroot 简介 在《第三篇 系统移植篇》我们最后讲解了如何使用 busybox 构建文件系统,busybox 仅仅 只是帮我们构建好了一些常用的命令和文件,像 lib 库、/etc 目录下的一些文件都需要我们自…

初识Spring、SpringIOC

Spring 一、什么是Spring框架?(重要) ---对Spring的理解 记忆关键字:1.核心思想(IOC、AOP) 2.作用(解耦、简化) 3.简单描述框架组成 答:定义:Spring是一个轻量级的控制反转(IoC)和…

SpringBoot中解决文件application.properties中文注释乱码的问题

如图看到中文注释乱码 很影响代码的阅读 原因是字符编码使用了ISO-8859-1 这里演示如何在idea里面把ISO-8859-1改为UTF-8 点击右上角设置 搜索框输入UTF-8 把默认的改成UTF-8就行了 可以看到中文注释正常显示 希望能够点点赞和收藏!!

猫咪浮毛大作战!希喂、安德迈宠物空气净化器PK,实测数据大公开

宠物空气净化器作为宠物领域的新产品,凭借自身独特的功能受到铲屎官们的喜爱,越来越多的商家关注到这个市场。然而,市面上品牌逐渐增多,质量却参差不齐,一些不良商家以次充好,容易让消费者陷入消费陷阱。因…

PHYS_OPT_MODIFIED

当对原始单元执行物理优化时,PHYS_OPT_MODIFIED 更新单元的属性以反映对单元执行的优化。什么时候? 对同一单元格执行多次优化,PHYS_OPT_MODIFIED值 包含按发生顺序排列的优化列表。 架构支持 所有架构。 适用对象 PHYS_OPT_MODIFIED属性放置…

Linux嵌入式学习——C++学习(2)

一、标识符的作用域和可见性 (一)作用域 1、全局作用域 在函数外部声明的变量和函数具有全局作用域。这些变量和函数在程序的任何地方都可以被访问。 2.局部作用域 在函数内部、循环体内部或条件语句内部声明的变量具有局部作用域。这些变量只能在其…

<数据集>航拍屋顶识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:3516张 标注数量(xml文件个数):3516 标注数量(txt文件个数):3516 标注类别数:1 标注类别名称:[roof] 序号类别名称图片数框数1roof351643938 使用标注工具&#xf…

模具3D打印:成本缩减与产能提升的新引擎

近年来,3D打印技术,特别是在航空航天、汽车制造、生物医疗等前沿领域,已成为复杂结构件研发与生产的关键技术。针对广大制造企业而言,评估金属3D打印技术的经济效能,即其能否有效助力企业成本控制与产能提升&#xff0…

魔众文库系统v7.0.0版本推荐店铺功能,管理菜单逻辑优化

推荐店铺功能,管理菜单逻辑优化 [新功能] RandomImageProvider 逻辑升级重构,支持更丰富的随机图片生成 [新功能] 资源篮订单参数字段 [新功能] 首页推荐店铺功能,需要在后台 文库系统 → 文库店铺 开启推荐 [系统优化] Grid 快捷编辑请求…

Yolo-World初步使用

Yolo v8目前已经支持Yolo-World,整理一下初步使用步骤。 使用步骤 1 先下载Yolo-World的pt文件,下载地址:GitHub - AILab-CVC/YOLO-World: [CVPR 2024] Real-Time Open-Vocabulary Object Detection 官网应该是点这里(有个笑脸…

C++入门:C语言到C++的过渡

前言:C——为弥补C缺陷而生的语言 C起源于 1979 年,当时 Bjarne Stroustrup 在贝尔实验室工作,面对复杂软件开发任务,他感到 C 语言在表达能力、可维护性和可扩展性方面存在不足。 1983 年,Bjarne Stroustrup 在 C 语言…

大数据应用型产品设计方法及行业案例介绍(可编辑110页PPT)

引言:随着信息技术的飞速发展,大数据已成为推动各行各业创新与变革的重要力量。大数据应用型产品,作为连接海量数据与实际应用需求的桥梁,其设计方法不仅要求深入理解数据特性,还需精准把握用户需求,以实现…

git:安装 / 设置环境变量 / 使用

一、下载 https://github.com/git-for-windows/git/releases/download/v2.45.1.windows.1/Git-2.45.1-64-bit.exe 下载成功-双击打开 下一步-Next 二、添加环境变量 1、找到git安装地址 win r cmd 回车 where git 设置环境变量 C:\Program Files\Git\cmd\git.exe 此电…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 最大括号深度(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是 春秋招笔试突围 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-D卷的三语言AC题解 💻 ACM金牌🏅️团队| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,…

分布式事务学习整理

一、整体背景 最近在分布式事务领域这块的了解比较少,对自己来说是一个业务盲点,所以想抽空学习以及整理下关于分布式事务的相关知识。 1、分布式事务的发展 总所周知,我们为什么要考虑分布式事务,从一开始发展来说&#xff0c…

vscode源代码管理的传入传出更改视图如何关闭

传入传出更改视图关闭: vscode源代码管理中下面显示的大量传入传出记录,不想显示的话 在设置里搜索 scm.showHistoryGraph 可以关闭。

[Meachines] [Easy] valentine SSL心脏滴血+SSH-RSA解密+trp00f自动化权限提升+Tmux进程劫持权限提升

信息收集 IP AddressOpening Ports10.10.10.79TCP:22,80,443 $ nmap 10.10.10.79 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 22/tcp open ssh OpenSSH 5.9p1 Debian 5ubuntu1.10 (Ubuntu Linux; protocol 2.0) | ssh-hostkey: | 1024 96:4c:51:42:…

以树莓集团的视角:探索AI技术如何重塑数字媒体产业发展

在科技日新月异的今天,AI技术如同一股不可阻挡的潮流,正深刻改变着我们的世界,尤其是数字媒体产业发展。作为数字产业生态链的杰出建设者,树莓集团始终站在时代前沿,积极探索AI技术如何为数字媒体产业注入新活力。 在树…