MiniMax-01中Lightning Attention的由来(线性注意力进化史)

news2025/3/3 21:24:29

目录

  • 引言
  • 原始注意力
  • 线性注意力
  • 因果模型存在的问题
  • 累加求和操作的限制
  • Lightning Attention
    • Lightning Attention-1
    • Lightning Attention-2
  • 备注

引言

MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。

那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?

本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。

原始注意力

现在主流的有两类模型,一种是应用双向注意力的bert类模型,另一种是应用单向注意力的gpt类模型,他们所使用的注意力其实是有细微差别的。

  • 双向注意力(bert类),就是传统认知中标准的注意力

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKT)V

  • 单向注意力(因果模型,gpt类),只能看到当前和前面的token,所有要在softmax之前乘上一个掩码矩阵,M为单向掩码矩阵

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKTM)V

其中Q、K、V每个矩阵的维度都是[n, d],即[序列长度,隐层维度],此时 Q K T QK^T QKT的维度是[n, n],所以整体复杂度是 O ( n 2 d ) O(n^2d) O(n2d)。其中d是固定大小, n 2 n^2 n2随着序列长度平方增加,就主导了整体的复杂度。

线性注意力

原始注意力中softmax的作用主要是引入非线性(取概率化再与V乘都是次要的),那就可以将其换成其他的非线性激活函数。
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V
这里的 ϕ \phi ϕ代表所使用的激活函数,有很多种可以选择(论文常用的有1+elu)。这里的归一化就先省略掉了,有一些论文就将K矩阵的归一化放到分母上(或者说K矩阵归一化的逆)。

此时观察,使用softmax必须等 Q K T QK^T QKT先计算完,而使用其他的激活函数只对单个Q或者K进行运算,不需要绑定 Q K T QK^T QKT。所以就可以将左乘变成右乘
( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
此时 ϕ ( K ) T V \phi(K)^TV ϕ(K)TV的复杂度是 O ( d 2 ) O(d^2) O(d2),所以整体复杂度变成了 O ( n d 2 ) O(nd^2) O(nd2),随着序列长度n线性增长,此时就是线性注意力了。

(可选):通常线性注意力的公式还有如下形式

O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ1(QKTV)

(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。

因果模型存在的问题

注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的 q t q_t qt乘以前面的 k j k_j kj,再乘以 v j v_j vj累加求和。此时 Q K T QK^T QKT可以正常进行矩阵运算,然后使用 ⊙ \odot (Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。

O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKTM)V

o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1t(qtTkj)vj

此时注意,上面公式的运算涉及 ⊙ \odot ,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KTMV) ⊙ \odot 是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。

累加求和操作的限制

双向注意力模型(bert)中使用的线性注意力如下,可以先算KV

( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图

在这里插入图片描述

  • 双向注意力计算
    K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]

K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12k1dk21k22k2dk31k32k3dk41k42k4d

V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d

K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]= [KTV]1[KTV]2[KTV]d

此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。

q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 ⋮ [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 ⋮ q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} \\ [K^{T}V]_{2} \\ \vdots \\ [K^{T}V]_{d} \\ \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} \\ q_3[K^{T}V]_{2} \\ \vdots \\ q_3[K^{T}V]_{d} \\ \end{bmatrix} q3KTV=q3 [KTV]1[KTV]2[KTV]d = q3[KTV]1q3[KTV]2q3[KTV]d

也可以使用以下代码测试

q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)

# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1], 
                   [2, 2, 2, 2], 
                   [3, 3, 3, 3], 
                   [4, 4, 4, 4],
                   [5, 5, 5, 5],
                   [6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1]])

print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q, kT @ v)
print('result', result)

此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。

但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。

o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。

Lightning Attention

到这里可以引出MiniMax-01 中所使用的Lightning Attention了,但其实这个注意力有两个版本,MiniMax-01中所提到的就是是Lightning Attention-2,那咱们先看看第一个版本做了什么。

Lightning Attention-1

源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer

Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ⊙ M ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)TM)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention中的硬件加速。

其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是 O ( n 2 d ) O(n^2d) O(n2d)
在这里插入图片描述
在这里插入图片描述

Lightning Attention-2

源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Lightning Attention-2解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为 O ( n d 2 ) O(nd^2) O(nd2)
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。

公式中也可以发现,当前step之前的k和v是可以相乘的,比如 q 3 q_3 q3在计算时,可以将 k 1 T v 1 + k 2 T v 2 + k 3 T v 3 k_1^Tv_1+k_2^Tv_2+k_3^Tv_3 k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
在这里插入图片描述
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。

此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。

块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。

下图是Lightning Attention-2的结构图, λ \lambda λ是它的模型所使用的位置编码,忽略即可。
在这里插入图片描述
以下是前向传播和反向传播流程。
在这里插入图片描述
在这里插入图片描述
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了

解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。

备注

个人理解,若有不对请指出,谢谢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2284043.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

汽车网络信息安全-ISO/SAE 21434解析(中)

目录 第七章-分布式网络安全活动 1. 供应商能力评估 2. 报价 3. 网络安全职责界定 第八章-持续的网络安全活动 1. 网路安全监控 2. 网络安全事件评估 3. 漏洞分析 4. 漏洞管理 第九章-概念阶段 1. 对象定义 2. 网路安全目标 3. 网络安全概念 第十章 - 产品开发 第十…

LLaMA-Factory 微调LLaMA3

LoRA介绍 LoRA(Low-Rank Adaptation)是一种用于大模型微调的技术, 通过引入低秩矩阵来减少微调时的参数量。在预训练的模型中, LoRA通过添加两个小矩阵B和A来近似原始的大矩阵ΔW,从而减 少需要更新的参数数量。具体来…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.17 时间魔法:处理千万级时间序列的秘籍

1.17 时间魔法:处理千万级时间序列的秘籍 目录 #mermaid-svg-fa6SvjKCpmJ6C2BY {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-fa6SvjKCpmJ6C2BY .error-icon{fill:#552222;}#mermaid-svg-fa6SvjKCpmJ6…

WPS数据分析000009

一、函数与数据透视表统计数据时效率差异 函数 F4绝对引用 数据透视表 二、数据透视表基础操作 数据透视表:一个快速的生成报表的工具 显示详细信息 方式一; 方式二: 移动数据透视表 删除数据透视表 复制粘贴数据透视表 留足空间,否则拖动字…

Ansible自动化运维实战--script、unarchive和shell模块(6/8)

文章目录 一、script模块1.1、功能1.2、常用参数1.3、举例 二、unarchive模块2.1、功能2.2、常用参数2.3、举例 三、shell模块3.1、功能3.2、常用参数3.3、举例 一、script模块 1.1、功能 Ansible 的 script 模块允许你在远程主机上运行本地的脚本文件,其提供了一…

K8S 快速实战

K8S 核心架构原理: 我们已经知道了 K8S 的核心功能:自动化运维管理多个容器化程序。那么 K8S 怎么做到的呢?这里,我们从宏观架构上来学习 K8S 的设计思想。首先看下图: K8S 是属于主从设备模型(Master-Slave 架构),即有 Master 节点负责核心的调度、管理和运维,Slave…

用Python和PyQt5打造一个股票涨幅统计工具

在当今的金融市场中,股票数据的实时获取和分析是投资者和金融从业者的核心需求之一。无论是个人投资者还是专业机构,都需要一个高效的工具来帮助他们快速获取股票数据并进行分析。本文将带你一步步用Python和PyQt5打造一个股票涨幅统计工具,不…

猿人学第一题 js混淆源码乱码

首先检查刷新网络可知,m参数被加密,这是一个ajax请求 那么我们直接去定位该路径 定位成功 观察堆栈之后可以分析出来这应该是一个混淆,我们放到解码平台去还原一下 window["url"] "/api/match/1";request function…

【学术会议征稿】第五届能源、电力与先进热力系统学术会议(EPATS 2025)

能源、电力与先进热力系统设计是指结合物理理论、工程技术和计算机模拟,对能源转换、利用和传输过程进行设计的学科领域。它涵盖了从能源的生产到最终的利用整个流程,旨在提高能源利用效率,减少能源消耗和环境污染。 重要信息 官网&#xf…

对神经网络基础的理解

目录 一、《python神经网络编程》 二、一些粗浅的认识 1) 神经网络也是一种拟合 2)神经网络不是真的大脑 3)网络构建需要反复迭代 三、数字图像识别的实现思路 1)建立一个神经网络类 2)权重更新的具体实现 3&am…

redis的分片集群模式

redis的分片集群模式 1 主从哨兵集群的问题和分片集群特点 主从哨兵集群可应对高并发写和高可用性,但是还有2个问题没有解决: (1)海量数据存储 (2)高并发写的问题 使用分片集群可解决,分片集群…

【29】Word:李楠-学术期刊❗

目录 题目​ NO1.2.3.4.5 NO6.7.8 NO9.10.11 NO12.13.14.15 NO16 题目 NO1.2.3.4.5 另存为手动/F12Fn光标来到开头位置处→插入→封面→选择花丝→根据样例图片,对应位置填入对应文字 (手动调整即可)复制样式:开始→样式对话框→管理…

基于 AI Coding 「RTC + STT」 Web Demo

文章目录 1. 写在最前面1.1 旧测试流程1.2 新测试流程 2. Cursor 编程 vs Copilot 编程2.1 coding 速度2.2 coding 正确性 3. 碎碎念 1. 写在最前面 为了 Fix 语音转文字(STT)产品在 Json 协议支持上的问题,笔者需要将推送到 RTC 的数据按照…

dup2 + fgets + printf 实现文件拷贝

思路 将源文件的内容读取到内存中,然后将这些内容写入到目标文件。 1: 打开源文件、目标文件 fopen() 以读模式打开源文件。 open ()以写模式打开目标文件。 2: 读取源文件、写入目标文件 fgets ()从源文件中读取内容。 printf ()将内容写入目标文件。 printf…

[ACTF2020 新生赛]Upload1

题目 以为是前端验证&#xff0c;试了一下PHP传不上去 可以创建一个1.phtml文件。对.phtml文件的解释: 是一个嵌入了PHP脚本的html页面。将以下代码写入该文件中 <script languagephp>eval($_POST[md]);</script><script languagephp>system(cat /flag);&l…

SpringBoot整合Swagger UI 用于提供接口可视化界面

目录 一、引入相关依赖 二、添加配置文件 三、测试 四、Swagger 相关注解 一、引入相关依赖 图像化依赖 Swagger UI 用于提供可视化界面&#xff1a; <dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger-ui</artifactI…

深度学习项目--基于LSTM的糖尿病预测探究(pytorch实现)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 前言 LSTM模型一直是一个很经典的模型&#xff0c;一般用于序列数据预测&#xff0c;这个可以很好的挖掘数据上下文信息&#xff0c;本文将使用LSTM进行糖尿病…

LeetCode - Google 大模型校招10题 第1天 Attention 汇总 (3题)

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/145368666 GroupQueryAttention(分组查询注意力机制) 和 KVCache(键值缓存) 是大语言模型中的常见架构&#xff0c;GroupQueryAttention 是注意力…

Kotlin开发(七):对象表达式、对象声明和委托的奥秘

Kotlin 让代码更优雅&#xff01; 每个程序员都希望写出优雅高效的代码&#xff0c;但现实往往不尽人意。对象表达式、对象声明和 Kotlin 委托正是为了解决代码中的复杂性而诞生的。为什么选择这个主题&#xff1f;因为它不仅是 Kotlin 语言的亮点之一&#xff0c;还能极大地提…

数据库、数据仓库、数据湖有什么不同

数据库、数据仓库和数据湖是三种不同的数据存储和管理技术&#xff0c;它们在用途、设计目标、数据处理方式以及适用场景上存在显著差异。以下将从多个角度详细说明它们之间的区别&#xff1a; 1. 数据结构与存储方式 数据库&#xff1a; 数据库主要用于存储结构化的数据&…