并行优化策略

news2024/11/22 8:55:33

并行优化策略汇总

并行优化策略

数据并行(DP)

将数据集分散到m个设备中,进行训练。得到训练数据后在进行allreduce操作。确保每个worker都有相同模型参数。

整体流程如下

  • 若干块计算GPU,如图中GPU0~GPU2;1块梯度收集GPU,如图中AllReduce操作所在GPU。
  • 在每块计算GPU上都拷贝一份完整的模型参数。
  • 把一份数据X均匀分给不同的计算GPU。
  • 每块计算GPU做一轮FWD和BWD后,算得一份梯度G。
  • 每块计算GPU将自己的梯度push给梯度收集GPU,做聚合操作。这里的聚合操作一般指梯度累加。当然也支持用户自定义。
  • 梯度收集GPU聚合完毕后,计算GPU从它那pull下完整的梯度结果,用于更新模型参数W。更新完毕后,计算GPU上的模型参数依然保持一致。
  • 聚合再下发梯度的操作,称为AllReduce
梯度异步更新


上图刻画了在梯度异步更新的场景下,某个Worker的计算顺序为:

  • 在第10轮计算中,该Worker正常计算梯度,并向Server发送push&pull梯度请求。
  • 但是,该Worker并不会实际等到把聚合梯度拿回来,更新完参数W后再做计算。而是直接拿旧的W,吃新的数据,继续第11轮的计算。这样就保证在通讯的时间里,Worker也在马不停蹄做计算,提升计算通讯比。
  • 当然,异步也不能太过份。只计算梯度,不更新权重,那模型就无法收敛。图中刻画的是延迟为1的异步更新,也就是在开始第12轮对的计算时,必须保证W已经用第10、11轮的梯度做完2次更新了。

张量并行(TP)

它的基本思想就是把模型的参数纵向切开,放到不同的GPU上进行独立计算,然后再做聚合。

切分权重方式:

设输入数据为X,参数为W。X的维度 = (b, s, h),W的维度 = (h, h')。

横向

当GPU数量为2的时候,将权重分成2份,再将input按照h分开成2列即可做矩阵乘法。

反向传播时


 g  的 backward:

假定现在我们要对  W_i  求梯度,则可推导出:

\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial Y} \ast \frac{\partial Y}{\partial W_i}

因为Y = Y_1+Y_2所以\frac{\partial Y}{\partial Y_i}=1

也就是说,只要把\frac{\partial L}{\partial Y}同时广播到两块 GPU 上,两块 GPU 就可以独立计算各自权重的梯度了。


的 backward:

在上图中,我们只画了模型其中一层的计算过程。当模型存在多层时,梯度要从上一层向下一层传播。例如图中,梯度要先传播到  X ,然后才能往下一层继续传递。这就是  f  的 backward 的作用。

这里也易推导出:

\frac{\partial L}{\partial X} = \text{concat}\left[\frac{\partial L}{\partial X_1}, \frac{\partial L}{\partial X_2}\right]

• 如果X被分解为多个子部分,则需要先对每个部分计算梯度。

• 最后,通过 concat 将这些分部分梯度拼接,得到输入  X  的整体梯度。

纵向

将w按照列切分

反向传播

f与g的区别

f与g都是算子

f层的作用

前向操作:

将输入 X 分成多个部分,分别通过不同的分支(如 XW1 和 XW2)。

反向传播:

梯度来自多个路径,需要相加汇总

\frac{\partial L}{\partial X} = \frac{\partial L}{\partial X}|_1 + \frac{\partial L}{\partial X}|_2 

\frac{\partial L}{\partial X}|_i表示在第i块gpu上计算x的梯度

g层的作用

前向操作:

合并多个分支结果(如将 Y1 和 Y2 拼接成 Y)。

反向传播:

梯度从整体拆分为各分支,分配到对应的输出:

Pipeline并行(PP)

将模型分成不同层,每层放到一块卡上。

模型做forward和backward过程:

 gpu0中做一次forward,然后数据传入到gpu1,然后gpu1处理完再传到gpu2.。。。

backward反过来即可。

缺点就是利用率太低了

1) 当gpu数量越多,空闲的比例越大,gpu资源都浪费了。

2) 中间结果的内存占用量太大,当backwards的时候会使用forward时的中间结果,假如有n层模型,中间结要存n-1次。

GPipe

把原先的数据再划分成若干个batch,送入GPU进行训练

 Gpipe通过实验证明,当 M>=4K 时,bubble产生的空转时间占比对最终训练时长影响是微小的,可以忽略不计。

内存问题:

1) 检查点机制

只保存上一块GPU的激活值,剩下的结果等backwards时再进行计算。

如何计算:使用上一块的激活值再执行一遍forward得到结果

2) PipeDream

每块GPU峰值时刻存储大小 = 每块GPU上的输入数据大小 + 每块GPU在forward过程中的中间结果大小

使用Pytorch提供的pipeline接口,其中有一个参数叫checkpoint可以实现PP的检查点机制。

PipeDream

GPipe需要等所有的microbatch前向传播完成后,才会开始反向传播。PipeDream则是当一个microbatch的前向传播完成后,立即进入反向传播阶段。(1F1B)

PipeDream在bubbles上与GPipe没有区别,但是由于PipeDream释放显存的时间更早,因此会降低对显存的需求。

混合并行

DP与PP并行

如图所示,GPU13 和GPU24 分别为两个子组,他们的参数相同,形成dp并行;

GPU12和GPU34分别组成一个完整的模型,形成了pp并行。

3D并行

3D并行是由DP、TP和PP组成。

例:

这个并行策略结合了 流水线并行张量并行数据并行,是一种高效利用多GPU的分布式训练方法。

1. 流水线并行

划分方式

• 将整个模型分为 4 份(model_1model_4)。

• 每连续的 4 张 GPU(颜色相同的rank) 负责一个子模型model_i

目的

让不同的 GPU 分别负责模型的不同部分,以减小单 GPU 的计算压力,同时利用流水线方式提升吞吐量。

2. 张量并行

划分方式

• 针对每个 model_i,再细分到多个 GPU 以并行计算单个张量。

• 每个 张量并行组 包含 2 个 GPU。

目的

将张量的计算分散到多个 GPU 上,提高单个子模型的计算效率。

3. 数据并行

划分方式

• 保证并行组中使用相同模型参数的 GPU 读取相同的数据。

• 例如,Rank0 和 Rank2 负责相同的参数,因此它们组成一个 数据并行组

目的

在不同设备上同步模型权重,保证参数一致性,支持大规模数据训练。

结果:

假设有 16 张 GPU(Rank0 到 Rank15):

流水线并行:每 4 张 GPU 分配到一个模型块。

张量并行:每 2 张 GPU 形成一个张量并行组。

数据并行:跨张量并行组同步相同参数的 GPU 形成数据并行组。

分析

3D并行需要平衡显存效率、计算效率和通信开销.

分配策略

1. 优先模型并行组放置在一个节点内

原因:模型并行(包括张量并行和流水线并行)是三种策略中通信开销最大的。

策略:尽量将张量并行组安排在同一个节点内,以利用较大的节点内带宽,减少跨节点通信。

2. 流水线并行跨节点调度

原因:流水线并行通信量最低,不会受限于较低的节点间带宽。

策略:流水线的不同分段可以分布在不同节点上。

3. 数据并行是否跨节点

策略

• 如果张量并行没有跨节点,则数据并行也不需要跨节点;

• 如果张量并行跨节点,则数据并行需要跨节点进行同步。

综合总结

显存效率

• 流水线并行和张量并行通过任务拆分降低单 GPU 的显存需求;

• ZeRO-DP 进一步分布化存储优化器状态和梯度,进一步节省显存。

通信开销

• 优先减少跨节点通信:将模型并行组(特别是张量并行组)优先安排在节点内。

• 流水线并行适合跨节点调度。

• 张量并行的设计决定了数据并行的通信需求,合理分配可以降低数据同步开销。

计算效率

• 模型划分不宜过细,以避免通信开销过大。

• 数据并行结合 ZeRO 技术可以实现显存效率和通信效率的平衡。

参考

PS:看了猛猿答主的文章,非常厉害的up。知乎地址为:https://www.zhihu.com/people/lemonround

https://zhuanlan.zhihu.com/p/617133971

https://zhuanlan.zhihu.com/p/622212228

https://zhuanlan.zhihu.com/p/613196255

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245221.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决 Android 单元测试 No tests found for given includes:

问题 报错: Execution failed for task :testDebugUnitTest. > No tests found for given includes: 解决方案 1、一开始以为是没有给测试类加public修饰 2、然后替换 Test 注解的包可以解决,将 org.junit.jupiter.api.Test 修改为 org.junit.Tes…

知识见闻 - 数学: 均方根 Root Mean Square

What is Root Mean Square (RMS)? 在统计学上,均方根(RMS)是均方的平方根,而均方是一组数值的平方的算术平均数。均方根也称为二次均值,是指数为 2 的广义均值的一种特例。均方根也被定义为基于一个周期内瞬时值的平方…

寻找用户推荐人(考点:ifnull)【SQL+Pandas】

今天尝试刷一下力扣的sql面试题&#xff0c;这个写法我也是第一次见 题目是 我们需要在这个表中查出referee_id&#xff01;2的 正确写法是 select name from customer where ifnull(referee_id,0) ! 2 -- 不等于还可以这么写&#xff1a;<>

Java Database Connectivity (JDBC + Servlet)

Java Database Connectivity (JDBC)是一个Java API&#xff0c;用于与数据库进行连接和操作。通过JDBC&#xff0c;Java程序可以与各种关系型数据库进行通信&#xff0c;执行SQL查询、更新数据等操作。 一、Java连接数据库两种方式 ​​​​​ ​​ 二、Java中…

【Python爬虫实战】深入解析 Scrapy 爬虫框架:高效抓取与实战搭建全指南

&#x1f308;个人主页&#xff1a;易辰君-CSDN博客 &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、Srapy简介 &#xff08;一&#xff09;什么是Srapy &#xff08;二&#xff09;Scrapy 的设计目标 …

编程之路,从0开始:动态内存管理

Hello&#xff0c;大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路。 我们今天来学习C语言中的动态内存管理。 目录 1、为什么要有动态内存管理&#xff1f; 2、malloc和free &#xff08;1&#xff09;malloc函数 &…

初始Python篇(4)—— 元组、字典

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a; Python 目录 元组 相关概念 元组的创建与删除 元组的遍历 元组生成式 字典 相关概念 字典的创建与删除 字典的遍历与访问 字典…

d3-ease 的各种方法和示例

D3.js中的d3-ease模块提供了多种缓动函数&#xff0c;用于实现平滑的动画过渡效果。这些缓动函数通过扭曲时间控制动画中运动的方法&#xff0c;使得动画更加自然和流畅。以下是D3中常见的一些ease方法和示例代码&#xff1a; 线性缓动&#xff08;linear&#xff09;&#xff…

HTML5拖拽API学习 托拽排序和可托拽课程表

文章目录 前言拖拽API核心概念拖拽式使用流程例子注意事项综合例子&#x1f330; 可拖拽课程表拖拽排序 前言 前端拖拽功能让网页元素可以通过鼠标或触摸操作移动。HTML5 提供了标准的拖拽API&#xff0c;简化了拖放操作的实现。以下是拖拽API的基本使用指南&#xff1a; 拖拽…

Throwable、IO流、Java虚拟机

Error和Exception stream结尾都是字节流&#xff0c;reader和writer结尾都是字符流 两者的区别就是读写的时候一个是按字节读写&#xff0c;一个是按字符。 实际使用通常差不多。 在读写文件需要对内容按行处理&#xff0c;比如比较特定字符&#xff0c;处理某一行数据的时候一…

lanqiao OJ 364 跳石头

这个题目的条件是移动的石头数量给定&#xff0c;但是最小移动距离的最大值我们不知道&#xff0c;所以要通过mid来“猜测”。如果当前的mid需要移动的最小石头数量超过给定数&#xff0c;则mid不成立&#xff0c;需要缩小&#xff0c;反之则增大mid&#xff0c;直至找到一个最…

「一」HarmonyOS端云一体化概要

关于作者 白晓明 宁夏图尔科技有限公司董事长兼CEO、坚果派联合创始人 华为HDE、润和软件HiHope社区专家、鸿蒙KOL、仓颉KOL 华为开发者学堂/51CTO学堂/CSDN学堂认证讲师 开放原子开源基金会2023开源贡献之星 「目录」 「一」HarmonyOS端云一体化概要 「二」体验HarmonyOS端云一…

Shell基础(7)

声明&#xff01; 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本人以及泷羽sec团…

音视频pts/dts

现在的视频流有两个非常重要的时间戳&#xff0c;pts和dts&#xff0c;其中pts是显示的时候用&#xff0c;dts在解码的时候用。 pts很好理解&#xff0c;按照pts的顺序以及duration不间断的display就可以了。 dts在解码的时候用&#xff0c;那么这句话怎么理解&#xff0c;解…

sql server怎样用sql profiler捕获带变量值的慢sql

一 新建跟踪 点击工具-SQL Server Profiler&#xff1a; 点击文件-新建跟踪的按钮&#xff1a; 在‘事件选择’选项卡只选择如下两项内容&#xff08;RPC:Completed,SQL:BatchCompleted&#xff09;&#xff0c;把多余的取消勾选&#xff1a; 然后勾选上面截图中右下方的‘显示…

二叉树——输出叶子到根节点的路径

目录 代码 算法思想 例子 思维拓展 代码 int LeaveBit(Bitree T,int flag,int g) {if (!T) {return 0;}if (T->rchild NULL && T->lchild NULL) {//cout << "empty:" << T->data << endl;s.push(T->data);while (!s.emp…

深入理解Spring(三)

目录 2.1.3、Spring配置非自定义Bean 1)配置Druid数据源交由Spring管理 2)配置Connection交由Spring管理 3)配置日期对象交由Spring管理 4)配置MyBatis的SqlSessionFactory交由Spring管理 2.1.4、Bean实例化的基本流程 1)Bean信息定义对象-BeanDefinition 2)DefaultLi…

React Native 基础

React 的核心概念 定义函数式组件 import组件 要定义一个Cat组件,第一步要使用 import 语句来引入React以及React Native的 Text 组件: import React from react; import { Text } from react-native; 定义函数作为组件 const CatApp = () => {}; 渲染Text组件

SpringBoot,IOC,DI,分层解耦,统一响应

目录 详细参考day05 web请求 1、BS架构流程 2、RequestParam注解 完成参数名和形参的映射 3、controller接收json对象&#xff0c;使用RequestBody注解 4、PathVariable注解传递路径参数 5、ResponseBody&#xff08;return 响应数据&#xff09; RestController源码 6、统一响…

Linux:confluence8.5.9的部署(下载+安装+pojie)离线部署全流程 遇到的问题

原文地址Linux&#xff1a;confluence8.5.9的部署&#xff08;下载安装破ji&#xff09;离线部署全流程_atlassian-agent-v1.3.1.zip-CSDN博客 背景&#xff1a;个人使用2核4g 内存扛不住 总是卡住&#xff0c;但是流程通了所以 直接公司开服务器干生产 个人是centos7 公司…