Megatron-LM源码系列(八): Context Parallel并行

news2024/12/23 9:15:28

1. Context Parallel并行原理介绍

megatron中的context并行(简称CP)与sequence并行(简称SP)不同点在于,SP只针对LayernormDropout输出的activation在sequence维度上进行切分,CP则是对所有的input输入和所有的输出activation在sequence维度上进行切分,可以看成是增强版的SP。除了Attention模块以外,其他的模块(Layernorm、Dropout)由于没有多token的处理,在CP并行时都不用任何修改。

为什么Attention模块是个例外? 因为Attention计算过程中每个token的Q(query)要跟同一个sequence中其他token的K(key)和V(value)一起进行计算,存在计算上的依赖,所以通过CP并行后,在计算Attention前要通过allgather通信拿到所有token的KV向量,在反向计算时对应需要通过reduce_scatter分发gradient梯度。

为了减少显存占用,在前向时每个gpu只用保存一部分KV块,反向时通过allgather通信拿到所有的KV数据。KV的通信发生在相邻TP通信组相同位置的rank之间。allgather和reduce_scatter在ring拓扑架构实现时,底层会通过send和recv来进行实现。
在这里插入图片描述

以上图TP2-CP2的transformer网络为例,在Attention前的是CP的通信算子,其他都是TP的通信算子。AG表示allgather, RS表示reduce_scatter, AG/RS表示前向allgather反向reduce_scatter, RS/AG表示前向reduce_scatter反向allgather。

这里TP2对应为[GPU0, GPU1], [GPU2, GPU3], CP2对应为TP组相同位置的rank号,也就是[GPU0, GPU2], [GPU1, GPU3]。CP并行与Ring Attention类似,但是提供了新的OSS与FlashAttention版本,也去除了low-triangle causal masking的冗余计算。

LLM经常由于sequence长度过长导致显存OOM,这时之前的一种方式是通过重计算的方式保存中间的activation产出,全量重计算的劣势会带来30%的计算代价;另外一种方式是扩大TP(tensor parallel)的大小,扩大TP的劣势在于会对tensor切的更小,从而导致linear fc的计算时间变少,从而与通信很难进行计算的掩盖。

通过CP可以更好解决OOM的问题,每个GPU只用处理一部分的sequence, 同时减少CP倍的通信和计算,但保持TP不变,同时activation也会减少CP倍。CP优化的性能参考如下图,在Megatron中(Megatron-Core>=0.5.0 && Transformer Engine >=1.1)通过指定--context-parallel-size可以进行使用。 t o t a l _ g p u _ c o u n t = T P × C P × P P × D P total\_gpu\_count = TP \times CP \times PP \times DP total_gpu_count=TP×CP×PP×DP

在这里插入图片描述

2. 源码

以Megatron-Core 0.5.0为例进行介绍

  • 首先在megatron/arguments.py中定义了--context-parallel-size参数, 同时也要求了world_size能要整除TP*PP*CP
    group.add_argument('--context-parallel-size', type=int, default=1,
                       help='Degree of context parallelism.')
    ....
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \
        'world size ({}) is not divisible by tensor parallel size ({}) times ' \
        'pipeline parallel size ({}) times context parallel size ({})'.format(
        args.world_size, args.tensor_model_parallel_size,
        args.pipeline_model_parallel_size, args.context_parallel_size)
    args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
  • megatron/core/parallel_state.py中初始化通信组时会初始化相关CP通信组, 以TP-PP-DP-CP=8-1-1-2为例,TP通信组为[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15], CP通信组为[0,8],[1,9],[2,10],[3,11],[4,12],[5,13],[6,14],[7,15]
def initialize_model_parallel(...):
    ...
    for i in range(pipeline_model_parallel_size):
        for j in range(data_parallel_size):
            start_rank = (
                i * num_pipeline_model_parallel_groups
                + j * tensor_model_parallel_size * context_parallel_size
            )
            end_rank = (
                i * num_pipeline_model_parallel_groups
                + (j + 1) * tensor_model_parallel_size * context_parallel_size
            )
            for k in range(tensor_model_parallel_size):
                ranks = range(start_rank + k, end_rank, tensor_model_parallel_size)
                group = torch.distributed.new_group(
                    ranks, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
                )
                if rank in ranks:
                    _CONTEXT_PARALLEL_GROUP = group
                    _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
  • megatron/core/transformer/custom_layers/transformer_engine.pyTEDotProductAttention会初始化相关CP通信组相关参数,TEDotProductAttention继承自te.pytorch.DotProductAttention,在前向中直接调用父类的的forward函数。
class TEDotProductAttention(te.pytorch.DotProductAttention):
    def __init__(...):
        ...
        if te_version >= packaging.version.Version("1.0.0"):
            if getattr(TEDotProductAttention, "cp_stream") is None:
                TEDotProductAttention.cp_stream = torch.cuda.Stream()
            extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
            extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
                check_initialized=False
            )
            extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
        ...

    def forward(...):
        ...
            core_attn_out = super().forward(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type.name,
                **packed_seq_kwargs,
            )
        ...
  • Transformer Engine中DotProductAttention定义在transformer_engine/pytorch/attention.py中,CP相关参数通过attn_kwargs进行传入。接下来会调用FlashAttention的前反向。
class DotProductAttention(torch.nn.Module):
    def __init__(...):
        ...
        if self.use_flash_attention:
            self.flash_attention = FlashAttention(norm_factor,
                                                  attention_type=attention_type,
                                                  layer_number=layer_number,
                                                  deterministic=self.deterministic,
                                                  **attn_kwargs)
        ...

class FlashAttention(torch.nn.Module):
    def forward(...):
        ...
        if context_parallel:
            with self.attention_dropout_ctx():
                output = attn_forward_func_with_cp(
                    self.training, query_layer, key_layer, value_layer,
                    cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,
                    self.attention_dropout if self.training else 0.0,
                    cp_group, cp_global_ranks, cp_stream,
                    softmax_scale=1.0/self.norm_factor,
                    qkv_format="bshd" if qkv_format=="sbhd" else qkv_format,
                    attn_mask_type=attn_mask_type,
                    deterministic=self.deterministic
                )
  • 在FlashAttention中会通过函数attn_forward_func_with_cp进行调用,最终Attn前的all_gather通信是在AttnFuncWithCP中通过send、recv通信来实现的, 执行完通信就执行对应的flash_attention算子的调用。
def attn_forward_func_with_cp(...):
    out = AttnFuncWithCP.apply(
        is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
        dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format,
        attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention
    )
    return out
    
class AttnFuncWithCP(torch.autograd.Function):
    def forward(...):
        for i in range(cp_size+1):
            if i < cp_size:
                with torch.cuda.stream(flash_attn_streams[i%2]):
                    # wait until KV is received
                    for req in send_recv_reqs[(i+1)%2]:
                        req.wait()

                    if i < (cp_size-1):
                        p2p_comm_buffers[i+1] = torch.empty_like(p2p_comm_buffers[i])
                        send_recv_reqs[i%2] = flash_attn_p2p_communicate(rank,
                                                                         p2p_comm_buffers[i],
                                                                         send_dst,
                                                                         p2p_comm_buffers[i+1],
                                                                         recv_src,
                                                                         cp_group,
                                                                         batch_p2p_comm)
                    ...
                                fused_attn_fwd(
                                    is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q,
                                    cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0],
                                    kv_inputs[i%2][1], TE_DType[q.dtype],
                                    tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
                                    attn_scale=softmax_scale, dropout=dropout_p,
                                    qkv_layout=qkv_layout, attn_mask_type="causal",
                                    attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2],
                                )

3. 参考

  • Context parallelism overview

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1705088.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

「小明赠书活动」第四期《Java开发坑点解析:从根因分析到最佳实践》

⭐️ 赠书 - 《Java开发坑点解析&#xff1a;从根因分析到最佳实践》 这是一本帮助Java开发人员规避常见错误的书。从业务代码开发、项目技术设计、代码安全3个层面剖析150多个常见坑点。 业务代码开发层面&#xff0c;近20个方面的坑&#xff0c;涉及多线程、数据访问、池技术…

【VTKExamples::Utilities】第四期 CameraModifiedEvent

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例CameraModifiedEvent,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 1. CameraModifi…

521源码-免费音乐源码-最新流媒体在线音乐系统网站源码| 英文版源码| 音乐社区 | 多语言 | 开心版

免费音乐源码 一键自动安装&#xff1a;安装用翻译看提示操作即可 本源码下载地址&#xff1a;最新流媒体在线音乐系统网站源码| 英文版源码| 音乐社区 | 多语言 | 开心版 - 521源码 更多网站源码学习教程&#xff0c;请点击&#x1f449;-521源码-&#x1f448;获取最新资源…

APM2.8如何供电

APM2.8飞控供电有两种&#xff0c; 1.电流计供电&#xff0c; 2.带BEC&#xff08;稳压功能&#xff09;的电调供电 飞控有一个JP1&#xff0c;它是一个供电选择接口&#xff0c;当插入跳线帽时&#xff0c;飞控用带BEC电调供电&#xff0c;当不插入时&#xff0c;用电流计供…

基于springboot的论坛管理系统(含源码+sql+视频导入教程)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于springboot的论坛管理系统3拥有两种角色 管理员&#xff1a;用户管理、公告管理、帖子管理、分类管理、留言管理、系统管理等 用户&#xff1a;登录注册、查看发布帖子等 1.1 背景…

深度学习论文: YOLOv10: Real-Time End-to-End Object Detection

深度学习论文: YOLOv10: Real-Time End-to-End Object Detection YOLOv10: Real-Time End-to-End Object Detection PDF: https://arxiv.org/pdf/2405.14458 PyTorch代码: https://github.com/shanglianlm0525/CvPytorch PyTorch代码: https://github.com/shanglianlm0525/PyTo…

如何评价 OpenAI 最新发布支持实时语音对话的模型GPT-4o?OpenAI发完GTP-4o,国内大模型行业还有哪些机会?

文章目录 OpenAI发完GTP-4o&#xff0c;国内大模型行业还有哪些机会&#xff1f;详细了解一下OpenAI最新发布的支持实时语音对话的模型GPT-4o国内大模型如何寻找发展机会&#xff1f;想要发展技术必须要创新与追赶或许应用场景拓展也是一种出路产业生态构建 ChatGPT 问世才 17 …

隆道专属商城 | 助力企业跨平台整合优势资源,解决采购寻源比价难题!

数字化采购时代&#xff0c;企业面临着日益激烈的市场竞争&#xff0c;如何优化资源配置、降低采购成本、提高采购效率成为企业追求的核心目标。当前&#xff0c;网上商城凭借其强大的供应链资源整合能力&#xff0c;为企业内部采购商城的搭建提供了独特的优势&#xff0c;已然…

【Lexus.4】Executive Sedan——Dismantling Follow-up

文章目录 【碰撞测试】前后防撞钢梁偏置碰撞A/B/C柱&#xff0c;边梁抗拉、屈服强度 【底盘】平整度护板&#xff08;发动机&#xff0c;底盘&#xff09;前副车架结构前悬架形式后悬架形式与材质簧下质量 【发动机】【轮上马力】【零部件供应商】 来自2021《懂车大爆炸》——是…

网络风暴:揭秘DDoS攻击的幕后黑手

在数字化时代的浪潮中&#xff0c;网络攻击已成为一种新型的战争手段。其中&#xff0c;分布式拒绝服务攻击&#xff08;DDoS&#xff09;以其强大的破坏力和隐蔽性&#xff0c;成为网络安全领域的一大挑战。DDoS攻击通过发动海量的恶意流量&#xff0c;如同狂风暴雨般席卷目标…

Springboot项目——博客平台

前言&#xff1a;为巩固之前学习的知识&#xff0c;同时锻炼自己的代码能力&#xff0c;项目经验&#xff0c;熟悉前后端交互方式等&#xff0c;特此完成一个博客平台系统。&#xff08;总之&#xff0c;为了学习&#xff0c;为了进步&#xff09; 博客平台&#xff1a;本项目…

干货|图生代码实例整理,让你的代码更高效

前言 “图生代码”。这项新功能允许开发人员直接利用产品设计图一键生成相应的代码&#xff0c;极大地提高了编程效率和研发速度。甚至会未来软件开发可能迎来一场革命性的变革。但图生代码究竟能直到什么程度&#xff1f;本文结合一款图生代码的实例程序整理了一些有代表意义…

如何在 DigitalOcean Droplet 云主机上创建 Ubuntu 服务器

在本文中&#xff0c;你将通过 DigitalOcean 的管理面板创建一个 Ubuntu 服务器&#xff0c;并将其配置为使用你的 SSH 密钥。设置好服务器后&#xff0c;你可以在其上部署应用程序和网站。 本教程是DigitalOcean云课程简介的一部分&#xff0c;它指导用户完成将应用程序安全地…

期望薪资30k字节java2面,A给B转账的同时B给A转账怎么并发量最高

一面 1、自我介绍 2、详细介绍一下自己的做的项目&#xff1f;根据项目提了一些问题 3、hashmap原理 4、B树原理&#xff1f; 5、final禁止重排序原理&#xff1f; 6、设计一个榨汁机类&#xff0c;面向对象怎么设计&#xff1f; 7、get、post区别&#xff0c;使用场景&…

mysql实战——mysql5.7保姆级安装教程

1、上传 上传5.7压缩包到/usr/local目录下 2、解压 cd /usr/local tar -zxvf mysql--5.7.38-linux-glibc2.12-x86_64.tar.gz mv mysql-5.7.38-linux-glibc2.12-x86_64/ mysql 3、创建mysql用户组和用户 groupadd mysql useradd -g mysql mysql 4、创建数据目录data&#xf…

OneForall工具的下载安装和使用(Windows和Linux)

目录 OneForall的介绍 OneForall的下载 OneForall的安装 安装要求 安装步骤&#xff08;git 版&#xff09; 安装&#xff08;kali&#xff09; OneForall的使用命令 在Windows 在Linux&#xff08;kali&#xff09; OneForall的结果说明 免责声明 本文所提供的文字和…

基于Java的高校学生勤工助学优派系统的设计与实现(论文+源码)_kaic

摘 要 高校勤工助学管理系统的出现&#xff0c;让学生的工作更加标准,不仅仅使高校办公室的办公水平以及管理水平大大提高&#xff0c;还优化了勤工助学资金的使用方式方法&#xff0c;完善了资助所需费用的资源配置&#xff0c;可以卓有成效地缩减学校的管理经费。本系统主…

智能SQL代码生成器,开发者的得力助手

&#x1f3e1; 博客首页&#xff1a;IT 派同学 ⛳️ 欢迎关注 &#x1f433; 点赞 &#x1f392; 收藏 ✏️ 留言 &#x1f3a2; 本文由 IT 派同学原创编撰 &#x1f6a7; 系列专栏&#xff1a;《开源专栏》 &#x1f388; 本系列主要输出作者自创的开源项目 &#x1f517; 作品…

B端产品C端化设计,趋势不可挡呀。

一、B端产品和C端产品设计的不同 在设计上&#xff0c;B端&#xff08;Business-to-Business&#xff09;和C端&#xff08;Consumer&#xff09;之间存在一些区别。 用户群体&#xff1a;B端产品的用户是企业或组织&#xff0c;而C端产品的用户是普通消费者。B端产品的用户通…

面向对象编程的魅力与实战:以坦克飞机大战为例

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、面向对象编程的引言 二、理解面向对象编程与面向过程编程的差异 三、创建类与对象&…