什么是 Flash Attention

news2025/1/16 3:36:05

Flash Attention 是 由 Tri Dao 和 Dan Fu 等人在2022年的论文 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 中 提出的,

论文可以从 https://arxiv.org/abs/2205.14135 页面下载,点击 View PDF 就可以下载。

下面我们通过详细解读这篇论文,来说明什么是Flash Attention。

Transformer在处理长序列时速度慢且占用大量内存,因为自注意力的时间和内存复杂度与序列长度的平方成正比。近似注意力方法尝试通过牺牲模型质量来减少计算复杂度来解决这个问题,但往往不能实现实际速度提升。我们认为一个缺失的原则是使注意力算法具有IO感知能力——考虑GPU内存层级之间的读取和写入。我们提出了FlashAttention,这是一种IO感知的精确注意力算法,使用平铺技术来减少GPU高带宽内存(HBM)和GPU片上静态随机存储器(SRAM)之间的内存读写次数。我们分析了FlashAttention的IO复杂度,表明它需要比标准注意力更少的HBM访问,并且在一定范围内的SRAM大小下是最优的。FlashAttention训练Transformer比现有基准更快:与MLPerf 1.1训练速度记录相比,BERT-large(序列长度为512)的端到端墙钟速度提高了15%,GPT-2(序列长度为1K)的速度提高了3倍,长距离竞技场(序列长度为1K-4K)的速度提高了2.4倍。

左图:FlashAttention使用平铺技术,防止在(相对)较慢的GPU HBM上生成大型的𝑁×𝑁注意力矩阵(虚线框)。在外部循环(红色箭头)中,FlashAttention循环遍历K和V矩阵的块,并将它们加载到快速的片上静态随机存储器(SRAM)中。在每个块中,FlashAttention循环遍历Q矩阵的块(蓝色箭头),将它们加载到SRAM,并将注意力计算的输出写回HBM。右图:与GPT-2上注意力的PyTorch实现相比的加速效果。FlashAttention不会将大型的𝑁×𝑁注意力矩阵读取和写入HBM,从而使注意力计算加速了7.6倍。

在现代GPU上,计算速度已经超过了内存读取速度,而Transformer中的大多数操作都受到内存访问的限制。对于类似受内存限制的操作,IO感知算法对于读取和写入数据可能占据运行时较大部分的情况至关重要,比如数据库连接、图像处理、数值线性代数等等。然而,诸如PyTorch和Tensorflow等常见的Python深度学习接口并不允许对内存访问进行精细控制。

我们提出了FlashAttention,这是一种新的注意力机制算法,可以在极少的内存访问次数下计算精确的注意力。我们的主要目标是避免从高带宽存储器(HBM)读取和写入注意力矩阵。这需要 (i) 在没有访问整个输入的情况下计算softmax归一化 (ii) 在反向传播过程中不存储大型的中间注意力矩阵。我们应用了两种成熟的技术来解决这些挑战。 (i) 我们重新构造了注意力计算,将输入分成块,并对输入块进行多次遍历,逐步执行softmax归一化(也称为平铺)。 (ii) 我们在正向传播过程中存储了softmax归一化因子,以便在反向传播过程中快速重新计算注意力,这比从HBM读取中间注意力矩阵的标准方法更快。我们在CUDA中实现了FlashAttention,以实现对内存访问的精细控制,并将所有的注意力操作融合到一个GPU核函数中。即使由于重新计算而增加了浮点运算次数,我们的算法在运行时比标准注意力更快,并且使用的内存量(与序列长度成正比)更少,这得益于大大减少的HBM访问量。

我们分析了FlashAttention的IO复杂度,证明它需要𝑂(𝑁^2𝑑^2/𝑀)次HBM访问,其中𝑑是头部维度,𝑀是SRAM的大小,而标准注意力的HBM访问次数为Ω(𝑁𝑑 + 𝑁^2)。对于𝑑和𝑀的典型值,我们证明FlashAttention相比标准注意力需要更少的HBM访问次数。此外,我们提供了一个下界,表明没有精确的注意力算法能够在所有SRAM大小上渐近地改进HBM访问次数。

GPU内存层次结构。GPU内存层次结构包括多种不同大小和速度的内存,较小的内存速度更快。以A100 GPU为例,其具有40-80GB的高带宽存储器(HBM),带宽为1.5-2.0TB/s,每个108个流多处理器具有192KB的片上静态随机存储器(SRAM),其带宽估计约为19TB/s。片上SRAM的速度比HBM快一个数量级,但在尺寸上小几个数量级。随着计算相对内存速度变得更快,操作越来越受到内存(HBM)访问的限制。因此,利用快速的SRAM变得更加重要。

执行模型。GPU拥有大量线程来执行操作(称为核函数)。每个核函数从HBM加载输入到寄存器和SRAM,进行计算,然后将输出写入HBM。

性能特征。根据计算和内存访问的平衡,操作可以被归类为计算受限或内存受限。这通常由算术密集度来衡量,即每个内存访问的算术操作数。

计算受限:操作所需时间取决于算术操作的数量,而在HBM上的访问时间要小得多。典型的例子包括具有大内部维度的矩阵乘法和具有大通道数量的卷积。
内存受限:操作所需时间取决于内存访问次数,而计算所需时间要小得多。例如,大多数其他操作:逐元素(例如,激活,丢弃)和缩减(例如,求和,softmax,批标准化,层标准化)。
核函数融合。加速内存受限操作的最常见方法是核函数融合:如果有多个操作应用于相同的输入,那么可以一次从HBM加载输入,而不是每个操作都进行多次加载。编译器可以自动融合许多逐元素操作。

然而,在模型训练的情境中,中间值仍然需要被写入HBM以供反向传播保存,降低了简单核函数融合的效果。

标准注意力实现

给定输入序列 Q、K、V ∈ R^(𝑁 ×𝑑) ,其中 𝑁 是序列长度,𝑑 是头维度,我们希望计算注意力输出 O ∈ R^(𝑁 ×𝑑) : S = QK^T∈ R^(𝑁 ×𝑁) ,P = softmax(S) ∈ R^(𝑁 ×𝑁) ,O = PV ∈ R^(𝑁 ×𝑑) , 其中 softmax 按行应用。 标准的注意力实现会将矩阵 S 和 P 实现在 HBM,这将占用 𝑂(𝑁^2 ) 的内存。 通常 𝑁 远大于 𝑑(例如,对于 GPT2,𝑁 = 1024,𝑑 = 64)。我们在算法 0 中描述了标准的注意力实现。由于一些或大部分操作是内存绑定的(例如 softmax),大量的内存访问会导致较慢的墙钟时间。 这个问题会因为其他应用于注意力矩阵的逐元素操作而变得更加严重,比如应用于 S 的掩码或应用于 P 的 dropout。因此,已经有许多尝试将多个逐元素操作融合在一起,比如将掩码与 softmax 融合在一起。标准注意力实现在序列长度 𝑁 的 HBM 访问方面呈二次增长。

算法 0

标准注意力实现 要求:矩阵 Q、K、V ∈ R^(𝑁 ×𝑑) 在 HBM 中。

 1: 从 HBM 中按块加载 Q、K,计算 S = QK>,将 S 写入 HBM。

 2: 从 HBM 中读取 S,计算 P = softmax(S),将 P 写入 HBM。

 3: 从 HBM 中按块加载 P 和 V,计算 O = PV,将 O 写入 HBM。

 4: 返回 O。

下面我们展示如何以更少的 HBM 读/写操作并且在不存储大型中间矩阵的情况下计算精确的注意力。这产生了一种既内存高效又在墙钟时间上更快的注意力算法。

一种具有平铺和重计算的高效注意力算法
给定输入Q、K、V ∈ R 𝑁 ×𝑑 在HBM中,我们的目标是计算注意力输出O ∈ R^(𝑁 ×𝑑) 并将其写入HBM。我们的目标是减少HBM访问的数量(至少是𝑁的平方级别)。我们应用了两种已建立的技术(平铺、重计算)来克服在子二次HBM访问中计算精确注意力的技术挑战。我们在算法1中描述了这一点。其主要思想是我们将输入Q、K、V分成块,从较慢的HBM加载到较快的SRAM,然后针对这些块计算注意力输出。通过在将每个块的输出按正确的归一化因子进行缩放后相加,我们最终得到了正确的结果。
平铺:我们通过块来计算注意力。Softmax将K的列耦合起来,因此我们分解了大型softmax并进行了缩放。为了数值稳定性,向量𝑥 ∈ R^𝐵的softmax计算如下:

对于向量 𝑥(1), 𝑥(2) ∈ R^𝐵,我们可以将拼接向量 𝑥 = [𝑥(1)𝑥(2)] ∈ R^(2𝐵) 的 softmax 分解为:

因此,如果我们跟踪一些额外的统计数据(𝑚(𝑥), ℓ(𝑥)),我们可以逐块计算softmax。因此,我们将输入 Q、K、V 分割成块(算法1第3行),计算softmax值以及额外的统计数据(算法1第10行),然后将结果合并(算法1第12行)。

重新计算。我们的一个目标是不要为反向传播存储𝑂(𝑁^2)个中间值。反向传播通常需要矩阵 S、P ∈ R^𝑁×𝑁 来计算相对于 Q、K、V 的梯度。然而,通过存储输出 O 和 softmax 归一化统计数据 (𝑚, ℓ),我们可以在反向传播中轻松地从 Q、K、V 的块中重新计算注意力矩阵 S 和 P。这可以看作是一种选择性梯度检查点技术。虽然梯度检查点技术被建议用于减少最大所需内存,但所有已知的实现都必须以速度为代价来换取内存。相比之下,即使有更多的浮点操作数,我们的重新计算也能加速反向传播,因为减少了高带宽内存访问。

实施细节:内核融合。分块使我们能够在一个CUDA内核中实现我们的算法,从高带宽内存加载输入,执行所有的计算步骤(矩阵乘法,softmax,可选的掩码和丢弃,矩阵乘法),然后将结果写回到高带宽内存。这避免了重复地从高带宽内存读取和写入输入和输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2121505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Php数组函数中的那些什么sort排序函数是不是很乱? 可以这样看。以及php搜索给定的值在数组中最后一次出现的位置的实现思考

一、Php数组函数中的那些什么sort排序函数是不是很乱? 可以这样看 PHP的数组函数真不少,甚至对一个程序员来说,在其整个程序生涯中有些方法他永远也不会用上。不过每一个方法都有其价值、或者在出现的时候有其价值。所以偶尔有空时还是可以去看看。在这…

并发编程:线程池(下)

一、线程池常用的阻塞队列有哪些? 新任务来的时候会先判断当前运行的线程数量是否达到核心线程数,如果达到的话,新任务就会被存放在队列中。 不同的线程池会选用不同的阻塞队列,我们可以结合内置线程池来分析。 容量为 Integer…

UE5 半透明阴影 快速解决方案

Step 1: 打开该选项 Step 2: 将半透明材质给到模型后,设置光照的Shadow Resolution Scale,越大,阴影的效果越好 Step 3: 用这种方式去做,阴影会因为半透明的程度,降低阴影的浓度 要…

Spring security 动态权限管理(基于数据库)

一、简介 如果对该篇文章不了解,请移步上一篇文章:spring security 中的授权使用-CSDN博客 当我们配置的 URL 拦截规则请求 URL 所需要的权限都是通过代码来配置的,这样就比较死板,如果想要调整访问某一个 URL 所需要的权限&…

【专项刷题】— 队列

1、N 叉树的层序遍历 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 每次遍历一层节点的时候就把当前节点的值添加到列表中再将当前层的节点的子节点添加到队列中每次遍历完一层之后就添加到总表中代码&#xff1a; public List<List<Integer>> levelO…

如何远程实时监控员工的电脑屏幕?远程桌面监控的五个可实现方法分享

想象一下&#xff0c;你在办公室喝着咖啡&#xff0c;员工的电脑屏幕却在数百公里之外实时呈现在你的眼前。你可以看到他们在干什么&#xff0c;是埋头工作还是悄悄摸鱼&#xff1f;远程桌面监控让这一切变得触手可及&#xff0c;简直像给了管理者一双“千里眼”&#xff01; 如…

RedisTemplate操作String的API

文章目录 1 String 介绍2 命令3 对应 RedisTemplate API❄️❄️ 3.1 添加缓存❄️❄️ 3.2 设置过期时间(单独设置)❄️❄️ 3.3 获取缓存值❄️❄️ 3.4 删除key❄️❄️ 3.5 顺序递增❄️❄️ 3.6 顺序递减 ⛄4 以下是一些常用的API⛄5 应用场景 1 String 介绍 String 类型…

9.10-AutoAWQ代码解析

1、首先要去官网下载源码。https://github.com/casper-hansen/AutoAWQ.githttps://github.com/casper-hansen/AutoAWQ.git 2、git clone后&#xff0c;下载AutoAWQ所需环境。 pip install -e . 3、查看quantize.py代码&#xff0c;修改model_path部分&#xff0c;修改为想要量…

系统架构师考试学习笔记第四篇——架构设计实践知识(19)嵌入式系统架构设计理论与实践

本章考点&#xff1a; 第19课时主要学习嵌入式系统架构设计的理论和工作中的实践。根据新版考试大纲&#xff0c;本课时知识点会涉及案例分析题&#xff08;25分&#xff09;。在历年考试中&#xff0c;案例题对该部分内容都有固定考查&#xff0c;综合知识选择题目中有固定分值…

您与该网站的连接不是私密连接,存在安全隐患

您与该网站的连接不是私密连接&#xff0c;存在安全隐患。 攻击者可能会试图窃取您的信息&#xff08;例如&#xff1a;密码、通讯内容或信用卡信息&#xff09;。为避免您的信息失窃&#xff0c;建议您停止访问该页面。了解详情 解决办法如下&#xff1a; 1、查看电脑时间&…

使用FastJson2将对象转成JSON字符串时,小数位“0”开头时转换出错

maven坐标&#xff1a; <dependency> <groupId>com.alibaba.fastjson2</groupId> <artifactId>fastjson2</artifactId> <version>2.0.40</version> </dependency> 问题现象&#xff1a; 问题原因&#xff1a; I…

IP路由选择

文章目录 1. 基本概念2. RIP(路由选择信息协议)3. OSPF 1. 基本概念 路由选择协议 路由选择协议让路由器能够动态地发现互联网络&#xff0c;并确保所有路由器的路由选择表都相同。路由选择协议还用于找出最佳路径&#xff0c;让分组穿越互联网络前往目的地的效率最高。RIP、R…

领夹麦克风哪个品牌好?无线领夹麦克风品牌大全,麦克风推荐

在这个全民直播、Vlog盛行的时代&#xff0c;一款轻便高效的无线领夹麦克风成了不少内容创作者的必备神器。但市面上产品五花八门&#xff0c;有的打着“超远传输、无损音质”的旗号&#xff0c;实则性能平平&#xff0c;甚至存在信号干扰、噪音大等问题&#xff0c;让人直呼交…

SpringBoot集成MyBatis-PlusDruid

目录 MyBatis-Plus简介 实例演示 创建Springboot项目 初始化Springboot项目 添加关键依赖 application.properties添加相关配置 启动类 编写实体类 编写mapper接口 条件构造器 分页插件 自定义 SQL 映射 MyBatis-Plus简介 MyBatis-Plus简介‌MyBatis-Plus‌&…

铁威马秋季新品即将上线,你想要的NAS我都有!

各位铁粉们&#xff0c;注意啦&#xff01; 一场关于存储的饕餮盛宴即将拉开帷幕 铁威马&#xff0c;带着九款全新力作NAS 将于9月19日席卷全球市场 是的&#xff0c;你没听错 九款&#xff01; 从入门级到专业级 从桌面型到机架式 全系搭载TOS 6 总有一款能击中你的心…

PCI 9054应用总结

1 PCI配置空间 1.1 BAR大小的确定 Linux kernel读取PCI BARn表示的内存长度时&#xff0c;先直接读取BARn的值&#xff0c;这个就是地址&#xff0c;然后再向BARn写入0xffff,ffff&#xff0c;再读取BARn的值就是需要的内存长度&#xff08;忽略bit3到bit0的处理&#xff09;&a…

微波无源器件 3 一种用于Ka频带双极化波束形成网络的双模三路功分器

摘要&#xff1a; 本文给出了一种用于Ka频带的双极化工作的双模3路功分器的设计和性能。对有着三个输出端口的平衡地很好的功分的TE10和TE01模式和27.5-30GHz上优于-23dB的输入匹配可以获得相似的性能。与双模定向耦合器相连结&#xff0c;此三路功分器对于双极化波束形成网络具…

【Go】Go语言介绍与开发环境搭建

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

Qt篇——Qt获取Windows电脑上所有外接设备的名称、物理端口位置等信息

我之前有发过一篇文章《Qt篇——获取Windows系统上插入的串口设备的物理序号》&#xff0c;文章中主要获取的是插入的USB串口设备的物理序号&#xff1b;而本篇文章则进行拓展&#xff0c;可以获取所有外接设备的相关信息&#xff08;比如USB摄像头、USB蓝牙、USB网卡、其它一些…

膨胀腐蚀操作opencv dilate膨胀白膨胀,erode腐蚀是黑吃白。主要针对二值图

效果&#xff1a; 代码&#xff1a; import cv2 import numpy as np from matplotlib import pyplot as pltif __name__ "__main__":h 10w 10data np.random.normal(0, 1, [h, w]) # sigma, 2*sigma, 3*sigma之间的数的比例分别为0.68&#xff0c; 0.96&#…