论文名称:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
论文地址:https://arxiv.org/abs/2205.14135
一、研究FlashAttention的Motivate
FlashAttention技术在现在的主流大语言模型中均有应用,其主要作用是减少Transformer结构中运算(主要是self-attention,包括softmax、dropout等)的显存消耗,进一步解除文本处理长度限制,使得模型能够处理更长、更复杂的文本数据 以及 多轮对话功能。
让我们继续看看原论文中的说法:
Transformer 作为语言模型的基础架构提供了强大的特征表达能力,已经作为LLM的基础模型构件被大量使用。
在Transformer中核心组件是多头自注意力(multi-head selft-attention),这里的计算复杂度和空间复杂度是序列长度的二次方O(n2)。因此长文本处理仍然面临挑战。
当然有许多尝试用于减少注意力的计算和内存开销。例如,稀疏近似和低秩近似得方法,将计算复杂度降低到序列长度的线性或亚线性,但这些方法主要关注FLOPs(浮点数计算次数)的减少(这部分消耗主要由矩阵运算提供),而忽略了IO读写的内存访问开销。
由下图可以看到,GPT-2中的标准attention,耗时对比:矩阵运算 < softmax < Dropout。在现代GPU中,计算速度超过显存访问速度。基于这样的发现,论文作者将突破‘超长文本处理’的契机放在了注意力的IO瓶颈。论文团队在对GPU硬件和注意力实现进行性能剖析后,将性能瓶颈锁定在‘HBM内存的读写压力过大’,指标论文的主要优化方向为‘降低HBM的IO次数’。
二、标准注意力机制与HBM的访问关系
2.1 标准Attention机制推理过程
将上面的步骤进行拆解可以得到
,,
如下图所示,一次标准Attention的实现需要多次读写HBM:
1. 按块从HBM中读取矩阵Q和K,计算S,并将S写入HBM;
2. 从HBM中读取S,计算完P=softmax(S)之后,将P写入HBM;
3. 按块从HBM读取中间结果P和V,计算O=PV,将O写入HBM;
4. 返回O
注意:笔者不清楚Q和K是同时读取还是分为2次;有相关科普说是分别读取(读两次HBM)
2.2 GPU结构的一些知识
这里是论文中给出的GPU A-100的内存结构:
1. HBM(High Bandwidth Memory,高带宽存取存储器)
由多个DRAM堆叠。
2. SRAM(Static Random-Access Memory, 静态随机访问存储器)
用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高。由图中可见SRAM的执行/读写速度是HBM的12.67倍,但存储空间远远小于HBM。
三、FlashAttention
FlashAttention,总的来说是一种优化访问开销的精准注意力算法。
motivation:从GPU内存结构来看,要想提升Attention的性能,应该让计算过程尽可能在SRAM中进行。由于序列长度 N 可能会很长,无法将Q、K、V 以及中间结果完整存储在SRAM中,因此FlashAttention就采用了‘分块’操作,每块的计算所需内存不超过SRAM大小。
这里有两种核心操作:tiling(平铺) 、recomputation(重计算) ,最后使用 kernel fusion 进行融合。
1. tiling:利用更高速的SRAM代替HBM;
2. recomputation:放弃中间结果写回,需要使用时再次计算,用计算Trade-off访存;
3. kernel fusion:基于Tiling使用一个kernel完成整个计算。
3.1 tiling 平铺
tiling 基本思路:不直接对整个输入序列计算注意力,而是根据SRAM大小将其分为多个较小的块,逐个对‘块’进行计算,在计算过程中增量式(详见3.1.2)地进行softmax的逼近。在整个计算过程中只需要更新某些中间变量(如全局最大值,详见3.1.2),不需要计算整个注意力权重矩阵。
而‘分块’操作的难点在于softmax的计算,softmax计算中分母位置包含所有元素的求和项(该项用于归一化),论文重点描述了softmax的‘分块’计算。
3.1.1 先来看看标准softmax计算流程(无分块)
这里有两个版本的softmax,softmax(x)应该是我们常见的理论上的实现方式;但是在实际操作中,我们通常使用safe_softmax(x),笔者已经替大家试过了,两者结果一致。
# safe softmax / 安全 softmax
def safe_softmax(x):
# 防止数值计算时的下溢,先将x中的每个元素减去x中的最大值
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
# 一般形式softmax
def softmax(x):
e_x_2 = np.exp(x)
return e_x_2 / e_x.sum(axis=0)
# input = np.array([1, 2, 3, 4])
# output = [0.0320586 0.08714432 0.23688282 0.64391426]
让我们继续用一个case,推理下safe_softmax的计算:x = [1, 2, 3, 4]
a. 计算组间最大值,防止计算下溢,m(x) = max(x) = 4
b. 指数计算,
c. 计算softmax分母 / 归一化因子,
d. softmax计算,
3.1.2 继续看看分块softmax计算流程:举例推理
举例推理:简单起见,这里分为2块计算,x1 = [1, 2],x2 = [3, 4]
a. 计算第一块内的最大值 m(x1) = max(x1) = 2 = m(x) {记录全局最大值m(x)}
b. 第一个块内,进行指数计算 。初始化赋值f(x)=f(x1)
c. 第一个块内,计算归一化因子 。注意,这里是中间变量
d. 开始操作第二个模块,更新到此刻为止的最大值 m(x)=m(x2) = max(m(x1), x2) = 4
补充:论文提供的伪代码中,使用for循环处理每个块,每一步都会更新最大值。
e. 第二个块内,进行指数计算
f. 第二个块内,计算归一化因子 。注意,这里仍然是中间变量。
g. 柔和两个块的中间结果计算全局f(x) 和 l(x)
注:至此,各位会发现分块计算的f(x)和l(x)到了这一步的结果和不分块计算的结果一致。
3.1.3 补充 + 尚存问题
tiling 操作在FlashAttention中是一个贯穿正向传播和反向传播的重要策略。它不仅在正向传播中用于分块处理输入矩阵以提高计算效率和减少内存使用,还在反向传播中用于优化内存访问和重新计算必要的中间变量。
3.2 recomputation 重计算
Recomputation是一种算力换内存的操作,即基于trade-off的思想。在上述分析中重点在于优化访问开销,既然GPU计算时间 小于 HBM读写时间,那么就不存储注意力计算过程中的中间结果,而是在某层反向传播中临时计算梯度更新所需的正向传播的中间状态。
相对于标准注意力机制从HBM中读取很大的中间注意力矩阵,重新计算尽管增加了额外的计算量FLOPs,但仍能够减少运行时间。由下图可见,虽然增加了FLOPs,但是减少了HBM的读写量,最终耗时性能收益明显。
注1:在这里(反向传播),仅保存了前向 tiling 过程中的两个统计量 m(x) 和 l(x);
注2:在正向传播中,变量S、P(见2.1)不会被保存;但是在反向传播中需要计算S、P关于Q、K、V的偏导,然后用于更新权重,在这里是重新计算中间结果S和P。
注3:在recomputation中同样基于 tiling 平铺的思想重新计算所需的注意力权重矩阵。看到这么一种说法:“recomputation 可以看作是一种基于 tiling 的特殊的 gradient checkpointing”。
3.3 Kernal Fusion
核心思想是将多个操作融合成一个操作,以此减少HBM的访问次数。tiling 分块计算使得可以用一个Kernal完成注意力的所有操作。
例如:在 SRAM 中计算完 𝑆 之后紧接着就通过 𝑆 计算 𝑃 ,这样可以避免在 HBM 和 SRAM 交换 𝑆 。
3.4 不确定的部分
笔者猜测全流程:从HBM加载输入数据(如完整的Q、K),然后‘分块’加载到SRAM执行计算,在SRAM基于一个Kernal Fusion的概念,将mask、softmax、dropout等计算完整,最后将结果写回HBM。整个流程只有‘两次’读写HBM操作?
是否是这个样子,各位可以评论区留言。
但是,看伪代码,for循环不断的从HBM加载数据到SRAM,这一步也需要消耗吧。
四、论文伪代码解析
4.1 FlashAttention前向传播
按行数进行代码描述
首先确定SRAM的大小,记M,保证Q、K、V和结果O的分块能够保留在SRAM内;
1. 计算 ‘块数’ or 列大小 Bc
2. 在HBM中初始化输出矩阵O,中间变量l和m,其中m用于记录每一行中行最大值,初始化-inf;
3. 将Q、K、V切块,块数分别为Tr 、Tc、Tc;
4. 将2中初始化的O、l、m切块,块数和Q一样,均为Tr;
5+6. 外层循环,将 Kj、Vj 从HBM加载到SRAM;
7+8. 内层循环,将Qi、Oi、li、mi 从HBM加载到SRAM;
9. 开始注意力机制的计算,计算中间变量 Sij;
10. 计算Sij每一行的最大值,记mij(Sij是一个Br * Bc的矩阵,有Br行);按行开展safe_softmax指数运算得到Pij(约等于第三章中的f(x));计算Pij每一行的和,记Lij(softmax分母);
11. 计算 mi(new)、li(new),这一步类比3.1.2中(d,e,f,g),再更新最大值之后,计算分母累计值;
12. 累加计算注意力(KV部分)更新Oi并写入HBM,供下一轮循环读取;
13. 重新赋值并将当前累积 li、mi 写入HBM;在下一轮中,将作为上一轮的累积结果
补充:GPU内多线程分块读取 + 计算。
作者还将Flash Attention扩展到了块稀疏注意力,产生了一种更优的近似的注意力算法。
4.2 反向传播过程(我要开始偷懒了)
已知前向过程只将Oi、li、mi 写入了HBM,并没有保存S和P,再根据标准self_attention反向传播计算dQ、dK、dV的公式(如下图,图来自于原论文最后的补充材料),分块计算结果。
‘分块’attention 反向传播伪代码如下:
1~4. 前向过程会保留Q,K,V,O,l,m在HBM中,dO由反向传播计算得到后,按照和前向传播相同的分块模式重新分块;
5. 初始化dQ,dK,dV为全0矩阵,并按照对等Q,K,V的分割方式分割dQ,dK,dV;
6~10. 外循环:从HBM中读取K、V 块到SRAM;内循环:读取Q块到SRAM;
11~20. 根据前向过程重新计算对应的Sij和Pij;按分块矩阵的方式分别计算对应梯度d(Sij)和d(Pij)
21~end. 累积形式更新dQ、dK、dV
五、总结
FlashAttention是通过减少HBM访问开销、以内存换时间等操作优化后的精准Attention,虽然多了很多计算步骤,可能会导致一定的精度损失,但仍然能够保证模型在处理复杂任务时的精确性和可靠性。
核心收益如下:
长文本处理能力:更小的内存(显存复杂度从O(N^2)降低到了O(N)) + 更快的推理速度(减少HBM访问),这些特性扩展了文本处理长度限制,C哈她GLM2应用该技术后,将文本可处理长度从2K提升到了32K。大预言模型能够处理更长、更复杂的文本数据。这一改进推进了‘长文本’的处理和模型效果优化。
增强上下文理解能力:更长的输入,可能会增强对长历史对话的理解能力,确保模型在多轮对话中能够准确捕捉和整合上下文信息。
灵活的组件:FlashAttention可以应用于各种类型的神经网络,包括卷积神经网络(CNN)、循环神经网络(RNN)和Transformer等。这种灵活性使得FlashAttention能够在多种场景和任务中发挥作用。
主要缺陷:
硬件依赖:FlashAttention起作用的一部分起因是计算开销 < 访问开销,因此能够起到更好的作用,就比较依赖于内存带宽和计算带宽。
额外的调度配置:分块、动态规划(累积计算中间结果和最终结果)和缓存机制等方法来优化计算过程,那么在GPU内不同线程之间如何调度、如何分区的配置需要根据任务和数据反复调试,以找到最佳配置。