使用Unit Scaling进行FP16 和 FP8 训练

news2024/10/7 6:51:26

Unit Scaling 是一种新的低精度机器学习方法,能够在没有损失缩放的情况下训练 FP16 和 FP8 中的语言模型。

使用FP16和BFLOAT16替代FP32可以将内存、带宽和计算需求的大幅减少,这也是目前越来越大的模型所需要的。

背景介绍

随着支持fp8的硬件的发展,在不影响效率的前提下,进一步降低精度也成为了可能。但是这些较小的、低精度的格式在实践中并不总是易于使用。对于FP8来说则更加困难。因为这些较小的格式通常将用户限制在更窄的可表示值范围内。为了解决这个问题,Graphcore Research开发了一种新方法,我们称之为Unit Scaling。

上图为FP16和FP8中量化的不同尺度的正态分布的信噪比(SNR)。对于较小的数字格式,信号在较窄的尺度范围内较强。

Unit Scaling是一种模型设计技术,它在初始化时根据缩放原则进行操作:也就是说对激活、权重和梯度的单位方差进行缩放。模型会自动生成针对低精度数字格式进行良好缩放的张量。并且使用更简单,并最大限度地减少这些表示的缺点,与低精度训练的替代方法不同,它引入的开销和额外的复杂性很小。

论文的方法取得了突破性的成果:首次在 FP16 甚至 FP8 中准确地训练了 BERT Base 和 BERT Large 模型,并且没有缩放的性能损失。模型也不需要额外的超参数,可以直接使用。

对于关心结果并因此希望在 FP16 和 FP8 中进行训练的人来说,Unit Scaling提供了一个直接的解决方案。

FP16/FP8训练的现有方法

FP16和FP8训练需要某种形式的缩放来保持值在范围内。目前的做法如下:

1、(静态)损失缩放

缩小范围对于训练期间的向反向传播是具有挑战性的通常会导致梯度下溢。为了解决这个问题,最常见的方法是将损失乘以超参数以增加梯度的大小 [1]。由于没有原则性的方法来提前选择损失的规模,所以这个超参数通常需要多次运行。

2、自动损失缩放

通过基于运行时的梯度溢出(或直方图)[2] 动态调整损失比例,可以避免超参数扫描的需要。但是这种自动方案会增加开销和复杂性。

3、张量缩放

上述方法的另一个缺点是它们只提供单一的全局损失尺度。另外一种解决方案是根据张量统计 [3] 重新缩放值。这也是一种自动/运行时方案,很复杂且难以有效实施。

Unit Scaling

Unit Scaling 在前向和反向传播中引入局部缩放因子控制值的范围。选择的范围是基于每个操作符如何影响值规模的理解,而并不是使用运行时分析得到的。通过选择正确的比例因子,每个操作都大致保持其输入的比例。通过将其应用于所有操作,可以控制整个模型中传播初始(单位)比例,从而实现全局的缩放。

这种方法比自动缩放方案更简单,因为唯一的额外开销是应用缩放因子。对于 BERT Large,这会将 FLOPs 增加 0.2%,应该可以忽略不计。

模型可以通过应用以下方法进行Unit Scaling:

  • 用单位方差初始化无偏差参数
  • 计算所有操作的理想比例因子
  • 识别非切边并限制使用它们的操作具有相同的缩放比例
  • 用加权的加法替换加法

下面我们将更详细地解释这些规则。

1、理想的比例因子

我们可以对一些操作进行数学分析,以确定它们如何影响输入的方差。

比如基本矩阵乘法 XW(其中 X 是 (b × m) 矩阵,W 是 (m × n) 矩阵)的输出方差为 σ(X)² · σ(W)² · m。要缩放此操作,我们必须确保 σ(X)² = σ(W)² = 1,然后将 1/√m 乘法添加到输出。

对于反向传播,需要引入了两个新的矩阵乘法,理想的比例因子为 1/√n 和 1/√b。其他操作也可以类似分析,输出方差不容易分析,所以可以使用经验方法来找到缩放因子。

在论文作者中提供了更详细的分析,以及常见操作的概要及其理想的比例因子。

2、切边

直接将这些理想的比例因子应用于正向和反向传播中会产生无效的梯度。为了避免这种情况,某些操作需要使用共享的缩放因子。

我们使用前向计算图并找到所有没有用切边表示的变量(如果去掉这些边,会将图分割成两个不相连的更小的图)。比如,下面是一个transformer的FFN层:

在权重、输入和输出变量上有切边。该图还显示了为第二个matmul的反向传播生成的梯度操作(我们只考虑正向图的切边)。

因为 x₃ 不是切边,所以可以限制 ∇x₃ 的 matmul 使用与前向传播中相同的比例因子,但是由于 w2 是切边,它允许有自己的反向缩放因子,所以为受约束的操作选择共享比例因子,采用之前计算的理想比例因子的几何平均值。

这个规则听起来很复杂,但实际上它通常可以归结为一个简单的过程:为权重梯度提供它们自己的比例因子(也就是模型中的任何编码器/解码器层)。

3、加权加法操作

最后一步是用加权的加法替换加法操作。根据设计的单位缩放产生的变量具有相等的尺度,如果我们将两个张量相加,它们实际上都具有相等的权重。但是在某些情况下,例如残差连接就需要一个不平衡的权重来获得良好的性能。所以将加法操作替换为加权(和单位缩放)加法等效操作。

对于残差连接,可以推导出以下方案:

代码实现

下面的代码展示了一个在PyTorch中实现Unit Scaling的FFN层。

首先定义创建基本操作的缩放版本,例如scaled_projection:

 classScaledGrad(autograd.Function):
   @staticmethod
   defforward(ctx, X, alpha, beta):
     ctx.save_for_backward(tensor(beta, dtype=X.dtype))
     returnalpha*X
 
   @staticmethod
   defbackward(ctx, grad_Y):
     beta, =ctx.saved_tensors
     returnbeta*grad_Y, None, None
 
 defscaled(X, alpha=1, beta=1):
   """forward: Y = X * alpha, backward: grad_X = grad_Y * beta"""
   returnScaledGrad.apply(X, alpha, beta)
 
 defscaled_projection(X, W):
   (b, _), (m, n) =X.shape, W.shape
   alpha=beta_X= (m*n) **-(1/4) beta_W=b**-(1/2)
   X=scaled(X, beta=beta_X)
   W=scaled(W, beta=beta_W)
   returnscaled(matmul(X, W), alpha)

这样我们就可以创建完整的层。我们只演示一个标准FFN和它的缩放版本:

 classFFN(nn.Module):
   def__init__(self, d, h):
     super().__init__()
     self.norm=LayerNorm(d)
     sigma= (d*h) **-(1/4)
     self.W_1=Parameter(randn(d, h) *sigma)
     self.W_2=Parameter(randn(h, d) *sigma)
 
   defforward(self, X):
     Z=self.norm(X)
     Z=matmul(Z, self.W_1) Z=gelu(Z)
     Z=matmul(Z, self.W_2) returnX+Z
 
 
 classScaledFFN(nn.Module):
   def__init__(self, d, h, tau):
     super().__init__()
     self.norm=ScaledLayerNorm(d)  # Not defined here
     self.W1=Parameter(randn(d, h))
     self.W2=Parameter(randn(h, d))
     self.tau=tau
 
   defforward(self, X):
     a= (1-self.tau) ** (1/2)
     b=self.tau** (1/2)
     Z=self.norm(scaled(X, beta=b))
     Z=scaled_projection(Z, self.W1)
     Z=scaled_gelu(Z)  # Not defined here
     Z=scaled_projection(Z, self.W2)
     returnX*a+scaled(Z, b)  # fixed(𝜏) weighted add

结果展示

实验结果表明,这个方法在广泛的模型中是有效的,并且可以开箱即用,不需要额外的超参数调优。

1、小规模的实验

第一组实验验证了在不同模型架构上的广泛适用性。在FP32和FP16中训练了大量具有和不具有Unit Scaling的小型字符级语言模型,并比较了结果。

在几乎所有情况下,它都与基线性能匹配,甚至略有提高。当从FP32切换到FP16时,不需要调优。

2、大规模的实验

第二组实验在一个更大、更现实的生产级模型BERT[4]上验证了有效性。对单Unit Scaling模型进行调整,使其与标准BERT实现保持一致,然后使用来自英文维基百科文章的文本对其进行训练。

我们对SQuAD v1.0和SQuAD v2.0评估任务的结果如下:

Unit Scaling能够获得与标准(基线)模型相同的性能,并且在所有情况下都可以直接使用。基线模型和Unit Scaling模型并不完全相同,但是它们下游性能的偏差很小(Unit Scaling的BERT Base略低于基线,BERT Large略高于基线)。

FP8的实现是基于Graphcore、AMD和Qualcomm最近提出的标准化格式。Graphcore研究之前证明了在FP8中训练损失缩放BERT而没有退化[5],论文也证明了通过Unit Scaling也可以实现同样的效果。

要使FP8优于FP16,不需要额外的技术。只是简单地将matmul输入量化到FP8中,并能够准确地训练(FP8 E4变体中的权重和激活,以及E5中的梯度)。

低精度训练的未来

随着支持FP8的硬件在人工智能社区的采用越来越多,有效、直接且有原则的模型缩放方法也变得越来越重要。Unit Scaling可以适用于广泛的模型和优化器,并且计算开销最小。

下一代大型模型可能会广泛使用低精度格式,所以这种缩放的方法非常的必要。低精度训练的效率优势是巨大的,Unit Scaling也证明了低精度并不一定会 降低模型的表现。

论文地址:
https://avoid.overfit.cn/post/dfcaa9c45d70421a98f4df52a9e83610

引用

[1] P. Micikevicius et al., Mixed precision training (2018). 6th International Conference on Learning Representations

[2] O. Kuchaiev et al., Mixed-precision training for nlp and speech recognition with openseq2seq (2018), arXiv preprint arXiv:1805.10387

[3] P. Micikevicius et al., FP8 formats for deep learning (2022). arXiv preprint arXiv:2209.05433

[4] J. Devlin et al., BERT: Pre-training of deep bidirectional transformers for language understanding (2019). NAACL-HLT

[5] B. Noune et al., 8-bit numerical formats for deep neural networks (2019). arXiv preprint arXiv:2206.02915

本文作者:Charlie Blake

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/416082.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

webrtc入门系列(三)云服务器coturn环境搭建

《webrtc入门系列(一)easy_webrtc_server 入门环境搭建》 《webrtc入门系列(二)easy_webrtc_server 入门example测试》 《webrtc入门系列(三)云服务器coturn环境搭建》 《webrtc入门系列(四&…

测试题目气死人

服了差不多每一题都要错几个案例我真的服了wok,什么鬼东西!!! lx学长的羊圈 Description lx学长是一个养羊大户,有成千上百个羊圈。可是却一次也没来羊圈帮过忙,今天他被叫来羊圈给羊羊们施展成双成对大法…

力扣算法系统刷题题解记录

力扣算法系统刷题题解记录 文章目录力扣算法系统刷题题解记录前言一、数组704二分查找示意图:解题思路代码27.移除元素示意图解题思路代码前言 参考顺序和资料:《代码随想录》 二刷要认真做笔记啦,加油! 一、数组 704二分查找 …

2023-04-12 面试中常见的数组题目

数组中的问题其实最常见 通过基础问题,掌握写出正确算法的“秘诀”巧妙使用双索引技术,解决复杂问题对撞指针- 滑动窗口 1 从二分查找法看如何写出正确的程序 本节学习重点:处理边界问题! 1.确定边界范围方法,先用区…

13、Qt生成dll-QLibrary方式使用

Qt创建dll,使用QLibrary类方式调用dll 一、创建项目 1、新建项目->其他项目->Empty qmake Project->Choose 2、输入项目名,选择项目位置,下一步 3、选择MinGW,下一步 4、完成 5、.pro中添加TEMPLATE subdirs&#xff…

定时任务之时间轮算法

初识时间轮 我们先来考虑一个简单的情况,目前有三个任务A、B、C,分别需要在3点钟,4点钟和9点钟执行,可以把时间想象成一个钟表。 如上图中所示,我只需要把任务放到它需要被执行的时刻,然后等着时针转到这个…

IP协议(网络层重点协议)

目录 一、IP协议报头格式 二、地址选择 1、IP地址 (1)格式 (2)组成 (3)分类 (4)子网掩码 三、路由选择 IP协议是网络层的协议,它主要完成两个方面的任务&#xf…

4.16--设计模式之创建型之代理模式(总复习版本)---脚踏实地,一步一个脚印

1.代理对象 定义:代理模式给某一个对象提供一个代理对象,并由代理对象控制对原对象的引用,从而实现对真实对象的操作。 通俗的来讲代理模式就是我们生活中常见的中介。 在代理模式中,代理对象主要起到一个中介的作用,…

初识Docker并在linux完成安装

文章目录一、 初识Docker1.1 简介1.2 Docker和虚拟机的异同1.3 Docker架构二、 DockerHub三、Docker的安装一、 初识Docker 1.1 简介 Docker是一种开源的容器化平台,可以让开发者在容器中打包、发布、运行和管理应用程序。它使用轻量级的容器来隔离应用程序和它们的…

Scrapy爬虫基本使用与股票数据Scrapy爬虫

Scrapy爬虫的常用命令 scrapy命令行格式 红色是常用的三种命令 为什么Scrapy采用命令行创建和运行爬虫? 命令行(不是图形界面)更容易自动化,适合脚本控制 本质上,Scrapy是给程序员用的,功能&#xff08…

vue打包之后,可以进行修改配置后端地址、端口等信息方法

前言 用vue-cli构建的项目通常是采用前后端分离的开发模式,也就是前端与后台完全分离,此时就需要将后台接口地址打包进项目中,但是,难道我们只是改个接口地址也要重新打包吗?当然不行了,那就太麻烦了&#…

支付宝沙箱环境+SpringBoot+内网穿透整合开发

目录 1.查看沙箱账号 2.内网穿透 3.沙箱环境整合SpringBoot开发 下面我将以实际案例详细介绍如何使用沙箱环境进行支付宝支付对接的开发 1.查看沙箱账号 首先什么是沙箱账号? 沙箱账号是指在支付宝沙箱环境中创建的测试账户,用于模拟真实的支付流程…

The 2022 ICPC Asia Xian Regional Contest

题目顺序大致按照难度排序。 F. Hotel 现在酒店中有单人间和双人间,价格分别是c1,c2,现在有n个队,每队三个人,性别分别用字母表示,当两个人性别相同且在同一个队时,他们可以住在双人间中。求最…

【跑跑Github开源项目系列】基于YOLO和Streamlit的车辆识别系统demo

【跑跑Github开源项目系列】基于YOLO和Streamlit的车辆识别系统demo写在前面环境配置创建虚拟环境安装库项目运行写在前面 相信很多朋友跟我一样在github等平台上偷代码 (读书人的事怎么能叫偷呢) 的时候会发现伟大且无私的作者虽然开源了代码但是readme文件该写的没写&#x…

2023TYUT移动应用软件开发程序设计和填空

目录 程序设计 程序设计1:根据要求设计UI,补充相应布局文件,即.xml文件 程序设计2:根据要求,补充Activity.java文件 程序填空 说明: 程序设计 程序设计1:根据要求设计UI,补充相应布局文件,即.xml文件…

【C++初阶】第十篇:list模拟实现

文章目录一、list的模拟实现三个类及其成员函数接口总览结点类的模拟实现迭代器类的模拟实现迭代器类的模板参数说明迭代器operator->的重载迭代器模拟实现代码list的模拟实现无参构造函数带参构造拷贝构造函数赋值运算符重载函数析构函数begin和endinserteraselist的迭代器…

WordPress添加阿里云OSS对象云储存配置教程

背景:随着页面文章增多,内置图片存储拖连网站响应速度,这里对我来说主要是想提升速度 目的:使用第三方云存储作为图片外存储(图床),这样处理可以为服务器节省很多磁盘空间,在网站搬家的时候减少文件迁移的工…

【数据结构】堆(笔记总结)

👦个人主页:Weraphael ✍🏻作者简介:目前学习C和算法 ✈️专栏:数据结构 🐋 希望大家多多支持,咱一起进步!😁 如果文章对你有帮助的话 欢迎 评论💬 点赞&…

MySQL--数据库基础--0406

目录 1.什么是数据库? 2. 基本使用 2.1 连接服务器 2.2 数据库的操作在Linux中的体现 2.3 使用案例 3.服务器,数据库,表关系 4.数据逻辑存储 5.SQL的分类 6.存储引擎 1.什么是数据库? 数据库和文件 文件或者数据库&…

OK-MX93开发板-实现Web页面无线点灯

上篇文章:i.MX9352——介绍一款多核异构开发板,介绍了OK-MX9352开发板的基础硬件功能。 本篇来使用OK-MX9352开发板,通过Web界面进行点灯测试,最终的效果如下: 在进行代码编写之前,先在Ubuntu虚拟机上把这…