pytorch单机多卡训练

news2024/12/26 21:28:32

多卡训练的方式

以下内容来自知乎文章:当代研究生应当掌握的并行训练方法(单机多卡)

pytorch上使用多卡训练,可以使用的方式包括:

  • nn.DataParallel
  • torch.nn.parallel.DistributedDataParallel
  • 使用Apex加速。Apex 是 NVIDIA 开源的用于混合精度训练和分布式训练库。Apex 对混合精度训练的过程进行了封装,改两三行配置就可以进行混合精度的训练,从而大幅度降低显存占用,节约运算时间。此外,Apex 也提供了对分布式训练的封装,针对 NVIDIA 的 NCCL 通信库进行了优化。
  • 使用Horovod加速。Horovod 是 Uber 开源的深度学习工具,它的发展吸取了 Facebook “Training ImageNet In 1 Hour” 与百度 “Ring Allreduce” 的优点,可以无痛与 PyTorch/Tensorflow 等深度学习框架结合,实现并行训练。

对比不同的方式,后面几种差别其实并不大,而pytorch官方也建议将第一种DataParallel替换成DistributedDataParallel,下面我们只关注DistributedDataParallel的实现,想了解其他的可以去看引用的文章。
在这里插入图片描述

数据并行

一般来说并行加速可以分为模型并行和数据并行,模型并行指的是将模型的不同组成部分放到不同的gpu上,而数据并行则是把一个模型拷贝到多个gpu上(例如8块gpu),然后将一个完整的batch(例如batch size为128)再分成多个小的batch(每个小batch就是128/8个),分别由这多个gpu进行计算,反向传播之后得到参数的梯度,然后将不同gpu上的梯度求平均(reduce),然后同步不同模型上的参数更新。

DistributedDataParallel

一些概念
Node: 一个节点, 可以理解为一台电脑.
Device: 工作设备, 可以简单理解为一张卡, 即一个GPU.
Process: 一个进程, 可以简单理解为一个Python程序.

DistributedDataParallel和DataParallel的区别

  1. DistributedDataParallel支持模型并行,而DataParallel并不支持,这意味如果模型太大单卡显存不足时只能使用前者;
  2. DataParallel是单进程多线程的,只用于单机情况,而DistributedDataParallel是多进程的,适用于单机和多机情况,真正实现分布式训练;
  3. DistributedDataParallel的训练更高效,因为每个进程都是独立的Python解释器,避免GIL问题,而且通信成本低其训练速度更快,基本上DataParallel已经被弃用;
  4. 必须要说明的是DistributedDataParallel中每个进程都有独立的优化器,执行自己的更新过程,但是梯度通过通信传递到每个进程,所有执行的内容是相同的;

参考链接:https://blog.csdn.net/weixin_43402775/article/details/114318434

如何使用DistributedDataParallel进行单机多卡的训练

启动进程

要使用分布式训练,首先需要在每个训练节点(Node)上生成多个分布式训练进程。对于每一个进程, 它都有一个local_rank和global_rank, local_rank对应的就是该Process在自己的Node上的编号, 而global_rank就是全局的编号。比如你有2个Node, 每个Node上各有4个Proess (Process0, Process1, Process2, Process3). 那么对于Process2来说, 它的local_rank就是0(即它在Node1上是第0个Process), global_rank 就是2。对于单机多卡的情况,那么local_rank和global_rank是一样的。

可以使用的方法有:

  1. torch.distributed.launch:这是一个非常常见的启动方式,在单节点分布式训练或多节点分布式训练的两种情况下,此程序将在每个节点启动给定数量的进程(--nproc_per_node)。如果用于GPU训练,这个数字需要小于或等于当前系统上的GPU数量(nproc_per_node),并且每个进程将运行在单个GPU上,从GPU 0到GPU (nproc_per_node - 1)。
    使用方式为:python -m torch.distributed.launch --nproc_per_node=N --use_env xxx.py,其中-m表示后面加上的是模块名,因此不需要带.py,--nproc_per_node=N表示启动N个进程,--use_env表示pytorch会将当前进程在本机上的rank添加到环境变量“LOCAL_RANK”,因此可以通过os.environ['LOCAL_RANK']来获取当前的gpu编号,如果不加--use_env,那么必须声明一个命令行参数--local_rank,因为启动的时候会给每个进程传入这样一个参数,我们需要用变量来接收他,这边还是建议使用--use_env的方式。最后我们加上要运行的python文件名即可。

  2. torchrun:是为了替代torch.distributed.launch的新型启动方式, 可以支持ELASTIC LAUNCH, 即动态控制启动的节点数量, 但是由于是新功能, 只有最新的torch 1.10, 处于兼容性考虑还是建议使用torch.distributed.launch。torchrun默认有--use_env的。python -m torch.distributed.launch --use-env train_script.py可以用torchrun train_script.py来替代。

初始化进程组

在启动多个进程之后,需要初始化进程组,使用的方法是使用torch.distributed.init_process_group()来初始化默认的分布式进程组。

torch.distributed.init_process_group(backend=None, init_method=None, timeout=datetime.timedelta(seconds=1800), world_size=- 1, rank=- 1, store=None, group_name='', pg_options=None)

一般需要传入的参数有:

  • backend :使用什么后端进行进程之间的通信,选择有:mpi、gloo、nccl、ucc,一般使用nccl。
  • world_size:使用几个进程,每个进程会对应一个gpu。
  • rank:当前进程的编号,大小在[0,world_size-1]

如果使用了--use_env,那么这里的rank和world_size都可以通过os.environ['LOCAL_RANK']os.environ['WORLD_SIZE']来获取,然后传入这个函数。

该语句后面最好加一句torch.distributed.barrier(),这个函数会进行多个进程间的同步,确保每一个进程都执行完了init_process_group()

创建模型

做完以上这些之后,我们就可以创建模型了:

# 创建模型, 并将其移动到local_rank对应的GPU上
model = ToyModel().to(local_rank)
ddp_model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

上面说过,local_rank可以通过环境变量来获取。第一句是将model放到对应的gpu上,也可以通过以下方式来实现:

  • torch.cuda.set_device(local_rank)
  • with torch.cuda.device(local_rank)

注意,这里的ddp_model和原来的model就不一样了,如果你要保存的是原来模型的参数,需要通过ddp_model.module来获取。

读取数据

有了模型之后,如何读取数据进行训练呢?如果一个batch有n个数据,我们有m个gpu参与训练,那么肯定希望m个gpu上都有不同的n/m个数据,这个就是采样的问题了,数据集dataset和dataloader的使用并不受影响。

在分布式训练的时候,我们要使用DistributedSampler这个类。和DistributedDataParallel搭配使用的时候,每个进程都可以将DistributedSampler实例作为DataLoader采样器传递,并加载原始数据集的一个专有子集。使用方式:

torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)

除了dataset都是可选参数,如果不传入默认会从当前的进程组中获取有关的参数,因此我们只需要传入一个dataset即可。然后再dataloader中传入这个sampler就行。

load和save

在训练完,我们一般都会保存模型的权重,这时候最好是保存ddp_model.module.statedict(),然后load的时候也是ddp_model.module.load_state_dict()。

其他可能用到的函数

  • torch.distributed.all_gather():把所有进程中的某个tensor收集起来,比如有8个进程,都有一个tensor a,那么可以把所有进程中的a收集起来得到一个list
  • torch.distributed.all_reduce():汇总所有gpu上的某一个tensor值,可以选择平均或者求和等,然后再分发到所有gpu上使得每个gpu上的值都是相同的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/418035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式学习笔记汇总

本文整理STM32、STM8和uCOS-III的所有文章链接。 STM32学习笔记目录 源码:mySTM32-learn STM32学习笔记(1)——LED和蜂鸣器 STM32学习笔记(2)——按键输入实验 STM32学习笔记(3)——时钟系统 …

.NET System.Management 获取windows系统和硬件信息

ManagementObject用于创建WMI类的实例与WINDOWS系统进行交互,通过使用WMI我们可以获取服务器硬件信息、收集服务器性能数据、操作Windows服务,甚至可以远程关机或是重启服务器。 WMI 的全称 Windows Management Instrumentation,即 Windows …

音视频八股文(1)--音视频基础

1.1.音视频录制原理 1.2.音视频播放原理 1.3.图像表示RGB-YUV 1.3.1 图像基础概念 ◼ 像素:像素是一个图片的基本单位,pix是英语单词picture的简写,加上英 语单词“元素element”,就得到了“pixel”,简称px&#xff…

使用Mingw64在CLion中搭建Linux开发环境

1.前言: 博主本来一直是在Visual Studio 2017中使用C语言编写程序,但有个问题是Visual Studio 2017默认使用自带的Windows SDK和编译器,我想使用POSIX文件操作就不行(因为Windows中没有Linux SDK),虽然Wind…

【Kafka-架构及基本原理】Kafka生产者、消费者、Broker原理解析 Kafka原理流程图

【Kafka-架构及基本原理】Kafka生产者、消费者、Broker原理解析 & Kafka原理流程图1)Kafka原理1.1.生产者流程细节1.2.Broker 的存储流程细节1.3.消费者流程细节2)Kafka读写流程图1)Kafka原理 1.1.生产者流程细节 1、生产者发送消息到 …

计算机毕业设计源码整合大全_kaic

以下为具体单个列表(单个下载在我主页搜索即可): 1:计算机专业-ASP(499套) ASP学生公寓管理系统的设计与实现(源代码论文).rar 1:计算机专业-ASP(499套) ASP学科建设设计(源代码论文).ra…

Clickhouse 引擎之MergeTree详解

分区详解 数据存储底层分布 # 数据在这个位置 rootfjj001:~# cd /var/lib/clickhouse/data rootfjj001:/var/lib/clickhouse/data# ls # 数据库 default system rootfjj001:/var/lib/clickhouse/data# cd default/ rootfjj001:/var/lib/clickhouse/data/default# ls #表 enu…

ASEMI代理AD8400ARZ10-REEL原装ADI车规级AD8226ARZ-R7

编辑:ll ASEMI代理AD8400ARZ10-REEL原装ADI车规级AD8226ARZ-R7 型号:AD8400ARZ10-REEL 品牌:ADI/亚德诺 封装:SOIC-8 批号:2023 引脚数量:8 安装类型:表面贴装型 AD8400ARZ10-REEL汽车芯…

Zabbix监控系统——附详细步骤和图解

文章目录一、Zabbix概述1、使用zabbix的原因2、zabbix的概念和构成3、zabbix 监控原理:4、zabbix的程序组件二、安装 zabbix 5.01、部署 zabbix 服务端的操作步骤2、实例操作:部署 zabbix 服务端3、部署 zabbix 客户端4、实例操作:部署 zabbi…

【Linux】揭开套接字编程的神秘面纱(下)

​🌠 作者:阿亮joy. 🎆专栏:《学会Linux》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录👉前言&…

(二十三)槽函数的书写规则导致槽函数触发2次的问题

在创建QT的信号和槽时,经常无意间保留着QT书写槽函数的习惯,或者在QT设计界面直接右键【转到槽】去创建槽函数,但是后期需要用到disconnect时,又重新写了一遍connect函数,那么你会发现实际槽函数执行了2遍。 首先来看…

要在Ubuntu中查找进程的PID,可以使用pgrep或pidof命令。

一 查找进程 1.pgrep命令 pgrep命令可以根据进程名或其他属性查找进程的PID。例如,要查找名为"firefox"的进程的PID,可以在终端中输入以下命令: pgrep firefox如果有多个名为"firefox"的进程,pgrep命令将返…

互联网一个赛道只剩下几家,真要爆品

互联网一个赛道剩下几家,真要爆品 2017年的书,案例基本上是马后炮总结 趣讲大白话:说起来容易,做起来难 【趣讲信息科技136期】 **************************** 书中讲的范冰冰翻车了 书中不看好的线下渠道,现在成香饽饽…

面试篇-Java并发之CAS:掌握原理、优缺点和应用场景分析,避免竞态问题

1、CAS介绍及原理 多线程中的CAS(Compare-and-Swap)操作是一种常见的并发控制方法,用于实现原子性更新共享变量的值。其核心思想是通过比较内存地址上的值和期望值是否相等来确定是否可以进行更新操作,从而避免多线程条件下的竞态…

HMI实时显示网络摄像机监控画面——以海康威视网络摄像机为例

随着IOT技术的快速发展,网络摄像机快速应用于工业领域,结合其他智能设备建立一个智能系统,提高用户与机器设备之间的交互体验,帮助企业优化人员配置。 作为重要的可视化设备,HMI不仅可以采集现场设备数据,…

uniapp系列-使用uniapp携带收件人信息调用手机邮件应用发邮件的2种方案

背景描述 我们使用uniapp打包之后,某些情况下,需要使用uniapp打开手机其他应用去发邮件,携带对方email 信息以及主题信息等,那我们应该怎么处理呢? 方案一:使用uniapp标签-uni-link,注意这种方…

BGP实验(一)

实验要求: 1、As1存在两个环回,一个地址为192.168.1.0/24,该地址不能在任何协议中宣告, As3存在两个环回,.一个地址为192.168.2.0/24,该地址不能在任何协议中宣告, As1还有一个环回地址为10.1.1.0/24&…

研读Rust圣经解析——Rust learn-8(match,if-let简洁控制流,包管理)

研读Rust圣经解析——Rust learn-8(match,if-let简洁控制流,包管理)matchother和占位符_区别easy matchenum matchno valuematch innerOption matchmore better wayif-let整洁控制包管理模块(mod)拆分声明modpub公开use展开引用拆解模块结构m…

docker cmd

sudo docker run --gpus all --name uavrl1 themvs/uav_swarm_reinforcement_learning sudo docker p s-a 86850d5a9dc3 sudo docker run --gpus all --name uavrl12 uavrl:v1.2 ---------- 共享屏幕输入类似指令,实测可行 sudo docker run -it --nethost --ipc…

Leetcode每日一题——“轮转数组”

各位CSDN的uu们你们好呀,今天,小雅兰的内容是轮转数组,下面,让我们进入轮转数组的世界吧 小雅兰之前其实就已经写过了字符串旋转的问题了: C语言刷题(7)(字符串旋转问题&#xff09…