将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

news2024/12/29 8:46:50

将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

  • 一.测试数据
  • 二.测试环境
  • 三.普通实现
  • 四.分块实现

本文演示了如何将数据切分成N份,采用NCCL异步通信,让all_gather+matmul尽量Overlap

一.测试数据

  • 1.测试规模:8192*8192 world_size=2
  • 2.单算子:all_gather:0.03508s matmul:0.05689s e2e:0.09197s。matmul耗时最长
  • 3.按输入和权值切分成8份,async_op=True。e2e:0.75ms
  • 4.e2e耗时从91ms缩短到75ms 缩短了17%。耗时为纯matmul算子的:1.34倍

二.测试环境

docker run --gpus all --shm-size=32g -ti -e NVIDIA_VISIBLE_DEVICES=all \
        --privileged --net=host -v $PWD:/home \
        -w /home --name all_gather_mm \
        nvcr.io/nvidia/pytorch:23.07-py3 /bin/bash

三.普通实现

tee all_gather_mm_native.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
from torch.profiler import profile
import nvtx

dev_type="cuda"
dist.init_process_group(backend='nccl')

torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)
shape=(8192,8192)

input_tensor=torch.rand((shape[0],shape[1]),dtype=torch.float).to(device)
weight=torch.rand((shape[1],8192),dtype=torch.float).to(device)
all_gather_buffer=torch.zeros((shape[0]*world_size,shape[1]),dtype=torch.float).to(device)

for i in range(10):
    with nvtx.annotate(f"iter:{i}", color="blue"): 
        dist.barrier()
        t0=time.time()
        torch.distributed._all_gather_base(all_gather_buffer, input_tensor)
        dist.barrier()
        torch.cuda.synchronize()
        t1=time.time()
        output = torch.matmul(all_gather_buffer, weight)
        torch.cuda.synchronize()
        t2=time.time()
        if rank==0:
            print(f"iter:{i} all_gather:{t1-t0:.5f} matmul:{t2-t1:.5f} e2e:{t2-t0:.5f} data:{output.mean()}")
EOF
export NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
export CUDA_VISIBLE_DEVICES="1,3"
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_native

nsys profile --stats=true -o all_gather_mm_native.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \
        torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_native

输出

iter:0 all_gather:0.03809 matmul:0.84971 e2e:0.88780 data:2047.62548828125
iter:1 all_gather:0.03327 matmul:0.06595 e2e:0.09922 data:2047.62548828125
iter:2 all_gather:0.03720 matmul:0.06082 e2e:0.09802 data:2047.62548828125
iter:3 all_gather:0.03682 matmul:0.05644 e2e:0.09326 data:2047.62548828125
iter:4 all_gather:0.03382 matmul:0.05648 e2e:0.09030 data:2047.62548828125
iter:5 all_gather:0.03404 matmul:0.05635 e2e:0.09039 data:2047.62548828125
iter:6 all_gather:0.03657 matmul:0.05701 e2e:0.09359 data:2047.62548828125
iter:7 all_gather:0.03840 matmul:0.05695 e2e:0.09535 data:2047.62548828125
iter:8 all_gather:0.03721 matmul:0.05685 e2e:0.09406 data:2047.62548828125
iter:9 all_gather:0.03508 matmul:0.05689 e2e:0.09197 data:2047.62548828125

在这里插入图片描述

四.分块实现

tee all_gather_mm_tiling.py <<-'EOF'
import os
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
import time
import numpy as np
import nvtx

# 分几块
num_blocks = 8

dev_type="cuda"
dist.init_process_group(backend='nccl')

torch.manual_seed(1)
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
local_rank=int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
device = torch.device(dev_type,local_rank)

streams = [torch.cuda.Stream(device=device) for _ in range(num_blocks)]

def all_gather_matmul(rank, world_size, input, weight,gathered_buffer,output_buffer, num_blocks, device):
    input_chunk_size = input.size(0) // num_blocks  # 每块的大小
    weight_chunk_size = weight.size(1) // num_blocks
    handles = []
    for i in range(num_blocks):
        with torch.cuda.stream(streams[i]):
            # 划分块并进行 all_gather
            input_chunk = input[i * input_chunk_size: (i + 1) * input_chunk_size]
            gather_start_idx = i * input_chunk_size * world_size  # 起始索引
            handle = dist.all_gather_into_tensor(gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size], input_chunk, async_op=True)
            handles.append((handle, gather_start_idx))
    outputs = torch.zeros_like(output_buffer)
    for i in range(num_blocks):
        with torch.cuda.stream(streams[i]):
            handle, gather_start_idx = handles[i]
            handle.wait()  # 等待通信完成
            # 直接在通信结果上进行矩阵乘法
            gathered_input = gathered_buffer[gather_start_idx:gather_start_idx + input_chunk_size * world_size]
            for j in range(num_blocks):
                weight_chunk = weight[:, j * weight_chunk_size: (j + 1) * weight_chunk_size]
                output_chunk = outputs[i * input_chunk_size * world_size: (i + 1) * input_chunk_size * world_size, j * weight_chunk_size: (j + 1) * weight_chunk_size]             
                # 进行局部矩阵相乘
                output_chunk.add_(torch.matmul(gathered_input, weight_chunk))
    torch.cuda.synchronize(device)
    return outputs

# 初始化
input = torch.rand((8192, 8192),dtype=torch.float).to(device) 
weight = torch.rand((8192, 8192),dtype=torch.float).to(device) 
all_gather_buffer = torch.zeros((8192 * world_size, 8192),dtype=torch.float).to(device)

for i in range(10):
    output = torch.zeros(input.size(0) * world_size, weight.size(1),dtype=torch.float,device=device)
    dist.barrier()
    t0=time.time()
    with nvtx.annotate(f"iter:{i}", color="blue"):
        output = all_gather_matmul(rank, world_size, input, weight,all_gather_buffer,output,num_blocks,device)
    torch.cuda.synchronize()
    t1=time.time()
    if rank == 0:
        print(f"iter:{i} e2e:{t1-t0:.5f} data:{output.mean()}")
EOF

export NCCL_DEBUG=error
export NCCL_IB_DISABLE=1
torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tiling

nsys profile --stats=true -o all_gather_mm_tiling.nsys-rep -f true -t cuda,nvtx --gpu-metrics-device=1,3 \
        torchrun -m --nnodes=1 --nproc_per_node=2 all_gather_mm_tiling

输出

iter:0 e2e:0.13553 data:2047.62548828125
iter:1 e2e:0.07687 data:2047.62548828125
iter:2 e2e:0.07717 data:2047.62548828125
iter:3 e2e:0.07645 data:2047.62548828125
iter:4 e2e:0.07724 data:2047.62548828125
iter:5 e2e:0.07586 data:2047.62548828125
iter:6 e2e:0.07587 data:2047.62548828125
iter:7 e2e:0.07589 data:2047.62548828125
iter:8 e2e:0.07626 data:2047.62548828125
iter:9 e2e:0.07549 data:2047.62548828125

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1887348.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字化供应链:背景特点

​背景 1、外部环境 近年来&#xff0c;供应链脆弱性凸显&#xff0c;企业供应链压力难以缓解。 美国媒体针对美国零售联合会、美国服装和鞋类协会、美国供应链管理专业委员会等主体进行的一项供应链调查显示&#xff1a; 61%的供应链经理预计&#xff0c;供应链紊乱问题至少…

在IDEA中创建Maven项目

相关内容&#xff1a; Maven的安装与配置 在IDEA中配置Maven环境 IDEA中导入Maven项目 2023版IDEA创建Maven项目&#xff08;新版&#xff09; 1.打开IDEA&#xff0c;点击 文件 -> 新建 -> 项目 2.创建Maven项目 3.编写java文件并运行 在src -> java -> 创建…

xcode运行ios项目报错Sandbox: rsync.samba(24352) deny(1) file-write-create

xcode运行ios项目报错 Sandbox: rsync.samba(24352) deny(1) file-write-create 解决方案&#xff1a; Update your Xcode project build option ENABLE_USER_SCRIPT_SANDBOXING to No.

谷歌GenType:1分钟生成AI艺术字母表,小众但好用,完全免费!(附教程)

文章首发于公众号&#xff1a;X小鹿AI副业 大家好&#xff0c;我是程序员X小鹿&#xff0c;前互联网大厂程序员&#xff0c;自由职业2年&#xff0c;也一名 AIGC 爱好者&#xff0c;持续分享更多前沿的「AI 工具」和「AI副业玩法」&#xff0c;欢迎一起交流~ 最近发现一个好玩的…

2024最新版Redis常见面试题包含详细讲解

Redis适用于哪些场景&#xff1f; 缓存分布式锁降级限流消息队列延迟消息队 说一说缓存穿透 缓存穿透的概念 用户频繁的发起恶意请求查询缓存中和数据库中都不存在的数据&#xff0c;查询积累到一定量级导致数据库压力过大甚至宕机。 缓存穿透的原因 比如正常情况下用户发…

维护Nginx千字经验总结

Hello , 我是恒 。 维护putty和nginx两个项目好久了&#xff0c;用面向底层的思路去接触 在nginx社区的收获不少&#xff0c;在这里谈谈我的感悟 Nginx的夺冠不是偶然 高速:一方面&#xff0c;在正常情况下&#xff0c;单次请求会得到更快的响应&#xff1b;另一方面&#xff0…

1996-2023年各省财政收支数据(无缺失)(地方财政一般预算收入、地方财政一般预算支出)

1996-2023年各省财政收支数据&#xff08;无缺失&#xff09;&#xff08;地方财政一般预算收入、地方财政一般预算支出&#xff09; 1、时间&#xff1a;1996-2023年 2、来源&#xff1a;国家统计局、统计年鉴、 3、指标&#xff1a;地方财政一般预算收入、地方财政一般预算…

51单片机第23步_定时器1工作在模式0(13位定时器)

重点学习51单片机定时器1工作在模式0的应用。 在51单片机中&#xff0c;定时器1工作在模式0&#xff0c;它和定时器0一样&#xff0c;TL1占低5位&#xff0c;TH1占高8位&#xff0c;合计13位&#xff0c;也是向上计数。 1、定时器1工作在模式0 1)、定时器1工作在模式0的框图…

SUPERVIVE无法联机、联机失败、联机报错的解决办法分享

SUPERVIVE是一款战术竞技游戏&#xff0c;核心玩法为多人大逃杀&#xff0c;40名玩家可以自愿或随机组成2或4人小分队&#xff0c;空降进入末日地图&#xff0c;一边苟着收集资源&#xff0c;一边武装自己&#xff0c;在生存区不断首夺的同时&#xff0c;努力战到最后&#xff…

pycharm中新建的临时python文件存放在哪里?

在pycharm中建立的临时python文件&#xff0c;从哪里可以找到呢&#xff1f; 1.我们打开cmd窗口&#xff0c;进入根目录&#xff0c;用dos命令“dir scratch*.py/a/s”进行查找&#xff0c;发现这些临时文件存放在Roaming\JetBrains\PyCharmCE2022.2\scratches 的目录里面 2.…

2Python的Pandas:读取数据

1.读取Excel文件 1.1.读取数据 import pandas as pd# Excel 文件的 URL 或本地路径 url "https://www.gairuo.com/file/data/dataset/team.xlsx"# 使用 Pandas 的 read_excel 函数读取数据 try:df pd.read_excel(url)print(df.head()) # 打印 DataFrame 的前几行…

在 Mac 上使用 本地 LLM 文本终结

我们可使用本地大型语言模型&#xff0c;如Mistral、Llama等&#xff0c;来给文本做总结&#xff0c;相比在线的 Kimi &#xff0c;ChatGPT&#xff0c; 我们不用担心数据泄露&#xff0c;因为整个操作都是在本地电脑完成的。 我们用 ollama 举例 首先安装 ollama https://ol…

观测云赋能「阿里云飞天企业版」,打造全方位监控观测解决方案

近日&#xff0c;观测云成功通过了「阿里云飞天企业版」的生态集成认证测试&#xff0c;并荣获阿里云颁发的产品生态集成认证证书。作为监控观测领域的领军者&#xff0c;观测云一直专注于提供统一的数据视角&#xff0c;助力用户构建起全球范围内的端到端全链路可观测服务。此…

荣耀大横评,睿蓝7-450荣耀版卷出来的性价比之王

手握11万左右预算,如何在市场内选出一辆合适自己的车?荣耀版车型无疑是当下的最佳答案。在众多荣耀版车型中,比亚迪宋PLUS荣耀版EV520km领先型(后统称宋PLUS荣耀版)、比亚迪元PLUS荣耀版430km领先型(后统称元PLUS荣耀版)、比亚迪海豚PLUS荣耀版420km时尚版(后统称海豚荣耀版)、…

YOLO-V2

一、V2版本细节升级 1、YOLO-V2&#xff1a; 更快&#xff01;更强 1.1 做的改进内容 1. YOLO-V2-Batch Normalization V2版本舍弃Dropout&#xff0c;卷积后每一层全部加入Batch Normalization网络的每一层的输入都做了归一化&#xff0c;收敛相对更容易经过Batch Norma…

Python协作运动机器人刚体力学解耦模型

&#x1f3af;要点 &#x1f3af;腿式或固定式机器人模型 | &#x1f3af;网格、点云和体素网格碰撞检测 | &#x1f3af;正反向运动学和动力学 | &#x1f3af;机器人刚体力学计算 | &#x1f3af;编辑参考系姿势和路径 | &#x1f3af;软件接口实体机器人模拟 | &#x1f3a…

MyBatis-plus这么好用,不允许还有人不会

你好呀&#xff0c;我是 javapub. 做 Java 的同学都会用到的三件套&#xff0c;Spring、SpringMV、MyBatis。但是由于使用起来配置较多&#xff0c;依赖冲突频发。所有&#xff0c;各路大佬又在这上边做了包装&#xff0c;像我们常用的 SpringBoot、MyBatisPlus。 基于当前要…

2024年7月2日 (周二) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键&#xff0c;实现一键唤起、一键隐藏的 Windows 工具&#xff0c;并且支持窗口动态绑定快捷键&#xff08;无需设置自动实现&#xff09;。 卸载工具 HiBitUninstaller: Windows上的软件卸载工具 经典名作30周年新篇《恐怖惊魂夜…

谷歌网站SEO服务有哪些?

Seo其实说来说去就包含三样&#xff0c;网站本身技术优化&#xff0c;内容以及外链&#xff0c;而这三样里&#xff0c;网站的技术优化是前提本身&#xff0c;确保网站符合谷歌搜索规范&#xff0c;包括调整网站的结构、速度和移动设备兼容性&#xff0c;以提高用户体验和搜索引…

【C语言】const 关键字

在C语言中&#xff0c;const关键字用于定义常量&#xff0c;使得变量的值在其声明之后无法被修改。这可以帮助防止意外修改数据&#xff0c;提高代码的安全性和可读性。以下是有关const关键字的一些详细说明&#xff1a; 基本用法 const int max_value 100;在这个例子中&…