飞桨首创 FlashMask :加速大模型灵活注意力掩码计算,长序列训练的利器

news2024/11/27 0:22:40

在 Transformer 类大模型训练任务中,注意力掩码(Attention Mask)一方面带来了大量的冗余计算,另一方面因其 O ( N 2 ) O(N^2) O(N2)巨大的存储占用导致难以实现长序列场景的高效训练(其中 N N N为序列长度)。虽然业界已有 FlashAttention 等针对特定注意力掩码的计算加速方法,但其支持的注意力掩码模式有限,难以满足大模型训练任务对灵活注意力掩码的需求。为了解决上述问题,飞桨独创 FlashMask 技术,提出了列式稀疏的注意力掩码表示方法,支持灵活多样的注意力掩码模式,使得存储复杂度从 O ( N 2 ) O(N^2) O(N2)降低至 O ( N ) O(N) O(N),并在此基础上实现了高效的算子 Kernel,极致加速大模型训练效率,尤其是长序列场景下的训练效率。

我们在NVIDIA A100 (80G) GPU上对 FlashMask 在大语言模型微调和对齐训练中的表现进行了评估,包括 SFT、LoRA、DPO 和 RM。与现有的 FlashAttention 稠密掩码方法相比,FlashMask 在端到端训练速度上实现了显著提升,速度提高幅度在 1.65 倍到 3.22 倍之间。此外,我们还评估了其 Kernel 层次上的性能。FlashMask 在理论最大浮点运算次数上达到了37.8%到62.3%,在 Kernel 每秒浮点运算次数(TFLOPs/s)方面,其性能超过 FlexAttention,提升幅度为 12.1% 到 60.7%。

arXiv 论文: https://arxiv.org/abs/2410.01359

PaddlePaddle 官方文档: https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/nn/functional/flashmask_attention_en.html

PaddleNLP 开源集成:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/docs/flashmask.md

星河社区快速体验:https://aistudio.baidu.com/projectdetail/8459413

一、大语言模型的挑战

随着人工智能技术的迅猛发展,以 Transformer 为代表的大模型在自然语言处理、计算机视觉和多模态应用中展现出了非凡的能力。在这些大模型中,注意力(Attention)机制是一个关键环节。为了在大模型训练任务中确定哪些 Query-Key token 之间需要进行有效的 Attention 计算,业界通常使用注意力掩码(Attention Mask)。然而,目前的注意力掩码通常采用二维稠密矩阵表示,这导致了一些问题。一方面,这种表示方法引入了大量冗余计算,因为许多无效 token 的 Attention 仍需计算;另一方面,这种掩码的空间复杂度为 O ( N 2 ) O(N^2) O(N2)(其中 N N N为序列长度),在长序列的训练场景中会造成巨大的存储压力,难以进行高效训练。为了解决这些问题,业界已经提出了一些方案,如 Memory Efficient Attention (MEA) [1] 和 FlashAttention [2]。然而,这些方案支持的注意力掩码类型较为有限。正如图 1 所示,FlashAttention 只能支持如纯因果掩码(Causal)、滑动窗口掩码(Sliding Window)、因果文档掩码(Causal Document Mask)和文档掩码(Document Mask)等几种固定形式的掩码。然而,实际训练任务中使用的注意力掩码形式往往丰富多变,当前技术难以满足大模型不同训练任务对注意力掩码灵活性的要求。
图1 常见的注意力掩码类型
图1 常见的注意力掩码类型

二、FlashMask 的创新:列式稀疏掩码表示方法与高效计算

1. 关键洞察

FlashMask 的核心发现是,在大模型常见的注意力掩码模式中,Query-Key token 的掩码模式具有一定的连续性。具体而言,对于每一个 Key token,无效注意力计算的 Query token 是相邻排列的。也就是说,在图 1 中二维掩码矩阵中,Query token 作用在每一列的 Key token 的灰色部分沿列方向连续分布。基于这一洞察,FlashMask 巧妙地将二维稠密掩码矩阵转换为一维的行索引区间,从而实现更为紧凑的表示形式,并显著降低了存储需求。我们可以公式化表示为:

M j = [ s t a r t j , e n d j ) , ∀ j ∈ { 1 , … , N } M_{j} = [start_j, end_j), \quad \forall j \in \{1, \ldots, N\} Mj=[startj,endj),j{1,,N}

其中 N N N为 Key 的序列长度, M j M_j Mj为二维的稠密掩码矩阵的第 j j j列, [ s t a r t j , e n d j ) [start_j, end_j) [startj,endj)为连续的行索引区间,表示 s t a r t j start_j startj e n d j − 1 end_{j} - 1 endj1的连续 Query token 是被 mask 掉,置为无效 Attention 计算。

2. 注意力掩码的列式稀疏掩码表示方法

为了高效处理因果和双向注意力场景中的复杂掩码模式,FlashMask 提出了一种新颖的列式稀疏表示方法。以对角线为区分,它使用四个一维向量来表示掩码:

  • 下三角起始行索引(Lower Triangular Start,简称 L T S LTS LTS
  • 下三角结束行索引(Lower Triangular End,简称 L T E LTE LTE
  • 上三角起始行索引(Upper Triangular Start,简称 U T S UTS UTS
  • 上三角结束行索引(Upper Triangular End,简称 U T E UTE UTE
    其中下三角被 mask 掉的行索引区间使用 [ 𝐿 𝑇 𝑆 , 𝐿 𝑇 𝐸 ) [𝐿𝑇𝑆, 𝐿𝑇𝐸) [LTS,LTE)表示,上三角被 mask 掉的行索引区间使用 [ 𝑈 𝑇 𝑆 , 𝑈 𝑇 𝐸 ) [𝑈𝑇𝑆, 𝑈𝑇𝐸) [UTS,UTE) 表示。 如图2所示,我们展示了16个Query token 和16个Key token 做 Attention 计算时较为复杂的二维稠密因果注意力的掩码矩阵,灰色单元格是 mask 区域。
    图2 较为复杂的二维稠密因果注意力的掩码矩阵示意图
    图2 较为复杂的二维稠密因果注意力的掩码矩阵示意图

可以通过 [ 𝐿 𝑇 𝑆 , 𝐿 𝑇 𝐸 ) [𝐿𝑇𝑆, 𝐿𝑇𝐸) [LTS,LTE)两个向量进行表达,如下所示:
在这里插入图片描述

以第 0 列为例,开始 mask 的行为 13,结束 mask 的行为 15(开区间),表示位置为 13 和 14 的 Query token 不与位置为 0 的 Key token 做有效 Attention 计算。

更多的例子参考图 3,FlashMask 使用列式稀疏掩码表示方法,表达了图 1 中所有的注意力掩码模式。其中 - 的空缺表示在不同的场景下有不同的默认值, L T S LTS LTS U T S UTS UTS中的默认值是 0,表示 mask 区域默认从第 0 行开始, L T E LTE LTE U T E UTE UTE中的默认值是 Query 的序列长度,表示 mask 区域默认结束于最后一行。
图3 使用 FlashMask 的列式稀疏掩码表示方法表示图1的注意力掩码模式
图3 使用 FlashMask 的列式稀疏掩码表示方法表示图1的注意力掩码模式

3. 扩展 FlashAttention 支持复杂掩码

FlashMask 将列式掩码表示方法集成到 FlashAttention-2 算法中,增强了其对注意力掩码的支持能力。在 FlashAttention Kernel 的分块计算基础上,FlashMask 利用上述的 L T S LTS LTS L T E LTE LTE U T S UTS UTS U T E UTE UTE 等掩码向量,来判断当前分块的掩码类型:

  • 完全掩码块:此类块的所有元素均被掩码,计算时可直接跳过。
  • 部分掩码块:此类块仅部分元素被掩码,因此需要对该块进行逐元素的掩码处理。
  • 未掩码块:此类块中的所有元素均未被掩码,可以简化计算过程,无需额外的掩码操作。
    通过这种分类处理,FlashMask 显著提升了计算效率:完全掩码块的计算被直接跳过,未掩码块的计算得到简化,仅对部分掩码块执行必要的掩码操作,如图4所示。
    图4 FlashMask 计算过程示意图图4 FlashMask 计算过程示意图

算法1详细描述了 FlashMask 扩展 FlashAttention-2 的前向计算过程,其中浅蓝色阴影部分表示 FlashMask 新增的计算步骤 [3]。
算法1 FlashMask 的前向计算伪代码
算法1 FlashMask 的前向计算伪代码

4. 效率提升与精度保证

FlashMask 充分利用了注意力掩码中的稀疏性,通过跳过完全掩码块的计算,减少了计算开销,同时不改变算法的精度。与使用稠密掩码矩阵的注意力计算保持比特级别的数值等效性,确保了精度无损。

三、FlashMask 的优势:速度与存储的双重提升

1. 端到端训练吞吐量提升

在 Llama-2 7B、13B、70B 等模型规模下,针对 SFT、LoRA、DPO、RM 四种下游训练场景和不同序列长度的实验表明,FlashMask 在各个模型规模和序列长度下均实现了端到端的加速和存储效率的提升。相比现有的基于稠密掩码矩阵的计算方法,FlashMask 实现了 1.65 倍至 3.22 倍的吞吐量提升,并支持更长的序列长度。
图5 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,在不同序列长度下的端到端训练吞吐量
图5 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,在不同序列长度下的端到端训练吞吐量
图6 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,不同序列长度下的端到端训练峰值显存消耗
图6 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,3 个 Llama2 模型规模,不同序列长度下的端到端训练峰值显存消耗
表2 在 Llama2 7B 模型上 FlashMask 对比 FlashAttention (Causal=True) 的显存消耗,单位(GB)
表2 在 Llama2 7B 模型上 FlashMask 对比 FlashAttention (Causal=True) 的显存消耗,单位(GB)

2. 端到端训练收敛验证

在 Llama 3.1 模型上的实验验证了 FlashMask 对收敛精度没有影响。作为一种精确的算法,通过控制计算过程的随机性(如 FlashAttention 反向 Query 梯度计算使用 atomicAdd 操作),FlashMask 可以与使用稠密掩码的 FlashAttention 在比特级别精确对齐。
图7 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,Llama3.1 8B 模型端到端训练 Loss 对比
图7 在四个下游训练任务(SFT、LoRA、DPO 和 RM)中,Llama3.1 8B 模型端到端训练 Loss 对比

3. 稀疏度与 Kernel 计算时延的线性关系

FlashMask 利用注意力掩码的块稀疏性,跳过完全掩码块的计算,将计算复杂度降低到 O ( ( 1 − ρ ) T r T c ) O((1 - ρ)T_rT_c) O((1ρ)TrTc),其中 ρ ρ ρ 表示块稀疏性。为了验证这一关系,FlashMask 进行了多组实验,测试了三种不同的掩码类型(因果文档掩码、共享问题掩码和文档掩码),并使用不同稀疏度的数据。实验结果(如图 5 所示)表明,Kernel 执行延迟与稀疏性之间呈线性关系,意味着随着稀疏性的增加,FlashMask 的计算速度进一步提升。
图8 不同块稀疏度下的 Kernel 计算时延
图8 不同块稀疏度下的 Kernel 计算时延

4. Kernel 性能对比

关注到近期 PyTorch 推出了 FlexAttention [4](使用编译器技术支持 Attention Mask),FlashMask 与之在 Kernel 级别进行了对比。在各种常见的注意力掩码模式下,FlashMask 展现了更高的计算效率。在 TFLOPs/s 指标上,FlashMask 比 FlexAttention 高出 12.1% 至 60.7%,在 A100 GPU 上实现了 37.8% 至 62.3% 的理论峰值计算性能。
 图9 在 A100-SXM 80G GPU 上的 Kernel 前向和反向速度对比。FlexAttention 使用 PyTorch 2.6.0.dev20240920+cu124
图9 在 A100-SXM 80G GPU 上的 Kernel 前向和反向速度对比。FlexAttention 使用 PyTorch 2.6.0.dev20240920+cu124

四、FlashMask 的应用:赋能大语言模型

FlashMask 的创新和优势为 Transformer 类大模型的注意力机制训练加速开辟了新的可能,可广泛应用于各种任务,并支持超长序列高效训练。

1. 可广泛应用于大语言模型的下游训练加速

FlashMask 可以应用于大语言模型的下游任务训练,例如 SFT、LoRA、DPO、RM 等。特别是在 DPO 和 RM 的训练中,其数据由问题和回答对组成,训练时多个答案可以共享一个问题,从而大幅减少对问题 token 的冗余计算。

2. 支持单向/双向混合注意力掩码模式训练

FlashMask 支持多种注意力模式,包括因果掩码(单向注意力)和文档掩码(双向注意力),因此能够灵活地应用于需要混合注意力的场景。例如:

  • 全局 + 滑动窗口掩码:这种掩码结合了全局注意力和滑动窗口注意力,既能捕捉全局上下文信息,又能关注局部细节。FlashMask 能高效处理这种混合掩码,提升模型性能。
  • 前缀语言模型:在生成文本时,前缀部分需要关注所有的 token,而其他部分使用因果掩码(如 T5 模型的预训练)。FlashMask 可以同时支持这两种注意力模式,提高前缀语言模型的训练和推理效率。

3. 支持多模态图文数据的混合多分辨率训练

在多模态数据处理中,不同模态的数据可能具有不同的分辨率。虽然文中未明确提及 FlashMask 在多模态和多分辨率训练中的应用,但 FlashMask 可以通过不同的注意力模式和掩码策略,有效处理这些具有不同分辨率的数据。针对长序列处理能力的优化,使得FlashMask能够帮助模型更好地学习不同模态数据之间的关联。例如,在图文匹配任务中,FlashMask 可以帮助模型更有效地对齐图像和文本中的关键信息。

FlashMask 的开源代码已在 PaddlePaddle 和 PaddleNLP 平台发布,支持超过千亿参数的模型以及超过 128K tokens 的上下文长度。我们相信,FlashMask 将成为推动大语言模型发展的重要力量,为算法研究人员提供更广阔的注意力掩码创新与研究空间。

结语

飞桨首创 FlashMask 技术,通过创新的列式稀疏掩码表示方法和高效的Kernel实现方式,解决了传统注意力掩码方法冗余计算和存储占用过高等难题,助力大模型训练加速,尤其是基于长文场景的训练加速。未来,飞桨将持续研发高效的大模型训练加速技术,不断推进大模型技术更新迭代。

参考文献

[1] Self-attention Does Not Need O(n^2) Memory. https://arxiv.org/abs/2112.05682

[2] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691

[3] FlashMask: Efficient and Rich Mask Extension of FlashAttention. https://arxiv.org/abs/2410.01359

[4] FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention. https://pytorch.org/blog/flexattention/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2231914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高电压、真差分信号采集的SAR ADC驱动电路设计

1 简介 本设计展示了一种用于驱动高压 SAR ADC 以实现高压全差分信号数据采集的解决方案。该差分信号可能具有广泛的共模电压范围,具体取决于放大器的电源和输入信号振幅。使用一个通用高压精密放大器来执行差分到单端信号转换,并以最高吞吐量驱动 10V 的…

在VS Code中操作MySQL数据库

【基础篇】 【小白专用24.5.26 已验证】VSCode下载和安装与配置PHP开发环境(详细版)_vscode php-CSDN博客 ~~~~~~~~~~~~~~~~~~~~~~~~~ 在VS Code中下载插件 Prettier SQL VSCode 和 MySQL : 随后在VS Code中点击Database图标 在连接界面输入MySQL数据库…

Java唯一键实现方案

数据唯一性 1、生成UUID1.1 代码中实现1.2 数据库中实现优点缺点 2、数据库递增主键优点 3、数据库递增序列3.1 创建序列3.2 使用序列优点缺点 在Java项目开发中,对数据的唯一性要求,业务数据入库的时候保持单表只有一条记录,因此对记录中要求…

【MySQL】可重复读级别下基于Next Key Lock解决幻读

昨天读到了一篇文章[1],里面讲,面试官说mysql的可重复读级别下有解决幻读的方式,最后公布了答案,是在sql后面加for update。这么说倒是没错,但是这种问法给我一种奇怪的感觉,因为for update无论在哪个隔离级…

vscode通过.vscode/launch.json 内置php服务启动thinkphp 应用后无法加载路由解决方法

我们在使用vscode的 .vscode/launch.json Launch built-in server and debug 启动thinkphp应用后默认是未加载thinkphp的路由文件的, 这个就导致了,某些thinkphp的一些url路由无法访问的情况, 如http://0.0.0.0:8000/api/auth.admin/info这…

【canal 中间件】canal 实时监听 binlog

文章目录 一、安装 MySQL1.1 启动 mysql 服务器1.2 开启 Binlog 写入功能1.2.1创建 binlog 配置文件1.2.2 修改配置文件权限1.2.3 挂载配置文件1.2.4 检测 binlog 配置是否成功 1.3 创建账户并授权 二、安装 canal2.1 安装 canal-admin(可选)2.1.1 启动 canal-admin 容器2.1.2 …

在阿里云快速启动Umami玩转网页分析

阿里云计算巢提供了Umami快速部署能力,使用者不需要自己下载代码,不需要自己安装复杂的依赖,不需要了解底层技术,只需要在控制台图形界面点击几下鼠标就可以快速部署并启动Umami,非技术同学也能轻松搞定。 什么是Umam…

【模型学习之路】手写+分析bert

手写分析bert 目录 前言 架构 embeddings Bertmodel 预训练任务 MLM NSP Bert 后话 netron可视化 code2flow可视化 fine tuning 前言 Attention is all you need! 读本文前,建议至少看懂【模型学习之路】手写分析Transformer-CSDN博客。 毕竟Bert是tr…

stm32移植LVGL(LVGL 8.2.0)

目录 1.下载LVGL源码 2.修改LVGL文件夹 (1)文件夹 examples 改动 (2)文件夹 demos 改动 3.最终LVGL文件夹内容 4.软件Keil配置、添加头文件 5.程序配置 6.其它配置参考链接 1.下载LVGL源码 LVGL源码地址:https://github.com/lvgl/lvgl 2.修改LVGL文件夹…

海南华志亿星电子商务有限公司电商服务的璀璨新星

在这个全民直播、短视频带货风起云涌的时代,抖音电商以其独特的魅力成为了众多商家竞相追逐的蓝海市场。而在这片波澜壮阔的商海中,海南华志亿星电子商务有限公司犹如一颗璀璨的新星,以其专业的服务、创新的策略,为无数商家照亮了…

动手学深度学习66 使用注意力机制的seq2seq

1. 使用注意力机制的seq2seq key value等价 是一个东西 第i个词rnn的输出 根据加权的不同,解码器前面用编码器前面的输出,到后面用后面的输出。 2. code 核心代码: context 怎么算 embedding没变,Decoder加了attention层 class Seq2SeqAt…

高校大数据实训平台介绍

高校大数据实验室架构 具体实训平台介绍 编程实训平台 1、大数据开发实训平台 大数据开发实训平台是面向实训课和课后训练的编程实训平台,平台底层基于Docker技术,采用容器云部署方案,预装大数据相关课程教学所需的实训环境…

【快速上手】pyspark 集群环境下的搭建(Yarn模式)

目录 前言: 一、安装步骤 安装前准备 1.第一步:安装python 2.第二步:在bigdata01上安装spark 3.第三步:同步bigdata01中的spark到bigdata02和03上 二、启动 三、可打开yarn界面查看任务 前言: 上一篇介绍的是…

【ARM Linux 系统稳定性分析入门及渐进 1.2 -- Crash 工具依赖内容】

文章目录 Prerequisites1. 内核对象文件2. 内存镜像3. 平台处理器类型4. Linux 内核版本 Prerequisites crash 工具需要依赖下面的内容: 1. 内核对象文件 vmlinux 文件:需要一个 vmlinux 内核对象文件,在本文中称为命名列表(na…

【Canal 中间件】Canal 实现 MySQL 增量数据的异步缓存更新

文章目录 一、安装 MySQL1.1 启动 mysql 服务器1.2 开启 Binlog 写入功能1.2.1创建 binlog 配置文件1.2.2 修改配置文件权限1.2.3 挂载配置文件1.2.4 检测 binlog 配置是否成功 1.3 创建账户并授权 二、安装 RocketMQ2.1 创建容器共享网络2.2 启动 NameServer2.3 启动 Broker2.…

Spring Boot2.x教程:(十)从Field injection is not recommended谈谈依赖注入

从Field injection is not recommended谈谈依赖注入 1、问题引入2、依赖注入的三种方式2.1、字段注入(Field Injection)2.2、构造器注入(Constructor Injection)2.3、setter注入(Setter Injection) 3、为什…

解决 ClickHouse 高可用集群中 VRID 冲突问题:基于 chproxy 和 keepalived 的实践分析

Part1背景描述 近期,我们部署了两套 ClickHouse 生产集群,分别位于同城的两个数据中心。这两套集群的数据保持一致,以便在一个数据中心发生故障时,能够迅速切换应用至另一个数据中心的 ClickHouse 实例,确保服务连续性…

【Android】View的事件分发机制

文章目录 分发顺序ActivityViewGroupView 协作方法整体流程注意 Activity事件分发ViewGroup事件分发View点击事件总结 分发顺序 Activity->ViewGroup->View Activity 分发事件:Activity 通过 dispatchTouchEvent 方法分发事件,首先尝试将事件传递…

java项目之微服务在线教育系统设计与实现(springcloud)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的闲一品交易平台。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 微服务在线教育系统设计与…

ChatGPT:真如吹的那般神乎其神吗?

ChatGPT的确是个神奇的东西。短短600多天,就已成全球访问量最大的网站之一。 ChatGPT已经出现在与这些大佬顶级大佬Google、Youtube、X.com、Baidu、Yahoo、amazon、Tiktok一起。 当然ChatGPT很优秀,这没有疑问,主要问题还是对度的把握上。…