分布式训练数据并行极致优化:ZeRO

news2024/10/6 8:58:25

分布式训练数据并行极致优化:ZeRO

导言

随着 ChatGPT 的爆火,大模型成为了近些年人工智能的研究热点。大模型能力惊艳,但是训练起来成本也不小。大模型,顾名思义,最大的特点就是 “大”。这里的 “大” 通常指的就是模型的参数量大。因此,在分布式训练中,如何利用有限的显存训练更大的模型就是重点。分布式的训练的常用范式包括数据并行和模型并行,其中模型并行又包括张量并行和流水线并行。Megatron-ML 等框架中实现的的张量并行已经是是训练大模型的标配,但是数据并行作为最简洁、最易理解、最易实现的分布式训练范式,近些年还是有了完善的优化。本文主要介绍分布式训练数据并行的极致优化:ZeRO。

数据并行中,一个显著的问题就是每张卡都需要保存一个完整的模型及其优化参数(包括模型梯度、Adam 参数等),这其中有极大的冗余性,能否每张卡只保存 全部模型参数的一部分呢。ZeRO(Zero Redundancy Optimizer,另冗余优化)是由微软在 2019 年提出的一种高效的数据并行方案。ZeRO 能够消除分布式训练数据并行中的冗余性,并同时能够维持较低的通信量和较高的计算粒度。这使得我们能够在显存有限的条件下,训练更大的模型。近年来比较知名的分布式训练框架,如微软的 DeepSpeed、Pytorch 的 FSDP 都是基于 ZeRO 的数据并行思想。

数据并行空间复杂度分析

我们以如今比较常用的 ADAM 优化器和混合精度训练的情况为例,来分析训练过程中的显存占用。

ADAM 维护梯度的一阶动量(momentum)和二阶动量(variance),具有动态的学习率,是现今常用的优化器。从显存占用的角度来看,ADAM 优化器除了需要维护模型参数及其梯度之外,还需要维护 momentum 和 variance。

混合精度训练已经是如今训练大规模训练的标配,它能在几乎不损失性能的情况下减小显存占用并加快训练速度。混合精度训练过程中一般有 fp16 和 fp32 两种精度类型的数值。fp16 类型包含模型参数及其梯度,fp32 类型包括模型参数的 fp32 备份,以及优化器需要维护的参数,比如 ADAM 中的 momentum 和 variance。

以上是 ADAM 优化器 + 混合精度训练情况下模型状态的显存占用。除此之外,训练中还有激活值、临时缓冲区和显存碎片等。

综上所述,训练过程中的显存占用可分为两大部分:

  1. 模型状态:记模型本身参数量为 Φ \Phi Φ ,在 Adam + 混合精度训练的情况下,模型状态包括 fp16 的模型参数 2 Φ 2\Phi 和参数梯度 2 Φ 2\Phi 和 fp32 的模型参数备份 4 Φ 4\Phi ,momentum 4 Φ 4\Phi 和 variance 4 Φ 4\Phi ,即总共 2 Φ + 2 Φ + 4 Φ + 4 Φ + 4 Φ = 16 Φ 2\Phi+2\Phi+4\Phi+4\Phi+4\Phi=16\Phi ++++=16Φ 。(注意 fp16 占两个字节,fp32 占四个字节)
  2. 剩余状态:即训练中的激活值、临时缓冲区和显存碎片等。

以 GPT-2 为例,GPT-2 模型含有 1.5B 个参数,如果用 fp16 格式,模型本身只占 3GB 显存,但是实际训练过程中的模型状态需要耗费 24GB!可以看到。模型状态是成倍于模型本身的大小,是显存消耗的大头。并且,对于剩余状态中的激活值等,已经有 activation checkpointing 等以时间换空间的优化方式,可以有效减小这部分显存消耗。因此,优化模型状态的显存占用是重点。

ZeRO 由 ZeRO-DP 和 ZeRO-R 组成,分别是对模型状态和剩余状态的显存优化。

ZeRO-DP

模型状态是 ZeRO 显存优化的重点。在导言中提到,数据并行的分布式训练方式中,每个 GPU 都要保存一份独立、完整的模型状态参数,即 12 Φ 12\Phi 12Φ 的显存占用。显然,这其中是存在大量冗余的,按理说,我们只要保存一份模型状态参数即可。这正是 ZeRO 优化的思路:分片(partition),在分布式训练的 N N N 个 GPU 中,每个 GPU 保存 1 N \frac{1}{N} N1 的模型状态参数,当计算需要其他部分的模型状态参数时,将其他 GPU 保存的参数传过来即可。这是一种以带宽换显存的思路。

下面的图来自 ZeRO 论文原文,比较直观地展示了 ZeRO 显存优化的思路。

在这里插入图片描述

ZeRO-DP 的显存优化有三个优化等级,一般称为 ZeRO-1,ZeRO-2,ZeRO-3,对应图中的 P o s P_{os} Pos P o s + g P_{os+g} Pos+g P o s + g + p P_{os+g+p} Pos+g+p 。未进行优化是,显存占用为 ( 2 + 2 + K ) ∗ Φ (2+2+K)*\Phi (2+2+K)Φ

  • ZeRO-1:首先,根据之前的分析,Adam 优化器状态(Optimizer States,os)是占用显存最多的,对应图中绿色部分。将优化器状态分片,在不同的 GPU 上维护。从而 ZeRO-1 的显存占用为 ( 2 + 2 + K N ) ∗ Φ (2+2+\frac{K}{N})*\Phi (2+2+NK)Φ,当 K → ∞ K\rightarrow \infty K 是,约为 4 Φ 4\Phi
  • ZeRO-2:其次要优化的是梯度(Gradients,g),对应图中橙色部分,同样切片保存到不同的 GPU 上,显存占用为 ( 2 + 2 + K N ) ∗ Φ (2+\frac{2+K}{N})*\Phi (2+N2+K)Φ ,当 K → ∞ K\rightarrow \infty K 是,约为 2 Φ 2\Phi
  • ZeRO-3:最后要优化的是模型参数(Parameter,p),对应图中绿色部分,此时显存占用为 2 + 2 + K N ∗ Φ \frac{2+2+K}{N}*\Phi N2+2+KΦ,当 K → ∞ K\rightarrow \infty K 是,模型状态所占显存接近于零。

可以看到,使用 ZeRO 策略将模型状态进行分片保存,随着 GPU 增加,分片越来越多,该部分的显存占用越来越小,甚至理论上会趋于零。

但实际中,要考虑各个 GPU 之间通讯的开销,别忘了,我们现存的节省,使用带宽和通讯“换”来的。结论是:ZeRO-1 和 ZeRO-2 与不使用 ZeRO 策略传统数据并行方式的通讯量一致,而 ZeRO-3,则要额外的通讯量。具体的分析后面会单独讲。权衡显存占用和通讯开销,实际中我们一般选择 ZeRO-1 或 ZeRO-2 即可。DeepSpeed 中可以设置 ZeRO-1/2/3,而 Pytorch 的 FSDP,即 Fully Sharded Data Parallel,既然是 Fully,即是完全切片了,相当于 ZeRO-3。

ZeRO-R

ZeRO-DP 优化了模型状态的显存占用,而 ZeRO-R 则优化剩余状态,也就是激活值(activation)、临时缓冲区(buffer)以及显存碎片(fragmentation)。

  • 激活值同样使用分片方法,并且配合 activation-checkpointing 来进一步减小显存占用;
  • 模型训练过程中经常会创建一些大小不等的临时缓冲区,比如对梯度进行 AllReduce 等,解决办法就是预先创建一个固定的缓冲区,训练过程中不再动态创建,如果要传输的数据较小,则多组数据 bucket 后再一次性传输,提高效率
  • 显存出现碎片的一大原因是时候 gradient checkpointing 后,不断地创建和销毁那些不保存的激活值,解决方法是预先分配一块连续的显存,将常驻显存的模型状态和 checkpointed activation 存在里面,剩余显存用于动态创建和销毁 discarded activation

ZeRO-R 部分都是计算机系统中一些比较常用的缓存方式。

分片通讯量分析

集合通讯原语复习

在分析 ZeRO 分片策略的通讯量之前,我们先回顾一下常用的集合通讯原语,包括 AllReduce、Broadcast、Reduce、AllGather、ReduceScatter。这里参考英伟达 NCCL 的官方文档。

AllReduce

AllReduce 操作对所有节点上的数据进行规约操作(如 sum、min、max 等),并将结果保存在每个节点的缓冲区中。

k k k 个节点执行 sun 操作为例,每个节点提供一个含有 N N N 个元素的向量 V i V_i Vi ,得到所有节点上的 V i V_i Vi 加和之后的结果,同样是一个含有 N N N 个元素的向量 S S S。即有: S [ i ] = V 0 [ i ] + V 1 [ i ] + ⋯ + V k − 1 [ i ] S[i]=V_0[i]+V_1[i]+\dots+V_{k-1}[i] S[i]=V0[i]+V1[i]++Vk1[i]

AllReduce 是数据并行的通信基础,目前分布式训练中常用的是 Ring AllReduce,有兴趣可以读一下袁进辉老师的手把手推导Ring All-reduce的数学性质。

在这里插入图片描述

Broadcast

Broadcast 将某个节点上的向量复制到其他所有节点上。

在这里插入图片描述

Reduce

Reduce 操作的计算过程与 AllReduce 一致,只是只将结果写入到一个节点中。

注意:Reduce + Broadcast 等价于 AllReduce。

在这里插入图片描述

AllGather

AllGather 操作收集 k k k 个节点上的各自 N N N 个值,得到一个 k ∗ N k*N kN 的矩阵,并将其分发到所有节点上。

注意:执行 ReduceScatter + AllGather,等价于 AllReduce。

在这里插入图片描述

ReduceScatter

ReduceScatter 操作的计算过程与 Reduce 操作一致,只是将结果等分开来,按照节点序号分发给不同的节点。

在这里插入图片描述

通讯量分析

之前我们提到:ZeRO-1 和 ZeRO-2 与不使用 ZeRO 策略传统数据并行方式的通讯量一致,而 ZeRO-3,则要额外的通讯量。

传统数据数据并行在每一步(step/iteration)计算梯度后,需要进行一次 AllReduce 操作来计算梯度均值。常见的 Ring AllReduce,分为 ReduceScatter 和AllGather 两步,每张卡的通信数据量(发送+接受)近似为 2 Φ 2\Phi

我们直接分析 P o s + g P_{os+g} Pos+g ,每张卡只存储 1 N \frac{1}{N} N1 的优化器状态和梯度,对于 gpu0 来说,为了计算它这 1 N \frac{1}{N} N1 梯度的均值,需要进行一次 Reduce 操作,通信数据量是 1 N Φ ∗ N = Φ \frac{1}{N}\Phi*N=\Phi N1ΦN=Φ,然后其余显卡则不需要保存这部分梯度值了。实现中使用了 bucket 策略,保证 1 N \frac{1}{N} N1 的梯度每张卡只发送一次。

这里还要注意一点,假如模型最后两层的梯度落在 gpu0 ,为了节省显存,其他卡将这两层梯度删除,怎么计算倒数第三层的梯度呢?还是因为用了 bucket,其他卡可以将梯度发送和计算倒数第三层梯度同时进行,当二者都结束,就可以放心将后两层梯度删除了。

当 gpu0 计算好梯度均值后,就可以更新局部的优化器状态(包括 1 N Φ \frac{1}{N}\Phi N1Φ 的参数),当反向传播过程结束,进行一次Gather操作,更新 ( 1 − 1 N ) Φ (1-\frac{1}{N})\Phi (1N1)Φ 的模型参数,通信数据量是 1 N Φ ∗ N = Φ \frac{1}{N}\Phi*N=\Phi N1ΦN=Φ

从全局来看,相当于用 Reduce-Scatter 和 AllGather 两步,与传统数据并行一致。

而对于 ZeRO-3, P o s + g + p P_{os+g+p} Pos+g+p 使得每张卡只存了 1 N \frac{1}{N} N1 的模型本身参数,不管是在前向计算还是反向传播,都涉及一次 Broadcast 操作。

ZeRO-Offload

GPU 显存是制约能够训练模型大小的关键因素。内存比 GPU 显存要廉价许多,ZeRO-Offload 的思路就是将暂时不用的张量放到内存中,来扩大可训练模型的规模。有点像内存将磁盘作为交换 swap 的思路。

ZeRO-Infinity

同样是进行 offload,ZeRO-Offload 更侧重单卡场景,而 ZeRO-Infinity 则是典型的工业界风格,试图打破大规模训练的内存墙,奔着极大规模训练去了。

Ref

  • DeepSpeed之ZeRO系列:将显存优化进行到底
  • ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
  • 大模型高效训练的关键技术|AI 盐沙龙
  • 数据并行Deep-dive: 从DP 到 Fully Sharded Data Parallel (FSDP)完全分片数据并行
  • Nvidia NCCL Collective Operations
  • ZeRO-Offload: Democratizing Billion-Scale Model Training
  • ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning
  • AI算力的阿喀琉斯之踵:内存墙

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/795934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

本地部署 Stable Diffusion XL 1.0 Gradio Demo WebUI

StableDiffusion XL 1.0 Gradio Demo WebUI 0. 先展示几张 StableDiffusion XL 生成的图片1. 什么是 Stable Diffusion XL Gradio Demo WebUI2. Github 地址3. 安装 Miniconda34. 创建虚拟环境5. 安装 Stable Diffusion XL Gradio Demo WebUI6. 启动 Stable Diffusion XL Gradi…

c语言内存函数的深度解析

本章对 memcpy,memmove,memcmp 三个函数进行详解和模拟实现; 本章重点:3个常见内存函数的使用方法及注意事项并学会模拟实现; 如果您觉得文章不错,期待你的一键三连哦,你的鼓励是我创作的动力…

基于深度学习的裂纹图像分类研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

2023最新ChatGPT商业运营版网站源码+支持ChatGPT4.0+GPT联网+支持ai绘画(Midjourney)+支持Mind思维导图生成

本系统使用Nestjs和Vue3框架技术,持续集成AI能力到本系统! 支持GPT3模型、GPT4模型Midjourney专业绘画(全自定义调参)、Midjourney以图生图、Dall-E2绘画Mind思维导图生成应用工作台(Prompt)AI绘画广场自定…

行业动态 - Zhaga 常见问题解答

本文采用chatGPT 3.5翻译润色,内容来自于Zhaga联盟官网Zhaga FAQ [1],原文网页提供了更多的延伸阅读资料,可点击文末链接访问。另外不得不说,chatGPT对文字的优化调整功能太好用了。 ​ 1. "Zhaga"这个名字的由来和含义…

Nuxt 菜鸟入门学习笔记二:配置

文章目录 Nuxt 配置环境覆盖环境变量和私有令牌 应用配置runtimeConfig 与 app.config外部配置文件Vue 配置支持配置 Vite配置 webpack启用试验性 Vue 功能 Nuxt 官网地址: https://nuxt.com/ 默认情况下,Nuxt 的配置涵盖了大多数用例。nuxt.config.ts …

【雕爷学编程】Arduino动手做(172)---WeMos D1开发板模块4

37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&am…

为什么 Raft 原生系统是流数据的未来

虽然Apache Kafka正在逐步引入KRaft以简化其一致性方法,但基于Raft构建的系统对于未来的超大规模工作负载显示出更多的潜力。 ​共识是一致性分布式系统的基础。为了在不可避免的崩溃事件中保证系统的可用性,系统需要一种方式来确保集群中的每个节点保持…

快速跑 nerf instant-ngp 快速调试与配置,跑自己的数据

1.下载Anaconda3 2.打开Anaconda Prompt (Anaconda) 创建虚拟环境 conda create -n nerf-ngp python3.8切换到虚拟环境 conda activate nerf-ngp安装相关依赖包 pip install commentjson imageio numpy opencv-python-headless pybind11 pyquaternion scipy tqdm安装完毕后…

现在设计师都在用哪些工具做UI设计

随着国内企业在用户交互方面的竞争,UI设计的未来是无限的。 如果你仍然或只是在寻找一个合适的UI设计工具,那么这篇文章应该非常适合你。 1.即时设计 即时设计是一款免费的在线 UI 设计工具,无系统限制,浏览器打开即可使用&…

Java面试准备篇:全面了解面试流程与常见问题

文章目录 1.1 Java面试概述1.2 面试流程和注意事项1.3 自我介绍及项目介绍1.4 常见面试问题 在现代职场中,面试是求职过程中至关重要的一环,特别是对于Java开发者而言。为了帮助广大Java开发者更好地应对面试,本文将提供一份全面的Java面试准…

Python中安装pyinstaller并打包为exe可执行程序

环境:vs2022 win10 python3.7.8 工具:pyinstaller 1、安装pyinstaller,cmd --> pip install pyinstaller 2、安装完成后,打开cmd,输入命令:pyinstaller -F xxx.py ,xxx为py文件的全路径&am…

超细整理,Python接口自动化测试-关联参数(购物接口实例)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 什么是参数关联&a…

Hadoop生态体系-2

目录标题 1、MapReduce介绍2、数据仓库3、HIVE4、HQL4.1 hive读写文件机制4.2 Hive数据存储路径 1、MapReduce介绍 思想:分而治之 map:“分”,即把复杂的任务分解为若干个“简单的任务”来处理。可以进行拆分的前提是这些小任务可以并行计算&#xff0c…

3D 渲染技巧-如何创建高质量写实渲染?

掌握创建高质量建筑渲染和任何 3D 渲染的艺术是一项复杂且需要技巧的工作,通常需要多年的经验和实践。实现逼真的结果需要仔细考虑众多因素,并避免可能导致缺乏真实性的假渲染效果的常见错误。 避免常见错误 - 提升渲染游戏的技巧 在追求创建真正逼真的…

数据中心机房机柜配电新模式的探讨与选型

安科瑞 华楠 摘 要:对数据中心机房列头柜配电方式特征和问题进行深入研究,分析机房末端配电安全性及可用性,主要阐述了数据中心机房机柜配电新模式。 关键词:数据中心;机房机柜;配电模式 1 原始配电方案 …

Pycharm中如何设置在新窗口打开项目

settingAppearance&Behavior–System SettingsOpen project in - new window

抑郁症的自我治疗:警惕隐藏在微笑背后的抑郁症

抑郁症是一种常见的心理疾病,它可以隐藏在微笑背后。许多人经常感到沮丧、情绪低落,这时候可能是抑郁症的前兆。然而,自我治疗也是一种非常有效的抑郁症治疗方法。在本文中,我将分享一些关于如何自我治疗抑郁症的方法。 首先&…

递归对比对象函数

在JavaScript中,对象之间的比较通常通过引用进行。当你使用运算符比较两个对象时,它会检查它们是否引用了同一个内存地址,而不是逐个比较对象的属性。 上图可见,obj1和{}是两个不同的对象,尽管它们具有相同的结构&…

运算方法与运算器

一、定点数运算及溢出检测 1. 定点数加法运算 2. 定点数减法运算 3. 数溢出的概念及其判断方法 运算结果超出了某种数据类型的表示范围 (1)溢出的概念 (2)溢出的检测方法 溢出只可能发生在同符号数相加 方法1:对操…