【深度学习】【分布式训练】Collective通信操作及Pytorch示例

news2025/1/16 1:32:21

相关博客
【深度学习】【分布式训练】Collective通信操作及Pytorch示例
【自然语言处理】【大模型】大语言模型BLOOM推理工具测试
【自然语言处理】【大模型】GLM-130B:一个开源双语预训练语言模型
【自然语言处理】【大模型】用于大型Transformer的8-bit矩阵乘法介绍
【自然语言处理】【大模型】BLOOM:一个176B参数且可开放获取的多语言模型

Collective通信操作及Pytorch示例

​ 大模型时代,单机已经无法完成先进模型的训练和推理,分布式训练和推理将会是必然的选择。各类分布式训练和推断工具都会使用到Collective通信。网络上大多数的教程仅简单介绍这些操作的原理,没有代码示例来辅助理解。本文会介绍各类Collective通信操作,并展示pytorch中如何使用

一、Collective通信操作

1. AllReduce

​ 将各个显卡的张量进行聚合(sum、min、max)后,再将结果写回至各个显卡。

在这里插入图片描述

2. Broadcast

​ 将张量从某张卡广播至所有卡。

请添加图片描述

3. Reduce

​ 执行同AllReduce相同的操作,但结果仅写入具有的某个显卡。

请添加图片描述

4. AllGather

​ 每个显卡上有一个大小为N的张量,共有k个显卡。经过AllGather后将所有显卡上的张量合并为一个 N × k N\times k N×k的张量,然后将结果分配至所有显卡上。

请添加图片描述

5. ReduceScatter

​ 执行Reduce相同的操作,但是结果会被分散至不同的显卡。

请添加图片描述

二、Pytorch示例

​ pytorch的分布式包torch.distributed能够方便的实现跨进程和跨机器集群的并行计算。本文代码运行在单机双卡服务器上,并基于下面的模板来执行不同的分布式操作。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def init_process(rank, size, fn, backend='nccl'):
    """
    为每个进程初始化分布式环境,保证相互之间可以通信,并调用函数fn。
    """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
    
    
def run(world_size, func):
    """
    启动world_size个进程,并执行函数func。
    """
    processes = []
    mp.set_start_method("spawn")
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, func))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        
if __name__ == "__main__":
    run(2, func) # 这里的func随后会被替换为不同的分布式示例函数
    pass

​ 先对上面的模板做一些简单的介绍。

  • 函数run会根据传入的参数world_size,生成对应数量的进程。每个进程都会调用init_process来初始化分布式环境,并调用传入的分布式示例函数。
  • torch.distributed.init_process_group(),该方法负责各进程之间的初始协调,保证各进程都会与master进行握手。该方法在调用完成之前会一直阻塞,并且后续的所有操作都必须在该操作之后。调用该方法时需要初始化下面的4个环境变量:
    • MASTER_PORT:rank 0进程所在机器上的空闲端口;
    • MASTER_ADDR:rank 0进程所在机器上的IP地址;
    • WORLD_SIZE:进程总数;
    • RANK:每个进程的RANK,所以每个进程知道其是否是master;

1. 点对点通信

​ 在介绍其他collective通信之前,先看一个简单的点对点通信实现。

def p2p_block_func(rank, size):
    """
    将rank src上的tensor发送至rank dst(阻塞)。
    """
    src = 0
    dst = 1
    group = dist.new_group(list(range(size)))
    # 对于rank src,该tensor用于发送
    # 对于rank dst,该tensor用于接收
    tensor = torch.zeros(1).to(torch.device("cuda", rank))
    if rank == src:
        tensor += 1
        # 发送tensor([1.])
        # group指定了该操作所见进程的范围,默认情况下是整个world
        dist.send(tensor=tensor, dst=1, group=group)
    elif rank == dst:
        # rank dst的tensor初始化为tensor([0.]),但接收后为tensor([1.])
        dist.recv(tensor=tensor, src=0, group=group)
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, p2p_block_func)

p2p_block_func实现从rank 0发送一个tensor([1.0])至rank 1,该操作在发送完成/接收完成之前都会阻塞。

​ 下面是一个不阻塞的版本:

def p2p_unblock_func(rank, size):
    """
    将rank src上的tensor发送至rank dst(非阻塞)。
    """
    src = 0
    dst = 1
    group = dist.new_group(list(range(size)))
    tensor = torch.zeros(1).to(torch.device("cuda", rank))
    if rank == src:
        tensor += 1
        # 非阻塞发送
        req = dist.isend(tensor=tensor, dst=dst, group=group)
        print("Rank 0 started sending")
    elif rank == dst:
        # 非阻塞接收
        req = dist.irecv(tensor=tensor, src=src, group=group)
        print("Rank 1 started receiving")
    req.wait()
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, p2p_unblock_func)

p2p_unblock_func是非阻塞版本的点对点通信。使用非阻塞方法时,因为不知道数据何时送达,所以在req.wait()完成之前不要对发送/接收的tensor进行任何操作。

2. Broadcast

def broadcast_func(rank, size):
    src = 0
    group = dist.new_group(list(range(size)))
    if rank == src:
        # 对于rank src,初始化tensor([1.])
        tensor = torch.zeros(1).to(torch.device("cuda", rank)) + 1
    else:
        # 对于非rank src,初始化tensor([0.])
        tensor = torch.zeros(1).to(torch.device("cuda", rank))
    # 对于rank src,broadcast是发送;否则,则是接收
    dist.broadcast(tensor=tensor, src=0, group=group)
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, broadcast_func)

broadcast_func会将rank 0上的tensor([1.])广播至所有的rank上。

3. Reduce与Allreduce

def reduce_func(rank, size):
    dst = 1
    group = dist.new_group(list(range(size)))
    tensor = torch.ones(1).to(torch.device("cuda", rank))
    # 对于所有rank都会发送, 但仅有dst会接收求和的结果
    dist.reduce(tensor, dst=dst, op=dist.ReduceOp.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, reduce_func)

reduce_func会对group中所有rank的tensor进行聚合,并将结果发送至rank dst。

def allreduce_func(rank, size):
    group = dist.new_group(list(range(size)))
    tensor = torch.ones(1).to(torch.device("cuda", rank))
    # tensor即用来发送,也用来接收
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, allreduce_func)

allreduce_func将group中所有rank的tensor进行聚合,并将结果发送至group中的所有rank。

4. Gather与Allgather

def gather_func(rank, size):
    dst = 1
    group = dist.new_group(list(range(size)))
    # 该tensor用于发送
    tensor = torch.zeros(1).to(torch.device("cuda", rank)) + rank
    gather_list = []
    if rank == dst:
        # gather_list中的tensor数量应该是size个,用于接收其他rank发送来的tensor
        gather_list = [torch.zeros(1).to(torch.device("cuda", dst)) for _ in range(size)]
        # 仅在rank dst上需要指定gather_list
        dist.gather(tensor, gather_list=gather_list, dst=dst, group=group)
    else:
        # 非rank dst,相当于发送tensor
        dist.gather(tensor, dst=dst, group=group)
    print('Rank ', rank, ' has data ', gather_list)
    
if __name__ == "__main__":
    run(2, gather_func)

gather_func从group中所有rank上收集tensor,并发送至rank dst。(相当于不进行聚合操作的reduce)

def allgather_func(rank, size):
    group = dist.new_group(list(range(size)))
    # 该tensor用于发送
    tensor = torch.zeros(1).to(torch.device("cuda", rank)) + rank
    # gether_list用于接收各个rank发送来的tensor
    gather_list = [torch.zeros(1).to(torch.device("cuda", rank)) for _ in range(size)]
    dist.all_gather(gather_list, tensor, group=group)
    # 各个rank的gather_list均一致
    print('Rank ', rank, ' has data ', gather_list)
    
if __name__ == "__main__":
    run(2, allgather_func)

allgather_func从group中所有rank上收集tensor,并将收集到的tensor发送至所有group中的rank。

5. Scatter与ReduceScatter

def scatter_func(rank, size):
    src = 0
    group = dist.new_group(list(range(size)))
    # 各个rank用于接收的tensor
    tensor = torch.empty(1).to(torch.device("cuda", rank))
    if rank == src:
        # 在rank src上,将tensor_list中的tensor分发至不同的rank上
        # tensor_list:[tensor([1.]), tensor([2.])]
        tensor_list = [torch.tensor([i + 1], dtype=torch.float32).to(torch.device("cuda", rank)) for i in range(size)]
        # 将tensor_list发送至各个rank
        # 接收属于rank src的那部分tensor
        dist.scatter(tensor, scatter_list=tensor_list, src=0, group=group)
    else:
        # 接收属于对应rank的tensor
        dist.scatter(tensor, scatter_list=[], src=0, group=group)
    # 每个rank都拥有tensor_list中的一部分tensor
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, scatter_func)

scatter_func会将rank src中的一组tensor逐个分发至其他rank上,每个rank持有的tensor不同。

def reduce_scatter_func(rank, size):
    group = dist.new_group(list(range(size)))
    # 用于接收的tensor
    tensor = torch.empty(1).to(torch.device("cuda", rank))
    # 用于发送的tensor列表
    # 对于每个rank,有tensor_list=[tensor([0.]), tensor([1.])]
    tensor_list = [torch.Tensor([i]).to(torch.device("cuda", rank)) for i in range(size)]
    # step1. 经过reduce的操作会得到tensor列表[tensor([0.]), tensor([2.])]
    # step2. tensor列表[tensor([0.]), tensor([2.])]分发至各个rank
    # rank 0得到tensor([0.]),rank 1得到tensor([2.])
    dist.reduce_scatter(tensor, tensor_list, op=dist.ReduceOp.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor)
    
if __name__ == "__main__":
    run(2, reduce_scatter_func)

参考资料

https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html

https://pytorch.org/tutorials/intermediate/dist_tuto.html#collective-communication

https://pytorch.org/docs/stable/distributed.html#collective-functions

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/415689.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第02章_变量与运算符

第02章_变量与运算符 讲师:尚硅谷-宋红康(江湖人称:康师傅) 官网:http://www.atguigu.com 本章专题与脉络 1. 关键字(keyword) 定义:被Java语言赋予了特殊含义,用做专门…

银河麒麟服务器ky10 sp3 x86 pgadmin使用

目录 打开网页并登录 连接数据库 备份数据库 还原数据库 打开网页并登录 打开浏览器,输入127.0.0.1:5050,输入用户名和密码登录, 我这边设置的用户名是123456qq.com,密码是 123456 连接数据库 右键选择register-Server 输…

Html5版飞机大战游戏中(Boss战)制作

内容在“60行代码,制作飞机大战游戏”的基础上,继续追加入了Boss战的功能。 boss的血量默认设置为100了,可以二次开发调整……(^_^) 玩起来有一定难度哈。 试玩地址:点击试玩 实现功能 添加玩家飞机,并进行控制Boss能…

vue+MapboxGL:从0 到1 搭建开发环境

本系列教程是在vue2.X的基础上加载mapbox 程序,来开发各种示例程序。 安装顺序 1,下载安装nodejs 下载地址:https://nodejs.org/en/download/ 根据用户自己的机器情况进行选择不同版本的软件下载。 本教程示例采用是是windows 64位系统软件。 安装过程很简单,一路下一步…

vue-router3.0处理页面滚动部分源码分析

在使用vue-router3.0时候,会发现不同的路由之间来回切换,会滚动到上次浏览的位置,今天就来看看这部分的vue-router中的源码实现。 无论是基于hash还是history的路由切换,都对滚动进行了处理,这里分析其中一种即可。 无…

TeeChart Pro ActiveX 2023.3.20 Crack

TeeChart Pro ActiveX 图表组件库提供数百种 2D 和 3D 图形样式、56 种数学和统计函数供您选择,以及无限数量的轴和 14 个工具箱组件。图表控件可以有效地用于创建多任务仪表板。 插件的多功能性 ActiveX 图表控件作为服务器端库中的 Web 图表、脚本化 ASP 图表或桌…

0201概述和结构-索引-MySQL

文章目录1 概述1.1 介绍1.2 优缺点2 索引结构2.1 BTree索引2.2 hash索引2.3 对比3 索引分类3.1 通用分类3.2 InnoDB存储引擎分类4 思考题后记1 概述 1.1 介绍 索引是帮忙MySQL 高效获取数据的数据结构(有序)。在数据之外,数据系统还维护着满…

【CF1764C】Doremy‘s City Construction(二分图,贪心)

【题目描述】 有nnn个点,每个点的点权为aia_iai​,你可以在任意两个点之间连边,最终连成的图需要满足:不存在任意的三个点,满足au≤av≤awa_u\le a_v\le a_wau​≤av​≤aw​(非降序)且边(u,v)(…

『pyqt5 从0基础开始项目实战』06. 获取选中多行table 重新初始化数据(保姆级图文)

目录导包和框架代码重新初始化绑定点击事件获取当前选中的所有行id实现初始化数据完整代码main.pythreads.py总结欢迎关注 『pyqt5 从0基础开始项目实战』 专栏,持续更新中 欢迎关注 『pyqt5 从0基础开始项目实战』 专栏,持续更新中 导包和框架代码 请查…

案例分享 | 金融业智能运维AIOps怎么做?看这一篇就够了

​构建双态IT系统,AIOps已经是必然的选择。运维数字化转型已是大势所趋,实体业务的逐步线上化对IT系统的稳定与安全提出更高要求,同时随着双态IT等复杂系统的建立,如何平衡IT运维效率与成本成为区域性银行面临的重要问题&#xff…

Windows编程基础

Windows编程基础 Unit1应用程序分类 控制台程序:Console Dos程序,本身没有窗口,通过windows Dos窗口执行 窗口程序 拥有自己的窗口,可以与用户交互 库程序 存放代码、数据的程序,执行文件可以从中取出代码执行和获取…

【MySQL】索引事务

摄影分享~ 文章目录索引概念使用场景使用事务概念使用事务的特性索引 概念 索引是一种特殊的文件,包含着对数据表里所有记录的引用指针。可以对表中的一列或多列创建索引并指定索引的类型,各类索引有各自的数据结构实现。 通过目录,就可以…

如何使用数字示波器

本文介绍以鼎阳SIGLENT SDS1122E数字示波器为例。 带了一根电源线;两根信号线,每根信号线都有几个小配件,如下所示: 使用概述 我们都知道万用表(又称欧姆表)是工程师最常用的调试电路的工具,但万…

技术+商业“双轮”驱动,量旋科技加速推进全方位的量子计算解决方案

【中国,深圳】4月14日,在第三个“世界量子日”,以“‘双轮’驱动 加速未来”为主题的量旋科技2023战略发布会在线上举办。 本次发布会,量旋科技全线升级了三大业务线产品:其中重点布局的超导量子计算体系产品&#xf…

监控系统 Prometheus 的说明

一、Prometheus 是什么? ELK Stack 日志收集和检索平台想必大家应该比较熟悉,Elasticsearch Filebeat Logstash Kibana。 而 Prometheus 就相当于一整个 ELK,但是它其实并不是适合存储大量日志,也不适合长期存储(默…

【AI绘图学习笔记】transformer

台大李宏毅21年机器学习课程 self-attention和transformer 文章目录Seq2seq实现原理EncoderDecoderAutoregressive自回归解码器Non-Autoregressive非自回归解码器Corss-attention总结TrainingtrickCopy MechanismGuided AttentionBeam Search强化学习(Reinforcement…

AVL树,红黑树,红黑树封装map和set

文章目录AVL树AVL树的实现AVL树的节点AVL树的平衡因子AVL树的插入AVL树的旋转左单旋右单旋左右正旋右左正旋中序遍历打印节点判断子树是否平衡整体代码验证代码红黑树概念性质(规则)红黑树的实现结点定义插入parent在grandparent的左情况一:u…

登录认证功能的统一拦截技术(拦截器)

目录 1.说明 2.使用方法 (1) 定义拦截器 (2)注册配置拦截器 (3)示例: 3.interceptor详细说明 (1)拦截路径 (2)执行流程 (3)过滤器和拦截器的区别 4.登录校验的拦截器实现 5.全局异常处理(补充说明) 1.说明 拦截器是一种动态拦截方法调用的机制&#xff0…

3年功能测试被辞,待业3个月,2023不会自动化测试真的找不到工作吗?

前言 来自一位粉丝的投稿,在测试行业已近打拼了3年,一直兢兢业业,前不久被公司以人员优化的理由辞退,到现在已近过去了3个月还没有找到测试工作,让她很焦虑,我通过和她的交流才发现她最大的问题就是技术方…

从零开始学习Python中UnitTest测试框架:实现高效自动化测试流程

目录:导读 引言 1.白盒测试原理 2.自动化测试用例编写 3.UnitTest测试框架 3.1UnitTest组件(测试固件) 3.1.2测试套件 3.1.3测试运行 3.1.4测试断言 3.1.5测试结果 3.2unittest测试固件的详解 3.2.1测试固件每次均执行 3.2.2测试…