一、训练时间预估
1、耗时计算内容
- 浮点数计算
2、训练时间
- 训练时间 = 运算量 / (GPU卡数 X GPU每秒运算数)
3、每秒运算数
实际GPU每秒运算数约等于GPU理论极限能力的30%到70%
二、训练显存预估
1、字母映射
- P:模型参数量
- G:GPU数量
- N𝑡:张量并行数
- Np:流水线并行数
- N𝑑:训练的数据并行数
2、占用显存的部分
- 模型参数与优化器
- 训练中需要保存的激活值
- 其他显存占用
3、模型参数与优化器的显存占用
3.1、简介
- 组成部分
- 模型参数
- 模型梯度
- 优化器等
- 数据存储格式
- 16位、2字节浮点数
- 模型参数
- 模型梯度
- 32位、4字节浮点数
- 模型参数
- 动量参数
- 动量二阶矩阵参数
- 16位、2字节浮点数
3.2、不同方案对比
- 不使用ZeRO优化技术:每张卡需要16P显存
- 模型参数:2P
- 模型梯度:2P
- Adam优化器
- 模型参数:4P
- 动量参数:4P
- 动量二阶矩参数:4P
- 使用ZeRO的优化器参数分区方案(ZeRO-1):
- 特点:将优化器的参数平摊到每张GPU上,模型参数和模型梯度各自保留
- 显存占用:( 2 + 2 )* P + ( 4 + 4 + 4 )* P / N𝑑 = 4P + 12P / N𝑑
- 优点:相比不使用ZeRO的方案,在GPU数量足够多的情况下,模型参数和优化器所占用显存会减少到原来的1/4
- 使用ZeRO的模型梯度分区方案(ZeRO-2):
- 简介:在ZeRO-1的基础上,进一步将模型梯度平分到每张GPU上
- 显存占用:2P + ( 2 + 4 + 4 + 4 )* P / N𝑑 = 2P + 14P / N𝑑
- 优点:GPU数量足够多的情况下,用于存储模型参数和优化器的显存会减少到1/8
- 使用ZeRO的模型参数分区方案(ZeRO-3):
- 简介:基于ZeRO-2方案,进一步将模型参数均分到每张GPU上
- 显存占用:16P / N𝑑
- 优点:用于存储模型参数和优化器的显存会减少至原来的1 / N𝑑
- 使用张量并行和流水线并行的方案
- 简介:与上述四种方式兼容
- 显存占用:上述几种方案的情况下 / N𝑡 * Np
4、训练激活值的显存占用
4.1、简介
- 显存存储内容:前向传播需要保留每层的激活值(中间状态),来用于后续反向传播中计算梯度并更新模型参数
4.2、不同层显存占用
- 多头自注意力层(不考虑张量并行、流水线并行、激活重计算等优化方法):
- 显存占用
- 查询、键、只的线性变换需要保存其输入,占用2BTH字节
- 多头注意力计算保存输入的查询、键、值,共占用6BTH字节
- 合并多头结果需要保存其输入,占用2BTH字节
- 若未使用FlashAttention优化,需要占用2BT²N字节,若未使用,此部分无开销
- 显存占用
- 前馈网络层
- 显存占用:2BTH + 6BTH’
- 保存SwiGLU激活函数的输入,占用2BTH字节
- 同时要保存WᴳX和WᵁX的值,占用4BTH’字节
- 保留SwiGLU的输出值,作为后续线性变换的输入,共计2BTH’字节
- 显存占用:2BTH + 6BTH’
- 归一化层
- 显存占用:每层解码器包含两个归一化层,每个归一化层需要保存其输入,占用4BTH字节
- 输出层
- 显存占用:
- 经过L层解码器解析后,需要经过归一化层处理,需要保存其输入,占用2BTH字节
- 保存词表映射的输入,占用2BTH字节
- 保存softmax输入,占用,在实践中,为了提升softmax的精度,回见输入转化为32位浮点数来进行后续计算,占用4BTV字节
- 显存占用:
4.3、不同优化方案显存占用(默认不使用FlashAttention,使用的话,去掉2BT²N即可)
- 未使用FlashAttention的情况下:(16BTH + 6BTH’ + 2BT²N)* L + 4BTH + 4BTV
- 使用FlashAttention的情况下(去掉2BT²N):(16BTH + 6BTH’ )* L + 4BTH + 4BTV
- 流水线并行:每个GPU仅需要保存分配到的层对应的激活值即可,激活值占用显存 = (16BTH + 6BTH’ + 2BT²N)* L / Np + 4BTH + 4BTV
- 张量并行:多头注意力层和前馈网络层中的线性变换操作可以通过拆分参数矩阵,分配到不同GPU上并行计算结果,对应的激活值也分配到响应GPU上去,激活值占用显存 = (( 8 + 8 / N𝑡)BTH + 6BTH’ / N𝑡 + 2BT²N / N𝑡 )* L + 4BTH + 4BTV
- 激活重计算:激活重计算在前向传播时仅保存Transformer每一层的输入和最后层softmax层的输入,在反向传播时按需重新计算激活值来减少显存占用,激活值占用 = ( 4 + 2L )BTH + 4BTV
5、其他显存占用
- 代码库内核:0.8-1GB
- ZeRO优化技术实现:1-4GB
- 训练过程中的中间结果和显存碎片:0.5-1GB
6、实际训练所需显存估计
至少需要16别参数数量的显存资源,例如13B的模型,至少需要13 X 16 = 208GB的显存