Mamba以及我们看的第一篇MambaOcc

news2024/9/30 21:29:18

0. 简介

深度学习架构有很多,但近些年最成功的莫过于 Transformer,其已经在多个应用领域确立了自己的主导地位。如此成功的一大关键推动力是注意力机制,这能让基于 Transformer 的模型关注与输入序列相关的部分,实现更好的上下文理解。但是,注意力机制的缺点是计算开销大,会随输入规模而二次增长,也因此就难以处理非常长的文本。而Mamba的出现则是解决了这个问题,通过结构化的状态空间序列模型(SSM)。该架构能高效地捕获序列数据中的复杂依赖关系,并由此成为 Transformer 的一大强劲对手。

1. 理解Mamba和Transformer

我们可以在Mamba模型底层技术详解,与Transformer到底有何不同?一文中看到状态空间模型Mamba的整体流程。而我们这里借鉴一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba一文来大致概括为什么Mamba为什么能从Transformer中占得一席之地。我这里将花两节概括一下重点,如果有需要,请看一下这篇原文,讲的非常好

1.1 Transformer模块

我们知道其是由Attention模块组成的,主要模块为:

  1. 自注意力机制(Self-Attention): 自注意力允许模型在处理输入序列时考虑序列中所有其他位置的信息。通过计算输入序列中每个位置的加权和,模型能够捕捉到不同位置之间的关系。这种机制使得模型能够更好地理解上下文。
  2. 多头注意力(Multi-Head Attention): 多头注意力是自注意力机制的扩展,它将输入分成多个“头”,并在每个头上独立执行自注意力计算。最后,将所有头的输出拼接在一起并通过线性变换得到最终结果。多头注意力使模型能够学习到不同的表示和特征,从而增强模型的表达能力。
  3. 前馈神经网络(Feed-Forward Neural Network): 在自注意力层之后,Transformer 模块还包含一个前馈神经网络,这个网络对每个位置的表示进行独立的非线性变换。这个前馈网络通常由两个线性变换和一个激活函数(如ReLU)组成。

而计算复杂度和序列长度的平方 N 2 N^2 N2成正比,可以看一个小例子,比如两个相乘的矩阵大小分别为 N × d N \times d N×d d × N d \times N d×N,矩阵乘法的一种计算方式是使用第一个矩阵的每一行与第二个矩阵的每一列做​点乘​。相关的动图可以参考经典文献阅读之–Deformable DETR中的做法,其需要将 Q Q Q K K K相乘。然后在最后还要再乘上 V V V值向量
在这里插入图片描述

在这个乘法过程中,计算每个元素 C [ i ] [ j ] C[i][j] C[i][j]的值需要将矩阵 A A A的第$ i$ 行与矩阵 B B B 的第 j j j列进行点乘。具体来说,点乘的计算公式为:

C [ i ] [ j ] = ∑ k = 1 d A [ i ] [ k ] × B [ k ] [ j ] C[i][j] = \sum_{k=1}^{d} A[i][k] \times B[k][j] C[i][j]=k=1dA[i][k]×B[k][j]

因此,为了计算矩阵 C C C 中的每个元素,我们需要进行 d d d次乘法和 d − 1 d-1 d1次加法。由于 C C C中有 N 2 N^2 N2 个元素(每个 $i $ 和 j j j 的组合),所以整个矩阵乘法的计算复杂度为:

O ( N 2 ⋅ d ) O(N^2 \cdot d) O(N2d)

1.2 状态空间与SSM

我们知道RNN在每一个时刻的隐藏状态 h t ​ h_t​ ht都是基于当前的输入 x t x_t xt和前一个时刻的隐藏状态 h t − 1 ​ h_{t-1}​ ht1计算得到的,比如泛化到任一时刻
在这里插入图片描述
但从上图中可以看到整个RNN是一个线性的结构,这就导致虽然每个隐藏状态都是所有先前隐藏状态的聚合,然随着时间的推移,RNN 往往会忘记某一部分信息,另外RNN这个结构,也导致其没法写成卷积形式,也没有办法并行训练,相当于推理快但训练慢(这也是Transformer要Attention的原因)。

为此Mamba在此基础上使用了状态空间与SSM来避免这个问题。

状态空间可以想象成我们正在穿过一个迷宫,图中每个小框代表迷宫中的一个位置,并附有某个隐式的信息,例如你距离出口有多远
在这里插入图片描述
而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框显示

  • 你当前所在的位置(当前状态current state)
  • 下一步可以去哪里(未来可能的状态possible future states)
  • 以及哪些变化会将你带到下一个状态(向右或向左)

而描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors”
在这里插入图片描述
而在状态空间中的状态空间模型SSM也是一个RNN的变体,其主要用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型。一般SSMs包括以下组成

  • 映射输入序列x(t),比如在迷宫中向左和向下移动
  • 到潜在状态表示h(t),比如距离出口距离和 x/y 坐标
  • 并导出预测输出序列y(t),比如再次向左移动以更快到达出口

在这里插入图片描述
SSM 假设动态系统(例如在 3D 空间中移动的物体)可以通过两个方程从其在时间 t t t时的状态进行预测。总之,SSM的关键是找到:状态表示(state representation)—— h ( t ) h(t) h(t),以便结合「其与输入序列」预测输出序列。

  1. 下图的第一个方程是不是和RNN循环结构:非常类似?——通过上一个隐藏状态和当前输入综合得到当前的隐藏状态,只是两个权重 W W W U U U换成了、两个系数,且去掉了非线性的激活函数 t a n h tanh tanh
  2. 但系数代表着什么,这点其实非常关键,然我看过的几乎所有讲解SSM/S4/mamba的文章都没有一针见血的指出来,其实A就是存储着之前所有历史信息的浓缩精华(可以通过一系列系数组成的矩阵表示之),以基于 A A A更新下一个时刻的空间状态 h i d d e n s t a t e hidden state hiddenstate这样解决了第一个遗忘的问题
    在这里插入图片描述

第一个方程:状态方程,矩阵B与输入 x ( t ) x(t) x(t)相乘之后,再加上矩阵A与前一个状态 h ( t ) h(t) h(t)相乘的结果
在这里插入图片描述
换言之,B矩阵影响输入 x ( t ) x(t) x(t),A矩阵影响前一个状态 h ( t ) h(t) h(t) → h ( t ) \rightarrow h(t) h(t)指的是任何给定时间 t t t的潜在状态表示(latent state representation), → x ( t ) \rightarrow x(t) x(t)指的是某个输入
第二个方程:输出方程,描述了状态如何转换为输出(通过矩阵 C),以及输入如何影响输出(通过矩阵 D)
在这里插入图片描述
最终的方程流程如下图所示
在这里插入图片描述

2. Mamba的三大创新

2.1 S4模块改进

作为Mamba而言其核心主要是从SSM引申的S4来改进的。其公式为如下图所示。首先是从连续 SSM 转变为离散SSM,使得不再是函数到函数 x ( t ) → y ( t ) x(t) \rightarrow y(t) x(t)y(t),而是序列到序列 x k → y k x_{k} \rightarrow y_{k} xkyk,其次不存在D,完成了简化,因为D并不是SSM的核心
在这里插入图片描述
上图矩阵 A ‾ \overline{\mathbf{A}} A B ‾ \overline{\mathbf{B}} B现在表示模型的离散参数,且这里使用 k k k,而不是 t t t 来表示离散的时间步长。在每个时间步,都会涉及到隐藏状态的更新(比如 h k h_k hk取决于 B ‾ x k \overline{\mathbf{B}} \mathbf{x}_{\mathrm{k}} Bxk A ‾ h k − 1 \overline{\mathbf{A}} \mathbf{h}_{\mathrm{k}-1} Ahk1的共同作用结果,然后通过 C h k Ch_k Chk预测输出 y k y_k yk)
在这里插入图片描述
对应的 y 2 y_2 y2展开为:
在这里插入图片描述
如此,便可以RNN的结构来处理
在这里插入图片描述

此外S4也可以表示成卷积的形式。这里我们处理的是文本而不是图像,因此我们需要一维视角

在这里插入图片描述
而用来表示这个“过滤器”的内核源自 SSM 公式

在这里插入图片描述
这正好和我们上面 y 2 y_2 y2计算公式一致,对应的核就是 y 2 y_2 y2的系数
在这里插入图片描述
以此内推,可得
y 3 = C A ‾ A ‾ A ‾ B ‾ x 0 + C A ‾ A ‾ B ‾ x 1 + C A ‾ B ‾ x 2 + C B ‾ x 3 y_{3}=\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{0}+\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{1}+\mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{B}} x_{2}+\mathbf{C} \overline{\mathbf{B}} x_{3} y3=CAAABx0+CAABx1+CABx2+CBx3

在这里插入图片描述
为此SSMs可以当做是RNN与CNN的结合。即推理用RNN结构,训练用CNN结构。这样解决了训练过慢的问题
在这里插入图片描述

S4到S6

表格总结下各个模型的核心特点
在这里插入图片描述
总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度:

  • 高效的模型必须有一个小的状态(比如RNN或S4)
  • 而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)

而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的。避免了SSM和S4存在的不随输入变化(即与输入无关)得问题。-----即缺少Attention性质

在这里插入图片描述
在Mamaba中,作者让 B B B矩阵、 C C C矩阵、 Δ \Delta Δ成为输入的函数,让模型能够根据输入内容自适应地调整其行为
在这里插入图片描述
其中批量大小为 B B B,长度为 L L L,通道为 D D D(比如一个颜色的变量一般有R G B三个维度),SSM的隐藏层维度hidden为 N N N

从S4到S6的过程中,将影响输入的B矩阵、影响状态的C矩阵的大小从原来的 ( D , N ) (D,N) (D,N)
在这里插入图片描述

变成了 ( B , L , N ) (B,L,N) (B,L,N)【这三个参数分别对应batch size、sequence length、hidden state size】。
在这里插入图片描述

Δ \Delta Δ的大小由原来的 D D D变成了 ( B , L , D ) (B,L,D) (B,L,D),意味着对于一个 batch 里的 每个 token。

在这里插入图片描述
讲到这里,我们将大多数SSM架构比如H3的基础块,与现代神经网络比如transformer中普遍存在的门控MLP相结合,组成新的Mamba块,重复这个块,与归一化和残差连接结合,便构成了Mamba架构
在这里插入图片描述
关于mamba的整体架构,有两点值得强调下

  1. 为何要做线性投影
  • 经过线性投影后,输入嵌入的维度可能会增加,以便让模型能够处理更高维度的特征空间,从而捕获更细致、更复杂的特征
  1. 为什么SSM前面有个卷积?
    本质是对数据做进一步的预处理,更细节的原因在于:
  • SSM之前的CNN负责提取局部特征(因其擅长捕捉局部的短距离特征),而SSM则负责处理这些特征并捕捉序列数据中的长期依赖关系,两者算互为补充
  • CNN有助于建立token之间的局部上下文关系,从而防止独立的token计算毕竟如果每个 token 独立计算,那么模型就会丢失序列中 token 之间的上下文信息。通过先进行卷积操作,可以确保在进入 SSM 之前,序列中的每个 token 已经考虑了其邻居 token 的信息。这样,模型就不会单独地处理每个 token,而是在处理时考虑了整个局部上下文

下图就是整个Mamba的示意图,其中Selection SSM就是S6
在这里插入图片描述
与Transformer结构类似,Mamba结构也是由若干Mamba块堆叠而成。一个基本的Mamba块结构如图7所示:Mamba块由H3块以及门控MLP组合而成。H3为Hungry Hungry Hippos,是一种状态空间模型的执行方式。Mamba块简化了H3的结构,并与门控MLP结合,添加了残差项防止梯度消失。

Mamba的主要优势还是其优于Transformer的计算效率。Mamba的网络结构对于GPU的计算来说十分友好,特别是在数据存取交互上,Mamba结构的数据交互主要集中在GPU何SRAM间,而这部分的数据交互是快速的。

3. MambaOcc

《MambaOcc: Visual State Space Model for BEV-based Occupancy Prediction with Local Adaptive Reordering》提出了一种基于Mamba框架的新型占用率预测方法,旨在实现轻量级,同时提供高效的远距离信息建模,我们称之为MambaOcc算法模型。相关的工作也在Github上有链接了。个人感觉在这种长序列的情况中,也许Mamba其实是更有竞争力的。

MambaOcc方法设计轻量化,同时提供高效的长距离建模。首先,我们利用四方向视觉Mamba [7]来提取图像特征。为了减轻与3D体素相关的高计算负担,我们使用BEV特征作为占用预测的中间表示,并开发了一种结合卷积层和Mamba层的混合BEV编码器。鉴于Mamba架构在特征提取过程中对令牌顺序的敏感性,我们引入了一个利用可变形卷积(DCN)层的局部自适应重排序模块。该模块旨在动态更新每个位置的上下文,使模型能够更好地捕捉和利用数据中的局部依赖关系。这种方法不仅缓解了刚性令牌序列带来的问题,还通过确保在提取过程中优先考虑相关的上下文信息,提高了占用预测的整体准确性。本文的贡献如下:

  1. 提出了一种基于Mamba的轻量化占用预测方法(MambaOcc),在显著降低计算成本的同时提升了基于BEV的方法的性能。据我们所知,这是首个将Mamba集成到基于BEV的占用网络中的工作。
  2. 提出了一种具有局部自适应重排序机制的新型LAR-SS2D混合编码器,使得序列顺序优化更加灵活,并提升了状态空间模型的性能。
  3. 在Occ3DnuScenes数据集上,我们在参数和计算量有限的情况下实现了最先进的性能,例如,我们在减少42%参数和39%计算成本的同时,取得了比FlashOcc更好的结果。

4. 主要方法

在本节中,我们将从四个方面详细阐述所提出的MambaOcc:用于图像特征提取的基于Mamba的图像骨干网络(VM-Backbone),用于获取BEV格式特征和聚合多帧特征的视图变换和时间融合模块,带有自适应局部重排序模块的LAR-SS2D混合BEV编码器,以及占用预测头

…详情请参照古月居

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180754.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动手测试:CPU的L1~L3级缓存和内存的读取速度测试

引言 在许多文章中指出了这些缓存的架构,速度差异等。纸上得来终觉浅,今天想实际写代码简单测试一下。 背景 现代计算机系统中,CPU缓存(L1、L2、L3)和主内存(RAM)之间的读取速度有着显著的差…

数据结构之链表(2),双向链表

目录 前言 一、链表的分类详细 二、双向链表 三、双向链表的实现 四、List.c文件的完整代码 五、使用演示 总结 前言 接着上一篇单链表来详细说说链表中什么是带头和不带头,“哨兵位”是什么,什么是单向什么是双向,什么是循环和不循环。然后实…

U盘恢复数据工具:让数据失而复得的魔法

优盘里数据丢失无疑会给我们的工作和生活带来诸多不便。幸运的是,优盘数据恢复软件应运而生,它们如同数据的守护者,为我们提供了找回丢失数据的希望。这次我们就一同来探讨u盘恢复数据有什么方法吧。 1.福昕恢复数据 链接直达:h…

AutoSar 通信服务架构,CAN通信诊断详解

文章目录 Com(通信服务模块)PDU的定义和结构PDU的分类IPDU Mux 模块PDU R 模块(路由)Bus TP 模块BUS InterfaceCanIf模块LinIf模块 发送数据示例(CAN报文)接收数据示例(CAN报文)通信…

监控告警功能详细介绍及操作演示:运维团队的智能保障

在当今这个信息化高速发展的时代,运维团队面临着前所未有的挑战。为了确保系统的稳定性和高效运维,监控告警功能成为了运维团队不可或缺的得力助手。本文将详细介绍我们的监控告警功能,并结合实际操作页面进行演示,帮助运维团队更…

Docker入门指南:快速学习Docker的基本操作

为什么需要Docker 有时我们在本地开发好程序并成功运行之后,却在服务器上运行不起来,通过观察日志通常会发现,哦原来是这个库没安装,于是我们就需要先安装需要用到的库,然后再启动服务你可能还会发现用到的数据库信息…

《Linux从小白到高手》理论篇(六):Linux软件安装一篇通

List item 本篇介绍Linux软件安装相关的操作命令,看完本文,有关Linux软件安装相关操作的常用命令你就掌握了99%了。 Linux软件安装 RPM RPM软件的安装、删除、更新只有root权限才能使用;查询功能任何用户都可以操作;如果普通用…

真正的Open AI ——LLaMA颠覆开源大模型

1. LLaMA 简介 LLaMA(Large Language Model Meta AI)是由Meta(原Facebook)推出的一个大型语言模型系列,旨在通过更小的模型规模和更少的计算资源,实现与其他主流语言模型(如GPT)相媲…

spring简短注入

新建bean 创建set方法 jpackage com.dependency.spring6.bean;import org.slf4j.Logger; import org.slf4j.LoggerFactory;public class User {private static final Logger LOGGER LoggerFactory.getLogger(User.class);private String username;private String password;pr…

RPA跨流程复用元素技巧|实在RPA研究

为什么要跨流程复用元素 在 RPA 操作中,元素至关重要,因为自动化的本质就是模拟人类对元素的操作。基本上,每个流程都会包含若干个元素。对于同时维护多个流程的用户而言,相似的流程包,甚至是同一个元素。例如电商用户…

Solidworks斜接法兰快速绘制钣金箱体

Solidworks斜接法兰快速绘制钣金箱体 Chapter1 Solidworks斜接法兰快速绘制钣金箱体 Chapter1 Solidworks斜接法兰快速绘制钣金箱体 0.5mm间距为钣金焊接的预留焊缝。

Linux云计算 |【第四阶段】RDBMS1-DAY6

主要内容: MySQL索引(索引分类、创建索引)、用户及授权(创建用户并授权、查看授权、撤销授权、授权库mysql)、root密码恢复、备份、使用mysqldump进行逻辑备份、Percona 一、MySQL索引 1、基本概念 MySQL 索引(Inde…

给虚拟机安装操作系统以及无密码SSH登录

安装完虚拟化软件VMware Workstation Pro 17之后,我们下载了Ubuntu光盘映像文件,上次说演示desktop版的安装,但是考虑到后面要部署数据库,所以为了方便起见还是下载sever服务器版。 文件还挺大,在等待下载完成这会我们…

基于SpringBoot的休闲娱乐代理售票系统设计与实现

1.1研究背景 21世纪,我国早在上世纪就已普及互联网信息,互联网对人们生活中带来了无限的便利。像大部分的企事业单位都有自己的系统,由从今传统的管理模式向互联网发展,如今开发自己的系统是理所当然的。那么开发休闲娱乐代理售票…

C++那些事之内存优化

C那些事之内存优化 通常程序运行时内存是一个比较大的问题,如何减少内存占用和提升访问速度是至关重要。为了解决这些问题,C20 引入了 no_unique_address 特性,并结合空基类优化(EBO, Empty Base Optimization)&#x…

33 指针与数组:数组名与指针的关系、使用指针遍历数组、数组指针、指针数组、字符指针

目录​​​​​​​ 1 数组名与指针的关系 1.1 数组名 1.2 对数组名取地址 1.3 数组名与指针的区别 1.3.1 类型不同 1.3.2 sizeof 操作符的行为不同 1.3.3 & 操作符的行为不同 1.3.4 自增自减运算的行为不同 1.3.5 可变性不同 2 使用指针遍历数组 2.1 使用 *(nu…

智能网联汽车飞速发展,安全危机竟如影随形,如何破局?

随着人工智能、5G通信、大数据等技术的飞速发展,智能网联汽车正在成为全球汽车行业的焦点。特别是我国智能网联汽车市场规模近年来呈现快速增长态势,彰显了行业蓬勃发展的活力与潜力。然而,车联网技术的广泛应用也带来了一系列网络安全问题&a…

Mybatis知识

1. 基础知识 mybatis是基于java的持久层框架,它内部封装了jdbc,使开发者只需要关注sql语句本身,而不需要花费精力去处理加载驱动,创建连接,创建statement等繁杂的过程。 通过xml或者注解的方式将要执行的各种sta…

序列化方式五——ProtoStuff

介绍 Protostuff是一个基于Java的高效序列化库,它使用Protocol Buffers(简称protobuf)协议,为Java对象提供高效、灵活且易用的序列化和反序列化方法。Protostuff的主要优势在于其高性能和简单的使用方式,相对于其他序…

C#多线程数据同步的几种方式(不同的锁)

无锁 多个关联数据无法完整获取修改 internal class Program{static void Main(string[] args){Console.WriteLine("Hello, World!");ThreadPool.QueueUserWorkItem(Thread1);ThreadPool.QueueUserWorkItem(Thread2);ThreadPool.QueueUserWorkItem(Thread3);Console…