Pytorch分布式训练(一)

news2024/11/25 1:00:17

参考文献:

33、完整讲解PyTorch多GPU分布式训练代码编写_哔哩哔哩_bilibili

pytorch进程间通信 - 文举的博客 (liwenju0.com)

前言

2023年,训练模型时,使用DDP(DistributedDataParallel)已经成为Pytorch炼丹师的标准技能。本文主要讲述实现Pytorch分布式要做哪些事情,以及如何理解Pytorch分布式训练背后的通信原理(不会很深入)。

分布式训练流程

单机多卡训练流程

算法工程师实现单机多卡训练流程的思维导图如下:

 其中 init_process_proup(初始化进程组) 是实现整个训练的前提,其中比较重要的 world_size、rank 等参数的意义 我们后面会讲解,而像 backend 参数(指定通信库),如算法工程师没有进行指定,则 pytorch做出默认选择(在较低版本的 pytorch,如V1.0.0中,backend参数为必传)。

多机多卡训练流程

代码编写单机多卡一致,可复用代码。

Pytorch进程间通信

DDP 本身是依赖 torch.distributed 提供的进程间通信能力。所以理解torch.distributed提供的进程间通信的原理,对理解DDP的运行机制有很大的帮助。官方的tutorial中,实现了依靠torch.distributed实现DDP功能的demo代码,学习一下,很有裨益。

这部分,其实就两件事儿,建立进程组和实现进程组之间的通信。

创建进程组

说到底,就是建立多个进程,并且将这些进程归并到一起,成为一个group,在group内,每个进程一个id,用于标识自己。

建立多个进程,归根到底,还是一个进程一个进程建立。 那我们想,建立一个进程时,需要怎么做才能实现进程间的寻址呢。 torch.distributed给我们答案是四个参数

  • MASTER_ADDR
  • MASTER_PORT
  • WORLD_SIZE
  • RANK

MASTER_PORT和MASTER_ADDR的目的是告诉进程组中负责进程通信协调的核心进程的IP地址和端口。当然如果该进程就是核心进程,它会发现这就是自己。 RANK参数是该进程的id,WORLD_SIZE是说明进程组中进程的个数。

了解以上这些知识,就可以看一下创建进程组的代码:

"""run.py:"""
#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, size):
    """ Distributed function to be implemented later. """
    pass

def init_process(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size) #这段代码就是将该进程加入到进程组中的核心代码
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
        p = mp.Process(target=init_process, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

观察代码,可以看到, MASTER_ADDR 和 MASTER_PORT 是通过在代码中设置环境变量传递给 torch.distributed 的, RANK 和 WORLD_SIZE 是通过参数传递的,其实也可以通过设置环境变量的方式传递。如下方法:

    # 一般分布式GPU训练使用nccl后端,分布式CPU训练使用gloo后端
    # init_method的值设为"env://",表示需要的四组参数信息都在环境变量里取。
    dist.init_process_group(backend="nccl", init_method="env://")

上面代码中的run就是在初始化好进程组之后执行的函数,这里之所以传入rank 和size,是想在执行过程中根据不同的rank,来给不同的进程赋予不同的行为,比如,日志只在rank==0的进程中打印等。实际上,如果已经初始化了进程组,也可以通过如下两个函数获取相应的值,避免在函数中传递这两个参数。

def get_world_size() -> int:
    """Return the number of processes in the current process group."""
    if not dist.is_available() or not dist.is_initialized():
        return 1
    return dist.get_world_size()


def get_rank() -> int:
    """Return the rank of the current process in the current process group."""
    if not dist.is_available() or not dist.is_initialized():
        return 0
    return dist.get_rank()

这段代码是官方给的demo,看过之后,不免有些疑惑。这个代码似乎只适用于单机多卡的情况。 对于多机多卡的情况,在不同的机器上执行这个代码,MASTER_PORT 和MASTER_ADDR不用变,WORLD_SIZE需要调整为4,因为我们的代码每台机器上都启动两个进程, RANK这个时候就会发生冲突,不同的机器上的进程有相同的编号。解决方法就是在执行初始化函数时,传递一个NODE_RANK和NPROC_PER_NODE的参数,通过NODE_RANK和NPROC_PER_NODE计算出各个进程的RANK值,就可以保证不冲突了。 代码示例如下。

for r in range(NPROC_PER_NODE):
  RANK = NODE_RANK*NPROC_PER_NODE + r

实际上,torch已经将上面这些计算过程帮我们封装好了。代码如下所示:

python -m torch.distributed.launch \
            --master_port 12355 \ #主节点的端口
            --nproc_per_node=8 \ #每个节点启动的进程数
            --nnodes=nnodes  \ #节点总数
            --node_rank=1  \  # 当前节点的rank
            --master_addr=master_addr \ #主节点的ip地址
            --use_env \ #在环境变量中设置LOCAL_RANK
            train.py

使用这段代码启动train.py时,原先的

    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(size)
    # 一般分布式GPU训练使用nccl后端,分布式CPU训练使用gloo后端
    dist.init_process_group(backend="nccl", init_method="env://")
    fn(rank, size)

可以简单改成:

dist.init_process_group(backend="nccl", init_method="env://")
fn()

需要的四个环境变量参数,torch.distributed.launch都会帮我们设置好。fn中需要rank和size的地方,使用上面的两个便利函数即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/702911.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

sourcetree打开就闪退

1、问题分析 一直未出现如题描述情况,今早到公司可能是异常重启或者系统更新的愿意导致没有正常关机,出现了此种情况 2、问题解决 注:本文sourcetree安装在win11系统,win10等系统目录大同小异 ① (若快捷方式在桌面步骤①省略…

你的服务器还安全吗?用户数据是否面临泄露风险?

一系列严重的网络安全事件引起了广泛关注,多家知名公司的服务器遭到黑客挟持,用户的个人数据和敏感信息面临泄露的风险。这些事件揭示了网络安全的脆弱性和黑客攻击的威胁性,提醒着企业和个人加强对网络安全的重视。 一、入侵案例1.1 蔚来数据…

复习Javascript第二章

JavaScript 函数 JavaScript 函数是被设计为执行特定任务的代码块。 JavaScript 函数会在某代码调用它时被执行。 function myFunction(p1, p2) {return p1 * p2; // 该函数返回 p1 和 p2 的乘积 } JavaScript 函数语法 JavaScript 函数通过 function 关键词…

爆款视频生成器小程序源码搭建方案

爆款视频生成器是一种可以帮助用户快速制作出高质量视频的工具。它可以根据用户提供的素材、模板和音乐等要素,自动生成一个精美的视频。这种工具可以大大节省用户的时间和精力,同时还能够提高视频制作的效率和质量,使视频更易于被观众接受和…

开发一个商城小程序有哪些功能?

✔️近年来,随着微信小程序的不断优化和推出,越来越多的商家开始选择使用小程序作为销售渠道。商城小程序作为一种便捷、快速、高效的销售渠道,已经成为商家们打造线上商城的重要手段。商城小程序拥有着丰富的功能,可以满足不同商…

使用itextpdf填充表单域并生成pdf

文章目录 前言一、准备工作1.1 安装软件1.2 准备pdf1.3 设置表单域 二、创建项目三、编写代码3.1 编写工具类3.2 测试 四、测试结果 前言 最近手上有个任务,就是需要做一个pdf导出的功能。 可选择的技术点比较多,我这边综合考虑之后,使用的…

品牌推广的新路径:邀请歌手出席活动的独特策略“

在当今的市场竞争中,品牌推广和市场营销已经成为企业取得成功的重要因素之一。而邀请知名歌手出席活动则是一种备受瞩目的策略,可以为品牌带来巨大的优势和机遇。无论是与赵丽颖、迪丽热巴、张子枫、关晓彤、周冬雨还是孙俪等知名歌手合作,都…

WPF中Binding的数据转换—ValueConverters

WPF中Binding的数据转换—ValueConverters 在WPF中使用Binding经常会遇到需要转换的情况,如bool转为visibility,通常情况需要自己写一个类继承自IValueConverter接口,使用详情请参见Binding对数据的转换和校验,这种方法虽然不难&…

C 模拟包装机

一种自动包装机的结构如图 1 所示。首先机器中有 N 条轨道,放置了一些物品。轨道下面有一个筐。当某条轨道的按钮被按下时,活塞向左推动,将轨道尽头的一件物品推落筐中。当 0 号按钮被按下时,机械手将抓取筐顶部的一件物品&#x…

中间件漏洞解析

服务器解析漏洞算是历史比较悠久了,但如今依然广泛存在。在此记录汇总一些常见服务器(WEB server)的解析漏洞,比如IIS6.0、IIS7.5、apache、nginx等 2|0 二、IIS5.x-6.x解析漏洞(针对asa/asp/cer) 2|11、打…

学习笔记|盘点一些Linux 常用的命令

目录 1、apt-get Debian/Ubuntu系统包管理器 2、uname 获取 操作系统信息 3、date 查看/设置 系统时间 4、yum CentOS系统包管理器 5、mkdir 新建 文件夹 6、free 查看内存使用信息 7、wget 下载工具 8、cd 进入 文件夹 8、cp 复制或重命名 文件/文件夹 9、VI、VIM …

机器学习之支持向量机(SVM)

1 支持向量机介绍 支持向量机(support vector machine,SVM)是有监督学习中最有影响力的机器学习算法之一,该算法的诞生可追溯至上世纪 60 年代, 前苏联学者 Vapnik 在解决模式识别问题时提出这种算法模型,…

synchronized监视器锁

1、synchronized&监视器锁 1.1 synchronized 介绍 在 Java 中,synchronized 是一种关键字,用于实现线程的同步和互斥控制。它可以修饰方法或代码块,用于保护共享资源的访问,避免多个线程同时修改数据而引发的并发问题。 具…

chatgpt赋能python:Python重写父类__init__方法的必要性与实现方法

Python重写父类__init__方法的必要性与实现方法 在Python中,一个类可以继承自另一个类,从而获得另一个类的属性和方法。当我们继承一个父类时,通常我们需要重写其中的一些方法,以满足我们自己的需求。在这篇文章中,我…

玩机搞机-----带你了解高通刷机平台中的一些选项释义 玩转平台

很多刷机工具玩家都使用过,但对于一些新手来说。有些选项所表达的意义不太了解,选择与否严重会导致机型固件刷完个别功能出现故障,今天的这个博文对有些刷机平台中的选项做个简单的说明。 一 小米刷机平台 MiFlash.截止目前最新的版本是2022…

最新|2024年QS世界大学排名前100榜单发布

6月28日世界高等教育研究机构Quacquarelli Symonds(QS)率先公布了2024年世界大学排名,本次QS排名因指标和权重的重大调整,导致排名发生较大变化。知识人网小编将新的评分标准及前100的大学榜单整理如下,供读者参考。 前…

Unity渲染工程收集

NPR 非真实渲染 UnityURP-AnimeStyleCelShader SSR 屏幕空间反射 UnitySSReflectionURP

消息传输不丢失:掌握消息中间件的持久化机制

当涉及到消息的持久化和重放时,我们可以使用Spring Boot与RabbitMQ来实现这个场景。RabbitMQ支持消息的持久化,以确保在发送和接收过程中消息不会丢失。同时,我们可以使用消息的重放机制,以便在需要时重新发送消息。 首先&#xf…

leetcode:387. 字符串中的第一个唯一字符(python3解法)

难度:简单 给定一个字符串 s ,找到 它的第一个不重复的字符,并返回它的索引 。如果不存在,则返回 -1 。 示例 1: 输入: s "leetcode" 输出: 0示例 2: 输入: s "loveleetcode" 输出: 2示例 3: 输…

1253. 重构 2 行二进制矩阵(力扣)

1253. 重构 2 行二进制矩阵(力扣) 题目第一种方式分析测试代码运行结果 第二种方式测试代码运行结果 题目 给你一个 2 行 n 列的二进制数组: 矩阵是一个二进制矩阵,这意味着矩阵中的每个元素不是 0 就是 1。 第 0 行的元素之和为…