1.Softmax 如何并行?
Softmax 计算公式:
安全的 Softmax 运算:
softmax 有个问题,那就是很容易溢出。比如采用半精度,由于float16的最大值为65504,所以只要x>=11,那么softmax就溢出了。即使是float32,x也不能超过88。
好在 exp 有这么一个性质,那就是
根据这个性质,可以在分子分母上同时除以一个数,这样可以将 的范围都挪到非正实数域。
这样,就可以保证计算 softmax 时的数值稳定性。
这个算法可以分成三次迭代来执行。
-
求 x 的最大值 m
2. 计算 softmax 分母
3.求对应位置的 softmax
分析上面的步骤,可以发现,如果是不做任何优化的话,至少要进行和 GPU 进行6次通信(3次写入,3次写出)。
如果对每一步的for 循环进行一些并行切分的的话,还要加上 reduce_sum 和 reduce_max 之类的通信成本。
是否能将某些操作进行融合,减少通信呢?按照之前 layernorm 并行的经验,我们需要寻找一个 Online Algorithm。
Online Softmax
2018年 Nvidia 提出了《Online normalizer calculation for softmax》
既然是 Online 的算法,我们需要找出递归的表达式。
对于第二步中的我们期望去掉这个式子对
的依赖。
设 ,,注意,这里减去的全局最大值变成了当前最大值。这个式子有如下的性质:
还能不能进一步融合算子呢?没办法了,因为第二步的分母依赖于第一步的计算。
但是可以借助 GPU 的 share memory 来存储中间结果,将上面的两步只用一个 kernel 实现,这样就只需要与 global memory 通信两次,一次写入数据,一次读取结果。
整体来说,有两个重要的优化点:
-
将前两步的算子融合,减少 Reduce_max 和 Reduce_sum 之类的通信成本。
-
借助 share memory 存储中间结果,减少与 global memory 的通信成本。
这一篇只是从数学上给出了一些 Softmax 的并行理论基础。具体实现还有很多细节上的优化点,比如:
感兴趣的可以看看 oneflow 的一个 softmax 深度优化:https://www.oneflow.org/a/share/jishuboke/54.html . 源代码在https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/softmax.cuh
还有 Nvidia 自己实现的一个可读性很好的版本:https://github.com/NVIDIA/FasterTransformer/blob/release/v1.0_tag/fastertransformer/cuda/open_attention.cu#L189-L268 但是速度没有 oneflow 的好。
2.Transformers 中的 Layer Norm 可以并行加速么?
这个问题我之前觉得可以加速,而且给出了一个简单的实现方案。后来看 Transformers 的一些 GPU 训练的代码后,才发现我真是 too young too simple, sometimes even naive。
layernorm 的计算,重点就是计算均值和方差。分两步:
实际上的并行方案
上面的方案当然没什么问题,但是并不是最优的。
上面的算法需要遍历2次数据,一次计算均值,一次计算方差。能不能只遍历数据一次就能并行的把均值和方差算出来呢?
相信你会立马想到这个公式:
并行的时候,一边算平方和,一边算全部的和。最后平方和与均值都可以算出来,然后按公式一减就出来了,看上去十分的 Perfect。
但是这个公式只是理论上很完美,受限于计算机计算精度的问题,这个公式当两个平方项都很大的时候,精度会失真,导致算出来的方差很不稳定,甚至有可能是负数。后面会有代码演示数值稳定性的问题。
那能像上一节的算法那样,分别计算均值和方差最后聚合么?似乎有些反直觉,不需要知道全局的均值就可以计算方差。但是我们要相信数学家的折腾能力,搞出了无数匪夷所思的东西。就连加百列号角(Gabriel's Horn)这种鬼玩意都能搞出来,数学有无限的可能。(注:加百列号角 Gabriel's Horn 的体积是有限的,但是表面积是无限的。)
Transfromers 无论是在 pytorch,还是在 apex,还是在其他一些加速框架比如 oneflow 中,都采用了 Welford online Algorithm。这个算法是 Welford 在1962年发表的《Note on a Method for Calculating Corrected Sums of Squares and Products》中提出。他给出的算法,可以在一个集合新增一个元素的时候,均值和方差的不需要把所有的数都遍历一遍,而是根据之前集合的均值和方差就可以直接计算出来。
而在1972年,Chan 发表了《Updating Formulae and a Pairwise Algorithm for Computing Sample Variances》,可以认为是 Welford Algorithm 的一个升级版本,可以根据两个集合的均值和方差直接计算出整体的均值和方差。当然如果两个集合中,某一个集合只有一个元素,算法就退化成 Welford Algorithm 了。这个算法为大规模并行计算均值和方差提供了理论基础。
由于 Welford's Algorithm 是 Chan's Algorithm 的一个特例,所以下面简单说一下 Chan's Algorithm 是怎么一回事。
这里首先给出一个定义,定义与均值差的平方和为
也就是说我们只需要两个集合各自的均值和 M2 我们就可以计算出方差。
上面这个式子怎么来的呢?我们来证明一下:
证明完了好像也没那么神奇,陷入了人生三大错觉之一:我上我也行,只恨自己生的太晚。事后诸葛亮就是这么自信满满。
代码学习
由于Chan 和 Welford 算法的并行体质,Nvidia 的 Apex 库率先实现了这个方法,叫做 Fused Layer Norm。为啥叫 Fused ?因为把所有的计算都融合(fuse)到一个核函数里了,不需要与 CPU 来回通信。可以重点看代码开头的 cuWelfordOnlineSum 和 cuChanOnlineSum 两个函数。对应 python 代码入口为 apex.normalization.fused_layer_norm.FusedLayerNorm。代码见:https://github.com/NVIDIA/apex/blob/c3fad1ad120b23055f6630da0b029c8b626db78f/csrc/layer_norm_cuda_kernel.cu#L670
pytorch 后来也实现实现了,可以看 cuWelfordOnlineSum 和 cuWelfordCombine 两个函数,代码见:https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/layer_norm_kernel.cu
Oneflow 后来又进一步根据输入的大小优化了 Fused Layer Norm 的性能,代码见:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/layer_norm.cuh
所以现在我们使用 pytorch 和其他加速库的 layernorm 函数底层已经实现了并行。我们这些调包侠在用 python 写代码的时候,要记住,哪有什么岁月静好,都是 C++ 和 Cuda 大佬们在负重前行。
下面我用 python 模拟了一下 c++ cuda 的实现,同时测试了一下在数字比较大的时候的数值稳定性问题。可以发现,用平方和减去均值平方的方法,方差就算错了,成为了负值。
测试普通数字...
全局均值: -0.10384651739409387
Welford 并行全局均值: -0.10384651739409384
串行全局方差: 0.8165221946938586
平方差串行全局方差: 0.816522194693858
Welford 并行全局方差: 0.8165221946938584
--------------------------------------------------
测试大数...
全局均值: 999999999.8961536
Welford 全局均值: 999999999.8961536
串行全局方差: 0.8165221933047772
平方差串行全局方差: -512.0
Welford 并行全局方差: 0.8165221874239014
--------------------------------------------------
核心代码如下,全部的代码实在是有些又臭又长,就放在开篇提到的电子书里了。
# 核心代码
def welford_combine(val, mean, m2, count):
"""新增一个数"""
count += 1
delta1 = val - mean
mean += delta1 / count
delta2 = val - mean
m2 += delta1 * delta2
return mean, m2, count
def welford_combine_two(b_mean, b_m2, b_count, mean, m2, count):
"""合并两个集合"""
if b_count == 0:
return mean, m2, count
new_count = count + b_count
nb_over_n = b_count / new_count
delta = b_mean - mean
mean += delta * nb_over_n
m2 += b_m2 + delta * delta * count * nb_over_n
count = new_count
return mean, m2, count