作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra...
论文地址: https://arxiv.org/abs/2205.14135
项目地址: https://github.com/Dao-AILab/flash-attention
摘要
Transformers在处理长序列时速度慢且内存消耗大,因为自注意力的时间和内存复杂度与序列长度的平方成正比。近似注意力方法试图通过权衡模型质量来减少计算复杂度,但往往无法实现实际的速度提升。我们认为一个缺失的原则是使注意力算法考虑输入/输出(IO)-- 即考虑在GPU内存层次之间的读写操作。我们提出了FlashAttention,这是一种考虑IO的精确注意力算法,通过分块减少GPU高带宽内存(HBM)和GPU片上SRAM之间的内存读写次数。我们分析了FlashAttention的IO复杂度,显示它比标准注意力需要更少的HBM访问,并且在多种SRAM大小下表现最优。我们还将FlashAttention扩展到块稀疏注意力,产生了一种比现有任何近似注意力方法更快的近似注意力算法。FlashAttention使变压器训练速度比现有基准更快:在BERT-large(序列长度512)上相比MLPerf 1.1训练速度记录提升15%,在GPT-2(序列长度1K)上提升3倍,在长范围arena(序列长度1K-4K)上提升2.4倍。FlashAttention和块稀疏FlashAttention支持Transformers处理更长的上下文,从而生成更高质量的模型(GPT-2上的困惑度提升0.7点,长文档分类提升6.4点),以及全新的能力:在Path-X挑战(序列长度16K,61.4%准确率)和Path-256(序列长度64K,63.1%准确率)上实现了超越偶然表现的首个变压器。
原理概览
对论文总结Flash Attention快的原因是,这种新型的注意力算法,旨在解决传统注意力机制在大型模型中的存储和带宽开销问题。它通过减少内存读写次数,提高了注意力算法在GPU内存IO的效率,从而加快模型训练速度并增加上下文窗口大小。总览图如下:
左图:FlashAttention通过分块技术来避免在(相对较慢的)GPU高带宽内存(HBM)上生成大规模的N×N 注意力矩阵(虚线框)。在外层循环中(红色箭头),FlashAttention遍历K和V矩阵的块,并将它们加载到快速的片上SRAM中。在每个块中,FlashAttention遍历Q矩阵的块(蓝色箭头),将其加载到SRAM中,并将注意力计算的结果写回到HBM中。右图:在GPT-2上相较于PyTorch实现的注意力的加速比。FlashAttention避免了对大型N×N 注意力矩阵的读写操作,导致注意力计算速度提升了7.6倍。
注:HBM通常指的是“高带宽内存”(High Bandwidth Memory)。HBM是一种先进的垂直堆叠式DRAM内存技术,由SK海力士(SK Hynix)和三星(Samsung)等公司开发。这种类型的内存被设计用于高性能计算系统,如高端显卡、超级计算机和数据中心服务器。
原理拆分
1)首先标准注意力机制实现如下公式:
其中输入序列
,N是序列长度,d是隐藏层维度数,要想计算注意力输出
,那计算流程如上图,这里softmax是按行计算的。
标准的注意力实现将矩阵 S 和 P 放到高带宽内存 (HBM) 中,这需要 𝑂(𝑁^2) 的内存。通常 𝑁 ≫ 𝑑(例如,对于 GPT2,𝑁 = 1024 和 𝑑 = 64)。我们在下图伪代码中描述了标准的注意力实现。由于某些或大多数操作是受内存限制的(例如,softmax),大量的内存访问会导致较慢的实际时间。这一问题由于应用于注意力矩阵的其他逐元素操作(如对 S 进行掩码操作或对 P 进行 dropout 操作)而变得更加严重。因此,已经有许多尝试将多个逐元素操作融合在一起,例如将掩码与 softmax 融合 。
从伪代码中可以看到IO的读取非常的繁琐耗时:
均存储在HBM中
1. 从HBM中加载Q,K到SRAM中;
2. 计算
,将s写入HBM中;
3. 从HBM中读取s加载到SRAM;
4. 计算
,将P写入HBM中;
5. 从HBM中读取P,V并加载到SRAM;
6. 计算
,将O写入HBM,返回O
2)从这张图我们也可以知道,大的矩阵乘法和多通道卷积计算耗时比较少,而注意力计算中的dropout,softmax,mask比较耗时:
因此对其的优化一般是进行分块计算(Tiling and Recomputation)再融合操作,不对中间结果缓存,减少HBM的访问耗时。
Tiling and Recomputation的主要思想是,把输入Q、K、V分成多个块,将它们从slow HBM加载到 fast SRAM,然后计算相对于这些块的注意力输出。在通过正确的归一化因子缩放每个块的输出然后相加,便可得到正确结果。可以用一个CUDA kernel来执行注意力的所有操作。从HBM中加载输入数据,在SRAM中执行所有的计算操作(矩阵乘法,mask,softmax,dropout,矩阵乘法),再将计算结果写回到HBM中。通过kernel融合将多个操作融合为一个操作,避免了反复地从HBM中读写数据。
那么注意力计算中sofamax是如何切分计算的,需要对softmax进行缩放计算(在实际硬件中,浮点数表示的范围是有限的。对于float32和bfloat16来说,当 x≥89时,e^x就会变成inf,发生数据上溢的问题。为了避免发生数值溢出的问题,保证数值稳定性,计算时通常会“减去最大值”,称为“safe softmax”),具体公式如下:
对于重计算,在本文注意力实现中,后向传递计算Q,K,V 的梯度时,需要用NxN的中间矩阵S,P,但这两个矩阵并没有保存下来。这里的技巧是重计算,保存了两个统计量m(x), l(x),后向传递时在高速的SRAM上快速地重新计算Attention,通过分块的方式重新计算注意力矩阵S,P,因为这种方式,将显存复杂度从O(N^2)降到了O(N)。相比于标准注意力中,从HBM中读取很大的中间注意力矩阵的方法,重计算的方法要快得多。
Flash Attention通过kernel融合和分块计算,大量减少了HBM访问次数,尽管由于后向传递中的重计算增加了额外的计算量FLOPs,但依旧减少了运行时间,下图是GPT-2 medium在标准注意力和 FlashAttention 的向前 + 向后运行时间,可以看到提升了4.7倍。
最后对于flash attention2做了一下三点的优化:
1. 减少非矩阵乘法的计算,利用TensorCore加速
2. 调整O为外层训练,K,V为内层循环,减少HBM读取
3. 对于一个Block块处于矩阵上三角部分(被mask的部分),则不进行Attention计算
参考:
1. https://www.bilibili.com/video/BV1UT421k7rA/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=01b54c990198e640f937517e2d38c7db
2.https://zhuanlan.zhihu.com/p/639228219?s_r=0