mamba->jamba

news2024/12/23 3:10:55

1 mamba解决什么问题

Transformer的问题,其计算复杂度与序列长度的平方成正比,导致在处理长序列时效率低下。
Transformer 的注意力创建一个矩阵,将每个 token 与之前的每个 token 进行比较。矩阵中的权重由 token 对之间的相关性决定。
长度为 L 的序列生成 token 大约需要 L² 的计算量,如果序列长度增加,计算量会平方级增长。因此,需要重新计算整个序列是 Transformer 体系结构的主要瓶颈。

RNN 的问题:
每个隐藏状态都是之前所有隐藏状态的聚合。这会导致随着时间的推移,RNN 会忘记更久的信息,因为它只考虑前一个状态。
并且 RNN 的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。

人们一直在寻找一种既能像 Transformer 那样并行化训练,能够记住先前的信息,又能在推理时时间是随序列长度线性增长的模型,Mamba 就是这样应运而生的。

Mamba基于“选择性状态空间模型”(selective state space model),在处理长序列时展现出更高的效率和性能。

Mamba的主要创新点包括:
线性时间复杂度:与Transformer不同,Mamba在序列长度方面实现了线性时间运行,特别适合处理非常长的序列。
选择性状态空间:Mamba利用选择性状态空间,能够更高效和有效地捕获相关信息,特别是在长序列中。
硬件感知算法:Mamba使用针对现代硬件(尤其是GPU)优化的并行算法,减少内存需求,提高计算效率。
简化架构:Mamba的结构比Transformer更简单,它去除了传统的注意力和MLP块,提供了更好的可扩展性和性能。
在性能方面,Mamba在语言、音频和基因组学等多个领域表现出色,能够与大型Transformer模型相媲美甚至超越。特别是在语言建模中,Mamba展示了卓越的性能,其预训练模型和代码已公开供社区使用。

2 mamba

论文地址:https://arxiv.org/abs/2312.00752
模型地址:state-spaces (State Space Models)
代码地址:https://github.com/state-spaces/mamba

2.1 Mamba的主要特点

选择机制(Selection Mechanism):Mamba采用选择机制来改进状态空间模型(SSM),允许模型基于输入内容有选择地传播或遗忘信息,从而增强了模型的表达能力。
硬件感知算法(Hardware-aware Algorithm):为使选择机制SSM在硬件上高效运行,Mamba设计了融合了内核和重新计算的硬件感知算法,避免了中间状态的存储,提高了速度和内存效率。
简化的架构(Simplified Architecture):Mamba将H3中的SSM块和Transformer中的MLP块合并为一个简化的块,重复堆叠这些块形成整体架构。这种简化的设计提高了训练和推理的效率。

2.2 SSM->HIPPO->S4

在这里插入图片描述

2.2.1 标准SSM

状态空间模型(State Space Models,SSM)由简单的方程定义。它将一维输入信号 x(t)映射到 N 维潜在状态 h(t)然后再投影到一维输出信号 y(t)。
在这里插入图片描述

SSM的两个方程:状态方程x(t)与输出方程y(t)
总之,SSM的关键是找到:状态表示(state representation)—— h(t),以便结合「其与输入序列」预测输出序列

(矩阵A描述了所有内部状态如何连接影响当前内部状态)
(矩阵B描述了当前输入如何影响当前内部状态)
(矩阵C描述了所有内部状态如何影响输出)
(矩阵D描述了当前输入如何影响输出)在这里插入图片描述
在这里插入图片描述

简化的结构:
在这里插入图片描述

2.2.2 零阶保持技术–从连续信号到离散信号

如果你有一个连续信号,找到状态表示***h(t)***在分析上是具有挑战性的。此外,由于我们通常有离散输入(如文本序列),我们希望将模型离散化。

为此,我们使用零阶保持技术。它的工作原理如下。首先,每当我们接收到一个离散信号时,我们保持其值,直到我们接收到一个新的离散信号。这个过程创建了一个连续信号,SSM可以使用:
我们保持值的时间由一个新的可学习参数表示,称为步长 ∆。它表示输入的分辨率。
现在,我们有了连续信号作为输入,我们可以生成连续输出,并根据输入的时间步长仅对值进行采样。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里,矩阵A和B现在表示模型的离散化参数。
我们使用k而不是t来表示离散化的时间步长,并且在引用连续SSM和离散SSM时更加清晰。

离散SSM计算–像RNN一样
在这里插入图片描述在这里插入图片描述

2.2.3 长距离依赖问题的解决之道——HiPPO

可以说,SSM公式中最重要的一个方面是矩阵A。正如我们之前在循环表示中看到的那样,它捕捉了关于先前状态的信息,以构建新状态。
由于矩阵A只记住之前的几个token和捕获迄今为止看到的每个token之间的区别,特别是在循环表示的上下文中,因为它只回顾以前的状态

Hippo的全称是High-order Polynomial Projection Operator,
其对应的论文为:HiPPO: Recurrent Memory with Optimal Polynomial Projections

作者讨论了如何处理长距离依赖(Long-Range Dependencies,LRDs)的问题,LRDs 是序列建模中的一个关键挑战,因为它们涉及到在序列中跨越大量时间步的依赖关系。
作者指出,基本的 SSM 在实际应用中表现不佳,特别是在处理 LRDs 时。这是因为线性一阶常微分方程(ODEs)的解通常是指数函数,这可能导致梯度在序列长度上呈指数级增长,从而引发梯度消失或爆炸的问题。

为了解决这个问题,作者利用了 HiPPO 理论。HiPPO 理论指定了一类特殊的矩阵 A,当这些矩阵被纳入 SSM 的方程中时,可以使状态 x(t) 能够记住输入 u(t) 的历史信息。这些特殊矩阵被称为 HiPPO 矩阵,它们具有特定的数学形式,可以有效地捕捉长期依赖关系。
HiPPO 矩阵的一个关键特性是它们允许 SSM 在数学和实证上捕捉 LRDs。例如,通过将随机矩阵 A 替换为 HiPPO 矩阵,可以在序列 MNIST 基准测试上显著提高 SSM 的性能。

它使用矩阵A来构建一个状态表示,能够很好地捕捉最近的标记,并衰减较旧的标记。其公式可以表示如下:
在这里插入图片描述
在这里插入图片描述

使用HiPPO构建矩阵A被证明比将其初始化为随机矩阵要好得多。因此,与初始标记相比,它更准确地重构了新的信号(最近的标记)。

HiPPO矩阵背后的思想是它产生一个隐藏状态,可以记住其历史。
从数学上讲,它通过跟踪Legendre多项式的系数来实现这一点,这使得它能够近似所有先前的历史。

2.2.4 S4——HiPPO的应用

然后,HiPPO被应用于我们之前看到的循环和卷积表示,以处理长距离依赖关系。结果是Sequences的结构化状态空间(S4),这是一类可以高效处理长序列的SSM。
且对矩阵A 做了改进

它由三个部分组成:
状态空间模型
用于处理长距离依赖关系的HiPPO
用于创建循环和卷积表示的离散化

在这里插入图片描述
在这里插入图片描述

S4 是 HiPPO 的后续工作,论文名称为:Efficiently Modeling Long Sequences with Structured State Spaces。

S4 的主要工作是将 HiPPO 中的矩阵 A(称为 HiPPO 矩阵)转换为正规矩阵(正规矩阵可以分解为对角矩阵)和低秩矩阵的和,以此提高计算效率。

S4 通过这种分解,将计算复杂度降低到了O(N+L) ,其中 N 是 HiPPO 矩阵的维度,L 是序列长度。
在处理长度为 16000 的序列的语音分类任务中,S4 模型将专门设计的语音卷积神经网络(Speech CNNs)的测试错误率降低了一半,达到了1.7%。相比之下,所有的循环神经网络(RNN)和 Transformer 基线模型都无法学习,错误率均在70%以上。

S4 在推理时,使用递归形式,每次只需要和上一个状态进行计算,具有和 RNN 相似的推理效率。
由于离散时间 SSM 的递归性质,它在硬件上进行训练时存在效率问题。因此,作者将离散时间 SSM 的递归方程转换为离散卷积的形式。通过展开递归方程,可以得到一个卷积核,这个卷积核可以用来在序列数据上应用卷积操作。这种转换允许 SSM 利用快速傅里叶变换(FFT)等高效的卷积计算方法,从而在训练过程中提高计算效率。

在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述

为什么对角化可以减少 SSM 计算复杂度

  • 为了进一步提升计算效率,作者讨论了对角化在计算离散时间状态空间模型(SSM)中的应用,以及为什么直接应用对角化方法在实践中并不可行。
  • 对角化是一种线性代数技术,它可以将一个矩阵转换为对角形式,从而简化矩阵的乘法和其他运算。在 SSM 的上下文中,对角化可以显著减少计算复杂度,因为对角矩阵的幂运算(如在递归方程中出现的)可以通过简单的元素指数运算来完成。
    在这里插入图片描述

直接对角化 HiPPO 矩阵导致数值溢出
在这里插入图片描述

S4 参数化:正规矩阵+低秩矩阵
虽然矩阵 A不能直接对角化,但是可以表示为正规矩阵+低秩矩阵。
在这里插入图片描述

在这里插入图片描述

HiPPO 矩阵是 S4 模型中用于处理长距离依赖(LRDs)的关键组件。
在这一节中,作者通过以下几个方面的实验来验证 HiPPO 矩阵的重要性:

  • HiPPO 初始化:作者首先研究了不同初始化方法对 SSM 性能的影响,包括随机高斯初始化、HiPPO 初始化以及随机对角高斯矩阵初始化。实验结果表明,HiPPO 初始化在提高模型性能方面起到了关键作用。
  • HiPPO 矩阵是否可训练:作者还探讨了 HiPPO 矩阵固定以及可训练的效果。他们发现,固定 HiPPO 和可训练的差异不大。
  • NPLR SSMs:作者进一步研究了在没有 HiPPO 矩阵的情况下,随机 NPLR(Normal Plus Low-Rank,正规+低秩矩阵)的表现。结果表明,即使在 NPLR 形式下,这些随机矩阵的性能仍然不佳,这验证了 HiPPO 矩阵在 S4 模型中的核心作用。

通过这些消融实验,作者强调了 HiPPO 矩阵在 S4 模型中的重要性。这些实验结果不仅证实了 HiPPO 矩阵在处理长距离依赖方面的有效性,而且也表明了它在提升模型整体性能方面的关键作用。这些发现对于理解 S4 模型的设计和优化至关重要。

2.3 mamba的SSM【S4–>S6】

虽然 S4 在保证了计算效率的同时,优化了长距离依赖问题。
但是由于矩阵 A,B,C是固定不变的,和输入 token 无关,这就导致了 S4 在一些合成任务上效果不佳
状态空间模型,甚至是S4(结构化状态空间模型),在某些对语言建模和生成至关重要的任务上表现不佳,即关注或忽略特定输入的能力

  • 由于(循环/卷积)SSM 是线性时间不变的,对于 SSM生成的每个 token,矩阵 A、B 和 C 都是相同的。它无法选择从历史中回忆哪些之前的 token。无论输入 u 是什么,矩阵 B 都保持不变,因此与 u 无关,同理,无论输入是什么,A 和 C 也不变,这就是我们上面说的静态。即矩阵 A、B 和 C 的静态性质导致内容感知方面的问题。

为了解决上面的问题,作者提出了一种新的选择性 SSM(Selective State Space Models,简称 S6 或 Mamba)。这种模型通过让 SSM 的矩阵 A、B、C 依赖于输入数据,从而实现了选择性。这意味着模型可以根据当前的输入动态地调整其状态,选择性地传播或忽略信息。Mamba 集成了 S4 和 Transformer 的精华,一个更加高效(S4),一个更加强大(Transformer)。

Mamba: Linear-Time Sequence Modeling with Selective State Spaces

在本节中,我们将介绍 Mamba 的两大主要贡献:
一种选择性扫描算法,该算法允许模型过滤(不)相关信息;
一种硬件感知算法,该算法允许通过并行扫描、内核融合和重新计算来高效存储(中间)结果。
它们共同创建了选择性 SSM 或 S6 模型,这些模型可以像自注意力一样用于创建 Mamba 块。

选择性 SSM,这种架构通常被称为选择性 SSM或S6模型,因为它本质上是使用选择性扫描算法计算的 S4 模型。
在这里插入图片描述
S4 和 选择性 SSM 的核心区别在于,它们将几个关键参数(∆, B, C)设定为输入的函数,并且伴随着整个 tensor 形状的相关变化。特别是,这些参数现在具有一个长度维度 L,这意味着模型已经从时间不变(time-invariant)转变为时间变化(time-varying)。
最后作者选择把 A设成了与输入无关,作者给出的解释是离散化之后 A¯=exp⁡(ΔA),Δ的数据依赖能够让整体的 A¯与输入相关。
它们一起选择性地选择在隐藏状态中保留什么和忽略什么,因为它们现在依赖于输入。
较小的步长 ∆ 导致忽略特定单词,而更大的步长 ∆ 则更多地关注输入单词而不是上下文:

因为现在的参数 A,B,C都是输入相关了,所以不再是线性时间不变系统,也就失去了卷积的性质,不能用 FFT来进行高效训练了。
Mamba 作者采用了一种称为硬件感知的算法,实际上就是用三种经典技术来解决这个问题:
内核融合(kernel fusion)、并行扫描(parallel scan)和重计算(recomputation)。

  • 一般的实现会
    提前先把大小为 (B,L,D,N)的 A¯,B¯先算出来,
    然后把它们从 HBM (high-bandwidth memory 或 GPU memory) 读到SRAM,
    然后调用 scan 算子算出 (B,L,D,N)的 output,写到 HBM 里面。
    再开一个kernel 把 (B,L,D,N)的 output 以及(B,L,N)的 C 读进来,
    multiply and sum with C 得到最后的 (B,L,D)output 。
    整个过程的读写是 O(BLDN)。
  • 而 Mambda 作者的方法是:
    把 (Δ,A,B,C)读到 SRAM 里面,总共大小是 O(BLN+DN)
    在 SRAM 里面做离散化,得到 (B,L,D,N)的 A¯,B¯
    在 SRAM 里面做 scan,得到(B,L,D,N)的 output
    multiply and sum with C,得到最后的(B,L,D)output 写入HBM
    整个过程的总读写量是 O(BLN),比之前省了 O(N)倍。 backward 的时候就把 A¯,B¯重算一遍,类似于flashattn 重算 attention 分数矩阵的思想。只要重算的时间比读 O(BLND)快就算有效。

Mamba 的实现比其它方法实现快很多倍,scan 在输入长度 2k 的时候就开始比 FlashAttention 快了,之后越长越快。同时 scan 也比 Convolution 快。在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1602532.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

B树(B-tree)

B树(B-tree) B树(B-tree)是一种自平衡的多路查找树,主要用于磁盘或其他直接存取的辅助存储设备 B树能够保持数据有序,并允许在对数时间内完成查找、插入及删除等操作 这种数据结构常被应用在数据库和文件系统的实现上 B树的特点包括: B树为…

EelasticSearch是什么?及EelasticSearch的安装

一、概述 Elasticsearch 是一个基于 Apache Lucene 构建的开源分布式搜索引擎和分析引擎。它专为云计算环境设计,提供了一个分布式的、高可用的实时分析和搜索平台。Elasticsearch 可以处理大量数据,并且具备横向扩展能力,能够通过增加更多的…

如何获取手机root权限?

获取手机的 root 权限通常是指在 Android 设备上获取超级用户权限,这样用户就可以访问和修改系统文件、安装定制的 ROM、管理应用权限等。然而,需要注意的是,获取 root 权限可能会导致手机失去保修、安全性降低以及使系统变得不稳定。在获取 …

大话设计模式之单例模式

单例模式是一种创建型设计模式,它确保类只有一个实例,并提供一个全局访问点来访问该实例。 单例模式通常在以下情况下使用: 当一个类只能有一个实例,并且客户端需要访问该实例时。当该唯一实例需要被公开访问,以便在…

基于8B/10BGT收发器的PHY层设计(1)

一、PHY层简介 PHY层(Physical Layer)是OSI模型中最低的一层,也是最基本的一层,PHY是物理接口收发器,它实现物理层。包括MII/GMII(介质独立接口)子层、PCS(物理编码子层&#xff09…

c++的学习之路:24、 二叉搜索树概念

摘要 本章主要是讲一下二叉搜索树的实现 目录 摘要 一、二叉搜索树概念 二、 二叉搜索树操作 1、二叉搜索树的查找 2、二叉搜索树的插入 3、二叉搜索树的删除 三、二叉搜索树的实现 1、插入 2、中序遍历 3、删除 4、查找 四、二叉搜索树的递归实现 1、插入 2、删…

Java的maven项目导入本地jar包的三种方式

一、使用本地jar包 在项目中创建一个lib文件夹&#xff0c;将想要使用的本地jar包放进去 然后直接在pom.xml中添加下列依赖&#xff08;项目协作推荐&#xff09; <dependency><groupId>com.fpl</groupId><artifactId>spring</artifactId><…

牛客NC197 跳跃游戏(一)【中等 动态规划 Java、Go、PHP】

题目 题目链接&#xff1a; https://www.nowcoder.com/practice/23407eccb76447038d7c0f568370c1bd 思路 答案说的merge区间就是每个A[i]的地方能跳到的最远坐标是A[i] [i]&#xff0c; 有一个maxReach&#xff0c;遍历一遍A[i], 不断刷新MaxReach, 如果某个i 位置比maxReac…

你觉得职场能力重要还是情商重要?

职场能力和情商都是职业成功的关键因素&#xff0c;它们在不同的情境和角色中扮演着不同的作用。很难简单地说哪一个更重要&#xff0c;因为它们通常是相辅相成的。 职场能力包括专业技能、知识水平、解决问题的能力、工作效率、创新思维等。这些能力是完成工作任务、达成职业目…

通讯录的实现(顺序表)

前言&#xff1a;上篇文章我们讲解的顺序表以及顺序表的具体实现过程&#xff0c;那么我们的顺序表在实际应用中又有什么作用呢&#xff1f;今天我们就基于顺序表来实现一下通讯录。 目录 一.准备工作 二.通讯录的实现 1.通讯录的初始化 2.插入联系人 3.删除联系人 4.…

一篇文章详细介绍Stable Diffusion模型原理及实现过程(附常用模型网站、下载方式)

目录 前言 何为Stable Diffusion模型&#xff1f; Stable Diffusion工作原理&#xff1a; Stable Diffusion模型的应用场景 Stable Diffusion免费使用网站 stability.ai: 本地部署 Stable Diffusion方法&#xff1a; StableDiffusion中文网 博主介绍&#xff1a;✌专注于前后端…

任务管理与守护进程

1.前台进程与后台进程 1.1守护进程 在上一章中&#xff0c;我们实现了一个Tcp服务器&#xff0c;但是这个服务器还存在一些问题&#xff0c;例如&#xff0c;我们将云服务器&#xff08;xshell&#xff09;关闭之后&#xff0c;服务器就无法使用了。 但是真正的服务器肯定不…

Stable Diffusion WebUI 控制网络 ControlNet 插件实现精准控图-详细教程

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 大家好&#xff0c;我是水滴~~ 本文主要介绍 Stable Diffusion WebUI 一个比较重要的插件 ControlNet&#xff08;控制网络&#xff09;&#xff0c;主…

第46篇:随机存取存储器(RAM)模块<五>

Q&#xff1a;本期我们使用Quartus软件的IP Catalog工具创建双端口RAM。 A&#xff1a;前期创建的RAM存储模块只有一个端口&#xff0c;同时为读/写操作提供地址。我们将再创建一个具有两个地址输入端口的RAM模块&#xff0c;分别为读操作和写操作提供地址。选择Basic Functio…

Ubuntu:VSCode中编译运行C++代码

版本&#xff1a;Ubuntu22.04.1 LTS 目录 1 安装VSCode并汉化 2 检查Ubuntu是否已经安装了 GCC 3 在VScode中安装C/C扩展 4 在VSCode中进行C/C配置 1 安装VSCode并汉化 安装VSCode&#xff08;参考之前博客Ubuntu&#xff1a;安装VSCode_ubuntu vscode-CSDN博客&#xff…

两数相加(链表)

2. 两数相加 - 力扣&#xff08;LeetCode&#xff09; 题解 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。…

深入理解 pytest Fixture 方法及其应用

当涉及到编写自动化测试时&#xff0c;测试框架和工具的选择对于测试用例的设计和执行非常重要。在Python 中&#xff0c;pytest是一种广泛使用的测试框架&#xff0c;它提供了丰富的功能和灵活的扩展性。其中一个很有用的功 能是fixture方法&#xff0c;它允许我们初始化测试环…

Ypay源支付最新免授权牛角魔改版

YPay是专为个人站长打造的聚合免签系统&#xff0c;拥有卓越的性能和丰富的功能。采用全新轻量化的界面UI&#xff0c;让您可以更加方便快捷地解决知识付费和运营赞助的难题。同时&#xff0c;它基于高性能的ThinkPHP 6.1.2 Layui PearAdmin架构&#xff0c;提供实时监控和管…

【JavaWeb】Day47.Mybatis基础操作——删除

Mybatis基础操作 需求 准备数据库表 emp 创建一个新的springboot工程&#xff0c;选择引入对应的起步依赖&#xff08;mybatis、mysql驱动、lombok&#xff09; application.properties中引入数据库连接信息 创建对应的实体类 Emp&#xff08;实体类属性采用驼峰命名&#xf…

反转二叉树(力扣226)

解题思路&#xff1a;用队列进行前序遍历的同时把节点的左节点和右节点交换 具体代码如下&#xff1a; class Solution { public:TreeNode* invertTree(TreeNode* root) {if (root NULL) return root;swap(root->left, root->right); // 中invertTree(root->left)…