知识回顾 - 《Flash Attention为什么这么快?》

news2025/1/12 3:52:08

作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra...论文地址: https://arxiv.org/abs/2205.14135项目地址: https://github.com/Dao-AILab/flash-attention

摘要

        Transformers在处理长序列时速度慢且内存消耗大,因为自注意力的时间和内存复杂度与序列长度的平方成正比。近似注意力方法试图通过权衡模型质量来减少计算复杂度,但往往无法实现实际的速度提升。我们认为一个缺失的原则是使注意力算法考虑输入/输出(IO)-- 即考虑在GPU内存层次之间的读写操作。我们提出了FlashAttention,这是一种考虑IO的精确注意力算法,通过分块减少GPU高带宽内存(HBM)和GPU片上SRAM之间的内存读写次数。我们分析了FlashAttention的IO复杂度,显示它比标准注意力需要更少的HBM访问,并且在多种SRAM大小下表现最优。我们还将FlashAttention扩展到块稀疏注意力,产生了一种比现有任何近似注意力方法更快的近似注意力算法。FlashAttention使变压器训练速度比现有基准更快:在BERT-large(序列长度512)上相比MLPerf 1.1训练速度记录提升15%,在GPT-2(序列长度1K)上提升3倍,在长范围arena(序列长度1K-4K)上提升2.4倍。FlashAttention和块稀疏FlashAttention支持Transformers处理更长的上下文,从而生成更高质量的模型(GPT-2上的困惑度提升0.7点,长文档分类提升6.4点),以及全新的能力:在Path-X挑战(序列长度16K,61.4%准确率)和Path-256(序列长度64K,63.1%准确率)上实现了超越偶然表现的首个变压器。

原理概览

        对论文总结Flash Attention快的原因是,这种新型的注意力算法,旨在解决传统注意力机制在大型模型中的存储和带宽开销问题。它通过减少内存读写次数,提高了注意力算法在GPU内存IO的效率,从而加快模型训练速度并增加上下文窗口大小。总览图如下:

 

左图:FlashAttention通过分块技术来避免在(相对较慢的)GPU高带宽内存(HBM)上生成大规模的N×N 注意力矩阵(虚线框)。在外层循环中(红色箭头),FlashAttention遍历K和V矩阵的块,并将它们加载到快速的片上SRAM中。在每个块中,FlashAttention遍历Q矩阵的块(蓝色箭头),将其加载到SRAM中,并将注意力计算的结果写回到HBM中。右图:在GPT-2上相较于PyTorch实现的注意力的加速比。FlashAttention避免了对大型N×N 注意力矩阵的读写操作,导致注意力计算速度提升了7.6倍。

注:HBM通常指的是“高带宽内存”(High Bandwidth Memory)。HBM是一种先进的垂直堆叠式DRAM内存技术,由SK海力士(SK Hynix)和三星(Samsung)等公司开发。这种类型的内存被设计用于高性能计算系统,如高端显卡、超级计算机和数据中心服务器。

原理拆分

1)首先标准注意力机制实现如下公式:

图片

其中输入序列

图片

,N是序列长度,d是隐藏层维度数,要想计算注意力输出

图片

,那计算流程如上图,这里softmax是按行计算的。

标准的注意力实现将矩阵 SP 放到高带宽内存 (HBM) 中,这需要 𝑂(𝑁^2) 的内存。通常 𝑁 ≫ 𝑑(例如,对于 GPT2,𝑁 = 1024 和 𝑑 = 64)。我们在下图伪代码中描述了标准的注意力实现。由于某些或大多数操作是受内存限制的(例如,softmax),大量的内存访问会导致较慢的实际时间。这一问题由于应用于注意力矩阵的其他逐元素操作(如对 S 进行掩码操作或对 P 进行 dropout 操作)而变得更加严重。因此,已经有许多尝试将多个逐元素操作融合在一起,例如将掩码与 softmax 融合 。

 

从伪代码中可以看到IO的读取非常的繁琐耗时:

图片

均存储在HBM中

1. 从HBM中加载QK到SRAM中;

2. 计算

图片

,将s写入HBM中;

3. 从HBM中读取s加载到SRAM;

4. 计算

图片

,将P写入HBM中;

5. 从HBM中读取PV并加载到SRAM;

6. 计算

图片

,将O写入HBM,返回O

2)从这张图我们也可以知道,大的矩阵乘法和多通道卷积计算耗时比较少,注意力计算中的dropout,softmax,mask比较耗时:

图片

因此对其的优化一般是进行分块计算(Tiling and Recomputation)再融合操作,不对中间结果缓存,减少HBM的访问耗时。

Tiling and Recomputation的主要思想是,把输入Q、K、V分成多个块,将它们从slow HBM加载到 fast SRAM,然后计算相对于这些块的注意力输出。在通过正确的归一化因子缩放每个块的输出然后相加,便可得到正确结果。可以用一个CUDA kernel来执行注意力的所有操作。从HBM中加载输入数据,在SRAM中执行所有的计算操作(矩阵乘法,mask,softmax,dropout,矩阵乘法),再将计算结果写回到HBM中。通过kernel融合将多个操作融合为一个操作,避免了反复地从HBM中读写数据。

那么注意力计算中sofamax是如何切分计算的,需要对softmax进行缩放计算(在实际硬件中,浮点数表示的范围是有限的。对于float32和bfloat16来说,当 x≥89时,e^x就会变成inf,发生数据上溢的问题。为了避免发生数值溢出的问题,保证数值稳定性,计算时通常会“减去最大值”,称为“safe softmax”),具体公式如下:

        对于重计算,在本文注意力实现中,后向传递计算Q,K,V 的梯度时,需要用NxN的中间矩阵S,P,但这两个矩阵并没有保存下来。这里的技巧是重计算,保存了两个统计量m(x), l(x),后向传递时在高速的SRAM上快速地重新计算Attention,通过分块的方式重新计算注意力矩阵S,P,因为这种方式,将显存复杂度从O(N^2)降到了O(N)。相比于标准注意力中,从HBM中读取很大的中间注意力矩阵的方法,重计算的方法要快得多。

        Flash Attention通过kernel融合和分块计算,大量减少了HBM访问次数,尽管由于后向传递中的重计算增加了额外的计算量FLOPs,但依旧减少了运行时间,下图是GPT-2 medium在标准注意力和 FlashAttention 的向前 + 向后运行时间,可以看到提升了4.7倍。

 

最后对于flash attention2做了一下三点的优化:

1. 减少非矩阵乘法的计算,利用TensorCore加速

2. 调整O为外层训练,K,V为内层循环,减少HBM读取

3. 对于一个Block块处于矩阵上三角部分(被mask的部分),则不进行Attention计算

参考:

1. https://www.bilibili.com/video/BV1UT421k7rA/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=01b54c990198e640f937517e2d38c7db

2.https://zhuanlan.zhihu.com/p/639228219?s_r=0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2144060.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

误删分区后的数据救赎恢复实战解析

在数字化时代,数据不仅是信息的载体,更是个人记忆与企业资产的宝贵财富。然而,误删分区这一操作失误,却如同暗流涌动,悄无声息地吞噬着用户的重要数据。本文将深入探讨误删分区的现象、影响,并详细介绍一种…

【Linux】探索文件I/O奥秘,解锁软硬链接与生成动静态库知识

目录 1、C文件接口 1.1什么是当前路径? 1.2程序默认打开的文件流: 2、系统文件I/O 2.1.接口介绍: 2.1.1open: 参数讲解; flags如何实现一个参数就可以有多个参数传参的效果? open函数的返回值: 3…

线程池ThreadPoolExecutor实战及其原理分析

1. 线程池简介 线程池(Thread Pool)是一种基于池化思想管理线程的工具,经常出现在多线程服务器中,如Tomcat。 线程过多会带来额外的开销,其中包括创建销毁线程的开销、调度线程的开销等等,同时也降低了计算…

香港科技大学工学院2025/2026年度硕士研究生(MSc)项目招生宣讲会——华南师范大学佛山校区

🔔香港科技大学工学院2025/2026年度硕士研究生(MSc)项目招生宣讲会 🕙时间:2024年9月26日(星期四)19:00 🏠地点:华南师范大学佛山校区图书馆电影院 🎆2024T…

Spring6梳理9—— 依赖注入之外部注入对象类型属性

9.1 依赖注入之外部注入对象类型属性 9.1.1 创建dept与emp类 1.dept类 package com.atguigu.spring6.iocxml.ditest;//部门类 public class Dept {private String dname;public String getDname() {return dname;}public void setDname(String dname) {this.dname dname;…

【算法】遗传算法

一、引言 遗传算法(Genetic Algorithm, GA)是一种模拟生物进化过程的启发式搜索算法,它通过模拟自然选择、遗传、交叉和突变等生物学机制来优化问题的解决方案。遗传算法因其通用性、高效性和鲁棒性,在多个领域中得到了广泛应用&a…

【Java】网络编程:TCP_IP协议详解(IP协议数据报文及如何解决IPv4不够的状况)

🌈个人主页:努力学编程’ ⛅个人推荐: c语言从初阶到进阶 JavaEE详解 数据结构 ⚡学好数据结构,刷题刻不容缓:点击一起刷题 🌙心灵鸡汤:总有人要赢,为什么不能是我呢 &#x1f354…

Nest.js

Nestjs中文文档链接 TypeORM 中文文档 小满视频 1. 安装Nest.js 安装脚手架 npm i -g nestjs/cli创建nestjs工程 nest new工程目录 app.module.ts 根模块用于处理其他类的引用与共享。app.controller.ts 常见功能是用来处理http请求(处理请求的路径&#xff09…

.net core8 使用JWT鉴权(附当前源码)

说明 该文章是属于OverallAuth2.0系列文章,每周更新一篇该系列文章(从0到1完成系统开发)。 该系统文章,我会尽量说的非常详细,做到不管新手、老手都能看懂。 说明:OverallAuth2.0 是一个简单、易懂、功能强…

焦虑拜拜!这些维生素是你的情绪小太阳✨,焦虑星人必看!

🌿 ‌维生素B群:情绪的调节大师‌ 🎯 说到缓解焦虑,怎能不提维生素B群?它可是个大家庭,包括B1、B2、B6、B12等,每一个都是调节神经系统的关键角色。维生素B群能够促进神经递质的合成&#xff0…

Prometheus监控k8s环境构建

传统架构中比较流行的监控工具有 Zabbix、Nagios 等,这些监控工具对于 Kubernetes 这类云平台的监控不是很友好,特别是当 Kubernetes 集群中有了成千上万的容器后更是如此,本章节学习下一代的云原生监控平台---Prometheus。 一、基于kuberne…

DNS解析域名详解

你有没有想过,当一个url传过来网络对它进行了哪些操作~DNS又是怎样对域名进行解析的~或者我们为什么要用到域名,为什么不直接使用ip地址~ 对于我们而言,面对长串的ip地址,我们更喜欢记忆较短的域名,但是对于路由器来说…

第二证券:降息升温!资金涌入港股,行情还能持续多久?

在美联储行将打开降息影响下,多国股指改写高点。 当时,商场环绕美联储是25个基点仍是50个基点的降息展开预期买卖,资金流向风险财物规划扩大显着。17日,澳大利亚S&P/ASX 200指数股指、印度孟买SENSEX30指数、新加坡富时海峡指…

MySQL函数:日期函数

先贴一张黑马程序员的听课截图 1.返回当前日期 CURDATE(); select CURDATE(); //获取当前日期2. 返回当前时间 CURTIME(); select CURTIME(); //获取当前时间3.返回当前日期和时间NOW() select NOW(); //获取当前日期和时间 4.获取指定date的年份YEAR(date) select YEAR…

力扣(LeetCode)每日一题 2848. 与车相交的点

题目链接https://leetcode.cn/problems/points-that-intersect-with-cars/description/?envTypedaily-question&envId2024-09-15 给你一个下标从 0 开始的二维整数数组 nums 表示汽车停放在数轴上的坐标。对于任意下标 i,nums[i] [starti, endi] ,…

[Python]一、Python基础编程

F:\BaiduNetdiskDownload\2023人工智能开发学习路线图\1、人工智能开发入门\1、零基础Python编程 1. Python简介 Python优点: 学习成本低开源适应人群广泛应用领域广泛1.1 Python解释器 下载地址:Download Python | Python.org 1.2 Python开发IDE -- Pycharm 2. 基础语法…

人工智能(AI)的影响下人类的生活样子

讨论在人工智能(AI)的影响下人类的生活是什么样子 在21世纪的今天,人工智能(AI)已经不再是遥不可及的未来科技,而是悄然渗透到我们日常生活的每一个角落,以一种前所未有的方式改变着我们的生活方式、工作模式乃至社会…

使用 Python 绘制 BTC 期权的波动率曲面

波动率曲面(Volatility Surface)是期权交易中展示隐含波动率随行权价(strike price)和到期时间(expiry time)变化的一种三维图形。 本文尝试通过 Python,通过 ccxt 基于从交易所获取期权的指标…

远程连接MySQL并操作

配置MySQL开发环境 如果你使用的是基于Debian的系统(如Ubuntu),可以在终端通过如下步骤安装MySQL开发包。 更新软件包列表 运行以下命令以确保你拥有最新的软件包列表。 sudo apt-get update安装libmysqlclient-dev开发包 执行以下命令以…

python-字符排列问题

题目描述 有 n 个字母,列出由该字母组成的字符串的全排列(相同的排列只计一次)。输入格式 第一行输入是字母个数 n 。 接下来一行输入的是待排列的 n 个字母。输出格式 计算出的 n 个字母的所有不同排列总数。样例输入输出样例输入 4 aacc样例…