FlashAttention和PagedAttention

news2024/11/23 13:24:53

FlashAttention

FlashAttention一般指的是FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness这篇,当然Transformer Quality in Linear Time这篇里非要说FLASH = Fast Linear Attention with a Single Head,命名有点无语,关于FLASH的细节参考 FLASH:可能是近来最有意思的高效Transformer设计 ,下面重点写写FlashAttention:
在这里插入图片描述
tiling中文是瓦片化,实际上就是把计算像瓦片一样铺向SRAM,保证运算不要频繁在SRAM和**HBM(High-Bandwidth Memory,HBM是高带宽内存,也就是我们常说的显存)**频繁切换,提高速度。

标准注意力的内存复杂度

对于标准注意力实现,初期我们需要把输入 Q , K , V \mathbf{Q}, \mathbf{K}, \mathbf{V} Q,K,V从HBM中读取,并计算完毕后把输出 O \mathbf{O} O写入到HBM中。
第一步把 Q , K \mathbf{Q}, \mathbf{K} Q,K读取出来计算出 S = Q K ⊤ \mathbf{S}=\mathbf{Q K}^{\top} S=QK,然后把 S \mathbf{S} S存回去,内存访问复杂度 Θ ( N d + N 2 ) \Theta\left(N d+N^2\right) Θ(Nd+N2)

第二步把 S \mathbf{S} S读取出来计算出 P = softmax ⁡ ( S ) \mathbf{P}=\operatorname{softmax}(\mathbf{S}) P=softmax(S),然后把 P \mathbf{P} P存回去,内存访问复杂度 Θ ( N 2 ) \Theta\left(N^2\right) Θ(N2)

第三步把 V , P \mathbf{V}, \mathbf{P} V,P读取出来计算出 O = P V \mathbf{O}=\mathbf{P} \mathbf{V} O=PV,然后计算出结果 O \mathbf{O} O,内存访问复杂度 Θ ( N d + N 2 ) \Theta\left(N d+N^2\right) Θ(Nd+N2)

综上所述,整体的内存访问复杂度为 Θ ( N d + N 2 ) \Theta\left(N d+N^2\right) Θ(Nd+N2)

FlashAttention的算法

前向传播时减少对内存的访问次数

FlashAttention关键的想法就是tile(分块),把QKV都拆成块。这里一个关键点是softmax怎么算,有点绕,简单说就是把每部分分子分母的和给存下来,归一化到相同的比例。下面是个具体的例子, l _ p r e l\_pre l_pre是分母缩最大倍数后的和,也是最绕的点。假设QK结果是[1,2],那么softmax结果就是
[ e 1 e 1 + e 2 , e 2 e 1 + e 2 ] [\frac{e^1}{e^1+e^2},\frac{e^2}{e^1+e^2}] [e1+e2e1,e1+e2e2]
再乘以V的结果就是:
e 1 ∗ v 1 e 1 + e 2 + e 2 ∗ v 2 e 1 + e 2 \frac{e^1*v_1}{e^1+e^2}+\frac{e^2*v_2}{e^1+e^2} e1+e2e1v1+e1+e2e2v2
如果拆成两步算,第一步:
c u r _ s u m = e 1 ∗ v 1 e 1 m _ p r e = m a x ( e 1 ) = e 1 , 是分子 e 的和 l _ p r e = s u m ( e 1 ) = e 1 , 是分母 e 的和 cur\_sum = \frac{e^1*v_1}{e^1} \\ m\_pre = max(e^1)=e^1,是分子e的和 \\ l\_pre = sum(e^1)=e^1,是分母e的和 cur_sum=e1e1v1m_pre=max(e1)=e1,是分子e的和l_pre=sum(e1)=e1,是分母e的和
第二步:
m _ c u r = m a x ( e 2 , m _ p r e ) = e 2 l _ p r e ∗ = e m _ p r e − m _ c u r = e 1 − 2 ,分母缩共同倍数后相加 l _ c u r = s u m ( e 2 − 2 ) + l _ p r e c u r _ s u m = c u r _ s u m ∗ l _ p r e l _ c u r = e 1 ∗ v 1 e 1 ∗ e − 1 e − 1 + e 0 c u r _ s u m + = v 2 ∗ c u r _ s u m l _ p r e = e 1 ∗ v 1 e 1 + e 2 + e 2 ∗ v 2 e 1 + e 2 m\_cur = max(e^2,m\_pre)=e^2 \\ l\_pre *= e^{m\_pre - m\_cur}=e^{1-2} ,分母缩共同倍数后相加\\ l\_cur = sum(e^{2-2})+l\_pre\\ cur\_sum=cur\_sum*\frac{l\_pre}{l\_cur}=\frac{e^1*v_1}{e^1}*\frac{e^{-1}}{e^{-1}+e^0}\\ cur\_sum+=\frac{v_2*cur\_sum}{l\_pre}=\frac{e^1*v_1}{e^1+e^2}+\frac{e^2*v_2}{e^1+e^2} m_cur=max(e2,m_pre)=e2l_pre=em_prem_cur=e12,分母缩共同倍数后相加l_cur=sum(e22)+l_precur_sum=cur_suml_curl_pre=e1e1v1e1+e0e1cur_sum+=l_prev2cur_sum=e1+e2e1v1+e1+e2e2v2

这样,在前向的过程中,我们采用分块计算的方式,避免了矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗。详细的复杂度分析可以参考原文和https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh

反向传播时使用重新计算(recompute的方式来更新梯度)

我们这里则采用重新计算的方式来计算对应的梯度。在上面前向计算的时候我们不会存储 S , P \mathbf{S}, \mathbf{P} S,P矩阵,但是我们会存储对应的指数项之和 L L L来进行梯度的计算。这里不展开写了,细节可以参考原文和https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh
目前,Flash Attention已经集成至torch2.0,并且社区也提供了多种实现

PagedAttention

源自vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention这篇paper,关键的技术有两点:

KVCache

KV Cache是大模型推理优化的一个常用技术,该技术以空间换时间的思想,通过使用上次推理的KV缓存,可以在不影响任何计算精度的前提下,提高推理性能,降低端到端的时延。

以GPT为代表的Decoder-Only自回归语言模型在生成每一个新的 token 时,接受所有之前生成的 tokens 作为输入。然而,对于这些先前生成的 tokens,每次生成新的 token 时都需要重新计算他们的表示,这个过程造成了大量的计算浪费。KV Cache 的引入就是为了解决这个问题。

KV Cache实质上是存储了之前计算过的 key-value 对用于下一个Token的生成。在 Transformer 结构中,self-attention 中的k_proj, v_proj会将输入的每个 token 转化为一个 key 和一个 value,然后使用这些 key-value 以及当前的query对来计算下一个 token。引入 KV Cache,我们就可以将之前生成的 tokens 对应的 key-value 对存储起来,当生成新的 token 时,直接从 KV Cache 中取出这些已经计算好的 key-value 对,再把当前token的key-value做一个连结在进行计算,这样就避免了KV的重复计算,大大提高了计算效率。

整体来说,使用KV Cache包含以下两个步骤:

  • 预填充阶段:在计算第一个输出token过程中,此时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢,这时属于Compute-bound类型计算。
  • KV Cache阶段:在计算第二个输出token至最后一个token过程中,此时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

PagedAttention

通过KV Cache的技术,我们已经可以极大地提升LLM地推理速度,但是现有的Cache仍存在一些问题,

  • Large:对于LLaMA-13B中的单个序列,它占用高达1.7GB的内存。
  • Dynamic:它的大小取决于序列长度,而序列长度具有高度可变和不可预测的特点。
    因此,高效地管理KV Cache是一个重大挑战。现有系统(HuggingFace 默认实现是pytorch的内存分配策略)由于内存碎片化和过度预留而浪费了60%至80%的内存。
    为了解决这个问题,我们引入了PagedAttention,这是一种受传统操作系统虚拟内存和分页概念启发的注意力算法。与传统的注意力算法不同,PagedAttention允许将连续的键和值存储在非连续的内存空间中。具体而言,PagedAttention将每个序列的KV缓存分成多个块,每个块包含固定数量的标记的键和值。在注意力计算过程中,PagedAttention Kernel高效地识别和获取这些块,采用并行的方式加速计算。(和ByteTransformer的思想有点像)

内存布局

由于块在内存中不需要连续存储,我们可以像操作系统的虚拟内存那样以更加灵活的方式管理键和值的缓存:可以将块看作页,标记看作字节,序列看作进程。序列的连续逻辑块通过块表映射到非连续的物理块。随着生成新的标记,序列的边长,物理块按需进行分配。

在PagedAttention中,内存浪费仅发生在序列的最后一个块中。这样就使得我们的方案接近最优的内存使用率,仅有不到4%的浪费。通过内存效率的提升,我们能够显著提升BatchSize,同时进行多个序列的推理,提高GPU利用率,从而显著提高吞吐量。

PagedAttention:Cache在物理上不必连续
PagedAttention:Cache在物理上不必连续
使用 PagedAttention 的请求的示例生成过程使用 PagedAttention 的请求的示例生成过程

内存共享

在并行采样中,从相同的提示生成多个输出序列。在这种情况下,可以在输出序列之间共享提示的计算和内存。通过其块表,PagedAttention能够自然地实现内存共享。类似于进程共享物理页,PagedAttention中的不同序列可以通过将它们的逻辑块映射到相同的物理块来共享块。为确保安全共享,PagedAttention跟踪物理块的引用计数并实现 Copy-on-Write 机制。

通过PagedAttention的内存共享机制,极大地降低了复杂采样算法(如ParallelSampling和BeamSearch)的内存开销,使其内存使用量下降了高达55%。这项优化可以直接带来最多2.2倍的吞吐量提升,从而使得LLM服务中使用这些采样方法变得更加实用。

同时进行多输出的采样
同时进行多输出的采样
多输出采样的物理展示
多输出采样的物理展示

部分引用自:

  1. FLASH:https://arxiv.org/pdf/2202.10447.pdf
  2. FlashAttention:https://arxiv.org/pdf/2205.14135.pdf
  3. https://zhuanlan.zhihu.com/p/582606847
  4. https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh
  5. https://readpaper.feishu.cn/docx/EcZxdsf4uozCoixdU3NcW03snwV
  6. https://zhuanlan.zhihu.com/p/638468472

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/685637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript数据存储方式

内置对象 js内部提供的对象,包含各种属性和方法给开发者调用 document.write() console.log() Math Math对象是js提供的一个 “数学”对象,提供了一系列做数学运算的方法 max找最大值Math.max(3,8,5,4) 返回8min找最小值Math.min(3,8,5,4) 返回4ab…

printf不一样的玩法

Printf不一样的玩法 ❝ 在使用linux终端命令的时候,我们可以看到像more命令,它的显示方式与一般的字符串不同,是用了反显。同样,linux C下printf还有很多其他不常见的格式化输出形式。本文主要为你盘点这些形式。 ❞ 先看下效果&a…

MySQL:单行函数(全面详解)

MySQL:单行函数 前言一、函数的理解1、什么是函数2、不同DBMS函数的差异3、MySQL的内置函数及分类 二、数值函数1、基本函数2、角度与弧度互换函数3、三角函数4、指数与对数5、进制间的转换 三、字符串函数四、日期和时间函数1、获取日期、时间2、日期与时间戳的转换…

Bpmn.js流程建模结合业务整合工作流(二)

上一篇文章讲述了bpmn.js的基本搭建使用过程,下面介绍工具栏的按钮使用 以及右侧属性如何绑定到节点保存的 保存方法 /** 保存xml */async save() {await this.getNewXML() //获取最新的xmlawait this.getRootElement() //获取流程基本信息 节点信息const params = {name: th…

SSM+校园网上订餐系统 毕业设计-附源码211510

校园网上订餐系统的设计与实现 摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对校园网上订…

API 接口协作,swagger不再是第一选择了

目录 一、前言 1.1. 场景一、后端视角: 1.2. 场景二、前端视角: 1.3. 场景三、测试视角: 二、Apifox 2.1 场景一、后端视角: 漂亮的接口文档 2.2 场景二、前端视角: 2.3 场景三、测试视角: 三、总…

SOAP教程

参考 SOAP 教程 1、介绍 SOAP 中文解释为:简单对象访问协议。 SOAP 是一种简单的基于 XML 的协议,它使应用程序通过 HTTP 来交换信息。 SOAP 是基于 XML 的简易协议,可使应用程序在 HTTP 之上进行信息交换。或者更简单地说:SOAP…

ASP.NET Core Web API之Token验证

在实际开发中,我们经常需要对外提供接口以便客户获取数据,由于数据属于私密信息,并不能随意供其他人访问,所以就需要验证客户身份。那么如何才能验证客户的身份呢?今天以一个简单的小例子,简述ASP.NET Core…

一、枚举类型——使用枚举类型分发

如果将 RoShamBo1.java 直接转换为基于枚举的实现版本,则会出现问题。因为枚举实例并不是类型,所以无法重载 eval() 方法,你无法将枚举实例作为参数类型。不过,还有别的方法可以利用枚举来实现多路分发。 一种方法是通过构造方法…

实战:SonarQube平台安装配置-2023.6.24(安装成功)(docker方式)

实战:SonarQube平台安装配置-2023.6.24(安装成功)(docker方式) 目录 推荐文章 https://www.yuque.com/xyy-onlyone/aevhhf?# 《玩转Typora》 实验环境 sonarqube:9.9.0-community (docker方式部署) SonarScanner 4.8.0.2856 (部署在宿主机上)实验软件 链接&…

开关电源-PFC驱动电路的工作原理

PFC驱动电路的工作原理 由于PFC的控制地和MOS管组成的双向开关的源极不共地,因此需要解决开关管浮地驱动问题。 图2 驱动电路图 电路图说明: PFCPWM是DSP的PWM信号;VCC_4V和AGND是DSP侧的电源和控制地;Vccp_14V和AGND_DRV是MO…

echarts 的 一个图表容器,使用grid存放多个折线图,并配置x轴联动

效果图 配置参数 // prettier-ignore const data [["2000-06-05", 116], ["2000-06-06", 129], ["2000-06-07", 135], ["2000-06-08", 86], ["2000-06-09", 73], ["2000-06-10", 85], ["2000-06-11",…

开关电源- 用PFC拓扑电路对比

用PFC拓扑电路对比 最基本的有桥boost PFC电路 有桥boostPFC电路是最基本的电路,就不叙述了。 双Boost无桥PFC 双boost无桥拓扑的优点是使用功率元件比较少, 两个管子可以一起驱动, 这简化了驱动电路的设计, 同时让直接使用传统APFC的控制芯片成为可能.但是这种拓扑…

3.41 - haas506与esp8266-01s的串口通信(TCP透传)

haas506与esp8266-01s的串口通信 PC端调试wifi模块1.接线(与电脑通信)2.模式案例3.指令演示 开发板与wifi模块通信1.接线(TTL串口通信)2.代码测试 PC端调试wifi模块 esp8266-01s 1.接线(与电脑通信) 与电脑通信时引脚连接,wifi模块需要稳定3.3v供电,…

一个618项目的复盘总结反思

一、前言 618期间上线一个活动项目。但上线不顺利,当天就出现了性能问题,接口超时,用户无法打开网页,最后不得的临时下线。花了三天两夜,重构了后台核心代码,才让活动进行下去。 回头看了一下自己的时间记…

【零基础入门学习Python---Python条件和循环语句】

🚀 Python 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

OpenCV下载、环境变量配置

https://sourceforge.net/projects/opencvlibrary/files/ 选择合适的版本下载即可 这里使用opencv-3.1.0.exe执行 将需要bin目录加入到环境变量Path D:\opencv\31\opencv\build\x64\vc14\bin #include<opencv2/opencv.hpp> #include <iostream> using namespace …

【学习日记2023.6.25】之ElasticSearch搜索引擎

文章目录 分布式搜索引擎1.初识elasticsearch1.1.了解ES1.1.1.elasticsearch的作用1.1.2 ELK技术栈1.1.3 elasticsearch和lucene1.1.4 为什么不是其他搜索技术&#xff1f;1.1.5 总结 1.2 倒排索引1.2.1 正向索引1.2.2 倒排索引1.2.3 正向和倒排 1.3 es的一些概念1.3.1 文档和字…

[Web程序设计]实验: Servlet基础应用

一、实验目的 &#xff08;1&#xff09;掌握java web应用的基础和核心知识&#xff1a;servlet。 &#xff08;2&#xff09;理解servlet的具体使用。 二、实验内容 &#xff08;1&#xff09;编写一个servlet&#xff0c;实现统计网站被访问次数的功能&#xff1b; &…

SpringBoot 集成测试主要组件及其特点

SpringBoot 集成测试主要组件及其特点 随着SpringBoot的流行&#xff0c;集成测试也变得越来越重要。SpringBoot提供了一些主要组件来支持集成测试&#xff0c;本文将介绍这些组件及其特点。 1. Spring Test Spring Test是Spring框架提供的测试工具集&#xff0c;其主要目的是…