基于Ascend C的FlashAttention算子性能优化最佳实践

news2025/2/27 19:55:52

 LLM的Attention部分处理给计算系统带来巨大的计算和访存压力。业界先后出现FlashAttention、FlashAttention2等算法,通过计算等价和切分有效降低HBM数据访问量。 
昇腾异构计算架构CANN针对昇腾AI处理器的片上内存和缓存大小,以及数据搬运通路,基于Ascend C算子编程语言优化实现FlashAttention融合算子,充分利用片上缓存,提升Attention处理性能。根据实测,在一些典型场景中CANN的FlashAttention算子相比小算子取得了5倍以上的性能提升,开发者可直接调用相关算子API接口使能大模型极致性能优化。
本文针对FlashAttention反向融合算子的性能优化方案展开介绍,并通过优化实现了典型场景4倍左右的性能提升,希望对开发者优化此类基于Ascend C开发的融合算子带来启发。 

FlashAttention算法简介 

在主流大模型网络模型中,大量使用典型的Multi-Head Attention结构,带来了巨大的计算和内存开销。其运行过程中,矩阵乘和softmax结果存放在片上内存会带来巨大的内存消耗,访存性能严重下降,甚至会导致模型无法正常运行,同时网络中的矩阵和向量计算串行执行,也会导致硬件算力发挥受限。

斯坦福的Tri DAO提出了FlashAttention融合算子,其原理是对attention处理过程进行切分和计算等价,使得attention的多个步骤在一个算子中完成,并且通过多重循环、每次处理一小部分数据,以近似流式的方式访问片上内存,减少了片上内存访问的总数据量,并能够将计算和数据搬运更好的重叠隐藏。

 注意力的正向计算公式为:

为方便表达,以变量S和P表示计算公式: 

注意力的反向计算公式为: 

昇腾CANN基于Ascend C编程语言实现了FlashAttention正反向融合算子,其中反向算子计算流程可参考下图所示: 

本案例对FlashAttention反向算子进行了性能优化,主要涉及的优化手段包括tiling基本块大小调整,核间负载均衡,CV流水并行,MTE2流水优化以及FixPipe流水优化等,并在Atlas A2训练系列产品/Atlas 800I A2推理产品 验证平台下收益4倍左右的性能提升。下面以如下两个输入场景为例,介绍整个优化过程。

  • 第一个场景的输入维度信息为:B=1,N1=12,N2=12,S1=6144,S2=6144,D=128,并且为casual场景,casual场景即atten_mask的形状为下三角。
  • 第二个场景的输入维度信息为:B=24,N1=5,N2=5,S1=9216,S2=9216,D=64,不带atten_mask和drop_mask输入。

tiling基本块调整 

 根据以往优化的经验,循环间可能存在一些不必要的头开销,循环越多性能可能越差;满足UB最大空间限制的情况下,UB切分的基本块越大,循环越少,算子中通过InitBuffer接口分配UB buffer大小。

pipe->InitBuffer(ubBuffer, 120 * 1024);   
pipe->InitBuffer(tmpBuffer, 30 * 1024);   
pipe->InitBuffer(vecClc3, 8 * 1024);

 如上代码所示,InitBuffer接口的第二个参数表示buffer占用的大小,所有buffer大小的和即为占用的总空间。这里120 * 1024 + 30 * 1024 + 8 * 1024 = 158KB < UB Size,没有充分利用UB空间。
接下来试图通过调整tiling基本块进行性能优化,在满足UB空间大小够用的情况下,tiling基本块切分的越大越好。下图为优化前按照(64, 128)切分计算,总共需要循环计算32次:

考虑到UB空间没有用满,基本块调整到(128, 128),如下图优化后只需循环计算16次,切分后算子性能提升一倍:

 CV流水并行

 从流水图可以看到,可以看出两侧的流水都存在大段的空隙(图中绿色为vector部分流水,橙色为cube侧流水),CV之间流水很大程度上未并行,需要考虑CV流水优化。

 由于FAG算子中cube计算比vector计算快且存在依赖性,同时为了减少CV之间的通信次数,通过缓存机制实现让matmul提前计算多块,这里的缓存机制指的是将mm一次性计算多个基本块缓存到GM上。如下代码中,SetTail设置的SingleM和SingleN大小为BaseM,BaseN的倍数,即matmul一次发起多个基本块的计算,实现matmul结果的缓存,vector侧分多次取matmul的结果。

mm3.SetTail(s2CvExtend, -1, preS1Extend);   
mm3.SetTensorA(mulWorkSpaceGm[pingpongIdx * coreNum * cubeBaseMN + cBlockIdx * cubeBaseMN], true);  
mm3.SetTensorB(queryGm[mm2aTensorOffsetCv]);   
mm3.template IterateAll<false>(dkWorkSpaceGm[bTensorOffsetCv], true);

下图是实现mm1、mm2和mm3缓存的流水图,绿色的vector流水与橙色的cube流水均变得更密集,并行度提高,cv的间隔减小,提升了算子性能:  

基于缓存mm1/mm2/mm3的优化后,在本轮Vector等Cube流水的间隔,插入下一轮循环的Vector计算,这样使Vector流水与Cube流水之间的并行度更高,反映到流水图中为Vector计算更密集: 

相关优化点实现伪代码如下所示: 

 mm1计算;
dropout();
Sub();
dropout(); // 下一轮循环的Vector计算 
Sub();  // 下一轮循环的Vector计算 
mm2计算;
Softmax();
AttenMask();
...

 核间负载均衡

对于上述场景一,casual场景下可能存在核间分布不均匀的情况,如下图经过atten mask掩码后,红色部分是算子需要计算的部分,绿色无需计算;如果不按照基本块的个数来分核,按照第一根轴的大小8(行)来分核,假设平均分到9个核上,每个核做ceil(8 / 9) = 1行,则第一个核只需做1个基本块,但是第8个核需要做8个基本块的计算,出现严重的负载不均衡: 

因此需要考虑将红色块均匀分到多个核上计算,尽量实现每个核的计算量均匀,负载均衡。优化后,红色块总共36个基本块,均分到每个核上,每个核的计算量为4块,性能提升一倍。

 FixPipe流水优化

通过对场景一的Profilling数据进行分析可以看到,aic_fixpipe_ratio占比极高,占比高达81%,出现了很严重的bound: 

同时,CAModel工具打印发现存在很多异常的128B搬运,经过代码排查,发现workspace地址未512B对齐。代码实现中使用SetGlobalBuffer接口设置workspace的起始地址,如果起始地址不是按照512B对齐,搬运效率会很低,可以强制地址512B对齐来避免这个情况,下面代码中ADDR_ALIGN_SIZE即为512:

// init workspace address   
syncGlobal.SetGlobalBuffer((__gm__ int32_t*)workspace);   
uint64_t workspaceOffsets = SYNC_GLOBAL_WORKSPACE_SIZE;   
dqWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));   
workspaceOffsets = (workspaceOffsets + qPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  dkWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));   
workspaceOffsets = (workspaceOffsets + kvPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  dvWorkSpaceGm.SetGlobalBuffer((__gm__ float*)workspace + workspaceOffsets / sizeof(T2));   
workspaceOffsets = (workspaceOffsets + kvPostBlockTotal * sizeof(float) + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;  
// matmul1 and matmul2 workspace size   
matmulWorkspaceSize = cubeBaseMN * sizeof(float);  
mm1WorkspaceGm.SetGlobalBuffer((__gm__ T2*)(workspace + workspaceOffsets + cBlockIdx * matmulWorkspaceSize));  mm2WorkspaceGm.SetGlobalBuffer((__gm__ T2*)(workspace + workspaceOffsets + coreNum * matmulWorkspaceSize + cBlockIdx * matmulWorkspaceSize));   // drop workspace offset   
workspaceOffsets = (workspaceOffsets + coreNum * cubeBaseMN * sizeof(float) * INPUT_NUMS + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;   
dropWorkSpaceGm.SetGlobalBuffer((__gm__ T1*)workspace + workspaceOffsets / sizeof(T1));    
// mul workspace offset   
workspaceOffsets = (workspaceOffsets + coreNum * cubeBaseMN * sizeof(half) * 2 + ADDR_ALIGN_SIZE) / ADDR_ALIGN_SIZE * ADDR_ALIGN_SIZE;   
mulWorkSpaceGm.SetGlobalBuffer((__gm__ T1*)workspace + workspaceOffsets / sizeof(T1));

 修改代码,workspace地址经过512B对齐后,fixpipe时间减半:

 MTE2流水优化

 从场景二采集的profiling和打点图来看,mte2_ratio占比高,cube MTE2出现了明显bound,且部分MTE2搬运时间异常。

 

将输入数据排布格式从BSH更改为BNSD后,数据搬运连续,不需要跳地址读取数据,搬运效率提升一倍,部分异常搬运时长降低了一半。 

 优化方案性能收益

  • 调整tiling基本块:理论评估vector切块越大,计算和搬运循环次数越少,同时能够充分利用搬运带宽和vector算力。基本块大小从(64, 128)增大到(128, 128)后,性能提升一倍,实测与理论分析一致。
  • CV流水并行:CV流水掩盖的时间即为提升的性能,符合预期的收益。
  • 核间负载均衡:优化前负载最多的核的计算量减少的倍数,即为预期提升的性能;案例中优化前负载最多的核的计算量大小为8块,优化后为4块,实际性能提升一倍,符合预期的收益。
  • FixPipe优化:从Profiling数据看出FixPipe占比0.8,优化后占比0.55,实测算子性能提升45%,与理论分析一致。
  • MTE2优化:从Profiling数据看出MTE2占比0.52,优化后占比减少一半,实测算子性能提升30%,与理论分析一致。

 开发者在对基于Ascend C开发的融合算子进行性能优化时,可参考此案例中的优化思路。

更多学习资源 

 了解更多Ascend C算子性能优化手段和实践案例,请访问:昇腾Ascend C-入门课程-学习资源-算子文档-昇腾社区

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1823469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ozon如何上架产品,ozon平台怎么上架产品

在电子商务领域&#xff0c;产品上架是商家成功运营的关键步骤之一。对于正在或计划进军俄罗斯市场的卖家来说&#xff0c;了解如何在Ozon平台高效上架产品至关重要。接下来讲解下ozon如何上架产品&#xff0c;ozon平台怎么上架产品&#xff01; 产品上架工具&#xff1a;D.DDq…

基础-02-数据通信基础

文章目录 1.信道特征1.1 数据通信概念1.2 信道特性-信道带宽W1.3 信道特性-码元和码元速率1.4 信道特性-奈奎斯特定理1.5 信道特性-香农定理1.6 带宽/码元速率/数据速率关系梳理1.7 练习题 2.信道延迟2.1 信道延迟概念2.2 信道延迟计算2.3 练习题 3. 传输介质3.1 传输介质概念3…

电脑屏幕监控软件有哪些?2025年监控软件排行榜

电脑屏幕监控软件有哪些&#xff1f;2025年监控软件排行榜 虽然现在还是2024年&#xff0c;但是有一些被广泛讨论和推荐的电脑屏幕监控软件&#xff0c;它们将在2025年异军突起&#xff0c;成为行业的引领者。 1.安企神软件&#xff1a; 功能全面的电脑屏幕监控软件&#xf…

短视频矩阵系统/源码搭建---拆解热门视频功能开发上线

短视频矩阵系统/源码搭建 一、短视频矩阵系统源码开发需要用到以下技术&#xff1a; 1.前端技术&#xff1a;HTML、CSS、JavaScript、Vue.js等前端框架。 2.后端技术&#xff1a;Java、Python、PHP等后端语言及相关框架&#xff0c;如Spring Boot、Django、Laravel等。 3.移…

火车头采集怎么使用GPT等AI原创文章

火车头采集官方并没有GPT、百度文心一言AI、阿里通义千问AI、Kimi大模型等AI功能&#xff0c;但支持接入插件&#xff0c;可以编写相应人工智能AI原创文章插件&#xff08;火车头采集支持PHP和c#这2种语言的插件编写&#xff09;&#xff0c;或者导入第三方封装好的GPT等AI原创…

文件操作学不懂,小代老师带你深入理解文件操作(下卷)

文件操作学不懂&#xff0c;小代老师带你深入理解文件操作下卷 6. ⽂件的随机读写6.1 fseek6.2 ftell6.3 rewind 7. ⽂件读取结束的判定7.1 被错误使⽤的 feof 8. ⽂件缓冲区 6. ⽂件的随机读写 6.1 fseek 根据⽂件指针的位置和偏移量来定位⽂件指针&#xff08;⽂件内容的光…

RV32F\RV32D指令集

RV32F\RV32D指令集 F扩展1、浮点控制状态寄存器2、指令类型F扩展 F扩展增加了32个浮点寄存器f0-f31,每个32位宽,以及一个浮点控制和状态寄存器fcsr,其中包含浮点单元的工作模式和异常状态。FLEN=32表示F单精度浮点扩展,大多数浮点指令对浮点寄存器中的值进行操作。浮点加载…

苹果AI入华探讨及Apple Intelligence体验分析

引言 近日&#xff0c;苹果在WWDC 2024上引起了广泛关注。尽管苹果在发布会上并未明确提到“AI”一词&#xff0c;但从其展示的众多新功能中可以看出&#xff0c;AI已深深嵌入到其产品中。那么&#xff0c;苹果AI何时能在中国落地&#xff1f;它的模型大小是多少&#xff1f;用…

appproxy 一个轻量级的VPN代理工具,支持HTTP, SOCKS5协议

appproxy 项目背景 在分析app的时候,偶尔需要抓包,尝试了目前比较常见的代理工具Drony Postern ProxyDroid 发现都有一个相同的问题,对于较新的Android系统不太友好,要么app列表显示不正常,或者界面过于复杂,往往设置之后经常会失效,偶然在play上发现一个比较新的代理工具,界…

沉睡而且“狡猾”的特工:大模型也可以是!

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在某个环节出现卡点&#xff0c;可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技&#xff08;Mamba,xLSTM,KAN&#xff09;则提…

JVC摄像机SD卡变成RAW的恢复方法

JVC小日本胜利公司&#xff0c;公司名字绕口且产品线极广&#xff0c;涉及汽车、影音、娱乐……&#xff0c;而JVC在摄像机产品方面也有涉及&#xff0c;不过市场上极为少见。下边我们来看下这个JVC摄像机MP4恢复案例。 故障存储: 32G存储卡 RAW文件系统 故障现象: 客户无…

通配符SSL证书的应用范围

首先带大家先来了解一下通配符SSL证书&#xff1a; 通配符证书又名泛域名证书&#xff0c;是一种SSL/TLS证书&#xff0c;用于保护多个域名或无限多个域名。是由域名字段中的通配符*表示。通配符证书最大的亮点在于可以通过绑定一个主域名&#xff0c;从而间接绑定无数的次级子…

性能测试常见的内存溢出问题: JVM 内存溢出如何调优?

针对java项目做性能测试的时候,很多同学都见过一个报错,就是OOM【Out Of MemoryError】;那出现这种报错就是项目发生了内存溢出的问题,这是比较严重的性能问题。所以,作为一个性能测试工程师,我们要能够分析JVM内存的问题以及理解其中的原理,才能更好的给JVM内存出现的性…

国网I6000请求,出现缺少SOAPAction头信息如何解决?

错误代码: Client.NoSOAPAction 这表示客户端请求中缺少SOAPAction头信息。 错误消息: no SOAPAction header! 这明确指出请求中没有包含SOAPAction头。 详细信息: hostname: 9f1957926889&#xff0c;这是服务器的主机名&#xff0c;不直接影响错误分析。 解决方案 添加SOAP…

案例学习-存量更新规划实施探索(武汉)

案例学习-存量更新规划实施探索&#xff08;武汉&#xff09; 武汉市在早期旧城更新实践中发现零散化的更新往往导致资源配置分散、城市建设破碎化等弊病&#xff0c;特别是由于过于强调项目自身“经济平衡”&#xff0c;在实施过程中也逐步暴露出住宅占比过大、强度偏高、公服…

【工具】新手如何正确使用Pycharm?

1. 什么是JetBrains Toolbox JetBrains Toolbox是一个管理工具&#xff0c;用于安装、更新和管理JetBrains开发工具的所有版本。它可以简化多个IDE的管理&#xff0c;并确保你总是使用最新版本的软件。 2. 安装JetBrains Toolbox 步骤1&#xff1a;下载Toolbox 访问JetBrai…

List 列表

文章目录 一、什么是 List 列表1.1 创建 List 列表的方式1.2 列表的新增函数方法1.3 列表的删除函数方法1.4 修改列表数据的方法1.5 列表的查询函数方法1.6 列表的排序和反序1.7 列表的复制 一、什么是 List 列表 List 列表&#xff1a;该数据类型定义的变量可以理解为是一个数…

用Python向Word文档添加页眉和页脚

用Python向Word文档添加页眉和页脚 添加页眉和页脚效果代码 添加页眉和页脚 在本文中&#xff0c;我们将用python向文档中添加页眉和页脚。 效果 添加前的文档&#xff1a; 添加页眉和页脚后&#xff1a; 代码 from docx import Documentdef add_header_footer(doc_path…

C#——静态成员和非静态成员详情

静态成员和非静态成员 调用: 静态属性(static) : 类名.属性名调用 非静态属性(没static) : 1.先创建对象 2.对象.属性 特点: 静态方法里面只能访问静态成员 非经态方法中可以访问所有的属性 static数据成员在类的内部声明&#xff0c;但只能在类的外部定义&#xff0c;…

时序预测 | MATLAB实现TCN-Attention自注意力机制结合时间卷积神经网络时间序列预测

时序预测 | MATLAB实现TCN-Attention自注意力机制结合时间卷积神经网络时间序列预测 目录 时序预测 | MATLAB实现TCN-Attention自注意力机制结合时间卷积神经网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.MATLAB实现TCN-Attention自注意力机制结合时…