分布式训练数据并行极致优化:ZeRO
导言
随着 ChatGPT 的爆火,大模型成为了近些年人工智能的研究热点。大模型能力惊艳,但是训练起来成本也不小。大模型,顾名思义,最大的特点就是 “大”。这里的 “大” 通常指的就是模型的参数量大。因此,在分布式训练中,如何利用有限的显存训练更大的模型就是重点。分布式的训练的常用范式包括数据并行和模型并行,其中模型并行又包括张量并行和流水线并行。Megatron-ML 等框架中实现的的张量并行已经是是训练大模型的标配,但是数据并行作为最简洁、最易理解、最易实现的分布式训练范式,近些年还是有了完善的优化。本文主要介绍分布式训练数据并行的极致优化:ZeRO。
数据并行中,一个显著的问题就是每张卡都需要保存一个完整的模型及其优化参数(包括模型梯度、Adam 参数等),这其中有极大的冗余性,能否每张卡只保存 全部模型参数的一部分呢。ZeRO(Zero Redundancy Optimizer,另冗余优化)是由微软在 2019 年提出的一种高效的数据并行方案。ZeRO 能够消除分布式训练数据并行中的冗余性,并同时能够维持较低的通信量和较高的计算粒度。这使得我们能够在显存有限的条件下,训练更大的模型。近年来比较知名的分布式训练框架,如微软的 DeepSpeed、Pytorch 的 FSDP 都是基于 ZeRO 的数据并行思想。
数据并行空间复杂度分析
我们以如今比较常用的 ADAM 优化器和混合精度训练的情况为例,来分析训练过程中的显存占用。
ADAM 维护梯度的一阶动量(momentum)和二阶动量(variance),具有动态的学习率,是现今常用的优化器。从显存占用的角度来看,ADAM 优化器除了需要维护模型参数及其梯度之外,还需要维护 momentum 和 variance。
混合精度训练已经是如今训练大规模训练的标配,它能在几乎不损失性能的情况下减小显存占用并加快训练速度。混合精度训练过程中一般有 fp16 和 fp32 两种精度类型的数值。fp16 类型包含模型参数及其梯度,fp32 类型包括模型参数的 fp32 备份,以及优化器需要维护的参数,比如 ADAM 中的 momentum 和 variance。
以上是 ADAM 优化器 + 混合精度训练情况下模型状态的显存占用。除此之外,训练中还有激活值、临时缓冲区和显存碎片等。
综上所述,训练过程中的显存占用可分为两大部分:
- 模型状态:记模型本身参数量为 Φ \Phi Φ ,在 Adam + 混合精度训练的情况下,模型状态包括 fp16 的模型参数 2 Φ 2\Phi 2Φ 和参数梯度 2 Φ 2\Phi 2Φ 和 fp32 的模型参数备份 4 Φ 4\Phi 4Φ ,momentum 4 Φ 4\Phi 4Φ 和 variance 4 Φ 4\Phi 4Φ ,即总共 2 Φ + 2 Φ + 4 Φ + 4 Φ + 4 Φ = 16 Φ 2\Phi+2\Phi+4\Phi+4\Phi+4\Phi=16\Phi 2Φ+2Φ+4Φ+4Φ+4Φ=16Φ 。(注意 fp16 占两个字节,fp32 占四个字节)
- 剩余状态:即训练中的激活值、临时缓冲区和显存碎片等。
以 GPT-2 为例,GPT-2 模型含有 1.5B 个参数,如果用 fp16 格式,模型本身只占 3GB 显存,但是实际训练过程中的模型状态需要耗费 24GB!可以看到。模型状态是成倍于模型本身的大小,是显存消耗的大头。并且,对于剩余状态中的激活值等,已经有 activation checkpointing 等以时间换空间的优化方式,可以有效减小这部分显存消耗。因此,优化模型状态的显存占用是重点。
ZeRO 由 ZeRO-DP 和 ZeRO-R 组成,分别是对模型状态和剩余状态的显存优化。
ZeRO-DP
模型状态是 ZeRO 显存优化的重点。在导言中提到,数据并行的分布式训练方式中,每个 GPU 都要保存一份独立、完整的模型状态参数,即 12 Φ 12\Phi 12Φ 的显存占用。显然,这其中是存在大量冗余的,按理说,我们只要保存一份模型状态参数即可。这正是 ZeRO 优化的思路:分片(partition),在分布式训练的 N N N 个 GPU 中,每个 GPU 保存 1 N \frac{1}{N} N1 的模型状态参数,当计算需要其他部分的模型状态参数时,将其他 GPU 保存的参数传过来即可。这是一种以带宽换显存的思路。
下面的图来自 ZeRO 论文原文,比较直观地展示了 ZeRO 显存优化的思路。
ZeRO-DP 的显存优化有三个优化等级,一般称为 ZeRO-1,ZeRO-2,ZeRO-3,对应图中的 P o s P_{os} Pos 、 P o s + g P_{os+g} Pos+g、 P o s + g + p P_{os+g+p} Pos+g+p 。未进行优化是,显存占用为 ( 2 + 2 + K ) ∗ Φ (2+2+K)*\Phi (2+2+K)∗Φ
- ZeRO-1:首先,根据之前的分析,Adam 优化器状态(Optimizer States,os)是占用显存最多的,对应图中绿色部分。将优化器状态分片,在不同的 GPU 上维护。从而 ZeRO-1 的显存占用为 ( 2 + 2 + K N ) ∗ Φ (2+2+\frac{K}{N})*\Phi (2+2+NK)∗Φ,当 K → ∞ K\rightarrow \infty K→∞ 是,约为 4 Φ 4\Phi 4Φ。
- ZeRO-2:其次要优化的是梯度(Gradients,g),对应图中橙色部分,同样切片保存到不同的 GPU 上,显存占用为 ( 2 + 2 + K N ) ∗ Φ (2+\frac{2+K}{N})*\Phi (2+N2+K)∗Φ ,当 K → ∞ K\rightarrow \infty K→∞ 是,约为 2 Φ 2\Phi 2Φ。
- ZeRO-3:最后要优化的是模型参数(Parameter,p),对应图中绿色部分,此时显存占用为 2 + 2 + K N ∗ Φ \frac{2+2+K}{N}*\Phi N2+2+K∗Φ,当 K → ∞ K\rightarrow \infty K→∞ 是,模型状态所占显存接近于零。
可以看到,使用 ZeRO 策略将模型状态进行分片保存,随着 GPU 增加,分片越来越多,该部分的显存占用越来越小,甚至理论上会趋于零。
但实际中,要考虑各个 GPU 之间通讯的开销,别忘了,我们现存的节省,使用带宽和通讯“换”来的。结论是:ZeRO-1 和 ZeRO-2 与不使用 ZeRO 策略传统数据并行方式的通讯量一致,而 ZeRO-3,则要额外的通讯量。具体的分析后面会单独讲。权衡显存占用和通讯开销,实际中我们一般选择 ZeRO-1 或 ZeRO-2 即可。DeepSpeed 中可以设置 ZeRO-1/2/3,而 Pytorch 的 FSDP,即 Fully Sharded Data Parallel,既然是 Fully,即是完全切片了,相当于 ZeRO-3。
ZeRO-R
ZeRO-DP 优化了模型状态的显存占用,而 ZeRO-R 则优化剩余状态,也就是激活值(activation)、临时缓冲区(buffer)以及显存碎片(fragmentation)。
- 激活值同样使用分片方法,并且配合 activation-checkpointing 来进一步减小显存占用;
- 模型训练过程中经常会创建一些大小不等的临时缓冲区,比如对梯度进行 AllReduce 等,解决办法就是预先创建一个固定的缓冲区,训练过程中不再动态创建,如果要传输的数据较小,则多组数据 bucket 后再一次性传输,提高效率
- 显存出现碎片的一大原因是时候 gradient checkpointing 后,不断地创建和销毁那些不保存的激活值,解决方法是预先分配一块连续的显存,将常驻显存的模型状态和 checkpointed activation 存在里面,剩余显存用于动态创建和销毁 discarded activation
ZeRO-R 部分都是计算机系统中一些比较常用的缓存方式。
分片通讯量分析
集合通讯原语复习
在分析 ZeRO 分片策略的通讯量之前,我们先回顾一下常用的集合通讯原语,包括 AllReduce、Broadcast、Reduce、AllGather、ReduceScatter。这里参考英伟达 NCCL 的官方文档。
AllReduce
AllReduce 操作对所有节点上的数据进行规约操作(如 sum、min、max 等),并将结果保存在每个节点的缓冲区中。
以 k k k 个节点执行 sun 操作为例,每个节点提供一个含有 N N N 个元素的向量 V i V_i Vi ,得到所有节点上的 V i V_i Vi 加和之后的结果,同样是一个含有 N N N 个元素的向量 S S S。即有: S [ i ] = V 0 [ i ] + V 1 [ i ] + ⋯ + V k − 1 [ i ] S[i]=V_0[i]+V_1[i]+\dots+V_{k-1}[i] S[i]=V0[i]+V1[i]+⋯+Vk−1[i]。
AllReduce 是数据并行的通信基础,目前分布式训练中常用的是 Ring AllReduce,有兴趣可以读一下袁进辉老师的手把手推导Ring All-reduce的数学性质。
Broadcast
Broadcast 将某个节点上的向量复制到其他所有节点上。
Reduce
Reduce 操作的计算过程与 AllReduce 一致,只是只将结果写入到一个节点中。
注意:Reduce + Broadcast 等价于 AllReduce。
AllGather
AllGather 操作收集 k k k 个节点上的各自 N N N 个值,得到一个 k ∗ N k*N k∗N 的矩阵,并将其分发到所有节点上。
注意:执行 ReduceScatter + AllGather,等价于 AllReduce。
ReduceScatter
ReduceScatter 操作的计算过程与 Reduce 操作一致,只是将结果等分开来,按照节点序号分发给不同的节点。
通讯量分析
之前我们提到:ZeRO-1 和 ZeRO-2 与不使用 ZeRO 策略传统数据并行方式的通讯量一致,而 ZeRO-3,则要额外的通讯量。
传统数据数据并行在每一步(step/iteration)计算梯度后,需要进行一次 AllReduce 操作来计算梯度均值。常见的 Ring AllReduce,分为 ReduceScatter 和AllGather 两步,每张卡的通信数据量(发送+接受)近似为 2 Φ 2\Phi 2Φ。
我们直接分析 P o s + g P_{os+g} Pos+g ,每张卡只存储 1 N \frac{1}{N} N1 的优化器状态和梯度,对于 gpu0 来说,为了计算它这 1 N \frac{1}{N} N1 梯度的均值,需要进行一次 Reduce 操作,通信数据量是 1 N Φ ∗ N = Φ \frac{1}{N}\Phi*N=\Phi N1Φ∗N=Φ,然后其余显卡则不需要保存这部分梯度值了。实现中使用了 bucket 策略,保证 1 N \frac{1}{N} N1 的梯度每张卡只发送一次。
这里还要注意一点,假如模型最后两层的梯度落在 gpu0 ,为了节省显存,其他卡将这两层梯度删除,怎么计算倒数第三层的梯度呢?还是因为用了 bucket,其他卡可以将梯度发送和计算倒数第三层梯度同时进行,当二者都结束,就可以放心将后两层梯度删除了。
当 gpu0 计算好梯度均值后,就可以更新局部的优化器状态(包括 1 N Φ \frac{1}{N}\Phi N1Φ 的参数),当反向传播过程结束,进行一次Gather操作,更新 ( 1 − 1 N ) Φ (1-\frac{1}{N})\Phi (1−N1)Φ 的模型参数,通信数据量是 1 N Φ ∗ N = Φ \frac{1}{N}\Phi*N=\Phi N1Φ∗N=Φ 。
从全局来看,相当于用 Reduce-Scatter 和 AllGather 两步,与传统数据并行一致。
而对于 ZeRO-3, P o s + g + p P_{os+g+p} Pos+g+p 使得每张卡只存了 1 N \frac{1}{N} N1 的模型本身参数,不管是在前向计算还是反向传播,都涉及一次 Broadcast 操作。
ZeRO-Offload
GPU 显存是制约能够训练模型大小的关键因素。内存比 GPU 显存要廉价许多,ZeRO-Offload 的思路就是将暂时不用的张量放到内存中,来扩大可训练模型的规模。有点像内存将磁盘作为交换 swap 的思路。
ZeRO-Infinity
同样是进行 offload,ZeRO-Offload 更侧重单卡场景,而 ZeRO-Infinity 则是典型的工业界风格,试图打破大规模训练的内存墙,奔着极大规模训练去了。
Ref
- DeepSpeed之ZeRO系列:将显存优化进行到底
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- 大模型高效训练的关键技术|AI 盐沙龙
- 数据并行Deep-dive: 从DP 到 Fully Sharded Data Parallel (FSDP)完全分片数据并行
- Nvidia NCCL Collective Operations
- ZeRO-Offload: Democratizing Billion-Scale Model Training
- ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning
- AI算力的阿喀琉斯之踵:内存墙