LLM Algorithms(1): Flash Attention

news2025/1/16 13:59:04

目录

Background

Flash Attention

Flash Attention Algorithm

参考


NIPS-2022:  Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • idea:减少资源消耗,提升或保持模型性能。
  • 普通attention的空间复杂度是O(N^2) --》降低到Flash Attention O(N)
  • Exact 结果相等。这不是attention的近似计算,Flash Attention的计算结果和原始方法一致。
  • IO aware. 和传统attention相比,Flash Attention会考虑硬件特性,而不是把它当作黑盒。 

Background

Nvidia GPU (GPU性能指标 = FLOPS / GB/s,FLOPS, GPU计算能力--每秒计算速度;GB/s,GPU内存吞吐量

  1. 2016-P100
  2. 2018-V100
  3. 2020-A100
  4. 2022-H100

多年来,GPU的计算能力(FLOPS)的增长速度比增加内存吞吐量(TB/s)更快。 

这两者需要紧密配合去达到数据处理的最优比,但自从硬件失去了这种平衡,我们必须通过软件来进行补偿。因此需要算法能够感知IO (IO-aware)。根据计算和内存访问比例,一个操作可以分为:

  1. 计算受限型 (e.g. 矩阵乘法)
  2. 内存受限型
    1. Element-wise 逐元素操作: activation, dropout, masking.
    2. Reduction 操作: softmax, layer norm, sum. 

element-wise操作是指在计算时只依赖当前值,比如每个元素都乘以2。而reduction依赖所有值(比如整个矩阵或矩阵的行),比如softmax。 

attention的计算时内存受限的,因为它的大部分计算都是element-wise的。 

尽管masking、softmax和dropout操作占用了大部分时间,但大部分FLOPS都用在矩阵乘法中,虽然他们花的时间不多。即数据太庞大,attention计算内存不足,或者说内存利用效率太低!

可以通过内存调整去加速masking、softmax和dropout这些操作呢,但是具体咋办? 

人们都知道把大矩阵切分成小块,但如何保证切分小块的计算结果=原attention计算结果?  

扩展:在计算机体系结构里,内存不是单一的构建,内存存储都是分层的。一般规则是:Memory IO speed 内存速度越快,成本越高,容量越小。

  1. GPU SRAM,19TB/s (20 MB),Static RAM, 静态随机存储器
  2. GPU HBM,1.5TB/s (40 GB),high Boardwidth memory, 高带宽内存 
  3. GPU DRAM,12.8GB/s (>1 TB),main memory

实际上,要充分利用内存、实现IO-aware,关键在于充分利用静态随机存取存储器 (SPAM)比高带宽内存 (HBM)快得多的事实,确保减少两者之间的通信。

(HBM,这是导致CUDA内存溢出的因素之一) 

Flash Attention

Flash Attention 采样分而治之的思想,将大矩阵切块加载到SRAM中,计算每个分块的m和l值。利用上一轮m和l值结合新的子块迭代计算,最终计算出整个矩阵的树枝。Flash Attention基本上可以归结为两个主要思想:

  •  Tiling (在前向和后向传递中使用) - 简单讲就是将NxN的softmax分数矩阵划分为块。
  • 重新计算(因为每个块的系数不一样,Flash Attention每融合一个小块,就需要调整一下之前块的系数,去保持一致!)
  • 传统attention需要分配完整的NxN矩阵(S, P),这是main需要解决的瓶颈,这也是Flash Attention主要解决的问题,将复杂度从O(N^2)降低到O(N)

整个过程不用存储中间变量S和P矩阵,节省了效率因为Attention 操作最大的问题就是每次操作都要从HBM把数据加载到GPU SRAM,运算结束后又从SRAM复制到HBM。这类似于cpu的寄存器与内存的关系,因此最容易的优化方法就是避免这种数据的来回移动,即编译器行话"kernel fusion"。

Flash Attention Algorithm

假设输入一个一维向量x^{(i)} = [x_1,x_2,...,x_B],对应于QK=Sij相似度矩阵中的一行向量。 

1. softmax分块计算:

  • m(x) = max(xi),这是rowmax 操作这是单个值
  • f(x) = [e^{x_1-m(x),..., e^{x_B-m(x)}}]。对应原公式的\tilde{P}_{ij}then why xi-m(x)?这是为了数值稳定,每个数减去相同的任一常量,其softmax值不变。==》减去最大的元素,保证最大值为e^0=1,因为在0~1之间时,浮点数的精度是最大的。
  • l(x) = \sum_if(x)_i,对应原公式\tilde{l}_{ij}这是rowsum 操作
  • so\!ftmax = \frac{f(x)}{l(x)}, softmax除法可以写成diag(l(x))^{-1},把l(x)拉伸成diag的主要原因是把更新公式写成矩阵乘法的形式

2. Flash Attention每次都是合并两块:previous blocks result + latest block。如何保证每一个小块的合并结果与原有attention结果相同?搞好softmax系数的一致性!

  •  因为each step都需要重新计算m(x) = max(m^{(i)}),而m(x)是变的,前面blocks在合并之前,需要先通过m_i - m_i^{new}修正之前block的系数,\tilde{m}_{ij}是指第ij单个block的max(x),不涉及之前blocks的max值
  • m(x) = m([x^{(1), x^{(2)}}]) = max(m(x^{(1)}, m^{(2)}))
  • f(x) = [e^{m(x^{(1)})-m(x))}f(x^{(1)}, e^{m(x^{(2)})-m(x))}f(x^{(2)})]
  • l(x) = e^{m(x^{(1)})-m(x))}l(x^{(1)}, e^{m(x^{(2)})-m(x))}l(x^{(2)})修正系数m_i - m_i^{new}保持一致,因为这两个blocks的softmax系数不一致,m(x^{(2)})-m(x)保证最新的single block的softmax系数与之前的一致!
  • so\!ftmax = \frac{f(x)}{l(x)}

举例:假设x \in R^6,并且它被分成3块:x^{(1)} = [1,3]x^{(2)} = [2,4]x^{(3)} = [3,2]

我们先计算前两块:

  • m(x^{(1)})=3, f(x^{(1)})=[e^{-2},1], l(x^{(1)})=(e^{-2}+1)
  • m(x^{(2)})=4, f(x^{(2)})=[e^{-2},1], l(x^{(2)})=(e^{-2}+1)

我们根据上面的结果计算前两块的结果:

  • m(x) = max(m(x^{(1)}), m(x^{(2)})) = max(3,4)=4
  • f(x) = [e^{3-4}f(x^{(1)}), e^{4-4}f(x^{(2)})]
  • l(x) = e^{3-4}l(x^{(1)}) + e^{4-4}l(x^{(2)})

为什么上面的结果是正确的呢?首先m(x)应该非常明显,4个数中的最大数肯定就是分成两组后的最大中的较大者。而f(x)计算的核心就是在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4以及在𝑓(𝑥(2))𝑓(𝑥(2))前乘以𝑒4−4𝑒4−4。l(x)的计算和f(x)是类似的。为什么需要在𝑓(𝑥(1))𝑓(𝑥(1))前乘以𝑒3−4𝑒3−4?因为在计算𝑓(𝑥(1))𝑓(𝑥(1))时最大的数是3,因此前两个数的指数都乘以了𝑒−3𝑒−3。但是现在前4个数的最大是4了,后面两个数的指数乘以了𝑒−4𝑒−4,因此直接合并为[𝑓(𝑥(1)),𝑓(𝑥(2))][𝑓(𝑥(1)),𝑓(𝑥(2))]是不对的,需要把前面两个数再乘以𝑒3−4=𝑒−1𝑒3−4=𝑒−1。而后面两个数本来就乘以了𝑒−4𝑒−4,所以不用变

计算output Oi:我们把一个很大的x拆分成长度为B的blocks,用上面的算法计算block 1和block 2,然后合并其结果;接着计算第3块,并将above 结果与第三块合并; ... =》所以,我们在定义时,可以把空块x=[], m(x)=-inf, f(x)=[], l(x)=0,这样我们就可以把第一块block的计算转换成block 1和空块的合并,使得循环可以从第一块开始!

  • O_1 = diag(l_1)^{-1}(0 * 0 + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)
  •  O_2 = diag(l_i^{new})^{-1}(diag(l_i)O_ie^{m_i-m_i^{new}} + e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j)

因为Flash Attention不存储中间变量S和P矩阵,所以我们用diag(l_i)O_i反推出之前的PV值,再用e^{m_i-m_i^{new}}修正系数,最后加上第ij块e^{\tilde{m}_{ij}-m_i^{new}}\tilde{P}_{ij}V_j) with single e^{\tilde{m}_{ij}},得到的结果最后再除以diag(l_i^{new})^{-1}保持softmax运算完整性。

参考

Flash Attention论文解读 - 李理的博客

https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1810405.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探究IOC容器刷新环节初始化前的预处理

目录 一、IOC容器的刷新环节快速回顾 二、初始化前的预处理prepareRefresh源码分析 三、初始化属性源 (一)GenericWebApplicationContext初始化属性源 (二)StaticWebApplicationContext初始化属性源 四、初始化早期事件集合…

亚马逊冗余库存处理

在亚马逊放置90天以上的产品,又不在正常的动销,就要采取一定的措施了。清库存方式: 最直接的方式——降价促销(至少要降价百分之三十以上,库龄越久,降价越狠)参加官方的活动促销的话是需要符合…

在 Word 中,如何有效调整文字与下划线之间的距离

🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ 如果你在使用 Word 时,希望调整文字和下划线之间的距离,让它们看起来更加美观,可以按照以下步骤操作: 1. 在你想要加下划线的文字前后各加一个空格&…

Tdengine的时序数据库简介、单机部署、操作语句及java应用

Tdengine的时序数据库简介、单机部署、操作语句及java应用 本文介绍了Tdengine的功能特点、应用场景、超级表和子表等概念,讲述了Tdengine2.6.0.34的单机部署,并介绍了taos数据库的常见使用方法及特色窗口查询方法,最后介绍了在java中的应用。…

Harmony中的HAP、HAR、HSP区别

Harmony中的HAP、HAR、HSP区别 想要更加合理的开发一个企业级别的Harmony应用,那么就不得不提其中的HAP、HAR、HSP了。 前言 对于普通的用户来说,可能一个普通的应用就等于一个安装文件如安卓下的APK。但是对于Harmony应用开发工程师来讲,…

Python | Leetcode Python题解之第143题重排链表

题目: 题解: class Solution:def reorderList(self, head: ListNode) -> None:if not head:returnmid self.middleNode(head)l1 headl2 mid.nextmid.next Nonel2 self.reverseList(l2)self.mergeList(l1, l2)def middleNode(self, head: ListNo…

单田芳mp3百度网盘,单田芳评书下载百度云百度网盘

单老的评书还注重情感的表达。他善于运用声音、语气、语调等手段,将人物的情感刻画得淋漓尽致。无论是喜怒哀乐,他都能准确地把握人物的情感变化,并通过自己的表演将其传递给听众。这种情感的传递,使得听众能够更加深入地理解故事…

Springboot 开发之任务调度框架(一)Quartz 简介

一、引言 常见的定时任务框架有 Quartz、elastic-job、xxl-job等等,本文主要介绍 Spirng Boot 集成 Quartz 定时任务框架。 二、Quartz 简介 Quartz 是一个功能强大且灵活的开源作业调度库,广泛用于 Java 应用中。它允许开发者创建复杂的调度任务&…

Web--CSS基础

文章目录 定义方式选择器文本字体背景边框元素展示格式内边距与外边距盒子模型位置浮动flex布局响应式布局 定义方式 行内样式表 直接定义在style属性中&#xff0c;作用于当前标签 <img src "/imges/logo.jpg" alt "" style "width 400"…

react修改本地运行项目的端口

一、描述 如果你想让项目在你想要的端口打开的话&#xff0c;就需要进行设置 二、代码 设置一下pages.json文件就可以了&#xff0c;如下&#xff1a; 如果想打开项目不需要点击下面的链接地址&#xff0c;让他运行npm run dev之后自己直接打开到浏览器的话&#xff0c;在后…

智能楼宇的智慧心脏:ARMxy工业计算机在自动化控制中的应用

智能楼宇已成为现代化城市不可或缺的一部分。在这场数字化转型浪潮中&#xff0c;ARMxy工业计算机凭借其强大的处理能力、高度的系统兼容性和灵活的I/O配置&#xff0c;成为了推动楼宇自动化控制领域创新的重要力量。 某大型商业综合体项目&#xff0c;面临着传统HVAC系统效率低…

Android Studio历史版本

android studio的历史版本

OpenAI 宕机事件:GPT 停摆的影响与应对

引言 2024年6月4日&#xff0c;OpenAI 的 GPT 模型发生了一次全球性的宕机&#xff0c;持续时间长达8小时。此次宕机不仅影响了OpenAI自家的服务&#xff0c;还导致大量用户涌向竞争对手平台&#xff0c;如Claude和Gemini&#xff0c;结果也导致这些平台出现故障。这次事件的广…

conda 创建环境失败

conda create -n pylableimg python3.10在conda &#xff08;base&#xff09;环境下&#xff0c;创建新的环境&#xff0c;失败。 报错&#xff1a; LookupError: didn’t find info-scipy-1.11.3-py310h309d312_0 component in C:\Users\Jane.conda\pkgs\scipy-1.11.3-py310h…

VMware Workstation Pro的最新下载地址

前言 VMware被Broadcom收购后现在的下载方式也改变了&#xff0c;Workstation Pro 和 Fusion Pro 产品现在起将免费供个人用户使用下载方式 首先先把下载地址打开 https://support.broadcom.com/group/ecx/productdownloads?subfamilyVMwareWorkstationPro 打开链接&#xff…

人工智能和机器学习这两个概念有什么区别?

什么是人工智能&#xff1f; 先来说下人工智能&#xff0c;人工智能&#xff08;Artificial Intelligence&#xff09;&#xff0c;英文缩写为AI&#xff0c;通俗来讲就是用机器去做在过去只有人能做的事。 人工智能最早是由图灵提出的&#xff0c;在1950年&#xff0c;计算机…

jenkins插件之Jdepend

JDepend插件是一个为构建生成JDepend报告的插件。 安装插件 JDepend Dashboard -->> 系统管理 -->> 插件管理 -->> Available plugins 搜索 Jdepend, 点击安装构建步骤新增执行shell #执行pdepend if docker exec phpfpm82 /tmp/composer/vendor/bin/pdepe…

Python酷库之旅-开启库房之门

目录 一、库的定义 二、库的组成 三、库的分类 四、如何学好Python库&#xff1f; 五、注意事项 六、推荐阅读 1、Python筑基之旅 2、Python函数之旅 3、Python算法之旅 4、Python魔法之旅 5、 博客个人主页 一、库的定义 在Python中&#xff0c;库(Library)是一个封…

基于深度学习网络的USB摄像头实时视频采集与人脸检测matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 将摄像头对这播放视频的显示器&#xff0c;然后进行识别&#xff0c;识别结果如下&#xff1a; 本课题中&#xff0c;使用的USB摄像头为&#xff…

目前比较好用的LabVIEW架构及其选择

LabVIEW提供了多种架构供开发者选择&#xff0c;以满足不同类型项目的需求。选择合适的架构不仅可以提高开发效率&#xff0c;还能确保项目的稳定性和可维护性。本文将介绍几种常用的LabVIEW架构&#xff0c;并根据不同项目需求和个人习惯提供选择建议。 常用LabVIEW架构 1. …