笔记01----Transformer高效语义分割解码器模块DEPICT(即插即用)

news2024/11/21 15:23:46

学习笔记01----即插即用的解码器模块DEPICT

    • 前言
    • 源码下载
    • DEPICT实现
    • 实验

前言

文 章 标 题:《Rethinking Decoders for Transformer-based Semantic Segmentation: Compression is All You Need》
当前的 Transformer-based 方法(如 DETR 和其变体)取得了显著进展。但这些解码器(decoder)的设计更多是基于经验,缺乏理论解释,难以确定性能瓶颈并进行进一步改进。
该论文将语义分割任务建模为“从主空间到子空间的信息压缩”问题,强调从高维图像特征中提取类别相关的紧凑表示。
提出 DEPICT 解码器:

  • 基于 自注意力(MSSA) 和 交叉注意力(MSCA) 设计简单高效的解码器。
  • MSSA 构建主子空间,去除冗余,优化图像特征。
  • MSCA 动态提取类别相关特征,生成类别嵌入的低维表示。

源码下载

源代码地址:https://github.com/QishuaiWen/DEPICT

DEPICT实现

在这里插入图片描述
DEPICT流程:
1. 图像特征输入: 通过vit的主干网络对图像进行特征提取。这些特征中可能包含很多不重要的信息,比如背景噪声。我们的目标是提取出与分类相关的特征。
2.sa模式—自注意力模块(MSSA): 通过自注意力机制(Multi-head Subspace Self-Attention, MSSA),捕捉图像块之间的全局关系,去掉不相关信息,优化出更加紧凑的主要特征(主子空间)。它的具体操作是将 类别嵌入向量图像特征进行 拼接操作 输入 MSSA模块进行特征优化。
3.ca模式—交叉注意力模块(MSCA):类别嵌入(这是一个可学习的特征向量)作为查询,图像特征作为键和值,通过交叉注意力(Multi-head Subspace Cross-Attention, MSCA)提取每个类别的相关特征,生成类别嵌入的低维表示。它的具体操作是将 类别嵌入向量 作为 查询向量 通过MSCA进行特征优化。
类别嵌入向量是一个可学习的参数,是从 主空间中提取 出的,与类别强相关的特征子集,是图像特征的降维。
4.生成分割掩码:用点积操作比较图像特征和类别嵌入,生成每块图像属于每个类别的概率。

import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
from dec_blocks import Transformer
from segm.model.utils import init_weights
class MaskTransformer(nn.Module):
    def __init__(
            self,
            n_cls,#类别数量
            patch_size,# 图像分块大小
            n_layers,  # Transformer 的层数
            n_heads,  # 多头注意力中的头数
            d_model,  # 特征的嵌入维度
            dropout,  # dropout 概率
            mode='ca',  # 模式选择:'ca' (交叉注意力) 或 'sa' (自注意力)
    ):
        super().__init__()

        self.patch_size = patch_size
        self.n_cls = n_cls
        self.mode = mode

        # cls_emb 是类别嵌入矩阵,初始化为随机值,形状为 (1, n_cls, d_model)。
        # 在 DEPICT 中,类别嵌入对应于主子空间的基向量 P
        self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
        
        if mode == 'sa':
            # 提取图像主特征
            self.net = Transformer(d_model, n_layers, n_heads, 100, dropout)
            self.decoder_norm = nn.LayerNorm(d_model)
        elif mode == 'ca':
            # 用于优化图像特征的主特征
            self.snet = Transformer(d_model, n_layers, n_heads, 100, dropout)
            # 用于进一步提取类别嵌入
            self.cnet = Transformer(d_model, 3, n_heads, 50, dropout)
            self.snorm = nn.LayerNorm(d_model)
            self.cnorm = nn.LayerNorm(d_model)
        else:
            raise ValueError(f"Provided mode: {mode} is not valid.")
            
        self.mask_norm = nn.LayerNorm(n_cls)

        self.apply(init_weights)
        trunc_normal_(self.cls_emb, std=0.02)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_emb"}

    def forward(self, x, im_size=None):
        H, W = im_size

        GS = H // self.patch_size

        # 扩张维度从(1, n_cls, d_model)到(batch_size,n_cls,d_model)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        
        if self.mode == 'sa':
            # 拼接图像特征和类别嵌入
            # (batch_size,num_patches,d_model)
            x = torch.cat((x, cls_emb), 1)
            # 通过 Transformer 网络
            x = self.net(x)
            # 归一化处理
            x = self.decoder_norm(x)
            # patches优化后的图像特征。
            # cls_seg_feat:更新后的类别嵌入
            patches, cls_seg_feat = x[:, :-self.n_cls], x[:, -self.n_cls:]
        else:
            # 优化图像特征
            x = self.snet(x)
            # 归一化处理
            x = self.snorm(x)
            # 通过交叉注意力提取类别嵌入
            cls_emb = self.cnet(x, query=cls_emb)
            # 归一化
            cls_emb = self.cnorm(cls_emb)
            # patches优化后的图像特征。
            # cls_seg_feat:更新后的类别嵌入
            patches, cls_seg_feat = x, cls_emb

        #  向量标准化
        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        # 点积操作:生成掩码
        # patches:形状为 (batch_size, num_patches, d_model)。
        # cls_seg_feat:形状为 (batch_size, n_cls, d_model)
        # 转为 (batch_size, d_model, n_cls),方便点积运算。
        # 输出 masks 的形状为 (batch_size, num_patches, n_cls),表示每个 patch 属于每个类别的得分。
        masks = patches @ cls_seg_feat.transpose(1, 2)
        # 标准化为了简化训练
        masks = self.mask_norm(masks)
        # 重排掩码形状
        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))

        return masks

调用测试代码

def main():
    # 配置参数
    n_cls = 10           # 类别数,例如分割任务有 10 个类别
    patch_size = 16       # 图像分块大小
    n_layers = 4          # Transformer 层数
    n_heads = 8           # 多头注意力头数
    d_model = 128         # 特征嵌入维度
    dropout = 0.1         # dropout 比例
    mode = 'ca'           # 模式选择:'ca' 或 'sa'
    # 初始化 MaskTransformer
    model = MaskTransformer(
        n_cls=n_cls,
        patch_size=patch_size,
        n_layers=n_layers,
        n_heads=n_heads,
        d_model=d_model,
        dropout=dropout,
        mode=mode
    )
    # 测试输入
    batch_size = 2        # 批次大小
    image_size = 128      # 图像尺寸(假设输入为 128x128)
    num_patches = (image_size // patch_size) ** 2  # 分块后有多少个 patch
    # 生成随机的图像特征输入 (batch_size, num_patches, d_model)
    x = torch.randn(batch_size, num_patches, d_model)
    # 设置 im_size
    im_size = (image_size, image_size)
    # 运行模型
    masks = model(x, im_size=im_size)
    # 输出形状
    print("Output masks shape:", masks.shape)

实验

ADE20KcityscapePascalContext数据集
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244741.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

layui合并table相同内的行

<table border"1" id"table1" class"layui-table"><thead><tr><th><b>姓名</b></th><th><b>项目</b></th><th><b>任务</b></th><th><b>…

【大模型】大模型RAG检索增强生成技术使用详解

目录 一、前言 二、RAG技术介绍 2.1 RAG是什么 2.2 RAG工作原理 2.3 RAG优势 2.4 RAG应用场景 三、在线大模型平台RAG技术使用 3.1 阿里百炼平台 3.1.1 创建知识库 3.1.2 导入文档数据 3.1.3 文档数据解析 3.1.4 查看数据 3.2 百度文心智能体 3.2.1 创建知识库 3…

人工智能与SEO优化中的关键词策略解析

内容概要 在当今数字化快速发展的时代&#xff0c;人工智能&#xff08;AI&#xff09;与搜索引擎优化&#xff08;SEO&#xff09;的结合正变得愈发重要。关键词策略是SEO优化的一项基础工作&#xff0c;它直接影响到网站的可见性和流量。通过运用智能算法&#xff0c;企业能…

【WRF-Urban】WRF 4.3版本中城市模块更新总结

【WRF-Urban】WRF 4.3版本中城市模块更新总结 WRF 4.3 版本中城市模块更新1. 局地气候区&#xff08;LCZ&#xff09;的引入WRF 查找表的修改&#xff1a;如何启用 11 类 LCZ 分类&#xff1a; 2. 屋顶缓解策略与建筑材料渗透性3. 新的建筑物阻力系数处理 使用LCZ的WRF-Urban模…

【Apache Paimon】-- 6 -- 清理过期数据

目录 1、简要介绍 2、操作方式和步骤 2.1、调整快照文件过期时间 2.2、设置分区过期时间 2.2.1、举例1 2.2.2、举例2 2.3、清理废弃文件 3、参考 1、简要介绍 清理 paimon &#xff08;表&#xff09;过期数据可以释放存储空间&#xff0c;优化资源利用并提升系统运行效…

第二十周:机器学习

目录 摘要 ABSTRACT 一、吴恩达机器学习exp2——逻辑回归 1、logistic函数 2、数据预处理 3、损失函数 4、梯度下降 5、设定评价指标 6、决策边界 7、正则化 二、动手深度学习pytorch——数据预处理 1、数据集读取 2、缺失值处理 3、转换为张量格式 总结 摘要…

反转链表、链表内指定区间反转

反转链表 给定一个单链表的头结点pHead&#xff08;该头节点是有值的&#xff0c;比如在下图&#xff0c;它的val是1&#xff09;&#xff0c;长度为n&#xff0c;反转该链表后&#xff0c;返回新链表的表头。 如当输入链表{1,2,3}时&#xff0c;经反转后&#xff0c;原链表变…

VScode学习前端-01

小问题合集&#xff1a; vscode按&#xff01;有时候没反应&#xff0c;有时候出来&#xff0c;是因为------>必须在英文状态下输入&#xff01; 把鼠标放在函数、变量等上面&#xff0c;会自动弹出提示&#xff0c;但挡住视线&#xff0c;有点不习惯。 打开file->pre…

【AI图像生成网站Golang】JWT认证与令牌桶算法

AI图像生成网站 目录 一、项目介绍 二、雪花算法 三、JWT认证与令牌桶算法 四、项目架构 五、图床上传与图像生成API搭建 六、项目测试与调试(等待更新) 三、JWT认证与令牌桶算法 在现代后端开发中&#xff0c;用户认证和接口限流是确保系统安全性和性能的两大关键要素…

基于Kafka2.1解读Consumer原理

文章目录 概要整体架构流程技术名词解释技术细节coordinatorfetcherclientconsumer#poll的主要流程 全局总览小结 概要 继上一篇讲Producer原理的文章过去已经一个多月了&#xff0c;今天来讲讲Consumer的原理。 其实源码早就读了部分了&#xff0c;但是最近工作比较忙&#x…

测试使用vite搭建的uni-app打包app区分开发环境和生产环境

用脚手架搭建的uniapp项目&#xff0c;打包H5和小程序可以和web端一样&#xff0c;能够通过env.dev和env.prod区分开发环境和生产环境&#xff0c;但是不知道打包成app时如何区分开发环境和生产环境&#xff0c;在此做一个测试记录。 打开package.json文件&#xff0c;在scrip…

【提效工具开发】管理Python脚本执行系统实现页面展示

Python脚本执行&#xff1a;工具管理Python脚本执行系统 背景 在现代的软件开发和测试过程中&#xff0c;自动化工具和脚本的管理变得至关重要。为了更高效地管理工具、关联文件、提取执行参数并支持动态执行Python代码&#xff0c;我们设计并实现了一套基于Django框架的工具…

Qt-常用的显示类控件

QLabel QLabel有如下核心属性&#xff1a; 关于文本格式的验证&#xff1a; 其中<b>xxx<b>&#xff0c;就是加粗的意思。 效果&#xff1a; 或者再把它改为markdown形式的&#xff1a; 在markd中&#xff0c;#就是表示一级标题&#xff0c;我们在加上##后&#x…

2024 RISC-V中国峰会 安全相关议题汇总

安全之安全(security)博客目录导读 第四届 RISC-V 中国峰会(RISC-V Summit China 2024)于8月21日至23日在杭州成功举办。此次峰会汇聚了 RISC-V 国际基金会、百余家重点企业及研究机构,约3000人线下参与,并在19日至25日间举办了超过20场同期活动,与全球开发者共同…

聊一聊Elasticsearch的索引分片的恢复机制

1、什么是索引分片的恢复&#xff1f; 所谓索引分片的恢复指的是在某些条件下&#xff0c;索引分片丢失&#xff0c;ES会把某索引的分片复制一份来得到该分片副本的过程。 2、触发分片恢复的场景有哪些&#xff1f; 分片的分配 当集群中节点的数量发生变化&#xff0c;或者配…

典型的 SOME/IP 多绑定用例总结

SOME/IP 部署中 AP SWC 不自行打开套接字连接的原因 在典型的 SOME/IP 网络协议部署场景里&#xff0c;AP SWC 不太可能自己打开套接字连接与远程服务通信&#xff0c;因为 SOME/IP 被设计为尽可能少用端口。这一需求源于低功耗 / 低资源的嵌入式 ECU&#xff0c;并行管理大量…

MySQL查询执行(八):Memory引擎

思考&#xff1a;两个group by语句都用了order bynull&#xff0c; 为什么使用内存临时表得到的语句结果里&#xff0c; 0这个值在最后一行&#xff1b; 而使用磁盘临时表得到的结果里&#xff0c; 0这个值在第一行&#xff1f; 答&#xff1a;答案对应第一小节&#xff1a;内…

canva 画图 UI 设计

起因&#xff0c; 目的: 来源: 客户需求。 目的&#xff1a; 用数据讲故事。 数据可以瞎编&#xff0c;图表一定要漂亮。 文件分享地址 读者可以在此文件的基础上&#xff0c;继续编辑。 效果图 过程: 我还是喜欢 canva. figma&#xff0c; 我用的时候&#xff0c;每每都想…

ES分词环境实战

文章目录 安装下载1.1 下载镜像1.2 单节点启动 防火墙设置异常处理【1】iptable链路中断 参考文档 参加完2024年11月软考&#xff0c;对ES的分词进行考查&#xff0c;前期有【 Docker 环境下安装部署 Elasticsearch 和 kibana】和【 Docker 环境下为 Elasticsearch 安装IK 分…

论文精读: PRB LiVSe2 Zigzag链序实验与理论计算

DOI: 10.1103/PhysRevB.108.094107 摘要节选 在具有轨道自由度的过渡金属化合物中&#xff0c;组成元素在低温下自组装形成分子的现象普遍存在。 在本研究中从实验和理论两方面讨论了钒二维三角形晶格层状LiVX2 &#xff08;X O&#xff0c; S, Se&#xff09;体系中出现的三…