斯坦福开源FlashAttention,大模型速度翻倍

news2024/11/23 13:51:53

一年时间,斯坦福大学提出的新型 Attention 算法 ——FlashAttention 完成了进化。这次在算法、并行化和工作分区等方面都有了显著改进,对大模型的适用性也更强了。

近来,几种长上下文语言模型陆续问世,包括 GPT-4(上下文长度为 32k)、MosaicML 的 MPT(上下文长度为 65k)Anthropic 的 Claude(上下文长度为 100k)。长文档查询和故事写作等新兴用例已经表明扩展语言模型上下文窗口是非常必要的。

然而,扩大 Transformer 的上下文长度是一个挑战,因为其核心的注意力层在时间复杂度和空间复杂度与输入序列长度的平方成正比。

一年前,来自斯坦福大学、纽约州立大学布法罗分校的研究者共同提出一种快速、内存高效的注意力算法 ——FlashAttention。该算法无需任何近似即可加速注意力并减少内存占用。现在,已经有许多机构和研究实验室采用 FlashAttention 来加速训练和推理。

1fb7c9bf6bb8558e5a5e3438d79c47a1.jpeg

FlashAttention 示意图

尽管 FlashAttention 的速度已经是优化基线的 2-4 倍,但它仍然有相当大的改进空间。FlashAttention 仍然不如优化过的矩阵乘法 (GEMM) 运算快,仅达到理论最大 FLOPs/s 的 25-40%。

现在,研究团队宣布推出 FlashAttention-2。FlashAttention-2 完全从头开始重写,使用 Nvidia 的 CUTLASS 3.x 及其核心库 CuTe 的原语(primitive)。

02353ffa869d5da53c8c218c310dd5c5.jpeg

FlashAttention-2 开发者 Tri Dao。他是斯坦福大学博士生,还是 Together.AI 首席科学家,并将于 2024 年 9 月开始任职普林斯顿大学计算机科学助理教授。

FlashAttention-2 的速度是 FlashAttention 的 2 倍,在 A100 GPU 上达到 230 TFLOPs/s。在端到端训练 GPT 类语言模型时,FlashAttention-2 可让训练速度高达 225 TFLOPs/s(模型 FLOP 利用率为 72%)。

FlashAttention-2 将加速现有模型的训练、微调和推理。这意味着我们可以用相同成本训练 2 倍上下文长度的语言模型。这将有助于语言模型理解长篇书籍和报告、高分辨率图像、音频和视频。


FlashAttention 是什么?

FlashAttention 是一种重新排序注意力计算的算法,它利用平铺、重计算等经典技术来显著提升计算速度,并将序列长度中的内存使用实现从二次到线性减少。其中平铺意味着将输入块从 HBM(GPU 内存)加载到 SRAM(快速缓存),并对该块执行注意力操作,更新 HBM 中的输出。

此外通过不将大型中间注意力矩阵写入 HBM,内存读写量减少,带来了 2-4 倍的时钟时间加速。

下图为 FlashAttention 的前向传递图:通过平铺和 softmax 重新缩放,研究者按块进行操作,避免从 HBM 中读取 / 写入,同时获得正确的输出,无需近似操作。

b90d47c12993c88c28ee6b05688438f7.jpeg

然而,FlashAttention 仍然存在一些低效率问题,原因在于不同线程块之间的工作分区不理想以及 GPU 上的 warp。这些导致低占用率或不必要的共享内存读写。

FlashAttention-2

更好的算法、并行化和工作分区

更少的非矩阵乘法 Flops

研究者调整了 FlashAttention 的算法,从而减少了非矩阵乘法(non-matmul)的 Flops 数量。这点很重要,因为现代 GPU 具有专门的计算单元(例如 Nvidia GPU 上的张量核心),使得矩阵乘法速度更快。

举例而言,A100 GPU 的 FP16/BF16 矩阵乘法的最大理论吞吐量为 312 TFLOPs/s,但非矩阵乘法 FP32 的理论吞吐量仅为 19.5 TFLOPs/s。

换一种思考方式,每个非矩阵乘法 FLOP 比矩阵乘法 FLOP 的代价高 16 倍。为了保持高吞吐量,研究者希望在矩阵乘法 FLOP 上花费尽可能多的时间。因此他们重写了 FlashAttention 中使用的在线 softmax 技巧,以减少重新缩放操作、边界检查和因果掩码操作的数量,而无需更改输出。

更好的并行化

FlashAttention v1 在批大小和头(head)数量上进行并行化。研究者使用 1 个线程块来处理一个注意力头,总共有(批大小 * 头数量)个线程块。每个线程块都计划在流式多处理器(SM)上运行,例如 A100 GPU 上有 108 个这样的 SM。当这个数字非常大(如 >= 80)时,这种调度是有效的,这时可以高效地使用 GPU 上几乎所有计算资源。

在长序列的情况下(通常意味着小批量或少量头),为了更好地利用 GPU 上的多处理器,现在研究者在序列长度维数上额外地进行并行化,使该机制显著加速。

更好的工作分区

即使在每个线程块内,研究者也必须决定如何在不同的 warp 之间划分工作(一组 32 个线程一起工作)。通常情况下,每个线程块使用 4 或 8 个 warp,分区方案如下图所述。

研究者改进了 FlashAttention-2 中的这种分区,减少不同 warp 之间的同步和通信量,进而减少共享内存读写。

9c95005eebe02ffb592cb6bb7a574fb4.jpeg

对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。这被称为「sliced-K」方案。不过,这种方案是低效的,原因在于所有 warp 都需要将它们的中间结果写入共享内存,并同步,然后将中间结果相加。这些共享内存读写会减慢 FlashAttention 中的前向传递速度。

在 FlashAttention-2 中,研究者将 Q 分割在 4 个 warp 上,同时保持 K 和 V 可被所有的 warp 访问。每个 warp 执行矩阵乘法以获得 Q K^T 的切片,然后只需与 V 的共享切片相乘就能获得相应的输出切片。warp 之间不需要通信。共享内存读写的减少也可以提升速度。

新特性:头维数高达 256、多查询注意力

我们知道,FlashAttention 仅支持最高 128 的头维数,这适用于大多数模型,但有一些模型被遗漏了。

因此,FlashAttention-2 支持了高达 256 的头维数,这意味着 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以使用 FlashAttention-2 来获得加速和节省内存。

此外,FlashAttention-2 还支持了多查询注意力(multi-query attention, MQA)以及分组查询注意力(grouped-query attention, GQA)。它们是注意力的变体,其中多个查询头关注相同的键和值头,以减少推理过程中 KV 缓存的大小,并可以显著提高推理吞吐量。

注意力基准结果

研究者在 A100 80GB SXM4 GPU 上,测量不同设置(无 / 有因果掩码、头维数 64 或 128)下不同注意力方法的运行时。

结果发现, FlashAttention-2 的速度是 FlashAttention(以及 xformers 库和 Triton 中的其他实现)的 2 倍。与 PyTorch 中的标准注意力实现相比,FlashAttention-2 的速度最高是它们的 9 倍。

0f9b7a28ec52f4bf6467922559087a9e.jpeg

A100 GPU 上的注意力前向 + 后向速度。

此外只需要在 H100 GPU 上 运行相同的实现(不使用特殊指令来利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高获得了 335 TFLOPs/s。

12a6a8c5836bc4fc3cce3abf12b085ce.jpeg

H100 GPU 上的注意力前向 + 后向速度。

当用于端到端 GPT 类模型训练时,FlashAttention-2 有助于在 A100 GPU 上实现最高 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。与优化良好的 FlashAttention 模型相比,端到端实现 1.3 倍加速。

5a5fc404ea63afa8ce313e98687e2bdd.jpeg

这里的基线是不使用 FlashAttention 的 Megatron-LM,它现在也可以选择使用 FlashAttention 了。不久的将来,FlashAttention-2 也将集成到 Megatron-LM 中。

研究团队表示:下一步将针对 H100 GPU 优化 FlashAttention-2,以使用新的硬件功能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/785916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C#性能】C# 语言中的数组迭代

一、说明 可迭代性,是数组等操作的根本;在C程序开发过程中,可迭代操作是非常普遍、非常广泛的,然而,对这种操作知道多少,又不知道多少,都将影响开发灵活性、开发的进度。因此,本文干…

vs2013 编译wxwidgets界面库

首先进入官网下载,本人再git上下载的基本都编译失败 https://www.wxwidgets.org/ 在网站里面找最新的就可以,下载之后放在一个目录,找到vs的目录 然后找到wx_vc12.sln,打开编译即可 Debug、Release编译出来的是静态库 DLL Deb…

P1149火柴棒等式题解

P1149[NOIP2008 提高组] 火柴棒等式 题目描述 给你 n n n 根火柴棍,你可以拼出多少个形如 A B C ABC ABC 的等式?等式中的 A A A、 B B B、 C C C 是用火柴棍拼出的整数(若该数非零,则最高位不能是 0 0 0)。用…

华南理工大学电信学院信号与系统实验1~4(杨俊美老师)

博客中仅放了题目,源码在下方链接: 链接:https://pan.baidu.com/s/1TpRAiNltybYWimETKHFKWA?pwd88ux 提取码:88ux --来自百度网盘超级会员V4的分享 有可能实验的序号可能编错码或者放错文件夹(因为老师给的时候很多…

题目4 命令执行(保姆级教程)

url:http://192.168.154.253:84/ #打开http://XXX:81/,XXX为靶机的ip地址 审题 1、打开题目看到有一个提示,此题目需要通过利用命令执行漏洞执行Linux命令获取webshell,最后从根目录下key.php文件中获得flag 2、开始答题 第一步&…

【Kafka】消息队列Kafka基础

目录 消息队列简介消息队列的应用场景异步处理系统解耦流量削峰日志处理 消息队列的两种模式点对点模式发布订阅模式 Kafka简介及应用场景Kafka比较其他MQ的优势Kafka目录结构搭建Kafka集群编写Kafka一键启动/关闭脚本 Kafka基础操作创建topic生产消息到Kafka从Kafka消费消息使…

GB/T 25000.51解读——软件产品的可靠性怎么测?

GB/T 25000.51-2016《软件产品质量要求和测试细则》是申请软件检测CNAS认可一定会用到的一部国家标准。在前面的文章中,我们为大家整体介绍了GB/T 25000.51-2016《软件产品质量要求和测试细则》国家标准的结构和所涵盖的内容以及对软件产品的八大质量特性中的功能性…

re学习(24)攻防世界 Hello CTF(进制转换)

一.题目链接: https://adworld.xctf.org.cn/challenges/list 二.使用步骤 int __cdecl main(int argc, const char **argv, const char **envp) {int i; // ebxchar v4; // alint result; // eaxint v6; // [esp0h] [ebp-70h]int v7; // [esp0h] [ebp-70h]char Buffer[2]; //…

微信内测朋友圈可以置顶了

我是卢松松,点点上面的头像,欢迎关注我哦! 微信现在算是垄断了聊天平台,每出个小功能都引起争议,这次朋友圈置顶就是个很好的例子:**微信内测朋友圈可以置顶了。已经上了大新闻、大热门!**如下图所示&…

10个流行的硬盘数据恢复软件:轻松恢复丢失的文件!

数据丢失对于个人和企业来说都是一场噩梦。无论是由于意外删除、硬件故障还是恶意攻击,丢失重要数据都可能导致重大挫折。 这就是硬盘数据恢复软件的用武之地!使用正确的数据恢复工具,您可以检索有价值的文件并迅速回到正轨。我们将探索十种…

OpenCV系列__chapter2

这里写目录标题 1 图像加减乘除位运算1.1 加法 img cv2.add(img1, img2)1.2 减法 img cv2.subtract(img1, img2)1.3 乘法 img cv2.multiply(img1, img2)1.4 除法 img cv2.divide(img1, img2)1.5 位运算 2 图像增强2.1 线性变换2.2 非线性变换 3 图像几何变换3.1 裁剪、放大…

需求分析案例:餐厅排队叫号送券需求

本文介绍一个餐饮系统的排队送券需求,针对该需求进行边界划分和技术实现方案的过程。 1、需求提出 还是前文说的互联网餐饮系统,提供了餐厅排队叫号能力。 因为有些热门餐厅,吃饭的人太多了,就出现了很多排队时间长,…

SkyWalking链路追踪-技术文档首页

SkyWalking 文档中文版(社区提供) (skyapm.github.io)https://skyapm.github.io/document-cn-translation-of-skywalking/ SkyWalking-基本概念 SkyWalking链路追踪是一个用于分布式系统的性能监控工具,它帮助开发人员了解系统中各组件之间…

【如何训练一个中译英翻译器】LSTM机器翻译模型部署之ncnn(python)(四)

ncnn:https://github.com/Tencent/ncnn 1、.h5模型保存为TFSaveModel格式 import tensorflow as tf from keras.models import load_model# 加载Keras模型 model load_model(encoder_model.h5)# 转换为SavedModel类型 tf.saved_model.save(model, TFSaveModel)2、…

ssm 图书借阅管理系统 【纯干货分享,免费领源码06780】

大数据时代下,数据呈爆炸式地增长。为了迎合信息化时代的潮流和信息化安全的要求,利用互联网服务于其他行业,促进生产,已经是成为一种势不可挡的趋势。在图书馆的要求下,开发一款整体式结构的图书借阅管理系统&#xf…

5.44 综合案例2.0-矩阵键盘信息输入上传-OLED屏幕

综合案例2.0-矩阵键盘信息输入上传-OLED屏幕 案例说明1、应用场景2、M320矩阵引脚说明3、接线说明 搭建云平台环境1.添加设备2.创建设备类型3.功能定义(创建物模型) 代码1.更改MQTT信息 测试 案例说明 矩阵键盘输入信息显示在OLED显示屏上。按确定键可以…

Django 图书管理系统

一、功能及页面设计 二、页面展示 (1)首页 (2)注册 (3)登录 (4)普通用户登录 4.1查看图书页面 4.2查看图书详情页 4.3修改密码 (5)管理员登录 5.1添加图书 5.2添加图片 三、代码展示 因为代码太多不好一个个展示 所以需要源码的小伙伴可以找我要代码 感谢三连支持&#xff0…

云计算需求激增带来的基础设施挑战及解决方案

云计算的指数级增长迅速改变了我们消费和存储数字信息的方式。随着企业和个人越来越依赖基于云的服务和数据存储,对支持这些服务的强大且可扩展的基础设施的需求已达到前所未有的水平。 云计算需求的快速增长 我们的日常生活越来越多地被新技术所渗透。流媒体服务、…

剑指 Offer 29. 顺时针打印矩阵 / LeetCode 54. 螺旋矩阵(模拟)

题目: 链接:剑指 Offer 29. 顺时针打印矩阵;LeetCode 54. 螺旋矩阵 难度:中等 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: 输入&#xff1a…

Fiddler Everywhere(TTP调试抓包工具) for Mac苹果电脑版

Fiddler Everywhere for Mac版是Mac电脑上的一款跨平台的HTTP调试抓包工具,Fiddler Everywhere for Mac能够记录客户端与服务器之间的所有HTTP(S)通信,支持对包进行监视、分析、设置断点、甚至修改请求/响应数据等操作。 适用于任…