AI Infra论文阅读之《在LLM训练中减少激活值内存》

news2024/11/15 23:46:19

写了一个Megatron-LM的3D Parallel进程组可视化的Playground,界面长下面这样:

在这里插入图片描述

可以直接访问:https://huggingface.co/spaces/BBuf/megatron-lm-parallel-group-playground

脚本也开源在:https://github.com/BBuf/megatron-lm-parallel-group-playground 请随意获取和修改。

0x0. 前言

本次阅读一篇Megatron-LM的README贴出的一篇paper,是NVIDIA在2022年上传的,讲的是Megatron-LM里面的Sequence Parallel和Selective Activation Recomputation如何降低大语言模型训练中的激活内存。这里来看一下理论讲解,另外Paper的第4节的激活内存公式估计比较有用。paper链接为:https://arxiv.org/pdf/2205.05198.pdf 。序列并行目前是非常常用的,但是选择性激活重计算可能用得人不多,我想一个很重要的原因应该是 FlashAttention 的出现大大降低了激活值内存大小。但这篇Paper的一些公式仍然是有用处的,至少给我们提供了理论依据。此外,Meagtron-LM推出的Context Parallel是在Sequence Parallel的基础上更近一步,可以继续降低激活值的大小,并且支持更长的序列。

摘要就是说通过Sequence Parallel和Selective Activation Recomputation可以减少激活重计算,把Sequece Parallel和Tensor Parallel结合在一起基本可以避免激活重计算的需要。然后在高达一万亿参数规模的语言模型上评估了上面的两个方法,并展示了这里的放大把激活内存降低了5倍,同时把激活重计算的执行时间开销减少了超过90%。例如,在2240 NVIDIA A100 GPUs上训练一个530B参数的GPT-3风格模型时,实现了54.2%的MFU,比使用重计算实现的42.1%快了29%。

0x1. 介绍

这里简单总结一下,一般来说Meagtron-LM里面张量并行都是放在同一个GPU的节点里面,节点内部由NVLink连接。然后,流水线并行虽然可以减少存储模型参数和优化器状态的内存,但是由于要存储一些Micro Batch的激活,所以并不能减少激活需要的内存。因此,激活内存的存储成为了训练大语言模型的一个关键问题。图1显示了从220亿参数到1万亿参数的四种模型配置所需的内存(模型配置的详细信息在表3中提供)。

在这里插入图片描述
在这里插入图片描述
这里的present work就是通过激活重计算(也叫Gradient Checkpointing)来减轻Activation的存储大小。之前标准的做法是在每一个Transformer层的边界进行重计算,paper也把这种方法叫作完全激活重计算。但完全激活重计算会导致增加30-40%左右的时间开销。为了节省这部分计算开销,但又要Scale Up模型,所以就引入了Paper介绍的两种方法Sequence Parallel和Selective Activation Recomputation。

0x2. 相关工作rt

对Megatron-LM的张量并行进行了简单的介绍,没什么干货,忽略。

0x3. Transformer的结构

如下图的图2所示:输入token被送入一个大小为 v × h v\times h v×h的词嵌入表中,token嵌入与学习到的位置嵌入(大小为 s × h s\times h s×h)结合,其中 s s s是序列长度, h h h是隐藏维度, v v v是词表大小。嵌入层的输出,即Transformer块的输入,是一个大小为 s × b × h s\times b\times h s×b×h的3维张量,其中 b b b是微批量大小。每个Transformer层由一个自注意力块组成,该块有 a a a个注意力头,接着是一个增加隐藏大小到 4 h 4h 4h然后再减少回 h h h的两层多层感知器(MLP)。每个Transformer层的输入和输出大小相同,为 s × b × h s×b×h s×b×h。最后一个Transformer层的输出被投影回词汇维度以计算交叉熵损失。paper假设词嵌入和输出层权重是共享的。变量名在表1中列出以供参考。

在这里插入图片描述
在这里插入图片描述

0x4. Activation Memory

首先,Paper导出了一个近似的公式来估计激活内存的大小,这里的激活指的是在Forward过程中创建并且在Backward中用于梯度计算所必需的任何张量。然后,这里只考虑对激活内存贡献最大的部分不考虑小的Buffer,比如对于LayerNorm来说,输入包含bsh个元素,但均值和方差每个只有sb个元素,由于h一般很大,所以bsh远远大于2sb,所以就忽略掉2sb,以sbh来计算LayerNorm层的激活大小。

0x4.1 每个Transformer层的Activation估计

注意Paper发表的时候还没有FlashAttention

如上图2所示,每个Transformer层由一个注意力层和一个MLP块组成,两者通过两个LayerNorm链接。下面,paper到处来存储每个元素激活所需的内存:

  • Attention块。包括自注意力机制后跟一个线性投影和一个注意力dropout。线性投影存储其输入激活,大小为2sbh,而注意力dropout需要一个大小为sbh的掩码。如图3所示的自注意力包含几个元素:
    • 查询(Q)、键(K)和值(V)矩阵乘法:我们只需要存储它们共享的输入,大小为 2 s b h 2sbh 2sbh
      QKT矩阵乘法:它需要存储Q和K,总大小为 4 s b h 4sbh 4sbh
    • Softmax:反向传播需要大小为 2 a s 2 b 2as^2b 2as2b的Softmax输出。
    • Softmax dropout:只需要一个大小为 a s 2 b as^2b as2b的掩码。
    • 对值(V)的注意力:我们需要存储dropout输出( 2 a s 2 b 2as^2b 2as2b)和值( 2 s b h 2sbh 2sbh),因此需要 2 a s 2 b + 2 s b h 2as^2b + 2sbh 2as2b+2sbh的存储空间。
      将上述值相加,总的来说,注意力块需要 11 s b h + 5 a s 2 b 11sbh + 5as^2b 11sbh+5as2b字节的存储空间。

在这里插入图片描述

  • MLP块。两个线性层存储它们的输入,大小为 2 s b h 2sbh 2sbh 8 s b h 8sbh 8sbh。GeLU非线性也需要其大小为 8 s b h 8sbh 8sbh的输入以进行反向传播。最后,dropout存储其掩码,大小为 s b h sbh sbh。总的来说,MLP块需要 19 s b h 19sbh 19sbh字节的存储空间。
  • LayerNorm。每个LayerNorm存储其输入,大小为 2 s b h 2sbh 2sbh,因此总共我们将需要 4 s b h 4sbh 4sbh的存储空间。

将注意力、MLP和层LayerNorm所需的内存相加,存储单层Transformer网络激活所需的内存是:

在这里插入图片描述

这是在没有应用模型并行时的计算公式,也就是单张卡需要的激活内存计算大小。

0x4.2 模型并行

这一节量化了张量并行对每个Transformer层的激活内存的影响。然后引入了序列并行的新方法,进一步减少了每一层的激活所需内存。最后还讨论了Pipline并行对激活内存的影响,并推导了激活内存的理论公式。

0x4.2.1 张量并行

指的就是Megatron-LM的张量并行,如下图所示:

然后应用了张量并行之后上面的公式就变成:

在这里插入图片描述

这里的10分别表示两个LayerNorm的输入,以及SelfAttention和MLP模块的输入以及输出部分Dropout所需要的激活内存。

0x4.2.2 序列并行

Megatron-LM序列并行的原理就是下面这张图,对比图4来看我们可以发现在非Tensor Parallel的部分使用了Sequence Parallel,同时通信原语也发生了变化:

在这里插入图片描述

在Figure4中,由于LayerNorm和Dropout必须接收完整的数据,对于一个Transformer Layer来说前向和后向都分别有2次all-reduce。而在序列并行中,前后向的2次allreduce分别被拆成了allgather+reduce-scatter,总的通信量没发生变化。paper在这一节对此有一个证明,这里就忽略了,直接给出同时使用序列并行和Tensor并行下的激活内存计算公式:

在这里插入图片描述

和单纯的张量并行相比,现在两个LayerNorm的输入,以及SelfAttention和MLP模块的输入以及输出部分Dropout所需要的激活内存都减少了 t t t倍,因为按照序列的维度进行了切分。

0x4.2.3 Pipline并行

GPipe->1F1B

Pipline并行可以读我之前写的这篇paper解读:AI Infra论文阅读之将流水线并行气泡几乎降到零(附基于Meagtron-LM的ZB-H1开源代码实现解读)。在这篇文章里面提到过对于GPipe来说流水线中最长驻留了 m m m 个未完成的 micro batch(上半部分图). 而 1F1B 则限制其最多驻留流水线深度 p p p 个未完成的 micro batch,如此形成了上图中的下半部分的流水线。这个流水线的特点是一个迭代的时间没有变化,但是 p ≪ m p \ll m pm ,所以驻留的未完成的 micro batch极大减少,减少了显存峰值。(重点是减少了显存的峰值,但是气泡还是不变)。这也是下图为什么估计第一个Stage的激活内存时分子乘以了L的原因,而和micro bacth的大小无关。

在这里插入图片描述

对于VPP来说,公式有一些变化,第一个Stage的显存会增加。

0x4.3 总的激活内存

上面的公式5没有考虑输入嵌入,最后一层的LayerNorm以及如图2所示的输出层所需的激活内存。位置和词嵌入在反向传播中不需要存储任何大量的激活内存。但Dropout操作需要激活内存。嵌入层中的Dropout也沿序列维度并行化。因此,它将需要 s b h p / t sbhp/t sbhp/t的存储空间。这里的p是Pipline并行维度,以及我们需要存储 p p p个micro batch的事实。
输出层之前的LayerNorm也使用序列并行,因此需要 2 s b h / t 2sbh/t 2sbh/t的存储空间。输出层投影到词汇维度需要 2 s b h / t 2sbh/t 2sbh/t的存储空间。最后,交叉熵损失需要存储以32位浮点数计算的对数值,因此将需要 4 s b v / t 4sbv/t 4sbv/t的存储空间。总共 4 s b h / t ( 1 + v / h ) 4sbh/t(1 + v/h) 4sbh/t(1+v/h),仅在没有Pipline并行的情况下包括( p = 1 p = 1 p=1)。
加上上述内存,由输入嵌入、最后一层LayerNorm和输出层引起的额外激活内存公式是:

s b h L t ( p L + δ p = 1 4 L ( 1 + v h ) ) \frac{sbhL}{t} \left( \frac{p}{L} + \delta_{p=1} \frac{4}{L} \left(1 + \frac{v}{h}\right) \right) tsbhL(Lp+δp=1L4(1+hv))

其中, δ p = 1 \delta_{p=1} δp=1在p=1时为1,否则为0。实际上这里的额外激活相比于公式5来说就太小了,例如对于22B的模型来说,额外激活的占比只有0.01%,所以一般直接用公式5估计激活内存就比较准确了。

0x5. 选择性的激活重计算

这一节翻译一下原文。

公式5得出的所需总激活内存对于大型模型来说仍然可能相当大。通过存储(或“checkpointing”)一组层的输入激活并在反向传播期间使用额外的前向pass重计算其它所需激活,激活重计算[5]克服了这一内存限制(这在本文中被称为完全激活重计算)。假设组只包含单个层,并忽略Transformer层外的激活,这种方法将激活所需的总内存减少到2sbhL。我们注意到,如果我们只在每个张量并行等级中存储部分激活,则这个所需内存可以进一步减少到2sbhL/t。然而,这种方法需要每层额外进行一次全收集操作,并将增加通信开销,因此,我们不考虑这种方法。

与存储所有激活(公式5)相比,对所有Transformer层进行checkpointing显著减少了训练模型所需的内存量。这种减少确实以重新计算(一个额外的前向pass)的成本为代价,可能引入高达30-40%的计算时间开销。为了平衡内存节省和计算开销,理想情况下应该只checkpointing足够的激活,以允许给定的模型并行配置在设备内存的限制下进行训练。序列并行性提供的内存节省使得许多更多的配置能够在无需重计算的情况下进行训练,但大型模型的最佳模型并行配置通常仍需要保存和重计算一些激活。选择存储与重计算激活数量的一个简单方法是只对一些Transformer层进行检查点,并存储其它层的所有激活。这种方法对大型模型的扩展性不是很好;例如,在训练MT-NLG时,每个设备只有三层,限制了你在内存与计算之间平衡的粒度。此外,我们注意到,并非所有激活都需要相同数量的操作来重新计算,因此,更加明智地选择哪些激活要存储和哪些需要重计算是有益的。

我们提出的不是对整个Transformer层进行checkpointing和重新计算,而是只对每个Transformer层中占用大量内存但重计算计算成本不高的部分进行checkpointing和重计算,或称为选择性激活重计算 。为此,我们注意到,公式5中的 5 a s / h 5as/h 5as/h项是由于网络宽度通过计算Q、K和V值的线性层增加后的注意力操作所致;即, Q K T QK^T QKT矩阵乘法、softmax、softmax dropout和对V的注意力操作。这些操作通常具有大的输入大小,因此激活量大,然而,每个输入元素的浮点操作数(FLOPs)非常低。Transformer层的其余部分占据了公式5中的 34 34 34项。因此,对于大型模型,其中 5 a s / h > 34 5as/h > 34 5as/h>34,如果我们checkpointing并重新计算Transformer层的这一部分,我们存储的激活几乎可以少一半,并且重计算那些未存储的激活只有一个相对不高的成本。

使用这种形式的选择性激活重计算,存储激活所需的内存从公式5减少到:

在这里插入图片描述

上述公式展示了,使用选择性激活重计算允许所需的激活内存与序列长度线性比例增长,并且独立于注意力头的数量。正如第4.2.3节中讨论的,在使用VPP Schedule的情况下,上述公式需要乘以 1 + p − 1 p m 1 + \frac{p-1}{pm} 1+pmp1

在使用Pipline并行时,如第4.2.3节讨论的,尽管给定设备只有 L / p L/p L/p层,但第一个Stage仍必须存储相当于L层激活的量,因为它必须为 p p p个micro batch存储Activation来流水。在这种情况下,可以采用的另一种技术是尽可能根据可用设备内存存储尽可能多的micro-batch的所有激活,并对其余部分进行完全或选择性重计算。实践中我们发现,应用序列并行和选择性激活重计算后,重计算开销足够小,以至于这种额外技术提供的改进非常有限。这种技术在附录C中有更详细的描述和分析。

简而言之,通过选择性激活重计算,可以有效减少存储激活所需的内存,使其与序列长度线性相关,而与注意力头数量无关。尤其在使用管道并行性时,采用额外技术进一步降低重计算成本是可能的,但在实际应用中,序列并行性和选择性激活重计算已经能够显著降低重计算开销,使得额外技术的效果较为有限。

这一节的Table2值得注意一下,是对上面各种并行和重计算方式的中间激活内存的计算公式。

在这里插入图片描述

0x6. 实验部分

Table3展示了进行试验的几个模型的尺寸大小和超参数。

在这里插入图片描述

然后实验部分看下几个图和表格就可以了。

在这里插入图片描述

这张图是实测了下相比于单纯的模型并行,Sequence Parallel,Selective Recompute,Full Compute等能节省的显存比例,可以看到序列并行和选择性重计算很有作用。

在这里插入图片描述

Table4展示了序列并行和选择性重计算分别对前后向时间的影响,是在22B的模型上实验的,可以看到序列并行和选择性重计算同时作用的情况下也只增加了4%的overhead。

在这里插入图片描述

这张图的结论就是序列并行和选择性重计算相比于完全重计算来说增加的算力开销非常少。

在这里插入图片描述

通过序列并行和选择性重计算可以提升各个尺寸大模型的吞吐和MFU。

0x7. 结论

序列并行目前是非常常用的,但是选择性激活重计算可能用得人不多,我想一个很重要的原因应该是FlashAttention的出现大大降低了激活值内存大小。但这篇Paper的一些公式仍然是有用处的,至少给我们提供了理论依据。此外,Meagtron-LM推出的Context Parallel是在Sequence Parallel的基础上更近一步,可以继续降低激活值的大小,并且支持更长的序列。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1542549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux部署seata-2.x整合SpringCloud使用(Nacos实现配置与注册中心)

😊 作者: 一恍过去 💖 主页: https://blog.csdn.net/zhuocailing3390 🎊 社区: Java技术栈交流 🎉 主题: Linux部署seata-2.x整合SpringCloud使用(Nacos实现配置与注册中心) ⏱️…

Request请求参数----中文乱码问题

一: GET POST获取请求参数: 在处理为什么会出现中文乱码的情况之前, 首先我们要直到GET 以及 POST两种获取请求参数的不同 1>POST POST获取请求参数是通过输入流getReader来进行获取的, 通过字符输入流来获取响应的请求参数, 并且在解码的时候, 默认的情况是 ISO_885…

Elasticsearch:虚拟形象辅助和对话驱动的语音到 RAG 搜索

作者:来自 Elastic Sunile Manjee 搜索的演变 搜索已经从产生简单结果的简单文本查询发展成为容纳文本、图像、视频和问题等各种格式的复杂系统。 如今的搜索结果通过生成式人工智能、机器学习和交互式聊天功能得到增强,提供更丰富、更动态且与上下文相…

一张表看懂阿里云服务器优惠价格表(CPU内存价格+带宽费用+磁盘价格)

2024年腾讯云服务器优惠价格表,一张表整理阿里云服务器最新报价,阿里云服务器网整理云服务器ECS和轻量应用服务器详细CPU内存、公网带宽和系统盘详细配置报价单,大家也可以直接移步到阿里云CLUB中心查看 aliyun.club 当前最新的云服务器优惠券…

【Redis】Redis特性

Redis 认识redisRedis特性在内存中存储数据可编程可扩展性持久化Clustering高可用性 认识redis Redis,英文全称是Remote Dictionary Server(远程字典服务),是一个开源的使用ANSIC语言编写、支持网络、可基于内存亦可持久化的日志…

Window全网解析网站下载视频

全网解析网站下载视频 介绍m3u8格式cbox格式 解析视频下载的方法方法一解析视频下载视频 方法二老王浏览器下载使用浏览器解析下载视频 总结 介绍 今天分享一下如何解析网页中的视频进行下载。通常情况下我们打开的某某网站的视频是不提供下载接口的,甚至说你下载了…

Verilog刷题笔记45

题目:Given the finite state machine circuit as shown, assume that the D flip-flops are initially reset to zero before the machine begins. Build this circuit. 解题: module top_module (input clk,input x,output z ); wire [2:0]size;dtou…

性能测试丨GreatSQL TPC-H 性能测试报告正式发布!

1、测试背景概述 本次测试针对GreatSQL开源数据库基于标准 TPC-H 场景的测试。 TPC-H(商业智能计算测试)是美国交易处理效能委员会(TPC,TransactionProcessing Performance Council)组织制定的用来模拟决策支持类应用…

StarRocks 助力金融营销数字化进化之路

作者:平安银行 数据资产中心数据及 AI 平台团队负责人 廖晓格 平安银行五位一体,做零售金融的领先银行,五位一体是由开放银行、AI 银行、远程银行、线下银行、综合化银行协同构建的数据化、智能化的零售客户经营模式,这套模式以数…

37、Linux中Xsync数据同步备份工具

37、Linux中Xsync数据同步备份工具 一、介绍二、配置集群hostname三、修改xsync文件四、赋权五、安装Rsync六、验证一七、配置免密登录1、生成rsa密钥2、copy机器自身公钥到目标机器3、.ssh/文件目录赋权 八、验证二 ⚠️ 注:本文全程在普通用户下操作,…

设计模式之建造者模式详解

建造者模式 1)概述 将一个复杂对象的构建与它的表示分离,使得同样的构建过程可以创建不同的表示。 1.复杂对象 复杂对象是指包含多个成员属性的对象。 2.结构图 Builder(抽象建造者):它为创建一个产品Product对象…

项目2-用户登录

1.创建项目 2.引入前端代码并检查是否有误 3.定义接口 需求分析 对于后端开发⼈员⽽⾔, 不涉及前端⻚⾯的展⽰, 只需要提供两个功能 1. 登录⻚⾯: 通过账号和密码, 校验输⼊的账号密码是否正确, 并告知前端 2. ⾸⻚: 告知前端当前登录⽤⼾. 如果当前已有⽤⼾登录, 返回登录的账…

看看Java Web怎么上传文件到服务器

旁白不多说了直接上主题了。 1、新建上传文件夹 在eclipse中&#xff0c;在我们前面文章中用到的项目HelloJSP&#xff0c;在webapp目录下新建uploadfiles文件夹&#xff0c;如下所示&#xff1a; 2、修改HelloWorld.jsp文件 <body><h1>文件上传</h1><…

稻飞虱在线监测仪的工作原理

TH-DF122随着现代农业科技的快速发展&#xff0c;智能化、精准化的农业管理工具日益成为农业生产的得力助手。其中&#xff0c;稻飞虱在线监测仪作为一种创新的农业监测设备&#xff0c;正以其独特的工作原理和显著的应用效果&#xff0c;成为保障稻田生态安全和提高稻米产量的…

ideaSSM 人才引进管理系统bootstrap开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 idea 开发 SSM 人才引进管理系统是一套完善的信息管理系统&#xff0c;结合SSM框架和bootstrap完成本系统&#xff0c;对理解JSP java编程开发语言有帮助系统采用SSM框架&#xff08;MVC模式开发&#xff09;&#xff0c;系统具有完整的源代码和数据库&#xff…

[python]bar_chart_race设置日期格式

1、设置日期标签的时间格式 # 设置日期格式&#xff0c;默认为%Y-%m-%dbcr.bar_chart_race(df, covid19_horiz.gif, period_fmt%b %-d, %Y) 2、更改日期标签为数值 # 设置日期标签为数值bcr.bar_chart_race(df.reset_index(dropTrue), covid19_horiz.gif, interpolate_period…

关于在CentOS中卸载MySQL

想要卸载MySQL当然要知道自己的MySQL是用那种方法来安装的了&#xff0c;一般来说MySQL的安装方法在市面上有三种 编译安装、YUM安装、RPM安装&#xff0c;下面会介绍到后两种安装的卸载方法 首先查看是否安装MySQL&#xff0c;一般可以看到版本信息就证明安装了 mysql -V 卸载…

day04_JDBC_课后练习(创建数据库,表格,添加模拟数据,搭建开发环境,编写实体类,实现接口,测试)

文章目录 day04_JDBC_课后练习1、创建数据库2、创建如下表格3、添加模拟数据4、搭建开发环境&#xff0c;准备各个工具组件&#xff08;1&#xff09;使用druid&#xff08;德鲁伊&#xff09;数据库连接池&#xff08;2&#xff09;使用尚硅谷的JDBCTools工具类&#xff08;直…

高中信息技术教资刷题笔记_选择题篇

1.信息技术基础 位与字节的换算 模2除法运算 网页保存 进制之间的计算 教你快速学会二进制、十进制、十六进制之间的转换 - 知乎 (zhihu.com) 原码、补码、反码计算 物联网技术 位运算 按位与&#xff1a;同位置为1&#xff0c;则为1&#xff0c;其他都是0按位或&#xff1a;有…

[Windows常用软件] word 复制粘贴报错修复

背景 在word 内 ctrlv 会报这个错。 microsoft visual basic MathPage.Wll 运行时错误 网上查了一下是 mathtype 导致的&#xff0c;应该是我之前卸载 mathtype 没有卸载干净导致的。 解决方案 参考知乎里面的一个回答解决的&#xff1a;https://www.zhihu.com/question/37…