前言
实话说,过去一两月一直忙着我司两大类项目的推进
- 一类是正在逐一上线基于大模型的论文翻译、论文审稿、论文对话、论文修订/润色、论文idea提炼等等
- 一类是正在抓紧做面向一个个工厂的具身智能机器人的解决方案,且很快会分别在我司在各地的办公室(南京、长沙、武汉、北京)一一摆上一两台干活的具身机器人
所以虽然说mamba2已发布一月有余,但实在是没有一块完整的时间来对其做详尽而细致的解读,而最终促使我来写的最大的动力还是来源于我半年前对mamba1的解读,实在是太受欢迎了且影响力巨大(截止到24年7月初,半年下来阅读量10万,2千余次收藏,在同样发表半年内文章中的表现很突出)
加之之前就有读者在我对上面mamba1做解读的文章下留言,什么时候出mamba2的解读,让我好几次跃跃欲试想开写
然,在我下定决心写本文之前,内心还是有过一阵小纠结的
- 一方面,怕没有一大块完整的时间(回想过去,23年上半年因为ChatGPT,公司重新焕发生机,个人也前所未有的沉迷于技术,又因23年下半年做大模型项目延续至今,今后因为业务的增长 大量的各种会议 可能难以再像过去一年半百分百沉迷于技术了)
- 二方面,mamba2的论文特别长,即《Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality》一文长达52页,全是各种概念、公式,故为了更好的理解mamba2,建议先熟练mamba1
当然,mamba2的核心主要解决两个问题:1 打通SSM与transformer之间的联系,2 将mamba2表述为矩阵乘法以加速训练
不过还是因为过去十多年写博客的经验,使得自己在面对再难啃的算法都有足够的自信与底气,坚信都可以一步步拆解、一步步抽丝剥茧并清晰易懂的写出来(读者在看本文时,也不用急,一步步来,可以慢慢看懂的,且未来一两月 我也会不断修订本文以让之不断更加通俗易懂),故本文最终还是来了
第一部分 背景回顾:从SSM、结构化矩阵到结构化状态空间对偶SSD
1.1 结构化状态空间模型Structured State Space Model
1.1.1 离散化、循环结构表示、卷积结构表示
虽然在之前对mamba1的讲解中已经讲过了很多背景,但为本文的完整性起见,还是把一系列背景知识按照mamba2论文的思路,再度逐一梳理下
首先,结构化状态空间序列模型S4是受到的特定连续系统的启发(如下述公式1所示,是结构化SSM的一般离散形式),该系统将一维序列通过隐式潜在状态 做映射(相当于将SSM简单地写成矩阵乘法)
- 其中、均是标量,则被视为具有N维的向量,且
- 其中的 矩阵 控制时间动态,从而必须是结构化的(结构化SSM也因此得名),以便能够足够高效地计算这种序列到序列的转换,从而在深度神经网络中使用
梳理一下结构化SSM的发展历史
- 最初的结构化SSM起源于函数的连续时间映射,而不是直接对序列进行操作
在连续时间视角中,在公式(1a)中,矩阵 (𝐴, 𝐵)不是直接学习的,而是从底层参数生成的,并且伴随着一个参数化的步长 Δ
“连续参数”通过固定公式和转换为“离散参数”(𝐴, 𝐵),其中这对 (𝑓𝐴, 𝑓𝐵 )被称为discretization rule - 结构化 SSM 可以被视为一种递归神经网络RNN,其中线性赋予它们额外的属性,并使它们能够避免传统 RNN 的顺序计算。相反,尽管有这种简化,SSM 仍然可以完全表达为序列变换 更多详见此文《一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba》的第2.1.2节
- 当SSM的动态在时间上是恒定的,如公式(1)所示,该模型称为线性时不变(linear time-invariant,简称LTI)模型,在这种情况下,它们等同于卷积 因此,SSM也可以被视为CNN的一种类型,但卷积核通过SSM参数 (𝐴, 𝐵, 𝐶)隐式参数化,且卷积核通常是全局的而不是局部的
反过来,通过经典的信号处理理论,所有充分良好的卷积都可以表示为SSM
通常,以前的LTI SSM会
- 使用卷积模式进行高效的可并行训练(整个输入序列提前看到)
- 并切换到递归模式(如本节开头的公式1所述)进行高效的自回归推理(输入逐步看到)
1.1.2 mamba一代的问题:没法用矩阵乘法
当在 Mamba1 中被引入为选择性 SSM时,则相当于允许(A, B, C)这三个参数随时间而变化(如下面公式2所示),此时,、、
公式2与标准的 LTI 公式1相比,该模型可以在每个时间步选择性地关注或忽略输入
在信息密集型数据如语言上,它的表现被证明远优于 LTI SSM,特别是随着其状态大小 N的增加,允许更多的信息容量。 然而,它只能在递归模式下计算,而不是卷积模式,并且需要仔细的硬件感知实现才能高效,即如下图所示
即便如此,它仍然不如硬件友好的模型(如 CNN 和 Transformer)高效,因为它没有利用矩阵乘法单元,而现代加速器(如 GPU 和 TPU)正是为此而专门设计的
总之,虽然时间不变SSM 与连续、递归和卷积序列模型密切相关,但它们与注意力机制没有直接关系。所以mamba2想揭示选择性SSM和注意力机制之间的更深层次关系,并利用这一点显著提高SSM的训练速度,同时允许更大的状态规模N
1.1.3 结构化SSM作为序列变换:三个定义之2.1 2.2 2.3
请直接看一下三个定义(分别定义序列变换、S6和注意力机制的序列变换形式、序列变换与矩阵的联系)
- 定义 2.1 一般而言,所谓序列变换指的是序列上的参数化映射
其中,,并且𝜃是任意参数集合
表示序列或时间轴,可以作为下标索引到第一个维度,例如
序列变换(例如SSM或自注意力机制)是深度序列模型的基石,它们被整合到神经网络架构中 例如Transformer
其实上面的公式1或2中的SSM便是一个序列变换,且 P = 1 当然,它可以通过简单地在此维度上来推广到 P > 1(换句话说,将输入视为 P 个独立序列并对每个序列应用SSM,即可以将 P视为一个头维度) - 定义 2.2 定义SSM 操作符作为序列变换,由上面的公式2定义
在 SSM 中, N维度是一个称为状态大小或状态维度的自由参数,也称之为状态扩展因子,因为它将输入/输出的大小扩展了 𝑁倍,这对这些模型的计算效率有影响(其实许多类型的序列变换,例如注意力机制,都可以表示为跨序列维度的单一矩阵乘法) - 定义 2.3 如果一个序列变换可以写成形式,其中𝑀是一个依赖于参数𝜃的矩阵,称其为矩阵变换,且用矩阵𝑀来表示序列变换
当然,在上下文明确时,通常省略对𝜃的依赖
1.2 注意力机制、结构化矩阵、结构化状态空间对偶SSD
1.2.1 线性注意力机制
注意力机制已经非常经典了(如果还不熟悉注意力机制的,请参见此文:Transformer通俗笔记:从Word2Vec、Seq2Seq逐步理解到GPT、BERT),屡见不鲜,其为序列中每对位置分配分数,使每个元素能够“关注”其余部分。 迄今为止,最常见和最重要的注意力机制变体是softmax自注意力机制,其定义如下
对于,由于注意力机制需要一次次计算两两token之间的注意力(毕竟有这个计算),导致了二次方的计算复杂度
为了降低二次方的复杂度,已经提出了许多注意力的变体,其中最重要的变体是线性注意力(详见此文的2.2.1 什么是线性transformer:Transformers are RNNs与cosformer)
粗略地说,这类方法通过将softmax折叠到核特征映射中,并利用矩阵乘法的结合性将注意力计算中的矩阵左乘改成右乘,即,如下图右侧所示,将QKV的左乘变成右乘后,从⽽将理论计算复杂度降为线性「更多详见此文《七月论文审稿GPT第1版:通过3万多篇paper和10多万的review数据微调RWKV》的2.2节」
值得一提的是
- 提出线性注意力的这个标题:
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention「作者:A Katharopoulos · 2020」 - 是否与提出mamba2的论文标题:
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
有着某种千丝万缕的联系呢
再进一步,既然transformer是RNN,而SSM某种意义上也是RNN,那mamba2和transformer是否有着直接的联系?不急,请继续看下文的讲解
此外,在因果(自回归)注意力的重要情况下,他们表明,当因果掩码被合并到左侧作为,其中是下三角1矩阵时,右侧可以扩展为递归(Moreover, in the important case of causal (autoregressive) attention, they show that when the causal mask is incorporated into the left-hand side as (𝐿 ◦ 𝑄𝐾⊤) · 𝑉 , where 𝐿 is the lower-triangular 1’s matrix, then the right-hand side can be expanded as a recurrence)
最近的一些工作,如RetNet(Y. Sun等,2023)和GateLoop(Katsch 2023)将其加强为更一般形式的
1.2.2 结构化矩阵(Structured Matrices):方便做矩阵乘法
一般矩阵需要 个参数来表示,并且执行诸如矩阵-向量乘法等基本操作需要时间。而所谓的结构化矩阵是指那些
- 可以压缩表示,比如在亚二次(理想情况下是线性)参数中表示,并且
- 通过快速算法(最重要的是矩阵乘法),直接操作这种压缩表示
也许最典型的结构化矩阵家族是稀疏矩阵和低秩矩阵。 然而,还存在许多其他家族,例如Toeplitz矩阵、Cauchy矩阵、Vandermonde矩阵和蝶形矩阵
1.2.3 结构化状态空间对偶(SSD):注意力矩阵乘以掩码矩阵
状态空间对偶(SSD)层可以定义为选择性SSM(如之前公式2所示)的特例
可以应用SSM作为递归(或并行扫描)的标准计算,其在序列长度上具有线性复杂度。 与Mamba中使用的版本相比,SSD有两个小的不同点:
- 的结构从对角线进一步简化为标量乘以单位矩阵结构。 在这种情况下,每个也可以仅用一个标量来表示
- 使用了更大的头维度 ,相比于Mamba1中使用的 P = 1,通常选择,而Transformer一般也会这样设置头的维度
与原始选择性SSM相比,这些变化可以被视为在略微降低表达能力的同时 显著提高训练效率。 特别是,新算法将允许在现代加速器上使用矩阵乘法单元
更进一步,SSD的对偶形式是一种与注意力密切相关的平方计算,其定义为
其中是依赖于输入的标量,范围在 [0, 1]之间
SSD与标准的softmax注意力相比,有两个主要区别
- 去掉了softmax
- 注意力矩阵按元素乘以一个额外的掩码矩阵
这两种变化都可以被视为解决了原始注意力中的问题。 例如,最近观察到softmax在注意力分数中会引起问题,如“注意力陷阱”现象(Darcet等,2024;Xiao等,2024)。 更重要的是,掩码矩阵可以被视为用不同的数据依赖位置掩码替换Transformer的启发式位置嵌入,从而控制跨时间传递的信息量(the mask matrix 𝐿 can be viewed as replacing the heuristic positional embeddings of Transformers with a different data-dependent positional mask that controls how much information is transfered across time)
更广泛地说,这种形式是下文定义的线性注意力的结构化掩码注意力泛化的一个实例
- 总之,通过展示SSM具有矩阵变换形式,对于一个依赖于的矩阵,各种形式的SSD可以通过统一的矩阵表示连接起来
- 特别地,SSD的对偶形式等价于通过矩阵 𝑀进行的朴素(平方时间)乘法,而递归形式是一种利用 𝑀结构的特定高效(线性时间)算法
以上之外,任何用于乘以 𝑀的算法都可以应用,此次提出的硬件高效SSD算法是一种新的结构化矩阵乘法方法,涉及 𝑀的块分解,比纯线性或二次形式获得更好的效率权衡。且与一般选择性SSM——mamba1(Gu和Dao 2023,即Albert Gu and Tri Dao. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”)相比,它相对简单且易于实现
第二部分 从状态空间模型是结构化矩阵、使用结构化矩阵推广线性注意力到SSD
2.1 状态空间模型是结构化矩阵:State Space Models are Structured Matrices
2.1.1 状态空间模型的矩阵变换形式:公式3
回顾一下,对选择性SSM——即mamba1的定义是通过之前的公式2定义的参数化映射
根据定义,,通过归纳法,可知时刻 的状态 ,可以表示为之前各个时刻的状态的加权和,即如下
上述公式中的
- 第一行的每一项表示的是之前某个时刻 的状态 经过一系列线性变换后的结果,最后这些结果加在一起得到了当前时刻的状态
- 第二行中的表示从一直乘到
为方便大伙一目了然,加之十多年前,我就提醒自己,写博客的目标之一是 如果某个算法看别的资料看不懂、看不动,那可以看懂、看动我的(坚持10多年来了,好处是博客影响力巨大,不好是累人),故还是要不厌其烦的解释下
其实理解上面那个公式很简单,直接一步一步推导一下即可,如下所示便可一目了然
- t = 0时
在这种情况下,是单位矩阵,故有- t = 1时
- t = 2时
- t = 3时
当然,这个过程可以借鉴下mamba1的这个图,只是、还没加上这个参数而已,所以在不同的输入x之下,便不会存在不同的、(下图来自mamba1解读一文的第2.1.2节SSM的循环结构表示:方便快速推理)
通过乘以矩阵来生成并将方程在上向量化,可推导出SSM的矩阵变换形式,如下(称之为公式3)
对于上述公式3,我举个例子,比如因为有
故可得
好比
2.1.2 半可分矩阵(定义3.1 3.2/公式4 3.3/公式5 3.4):顺序半可分、1-半分离矩阵(标量SSM递归)
首先,我们先来看下半可分矩阵(Semiseparable Matrices)的定义「称之为定义3.1,在有的文献中也被称为 (N, 0)-半可分性)」
一个(下三角)矩阵 𝑀是 N-半可分的,如果包含在下三角部分(即对角线或以下)的每个子矩阵的秩最多为 N,则称 N为半可分矩阵的阶数或秩
Definition 3.1. A (lower triangular) matrix 𝑀 is N-semiseparable if every submatrix contained in the lower triangular portion
(i.e. on or below the diagonal) has rank at most N. We call N the order or rank of the semiseparable matrix
其和其他形式的相关“可分”结构(例如准可分矩阵和其他半可分矩阵的定义)有时被称为结构化秩矩阵(或秩结构矩阵),因为它们的子矩阵由秩条件表征
半可分矩阵有许多结构化表示,包括分层半可分HSS、顺序半可分SSS和Bruhat形式(Pernet和Storjohann 2018),此处将主要使用SSS形式
2.1.2.1 顺序半可分SSS表示(The Sequentially Semiseparable (SSS) Representat):每个 N-半可分矩阵都有一个 N-SSS 表示
先看顺序半可分矩阵SSS表示的定义(称其为定义3.2,公式4)
一个下三角矩阵 𝑀 ∈ R(T,T)如果它可以写成以下形式,则具有 N-顺序半可分(SSS)表示
对于向量和矩阵,定义算子 SSS使得
换言之,如果是且,使得
相当于(注意,是竖着看每一列)
这个SSS表示带来的好处是如定义3.3所示
一个 N-SSS 矩阵 𝑀具有上面公式(4)的表示,则便是 N-半可分的「Lemma 3.3 An N-SSS matrix 𝑀 with representation (4) is N-semiseparable」
证明如下(定义为公式5)
考虑任何非对角块,其中 𝑗 ′ > 𝑗 ≥ 𝑖 > 𝑖′,这具有显式的秩-N分解为
为了避免正在阅读此文的你头疼,我还是用一个具体的示例来形象的说明下上述公式5
- 假设有以下矩阵
- 根据公式5的结构,选择j' = 2、j = 1、i = 1、i’ = 0,然后有:
- 先算左上角那个式子的结果,可得
先算前两项
再算前两项与第三项的结果- 计算右上角那个式子的结果
由于有
故而有- 接下来,计算左下角那个式子的结果
假设与相同,则
且假设与相同,则- 最后,计算右下角那个式子的结果
由于有
且有
则可得- 最终,将上面这些结果全部合并起来,则可以得到矩阵
且有定义 3.4 即每个 N-半可分矩阵都有一个 N-SSS 表示
2.1.2.2 1-半可分矩阵(标量SSM递归):许多序列模型算法可以归结为结构化矩阵乘法算法
接下来,列出1-SS矩阵的特殊情况,此时𝐶𝑗和 𝐵𝑖是标量,可以从SSS的表示(如上面的公式4所示),即中提取出来
由于对角矩阵易于处理(例如,对角矩阵的乘法与元素级标量乘法相同),故可以忽略这些项。 因此,对1-SS矩阵的基本表示是或如下(定义为公式6,依然是竖着看每一列)
其等同于标量递归的最小形式——即状态维度 N = 1且没有 (𝐵, 𝐶)投影的退化SSM情况
值得注意的是,矩阵乘法𝑦 = 𝑀𝑥 可以通过如下的式子进行递归计算(定义为公式7)
即
因此,也将1-SS矩阵的矩阵乘法称为标量SSM递归或累积乘积和(累积乘积和的广义形式)作为递归的基本形式,同时也是本次mamba2主要算法的构建模块
也从侧面说明,许多序列模型的算法可以归结为结构化矩阵乘法算法。1-SS矩阵体现了这一联系:有许多快速算法可以计算原始标量递归或cumprodsum算子,所有这些算法实际上都等价于1-SS矩阵的不同结构分解
2.1.3 状态空间模型是半可分矩阵:使得SSM问题转化为结构化矩阵乘法
回顾一下,我们对SSM的定义是通过定义2.1定义的参数化映射
SSM与半可分矩阵之间的联系仅仅是通过将这种变换写成矩阵乘法,将向量
- 公式(3)
直接建立了状态空间模型与顺序半可分表示之间的联系
而顺序半可分表示又等价于一般的半可分矩阵(定义3.3和定义3.4) - 定义 3.5 状态空间模型变换 𝑦 = SSM(𝐴, 𝐵, 𝐶)(𝑥)具有状态大小 N,等同于按顺序半可分表示的 N-SS 矩阵的矩阵乘法 𝑦 = SSS(𝐴, 𝐵, 𝐶) · 𝑥
换句话说,序列变换算子SSM(定义 2.2)
与矩阵构造算子 SSS(定义3.2) 一致
可以互换使用它们(有时也用SS作为简写)。 此外,巧合的是 结构化状态空间模型和顺序半可分矩阵具有相同的缩写,强调了它们的等价性
且可以使用这些缩写 SSM(状态空间模型或半可分矩阵)、SSS(结构化状态空间或顺序半可分) 或 SS(状态空间或半可分)互换使用,以明确地指代任一概念
当然,最终的约定一般是:SSM指状态空间模型,SS指半可分,SSS指顺序半可分
如下图所示,说明了将状态空间模型视为半可分矩阵的序列变换视角
- 作为序列变换,状态空间模型可以表示为作用于序列维度T上的矩阵变换𝑀∈R(T,T),在一个头的每个通道中共享相同的矩阵(左)
- 这个矩阵是一个半可分矩阵(右),它是一个秩结构矩阵,其中包含在对角线及其以下的每个子矩阵(蓝色)的秩最多为N,等于SSM的状态维度
这个意味着所有计算状态空间模型的算法都可以看作是对半可分矩阵进行结构化矩阵乘法的算法,总之,上面的定义3.5 使得可以将高效计算SSM(及其他序列模型)的问题转化为高效的结构化矩阵乘法算法
2.1.4 通过结构化矩阵算法计算状态空间模型
既然上文已经证明了SSM的计算可以转化为结构化矩阵乘法,那接下来,咱们便通过结构化矩阵算法计算状态空间模型
如前所述,半可分矩阵(即秩结构矩阵)是一种经典的结构化矩阵类型:
- 它们具有压缩表示形式,例如SSS形式只有参数,而不是参数
- 它们有直接在压缩表示上操作的快速算法
此外,参数化和矩阵乘法成本在半可分阶中可以非常紧凑
定义3.6 (Pernet, Signargout, 和 Villard (2023))表示:一个 N-SS 矩阵大小为 T可以用 𝑂 (NT)参数表示,并且矩阵-向量乘法在时间和空间上的复杂度为𝑂 (NT)
例如,1-SS 矩阵说明了这种连接的本质。 矩阵 𝑀 = 1SS(𝑎)由正好 T − 1 个参数,并且可以通过遵循上文提过的标量递归公式7在 𝑂 (T)时间内计算
根据上面的定义3.6可知,只需利用公式(2)
展开递归即可,具体过程如下公式8所示(三个公式分别被定义为8a、8b、8c)
这里, 𝐿 ∈ R(T,T)被定义为 1SS(𝐴),换句话说对于𝑖∈ [N]。 该算法涉及三个步骤,对应于上文的公式2:
- 通过输入矩阵 𝐵 (8a)扩展输入 𝑋
- 展开独立的标量SSM递归 (8b),且在在步骤(8b)中使用了标量SSM和1-SS矩阵之间的等价关系
- 通过输出矩阵 𝐶 (8c)收缩隐藏状态 𝐻
其实,整个公式8算是mamba1(S6)模型的一个特例,其中扩展的张量Z和H的大小为
2.2 结构化掩码注意力:使用结构化矩阵推广线性注意力
2.2.1 从自注意力、核注意力到掩码(核)注意力
注意力的基本形式(单头)是对三个向量序列的映射 (𝑄, 𝐾, 𝑉) ↦ →𝑌,如下所示(定义为公式9)
我们使用““shape annotation”来表示张量的维度,例如 𝑄 ∈ R(T,N),其中
- S和 T表示源和目标序列长度,分别意指:source、target之意
- N表示特征维度
- P表示头维度
最常见的softmax注意力变体使用softmax激活 𝑓 = softmax来规范 𝐺矩阵的行。
此外,虽然注意力通常被框定为对这三个对称视图输入𝑄, 𝐾, 𝑉的操作,但(9)中的输入和输出维度表明情况并非如此,特别是,输出中不存在特征维度 N时。因此在 S = T(例如自注意力)的情况下,将 𝑉视为主要输入,因此 (9)定义了一个适当的序列变换 𝑉 → 𝑌
2.2.1.1 自注意力
对于自注意力,其中(i)源序列和目标序列相同(即 S = T)
- (ii) 通常特征维度和头维度相同(即 N = P)
- (iii) 并且𝑄, 𝐾, 𝑉是通过对同一输入向量的线性投影生成的,即
2.2.1.2 核注意力
// 待更
2.2.1.3 掩码(核)注意力
设 𝐿为形状为 (T, S)的掩码。 最常见的是,在自回归自注意力情况下,当 S = T时, 𝐿可能是一个下三角矩阵,表示因果掩码
除了强制因果关系外,还可以应用许多其他类型的掩码——特别是各种稀疏模式,如带状、扩展
或块对角线——这些都是为了减少密集注意力的复杂性
掩码注意力通常用矩阵表示法表示为(定义为公式10)
更准确地说,带有shape annotation并将其分解为精确的计算序列(定义为公式11):
我们在本节中改进的注意力变体推导从注意到这个公式可以写成一个单一收缩开始(定义为公式12):
而算法11可以通过特定的成对收缩顺序重新表述为算法12的形式,如下公式13所示
2.2.2 线性注意力
如下公式14所示的线性注意力
等价于10:
接下来,以另一种顺序执行上面的公式12,从而得到下面的公式15
其中
- 第一步(15a)通过特征维度 N的因子执行“扩展”到更多特征
- 第二步(15b)是最关键的,并解释了线性注意力的线性部分
首先注意到 (15b) 只是通过 𝐿进行直接矩阵乘法「因为 (P, N)轴可以被展平」。 还要注意,这是唯一涉及 T和 S轴的项,因此应该具有 Ω(TS)复杂度(即序列长度的二次方)
然而,当掩码 𝐿是标准的因果注意力掩码(下三角全为1)时,通过 𝐿进行矩阵-向量乘法与特征逐项累积和相同 - 第三步(15c)收缩扩展的特征维度。 如果将 𝐾视为输入(如上文2.2.1节开头所述),那么 𝑉和 𝑄分别执行扩展和收缩
2.2.3 结构化掩码注意力SMA
通过掩码注意力的张量收缩视角(如公式15所示),得知原始线性注意力的关键在于因果掩码的矩阵-向量乘法等同于累积和操作符
然而,我们观察到没有理由注意力掩码必须全是1。 线性注意力快速的必要条件是 𝐿是一个结构化矩阵,根据定义,这些矩阵具有快速矩阵乘法(根据上文1.2.2节所述的结构化矩阵 所述)
特别是,我们可以使用任何掩码矩阵 𝐿,其矩阵-向量乘法具有次二次(理想情况下是线性)复杂度,这将通过加速瓶颈公式(15b)使其具有与标准线性注意力相同的复杂度
定义 4.2 结构化掩码注意力(或简称结构化注意力)被定义为一个函数作用于查询/键/值𝑄, 𝐾, 𝑉以及
任何结构化矩阵 𝐿(即具有次二次矩阵乘法),通过四维张量收缩
- SMA二次模式算法是通过(公式13)定义的成对收缩序列,对应于标准的(掩码)注意力计算
- SMA线性模式算法是通过(公式15)定义的成对收缩序列,其中步骤(15b)通过次二次结构矩阵乘法进行优化
可以将结构化掩码注意力实例化为任何给定的矩阵结构类别,比如如下图所示的一些实例
- 线性注意力使用因果掩码
- RetNet使用衰减掩码对于某些衰减因子
- 至于SSD下文介绍
- 衰减掩码可以推广到Toeplitz矩阵对于某些可学习的(或依赖于输入的)参数集
这可以解释为一种相对位置编码形式,类似于其他方法如AliBi,但乘法而不是加法 - 另一种变体可以使用傅里叶矩阵(Fourier matrix)以不同的方式编码位置结构
2.3 再谈状态空间对偶性SSD
回想一下,SSM由 𝑦 = SSM(𝐴, 𝐵, 𝐶)(𝑥)定义,SSM的矩阵形式使用SSS(顺序半可分)表示,其中公式3
现在让我们考虑 𝐴𝑗只是一个标量的情况;换句话说,一个结构化SSM的实例,其中𝐴矩阵是极其结构化的:对于标量和恒等矩阵
然后我们可以重新排列
这可以向量化为
其中 𝐵, 𝐶 ∈ R(T,N)
使用这种公式,完整的输出 𝑌 = 𝑀X精确计算为公式16
其中 S = T,从而可以看到这与掩码核注意力公式13的原始定义完全相同
因此,如「第2.1.4 通过结构化矩阵算法计算状态空间模型节」所述,计算标量结构化SSM——通过实现半可分矩阵𝑀并执行二次矩阵-向量乘法——与二次掩码核注意力完全相同
结构化掩码注意力允许使用任何结构化掩码 𝐿。 当 𝐿是因果掩码时,它是标准的线性注意力。 注意,因果掩码是,即1-SS掩码由公式6中的生成
这激发了将𝐿推广到1-半可分掩码类,或1-半可分结构化
掩码注意力(1-SS SMA),其中线性注意力递归中的cumsum被更一般的递归——标量SSM扫描,即1-半可分矩阵乘法所取代
最后,我们考虑1-半可分SMA的最重要原因是计算它的线性形式是对角状态空间模型的一个特例。SMA的线性形式是算法(15),其中瓶颈步骤(15b)可以看作是通过1-SS掩码进行矩阵乘法
// 待更
第三部分 从硬件高效的SSD算法、到Mamba-2 架构
3.1 硬件高效的SSD算法:块分解、对角块、低秩块
定义6.1 考虑一个具有状态扩展因子 N和头部维度 P = N的SSD模型,存在一种算法可以在任何输入上计算模型,该算法只需要训练FLOPs,推理FLOPs,推理内存,其工作主要由矩阵乘法主导
注意,所有这些界限都是紧的,因为具有状态扩展 N的状态空间模型在头部大小为时,总状态大小为 「分别得出训练和推理 FLOPs 的下界为和 」。此外,输入本身有个元素,从而产生了内存下限
如下图所示,状态空间对偶描述了状态空间模型和掩码注意力之间的密切关系
- 上图左侧:一般的 SSM和 SMA 都具有线性和二次形式,在符号上有直接的类比
- 上图右侧:SSM 和 SMA 在一大类状态空间对偶模型(SSD) 上相交,这些模型捕捉了许多序列模型作为特例
定义6.1背后的主要思想是再次将计算状态空间模型的问题视为半可分矩阵乘法,但以一种新的方式利用其结构
即不是在递归或注意模式下计算整个矩阵,而是对矩阵进行块分解
- 对角块可以使用对偶注意模式计算,这可以通过矩阵乘法高效完成
- 而非对角块可以通过半可分矩阵的秩结构进行分解并简化为较小的递归
3.1.1 块分解
首先,我们将矩阵 𝑀划分为一个的子矩阵网格,每个子矩阵的大小为 Q × Q,对于某个块大小 Q。 注意,根据半可分矩阵的定义性质(定义3.1),非对角块是低秩的
如下图所示,分别体现的是块分解、对角块、低秩块
举个例子,例如对于 T = 9 并分解成长度为 Q = 3 的块
上图中的阴影部分是半可分矩阵的非对角块的低秩分
从这里我们可以将问题简化为这两个部分。 这些也可以解释为将“块” 的输出分为两个部分:
- 块内输入的影响
- 以及块之前输入的影响
继续深入之前,先再次回顾下SSM的核心架构,直接贴张图吧(来源于此文中1.3.4节的“建立对SSM中两个核心方程的统一视角”的最后)
然后,如果要完成状态空间对偶(SSD)模型的完整 PyTorch代码,则可以先定义符号来定义批量矩阵乘法与批次维度 B
从而可以推断出效率的三个方面:
- 计算成本:总共FLOPs
- 内存成本:总共空间
- 并行化:更大的 M, N, K项可以利用现代加速器上的专用矩阵乘法单元
def segsum(x):
"""朴素的段和计算。exp(segsum(A)) 生成一个 1-SS 矩阵,等价于一个标量 SSM """
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(X, A, B, C, block_len=64, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# 重新排列成块/段
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1
为方便形象理解,再贴个图,如下(来源于此文中的3.1.1.3节“mamba:从S4到S6的算法变化流程”)
3.1.2 对角块
对角块很容易处理,因为它们只是较小规模的自相似问题。 𝑗-th 块表示计算范围内的答案
- 特别地,对于小块长度 Q,这个问题可以通过对偶二次SMA形式更有效地计算
其中,二次SMA计算的成本包括三个步骤: i) 计算核矩阵,其成本为 BMM( T/Q, Q, Q, N)
ii) 乘以掩码矩阵,这是对形状为 ( T/Q, Q, Q)的张量进行的逐元素操作
iii) 乘以 𝑋值,其成本为 BMM( T/Q, Q, P, N)
此外,这些块可以并行计算 - 这些子问题可以解释为:假设初始状态(到块)为 0,每块的输出是什么。换句话说,对于块 𝑗,这将计算正确的输出,仅考虑块输入
对应的代码为
# 1. 计算每个块内(对角块)的输出
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
3.1.3 低秩块:右B-块因子、中心A-块因子、左C-块因子三个部分的结算
低秩分解由3个项组成,相应地有三部分计算
- 像下面这样的项被称为右因子或 𝐵-块因子
此步骤计算低秩分解的右 𝐵-块因子的乘法。 注意,对于每个块,这是一个(N, Q)乘(Q, P)的矩阵乘法,
其中 N是状态维度, 𝑃是头维度。 每个块的结果是一个(N, P)张量,其维度与扩展的隐藏状态ℎ相同
这可以解释为:假设初始状态(到块)为 0,每个块的最终状态是什么。 换句话说,这计算了 ,其中
对应的代码为
这一步是一个单一的矩阵乘法,成本为 BMM( T/Q, N, P, Q)# 2. 计算每个块内的状态 # (低秩分解的非对角块的右项;B项) decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
- 像这样的项被称为中心因子或 𝐴-块因子
这一步计算了低秩分解中中心 𝐴-块因子项的影响。 在前一步中,每个块的最终状态的总形状为
现在通过一个由这现生成的1-SS矩阵相乘:
这一步可以通过任何用于计算1-SS乘法的算法来计算(也称为标量SSM扫描或累积乘积和操作符)
这可以解释为:每个块的实际最终状态是什么考虑到所有先前的输入; 换句话说,这计算了真实的隐藏状态(考虑到所有的)
对应的代码为
这一步是长度为 T/Q的标量SSM扫描(或1-SS乘法),在 (N, P)独立通道上进行。 这次扫描的工作是 TNP/Q,这是相对于其他因素可以忽略不计的# 3. 计算块间SSM递归;在块边界生成正确的SSM状态 # (非对角块分解的中间项;A项) if initial_states is None : initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1]
请注意,由于阻塞将序列长度从 T减少到 T/Q,这次扫描的成本比纯SSM扫描(例如Mamba的选择性扫描)小 Q倍
因此,我们观察到在大多数问题长度上,其他算法(附录B)可能更有效或更容易实现,而不会显著减慢速度
例如,通过1-SS矩阵乘法的简单实现成本为 BMM(1, T/Q, NP, T/Q),这比简单的递归/扫描实现更容易实现且可能更有效 - 像下面这样的项被称为左因子或 𝐶-块因子
这一步计算了左 𝐶-块因子的低秩分解的乘法。 对于每个块,这可以通过矩阵乘法来表示
这可以解释为:每个块的输出是什么考虑到正确的初始状态,并假设输入为 0
换句话说,对于块 𝑗,这计算了仅考虑先前输入的正确输出
对应的代码为
这一步是一个单一的矩阵乘法,成本为 BMM(T/Q, Q, P, N)# 4. 计算每个块的状态到输出的转换 # 低秩分解的非对角块的左项;C项 state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum( 'bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out
最后,如下# 添加块内和块间项的输出(对角块和非对角块) Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") return Y, final_state
整个的过程,可以用下图表示
- 通过使用状态空间模型的矩阵变换视角将其写成半可分矩阵,通过块分解矩阵乘法算法开发了更硬件高效的SSD模型计算
- 矩阵乘法也可以解释为状态空间模型,其中块表示输入和输出序列的分块
- 对角块表示块内计算,非对角块表示通过SSM的隐藏状态进行的块间计算
注意,上图可以配合下图一块看(来源于此文中2.1.1节的“离散数据的连续化:基于零阶保持技术做连续化并采样”的最后)
// 待更
3.2 Mamba-2 架构
如下图所示,Mamba-2模块通过去除序列线性投影简化了Mamba模块(The Mamba-2 block simplifies the Mamba block by removing sequential linear projections)
- SSM参数𝐴, 𝐵, 𝐶在模块开始时生成,而不是作为SSM输入𝑋 的函数
the SSM parameters 𝐴, 𝐵, 𝐶 are produced at the beginning of the block instead of as a function of the SSM input 𝑋 . - 添加了一个额外的归一化层,如NormFormer,提高了稳定性
An additional normalization layer is added as in NormFormer (Shleifer, Weston, and Ott 2021), improving stability. - 𝐵和𝐶投影只有一个头部,在𝑋头部之间共享,类似于多值注意力(MVA)
The 𝐵 and 𝐶 projections only have a single head shared across the 𝑋 heads, analogous to multi-value attention (MVA)
3.2.1 模块设计:并行参数投影、额外的归一化
我们首先讨论对神经网络模块的修改,这些修改独立于内部序列混合层(即核心SSD层之外)
3.2.1.1 并行参数投影
对比mamba1和mamba2可知
- Mamba-1的动机是基于SSM中心的观点,其中选择性SSM层被视为从 𝑋 → 𝑌的映射(Mamba-1 was motivated by an SSM-centric point of view where the selective SSM layer is viewed as a map from 𝑋 → 𝑌 )
SSM参数, 𝐵, 𝐶被视为辅助参数,是SSM输入 𝑋的函数。 因此,定义(𝐴, 𝐵, 𝐶)的线性投影——在初始线性投影创建𝑋之后进行(The SSM parameters 𝐴, 𝐵, 𝐶 are viewed as subsidiary and are functions of the SSM input 𝑋 . Thus the linear projections defining (𝐴, 𝐵, 𝐶) occur after the initial linear projection to create 𝑋) - 在Mamba-2中,SSD层被视为从𝐴, 𝑋, 𝐵, 𝐶 → 𝑌的映射。 因此,有必要在块的开头通过单个投影并行生成𝐴, 𝑋,𝐵, 𝐶(In Mamba-2, the SSD layer is viewed as a map from 𝐴, 𝑋, 𝐵, 𝐶 ↦ → 𝑌 . It therefore makes sense to produce 𝐴, 𝑋, 𝐵, 𝐶 in parallel with a single projection at the beginning of the block) 值得注意的是
这与标准注意力架构类比,其中𝑋, 𝐵, 𝐶对应于并行创建的𝑄, 𝐾, 𝑉投影(Note the analogy to standard attention architectures, where 𝑋, 𝐵, 𝐶 correspond to the 𝑄, 𝐾, 𝑉 projections that are created in parallel.)
为SSM的𝐴, 𝐵, 𝐶, 𝑋输入采用并行投影略微减少了参数,更重要的是,通过使用标准的Megatron分片模式,更适合于较大模型的张量并行(Note that adopting parallel projections for the 𝐴, 𝐵, 𝐶, 𝑋 inputs to the SSM slightly reduces parameters and more importantly is more amenable to tensor parallelism for larger models, by using standard Megatron sharding patterns)
3.2.1.2 额外的归一化
在初步实验中,发现较大模型中容易出现不稳定性
通过在最终输出投影之前的块中添加一个额外的归一化层(例如LayerNorm、GroupNorm或RMSNorm)来缓解这一问题。 这种归一化的使用与NormFormer架构最直接相关,该架构也在MLP和MHA块的末端添加了归一化层
且mamba2的作者还注意到,这一变化类似于其他最近与Mamba-2相关的模型,这些模型是从线性注意力视角推导出来的
- 原始的线性注意力公式通过一个分母项进行归一化,该分母项模拟了标准注意力中softmax函数的归一化
而TransNormerLLM和RetNet发现这种归一化是不稳定的,并在线性注意力层之后添加了额外的LayerNorm或GroupNorm - mamba2的额外归一化层与这些略有不同,发生在乘法门分支之后而不是之前
3.2.2 序列变换的多头模式:多查询、多键、多值
回想一下,SSM被定义为一个序列变换
其中:
- 𝐴, 𝐵, 𝐶 参数具有状态维度 N
- 它们定义了一个序列变换,例如可以表示为矩阵
- 该变换作用于输入序列,独立于 P轴
可以将其视为定义了序列变换的一个 head
定义 7.1(多头模式) 多头序列变换由 H个独立的头组成,总模型维度为 D = d_model。参数可以在各头之间共享,形成一个head模式
状态大小 N和头维度 P类似于注意力机制中的 𝑄K头维度和 𝑉头维度(The state size N and head dimension P are analogous to the 𝑄𝐾 head dimension and 𝑉 head dimension of attention, respectively)
正如在现代Transformer架构中(比如Google的PaLM、Meta的Llama),在Mamba-2中我们通常选择这些常数为64或128;当模型维度 D增加时,我们增加头的数量,同时保持头维度 N和 P不变(when the model dimension D increases, we increase the number of heads while keeping the head dimensions N and P fixed)
为了描述如何做到这一点,我们可以从多头注意力中转移和推广想法,以定义SSM或任何一般序列变换的类似模式(in order to describe how to do this, we can transfer and generalize ideas from multihead attention to define similar patterns for SSMs, or any general sequence transformation)
- 多头状态空间模型 (MHS) / 多头注意力机制 (MHA) 模式
Multihead SSM (MHS) / Multihead Attention (MHA) Pattern
经典的 MHA 模式假设头维度 P可以整除模型维度 D
头的数量定义为 H = D/P(比如transformer论文中,模型维度512,8个头,每个头的维度为512/8 = 64),然后,通过创建 H个核心序列变换的副本,通过创建每个参数的 H个独立副本来实现
请注意,虽然MHA模式最初是为注意力序列变换描述的,但它可以应用于与定义2.1兼容的任何事物。例如,多头SSD层将接受形状符合方程(17)的输入,其中SSD算法在 H = n_heads维度上广播 - Multi-contract SSM (MCS)/多查询注意力(MQA)模式
Multi-contract SSM (MCS) / Multi-query Attention (MQA) Pattern
多查询注意力(详见此文:一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA),顾名思义,即多个query 单个key value,如下图最右侧所示:Multi-query 可以显著提高自回归推理的速度,这依赖于缓存𝐾和𝑉张量。 这种技术只是避免给𝐾和𝑉额外的头维度,换句话说,就是将(𝐾, 𝑉)的单个头广播到𝑄的所有头上
利用状态空间对偶性,我们可以将MQA的等效SSM版本定义为方程(18) 其中, 𝑋 和 𝐵(注意力的 𝑉 和 𝐾 的SSM类比)在 H个头之间共享,也称之为多收缩SSM (MCS)头模式,因为控制SSM状态收缩的 𝐶 参数在每个头中都有独立的副本
相当于X B C类比于V K Q
此外,多查询注意力的思想可以扩展到分组查询注意力(分组头模式Grouped Head Patterns):而不是1个K和V头,可以创建 G个独立的K和V头,其中1 < G且 G整除 H(如上图中部所示)
这既是为了弥合多查询和多头注意力之间的性能差异,也是为了通过将 G设置为分片数量的倍数来实现更高效的张量并行 - 多键注意力 (MKA) 或多扩展SSM (MES)头模式
其中控制SSM扩展的 𝐵在每个头中是独立的,而 𝐶和 𝑋在头之间共享 - 多输入SSM (MIS) / 多值注意力(MVA) 模式
Multi-input SSM (MIS) / Multi-value Attention (MVA) Pattern
虽然MQA对于注意力来说是有意义的,因为它有KV缓存,但它不是SSM的自然选择
在Mamba中, 𝑋被视为SSM的主要输入,因此 𝐵和 𝐶是跨输入通道共享的参数,而在公式(20)中定义了一种新的多值注意力 (MVA) 的多输入 SSM (MIS) 模式,这同样可以应用于任何序列变换,例如 SSD
上面的描述可能比较绕,我给大家画个图,便一目了然了
首先,对于下图三种模式中的C B X都是可以逐一和注意力中的Q K V对应的,且当某个模式中的或C、或B、或X被圈起来时,则代表它的数量是更多的 属于多个,而没被圈起来的则可能是单个
具体而言,可以简单粗暴的理解为:
- 多查询便是多个Query 单个Key 单个Value
相当于对应:多个C 单个B 单个X- 多键便是多个Key 单个Query 单个Value
相当于对应:多个B 单个C 单个X- 多值便是多个Value 单个Query 单个Key
相当于对应:多个X 单个C 单个B
定义7.2 mamba1的重新定义
Mamba 架构的选择性SSM(S6)层可以被视为具有
- 头维度 𝑃 = 1: 每个通道都有独立的 SSM 动态 𝐴
- 多输入SSM(MIS) 或多值注意力(MVA)头结构(如上图最右侧所示):输入𝑋的所有通道共享𝐵、𝐶矩阵(对应于注意力对偶中的𝐾、Q)
因为通过实验证明,Mamba中最初使用的MVA模式表现最佳
此外,值得一提的是,Mamba-2中使用的多输入SSM头模式(multi-input SSM head pattern,比如8个X 1个C 一个B),可以轻松扩展到分组输入SSM(grouped-input SSM,GIS,比如8个X 4个C 4个B),或同义的分组值注意力(grouped-value attention,GVA,还是value对应的X最多,然后 C B相对少)
3.2.3 线性注意力的其他SSD扩展
// 待更
3.3 SSM的系统优化:张量并行、序列并行、可变长度
3.3.1 张量并行Tensor Parallel
张量并行「Tensor parallelism,简称TP,详见此文《大模型并行训练指南:通俗理解Megatron-DeepSpeed之模型并行与数据并行》的第二部分 张量并行(Tensor Parallelism,算模型并行的一种)」是一种模型并行技术,它将每一层(例如,注意力机制,MLP)拆分在多个加速器(如 GPU)上运行。这种技术被广泛用于在 GPU 集群上训练大多数大型模型(Brown 等,2020;Chow dhery 等,2023;Touvron, Lavril 等,2023;Touvron, L. Martin 等,2023),其中每个节点通常有 4-8 个 GPU,并具有快速网络连接,如 NVLink
TP 最初是为 Transformer 架构开发的,没法直接适应其他架构,故在Mamba 架构中使用 TP 有一定的挑战,进一步,Mamba-2 架构用起来TP之后,还得考虑如何设计以使 TP 高效
回顾 Mamba 架构,单个输入(为简单起见,不进行批处理),输入投影矩阵,其中 是扩展因子(通常为2),输出投影矩阵
使用 TP,假设想将计算分配到 2 个 GPU 上
- 很容易将输入投影矩阵和分成两个大小为的分区
It is easy to split the input projection matrices 𝑊 (𝑥 ) and 𝑊 (𝑧 ) into two partitions each of size 𝑑 × 𝑒𝑑/2 - 然后每个 GPU 将持有大小为的一半
Then each GPU would hold half of 𝑥𝑐 of size 𝐿 × 𝑒𝑑/2 - 然而,由于 Δ, 𝐵, 𝐶是的函数,所以需要在 GPU 之间进行额外的全归约,以在计算Δ, 𝐵, 𝐶之前获得整个
However,we see that since Δ, 𝐵, 𝐶 are functions are 𝑥𝑐 , so we would need an extra all-reduce between the GPUs to get the whole of 𝑥𝑐 before computing Δ, 𝐵, 𝐶 - 之后,由于它们在𝑑上是独立的,因此两个 GPU 可以并行计算 SSM
After that the two GPUs can compute the SSM in parallel since they are independent
along 𝑑 - 最后,我们可以将输出投影矩阵分成两个大小为的分区,并在最后进行一次全规约
At the end, we can split the output projection matrices 𝑊 (𝑜 ) into two partitions each of size 𝑒𝑑/2 × 𝑑, and do an all-reduce at the end
上述整个过程,与Transformer相比,将进行两次全规约,而不是一次,从而使通信时间加倍(Compared to Transformers, we would incur two all-reduces instead of one, doubling the time spent in communication)
对于大规模Transformer训练,通信可能已经占用了相当大的一部分时间(例如10-20%),加倍通信将使Mamba在大规模训练中效率不高「For large-scale Transformers training, communication might already take a significant fraction of time(e.g. 10-20%), and doubling communication would make Mamba not as efficient for large-scale training」
使用Mamba-2的目标是每个块只有一次全规约,类似于Transformer中的注意力或MLP块。因此,我们通过投影直接从𝑢得到Δ, 𝐵, 𝐶,而不是从得到,从而允许拆分这些投影矩阵
这意味着我们在不同的GPU上有不同的 Δ, 𝐵, 𝐶集合,这相当于在一个更大的“逻辑GPU”上有几个“组”的 Δ, 𝐵, 𝐶。此外,在每个块内使用GroupNorm,组的数量可被TP度整除,这样TP组中的GPU在块内无需通信:
可以看到,只需要拆分输入投影矩阵和输出投影矩阵,并且只需要在块的末尾进行全归约。 这类似于注意力和MLP层的TP设计
特别地,如果有TP度为2,则会拆分
- ,其中
- ,其中
- ,其中
对于 𝑖 = 1, 2,TP Mamba- 2层可以写成
总之,如下图所示
- 左侧是张量并行,分割输入投影矩阵、和输出投影矩阵
每个SSM头 (𝐴, 𝐵, 𝐶, 𝑋) →𝑌存在于单个设备上,选择GroupNorm作为最终归一化层可以避免额外的通信。每层需要一次全归约,就像Transformer中的MLP或注意力块一样 - 右侧是序列/上下文并行,类似于SSD算法,使用多个设备,可以沿序列维度进行分割,每个设备计算其序列的状态,然后将该状态传递给下一个GPU
3.3.2 序列并行
对于非常长的序列,可能需要沿序列长度维度将输入和激活拆分到不同的GPU上。 有两种主要技术:
- 用于残差和归一化操作的序列并行(SP):由Korthikanti等人首次提出,这种技术将TP中的all-reduce分解为reduce-scatter和all-gather
注意到在同一TP组中的所有GPU上,残差和归一化操作在相同输入上重复进行,SP通过执行:reduce-scatter、残差和归一化,然后all-gather,沿序列长度维度拆分激活
由于Mamba-2架构使用相同的残差和归一化结构,SP无需修改即可应用 - 序列并行用于token混合操作(注意力或SSM),也称为“上下文并行”(context parallelism,简称CP)。已经为注意力层开发了几种技术「例如,环形注意力(Liu, Yan, et al. 2024; Liu, Zaharia和 Abbeel 2023),使用复杂的负载均衡技术(Brandon 等人,2023)
注意力机制中的序列并行问题在于可以将查询和键分成块,但每个查询块需要与键块交互,导致通信带宽与工作者数量呈二次方关系
使用 SSMs,可以简单地分割序列:每个工作者获取一个初始状态,计算其输入的 SSM,返回最终状态,并将最终状态传递给下一个工作者。 通信带宽与工作者数量呈线性关系。 这种分解与 SSD 算法中的块分解完全相同,可以分成块/块
且在上图 中说明了这种上下文并行性
3.3.3 可变长度
虽然预训练通常对批次使用相同的序列长度,但在微调或推理过程中,模型可能需要处理不同长度的输入序列。
一种处理这种情况的简单方法是将批处理中所有序列右填充到最大长度,但如果序列长度差异很大,这可能效率低下。 对于Transformer,已经开发了复杂的技术来避免填充,并在GPU之间进行负载平衡(Zeng等,2022;Y.Zhai等,2023),或者在同一批次中打包多个序列并调整注意力掩码(Ding等,2024;Pouransari等,2024)
对于SSM,特别是Mamba,可以通过简单地将整个批次视为一个长序列来处理可变序列长度,并避免在单个序列之间传递状态。 这相当于简单地设置,对于一个序列末尾的token 𝑡,以防止它将信息传递给属于不同序列的token 𝑡 + 1
// 待更
由于之前计划解读完mamba2之后,便解读open-television、我司7方面review微调gemma2,再之后是TTT、nature审稿微调,没想7.12这天,flashattention3又来了..,实属应接不暇
故打算暂停对本文mamba2的修订,过几天后继续,by july,24年7.12于长沙办公室