数据并行 - DP/DDP/ZeRO

news2025/1/18 9:45:11

数据并行DP

数据并行的核心思想是:在各个GPU上都拷贝一份完整模型,各自吃一份数据,算一份梯度,最后对梯度进行累加来更新整体模型。理念不复杂,但到了大模型场景,巨大的存储和GPU间的通讯量,就是系统设计要考虑的重点了。在本文中,我们将递进介绍三种主流数据并行的实现方式:

  1. DP(Data Parallelism):最早的数据并行模式,一般采用参数服务器(Parameters Server)这一编程框架。实际中多用于单机多卡
  2. DDP(Distributed Data Parallelism):分布式数据并行,采用Ring AllReduce的通讯方式,实际中多用于多机场景
  3. ZeRO:零冗余优化器。由微软推出并应用于其DeepSpeed框架中。严格来讲ZeRO采用数据并行+张量并行的方式,旨在降低存储。

在这里插入图片描述
1)若干块计算GPU,如图中GPU0~GPU2;1块梯度收集GPU,如图中AllReduce操作所在GPU。
2)在每块计算GPU上都拷贝一份完整的模型参数。
3)把一份数据X(例如一个batch)均匀分给不同的计算GPU。
4)每块计算GPU做一轮FWD和BWD后,算得一份梯度G。
5)每块计算GPU将自己的梯度push给梯度收集GPU,做聚合操作。这里的聚合操作一般指梯度累加。当然也支持用户自定义。
6)梯度收集GPU聚合完毕后,计算GPU从它那pull下完整的梯度结果,用于更新模型参数W。更新完毕后,计算GPU上的模型参数依然保持一致。
7)聚合再下发梯度的操作,称为AllReduce。

  • 总结一下:打散 – 收集 – 反向分发(更新)

实现DP的一种经典编程框架叫“参数服务器”,在这个框架里,计算GPU称为Worker,梯度聚合GPU称为Server。在实际应用中,为了尽量减少通讯量,一般可选择一个Worker同时作为Server。比如可把梯度全发到GPU0上做聚合。需要再额外说明几点:

  • 1个Worker或者Server下可以不止1块GPU。
  • Server可以只做梯度聚合,也可以梯度聚合+全量参数更新一起做
  • 在参数服务器的语言体系下,DP的过程又可以被描述下图

在这里插入图片描述
那么问题所在:

  • 存储开销大。每块GPU上都存了一份完整的模型,造成冗余。
  • 通讯开销大。Server需要和每一个Worker进行梯度传输。当Server和Worker不在一台机器上时,Server的带宽将会成为整个系统的计算效率瓶颈。

概括一下:每一个节点干完自己的活儿提交上去,等sever的反馈更新,这个等待的过程就是浪费时间,且sever的压力非常大。


所以,梯度异步更新的idea就出来了

在梯度异步更新的场景下,某个Worker的计算顺序为:

  • 在第10轮计算中,该Worker正常计算梯度,并向Server发送push&pull梯度请求。
  • 但是,该Worker并不会实际等到把聚合梯度拿回来,更新完参数W后再做计算。而是直接拿旧的W,吃新的数据,继续第11轮的计算。这样就保证在通讯的时间里,Worker也在马不停蹄做计算,提升计算通讯比。
  • 当然,异步也不能太过份。只计算梯度,不更新权重,那模型就无法收敛。图中刻画的是延迟为1的异步更新,也就是在开始第12轮对的计算时,必须保证W已经用第10、11轮的梯度做完2次更新了。

意思就是,work的参数阶段性更新,隔多久更新一次由延迟时间步决定
三种更新方式:
(a) 无延迟
(b) 延迟但不指定延迟步数。也即在迭代2时,用的可能是老权重,也可能是新权重,听天由命。
(c ) 延迟且指定延迟步数为1。例如做迭代3时,可以不拿回迭代2的梯度,但必须保证迭代0、1的梯度都已拿回且用于参数更新。

总结一下,异步很香,但对一个Worker来说,只是等于W不变,batch的数量增加了而已,在SGD下,会减慢模型的整体收敛速度。异步的整体思想是,比起让Worker闲着,倒不如让它多吃点数据,虽然反馈延迟了,但只要它在干活在学习就行。

分布式数据并行(DDP)

受通讯负载不均的影响,DP一般用于单机多卡场景。
因此,DDP作为一种更通用的解决方案出现了,既能多机,也能单机。

  • DDP首先要解决的就是通讯问题:将Server上的通讯压力均衡转到各个Worker上。实现这一点后,可以进一步去Server,留Worker。

聚合梯度 + 下发梯度这一轮操作,称为AllReduce。

接下来我们介绍目前最通用的AllReduce方法:Ring-AllReduce。它由百度最先提出,非常有效地解决了数据并行中通讯负载不均的问题,使得DDP得以实现。

  • 太妙了,直接看图吧

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

  • 通过分组累加到此,每一块GPU上都有了一个分块完整的梯度,即1/4的完整梯度。谓之,Reduce-Scatter

  • 下面就是将每个分开的1/4完整的梯度同步到其他的GPU,也就是替换操作。谓之,ALLTOGETHER

在这里插入图片描述
在这里插入图片描述

  • 再迭代两轮,就OK了

在这里插入图片描述

在这里插入图片描述
小结一下:

朴素数据并行(DP)与分布式数据并行(DDP)。两者的总通讯量虽然相同,但DP存在负载不均的情况,大部分的通讯压力集中在Server上,而Server的通讯量与GPU数量呈线性关系,导致DP一般适用于单机多卡场景。而DDP通过采用Ring-AllReduce这一NCCL操作,使得通讯量均衡分布到每块GPU上,且该通讯量为一固定常量,不受GPU个数影响,因此可实现跨机器的训练。

  • DDP做了通讯负载不均的优化,但还遗留了一个显存开销问题:数据并行中,每个GPU上都复制了一份完整模型,当模型变大时,很容易打爆GPU的显存

ZeRO

由微软开发的ZeRO(零冗余优化),它是DeepSpeed这一分布式训练框架的核心,被用来解决大模型训练中的显存开销问题。

  • ZeRO的思想就是用通讯换显存

在这里插入图片描述
存储主要分为两大块:Model States和Residual States
Model States指和模型本身息息相关的,必须存储的内容,具体包括:

optimizer states:Adam优化算法中的momentum和variance
gradients:模型梯度
parameters:模型参数W

Residual States指并非模型必须的,但在训练过程中会额外产生的内容,具体包括:

activation:激活值。在流水线并行中我们曾详细介绍过。在backward过程中使用链式法则计算梯度时会用到。有了它算梯度会更快,但它不是必须存储的,因为可以通过重新做Forward来算它。
temporary buffers: 临时存储。例如把梯度发送到某块GPU上做加总聚合时产生的存储。
unusable fragment memory:碎片化的存储空间。虽然总存储空间是够的,但是如果取不到连续的存储空间,相关的请求也会被fail掉。对这类空间浪费可以通过内存整理来解决。


混合精度运算
精度混合训练,对于模型,我们肯定希望其参数越精准越好,也即我们用fp32(单精度浮点数,存储占4byte)来表示参数W。但是在forward和backward的过程中,fp32的计算开销也是庞大的。那么能否在计算的过程中,引入fp16或bf16(半精度浮点数,存储占2byte),来减轻计算压力呢?

于是,混合精度训练就产生了

在这里插入图片描述

  • 存储一份fp32的parameter,momentum和variance(统称model states)
  • 在forward开始之前,额外开辟一块存储空间,将fp32 parameter减半到fp16 parameter。
  • 正常做forward和backward,在此之间产生的activation和gradients,都用fp16进行存储。
  • 用fp16 gradients去更新fp32下的model states。
  • 当模型收敛后,fp32的parameter就是最终的参数输出。

也就是,模型参数存储时使用fp32,模型fw,bw计算时使用fp16

即设模型参数w为 ϕ \phi ϕ
在这里插入图片描述

  • 因为采用了Adam优化,所以才会出现momentum和variance,adam好像很费内存

这里暂不将activation纳入统计范围,原因是:

  1. activation不仅与模型参数相关,还与batch size相关
  2. activation的存储不是必须的。存储activation只是为了在用链式法则做backward的过程中,计算梯度更快一些。但你永远可以通过只保留最初的输入X,重新做forward来得到每一层的activation(虽然实际中并不会这么极端)。

因为activation的这种灵活性,纳入它后不方便衡量系统性能随模型增大的真实变动情况。因此在这里不考虑它,在后面会单开一块说明对activation的优化。


知道了什么东西会占存储,以及它们占了多大的存储之后,我们就可以来谈如何优化存储了。
注意到,在整个训练中,有很多states并不会每时每刻都用到,举例来说;

  1. Adam优化下的optimizer states只在最终做update时才用到
  2. 数据并行中,gradients只在最后做AllReduce和updates时才用到
  3. 参数W只在做forward和backward的那一刻才用到

所以,ZeRO想了一个简单粗暴的办法:如果数据算完即废,等需要的时候,我再想办法从个什么地方拿回来,那不就省了一笔存储空间吗?


优化状态分割

在这里插入图片描述

  • 优化参数在模型的W中,优化状态分割的意思是把W切开,每一个GPU单独自己更新属于自己的,然后在同步一下

在这里插入图片描述
在这里插入图片描述

  • 显存下降的非常明显, 在增加1.5倍单卡通讯开销的基础上,将单卡存储降低了4倍。看起来是个还不错的trade-off

接着切,同理可得,切梯度G
在这里插入图片描述
每块GPU用自己对应的O和G去更新相应的W。更新完毕后,每块GPU维持了一块更新完毕的W。同理,对W做一次All-Gather,将别的GPU算好的W同步到自己这来
在这里插入图片描述
全部都切开!!!

每块GPU置维持对应的optimizer states,gradients和parameters
在这里插入图片描述

最后数据并行的流程如下
(1)每块GPU上只保存部分参数W。将一个batch的数据分成3份,每块GPU各吃一份。
(2)做forward时,对W做一次All-Gather,取回分布在别的GPU上的W,得到一份完整的W,单卡通讯量 。forward做完,立刻把不是自己维护的W抛弃。
(3)做backward时,对W做一次All-Gather,取回完整的W,单卡通讯量 。backward做完,立刻把不是自己维护的W抛弃。
(4)做完backward,算得一份完整的梯度G,对G做一次Reduce-Scatter,从别的GPU上聚合自己维护的那部分梯度,单卡通讯量 。聚合操作结束后,立刻把不是自己维护的G抛弃。
(5)用自己维护的O和G,更新W。由于只维护部分W,因此无需再对W做任何AllReduce操作。

在这里插入图片描述

ZeRO - R

现在来看对residual states的优化

固定大小的内存buffer:

  • 提升带宽利用率。当GPU数量上升,GPU间的通讯次数也上升,每次的通讯量可能下降(但总通讯量不会变)。数据切片小了,就不能很好利用带宽了。所以这个buffer起到了积攒数据的作用:等数据积攒到一定大小,再进行通讯。
  • 使得存储大小可控。在每次通讯前,积攒的存储大小是常量,是已知可控的。更方便使用者对训练中的存储消耗和通讯时间进行预估。

设置机制,对碎片化的存储空间进行重新整合,整出连续的存储空间

防止出现总存储足够,但连续存储不够而引起的存储请求fail

ZeRO-Offload

最后,简单介绍一下ZeRO-Offload。它的核心思想是:显存不够,内存来凑。如果我把要存储的大头卸载(offload)到CPU上,而把计算部分放到GPU上,这样比起跨机,是不是能既降显存,也能减少一些通讯压力呢?

  • forward和backward计算量高,因此和它们相关的部分,例如参数W(fp16),activation,就全放入GPU。
  • update的部分计算量低,因此和它相关的部分,全部放入CPU中。例如W(fp32),optimizer states(fp32)和gradients(fp16)等

核心思想是:显存不够,内存来凑。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/960159.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【OpenCV实战】4.OpenCV 五种滤波使用实战(均值、盒状、中值、高斯、双边)

OpenCV 五种滤波使用实战(均值、盒状、中值、高斯、双边) 〇、Coding实战内容一、滤波、核和卷积1.1 滤波1.2 核 & 滤波器1.3 公式1.4 例子 二、图片边界填充实战2.1 解决问题2.2 相关OpenCV函数2.3 Code 三. 均值滤波实战3.1 理论3.2 Blur3.3 Code 四. 盒状滤波…

集成跨境电商ERP(积加、易仓、马帮等)连接多个应用

场景描述: 基于跨境电商开放平台(积加、易仓、马帮等)能力,无代码集成跨境电商ERP与多个应用互通互连。通过Aboter可搭建业务自动化流程,实现多个应用之间的数据连接。 连接器: 积加ERP马帮ERP易仓ERP……

Visual Studio Code 终端配置使用 MySQL

Visual Studio Code 终端配置使用 MySQL 找到 MySQL 的 bin 目录 在导航栏中搜索–》服务 找到MySQL–>双击 在终端切换上面找到的bin目录下输入指令 终端为Git Bash 输入命令 ./mysql -u root -p 接着输入密码,成功在终端使用 MySQL 数据库。

【LLM】快速开始 LangChain

theme: orange LangChain是一个软件开发工具包,它通过将组件链接在一起并公开简单统一的API,简化了大型语言模型和应用程序的集成。本篇文章将会简要介绍,让各位开发者对其有一个整体的认识。 前言 如果你是一名软件开发人员,努力…

chatGPT讲师AIGC讲师叶梓:大模型这么火,我们在使用时应该关注些什么?-5

以下为叶老师讲义分享: P20-P24 顺便看看某大模型觉得“两头蛇”长啥样? “羊驼-2”的神逻辑 欣赏一下GPT-4给出的满分答案 提示工程的模式 1、说明模式下,您为 ChatGPT 输入内容来解释或阐明一个概念或理论。 它的主要功能是定义各种概念。…

设计封面有诀窍,这5个实用软件让你快人一步

每位作者和出版商都梦想着为他们的作品设计一个引人注目的封面。这样一来,潜在的读者才会被吸引,愿意拿起这本书来阅读,从而提高书籍的销售量。这正是封面设计软件发挥作用的地方。专业的封面设计软件能够添加前沿的效果,以呈现书…

使用代理服务器和pip安装软件包

在开着代理服务器的情况下,直接pip install 软件包名会出现如下错误, WARNING: Retrying (Retry(total4, connectNone, readNone, redirectNone, statusNone)) after connection broken by SSLError(SSLZeroReturnError(6, TLS/SSL connection has been…

【Cadence】Calculator计算sp的3dB带宽

【Cadence】Calculator计算sp的3dB带宽 1.计算最大增益2.cross函数3. 3dB带宽 下面演示如何在Cadence计算s参数(如增益)的3dB带宽 1.计算最大增益 ymax函数 2.cross函数 cross函数可以计算经过y轴给定值对应的x坐标 edge number选择1是经过的第一个点…

B081-Lucene+ElasticSearch

目录 认识全文检索概念lucene原理全文检索的特点常见的全文检索方案 Lucene创建索引导包分析图代码 搜索索引分析图代码 ElasticSearch认识ElasticSearchES与Kibana的安装及使用说明ES相关概念理解和简单增删改查ES查询DSL查询DSL过滤 分词器IK分词器安装测试分词器 文档映射(字…

火热报名中 | 网安朝阳·西门子白帽黑客大赛燃爆来袭

2022年 首届西门子白帽黑客大赛 集结全国网安精英 以热爱之名 引爆整个夏天 2023年 网安朝阳西门子白帽黑客大赛—— 国际精英挑战赛 再度重磅归来 网安骑士的荣耀角斗场 等你来战 赛宁网安持续为第二届赛事 提供全程服务支持 热血战役 即将打响 报名通道现已开启…

风险评估

风险评估概念 风险评估是一种系统性的方法,用于识别、评估和量化潜在的风险和威胁,以便组织或个人能够采取适当的措施来管理和减轻这些风险。 风险评估的目的 风险评估要素关系 技术评估和管理评估 风险评估分析原理 风险评估服务 风险评估实施流程

SQLAlchemy 封装的工具类,数据库pgsql(数据库连接池)

1.SQLAlchemy是什么? SQLAlchemy 是 Python 著名的 ORM 工具包。通过 ORM,开发者可以用面向对象的方式来操作数据库,不再需要编写 SQL 语句。 SQLAlchemy 支持多种数据库,除 sqlite 外,其它数据库需要安装第三方驱动。…

专访远航汽车远勤山:踏踏实实做好产品 直面挑战乘风远航

8月25日,第二十六届成都国际汽车展览会在中国西部国际博览城隆重开幕。车展举办期间,远航汽车董事长远勤山先生、产品研发总监王震先生向媒体分享了远航汽车品牌发展、产品研发、技术创新以及市场布局等内容。 “通过我们的付出和努力,让我们…

景芯SoC 芯片全流程培训

【全网唯一】景芯SoC是一款用于芯片全流程培训的低功耗ISP图像处理SoC,采用低功耗RISC-V处理器,内置ITCM SRAM、DTCM SRAM,集成包括MIPI、ISP、CNN、QSPI、UART、I2C、GPIO、百兆以太网等IP,采用SMIC40工艺设计流片。 培训数据包括…

云计算在智能制造中的应用与前景

文章目录 云计算的基本概念智能制造的基本概念云计算在智能制造中的应用1. 数据存储和管理2. 大数据分析3. 机器学习和预测维护4. 跨地理分布的协作5. 资源弹性和成本优化 未来前景1. 智能工厂2. 预测性维护3. 定制化生产4. 绿色生产5. 全球制造协作 结论 🎉欢迎来到…

QTday3(QT实现文件对话框保存操作、实现键盘触发事件【WASD控制小球的移动】)

1.实现文件对话框保存操作 #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }void Widget::on_fontBtn_clicked() {//调用QFo…

Java的23种设计模式

Java的23种设计模式 一、创建型设计模式1.单例模式 singleton1.1.静态属性单例模式1.2 静态属性变种1.3 基础的懒汉模式1.4 线程安全的懒加载单例1.5 线程安全的懒加载 单例-改进1.6 双重检查锁1.7 静态内部类1.8 枚举单例1.9 注册表单例 2.工厂方法模式 factory3.抽象工厂模式…

Error:Java:无效的源发行版:14

问题描述:项目拉下来,跑的时候发现版本有问题。这个问题可好解决了,只需要看下面几个方面,然后让他们保持一致就OK了 step1:查看本地的jdk版本 打开cmd窗口,输入命令 java -version就可以查看到本地的jdk版…

dji uav建图导航系列(三)模拟建图、导航

前面博文【dji uav建图导航系列()建图】、【dji uav建图导航系列()导航】 使用真实无人机和挂载的激光雷达完成建图、导航的任务。 当需要验证某一个slam算法时,我们通常使用模拟环境进行测试,这里使用stageros进行模拟测试,实际就是通过模拟器,虚拟一个带有传感器(如…

如何一键批量查询全部物流信息?

在日常工作中,快递物流信息的查询是一项常规任务。然而,这个过程往往既耗时又费力,尤其是在面对大量单号的情况下。为了解决这个问题,我们推荐使用固乔快递查询助手,一款能够快速、准确地查询快递物流信息的软件。 首先…