VSSM VMamba实现

news2025/1/14 0:44:38

文章目录

    • VSSM
      • 维度变换
      • 初始化
      • 模型参数初始化
      • 模型搭建
        • def_make_layer
        • def _make_downsample
      • patch embed
      • 第一至四阶段
      • 分类器
    • VSSBlock
      • def __ init__
        • ssm分支
        • mlp分支
      • def forward

VSSM

Mamba实现可以参照之前的
mamba_minimal系列
论文地址:
VMamba
论文阅读:
VMamba:视觉状态空间模型
代码地址:
https://github.com/MzeroMiko/VMamba.git
SS2D实现

以分类任务用到的VMamba为例。

维度变换

操作的具体参数定义见初始化

阶段维度
输入x [ B , C , H , W ] [B, C, H, W] [B,C,H,W]
embed [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段1 [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段2 [ B , H / 8 , W / 8 , C 2 ] [B, H/8, W/8, C_2 ] [B,H/8,W/8,C2]
阶段3 [ B , H / 16 , W / 16 , C 3 ] [B, H/16, W/16, C_3 ] [B,H/16,W/16,C3]
阶段4 [ B , H / 32 , W / 32 , C 4 ] [B, H/32, W/32, C_4 ] [B,H/32,W/32,C4]
分类器 [ B , 1000 ] [B, 1000 ] [B,1000]

在这里插入图片描述

初始化

参数定义说明
in_chans3输入图像的通道数
depths[2, 2, 9, 2]定义每层的VSS Block数
dims[96, 192, 384, 768]定义每层的输出通道数
downsample_versionv2下采样操作的版本
patchembed_versionv1图像嵌入
mlp_ratio4.0定义mlp隐藏维度缩放
ssm_d_state16ssm隐状态的维度
ssm_ratio2.0d_inner = d_state * ssm_ratio
ssm_initv0ssm初始化版本
forward_typev2ssm前向版本

模型参数初始化

大部分参数即SS2D,VSS块中的参数由定义的ssm初始化版本初始化,剩下的线性层和归一化层参数由下面的函数初始化。

    def _init_weights(self, m: nn.Module):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

模型搭建

def_make_layer

构建VSSM的4个阶段,即4层,VSSBlock本身并不改变输入的尺寸,因此需要下采样模块将输出维度变换为下一阶段的输入维度

def _make_layer(
        dim=96, 
        drop_path=[0.1, 0.1], 
        use_checkpoint=False, 
        norm_layer=nn.LayerNorm,
        downsample=nn.Identity(),
        # ===========================
        ssm_d_state=16,
        ssm_ratio=2.0,
        ssm_dt_rank="auto",       
        ssm_act_layer=nn.SiLU,
        ssm_conv=3,
        ssm_conv_bias=True,
        ssm_drop_rate=0.0, 
        ssm_init="v0",
        forward_type="v2",
        # ===========================
        mlp_ratio=4.0,
        mlp_act_layer=nn.GELU,
        mlp_drop_rate=0.0,
        **kwargs,
    ):
        depth = len(drop_path)
        blocks = []
        for d in range(depth):
            blocks.append(VSSBlock(
                hidden_dim=dim, 
                drop_path=drop_path[d],
                norm_layer=norm_layer,
                ssm_d_state=ssm_d_state,
                ssm_ratio=ssm_ratio,
                ssm_dt_rank=ssm_dt_rank,
                ssm_act_layer=ssm_act_layer,
                ssm_conv=ssm_conv,
                ssm_conv_bias=ssm_conv_bias,
                ssm_drop_rate=ssm_drop_rate,
                ssm_init=ssm_init,
                forward_type=forward_type,
                mlp_ratio=mlp_ratio,
                mlp_act_layer=mlp_act_layer,
                mlp_drop_rate=mlp_drop_rate,
                use_checkpoint=use_checkpoint,
            ))
        
        return nn.Sequential(OrderedDict(
            blocks=nn.Sequential(*blocks,),
            downsample=downsample,
        ))
def _make_downsample

默认下采样版本v2

下采样模块,通过2D卷积之后,长宽变为原来的一半,通道数不变

    def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm):
        return nn.Sequential(
            Permute(0, 3, 1, 2),
            nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
            Permute(0, 2, 3, 1),
            norm_layer(out_dim),
        )

patch embed

默认嵌入版本v1,对输入图像进行embed

输入x维度 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W],嵌入后通道维变为96, H = H p a t c h _ s i z e H = \frac{H}{patch\_size} H=patch_sizeH W = W p a t c h _ s i z e W = \frac{W}{patch\_size} W=patch_sizeW [ B , 96 , H 4 , W 4 ] [B, 96, \frac{H}{4}, \frac{W}{4}] [B,96,4H,4W]

 def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):
        return nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
            Permute(0, 2, 3, 1),
            (norm_layer(embed_dim) if patch_norm else nn.Identity()), 
        )

第一至四阶段

这几个阶段的差别在于每一层的VSSBlock数不同,由depths定义分别为 [2, 2, 9, 2],输出维度由dims定义分别为[96, 192, 384, 768]。其组成元素除一阶段外,均在VSSBlock前包含下采样模块以变换维度。

具体介绍见VSSBlock

分类器

池化后长宽变为1,则变量尺寸变为 [ B , C , 1 , 1 ] [B, C, 1, 1] [B,C,1,1],展平后变为 [ B , C ] [B, C] [B,C]最后线性投影到类别维度1000

[ B , 1000 ] [B, 1000] [B,1000]

self.classifier = nn.Sequential(OrderedDict(
            norm=norm_layer(self.num_features), # B,H,W,C
            permute=Permute(0, 3, 1, 2),
            avgpool=nn.AdaptiveAvgPool2d(1),
            flatten=nn.Flatten(1),
            head=nn.Linear(self.num_features, num_classes),
        ))

VSSBlock

对于ssm分支来说,其输入输出维度不变为(B, H, W, d_model) ,对于mlp分支来说中间的隐藏维度根据mlp_ratio参数定义会有所增加,但是最后又会映射为原来的维度,因此整体上并不改变输入的维度。

def __ init__

主要分为两个分支ssm分支和mlp分支

ssm分支

主要组成部分是SS2D块
SS2D实现

if self.ssm_branch:
            self.norm = norm_layer(hidden_dim)
            self.op = _SS2D(
                d_model=hidden_dim, 
                d_state=ssm_d_state, 
                ssm_ratio=ssm_ratio,
                dt_rank=ssm_dt_rank,
                act_layer=ssm_act_layer,
                # ==========================
                d_conv=ssm_conv,
                conv_bias=ssm_conv_bias,
                # ==========================
                dropout=ssm_drop_rate,
                # =========================
                initialize=ssm_init,
                forward_type=forward_type,
            )      

图中的SS2D和SS2D类的定义有偏差,简单来说是是包含SS2D块加一个残差连接,图中所示SS2D应表示状态空间模型SSM部分,即VSS块相比SS2D块只增加了残差连接和入口的归一化。如果定义了MLP分支,VSS块的输出还会经过一个残差连接的两层MLP

在这里插入图片描述

mlp分支
 if self.mlp_branch:
            self.norm2 = norm_layer(hidden_dim)
            mlp_hidden_dim = int(hidden_dim * mlp_ratio)
            self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False)
            
 class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
        self.fc1 = Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def forward

    def _forward(self, input: torch.Tensor):
        if self.ssm_branch:
            if self.post_norm:
                x = input + self.drop_path(self.norm(self.op(input)))
            else:
                x = input + self.drop_path(self.op(self.norm(input)))
        if self.mlp_branch:
            if self.post_norm:
                x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
            else:
                x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
        return x

     
    
    def forward(self, input: torch.Tensor):
        if self.use_checkpoint:
            return checkpoint.checkpoint(self._forward, input)
        else:
            return self._forward(input)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1519538.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis 的常用基本全局命令【小林优选】

前言 Redis 常用的有 5 种数据结构,字符串,列表,哈希表,集合,有序集合,每一种数据结构都有自己独特的命令,但也有些通用的全局命令,本文所提到的是最基本的命令,Redis 的…

CIDR网络地址、广播地址、网段区间计算说明与计算工具

文章目录 开始问题参考答案 答案解析计算工具测试 开始 好久没有看计算网络,感觉已经完全返给老师了。 最近,有同事遇到个问题,网络一直不对,又开始重新看一下。 相信很多朋友长时间不看也忘了,所以,这里…

TCP机械臂控制

通过w(红色臂角度增大)s(红色臂角度减小)d(蓝色臂角度增大)a(蓝色臂角度减小)按键控制机械臂 注意:关闭计算机的杀毒软件,电脑管家,防火墙 1)基于TCP服务器…

数据泄露态势(2024年2月)

监控说明:以下数据由零零信安0.zone安全开源情报系统提供,该系统监控范围包括约10万个明网、深网、暗网、匿名社交社群威胁源。在进行抽样事件分析时,涉及到我国的数据不会选取任何政府、安全与公共事务的事件进行分析。如遇到影响较大的伪造…

专业款希亦、小米、必胜、云鲸洗地机怎么样?深度测评利弊

洗地机可以说是一种非常实用的清洁工具,尤其是对于那些需要经常给家里地板清洁的人来说。它能够高效、彻底清洁地板,去除顽固污渍、灰尘和细菌,让家居环境更加洁净卫生。可是面对型号繁多的洗地机,我们应该怎么挑选呢?…

架构设计-复杂度来源:高性能

对性能孜孜不倦的追求是整个人类技术不断发展的根本驱动力。例如计算机,从电子管计算机到晶体管计算机再到集成电路计算机,运算性能从每秒几次提升到每秒几亿次。但伴随性能越来越高,相应的方法和系统复杂度也是越来越高。现代的计算机 CPU 集…

实现elasticsearch和数据库的数据同步

1. 数据同步 elasticsearch中的酒店数据来自于mysql数据库,因此mysql数据发生改变时,elasticsearch也必须跟着改变,这个就是elasticsearch与mysql之间的数据同步。 1.1. 思路分析 常见的数据同步方案有三种: 同步调用 异步通知…

element-plus表格,多样本比较,动态渲染表头

问题: 公司给了个excel表格,让比较两个样本之间的数据,并且动态渲染,搞了半天没搞出来,最后让大佬解决了,特此写篇博客记录一下。 我之前的思路是合并行,大概效果是这样: 但是最终…

微服务学习day02 -- nacos配置管理 -- Feign远程调用 -- Gateway服务网关

0.学习目标 1.Nacos配置管理 Nacos除了可以做注册中心,同样可以做配置管理来使用。 1.1.统一配置管理 当微服务部署的实例越来越多,达到数十、数百时,逐个修改微服务配置就会让人抓狂,而且很容易出错。我们需要一种统一配置管理…

雷卯有多种封装的超低电容ESD供您选择

1.应用 HDMI 1.3、1.4、2.0、2.1 接口 LCD、HDTV MIPI 接口 手机、 天线(手机、GPS...) 高速以太网 100/1000 以太网 USB 2.0 和 USB3.0、DVI、LVDS、IEEE 1394 接口 车载信息箱、VN(车载导航) 卫星导航、便携式导航 …

3.2网安学习第三阶段第二周回顾(个人学习记录使用)

本周重点 ①SQL语句的基本用法 ②SQL注入的基本概念和原理 ③SQL注入类型(**重点) ④SQL注入的防御和绕过手段 本周主要内容–SQL 一、SQL语句的基本用法 limit用法:显示查询结果中从第n条开始显示m条记录 select * from tb_user limit 1,2union用法&#x…

构建部署_Docker常用命令

构建部署_Docker常见命令 启动命令镜像命令容器命令 启动命令 启动docker:systemctl start docker 停止docker:systemctl stop docker 重启docker:systemctl restart docker 查看docker状态:systemctl status docker 开机启动&…

【Shiro反序列化漏洞】Shiro-550反序列化漏洞复现

🍬 博主介绍👨‍🎓 博主介绍:大家好,我是 hacker-routing ,很高兴认识大家~ ✨主攻领域:【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 🎉点赞➕评论➕收…

腾讯云4核8G服务器性能怎么样?搭建网站够用吗?

腾讯云轻量4核8G12M服务器配置446元一年,646元12个月,腾讯云轻量应用服务器具有100%CPU性能,系统盘为180GB SSD盘,12M带宽下载速度1536KB/秒,月流量2000GB,折合每天66.6GB流量,超出月流量包的流…

【Datawhale组队学习:Sora原理与技术实战】训练一个 sora 模型的准备工作,video caption 和算力评估

训练 Sora 模型 在 Sora 的技术报告中,Sora 使用视频压缩网络将各种大小的视频压缩为潜在空间中的时空 patches sequence,然后使用 Diffusion Transformer 进行去噪,最后解码生成视频。 Open-Sora 在下图中总结了 Sora 可能使用的训练流程。…

快手发布革命性视频运动控制技术 DragAnything,拖动锚点精准控制视频物体和镜头运动

快手联合浙江大学、新加坡国立大学发布了DragAnything ,利用实体表示实现对任何物体的运动控制。该技术可以精确控制物体的运动,包括前景、背景和相机等不同元素。 该项目提供了对实体级别运动控制的新见解,通过实体表示揭示了像素级运动和实…

layuiAdmin-通用型后台模板框架【广泛用于各类管理平台】

1. 主页 1.1 控制台 2. 组件 3. 页面 3.1 个人主页 3.2 通讯录 3.3 客户列表 3.4 商品列表 3.5 留言板 3.6 搜索结果 3.7 注册 3.8 登入 3.9 忘记密码 4. 应用 4.1 内容系统 4.1.1 文章列表 4.1.2 分类管理 4.1.3 评论管理 4.2 社区系统 4.2.1 帖子列表 4.2.2 回…

数据结构---C语言栈队列

知识点: 栈: 只允许在一端进行插入或删除操作的线性表,先进后出LIFO 类似一摞书,按顺序拿,先放的书只能最后拿; 顺序栈:栈的顺序存储 typedef struct{Elemtype data[50];int top; }SqStack; SqS…

724.寻找数组的中心下标

题目:给你一个整数数组 nums ,请计算数组的 中心下标 。 数组 中心下标 是数组的一个下标,其左侧所有元素相加的和等于右侧所有元素相加的和。 如果中心下标位于数组最左端,那么左侧数之和视为 0 ,因为在下标的左侧不…

力扣爆刷第96天之hot100五连刷66-70

力扣爆刷第96天之hot100五连刷66-70 文章目录 力扣爆刷第96天之hot100五连刷66-70一、33. 搜索旋转排序数组二、153. 寻找旋转排序数组中的最小值三、4. 寻找两个正序数组的中位数四、20. 有效的括号五、155. 最小栈 一、33. 搜索旋转排序数组 题目链接:https://le…