文章目录
- 前言
- 预备知识
- FlashAttention1
- 传统Attention计算方式
- FlashAttention1的基本原理
- 除去Softmax操作的分块计算
- Softmax分块计算
- Attention分块计算
- FlashAttention2
- 参考资料
前言
FlashAttention系列工作,是一种加速注意力计算方法,目前已经应用在:GPT-3、Falcon2(阿联酋大模型)、Llama2、Megatron-LM、GPT-4等流行LLM上。并且FlashAttention2已经集成到了pytorch2.0中,可以很便捷的调用。
1. FlashAttention动机:Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length.
,可以看出由于Transformer中self-attention 的时间和内存复杂度是序列长度的二次方,所以序列过长时,算法速度会变慢,需要消耗很高的内存,导致低效的
2. FlashAttention主要贡献
- FlashAttention利用底层硬件的内存层次知识,例如GPU的内存层次结构,来提高计算速度和减少内存访问开销
- 核心原理是通过将输入分块,并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作
- FlashAttention减少了内存读写量,从而实现了2-4倍的计算加速
预备知识
1. GPU 内存层次结构
GPU 内存层次结构 包含多种不同大小和速度的内存形式,内存容量越小,读写速度越快。以A100 GPU为例,主要有两种类型,如下图所示:
- 高带宽内存 (HBM):也就是我们常说的GPU显存,A100具有 40-80GB HBM,带宽为 1.5-2.0TB/s
- SRAM:位于GPU片上,每个 108 个流式多处理器(SM, Streaming Multiprocessor)都有 192KB 片上 SRAM,带宽估计约为 19TB/s
- 二者的位置分布如下图所示,其中HBM在VRAM部门,而SRAM在GPU内部:
可以看到,片上 SRAM 比 HBM 快一个数量级,但内存容量小很多数量级。随着计算相对于内存速度变得更快,内存 (HBM) 访问越来越成为操作瓶颈。因此,利用快速 SRAM 变得更加重要。
2. GPU执行过程
GPU 有大量线程(threads )来执行操作(称为内核 Kernel)。每个Kernel将输入从 HBM 加载到寄存器和 SRAM,进行计算,然后将输出写入 HBM。
3. 性能特点
根据计算和内存访问的平衡,操作可以分为计算限制或内存限制。这通常通过算术强度来衡量,即内存访问的每个字节的算术运算数量。
- 计算限制:操作所花费的时间由算术运算的数量决定,而访问 HBM 的时间要少得多。典型的例子是大内部维度的矩阵乘法,以及大量通道的卷积
- 内存限制:操作所花费的时间由内存访问次数决定,而计算所花费的时间要少得多。示例包括大多数其他操作:逐元素(激活、dropout)和归约(求和、softmax、批量归一化、层归一化)
FlashAttention1
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
paper,code
传统Attention计算方式
传统的Attention计算过程如下:
- 首先 Q , K , V Q,K,V Q,K,V矩阵计算好,放在HBM中
- 接着,为了计算QK注意力得分,将 Q , K Q,K Q,K从HBM中取出来,写入SRAM,然后计算 S = Q K T S=QK^T S=QKT,再把 S S S从SRAM写入HBM
- 然后,从HBM加载 S S S到SRAM,计算 P = S o f t m a x ( S ) P=\rm{Softmax}(\it{S}) P=Softmax(S),然后再把 P P P从SRAM写入HBM
- 最后,从HBM加载 P , V P,V P,V到SRAM,计算最终的输出 O = P V O=PV O=PV,然后把 O O O写入HBM
可以看到,这个过程中存在多次HBM和SRAM之间的读写操作。同时由于Attention的计算方式,导致中间的临时变量 S , P S,P S,P的参数量和输入序列长度的平方成正比:
因此,在训练时,长序列输入在计算Attention时会产生更大参数量的临时变量,会占用更大显存空间,导致更多的访问消耗。也就是说,Attention操作主要是内存限制问题,通信时间是制约计算效率的主要因素。
FlashAttention1的基本原理
因此,FlashAttention主要的思想,就是减少通信时间,也就是减少IO操作,使得计算尽可能多的访问片上的SRAM,尽可能少的访问片外的HBM。
- 通过分块计算,融合多个操作,减少中间结果缓存
- 反向传播时,重新计算中间结果(类似于梯度检查点的原理)
除去Softmax操作的分块计算
在计算Attention主要有两个临时变量 S , P S,P S,P,FlashAttention的分块计算,使得不需要存储这两个临时变量,而是直接在SRAM计算得到部分最终结果 O O O,从而减少了内存访问开销。这里我们先忽略Softmax操作,因为在分块计算,他比较麻烦。
这里假设
Q
,
K
,
V
Q,K,V
Q,K,V矩阵的大小为
(
4
,
3
)
(4, 3)
(4,3),那么FlashAttention的分块计算过程如下,由于矩阵乘法的性质,每次分块计算得到的结果,都是最终结果矩阵中的一部分值:
Softmax分块计算
下面,我们来解决Softmax这个麻烦的操作,首先来回顾Softmax的计算公式:
softmax ( { x 1 , … , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N softmax({x1,…,xN})={∑j=1Nexjexi}i=1N
但是,如果数据类型为FP16,那么最大可以表示为65536,因此当 x i x_i xi为12时, e 12 = 162754 e^{12}=162754 e12=162754,超过了FP16所能表示的最大值。因此,我们需要使用Safe_Softmax方法来避免这个问题:
m = max ( x i ) softmax ( { x 1 , … , x N } ) = { e x i / e m ∑ j = 1 N e x j / e m } i = 1 N = { e x i − m ∑ j = 1 N e x j − m } i = 1 N \begin{aligned} & m=\max \left(x_i\right) \\ & \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i} / e^m}{\sum_{j=1}^N e^{x_j} / e^m}\right\}_{i=1}^N=\left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \end{aligned} m=max(xi)softmax({x1,…,xN})={∑j=1Nexj/emexi/em}i=1N={∑j=1Nexj−mexi−m}i=1N
也就是在计算Softmax之前,先对输入数据进行归一化处理。此时计算Softmax的流程为:
x = [ x 1 , … , x N ] m ( x ) : = max ( x ) p ( x ) : = [ e x 1 − m ( x ) , … , e x N − m ( x ) ] l ( x ) : = ∑ i p ( x ) i softmax ( x ) : = p ( x ) l ( x ) \begin{aligned} & x=\left[x_1, \ldots, x_N\right] \\ & m(x):=\max (x) \\ & p(x):=\left[e^{x_1-m(x)}, \ldots, e^{x_N-m(x)}\right] \\ & l(x):=\sum_i p(x)_i \\ & \operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} x=[x1,…,xN]m(x):=max(x)p(x):=[ex1−m(x),…,exN−m(x)]l(x):=i∑p(x)isoftmax(x):=l(x)p(x)
接下来看分块处理时,假设这里分为两块处理,我们首先需要在每个块内找到最大值(使用临时变量来保存),做归一化处理:
x = [ x 1 , … , x N , … x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … x 2 N ] m ( x 1 ) p ( x 1 ) l ( x 1 ) m ( x 2 ) p ( x 2 ) l ( x 2 ) \begin{aligned} & x=\left[x_1, \ldots, x_N, \ldots x_{2 N}\right] \\ & x^1=\left[x_1, \ldots, x_N\right] \quad x^2=\left[x_{N+1}, \ldots x_{2 N}\right] \\ & m\left(x^1\right) \quad p\left(x^1\right) \quad l\left(x^1\right) \quad m\left(x^2\right) \quad p\left(x^2\right) \quad l\left(x^2\right) \end{aligned} x=[x1,…,xN,…x2N]x1=[x1,…,xN]x2=[xN+1,…x2N]m(x1)p(x1)l(x1)m(x2)p(x2)l(x2)
然后再计算数据的全局最大值,并且更新 p ( x ) , l ( x ) p(x),l(x) p(x),l(x),最后计算得到输入数据的Softmax值:
- 由于全局最大值,一定是各个块内最大值中的一个
- 因此在更新
p
(
x
)
,
l
(
x
)
p(x),l(x)
p(x),l(x),只需要乘以每个块最大值相对于全局最大值的差值的指数,就可以了
m ( x ) : = max ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) softmax ( x ) : = p ( x ) l ( x ) \begin{aligned} &\begin{aligned} & m(x):=\max \left(m\left(x^1\right), m\left(x^2\right)\right) \\ & p(x):=\left[e^{m\left(x^1\right)-m(x)} p\left(x^1\right), e^{m\left(x^2\right)-m(x)} p\left(x^2\right)\right] \\ & l(x):=e^{m\left(x^1\right)-m(x)} l\left(x^1\right)+e^{m\left(x^2\right)-m(x)} l\left(x^2\right) \end{aligned}\\ &\operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} m(x):=max(m(x1),m(x2))p(x):=[em(x1)−m(x)p(x1),em(x2)−m(x)p(x2)]l(x):=em(x1)−m(x)l(x1)+em(x2)−m(x)l(x2)softmax(x):=l(x)p(x)
Attention分块计算
最后,我们来看一下FlashAttention的完整计算流程:
FlashAttention2
相比于FlashAttention1的改进:
- 减少了非矩阵乘法计算,可以利用Tensor Core加速计算
- 调整了内外训练方式,改为 Q 为外层循环,KV 为内层循环,进一步减少HBM读写,增加了并行度
- 如果一个Block处于矩阵上三角部分(Mask机制),则不进行attention计算,进一步优化了计算效率
参考资料
- [1] https://www.bilibili.com/video/BV1UT421k7rA/?share_source=copy_web&vd_source=79b1ab42a5b1cccc2807bc14de489fa7
- [2] https://zhuanlan.zhihu.com/p/676655352