Mamba.py: 状态空间模型的并行扫描

news2024/11/20 3:29:03

文章目录

  • Mamba.py:扫描和并行扫描
    • 什么是扫描
    • 什么是并行扫描
    • 累加计算的例子
      • 矩阵求和简化
      • 矩阵求和python实现
      • 累加求和的并行
    • Blelloch 算法
      • Up-sweep
      • Down-sweep
    • selective_scan

Mamba.py:扫描和并行扫描

mamba.py/docs/pscan.ipynb at main · alxndrTL/mamba.py (github.com)

是Mamba.py作者对其编写的pscan即并行扫描的解释。pscan是Belloch算法的pytorch实现,即并行前缀和扫描,Belloch算法可以参照Understanding the implementation of the Blelloch Algorithm (Work-Efficient Parallel Prefix Scan) | by Shivam Mohan | Nerd For Tech | Medium以及Prefix Sums and Their Applications(cmu.edu)

但是根据作者所说,因没有应用原始实现的recomputation技术,因此会占用巨大的显存,所以仅作教学用途。
Mamba.py的训练速度见下图
在这里插入图片描述

什么是扫描

一个扫描定义为一个操作,把一个矩阵作为输入,产生一个矩阵作为输出。

我个人理解:

系统的输入或者说外部环境随时间变化,而我们系统也要随之不断更新,不断将这些变量“扫描”进去,因为我们只能处理离散信息,所以我们会有一个采样步长 Δ \Delta Δ,我们根据这些外部信息更新我们所需要的量,比如我们的输出量,控制量。

扫描是外部变量的扫描,也是内部变量的扫描,所以在我看来,扫描的同义词是更新,扫描就是根据输入更新状态空间模型的参数和输出。

什么是并行扫描

下面是一个因果卷积网络

输出依赖于之前时刻的输入,因此按顺序输入也一定按顺序输出,我们不可能在还可能有输入的情况下,就盖棺定论或者未卜先知得到下一个时刻的输出。

但在输入确定的情况下,例如我们只有6个点的输入,我们不需要再按顺序一个个计算,因为不确定的是输入,而不是输入与输出的关系,只要输入确定,我们没有必要对输入做不必要地等待,我们可以并行计算输出,我们可以并行的计算上面橙色的输出。
在这里插入图片描述

累加计算的例子

一个简单著名的扫描的例子是一个矩阵的累加求和。

X = torch.tensor([1, 2, 3, 4])

torch.cumsum(X, dim=0)

最简单的一个实现是for循环

Y = torch.zeros_like(X)

cumulative_sum = 0
for t in range(X.size(0)):
    cumulative_sum += X[t]
    Y[t] = cumulative_sum

我们使用了一个累加变量cumulative_sum

一个等价形式如下

Y = torch.zeros_like(X)

Y[0] = X[0]
for t in range(1, X.size(0)):
    Y[t] = Y[t-1] + X[t]

我们不再显式表达累加变量,但它实际在Y里面。

表达为了递归形式 Y [ t ] = Y [ t − 1 ] + X [ t ] Y[t] = Y[t-1]+X[t] Y[t]=Y[t1]+X[t]

在这里插入图片描述

有点像是RNN的形式,从某种角度来说,Y相当于隐状态,X相当于输入,当我们处理输入时,我们不断更新隐状态。

我们看到这种计算方法时顺序循环,又没有可能并行扫描操作。

矩阵求和简化

更进一步简化我们的目标为计算输入矩阵X的总和。

在这里插入图片描述

可以写成一个树状结构

L = 2 d L = 2^{d} L=2d则我们可以通过两两求和将总的计算次数由 L L L变为 d d d次顺序求和,如图,原有7次加法8个阶段,可以变成7次加法3个阶段,每个阶段内的加法可以同时进行。

矩阵求和python实现

X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array
L = X.size(0)

Xa = X

for k in range(int(math.log2(L))):
    T = 2 * (Xa.size(0) // 2)

    # 两两分成一对
    Xa = Xa.view(T//2, 2) 

    # 每一对的两个元素相加
    Xa[:, 1].add_(Xa[:, 0])

    # 更新Xa
    Xa = Xa[:, 1]

因为Xa是X的一个view,因此我们实际上在原地进行更新,更新完毕的X实际为[1, 3, 3, 10, 5, 11, 7, 36]

但我们将每次循环中不更新的部分去掉,变成一个树状结构。

累加求和的并行

在这里插入图片描述

我们看到一些节点值是输入矩阵的部分和,例如10是1到4的和。每个节点实际上是它们的子节点或子孙节点的和。

因此我们可以利用这些节点值来计算,如果一个节点是左节点,我们将他们与其兄弟右节点的左子节点相连,将相加得到的值更新到所连接的后代节点。由上而下更新

在这里插入图片描述

但是需要注意的是,我们从始至终没有开辟过新的空间,在原地进行值的更新。因此最下层,即真正的X为如下值,而红框圈出的部分实际并不是sum(3->5),因为他所加的值,即X[3]并不是上层即中间过程中的值即3到4的值,而是最后更新的值,即1到4的和,因此应更正为1->5的求和,至此,所有的累加和均已求出

在这里插入图片描述

因此整个过程分为两部分,一部分是up-sweep,即从下往上,以求整个矩阵总和的形式原地更新矩阵值,而原地操作,使得矩阵中的值实际上是不同层次的,再通过down-sweep,从上到下,利用已更新的值更新完剩下的值

总结下来:

  • 首先向上扫描,把顶部元素,即总和作为根
  • 将当前节点值赋值给右节点,而当前节点的左节点值,是当前节点值减去UP树对应右节点的值
  • 重复

在这里插入图片描述

例如,我们得到28,在最右边节点首先我们得到前八个数的和36,要得到前七个,那么就是减去第八个,即原来这个节点的值为8,36-8 = 28

Blelloch 算法

Blelloch前缀求和,和累加求和的区别是,前缀和不包括自身的值。

Up-sweep

两两成对求和

在这里插入图片描述

在这里插入图片描述

Down-sweep

在这里插入图片描述

在这里插入图片描述

最后在第七步我们得到前缀和

selective_scan

对于状态空间模型的扫描,主要更新的是隐状态,得到隐状态即可得到我们的输入。

通过pscan函数得到隐状态hs(B, L, ED, N)

 def selective_scan(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
        
        hs = pscan(deltaA, BX) #隐状态x

        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y

deltaA实际为 e Δ A e^{\Delta A} eΔA BX实际为 Δ B u \Delta B u ΔBu

deltaA重定义为Aa, BX重定义为Xa
x k = e Δ k A x k − 1 + Δ k B ⋅ u k \begin{aligned} x_k=e^{\Delta_{k}\boldsymbol A}x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k} \end{aligned} xk=eΔkAxk1+ΔkBuk

x k = A a   x k − 1 + X a \begin{aligned} x_k=Aa\ x_{k-1}+Xa \end{aligned} xk=Aa xk1+Xa
具体实现见pscan.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1507393.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka可用与可靠机制

文章目录 kafka的副本机制ACKSIn-Sync Replicas(ISR)Unclean leader electionmin.insync.replicas acks1的情况acks-1的情况acks-1和min.insync.replicas2的情况 同步机制1.follower不对外提供服务的原因2.幂等性的实现 kafka的副本机制 假如3分区&…

Linux/Windows下部署OpenCV环境(Java/SpringBoot/IDEA)

环境 本文基于Linux(CentOS 7)、SpringBoot部署运行OpenCV 4.5.5,并顺带记录Windows/IDEA下如何调试SpringBoot调用OpenCV项目。 Windows下调试 首先我们编写代码,并在Windows/IDEA下调试通过。 下载Windows版安装包&#xff0…

OD_2024_C卷_200分_6、六_连续出牌数量【JAVA】【回溯算法】

题目描述 package odjava;import java.util.Arrays; import java.util.Scanner;public class 六_连续出牌数量 {// 定义扑克牌类static class Card {int num; // 牌号char color; // 花色public Card(int num, String color) {this.num num;this.color color.charAt(0); // 取…

网络套接字1

网络套接字1 📟作者主页:慢热的陕西人 🌴专栏链接:Linux 📣欢迎各位大佬👍点赞🔥关注🚓收藏,🍉留言 本博客主要内容讲解了udp的Linux环境下的使用&#xff0c…

JVM3_数据库连接池虚引用ConnectionFinalizerPhantomReference引起的FullGC压力问题排查

背景 XOP服务运行期间,查看Grafana面板,发现堆内存周期性堆积,观察FullGC的时间,xxx,需要调查下原因 目录 垃圾收集器概述 常见的垃圾收集器分区收集策略为什么CMS没成为默认收集器 查看JVM运行时环境分析快照 Pha…

msfconsole中db_namp的使用方法以及如何让msf连接数据库

一、db_nmap使用方法 1.打开数据库 1.1查看数据库postgresql连接状态 systemctl status postgresql查看数据库postgresql连接状态、 1.2启动postgresql systemctl start postgresql启动postgresql 1.3初始化 msfdb init初始化 2.C段扫描(db_nmap的使用) 2.1 db_nmap -sP 192…

AIGC实战——GPT(Generative Pre-trained Transformer)

AIGC实战——GPT 0. 前言1. GPT 简介2. 葡萄酒评论数据集3. 注意力机制3.1 查询、键和值3.2 多头注意力3.3 因果掩码 4. Transformer4.1 Transformer 块4.2 位置编码 5. 训练GPT6. GPT 分析6.1 生成文本6.2 注意力分数 小结系列链接 0. 前言 注意力机制能够用于构建先进的文本…

windows和linux系统安装redis

Redis安装 Redis安装与启动windows服务 Redis 安装 这样安装完在系统服务中并没有redis服务 redis服务启动 Redis安装与启动Linux服务 1.下载压缩包到服务器 我下载的是最新版本7.0.12,这里我是直接下载到了root目录下 wget https://github.com/redis/redis…

ChatGPT逐步进入留学圈但并不能解决留学规划的问题

2022 年底,一个能像人类一样对话的AI软件ChatGPT,在5天内突破一百万用户,风靡全球,如今用户已达1.8亿。 四个月后,ChatGPT进化为GPT4版本。该版本逻辑、数学推理能力卓越。拿留美标准化考试举例,GPT4能够在…

图论练习6

[NOIP2013]车站分级 Here 解题思路 由于起始点之间所选的站号,相互之间一定满足那么对于起始点间未选择的站号,一定满足选择的站号考虑用边来维护信息,表示的级别大于按题意,则车站会被分为几个联通块,且保证块内无环…

使用Java和PostGis的全国A级风景区数据入库实战

目录 前言 一、数据介绍 1、空间数据 2、属性表说明 3、QGIS数据预览 二、PostGIS空间数据库设计 1、空间表结构 三、Java空间入库 1、实体定义 2、数据操作Mapper 3、业务层实现 4、入库 5、数据入库验证 总结 前言 星垂平野阔,月涌大江流”“晴川历历…

WinoGrande数据集分享

来源: AINLPer公众号(每日干货分享!!) 编辑: ShuYini 校稿: ShuYini 时间: 2024-3-11 该数据集由华盛顿大学的研究人员提出,它是一个大规模的常识推理挑战数据集,包含约44,000个问题,旨在评估和…

【实战项目】网络编程:在Linux环境下基于opencv和socket的人脸识别系统--C++实现

🌞前言 这里我们会实现一个项目:在linux操作系统下基于OpenCV和Socket的人脸识别系统。 目录 🌞前言 🌞一、项目介绍 🌞二、项目分工 🌞三、项目难题 🌞四、实现细节 🌼4.1 关…

【APP逆向】酒仙网预约茅台程序,包含逆向过程详解

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 所属的专栏:爬虫实战,零基础、进阶教学 景天的主页:景天科技苑 文章目录 酒仙网预约抢购茅台1.抓包分析,账户名和密码登录2.短信登录3.登录+茅台预约 密码登录酒仙网预约抢购茅台 目标:账号登…

重启 explorer 进程的正确做法(二)

重启资源管理器进程的方法不唯一,但长期以来大家对实施方法用的不到位。 在上一篇中我认为:“我们往往使用 TerminateProcess 并传入 PID 和特殊结束代码 1 或者 taskkill /f /im 等方法重启资源管理器( explorer.exe ),其实这是不正确的。我…

jdk17出现错误无法初始化主类 和NoClassDefFoundError:Vector的解决方法

概述:网上流传文章大多都是编译和运行都加下面这串代码 --add-modulesjdk.incubator.vector我估计他们大多都是复制粘贴的文章,这种东西就是电子垃圾,在idea中,大多人都习惯用maven来构建java项目,接下来我将讲解使用…

Android14音频进阶:AudioTrack如何巧妙衔接AudioFlinger(五十七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只…

设计模式十:原型模式

文章目录 1、原型模式1.1 类创建过程1.2 浅拷贝1.3 深拷贝 2、示例2.1 简单形式2.2 复杂形式 3、spring中的原型模式3.1 ArrayList的原型模式3.2 spring中的原型模式 1、原型模式 原型模式就是从一个对象再创建另外一个可定制的对象, 而且不需要知道任何创建的细节。…

前端解决跨域问题( 6种方法 )

本专栏是汇集了一些HTML常常被遗忘的知识,这里算是温故而知新,往往这些零碎的知识点,在你开发中能起到炸惊效果。我们每个人都没有过目不忘,过久不忘的本事,就让这一点点知识慢慢渗透你的脑海。 本专栏的风格是力求简洁…

使用JDBC操作数据库

意志、工作和等待是成功的金字塔的基石。 Will, work and wait are the pyramidal cornerstones for success. 文章目录 JDBC简介:JDBC访问数据库步骤StatementPreparedStatement JDBC简介: 在Java应用程序中,JDBC(Java Database…