超出认知的数据压缩 用1-bit数据来表示32-bit的梯度 语音识别分布式机器学习 梯度压缩 论文精读

news2024/11/17 3:12:01

说明

介绍 1 − b i t 1-bit 1bit论文内容。
原文链接:1-bit stochastic gradient descent and its application to data-parallel distributed training of speech DNNs | Semantic Scholar

ABS

实验证明在分布式机器学习的过程中能够通过将同步所传递的梯度进行量化(从 32 32 32位到 1 1 1位),同时加上量化容错机制能够很好的加快整个训练过程。最值得注意的时,模型的收敛情况并没有收到影响。

本文将该发现与 S G D , A d a G r a d SGD,AdaGrad SGD,AdaGrad, 自动批量选择, 双缓冲区,模型并行技术相结合设计了。意料之外的是,量化甚至能让精度有一个很小的提升。

不同的模型在使用该方法后加速效果明显。

1 Introduction and Related Work

现在上下文相关的深度神经网络模型基本都是通过反向传播进行训练(常用的例如 S G D SGD SGD),而上述的模型的训练是非常耗时的。

注:上下文相关的神经网络应该是指类似与解决翻译问题所用的模型。

使用分布式训练上述的模型已经取得了一定成功。(这里介绍了一部分相关工作,这里不进行赘述)

在分布式训练中,一个非常重要的问题就是带宽瓶颈的问题:即各个节点训练一定时间之后需要进行数据的发送(这不管是模型并行还是数据并行都会发生),而这个过程交换的数据量往往取决于模型的大小,而现在使用的模型一般都比较大,所以网络的带宽成为了分布式机器学习的一个瓶颈。

缓解这个方法一般有两类方法:

  1. 增加批量大小,这样每轮会进行的计算会增加,而参与通信的时间会相对减少;
  2. 减少每一次通信的数据交换量。

本文提出的方法属于第二种,同时本文更多关注的是将量化方法应用于数据并行的分布式训练中。

2 Data-Parallel Deterministically Distributed SGD Training

通常采用 B P BP BP进行模型的训练,过程可以被概述为下列的两个方程:
KaTeX parse error: Undefined control sequence: \part at position 114: …^{t+N-1}\frac {\̲p̲a̲r̲t̲ ̲F_{\lambda}(o(\…
上述过程对应反向传播求导和梯度下降进行更新。

2.1 Data-Parallel Distributed SGD

上述的方程可以进行分布式计算,只需要将方程 ( 2 ) (2) (2)的梯度计算部分按照节点的个数,让每个处理一部分数据,然后进行梯度的计算,计算完成后,相加即可得到某个时刻的梯度。

p e r f e c t   o v e r l a p perfect\ overlap perfect overlap:选取合适的工作节点个数能够让计算和数据交换进行最优的并行化,也就是通信资源和计算资源同时饱和。
T c a l c ( K ^ ) = T c o m m ( K ^ ) T_{calc}(\hat K)=T_{comm}(\hat K) Tcalc(K^)=Tcomm(K^)
也就是选取 K K K让每轮的通信时间和计算时间相等,这样当通信完成是下一轮的计算也完成可以继续进行通信。

计算时间可以被分解成固定的计算时间和可变的计算时间两部分。固定的计算时间一般是完成一些必要的操作时间所花费的时间,而可变计算时间往往会根据模型的大小,批量大小发生改变。

F i g u r e   1 Figure\ 1 Figure 1展示了当节点个数下降到通信时间与计算时间相等时获得的加速。

在这里插入图片描述

K ^ \hat K K^可以通过如下公式进行计算:
K ^ = N / 2 ∗ T c a l c f r m + C ∗ T c a l c p o s t 1 Z T c o m m f l o a t − T c a l c u p d \hat K = \frac {N/2*T^{frm}_{calc}+C*T^{post}_{calc}}{\frac 1 ZT^{float}_{comm}-T^{upd}_{calc}} K^=Z1TcommfloatTcalcupdN/2Tcalcfrm+CTcalcpost
各参数的含义:(设模型的大小为 M M M

  • N N N:批量的大小。
  • T c a l c f r m T^{frm}_{calc} Tcalcfrm:计算梯度所花费的时间,大约为 M F L O P S \frac M {FLOPS} FLOPSM
  • C C C:一个常量,使用特殊方法处理所携带的常数。
  • T c a l c p o s t T^{post}_{calc} Tcalcpost:后续进行处理所需要的时间,例如使用 A d a G r a d , m o m e n t u m AdaGrad, momentum AdaGrad,momentum需要花费的额外时间(需要和 C C C相乘才能获取实际时间),大约为 M r \frac M r rM r r r为内存的带宽。
  • Z Z Z:数据传输之前的压缩率,在本文中将 32 32 32位压缩成 1 1 1位,那么 Z = 32 Z=32 Z=32
  • T c o m m f l o a t T_{comm}^{float} Tcommfloat:不压缩梯度进行通信所需要花费的时间,大约为 M b \frac M b bM b b b为两个结点之间的网络带宽。
  • T c a l c u p d T^{upd}_{calc} Tcalcupd:参数服务器将受到的梯度用于更新所需要的时间,大约为 M r \frac M r rM

上面的公式的理解:

分母是实际进行通信的时间,分子计算所有 N N N个数据会花费的时间,而前面的推断需要让每个结点的计算时间和通信时间相等,由于在分布式机器学习中每个结点拥有的数据量是相同的,所以分子的计算时间除以 K K K就是每个结点的计算时间,让两者相等并将 K K K移到方程的一边即可得到上面的公式。上面计算时间的计算出现了一个除以 2 2 2,这在下一小节会解释。

2.2 Double Buffering with Half Batches

上一节的公式里面出现了除以 2 2 2是因为使用了双缓冲区。整个过程将每一个批量会分成大小相等的两部分放入缓冲区中,当前一部分计算完成后,开始计算后一部分,这个时候前一部分即可进行通信,当后一部分计算完成的时候,即可进行通信,此时又可以从缓冲区读取下一个批量的前一部分,所以每次实际上只计算了整个批量的一半,所以上面会出现除以 2 2 2

2.3 Potential Faster-Than-Fixed-Cost Communication

当通信时间降到固定计算时间之下时,那么此时上面的计算公式将不再适用。此时整个系统的限制在于固定的计算时间,通信很难饱和。此时在使用双缓冲区就没有什么意义了,因为此时的通信带宽是足够的,而双缓冲区是为了缓解通信的压力。

2.4 Relation to Hogwild/ASGD

异步更新能够增加并行度,但是并没有改变基础性的东西。

3 1-Bit SGD with Error Feedback

将需要交换的数据进行压缩,将其压缩为一位的数据,这样能够减少数据的交换量,从而减少通信的瓶颈。

不过如果只是简单的将数据压缩成一位,然后任由 S G D SGD SGD进行缺失数据的修正,那么这是很容易导致模型发散的。

为了防止模型的发散提出了错误反馈机制,该机制并不会将丢失的精度给丢弃掉,而是会对其进行记录下一次更新的时候需要同时考虑之前丢失的精度。

整个过程可以用下面的公式来表述:
G q u a n t ( t ) = Q ( G ( t ) + Δ ( t − N ) ) Δ ( t ) = G ( t ) − Q − 1 ( G q u a n t ( t ) ) \begin{aligned} &G^{quant}(t) = Q(G(t)+\Delta(t-N))\\ &\Delta(t)=G(t)-Q^{-1}(G^{quant}(t)) \end{aligned} Gquant(t)=Q(G(t)+Δ(tN))Δ(t)=G(t)Q1(Gquant(t))
参数说明:

  • G G G:本轮计算出来的梯度。
  • Q Q Q:量化函数,具体的,如果输入大于 0 0 0,则量化成 1 1 1,输出小于 0 0 0,则量化成 0 0 0,这样就可以将原本需要用 32 32 32位表示的梯度只用 1 1 1位进行表示。
  • G q u a n t G^{quant} Gquant:量化后的梯度。
  • Q − 1 Q^{-1} Q1:反量化函数,输入只能是 0 0 0或者 1 1 1,当输入是 1 1 1的时候输出是 1 1 1,当输入是 0 0 0的时候,输出是 − 1 -1 1
  • Δ \Delta Δ:梯度误差。

上述的过程就是压缩之后会记录压缩误差,压缩误差在下一轮的时候继续参与压缩。

3.1 Aggregating the Gradients

本文使用的聚合算法的复杂度关于结点个数是 O ( 1 ) O(1) O(1)(这里的描述有一点奇怪,但是大概的意思就是每个结点聚合的数据量为所有梯度的一部分)。

具体的过程如下:

  • 如果有 K K K个结点,那么每个结点会处理 1 K \frac 1 K K1的梯度。
  • 每个结点会从其他的 K − 1 K-1 K1个结点接收属于自己处理的梯度部分。
  • 收到后结点聚合自己负责的这一部分梯度。
  • 聚合完成后分发给其他所有结点。

4 System Description

根据最佳节点个数的计算公式,至少存在三种方法能够提升并行度:(也就是通过改变变量让公式的结果变大)

  1. 提升 N N N:增加每轮处理的批量个数。
  2. 增加 Z Z Z:增加压缩度。
  3. 减少固定计算时间(固定计算时间减少意味着相同的计算时间中有更多的时间用于了梯度的计算)

本文的压缩算法属于第二种方法。

对于方法一本文的实验发现 N N N的增加是有一个限制的,当增大的太多时,整个模型可能会发散。同时本文发现一个成熟的模型能够处理的 N N N的值要大一些。

为了防止模型发散的情况,本文后面的实验会隔一定的时间增大 N N N而不是一开始就将 N N N设置的很大,这样能够防止模型发散或是准确率下降严重。

除了这些之外,学习率也是采用递减的方式。

最后使用了 A d a G r a d AdaGrad AdaGrad进行优化,这样模型会收敛的更快,同时这也使得批量大小能够进一步的增加。

本文的系统可以在三个不同的地方使用 A d a G r a d AdaGrad AdaGrad

  • 在本地梯度量化之前;(可能会导致不一致性,但是可能对量化有益)
  • 聚合结束后的数据交换期间;(可能会与量化冲突)
  • 使用动量平滑后;(可以减少内存的使用和固定计算时间但是效果不好,因为峰值被动量磨平了)

作者发现 A d a G r a d AdaGrad AdaGrad在量化后动量平滑前使用效果最好。

为了也利用方法 3 3 3,本文在使用数据并行的同时,在多 G P U GPU GPU上做了模型并行。

5 Experimental Results

实验的细节可以在原文中找到。
论文的实验做的都是语音识别相关的,所以并没有证明所有方面都适合该方法。

5.1 Cost Measurements

这一部分主要测量几个耗时。

T c o m m f l o a t T^{float}_{comm} Tcommfloat大约为 3 − 10 m s 3-10ms 310ms$。

T c a l c p o s t + T c a l c u p d = 18.2 m s T^{post}_{calc}+T^{upd}_{calc}=18.2ms Tcalcpost+Tcalcupd=18.2ms

T c a l c u p d ≈ 9 T^{upd}_{calc}\approx9 Tcalcupd9

T a b l e   1 Table\ 1 Table 1给出了 T c a l c f r m T^{frm}_{calc} Tcalcfrm与批量大小的关系。

在这里插入图片描述

5.2 Effect of 1-Bit Quantization

T a b l e   2 Table\ 2 Table 2展示了不同的模式下的三种方式的对比,可以看出 1 − b i t 1-bit 1bit的效果并没有受到太大的影响。

在这里插入图片描述

5.3 When to do AdaGrad?

T a b l e   3 Table\ 3 Table 3展示了在不同环节使用 A d a G r a d AdaGrad AdaGrad进行优化对于准确率的影响,可以看出应该在动量平衡之前使用错误率会更低,作者指出这可能是因为动量平滑减少了梯度的标准差从而导致 A d a G r a d AdaGrad AdaGrad的效果不好。

p a r t i a l   g r a d i e n t s partial\ gradients partial gradients代表量化前各个结点自身的梯度, a g g r e g a t e   g r a d i e n t aggregate\ gradient aggregate gradient代表量化后聚合的梯度。

在这里插入图片描述

5.4 Impact of MB-Size Selection and Double Buffering

这一部分讲的是选取特别大的批量的时间耗时。从 T a b l e   3 Table\ 3 Table 3中每一栏花费的时间可以得到一些信息。

第二行的时间比第一行小主要是因为:第一行的实验选取了较大的批量大小。

第四行增加到了四个结点进行数据进行,同时每个结点进行两个 G P U GPU GPU的模型并行导致整个的时间下降到 8.1 h 8.1h 8.1h.

第五行相比于第四行的下降则是因为第四行的实验每 24 h 24h 24h选择一次新的批量大小,而第五行每 72 h 72h 72h选择一次。

第六行与第五行对比可以发现在这种情况下并没有进一步带来速度的提升,不过当使用双重缓冲区了之后自动选择的批量大小将会有所下降。

第七行代表的是什么意思本人并没有看懂。

T a b l e   4 Table\ 4 Table 4展示了固定不同的批量大小下双重缓冲对于速度的影响。

在这里插入图片描述

5.5 Combination with Model Parallelism

下图展示了不同的数据并行与模型并行的速率,可以发现在显卡数量相同的情况下,只有 8 × 2 8\times2 8×2的时候模型并行提升了速率,在大多数情况下模型并并没有提升速率,这代表着在大多数情况下模型并行没有数据并行高效,这是因为如果使用模型并行则不能很好的利用缓存机制,模型并行的时候,每交换一次数据缓冲就会失效,而如果是更多的使用数据并行,那么就会减少一定缓存失效的次数。

在这里插入图片描述

5.6 Training a Production-Scale Model

表格中的 r e a l i g n realign realign代表将数据进行对齐。

在这里插入图片描述

6 Conclusion

将通信传播的通信量从 32 32 32位降为 1 1 1位,同时提出误差反馈机制保证模型的收敛。

一位的量化能够大大降低通信量从而减少通信所带来的瓶颈。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是企业商机管理 管理销售商机流程方法

现代企业发展道路上,市场竞争愈演愈烈,很多企业都开始重视客户信息化管理来促成销售交易,在销售管理中的商机需按照轻重缓急进行分类、跟进、监控,才能对商机进行有效管理。 从某种程度上来说,一个订单成功与否的关键…

SpringBoot整合XxlJob

SpringBoot整合XxlJob 1.XxlJob简介 官方网址:https://www.xuxueli.com/xxl-job XXL-JOB是一个分布式任务调度平台,其核心设计目标是开发迅速、学习简单、轻量级、易扩展。现已开放源代码并接入多家公司线上产品线,开箱即用。 为什么要使…

【10k~30k的区别】=== 功能测试、自动化测试、性能测试的区别

按测试执行的类型来分:功能测试、自动化测试、性能测试 1.功能测试 功能测试俗称点点点测试。初级测试人员的主要测试任务就是执行测试工程师所写的测试用 例,记录用例的执行状态及bug情况。与开发人员进行交互直到bug被修复。 功能测试理论…

Java查漏补缺(14)数据结构剖析、一维数组、链表、栈、队列、树与二叉树、List接口分析、Map接口分析、Set接口分析、HashMap的相关问题

Java查漏补缺(14)数据结构剖析、一维数组、链表、栈、队列、树与二叉树、List接口分析、Map接口分析、Set接口分析、HashMap的相关问题本章专题与脉络1. 数据结构剖析1.1 研究对象一:数据间逻辑关系1.2 研究对象二:数据的存储结构…

(pytorch进阶之路)Informer

论文:Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting (AAAI’21 Best Paper) 看了一下以前的论文学习学习,我也是重应用吧,所以代码部分会比较多,理论部分就一笔带过吧 论文作者也很良心的…

微服务的Feign到底是什么

Feign是什么 分区是一种数据库优化技术,它可以将大表按照一定的规则分成多个小表,从而提高查询和维护的效率。在分区的过程中,数据库会将数据按照分区规则分配到不同的分区中,并且可以在分区中使用索引和其他优化技术来提高查询效…

目标检测论文阅读:CBNet算法笔记

标题:CBNet: A Composite Backbone Network Architecture for Object Detection 期刊:TIP2022 论文地址:https://ieeexplore.ieee.org/document/9932281/ 官方代码:https://github.com/VDIGPKU/CBNetV2 作者单位:北京大…

【正点原子FPGA连载】第二十章AXI4接口之DDR读写实验 摘自【正点原子】DFZU2EG_4EV MPSoC之嵌入式Vitis开发指南

1)实验平台:正点原子MPSoC开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id692450874670 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html 第二十章AXI4接口…

如何查看Spring Boot各版本的变化

目录 1.版本 2.基础特性和使用 3.新增特性和Bug修复 1.版本 打开Spring官网,点进Spring Boot项目我们会发现在不同版本后面会跟着不同的标签: 这些标签对应不同的版本,其意思如下: GA正式版本,通常意味着该版本已…

VsCode安装PlatformIO 开发ESP arduino,买的板子或者随便ESP,PlatformIO添加Board(不是自定义Board)

这次主要记录怎么给新建选板子的时候没有的板子下程序 我这里是一块 WiFi Kit 32 (V3) PlatformIO里面只有到V2 先从头开始,安装PlatformIO 安装PlatformIO 直接搜索安装 安装有时候会比较慢,左侧出现蚂蚁图标之后点击会显示 右下角会提示正在安…

【神经网络】Transformer基础问答

1.Transforme与LSTM的区别 transformer和LSTM最大的区别就是LSTM的训练是迭代的,无法并行训练,LSTM单元计算完T时刻信息后,才会处理T1时刻的信息,T 1时刻的计算依赖 T-时刻的隐层计算结果。而transformer的训练是并行了&#xff0…

AndroidStudio打包HBuilderX的H5+项目为安卓App【一次过,无任何异常报错】

目录 1.查看HBuilderX的版本号 2.下载Dcloud上对应的安卓SDK 3.下载完安卓SDK后,我们解压它,注意不要放在任何有中文组成的文件夹中【是否有中文决定于你鼠标单击上面路径后,第一张图还没鼠标单击,第二张已鼠标单击&#xff0c…

【前端工程化】01-Node.js基础

Node.js基础认识NodeNode的定义Node的应用场景Node的安装和版本管理Node的基本操作Node.js执行文件Node的参数传递Node的REPL认识Node Node的定义 Node.js是一个基于V8 JavaScript引擎的JavaScript运行时环境 Node.js为JavaScript提供了一些服务器级别的操作API 文件读写网…

背靠“湘潭系”的谭新乔,能带领湖南裕能再上一个台阶吗?

文丨熔财经作者|kinki近日,磷酸铁锂正极材料龙头湖南裕能正式登陆A股,上市当天市值超过了400亿元,投资者中一签可赚1.49万元,可谓近年低迷的资本市场中一支“大肉签”。不过在 “开门红”之后,湖南裕能的股价便一路下挫…

ETL工具(kettle) 与 ETL产品(BeeloadBeeDI) 差之毫厘,谬以千里

E T L——是英文Extract-Transform-Load的缩写,用来描述将数据从来源端经过抽取(extract)、转换(transform)、加载(load)至目的端的过程。工具——原指工作时所需用的器具,后引申为达…

Clickhouse学习(一):MergeTree概述

MergeTree一、Clickhouse表引擎概述二、MergeTree表引擎<一>、ReplacingMergeTree引擎<二>、SummingMergeTree引擎<三>、AggregatingMergeTree引擎三、MergeTree分区一、Clickhouse表引擎概述 MergeTree表引擎:允许根据日期和主键创建索引 1、ReplacingMerge…

实践IC-GVINS: 以惯导为核心的GNSS-Visual-INS组合导航系统

视觉导航系统对环境比较敏感&#xff0c;受到光照变化、重复纹理、动态物体等影响&#xff1b;惯性导航系统(INS)则完全自主工作&#xff0c;不受外部环境影响&#xff0c;能够实现连续、高频的自主导航&#xff0c;但其误差发散较快。两者组合能够取长补短&#xff0c;形成视觉…

毕业设计 基于STM32单片机生理监控心率脉搏TFT彩屏波形曲线设计

基于STM32单片机生理监控心率脉搏TFT彩屏波形曲线设计1、项目简介1.1 系统构成1.2 系统功能2、部分电路设计2.1 STM32F103C8T6核心系统电路设计2.2心率检测电路设计2.3 TFT2.4寸彩屏电路设计3、部分代码展示3.1 ADC初始化3.2 获取ADC采样值3.3 LCD引脚初始化3.3 在LCD指定位置显…

15 Nacos客户端实例注册源码分析

Nacos客户端实例注册源码分析 实例客户端注册入口 流程图&#xff1a; 实际上我们在真实的生产环境中&#xff0c;我们要让某一个服务注册到Nacos中&#xff0c;我们首先要引入一个依赖&#xff1a; <dependency><groupId>com.alibaba.cloud</groupId>&l…

Android与flutter混合开发

这里我使用的android studio版本是2020.3.1&#xff1b;flutter版本2.5.3。此前在网上搜索的很多教教程版本都不一样&#xff0c;新版的IDE和SDK让我遇到了很多坑故这里整理一下。一、创建项目1.在Android项目中点击File->New->New Flutter Project。File->New->Ne…