FlashAttention解析——大预言模型核心组建

news2024/9/17 8:38:53

论文名称:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

论文地址:https://arxiv.org/abs/2205.14135        

一、研究FlashAttention的Motivate       

        FlashAttention技术在现在的主流大语言模型中均有应用,其主要作用是减少Transformer结构中运算(主要是self-attention,包括softmax、dropout等)的显存消耗,进一步解除文本处理长度限制,使得模型能够处理更长、更复杂的文本数据 以及 多轮对话功能。

        让我们继续看看原论文中的说法:

        Transformer 作为语言模型的基础架构提供了强大的特征表达能力,已经作为LLM的基础模型构件被大量使用。

        在Transformer中核心组件是多头自注意力(multi-head selft-attention),这里的计算复杂度和空间复杂度是序列长度的二次方O(n2)。因此长文本处理仍然面临挑战。

        当然有许多尝试用于减少注意力的计算和内存开销。例如,稀疏近似和低秩近似得方法,将计算复杂度降低到序列长度的线性或亚线性,但这些方法主要关注FLOPs(浮点数计算次数)的减少(这部分消耗主要由矩阵运算提供),而忽略了IO读写的内存访问开销。

        由下图可以看到,GPT-2中的标准attention,耗时对比:矩阵运算 < softmax < Dropout。在现代GPU中,计算速度超过显存访问速度。基于这样的发现,论文作者将突破‘超长文本处理’的契机放在了注意力的IO瓶颈。论文团队在对GPU硬件和注意力实现进行性能剖析后,将性能瓶颈锁定在‘HBM内存的读写压力过大’,指标论文的主要优化方向为‘降低HBM的IO次数’。

二、标准注意力机制与HBM的访问关系

2.1 标准Attention机制推理过程

        Q,K,V\epsilon R^{N\times D}, Attention(Q, K, V)=softmax(\frac{Q*K^{T}}{\sqrt{d}})V

        将上面的步骤进行拆解可以得到

        S=QK^{T} \epsilon R^{N\times N}P=softmax(S) \epsilon R^{N\times N}O=PV \epsilon R^{N\times d}

        如下图所示,一次标准Attention的实现需要多次读写HBM

        1. 按块从HBM中读取矩阵Q和K,计算S,并将S写入HBM;

        2. 从HBM中读取S,计算完P=softmax(S)之后,将P写入HBM;

        3. 按块从HBM读取中间结果P和V,计算O=PV,将O写入HBM;

        4. 返回O

        注意:笔者不清楚Q和K是同时读取还是分为2次;有相关科普说是分别读取(读两次HBM)

2.2 GPU结构的一些知识

        这里是论文中给出的GPU A-100的内存结构:

        1. HBM(High Bandwidth Memory,高带宽存取存储器)

                由多个DRAM堆叠。

        2. SRAM(Static Random-Access Memory, 静态随机访问存储器)

                用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高。由图中可见SRAM的执行/读写速度是HBM的12.67倍,但存储空间远远小于HBM。

三、FlashAttention

        FlashAttention,总的来说是一种优化访问开销精准注意力算法。

        motivation:从GPU内存结构来看,要想提升Attention的性能,应该让计算过程尽可能在SRAM中进行。由于序列长度 N 可能会很长,无法将Q、K、V 以及中间结果完整存储在SRAM中,因此FlashAttention就采用了‘分块’操作,每块的计算所需内存不超过SRAM大小。

        这里有两种核心操作:tiling(平铺) 、recomputation(重计算) ,最后使用 kernel fusion 进行融合。

        1. tiling:利用更高速的SRAM代替HBM;

        2. recomputation:放弃中间结果写回,需要使用时再次计算,用计算Trade-off访存;

        3. kernel fusion:基于Tiling使用一个kernel完成整个计算。

3.1 tiling 平铺

        tiling 基本思路:不直接对整个输入序列计算注意力,而是根据SRAM大小将其分为多个较小的块,逐个对‘块’进行计算,在计算过程中增量式(详见3.1.2)地进行softmax的逼近。在整个计算过程中只需要更新某些中间变量(如全局最大值,详见3.1.2),不需要计算整个注意力权重矩阵。

        而‘分块’操作的难点在于softmax的计算,softmax计算中分母位置包含所有元素的求和项(该项用于归一化),论文重点描述了softmax的‘分块’计算。

3.1.1 先来看看标准softmax计算流程(无分块)

        这里有两个版本的softmax,softmax(x)应该是我们常见的理论上的实现方式;但是在实际操作中,我们通常使用safe_softmax(x),笔者已经替大家试过了,两者结果一致

# safe softmax / 安全 softmax
def safe_softmax(x):
    # 防止数值计算时的下溢,先将x中的每个元素减去x中的最大值
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

# 一般形式softmax
def softmax(x):
    e_x_2 = np.exp(x)
    return e_x_2 / e_x.sum(axis=0)

# input = np.array([1, 2, 3, 4])
# output = [0.0320586  0.08714432 0.23688282 0.64391426]

        让我们继续用一个case,推理下safe_softmax的计算:x = [1, 2, 3, 4]

        a. 计算组间最大值,防止计算下溢,m(x) = max(x) = 4

        b. 指数计算,f(x) = [e^{1-m(x)}, e^{2-m(x)}, e^{3-m(x)}, e^{4-m(x)}] = [e^{-3}, e^{-2}, e^{-1}, e^{0}]

        c. 计算softmax分母 / 归一化因子, l(x) = e^{-3} + e^{-2} + e^{-1} + e^{-0}

        d. softmax计算, softmax(x) = \frac{f(x)}{l(x)}

3.1.2 继续看看分块softmax计算流程:举例推理

举例推理:简单起见,这里分为2块计算,x1 = [1, 2],x2 = [3, 4]

        a. 计算第一块内的最大值 m(x1) = max(x1) = 2 = m(x) {记录全局最大值m(x)}

        b. 第一个块内,进行指数计算 f(x1) = [e^{1-m(x1)}, e^{2-m(x1)}] = [e^{-1}, e^{0}]。初始化赋值f(x)=f(x1)

        c. 第一个块内,计算归一化因子 l(x1) = e^{-1} + e^{-0}。注意,这里是中间变量

        d. 开始操作第二个模块,更新到此刻为止的最大值 m(x)=m(x2) = max(m(x1), x2) = 4

                补充:论文提供的伪代码中,使用for循环处理每个块,每一步都会更新最大值。

        e. 第二个块内,进行指数计算  f(x2) = [e^{1-m(x2)}, e^{2-m(x2)}] = [e^{3-4}, e^{4-4}] = [e^{-1}, e^{0}]

        f. 第二个块内,计算归一化因子 l(x2) = e^{-1} + e^{-0}。注意,这里仍然是中间变量。

        g. 柔和两个块的中间结果计算全局f(x) 和 l(x)

f(x) = [e^{m(x1) - m(x)}f(x1), e^{m(x2) - m(x)}f(x2)]

f(x) = [e^{2-4}(e^{-1}, e^{0}), e^{4-4}(e^{-1}, e^{0})] = [e^{-2}(e^{-1}, e^{0}), e^{0}(e^{-1}, e^{0})] 

       l(x) = e^{m(x1) - m(x)} * l(x1) + e^{m(x2) - m(x)} * l(x2)

 l(x) = e^{2-4} * (e^{-1} + e^{-0}) + e^{4-4} * (e^{-1} + e^{-0}) = e^{-3} + e^{-2} + e^{-1} + e^{-0}

        :至此,各位会发现分块计算的f(x)和l(x)到了这一步的结果和不分块计算的结果一致。

3.1.3 补充 + 尚存问题

        tiling 操作在FlashAttention中是一个贯穿正向传播和反向传播的重要策略。它不仅在正向传播中用于分块处理输入矩阵以提高计算效率和减少内存使用,还在反向传播中用于优化内存访问和重新计算必要的中间变量。

3.2 recomputation 重计算

        Recomputation是一种算力换内存的操作,即基于trade-off的思想。在上述分析中重点在于优化访问开销,既然GPU计算时间 小于 HBM读写时间,那么就不存储注意力计算过程中的中间结果,而是在某层反向传播中临时计算梯度更新所需的正向传播的中间状态。

        相对于标准注意力机制从HBM中读取很大的中间注意力矩阵,重新计算尽管增加了额外的计算量FLOPs,但仍能够减少运行时间。由下图可见,虽然增加了FLOPs,但是减少了HBM的读写量,最终耗时性能收益明显。

        注1:在这里(反向传播),仅保存了前向 tiling 过程中的两个统计量 m(x) 和 l(x);

        注2:在正向传播中,变量S、P(见2.1)不会被保存;但是在反向传播中需要计算S、P关于Q、K、V的偏导,然后用于更新权重,在这里是重新计算中间结果S和P。

        注3:在recomputation中同样基于 tiling 平铺的思想重新计算所需的注意力权重矩阵。看到这么一种说法:“recomputation 可以看作是一种基于 tiling 的特殊的 gradient checkpointing”。

3.3 Kernal Fusion

        核心思想是将多个操作融合成一个操作,以此减少HBM的访问次数。tiling 分块计算使得可以用一个Kernal完成注意力的所有操作。        

        例如:在 SRAM 中计算完 𝑆 之后紧接着就通过 𝑆 计算 𝑃 ,这样可以避免在 HBM 和 SRAM 交换 𝑆 。

3.4 不确定的部分

        笔者猜测全流程:从HBM加载输入数据(如完整的Q、K),然后‘分块’加载到SRAM执行计算,在SRAM基于一个Kernal Fusion的概念,将mask、softmax、dropout等计算完整,最后将结果写回HBM。整个流程只有‘两次’读写HBM操作?

        是否是这个样子,各位可以评论区留言。

        但是,看伪代码,for循环不断的从HBM加载数据到SRAM,这一步也需要消耗吧。

        

四、论文伪代码解析

4.1 FlashAttention前向传播

按行数进行代码描述

首先确定SRAM的大小,记M,保证Q、K、V和结果O的分块能够保留在SRAM内;

1. 计算 ‘块数’ or 列大小 Bc

2. 在HBM中初始化输出矩阵O,中间变量l和m,其中m用于记录每一行中行最大值,初始化-inf;

3. 将Q、K、V切块,块数分别为Tr 、Tc、Tc;

4. 将2中初始化的O、l、m切块,块数和Q一样,均为Tr;

5+6. 外层循环,将 Kj、Vj 从HBM加载到SRAM;

7+8. 内层循环,将Qi、Oi、li、mi 从HBM加载到SRAM;

9. 开始注意力机制的计算,计算中间变量 Sij;

10. 计算Sij每一行的最大值,记mij(Sij是一个Br * Bc的矩阵,有Br行);按行开展safe_softmax指数运算得到Pij(约等于第三章中的f(x));计算Pij每一行的和,记Lij(softmax分母);

11. 计算 mi(new)、li(new),这一步类比3.1.2中(d,e,f,g),再更新最大值之后,计算分母累计值;

12. 累加计算注意力(KV部分)更新Oi并写入HBM,供下一轮循环读取;

13. 重新赋值并将当前累积 li、mi 写入HBM;在下一轮中,将作为上一轮的累积结果

补充:GPU内多线程分块读取 + 计算。

作者还将Flash Attention扩展到了块稀疏注意力,产生了一种更优的近似的注意力算法。

4.2 反向传播过程(我要开始偷懒了)

        已知前向过程只将Oi、li、mi 写入了HBM,并没有保存S和P,再根据标准self_attention反向传播计算dQ、dK、dV的公式(如下图,图来自于原论文最后的补充材料),分块计算结果。

        ‘分块’attention 反向传播伪代码如下:

        1~4. 前向过程会保留Q,K,V,O,l,m在HBM中,dO由反向传播计算得到后,按照和前向传播相同的分块模式重新分块;

        5. 初始化dQ,dK,dV为全0矩阵,并按照对等Q,K,V的分割方式分割dQ,dK,dV;

        6~10. 外循环:从HBM中读取K、V 块到SRAM;内循环:读取Q块到SRAM;

        11~20. 根据前向过程重新计算对应的Sij和Pij;按分块矩阵的方式分别计算对应梯度d(Sij)和d(Pij)

        21~end. 累积形式更新dQ、dK、dV

五、总结  

        FlashAttention是通过减少HBM访问开销、以内存换时间等操作优化后的精准Attention,虽然多了很多计算步骤,可能会导致一定的精度损失,但仍然能够保证模型在处理复杂任务时的精确性和可靠性。

        核心收益如下:

        长文本处理能力:更小的内存(显存复杂度从O(N^2)降低到了O(N)) + 更快的推理速度(减少HBM访问),这些特性扩展了文本处理长度限制,C哈她GLM2应用该技术后,将文本可处理长度从2K提升到了32K。大预言模型能够处理更长、更复杂的文本数据。这一改进推进了‘长文本’的处理和模型效果优化。

        增强上下文理解能力:更长的输入,可能会增强对长历史对话的理解能力,确保模型在多轮对话中能够准确捕捉和整合上下文信息。

        灵活的组件:FlashAttention可以应用于各种类型的神经网络,包括卷积神经网络(CNN)、循环神经网络(RNN)和Transformer等。这种灵活性使得FlashAttention能够在多种场景和任务中发挥作用。

        主要缺陷

        硬件依赖:FlashAttention起作用的一部分起因是计算开销 < 访问开销,因此能够起到更好的作用,就比较依赖于内存带宽和计算带宽。

        额外的调度配置:分块、动态规划(累积计算中间结果和最终结果)和缓存机制等方法来优化计算过程,那么在GPU内不同线程之间如何调度、如何分区的配置需要根据任务和数据反复调试,以找到最佳配置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024靠这份软件测试面试题宝典已成功上岸,跳槽成功

上月很多朋友靠这份面试宝典拿到大厂的office&#xff0c;跳槽成功&#xff0c;面试找工作的小白和要跳槽进阶都很适合&#xff0c;没有一点准备怎么能上岸成功呢&#xff1f; 这份面试题宝库&#xff0c;包含了很多部分&#xff1a;测试理论&#xff0c;Linux基础&#xff0c…

Java中级

IDAE介绍 IDEA&#xff0c;全称为IntelliJ IDEA&#xff0c;是一款由JetBrains公司开发的集成开发环境&#xff08;IDE&#xff09;&#xff0c;主要用于Java开发&#xff0c;但也支持多种其他编程语言和框架&#xff0c;如Kotlin、Scala、Groovy、Android、Spring、Hibernate…

SpringBoot Mysql->达梦8 activiti6.0.0 项目迁移

全部源码&#xff1a;公众号搜索资小库&#xff0c;回复dm获取源码 1.整合达梦 1.1 达梦驱动下载 MyBatis-Plus 框架 | 达梦技术文档 (dameng.com) 1.2 数据迁移 怎么安装数据库&#xff0c;很多大佬有帖子&#xff0c;搜一下达梦先建立用户&#xff0c;使用DM管理工具 链…

【进阶篇-Day13:JAVA中IO流之字节缓冲流的介绍】

目录 1、IO流介绍2、IO流体系结构2.1 FileOutputStream 字节输出流&#xff08;1&#xff09;字节输出流操作方法&#xff1a;&#xff08;2&#xff09; 标准的关流代码&#xff1a; 2.2 FileInputStream 字节输入流&#xff08;1&#xff09;字节输入流操作方法&#xff1a; …

Glove-词向量

文章目录 共现矩阵共线概率共线概率比词向量训练总结词向量存在的问题 上一篇文章词的向量化介绍了词的向量化&#xff0c;词向量的训练方式可以基于语言模型、基于窗口的CBOW和SKipGram的这几种方法。今天介绍的Glove也是一种训练词向量的一种方法&#xff0c;他是基于共现概率…

【每日一题】【回溯+二进制优化】[USACO1.5] 八皇后 Checker Challenge C\C++\Java\Python3

P1219 [USACO1.5] 八皇后 Checker Challenge [USACO1.5] 八皇后 Checker Challenge 题目描述 一个如下的 6 6 6 \times 6 66 的跳棋棋盘&#xff0c;有六个棋子被放置在棋盘上&#xff0c;使得每行、每列有且只有一个&#xff0c;每条对角线&#xff08;包括两条主对角线的…

Python设置Excel单元格中的部分文本颜色

文章目录 一、概述二、效果三、示例 一、概述 openpyxl &#xff08;目前&#xff09;不支持设置单元格内部分字体颜色 xlsxwriter 支持设置单元格内部分字体颜色&#xff08;创建新的Excel&#xff09; 二、效果 三、示例 """ Python设置Excel单元格中的部分…

昇思 25 天学习打卡营第 24 天 | MindSpore Pix2Pix 实现图像转换

1. 背景&#xff1a; 使用 MindSpore 学习神经网络&#xff0c;打卡第 24 天&#xff1b;主要内容也依据 mindspore 的学习记录。 2. PixPix 介绍&#xff1a; MindSpore 的 Pix2Pix 图像转换 介绍 Pix2Pix是基于条件生成对抗网络&#xff08;cGAN, Condition Generative Ad…

Oracle如何跨越incarnation进行数据恢复

作者介绍&#xff1a;老苏&#xff0c;10余年DBA工作运维经验&#xff0c;擅长Oracle、MySQL、PG、Mongodb数据库运维&#xff08;如安装迁移&#xff0c;性能优化、故障应急处理等&#xff09; 公众号&#xff1a;老苏畅谈运维 欢迎关注本人公众号&#xff0c;更多精彩与您分享…

Skywalking 入门与实战

一 什么是 Skywalking? Skywalking 时一个开源的分布式追踪系统&#xff0c;用于检测、诊断和优化分布式系统的功能。它可以帮助开发者和运维人员深入了解分布式系统中各个组件之间的调用关系、性能瓶颈以及异常情况&#xff0c;从而提供系统级的性能优化和故障排查。 1.1 为…

笑谈“八股文”,人生不成文

一、“八股文”在实际工作中是助力、阻力还是空谈&#xff1f; 作为现在各类大中小企业面试程序员时的必问内容&#xff0c;“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢&#xff1f;有IT人士不禁发出疑问&#xff1a;程序员面试…

AcWing3302. 表达式求值

代码解释 while(j<str.size()&&isdigit(str[j])){xx*10str[j]-0;}把字符串中里面连续的数字转化为int类型变量&#xff0c;比如输入996/3328,正常的挨个字符扫描只能扫到’9’,‘9’,‘6’,但是按照上面代码的算法是重新开了一个循环&#xff0c;直接把’9’,‘9’,…

【网络请求调试神器,curl -vvv 返回都有什么】

curl -vvv 是一个用于在命令行中执行 HTTP 请求的命令&#xff0c;其中 -vvv 是一个选项&#xff0c;用于启用详细的调试输出。 vvv: 这是一个选项&#xff0c;表示启用详细的调试输出。每个 v 增加调试信息的详细程度&#xff0c;vvv 是最高级别的详细输出。 详细输出包括&a…

【shell脚本快速一键部署项目】

目录 一、环境拓扑图二、主机环境描述三、注意四、需求描述五、shell代码的编写六、总结 一、环境拓扑图 二、主机环境描述 主机名主机地址需要提供的服务content.exam.com172.25.250.101提供基于 httpd/nginx 的 YUM仓库服务ntp.exam.com172.25.250.102提供基于Chronyd 的 NT…

GPU池化:点燃Jupyter Notebook中的AI算力之火

数据科学的火花在Jupyter Notebook中点燃&#xff0c;而GPU的加入&#xff0c;让这火焰更加炽热&#xff01;随着人工智能领域的飞速发展&#xff0c;利用GPU加速已成为数据科学和机器学习领域的新常态。 今天&#xff0c;我们要探索的&#xff0c;是Jupyter Notebook与GPU池化…

PHP学习:PHP基础

以.php作为后缀结尾的文件&#xff0c;由服务器解析和运行的语言。 一、语法 PHP 脚本可以放在文档中的任何位置。 PHP 脚本以 <?php 开始&#xff0c;以 ?> 结束。 <!DOCTYPE html> <html> <body><h1>My first PHP page</h1><?php …

spaCy语言模型下载

spaCy 是一个基于 Python 编写的开源自然语言处理&#xff08;NLP&#xff09;库&#xff0c;它提供了一系列的工具和功能&#xff0c;用于文本预处理、文本解析、命名实体识别、词性标注、句法分析和文本分类等任务。 spaCy支持多种语言模型对文本进行处理&#xff0c;包括中文…

自己在Vmware中搭建mqtt服务器

前言 在学习某个HMI的使用的时候&#xff0c;这个HMI带有MQTT功能&#xff0c;就想着自己是不是能够搭建一个自己的MQTT的服务器呢&#xff1f; 一、mqtt 自己搭建之一&#xff1a;Mosquitto 自己搭建MQTT服务器需要安装和运行MQTT服务软件&#xff0c;比如常用的是Mosquitto…

Tkinter简介与实战(1)

Tkinter简介与实战---实现一个计算器 Tkinter简介安装环境和安装命令WindowsmacOSLinux 注意事项使用正确的包管理器&#xff1a;检查安装完整性&#xff1a;更新 Python&#xff1a;使用虚拟环境&#xff1a; 一个实战例子-----计算器1.创建窗口&#xff1a;2.创建 GUI 组件&a…

学习大数据DAY27 Linux最终阶段测试

满分&#xff1a;100 得分&#xff1a;72 目录 一选择题&#xff08;每题 3 分&#xff0c;共计 30 分&#xff09; 二、编程题&#xff08;共 70…