大模型训练之加速篇 -attention优化【MQA-> flashAttention】

news2025/2/24 14:21:40

MQA (multi query attention)

Fast Transformer Decoding: One Write-Head is All You Need

MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。
那到底能提升多少的速度呢,我们来看论文中给出的结果图[生成每个token消耗的时间ms]:
在这里插入图片描述

从字面上看,Multi Query Attention(MQA) 和 Multi Head Attention(MHA)只差了一个单词,
就是从「Head」变成了「Query」。
我们知道,在 transformer 中是包含若干个注意力头(head)组成的,
而每个 head 又是由: query(Q),key(K),value(V) 3 个矩阵共同实现的。

「参数共享」并不是一个很新奇的思路,在 Albert 里也有通过使用跨层共享参数(Cross-layer parameter sharing)的方式来大大减少 bert 的参数量,具体做法可以参考这里:何枝:基于BERT的几种改进模型
现在,
我们知道了 MQA 实际上是将 head 中的 key 和 value 矩阵抽出来单独存为一份共享参数,
而 query 则是依旧保留在原来的 head 中,每个 head 有一份自己独有的 query 参数。

FlashAttention V1

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

当输入序列(sequence length)较长时,Transformer的计算过程缓慢且耗费内存,这是因为self-attention的time和memory complexity会随着sequence length的增加成二次增长。
标准Attention的中间结果S,P通常需要通过高带宽内存(HBM)进行存取,两者所需内存空间复杂度为O(N2)。

FlashAttention对HBM访问的次数为O(N2d2M-1)
Attention对HBM访问的次数为O(Nd+ N2)
往往N远远大于d(例如GPT2中N=1024,d=64),因此FlashAttention会快很多。下图展示了两者在GPT-2上的Forward+Backward的GFLOPs、HBM、Runtime对比(A100 GPU):
在这里插入图片描述

GPU中存储单元主要有HBM和SRAM:HBM容量大但是访问速度慢,SRAM容量小却有着较高的访问速度。例如:A100 GPU有40-80GB的HBM,带宽为1.5-2.0TB/s;每108个流式多核处理器各有192KB的片上SRAM,带宽估计约为19TB/s。可以看出,片上的SRAM比HBM快一个数量级,但尺寸要小许多数量级。
综上,FlashAttention目的不是节约FLOPs,而是减少对HBM的访问。重点是FlashAttention在训练和预测过程中的结果和标准Attention一样,对用户是无感的,而其他加速方法做不到这点。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
对应的计算过程:
每次外循环(outer loop,j)Kj Vj载入的的大小为Bcd=768d,一共循环次Tc=2次
每次内循环(inner loop,i)载入的Qi 的大小为Brd=64d,一共循环Tr=16次(总次数还需要乘以外循环)
Sij = Qi*KjT,即为(下标表示维度):。
Pij,表示和标准attention Pij计算的有区别,因为得到的row_max(Sij)最大值可能不是S第i行的最大值。的大小和一样,都为。
Pij和Sij只是部分结果,如下图所示,外循环是横向(特征维d)移动的,内循环是纵向(序列维N)移动的。换句话说,外循环在顺序计算特征,内循环在顺序计算序列。
Oi的大小为Br*d,第二维d是满的(和最终一样),这意味着每次外循环都要重新更新当前批次中的特征,即虽然第一次外循环P00*V0和第二次外循环P01*V1都会得到O0,但是第二次的O0是基于第一次O0重新生成的。
diag(……)作用是将vector生成为一个对角矩阵,从而实现相同长度的两个vector进行element-wise相乘。

在这里插入图片描述

GPU 知识

在这里插入图片描述
在这里插入图片描述

从Hardware角度来看:
Streaming Processor(SP):是最基本的处理单元,从fermi架构开始被叫做CUDA core。
Streaming MultiProcessor(SM):一个SM由多个CUDA core(SP)组成,每个SM在不同GPU架构上有不同数量的CUDA core,例如Pascal架构中一个SM有128个CUDA core。
SM还包括特殊运算单元(SFU),共享内存(shared memory),寄存器文件(Register File)和调度器(Warp Scheduler)等。register和shared memory是稀缺资源,这些有限的资源就使每个SM中active warps有非常严格的限制,也就限制了并行能力。

从Software(编程)角度来看:
thread:一个CUDA并行程序由多个thread来执行

thread是最基本的执行单元(the basic unit of execution)。

warp:一个warp通常包含32个thread。每个warp中的thread可以同时执行相同的指令,从而实现SIMT(单指令多线程)并行。

warp是SM中最小的调度单位(the smallest scheduling unit on an SM),一个SM可以同时处理多个warp

thread block:一个thread block可以包含多个warp,同一个block中的thread可以同步,也可以通过shared memory进行通信。

thread block是GPU执行的最小单位(the smallest unit of execution on the GPU)。

一个warp中的threads必然在同一个block中,如果block所含thread数量不是warp大小的整数倍,那么多出的那个warp中会剩余一些inactive的thread。也就是说,即使warp的thread数量不足,硬件也会为warp凑足thread,只不过这些thread是inactive状态,但也会消耗SM资源。
grid: 在GPU编程中,grid是一个由多个thread block组成的二维或三维数组。grid的大小取决于计算任务的规模和thread block的大小,通常根据计算任务的特点和GPU性能来进行调整。

Hardware和Software的联系:
SM采用的是Single-Instruction Multiple-Thread(SIMT,单指令多线程)架构,warp是最基本的执行单元,一个warp包含32个并行thread,这些thread以不同数据资源执行相同的指令。
当一个kernel被执行时,grid中的thread block被分配到SM上,大量的thread可能被分到不同的SM上,但是一个线程块的thread只能在一个SM上调度,SM一般可以调度多个block。每个thread拥有自己的程序计数器和状态寄存器,并且可以使用不同的数据来执行指令,从而实现并行计算,这就是所谓的Single Instruction Multiple Thread。
一个CUDA core可以执行一个thread,一个SM中的CUDA core会被分成几个warp,由warp scheduler负责调度。GPU规定warp中所有thread在同一周期执行相同的指令,尽管这些thread执行同一程序地址,但可能产生不同的行为,比如分支结构。一个SM同时并发的warp是有限的,由于资源限制,SM要为每个block分配共享内存,也要为每个warp中的thread分配独立的寄存器,所以SM的配置会影响其所支持的block和warp并发数量。

GPU执行模型小结:
GPU有大量的threads用于执行操作(an operation,也称为a kernel)。这些thread组成了thread block,接着这些blocks被调度在SMs上运行。在每个thread block中,threads被组成了warps(32个threads为一组)。一个warp内的threads可以通过快速shuffle指令进行通信或者合作执行矩阵乘法。在每个thread block内部,warps可以通过读取/写入共享内存进行通信。每个kernel从HBM加载数据到寄存器和SRAM中,进行计算,最后将结果写回HBM中。

FlashAttention V2

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention利用GPU非匀称的存储器层次结构,实现了显著的内存节省(从平方增加转为线性增加)和计算加速(提速2-4倍),而且计算结果保持一致。但是,FlashAttention仍然不如优化的矩阵乘法(GEMM)操作快,只达到理论最大FLOPs/s的25-40%。作者观察到,这种低效是由于GPU对不同thread blocks和warps工作分配不是最优的,造成了利用率低和不必要的共享内存读写。因此,本文提出了FlashAttention-2以解决这些问题。
虽然相比标准Attention,FlashAttention快了24倍,节约了1020倍内存,但是离设备理论最大throughput和flops还差了很多。本文提出了FlashAttention-2,它具有更好的并行性和工作分区。实验结果显示,FlashAttention-2在正向传递中实现了约2倍的速度提升,达到了理论最大吞吐量的73%,在反向传递中达到了理论最大吞吐量的63%。在每个A100 GPU上的训练速度可达到225 TFLOPs/s。
本文主要贡献和创新点为:

减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。
提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。
在一个attention计算块内,将工作分配在一个thread block的不同warp上,以减少通信和共享内存读/写。

在这里插入图片描述
Causal masking是attention的一个常见操作,特别是在自回归语言建模中,需要对注意力矩阵S应用因果掩码(即任何S ,其中 > 的条目都设置为−∞)。
由于FlashAttention和FlashAttention-2已经通过块操作来实现,对于所有列索引都大于行索引的块(大约占总块数的一半),我们可以跳过该块的计算。这比没有应用因果掩码的注意力计算速度提高了1.7-1.8倍。
不需要对那些行索引严格小于列索引的块应用因果掩码。这意味着对于每一行,我们只需要对1个块应用因果掩码。

并行处理
FlashAttention在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。每个block被调到到一个SM上运行,例如A100 GPU上有108个SMs。当block数量很大时(例如≥80),这种调度方式是高效的,因为几乎可以有效利用GPU上所有计算资源。
但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。因此,FlashAttention-2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。
Forward pass. FlashAttention算法有两个循环,K,V在外循环,Q,O在内循环。FlashAttention-2将Q移到了外循环i,K,V移到了内循环j,由于改进了算法使得warps之间不再需要相互通信去处理Qi,所以外循环可以放在不同的thread block上。这个交换的优化方法是由Phil Tillet在Triton[17]提出并实现的。

转载于:https://zhuanlan.zhihu.com/p/645376942

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1027835.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

慢SQL原因分析之索引失效 | 京东物流技术团队

现象 最近收到一个慢sql工单,慢sql大概是这样:“select xxx from tabel where type 1”。 咦,type字段明明有索引啊,为啥是慢sql呢? 原因 通过执行explain,发现实际上数据库执行了全表扫描,从而被系统…

excel的vlookup函数用法

vlookup函数功能:在表格的首列查找指定的数值,并返回表格当前行中指定列处的数值。 结构:VLOOKUP(查找值,查找区域,列序数,匹配条件) 解释:VLOOKUP(找谁,在哪里找,第几列,0或1) 说明: 1、第一参数:查找…

【Linux is not Unix】Linux前言

目录 二战军工的产物——第一台现代电子数字计算机ENIAC(埃尼阿克) Unix Linux Linux企业应用现状 如今计算机已经应用在我们生活的各个层面,像我们日常使用的笔记本是计算机的一类,可以解决我们生活中遇到的很多问题&#xff…

【数据分享】2023年全国乡镇(街道)点位数据(免费获取\shp格式\excel格式)

乡镇(街道)点位数据是我们各项研究中经常使用到的数据,在之前的文章中我们分享过2022年度的乡镇(街道)点位数据(可查看之前推送的文章获悉详情)。本次我们带来的是2023年度的全国范围的乡镇&…

Labelme分割标注软件

Labelme分割标注软件 1、环境配置与安装1.1 创建conda虚拟环境(建议)1.2 安装Labelme 2、简单使用2.1 创建label标签文件2.2 启动labelme2.3 打开文件/文件夹2.4 设置保存结果路径2.5 标注目标2.6 保存json文件格式 3 格式转换3.1 转换语义分割标签3.2 转换实例分割标签 相关重…

CUDA小白 - NPP(11) 图像处理 Comparison Operations

cuda小白 原始API链接 NPP GPU架构近些年也有不少的变化,具体的可以参考别的博主的介绍,都比较详细。还有一些cuda中的专有名词的含义,可以参考《详解CUDA的Context、Stream、Warp、SM、SP、Kernel、Block、Grid》 常见的NppStatus&#xf…

显示器有白点闪烁、间歇黑屏解决办法

问题描述 以上三张图片是不到一秒内通过手机视频拍摄显示器画面,可以看到第一张图大桥下和第二张图片右下角岛屿初均有红点闪烁。当触发黑屏时,显示器整体白点闪烁。并且时常黑屏,几秒后恢复。 解决办法 检查HDMI连接线是否脱落&#xff0c…

初识canvas

对于一个前端人员来说,canvas是必须掌握的技能之一。如果你想像画画一样在浏览器上作画,那么canvas就可以做你的画布。 接下啦我们就以画画的标准来初步认识下canvas 1.画布 画画的第一步你得有一张画纸或者画布,canvas标签就是我们的画布…

查看mysql 容量

SQL SELECT table_schema "database", sum( data_length index_length) / 1024 / 1024 /1024 "size in GB" FROM information_schema.TABLES GROUP BY table_schema;结果

CSS 之 grid 网格布局

一、简介 ​ display: grid;用于设置元素内部的布局类型为网格布局,其外显类型为块级元素。该类型的元素将内部分为行和列,划分成一个个单元格,并通过一系列相关属性控制单元格及其内容的布局和大小。 ​ 该属性值的主要应用场景为&#xf…

1297. 子串的最大出现次数

1297. 子串的最大出现次数 // 返回子串的最大出现次数&#xff1a;用hash表 // 子串中 不同字母的次数 < maxLetters && 子串长度> minSize && 子串长度 < maxSizeint maxFreq(char * s, int maxLetters, int minSize, int maxSize){}

【算法】矩阵快速幂优化动态规划

文章目录 知识讲解题目列表[矩阵快速幂] 题目列表&#x1f4d5;70. 爬楼梯解法1——线性DP解法2——矩阵快速幂 509. 斐波那契数1137. 第 N 个泰波那契数1220. 统计元音字母序列的数目解法1——线性DP解法2——矩阵快速幂优化DP 552. 学生出勤记录 II&#xff08;&#x1f6b9;…

Kafka 常见问题

文章目录 kafka 如何确保消息的可靠性传输Kafka 高性能的体现利用Partition实现并行处理利用PageCache 如何提高 Kafka 性能调整内核参数来优化IO性能减少网络开销批处理数据压缩降低网络负载高效的序列化方式 kafka 如何确保消息的可靠性传输 消费端弄丢了数据 唯一可能导致…

第N个数字

给你一个整数 n &#xff0c;请你在无限的整数序列 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, …] 中找出并返回第 n 位上的数字。 我觉得这题是哪以理解的 看这个题解 func findNthDigit(n int) int {digit : 1start : 1count : 9for n > count {n - countdigitstart start …

这个发表在 Nature Genetics的水稻全基因组关联数据库 RHRD,很赞!!!

历经半个世纪的发展&#xff0c;杂交水稻育种取得了巨大的成就&#xff0c;培育出了大量的高产、优质、适应环境变化的品系。本数据库是一个综合性的杂交水稻数据库&#xff08;http://ricehybridresource.cemps.ac.cn/#/&#xff09;&#xff0c;涵盖了从1976年至2017年间发布…

【Unity】简单的深度虚化shader

【Unity】简单的深度虚化shader 实现效果 可以用于对地图场景边界的白模处理 实现方法 1.关键方法 UnityObjectToClipPos&#xff1a;将物体坐标转换为屏幕坐标 LinearEyeDepth&#xff1a;将屏幕坐标中的z值转换为实际的深度值 saturate&#xff1a;将值规范到0~1之间&am…

Java 消息策略的实现 - Kafak 是怎么设计的

这个也是开放讨论题&#xff0c;主要讨论下 Kafka 在消息中是如何进行实现的。 1_cCyPNzf95ygMFUgsrleHtw976506 21.4 KB 总结 这个题目的开发性太强了。 Kafka 可以用的地方非常多&#xff0c;我经历过的项目有 Kafka 用在消息处理策略上的。这个主要是 IoT 项目&#xff0c…

three.js中的3D模型分层显示(分类型显示);使用dat.gui控制three.js中的3D模型分层显示;dat.gui调用一次但是渲染了多个

效果如上&#xff0c;就是可以通过dat.gui控制3D模型中仅仅显示管线或者是仅仅显示除了管线之外的模型。 1.在模型导入的时候就按照类型&#xff08;分层的类别标识&#xff09; 区别开&#xff08;我这里是按照是否是管线&#xff09; 这里是new THREE.Object3D();必须的否则…

Python基础学习笔记3

深度学习实践 深度学习离不开编程 深度学习离不开数学分析&#xff08;高等数学&#xff09;、线性代数、概率论等知识&#xff0c;更离不开以编程为核心的动手实践。 Python编程语言 无论是在机器学习还是深度学习中&#xff0c;Python已经成为主导性的编程语言。而且&…

OJ练习第178题——收集树中金币

收集树中金币 力扣链接&#xff1a;2603. 收集树中金币 题目描述 给你一个 n 个节点的无向无根树&#xff0c;节点编号从 0 到 n - 1 。给你整数 n 和一个长度为 n - 1 的二维整数数组 edges &#xff0c;其中 edges[i] [ai, bi] 表示树中节点 ai 和 bi 之间有一条边。再给…