Transformers 中的Softmax 和 Layer Norm 如何并行?

news2024/12/27 11:49:55

1.Softmax 如何并行?

        Softmax 计算公式:

        安全的 Softmax 运算:

        softmax 有个问题,那就是很容易溢出。比如采用半精度,由于float16的最大值为65504,所以只要x>=11,那么softmax就溢出了。即使是float32,x也不能超过88。

        好在 exp 有这么一个性质,那就是

        

         根据这个性质,可以在分子分母上同时除以一个数,这样可以将  的范围都挪到非正实数域。

这样,就可以保证计算 softmax 时的数值稳定性。

这个算法可以分成三次迭代来执行。

  1. 求 x 的最大值 m

       2. 计算 softmax 分母

        3.求对应位置的 softmax

        分析上面的步骤,可以发现,如果是不做任何优化的话,至少要进行和 GPU 进行6次通信(3次写入,3次写出)。

        如果对每一步的for 循环进行一些并行切分的的话,还要加上 reduce_sum 和 reduce_max 之类的通信成本。

        是否能将某些操作进行融合,减少通信呢?按照之前 layernorm 并行的经验,我们需要寻找一个 Online Algorithm。

Online Softmax

        2018年 Nvidia 提出了《Online normalizer calculation for softmax》

        既然是 Online 的算法,我们需要找出递归的表达式。

        对于第二步中的我们期望去掉这个式子对

的依赖。

设 ,,注意,这里减去的全局最大值变成了当前最大值。这个式子有如下的性质:

        还能不能进一步融合算子呢?没办法了,因为第二步的分母依赖于第一步的计算。

        但是可以借助 GPU 的 share memory 来存储中间结果,将上面的两步只用一个 kernel 实现,这样就只需要与 global memory 通信两次,一次写入数据,一次读取结果。

整体来说,有两个重要的优化点:

  1. 将前两步的算子融合,减少 Reduce_max 和 Reduce_sum 之类的通信成本。

  2. 借助 share memory 存储中间结果,减少与 global memory 的通信成本。

        这一篇只是从数学上给出了一些 Softmax 的并行理论基础。具体实现还有很多细节上的优化点,比如:

感兴趣的可以看看 oneflow 的一个 softmax 深度优化:https://www.oneflow.org/a/share/jishuboke/54.html . 源代码在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh

还有 Nvidia 自己实现的一个可读性很好的版本:https://github.com/NVIDIA/FasterTransformer/blob/release/v1.0_tag/fastertransformer/cuda/open_attention.cu#L189-L268 但是速度没有 oneflow 的好。

2.Transformers 中的 Layer Norm 可以并行加速么?

        这个问题我之前觉得可以加速,而且给出了一个简单的实现方案。后来看 Transformers 的一些 GPU 训练的代码后,才发现我真是 too young too simple, sometimes even naive。

        layernorm 的计算,重点就是计算均值和方差。分两步:

实际上的并行方案

上面的方案当然没什么问题,但是并不是最优的。

上面的算法需要遍历2次数据,一次计算均值,一次计算方差。能不能只遍历数据一次就能并行的把均值和方差算出来呢?

相信你会立马想到这个公式:

 

        并行的时候,一边算平方和,一边算全部的和。最后平方和与均值都可以算出来,然后按公式一减就出来了,看上去十分的 Perfect。

        但是这个公式只是理论上很完美,受限于计算机计算精度的问题,这个公式当两个平方项都很大的时候,精度会失真,导致算出来的方差很不稳定,甚至有可能是负数。后面会有代码演示数值稳定性的问题。

        那能像上一节的算法那样,分别计算均值和方差最后聚合么?似乎有些反直觉,不需要知道全局的均值就可以计算方差。但是我们要相信数学家的折腾能力,搞出了无数匪夷所思的东西。就连加百列号角(Gabriel's Horn)这种鬼玩意都能搞出来,数学有无限的可能。(注:加百列号角 Gabriel's Horn 的体积是有限的,但是表面积是无限的。)

        Transfromers 无论是在 pytorch,还是在 apex,还是在其他一些加速框架比如 oneflow 中,都采用了 Welford online Algorithm。这个算法是 Welford 在1962年发表的《Note on a Method for Calculating Corrected Sums of Squares and Products》中提出。他给出的算法,可以在一个集合新增一个元素的时候,均值和方差的不需要把所有的数都遍历一遍,而是根据之前集合的均值和方差就可以直接计算出来。

        而在1972年,Chan 发表了《Updating Formulae and a Pairwise Algorithm for Computing Sample Variances》,可以认为是 Welford Algorithm 的一个升级版本,可以根据两个集合的均值和方差直接计算出整体的均值和方差。当然如果两个集合中,某一个集合只有一个元素,算法就退化成 Welford Algorithm 了。这个算法为大规模并行计算均值和方差提供了理论基础。

        由于 Welford's Algorithm 是 Chan's Algorithm 的一个特例,所以下面简单说一下 Chan's Algorithm 是怎么一回事。

        这里首先给出一个定义,定义与均值差的平方和为

        也就是说我们只需要两个集合各自的均值和 M2 我们就可以计算出方差。

        上面这个式子怎么来的呢?我们来证明一下:

        证明完了好像也没那么神奇,陷入了人生三大错觉之一:我上我也行,只恨自己生的太晚。事后诸葛亮就是这么自信满满。 

代码学习

        由于Chan 和 Welford 算法的并行体质,Nvidia 的 Apex 库率先实现了这个方法,叫做 Fused Layer Norm。为啥叫 Fused ?因为把所有的计算都融合(fuse)到一个核函数里了,不需要与 CPU 来回通信。可以重点看代码开头的 cuWelfordOnlineSum 和 cuChanOnlineSum 两个函数。对应 python 代码入口为 apex.normalization.fused_layer_norm.FusedLayerNorm。代码见:https://github.com/NVIDIA/apex/blob/c3fad1ad120b23055f6630da0b029c8b626db78f/csrc/layer_norm_cuda_kernel.cu#L670

        pytorch 后来也实现实现了,可以看 cuWelfordOnlineSum 和 cuWelfordCombine 两个函数,代码见:https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer_norm_kernel.cu

        Oneflow 后来又进一步根据输入的大小优化了 Fused Layer Norm 的性能,代码见:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh

        所以现在我们使用 pytorch 和其他加速库的 layernorm 函数底层已经实现了并行。我们这些调包侠在用 python 写代码的时候,要记住,哪有什么岁月静好,都是 C++ 和 Cuda 大佬们在负重前行

        下面我用 python 模拟了一下 c++ cuda 的实现,同时测试了一下在数字比较大的时候的数值稳定性问题。可以发现,用平方和减去均值平方的方法,方差就算错了,成为了负值。

  测试普通数字...
    全局均值: -0.10384651739409387
    Welford 并行全局均值: -0.10384651739409384
    串行全局方差: 0.8165221946938586
    平方差串行全局方差: 0.816522194693858
    Welford 并行全局方差: 0.8165221946938584
    --------------------------------------------------
    测试大数...
    全局均值: 999999999.8961536
    Welford 全局均值: 999999999.8961536
    串行全局方差: 0.8165221933047772
    平方差串行全局方差: -512.0
    Welford 并行全局方差: 0.8165221874239014
    --------------------------------------------------

        核心代码如下,全部的代码实在是有些又臭又长,就放在开篇提到的电子书里了。

# 核心代码
    def welford_combine(val, mean, m2, count):
        """新增一个数"""
        count += 1
        delta1 = val - mean
        mean += delta1 / count
        delta2 = val - mean
        m2 += delta1 * delta2
        return mean, m2, count
    
    def welford_combine_two(b_mean, b_m2, b_count, mean, m2, count):
        """合并两个集合"""
        if b_count == 0:
            return mean, m2, count
        new_count = count + b_count
        nb_over_n = b_count / new_count
        delta = b_mean - mean
        mean += delta * nb_over_n
        m2 += b_m2 + delta * delta * count * nb_over_n
        count = new_count
        return mean, m2, count

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1984753.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言新手小白详细教程(6)函数

希望文章能够给到初学的你一些启发~ 如果觉得文章对你有帮助的话,点赞 关注 收藏支持一下笔者吧~ 阅读指南: 开篇说明为什么要使用函数?1.定义一个函数2.调用函数3.定义函数详解 开篇说明 截止目前,我们已…

华清IOday7 24-8-5

文章目录 使用有名管道实现,一个进程用于给另一个进程发消息,另一个进程收到消息后,展示到终端上,并且将消息保存到文件上一份使用有名管道实现两个进程间相互通信 使用有名管道实现,一个进程用于给另一个进程发消息&a…

服务器数据恢复—raid5阵列上层Oracle数据库数据恢复案例

服务器数据恢复环境&故障: 一台服务器上有8块SAS硬盘,其中的7块硬盘组建了一组RAID5阵列,另外1块硬盘作为热备盘使用。划分了6个LUN,服务器上部署有oracle数据库。 RAID5磁盘阵列中有2块硬盘出现故障并离线,RAID5阵…

浮点数在计算机中的编码方式

一、前言 我们常能听到,直接用浮点数做运算得出的结果是不准确的了;或者也能看到涉及到浮点数时,会出现一些奇奇怪怪的问题,比如: public class DecimalTest {public static void main(String[] args) {float f1 1.…

STK12.2+Python开发(二):添加访问约束,新建场景、卫星、地面站等,获取当前场景的信息

新建场景 1.获取当前打开的场景 #获取当前打开的场景 scenario root.CurrentScenario2.设置当前场景的时间 scenario.SetTimePeriod(Today,24hr)3.添加一个地面目标到当前的场景 scenario.SetTimePeriod(Today,24hr)4.添加一个地面目标到当前的场景,括号内是三…

空气质量传感器 - 从零开始认识各种传感器【二十八期】

空气质量传感器|从零开始认识各种传感器 1、什么是空气质量传感器? 空气质量传感器是一种检测空气中污染物浓度的设备,广泛应用于环境监测、智能家居、工业控制和健康管理等领域。 2、空气质量传感器是如何工作的? 对于每个人都关心的空气质…

效率何止10倍!利用输入法瞬间调用提示词

我们在日常工作/学习/生活有很多场景需要使用提示词,比如说: 快速总结文章快速排版解释概念翻译其它经常面对的任务 但是使用提示词有几个痛点: 你很难临时写一个非常完整的提示词你凑合写的提示词,又担心结果不满意如果已经保…

前端使用css动画绘制简易的进度条,数据多条的时候可以切换

文章目录 一、效果图二、使用步骤1.公共的进度条组件2.使用 总结 一、效果图 二、使用步骤 1.公共的进度条组件 我这里命名的progressBar.vue&#xff0c; 你们使用的时候直接复制粘贴到自己的项目里面即可。 文件中代码如下&#xff08;示例&#xff09;&#xff1a; <t…

EasyX 碰撞检测

代码&#xff1a; #define _UNICODE #define UNICODE#include <array> #include <cmath> #include <ctime> #include <format> #include <graphics.h> #include <vector>typedef struct tagRECTF {double left;double top;double right;d…

You Only Look Once:Unified, Real-Time Object Detection 论文阅读

论文名&#xff1a;You Only Look Once:Unified, Real-Time Object Detection 论文作者&#xff1a;Joseph Redmon et.al. 期刊/会议名&#xff1a;CVPR 2016 发表时间&#xff1a;2016-5 ​论文地址&#xff1a;https://arxiv.org/pdf/1506.02640 1.摘要 我们提出了一种新的目…

论文辅导 | 结合变种残差模型和 Transformer 的城市公路短时交通流预测

辅导文章 模型描述 城市公路交通流的预测受到历史交通流量和相邻车道交通流量的影响&#xff0c;蕴含了复杂的时空特征。针对传统交通流预测模型卷积长短时记忆网络(ConvLSTM)进行交通流预测时&#xff0c;未将时空特征分开提取而造成的提取不充分、特征信息混淆和特征信息缺失…

视频融合技术

三维视频融合技术遵循数字孪生多源数据融合的原则&#xff0c;比视频窗口、矩阵更加直观高效&#xff0c;省去了人脑理解空间的时间&#xff0c;可有效提升数字孪生城市在物联感知操作、虚实融合交互等方面的能力&#xff0c;动静一体、虚实结合&#xff0c;让三维场景“动起来…

常见的SQL注入

联合查询 如下&#xff0c;要求我们传入一个id值过去。传参?id1&#xff0c;当我们输入id1和id2时&#xff0c;页面中name值和password的值会发生变化&#xff0c;说明此时我们输入的数据和数据库有交互并且将数据显示在屏幕上了 输入?id1&#xff0c;页面发生报错&#xf…

手机联网如何设置动态ip

在现代社会&#xff0c;手机已成为我们日常生活中不可或缺的一部分&#xff0c;无论是工作、学习还是娱乐&#xff0c;都离不开网络的支持。而在手机联网的过程中&#xff0c;IP地址的分配方式显得尤为重要。动态IP地址因其灵活性和安全性&#xff0c;成为了许多用户的首选。那…

电子合同怎么制作?9款常用电子合同软件

文章将介绍了以下9个工具&#xff1a;e签宝、文书宝、签通云、快签宝、法天使、Zycus iContract、airSlate WorkFlow、Lightico、KeepSolid Sign。 在数字化快速发展的今天&#xff0c;电子合同成为了业务操作中不可或缺的一部分&#xff0c;但许多人仍然面临如何有效创建和管理…

Redis vs Memcached:Redis的三大优势

Redis vs Memcached&#xff1a;Redis的三大优势 1. 数据类型2. 数据持久化能力3. 高性能与灵活性 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 1. 数据类型 Redis&#xff1a;支持多样化的数据类型&#xff0c;包括字符串&#xff08;S…

前端性能优化-回流与重绘

前言 本文总结回流与重绘相关的知识点 回流与重绘的基本概念 重绘&#xff08;Repaint&#xff09;&#xff1a; 当元素样式发生改变&#xff0c;但不影响其几何属性的时候&#xff0c;浏览器只需要重新绘制这个元素&#xff0c;这个过程被称为重绘。 回流&#xff08;Refl…

Linux_监测CPU和内存

通过TOP持续获取进程的CPU和内存消耗&#xff0c;并写入到表格 # 配置进程名 processvm-agent # 配置次数 number100 # 配置间隔时间 time5 # csv结果文件 filecm_$(date %s).csv echo "%CPU,%MEM">${file} pid$(ps -aux | grep ${process} | awk -F {OFS"…

debug\moc\mocinclude.tmp dose not exist

先把jom禁用&#xff0c;然后清理工程&#xff0c;重新编译&#xff0c;编译通过后再重新打开jom

MybatisPlus的主键策略

ASSIGN_ID(默认策略) 生成唯一的值&#xff0c;包含数字&#xff0c;表对应字段类型bigint或者varchar类型 ASSIGN_UUID() 生成唯一的值&#xff0c;包含数字和字母&#xff0c;表对应字段类型varchar类型 AUTO 主键自动增长效果&#xff0c;和表字段auto_increment INPUT …