Mamba-minimal Mamba的最小限度实现 (一)

news2024/9/23 21:21:18

文章目录

    • 参数和数据尺寸约定
    • class MambaBlock
      • def forward
      • def __ int__
      • def ssm
      • def selective_scan

johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现

参数和数据尺寸约定

之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示

参数及简写Mamba论文简写
batch_size bB
序列长度 lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 expandE
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / delta
delta秩 dt_rank

class MambaBlock

def forward

根据forward简单梳理MambaBlock的结构

在这里插入图片描述

中间变量来源shape
输入x(b, l, d_model)
x_and_resx经过输入映射后(b, l, 2* d_in)
x切分后作为ssm分支输入(b, l, d_in)
res切分后作为门控分支输入(b, l, d_in)
y经过卷积,激活,ssm,门控后的输出(b, l, d_in)
outputy经过输出映射后得到(b, l, d_model)
def forward(self, x):
  
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output
    

def __ int__

初始化主要初始了几个部分

组件定义

操作及简写维度变换
输入映射 in_proj(b, l, d_model) -> (b, l, 2*d_in)
序列变换 conv1d只取前l (b, d_in, l) -> (b, d_in, l)
非线性激活 silu
输出映射 out_proj(b, l, d_in) -> (b, l, d)

在这里插入图片描述
ssm初始化

操作及简写作用
参数生成映射 x_proj生成数据依赖的参数B, C, Δ \Delta Δ
delta映射 dt_proj Δ \Delta Δ从dt_rank映射到d_in
矩阵A初始化简单初始化
矩阵D初始化简单初始化
def __init__(self, args: ModelArgs):

        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )
        
         # ssm模型的初始化部分
         # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

def ssm

这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。

SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖

其中一部分变量初始化于class MambaBlock的初始化部分

中间变量及简写来源
数据生成变量 x_db1x经过参数映射x_proj生成
最终delta Δ \Delta Δ切分而来的 Δ \Delta Δ经过映射和softplus
 def ssm(self, x):

        (d_in, n) = self.A_log.shape
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y
SSM参数shape
状态矩阵A(d_in, n)
输入矩阵B(b, l, n)
输出矩阵C(b, l, n)
直接传递矩阵D(d_in)

def selective_scan

我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。

在这里插入图片描述

在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化

前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk1+ΔkBuk=(I+ΔA)x(t)+ΔBu(t)

零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk1+(ΔkA)1(eΔkAI)ΔkBuk=eΔAx(t)+(ΔA)1(eΔAI)ΔBu(t)

这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢

def selective_scan(self, u, delta, A, B, C, D):

        (b, l, d_in) = u.shape
        n = A.shape[1]
        

        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])

        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入了解Kafka中生产者的神奇力量

欢迎来到我的博客,代码的世界里,每一行都是一个故事 深入了解Kafka中生产者的神奇力量 前言生产者的基本概念Kafka 生产者的定义:Kafka 生产者的基本原理:为何生产者是 Kafka 消息传递的创造者: 生产者的创建于配置生产…

新版AndroidStudio的Gradle窗口显示task list not built 问题解决

在使用新版AndroidStudio时,会出现,Task List not built 的问题。如果你记得task的名字,当然可以 直接通过命令 gradle taskname 或者 ./gradlew taskName直接执行即可,但是若是记不住,还是把这个任务构建处理比较好用…

智慧文旅|AI数字人导览:让旅游体验不再局限于传统

AI数字人导览作为一种创新的展示方式,已经逐渐成为了VR全景领域的一大亮点,不仅可以很好的嵌入在VR全景中,更是能够随时随地为观众提供一种声情并茂的讲解介绍,结合VR场景的沉浸式体验,让观众仿佛置身于真实场景之中&a…

『python爬虫』requests实战-精易论坛自动签到(保姆级图文+实现代码)

目录 实现效果API命令解析re.findall 匹配内容,用于在我们得到的网页源码中查找指定的内容session.post() 和 session.get() 实现思路库cookie怎么抓取cookie登录如何实现得到FORMHASH参数自动签到自动评分 实现代码后续优化总结 欢迎关注 『python爬虫』 专栏,持续…

Midjourney绘图欣赏系列(七)

Midjourney介绍 Midjourney 是生成式人工智能的一个很好的例子,它根据文本提示创建图像。它与 Dall-E 和 Stable Diffusion 一起成为最流行的 AI 艺术创作工具之一。与竞争对手不同,Midjourney 是自筹资金且闭源的,因此确切了解其幕后内容尚不…

Idea创建Maven项目

Maven安装配置步骤: 解压安装 bin目录 : 存放的是可执行命令。(mvn 命令重点关注) conf目录 :存放Maven的配置文件。(settings.xml配置文件后期需要修改) lib目录 :存放Maven依赖的j…

KH-MCX-KWE-W

KH-MCX-KWE-W品牌: kinghelm(金航标)封装: 插件 描述: 镀金

【教程】Github环境配置新手指南(超详细)

写在前面: 如果文章对你有帮助,记得点赞关注加收藏一波,利于以后需要的时候复习,多谢支持! 文章目录 一、Github初始设置(一)登入Github(二)新建仓库 二、本地Git配置&am…

在线部署ubuntu20.04服务器,安装jdk、mysql、redis、nginx、minio、开机自启微服务jar包

一、服务器 1、查看服务器版本 查看服务器版本为20.04 lsb_release -a2、服务器信息 服务器初始账号密码 sxd / 123456 首先,更改自身密码都输入123456 sudo passwd 创建最高权限root账号,密码为 123456 su root 3、更新服务器源 1、更新源列表 sudo apt-g…

tomcat优化与部署(三)------nignx优化与nginx +tomcat 部署

在目前流行的互联网架构中,Tomcat在目前的网络编程中是举足轻重的,由于Tomcat的运行依赖于JVM,从虚拟机的角度把Tomcat的调整分为外部环境调优 JVM 和 Tomcat 自身调优两部分 Tomcat 是一个流行的开源 Java 服务器,用于托管 Java …

简单题我重拳出击

有请第一位嘉宾:. - 力扣(LeetCode) 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 n…

代码随想录训练营第40天 | LeetCode 343. 整数拆分

LeetCode 343. 整数拆分 文章讲解:代码随想录(programmercarl.com) 视频讲解:动态规划,本题关键在于理解递推公式!| LeetCode:343. 整数拆分_哔哩哔哩_bilibili 思路 代码如下: ​​​​​​LeetCode 96…

【产品应用】一体化步进伺服电机在绿光激光打标机中的应用

随着科技的不断发展,激光打标技术已经成为现代工业生产中不可或缺的一部分。绿光激光打标机以其高精度、高效率、高可靠性等特点,广泛应用于各种材料的标记与打标。而在绿光激光打标机中,一体化步进电机的应用则为其带来了更高的性能与更稳定…

Lesson 5 Classification(short version)

听课(李宏毅老师的)笔记,方便梳理框架,以作复习之用。本节课主要讲了回归和分类的区别,分类的过程,分类的损失函数。这节课比较简短。 1. 回归和分类的区别 回归只是输出一个预测的值分类是输出预测的cla…

【Leetcode每日一刷】数组|双指针篇:977. 有序数组的平方、76. 最小覆盖子串(附滑动窗口法详解)

力扣每日刷题 一、977. 有序数组的平方1.1题目1.2、解题思路1.3、代码实现——C 二、76. 最小覆盖子串2.1:题目2.2、解题思路2.3:代码实现——c2.4:易错点 一、977. 有序数组的平方 1.1题目 [题目链接]( 1.2、解题思路 题型:双…

请编程输出无向无权图各个顶点的度 ← STL vector 模拟邻接表存图

【题目描述】 请利用 STL vector 模拟邻接表存图,编程输出无向无权图各个顶点的度。【输入样例】 5 6 1 3 2 1 1 4 2 3 3 4 5 1【输出样例】 4 2 3 2 1【算法分析】 本例利用 STL vector 模拟实现邻接表。代码参见:https://blog.csdn.net/hnjzsyjyj/arti…

服务器配置禁止IP直接访问,只允许域名访问

联网信息系统需设置只允许通过域名访问,禁止使用IP地址直接访问,建议同时采用云防护技术隐藏系统真实IP地址且只允许云防护节点IP访问服务器,提升网络安全防护能力。 一、Nginx 修改配置文件nginx.conf,在server段里插入正则表达式…

Redis系列之持久化机制RDB和AOF

Redis系列之持久化机制RDB和AOF 文章目录 1. 为什么需要持久化?2. 持久化的方式3. RDB机制3.1 RDB机制介绍3.2 配置RDB3.3 什么时候触发3.4 操作实例3.5 RDB优势和不足 4. AOF机制4.1 什么是AOF机制?4.2 同步机制4.3 重写机制4.4 AOF的优势和不足 混合模…

C++的面向诗篇:类的叙事与对象的旋律

个人主页:日刷百题 系列专栏:〖C/C小游戏〗〖Linux〗〖数据结构〗 〖C语言〗 🌎欢迎各位→点赞👍收藏⭐️留言📝 ​ ​ 一、面向对象的定义 学习C语言时,我们就经常听说C语言是面向过程的,…

3.7号freeRtoS

1. 串口通信 配置串口为异步通信 设置波特率,数据位,校验位,停止位,数据的方向 同步通信 在同步通信中,数据的传输是在发送端和接收端之间通过一个共享的时钟信号进行同步的。这意味着发送端和接收端的时钟需要保持…