Megatron-LM:Transformer模型专用分布式张量模型并行方法

news2024/10/7 12:22:17

论文标题:Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

论文链接:https://arxiv.org/abs/1909.08053

论文来源:NVIDIA

一、概述

随着自然语言处理领域预训练语言模型的规模变得越来越大,它们超过了现代处理器的内存限制,需要额外的内存管理技术,如激活检查点(activation checkpointing)。一些广泛使用的优化算法如Adam需要额外的内存来存储其中的动量和其他优化器状态,这降低了可以有效训练的模型大小。几种模型并行方法通过划分模型来克服这个限制,这样权重及其相关的优化器状态就不需要同时驻留在处理器上。例如,GPipe和Mesh-Tensorflow提供了不同种类的模型并行框架。但是,它们需要重写模型,并依赖于仍在开发中的自定义编译器和框架。

在这项工作中,我们使用简单高效的层内模型并行(intra-layer model-parallelism)来实现模型并行。我们利用transformer基础语言模型中的固有结构来实现一个简单的模型并行,它可以在PyTorch中高效训练,而无需自定义C++代码或编译器。这种方法与GPipe等基于流水线的模型并行方法是正交的(可以同时使用,相互独立而不冲突)。

为了证明我们方法的可扩展性,我们建立了一个baseline,在单个NVIDIA V100 32GB GPU上训练了一个12亿参数的模型,保持39 TeraFLOPs的计算速度。这是DGX-2H服务器中单GPU配置的理论峰值浮点运算能力的30%,因此是一个很强的基线。通过在512个GPU上以8路模型并行将模型扩展到83亿参数,我们实现了高达每秒15.1PetaFLOPs的持续计算速度。与单GPU情况相比,这代表了76%的扩展效率。下图显示了更详细的扩展结果。

12e22f6668caf742b8e8b6bba3152b54.jpeg
拓展效率

为了分析模型大小扩展对准确率的影响,我们训练了自回归的GPT-2语言模型以及自编码的BERT双向transformer,并在几个下游任务上对其进行评估。我们发现,随着模型大小的增加,现有的BERT架构会导致模型退化。我们通过重新排列transformer层中的层标准化和残差连接来克服这一挑战,结果表明,进行这一改变之后,下游任务的开发集结果随着模型大小的增加单调提升。此外,我们的模型在WikiText103、LAMBADA上的闭包式预测准确率以及RACE阅读理解数据集上都取得了SOTA的测试集结果。

二、背景

有两种主要的范式可以将深度神经网络的训练扩展到多个硬件加速器中:(1)数据并行,其中minibatch的训练被划分到多个worker中;(2)模型并行,其将模型的内存使用和计算分布到多个worker中。通过按可用worker的数量成比例增加minibatch大小(即weak scaling),可以观察到训练数据吞吐量的近乎线性的扩展。但是,大批量训练会使优化过程更复杂,可能导致准确率降低或收敛时间延长,反而抵消了增加训练吞吐量带来的好处。进一步的研究开发了各种技术来缓解这些影响,降低大型神经网络的训练时间。为了进一步扩展训练规模,一些并行工作将数据并行和激活检查点结合起来:在前向传播中重新计算而不是存储激活,以减少内存需求。

然而,这些技术在其可以处理的问题大小方面存在一个根本的局限:模型必须完全能够在一个worker上进行处理。随着BERT和GPT-2等语言模型大小和复杂度的增加,神经网络已经接近了现代硬件加速器的内存容量。这个问题的一个解决方案是采用参数共享来减少模型的内存占用,但这限制了模型的总体容量。我们的方法是利用模型并行将模型划分到多个加速器上。这不仅减轻了内存压力,而且独立于微批量大小增加了并行度。

在模型并行中,还有两种进一步的范式:逐层流水线并行(layer-wise pipeline parallelism)和更通用的分布式张量计算(distributed tensor computation)。在流水线模型并行中,一组操作首先在一个设备上执行,然后将输出传递到流水线中的下一个设备,在下一个设备上执行不同的另一组操作。一些方法与流水线并行结合使用参数服务器。然而,这些方法存在一致性问题。TensorFlow中的GPipe框架通过使用同步梯度下降来解决这种一致性问题。这种方法需要额外的逻辑来处理这些通信和计算操作的高效流水线,并受到减少效率的流水线bubble的影响,或者对优化器本身的更改会影响准确性。

分布式张量计算是一种正交的、更通用的方法,它将张量操作划分到多个设备上以加速计算或增加模型大小。编排这种并行计算的深度学习框架FlexFlow提供了一种选择最佳并行策略的方法。最近,Mesh-TensorFlow在TensorFlow中引入了一种指定分布式张量计算的通用类的语言。并行维度由终端用户在这一语言中指定,生成的图由适当的集体原语编译。我们利用类似于Mesh-TensorFlow中的见解,并利用transformer注意力头的并行计算来并行化我们的transformer模型。但是,我们没有实现一个用于模型并行的框架和编译器,而是仅对现有的PyTorch transformer实现进行了一些有针对性的修改。我们的方法很简单,不需要任何新的编译器或代码重写,可以通过插入几个简单的基元来完全实现,如下一节所述。

三、方法

我们利用transformer网络的结构来创建一个简单的模型并行实现,只需要添加几个同步原语。一个transformer层由一个自注意力块和一个两层多层感知器(MLP)组成,如下图所示。我们在这两个块中分别引入模型并行。

c6f12e2f9549319f4ccbad8eaacbd00a.jpeg
模型架构

首先详细说明MLP块。该块的第一部分是一个GEMM(General Matrix Multiplication),后面是一个GELU非线性层:

并行化GEMM的一种方法是按行切分权重矩阵,按列切分输入:

这将得到。由于GeLU是非线性函数,,这种切分方式需要在GeLU函数之前进行同步。

另一种切分方法是按列切分。这种切分方式允许独立地对每个切分后的GEMM的输出应用GeLU非线性函数:

这种方法的优点是去除了一个同步点。因此,我们采用这种列并行方式切分第一个GEMM,直接将GeLU层的输出作为第二个GEMM(这个GEMM以按行切分的方式并行)的输入,而不需要任何通信,如下图(a)所示。

d67f26b35663d284e73eb724246a9550.jpeg
并行方法

第二个GEMM的输出通过dropout层之前先跨GPU进行reduce。这种方法将MLP块中的两个GEMM切分到不同的GPU上,前向传播中只需要一个all-reduce操作(g操作符),反向传播中也只需要一个all-reduce(f操作符)。这两个操作符互为共轭,可以通过PyTorch中的几行代码实现。例如,f操作符的实现如下:

class f(torch.autograd.Function):
    def forward(ctx, x):
        return x
    
    def backward(ctx, gradient):
        all_reduce(gradient) 
        return gradient

如上图3(b)所示,对于自注意力块,我们利用了多头注意力操作中固有的并行性,以列并行的方式切分key(K)、query(Q)和value(V)对应的GEMM,这样每个注意力头对应的矩阵乘法在一个GPU上本地计算。这允许我们在GPU之间切分每个注意力头的参数和工作量,并且不需要任何直接的通信就可以完成自注意力。在自注意力之后的输出线性层的GEMM(自注意力之后)沿其行进行并行化,直接获取并行注意力层的输出,而不需要GPU之间的通信。MLP和自注意力这两种块的并行方法都融合了两组GEMM,消除了中间的同步点,从而获得了更好的扩展性。这使我们能够使用仅两个all-reduce在前向路径中完成transformer层中的所有GEMM的计算,并在反向路径中也使用两个all-reduce(如下图所示)。

a3c8e6266e497dae2b8ba587bf87a679.jpeg
通信操作

Transformer语言模型的输出嵌入维度为隐层维度乘以词汇表大小。由于词汇表的大小是万级的(例如,GPT-2使用的词汇表大小为50,257),将输出嵌入GEMM并行化是有益的。但是,在transformer语言模型中,输出嵌入层与输入嵌入共享权重,这需要同时修改这两者。我们沿词典维度切分输入嵌入权重矩阵为 (列向)。由于每个切分部分现在只包含嵌入表的一部分,在输入嵌入之后需要一个all-reduce(g操作符)。对于输出嵌入,一种方法是执行并行GEMM 以获得logits,添加一个all-gather ,并将结果发送到交叉熵损失函数。但是,在这种情况下,all-gather将通信个元素(其中是batch大小,是序列长度),由于词汇表大小很大,这会产生巨大的通信量。为了减小通信量,我们将并行GEMM 的输出与交叉熵损失函数融合,这将维度降低到。通信标量损失而不是logits极大地减少了通信量,这极大地提高了我们的模型并行方法的效率。

我们的模型并行方法的很大一部分可以归纳为针对减少通信并保持GPU计算限度的技术。对于dropout、层标准化、残差连接等计算,我们选择跨GPU重复计算,而不是让一个GPU计算部分然后广播结果到其他GPU。具体来说,我们在每个GPU上维护层标准化参数的副本,并在模型并行区域的输出上运行dropout和残差连接,然后将其馈送作为下一个模型并行区域的输入。对于模型的优化,我们允许每个模型并行worker优化自己的参数集合。由于所有的值要么是本地的(比如GEMM的参数),要么被复制在多个GPU上(比如层标准化的参数),所以在这种形式中不需要通信更新的参数值。

总而言之,我们上述的方法实现起来很监督,只需要在前向和反向传播中添加几个额外的all-reduce操作。它不需要编译器,与GPipe等方法提倡的流水线模型并行是正交互补的。

四、实验

  1. 并行效率

2762b22699f947197e68e10f2cd44211.jpeg
实验
  1. GPT-2实验

b4daec6ed03dcdfff1f3382f3b5f704b.jpeg
实验
cfa06617f0ae9d128d720cf853af58bc.jpeg
实验
f7536d013b51df0d466f3909313b2540.jpeg
实验
  1. BERT实验

在进行BERT相关的实验时,我们发现到模型参数规模超过BERT-large(336M)时会出现模型性能退化。我们发现按照下图7的方式重排层标准化与残差连接的顺序后可以解决这个问题。

d72392b51e38ba7bd767fe30617389d4.jpeg
实验
0bb562a8e3254ac46304e90c49c941bc.jpeg
实验
fdda36f21e32bb7dc0861fd513b03f66.jpeg
实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/783841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

access跨库查询

服务器上面安装了安全狗、Waf这样的安全软件,没有办法下载数据库内容 都在同一个服务器的不同网站,从11查12的数据库 12数据库路径在C:\wwwtest\2AspCMS\AspCms_data 把data.asp后缀改成mdb就能看到里面的表了,data.mdb如下 语句 当前网站…

CSS自学框架之表格和项目列表

表格和项目列表很直观的显示数据,是我们web开发中经常遇到的最简单表现信息形式。具体代码如下: 一、css代码 ul,ol{margin-left: 1.25em;} /* - 表格 */.myth-table{width: 100%;overflow-x: auto;overflow-y: hidden;border-radius: var(--radius);…

《Pytorch深度学习和图神经网络(卷 2)》学习笔记——第二章

基于图片内容的处理任务 主要包括目标检测、图片分割两大任务。 目标检测:精度相对较高,主要是以检测框的方式,找出图片中目标物体所在坐标。模型运算量相对较小,相对较快。 图片分割:精度相对较低,主要是…

【工具-jmeter】jmeter 入门级 demo 练习

目录 前言: 1. Jmeter 准备 1.1 jmeter 安装包下载 1.2 jmeter 启动 1.3 jmeter 语言选择 2. Jmeter 运行 1 个 Web 请求的 demo 2.1 添加 1 个 Thread Group 线程组 2.2 添加 1 个 HTTP Request 请求 2.3 乱码问题 2.4 添加 1 个 HTTP Header 请求头 2.…

开发中遇到的 cookie 问题

1. cookie 无法跨域携带问题 尽管已经登录,但是请求接口返回状态码:202,msg: 未登录,如下图所示; 1.1 XMLHttpRequest.withCredentials未设置 如果需要跨域 AJAX 请求发送 Cookie,需要withCre…

【UE】虚幻网络同步

UE网络官方文档链接:https://docs.unrealengine.com/5.2/zh-CN/networking-overview-for-unreal-engine/ 虚幻的网络模式 服务器作为游戏主机,保留一个真实授权的游戏状态。换句话说,服务器是多人游戏实际发生的地方。客户端会远程控制其在服…

SpringBoot Redis 使用Lettuce和Jedis配置哨兵模式

Redis 从入门到精通【应用篇】之SpringBoot Redis 配置哨兵模式 Lettuce 和Jedis 文章目录 Redis 从入门到精通【应用篇】之SpringBoot Redis 配置哨兵模式 Lettuce 和Jedis前言Lettuce和Jedis区别1. 连接方式2. 线程安全性 教程如下1. Lettuce 方式配置1.1. 添加 Redis 和 Let…

Java项目里添加python解析器

java项目里配置了SDK为1.8,添加python文件时会无法解析。 提示让模块配置Python解析器,点击 配置python解析器 ,弹出如下: 应用即可。

【机器学习】异常检测

异常检测 假设你是一名飞机涡扇引擎工程师,你在每个引擎出厂之前都需要检测两个指标——启动震动幅度和温度,查看其是否正常。在此之前你已经积累了相当多合格的发动机的出厂检测数据,如下图所示 我们把上述的正常启动的数据集总结为 D a t…

【Linux命令200例】chattr改变文件的扩展属性

🏆作者简介,黑夜开发者,全栈领域新星创作者✌,2023年6月csdn上海赛道top4。 🏆本文已收录于专栏:Linux命令大全。 🏆本专栏我们会通过具体的系统的命令讲解加上鲜活的实操案例对各个命令进行深入…

【人工智能】博弈、极小极大值、α-β剪枝、截断测试

文章目录 博弈极小极大值α-β剪枝截断测试博弈 极小极大值 假设两个玩家都以最大化自身利用进行博弈举例: 计算机假设在它移动后,对手会选择最小化的行动计算机在考虑自己的行动和对手的最佳行动后选择最佳行动算法实现

【python】在matlab中调用python

参考 Matlab调用Python - 知乎 (zhihu.com) 说一下我犯的错误: 1、电脑上有没有python都可以,我以为anaconda里的python不行,又重新下了一个python3.8 实际上导入的时候可以用 pyversion(D:\myDownloads\anaconda\envs\pytorch38\pytho…

Docker 全栈体系(五)

Docker 体系(高级篇) 二、DockerFile解析 1. 是什么? Dockerfile是用来构建Docker镜像的文本文件,是由一条条构建镜像所需的指令和参数构成的脚本。 1.1 概述 1.2 官网 https://docs.docker.com/engine/reference/builder/ 1…

freeBSD:ssh登录root

/etc/inetd.conf ee /etc/inetd.conf 去掉# /etc/rc.conf ee /etc/rc.conf 添加一句 sshd_enable"YES" /etc/ssh/sshd_config vi /etc/ssh/sshd_config 22行可以修改端口号,非必要就默认22 36行 去掉# 后面修改成 yes 61 PasswordAuthentication…

Python处理Elasticsearch

简介:Elasticsearch 是一个分布式、高扩展、高实时的搜索与数据分析引擎。它能很方便的使大量数据具有搜索、分析和探索的能力。充分利用Elasticsearch的水平伸缩性,能使数据在生产环境变得更有价值。Elasticsearch 的实现原理主要分为以下几个步骤&…

Golang数据库连接池技术原理与实现

1 为什么需要连接池? 如果不用连接池,而是每次请求都创建一个连接是比较昂贵的,因此需要完成3次tcp握手。同时在高并发场景下,由于没有连接池的最大连接数限制,可以创建无数个连接,耗尽文件描述符。连接池…

【软件测试】什么是selenium

1.seleniumJava环境搭建 前置条件: Java最低版本要求为8,浏览器使用chrome浏览器 1.1下载chrome浏览器 https://www.google.cn/chrome/ 1.2查看浏览器版本 点击关于Google chrome. 记住版本的前三个数. 1.3下载浏览器驱动 http://chromedriver.chromium.org/downloads 下载…

JS案例:在浏览器实现自定义菜单

目录 前言 设计思路 BaseElem Menu CustomElement BaseDrag Drag Resize 最终效果 总结 相关代码 前言 分享一下之前公司实现自定义菜单的思路,禁用浏览器右键菜单,使用自定义的菜单将其代替,主要功能有:鼠标右键调出菜…

二、基本数据类型和表达式

2.1数据类型 数据类型占用字节数取值范围bool1true 或 falsechar1-128 到 127 或 0 到 255 (取决于是否带符号)unsigned char10 到 255short2-32,768 到 32,767unsigned short20 到 65,535int4-2,147,483,648 到 2,147,483,647unsigned int40 到 4,294,…