pytorch 演示 tensor并行

news2024/11/15 14:10:17

pytorch 演示 tensor并行

  • 一.原理
  • 二.实现代码

本文演示了tensor并行的原理。如何将二个mlp切分到多张GPU上分别计算自己的分块,最后做一次reduce。
1.为了避免中间数据产生集合通信,A矩阵只能列切分,只计算全部batch*seqlen的部分feature
2.因为上面的步骤每张GPU只有部分feature,只因B矩阵按行切分,可与之进行矩阵乘,生成部分和
3.最后把每张GPU上的部分和加起来,就是最张的结果
以下demo,先实现了非分块的模型,然后模拟nccl分块,最后是分布式的实现

一.原理

在这里插入图片描述

二.实现代码

# torch_tp_demo.py
import os
import torch
from torch import nn
import torch.nn.functional as F 
import numpy as np
import torch.distributed as dist
from torch.distributed import ReduceOp
  
import time
import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument('--hidden_size', default=512, type=int, help='')
parser.add_argument('--ffn_size', default=1024, type=int, help='')
parser.add_argument('--seq_len', default=512, type=int, help='')
parser.add_argument('--batch_size', default=8, type=int, help='')
parser.add_argument('--world_size', default=4, type=int, help='')
parser.add_argument('--device', default="cuda", type=str, help='')

class FeedForward(nn.Module): 

    def __init__(self,hidden_size,ffn_size): 
        super(FeedForward, self).__init__() 
        self.fc1 = nn.Linear(hidden_size, ffn_size,bias=False)
        self.fc2 = nn.Linear(ffn_size, hidden_size,bias=False)

    def forward(self, input): 
        return self.fc2(self.fc1(input))

class FeedForwardTp(nn.Module):

    def __init__(self,hidden_size,ffn_size,tp_size,rank): 
        super(FeedForwardTp, self).__init__() 
        self.fc1 = nn.Linear(hidden_size, ffn_size//tp_size,bias=False)
        self.fc2 = nn.Linear(ffn_size//tp_size, hidden_size,bias=False)
      
        self.fc1.weight.data=torch.from_numpy(np.fromfile(f"fc1_{rank}.bin",dtype=np.float32)).reshape(self.fc1.weight.data.shape)
        self.fc2.weight.data=torch.from_numpy(np.fromfile(f"fc2_{rank}.bin",dtype=np.float32)).reshape(self.fc2.weight.data.shape)

    def forward(self, input): 
        return self.fc2(self.fc1(input))


args = parser.parse_args()
hidden_size = args.hidden_size
ffn_size = args.ffn_size
seq_len = args.seq_len
batch_size = args.batch_size
world_size = args.world_size
device = args.device

def native_mode():
  print(args)
  torch.random.manual_seed(1)
  model = FeedForward(hidden_size,ffn_size)
  model.eval()
  input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)

  for idx,chunk in enumerate(torch.split(model.fc1.weight, ffn_size//world_size, dim=0)):
      chunk.data.numpy().tofile(f"fc1_{idx}.bin")
  
  for idx,chunk in enumerate(torch.split(model.fc2.weight, ffn_size//world_size, dim=1)):
      chunk.data.numpy().tofile(f"fc2_{idx}.bin")
  
  model=model.half().to(device)
  

  usetime=[]
  for i in range(32):
    t0=time.time()    
    out = model(input)
    torch.cuda.synchronize()
    t1=time.time()
    if i>3:
      usetime.append(t1-t0)
  
  print("[INFO] native: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))
  

  result=[]
  for rank in range(world_size):
      model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)
      model.eval()
      out=model(input)
      torch.cuda.synchronize()
      result.append(out)
  
  sum_all=result[0]
  for t in result[1:]:
      sum_all=sum_all+t
  
  print("[INFO] tp_simulate: shape:{},sum:{:.5f}".format(sum_all.shape,sum_all.sum().item()))

def tp_mode():
  torch.random.manual_seed(1)
  dist.init_process_group(backend='nccl')
    
  world_size = torch.distributed.get_world_size()
  rank=rank = torch.distributed.get_rank()
  local_rank=int(os.environ['LOCAL_RANK'])
  
  torch.cuda.set_device(local_rank)
  device = torch.device("cuda",local_rank)
  
  input = torch.rand((batch_size, seq_len, hidden_size),dtype=torch.float32).half().to(device)  
  model = FeedForwardTp(hidden_size,ffn_size,world_size,rank).half().to(device)
  model.eval()
  if rank==0:
    print(args)
    
  usetime=[]
  for i in range(32):        
    dist.barrier()
    t0=time.time()
    out=model(input)
    #dist.reduce(out,0, op=ReduceOp.SUM) 
    dist.all_reduce(out,op=ReduceOp.SUM)
    torch.cuda.synchronize()
    if rank==0:
      t1=time.time()
      if i>3:
        usetime.append(t1-t0)
  
  if rank==0:
    print("[INFO] tp: shape:{},sum:{:.5f},mean:{:.5f},min:{:.5f},max:{:.5f}".format(out.shape,out.sum().item(),np.mean(usetime),np.min(usetime),np.max(usetime)))


if __name__ == "__main__":
  num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
  is_distributed = num_gpus > 1
  if is_distributed:
    tp_mode()
  else:
    native_mode()

运行命令:

python3 torch_tp_demo.py --hidden_size 512 \
			--ffn_size 4096 --seq_len 512 \
			--batch_size 8 --world_size 4 --device "cuda"
torchrun -m --nnodes=1 --nproc_per_node=4 \
			torch_tp_demo --hidden_size 512 \
			--ffn_size 4096 --seq_len 512 \
			--batch_size 8 --world_size 4 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1575384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode刷题-哈希表详细总结(Java)

哈希表 当我们想使⽤哈希法来解决问题的时候,我们⼀般会选择如下三种数据结构。 数组set (集合)map(映射) 当我们遇到了要快速判断⼀个元素是否出现集合⾥的时候,就要考虑哈希法。如果在做⾯试题⽬的时候…

搭建好WordPress网站后的基本操作流程

考虑到很多朋友是第一次使用WordPress,这里给大家分享一下基本的WordPress操作流程,你可以跟着实际情况决定操作步骤。 1.设置网站SSL安全证书。 我采用的是Hostease家的Linux主机产品,自带免费SSL证书 。支持一键安装wordpress程序。 2.进…

智慧驿站式的“智慧公厕”,给城市新基建带来新变化

随着智慧城市建设的推进,智慧驿站作为一种多功能城市部件,正逐渐在城市中崭露头角。这些智慧驿站集合了智慧公厕的管理功能,为城市的新基建带来了全新的变革。本文以智慧驿站智慧公厕源头实力厂家广州中期科技有限公司,大量精品案…

MyBatis操作数据库(1)

前言 在应用分层的学习时, 我们了解到web应用程序一般分为三层,即Controller, Service, Dao. 之前的案例中, 请求流程如下: 浏览器发起请求, 先请求Controller, Controller接受到请求后,调用Service进行业务逻辑处理, Service再调用Dao, 但是Dao层的数据是Mock的, 真实的数据…

基于 Vue3 + Webpack5 + Element Plus Table 二次构建表格组件

基于 Vue3 Webpack5 Element Plus Table 二次构建表格组件 文章目录 基于 Vue3 Webpack5 Element Plus Table 二次构建表格组件一、组件特点二、安装三、快速启动四、单元格渲染配置说明五、源码下载地址 基于 Vue3 Webpack5 Element Plus Table 二次构建表格组件&#x…

【白菜基础】蛋白组学之生信分析(1)

刚换了一个新课题组,新老板的研究方向为蛋白组学,从未接触过蛋白组学的我准备找一组模拟数据进行生信分析的入门学习。 蛋白组学数据挖掘流程图,参考公众号:蛋白质组学数据挖掘思路解析 (qq.com) 一、认识数据 我们组的数据主要…

【高校科研前沿】中国科学院南京地理与湖泊研究所肖启涛博士为一作在Sci. Bull发文:我国湖泊二氧化碳从大气的源向汇转变

目录 1.文章简介 2.研究内容 3.文章引用 1.文章简介 论文名称:Lakes shifted from a carbon dioxide source to a sink over past two decades in China 第一作者及通讯作者:肖启涛(博士生),段洪涛(研究…

【已解决】HalconDotNet.HOperatorException:“HALCON error #1201: Wrong type of control

前言 最近在学习Halcon视觉和C#的联合开发,碰到一个比较有意思的问题记录一下,大致的报错信息是说我用的halcondotnet版本和我在halcon导出的使用的halcondotnet.dll版本不一致,所以才报错的! 解决 首先你得找到你安装halcon的…

接口自动化入门:Jmeter的多组数据测试、JDBC驱动及数据断言!

在进行接口测试时,我们经常需要对接口进行多组数据测试,以验证接口在不同输入条件下的表现。同时,我们也需要对接口返回的数据进行断言,以确保接口返回的数据符合预期结果。JMeter正是一个强大的工具,可以帮助我们实现…

【Linux】正则表达式实验操作实例

正则表达式是一种强大的工具,用于在文本中查找、匹配和替换特定的字符串模式。 实验目的 掌握正则表达式的表达方式掌握grep/egrep命令的用法掌握sed 命令的用法掌握awk命令的用法 正则表达式 实验目的实验内容实验过程创建grep文件来进行如下操作用sed命令完成下列…

寻找排序数组中的最小值

题目描述 已知一个长度为 n 的数组,预先按照升序排列,经由 1 到 n 次 旋转 后,得到输入数组。例如,原数组 nums [0,1,2,4,5,6,7] 在变化后可能得到: 若旋转 4 次,则可以得到 [4,5,6,7,0,1,2]若旋转 7 次…

如何水出第一篇SCI:SCI发刊历程,从0到1全过程经验分享!!!

如何水出第一篇SCI:SCI发刊历程,从0到1全路程经验分享!!! 详细的改进教程以及源码,戳这!戳这!!戳这!!!B站:Ai学术叫叫兽e…

机器学习(30)

文章目录 摘要一、文献阅读1. 题目2. abstract3. 网络架构3.1 Sequence Generative Adversarial Nets3.2 SeqGAN via Policy Gradient3.3 The Generative Model for Sequences3.4 The Discriminative Model for Sequences(CNN) 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过…

UWB 雷达动目标检测

1. 静态载波滤除 1. 首先对所有接收脉冲求平均得出参考接收脉冲 [Cir数据为二维数组64*n, 其中n为慢时间域采样的数据帧数] 2. 接着利用每一束接收脉冲减去参考接收脉冲就可以得到目标回波信号,参考接收脉冲的表达式为 2. RD 谱 对雷达回波做静态载波滤…

Linux:IO多路转接之epoll

文章目录 epoll历史epoll的接口epoll_createepoll_waitepoll_ctl epoll原理代码实验 前面的内容介绍了select多路转接,也分析了其利弊,后面用poll改良了select,解决了部分的缺点,但是对于一些核心的缺点还是不能保证,比…

Langchain教程 | langchain+OpenAI+PostgreSQL(PGVector) 实现全链路教程,简单易懂入门

前提: 在阅读本文前,建议要有一定的langchain基础,以及langchain中document loader和text spliter有相关的认知,不然会比较难理解文本内容。 如果是没有任何基础的同学建议看下这个专栏:人工智能 | 大模型 | 实战与教程…

品牌定位升级|飞雕开关如何从家庭作坊走上国际之路?

飞雕电器,这个名字在中国开关插座行业中如同一面旗帜,自1987年起就扬帆在电工领域的大海中。它不仅见证了这个行业的起起伏伏,还始终以其创新的姿态站在浪尖之上。 飞雕的产品线丰富多彩,除主营的墙壁开关插座领域外,飞雕电器还涉足了与墙壁开关紧密相关的其它领域,现已推出移…

vmware 中的Ubuntu系统虚拟机忘记root密码强制重置操作

忘记密码情况下,vmware虚拟机重置Ubuntu的root密码 在企业使用的vmware ESXI中重置Ubuntu系统root密码 1-本地电脑安装个人版的vmware workstation,目的:vmware ESXI自带的远程控制台无法输入指定的键盘按键,需要借助外部的远程辅…

Ceph学习 -3.存储简介

文章目录 1.存储简介1.1 存储类型1.1.1 储备知识1.1.2 三种存储1.1.3 块存储1.1.4 文件存储1.1.5 对象存储1.1.6 三种存储之间的关系1.1.7 总结 1.2 Ceph简介1.2.1 官方介绍1.2.2 软件特点1.2.3 基本结构1.2.4 应用场景 1.3 小结 1.存储简介 学习目标:这一节&#x…

免疫检查点信号转导和癌症免疫治疗(文献)

目录 基础 介绍 免疫检查点的表面调控(细胞膜层面) ​编辑 PD-1调节 PD-L1调节 CTLA-4 调节 检查点信号通路 关于靶点研究 展望 Immune checkpoint signaling and cancer immunotherapy - PubMed (nih.gov) 基础 【中英字幕】肿瘤免疫疗法之免…