0. 简介
深度学习架构有很多,但近些年最成功的莫过于 Transformer,其已经在多个应用领域确立了自己的主导地位。如此成功的一大关键推动力是注意力机制,这能让基于 Transformer 的模型关注与输入序列相关的部分,实现更好的上下文理解。但是,注意力机制的缺点是计算开销大,会随输入规模而二次增长,也因此就难以处理非常长的文本。而Mamba的出现则是解决了这个问题,通过结构化的状态空间序列模型(SSM)。该架构能高效地捕获序列数据中的复杂依赖关系,并由此成为 Transformer 的一大强劲对手。
1. 理解Mamba和Transformer
我们可以在Mamba模型底层技术详解,与Transformer到底有何不同?一文中看到状态空间模型Mamba的整体流程。而我们这里借鉴一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba一文来大致概括为什么Mamba为什么能从Transformer中占得一席之地。我这里将花两节概括一下重点,如果有需要,请看一下这篇原文,讲的非常好
1.1 Transformer模块
我们知道其是由Attention模块组成的,主要模块为:
- 自注意力机制(Self-Attention): 自注意力允许模型在处理输入序列时考虑序列中所有其他位置的信息。通过计算输入序列中每个位置的加权和,模型能够捕捉到不同位置之间的关系。这种机制使得模型能够更好地理解上下文。
- 多头注意力(Multi-Head Attention): 多头注意力是自注意力机制的扩展,它将输入分成多个“头”,并在每个头上独立执行自注意力计算。最后,将所有头的输出拼接在一起并通过线性变换得到最终结果。多头注意力使模型能够学习到不同的表示和特征,从而增强模型的表达能力。
- 前馈神经网络(Feed-Forward Neural Network): 在自注意力层之后,Transformer 模块还包含一个前馈神经网络,这个网络对每个位置的表示进行独立的非线性变换。这个前馈网络通常由两个线性变换和一个激活函数(如ReLU)组成。
而计算复杂度和序列长度的平方
N
2
N^2
N2成正比,可以看一个小例子,比如两个相乘的矩阵大小分别为
N
×
d
N \times d
N×d和
d
×
N
d \times N
d×N,矩阵乘法的一种计算方式是使用第一个矩阵的每一行与第二个矩阵的每一列做点乘。相关的动图可以参考经典文献阅读之–Deformable DETR中的做法,其需要将
Q
Q
Q和
K
K
K相乘。然后在最后还要再乘上
V
V
V值向量
在这个乘法过程中,计算每个元素 C [ i ] [ j ] C[i][j] C[i][j]的值需要将矩阵 A A A的第$ i$ 行与矩阵 B B B 的第 j j j列进行点乘。具体来说,点乘的计算公式为:
C [ i ] [ j ] = ∑ k = 1 d A [ i ] [ k ] × B [ k ] [ j ] C[i][j] = \sum_{k=1}^{d} A[i][k] \times B[k][j] C[i][j]=k=1∑dA[i][k]×B[k][j]
因此,为了计算矩阵 C C C 中的每个元素,我们需要进行 d d d次乘法和 d − 1 d-1 d−1次加法。由于 C C C中有 N 2 N^2 N2 个元素(每个 $i $ 和 j j j 的组合),所以整个矩阵乘法的计算复杂度为:
O ( N 2 ⋅ d ) O(N^2 \cdot d) O(N2⋅d)
1.2 状态空间与SSM
我们知道RNN在每一个时刻的隐藏状态
h
t
h_t
ht都是基于当前的输入
x
t
x_t
xt和前一个时刻的隐藏状态
h
t
−
1
h_{t-1}
ht−1计算得到的,比如泛化到任一时刻
但从上图中可以看到整个RNN是一个线性的结构,这就导致虽然每个隐藏状态都是所有先前隐藏状态的聚合,然随着时间的推移,RNN 往往会忘记某一部分信息,另外RNN这个结构,也导致其没法写成卷积形式,也没有办法并行训练,相当于推理快但训练慢(这也是Transformer要Attention的原因)。
为此Mamba在此基础上使用了状态空间与SSM来避免这个问题。
状态空间可以想象成我们正在穿过一个迷宫,图中每个小框代表迷宫中的一个位置,并附有某个隐式的信息,例如你距离出口有多远
而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框显示
- 你当前所在的位置(当前状态current state)
- 下一步可以去哪里(未来可能的状态possible future states)
- 以及哪些变化会将你带到下一个状态(向右或向左)
而描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors”
而在状态空间中的状态空间模型SSM也是一个RNN的变体,其主要用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型。一般SSMs包括以下组成
- 映射输入序列x(t),比如在迷宫中向左和向下移动
- 到潜在状态表示h(t),比如距离出口距离和 x/y 坐标
- 并导出预测输出序列y(t),比如再次向左移动以更快到达出口
SSM 假设动态系统(例如在 3D 空间中移动的物体)可以通过两个方程从其在时间
t
t
t时的状态进行预测。总之,SSM的关键是找到:状态表示(state representation)——
h
(
t
)
h(t)
h(t),以便结合「其与输入序列」预测输出序列。
- 下图的第一个方程是不是和RNN循环结构:非常类似?——通过上一个隐藏状态和当前输入综合得到当前的隐藏状态,只是两个权重 W W W、 U U U换成了、两个系数,且去掉了非线性的激活函数 t a n h tanh tanh
- 但系数代表着什么,这点其实非常关键,然我看过的几乎所有讲解SSM/S4/mamba的文章都没有一针见血的指出来,其实A就是存储着之前所有历史信息的浓缩精华(可以通过一系列系数组成的矩阵表示之),以基于
A
A
A更新下一个时刻的空间状态
h
i
d
d
e
n
s
t
a
t
e
hidden state
hiddenstate。这样解决了第一个遗忘的问题
第一个方程:状态方程,矩阵B与输入 x ( t ) x(t) x(t)相乘之后,再加上矩阵A与前一个状态 h ( t ) h(t) h(t)相乘的结果
换言之,B矩阵影响输入 x ( t ) x(t) x(t),A矩阵影响前一个状态 h ( t ) h(t) h(t), → h ( t ) \rightarrow h(t) →h(t)指的是任何给定时间 t t t的潜在状态表示(latent state representation), → x ( t ) \rightarrow x(t) →x(t)指的是某个输入
第二个方程:输出方程,描述了状态如何转换为输出(通过矩阵 C),以及输入如何影响输出(通过矩阵 D)
最终的方程流程如下图所示
2. Mamba的三大创新
2.1 S4模块改进
作为Mamba而言其核心主要是从SSM引申的S4来改进的。其公式为如下图所示。首先是从连续 SSM 转变为离散SSM,使得不再是函数到函数
x
(
t
)
→
y
(
t
)
x(t) \rightarrow y(t)
x(t)→y(t),而是序列到序列
x
k
→
y
k
x_{k} \rightarrow y_{k}
xk→yk,其次不存在D,完成了简化,因为D并不是SSM的核心
上图矩阵
A
‾
\overline{\mathbf{A}}
A和
B
‾
\overline{\mathbf{B}}
B现在表示模型的离散参数,且这里使用
k
k
k,而不是
t
t
t 来表示离散的时间步长。在每个时间步,都会涉及到隐藏状态的更新(比如
h
k
h_k
hk取决于
B
‾
x
k
\overline{\mathbf{B}} \mathbf{x}_{\mathrm{k}}
Bxk和
A
‾
h
k
−
1
\overline{\mathbf{A}} \mathbf{h}_{\mathrm{k}-1}
Ahk−1的共同作用结果,然后通过
C
h
k
Ch_k
Chk预测输出
y
k
y_k
yk)
对应的
y
2
y_2
y2展开为:
如此,便可以RNN的结构来处理
此外S4也可以表示成卷积的形式。这里我们处理的是文本而不是图像,因此我们需要一维视角
而用来表示这个“过滤器”的内核源自 SSM 公式
这正好和我们上面
y
2
y_2
y2计算公式一致,对应的核就是
y
2
y_2
y2的系数
以此内推,可得
y
3
=
C
A
‾
A
‾
A
‾
B
‾
x
0
+
C
A
‾
A
‾
B
‾
x
1
+
C
A
‾
B
‾
x
2
+
C
B
‾
x
3
y_{3}=\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{0}+\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{1}+\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{2}+\mathbf{C} \overline{\mathbf{B}} x_{3}
y3=CAAABx0+CAABx1+CABx2+CBx3
为此SSMs可以当做是RNN与CNN的结合。即推理用RNN结构,训练用CNN结构。这样解决了训练过慢的问题
S4到S6
表格总结下各个模型的核心特点
总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:
- 高效的模型必须有一个小的状态(比如RNN或S4)
- 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)
而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的。避免了SSM和S4存在的不随输入变化(即与输入无关)得问题。-----即缺少Attention性质
在Mamaba中,作者让
B
B
B矩阵、
C
C
C矩阵、
Δ
\Delta
Δ成为输入的函数,让模型能够根据输入内容自适应地调整其行为
其中批量大小为
B
B
B,长度为
L
L
L,通道为
D
D
D(比如一个颜色的变量一般有R G B三个维度),SSM的隐藏层维度hidden为
N
N
N。
从S4到S6的过程中,将影响输入的B矩阵、影响状态的C矩阵的大小从原来的
(
D
,
N
)
(D,N)
(D,N)
变成了
(
B
,
L
,
N
)
(B,L,N)
(B,L,N)【这三个参数分别对应batch size、sequence length、hidden state size】。
且 Δ \Delta Δ的大小由原来的 D D D变成了 ( B , L , D ) (B,L,D) (B,L,D),意味着对于一个 batch 里的 每个 token。
讲到这里,我们将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构
关于mamba的整体架构,有两点值得强调下
- 为何要做线性投影
- 经过线性投影后,输入嵌入的维度可能会增加,以便让模型能够处理更高维度的特征空间,从而捕获更细致、更复杂的特征
- 为什么SSM前面有个卷积?
本质是对数据做进一步的预处理,更细节的原因在于:
- SSM之前的CNN负责提取局部特征(因其擅长捕捉局部的短距离特征),而SSM则负责处理这些特征并捕捉序列数据中的长期依赖关系,两者算互为补充
- CNN有助于建立token之间的局部上下文关系,从而防止独立的token计算毕竟如果每个 token 独立计算,那么模型就会丢失序列中 token 之间的上下文信息。通过先进行卷积操作,可以确保在进入 SSM 之前,序列中的每个 token 已经考虑了其邻居 token 的信息。这样,模型就不会单独地处理每个 token,而是在处理时考虑了整个局部上下文
下图就是整个Mamba的示意图,其中Selection SSM就是S6
与Transformer结构类似,Mamba结构也是由若干Mamba块堆叠而成。一个基本的Mamba块结构如图7所示:Mamba块由H3块以及门控MLP组合而成。H3为Hungry Hungry Hippos,是一种状态空间模型的执行方式。Mamba块简化了H3的结构,并与门控MLP结合,添加了残差项防止梯度消失。
Mamba的主要优势还是其优于Transformer的计算效率。Mamba的网络结构对于GPU的计算来说十分友好,特别是在数据存取交互上,Mamba结构的数据交互主要集中在GPU何SRAM间,而这部分的数据交互是快速的。
3. MambaOcc
《MambaOcc: Visual State Space Model for BEV-based Occupancy Prediction with Local Adaptive Reordering》提出了一种基于Mamba框架的新型占用率预测方法,旨在实现轻量级,同时提供高效的远距离信息建模,我们称之为MambaOcc算法模型。相关的工作也在Github上有链接了。个人感觉在这种长序列的情况中,也许Mamba其实是更有竞争力的。
MambaOcc方法设计轻量化,同时提供高效的长距离建模。首先,我们利用四方向视觉Mamba [7]来提取图像特征。为了减轻与3D体素相关的高计算负担,我们使用BEV特征作为占用预测的中间表示,并开发了一种结合卷积层和Mamba层的混合BEV编码器。鉴于Mamba架构在特征提取过程中对令牌顺序的敏感性,我们引入了一个利用可变形卷积(DCN)层的局部自适应重排序模块。该模块旨在动态更新每个位置的上下文,使模型能够更好地捕捉和利用数据中的局部依赖关系。这种方法不仅缓解了刚性令牌序列带来的问题,还通过确保在提取过程中优先考虑相关的上下文信息,提高了占用预测的整体准确性。本文的贡献如下:
- 提出了一种基于Mamba的轻量化占用预测方法(MambaOcc),在显著降低计算成本的同时提升了基于BEV的方法的性能。据我们所知,这是首个将Mamba集成到基于BEV的占用网络中的工作。
- 提出了一种具有局部自适应重排序机制的新型LAR-SS2D混合编码器,使得序列顺序优化更加灵活,并提升了状态空间模型的性能。
- 在Occ3DnuScenes数据集上,我们在参数和计算量有限的情况下实现了最先进的性能,例如,我们在减少42%参数和39%计算成本的同时,取得了比FlashOcc更好的结果。
4. 主要方法
在本节中,我们将从四个方面详细阐述所提出的MambaOcc:用于图像特征提取的基于Mamba的图像骨干网络(VM-Backbone),用于获取BEV格式特征和聚合多帧特征的视图变换和时间融合模块,带有自适应局部重排序模块的LAR-SS2D混合BEV编码器,以及占用预测头。