学习笔记01----即插即用的解码器模块DEPICT
- 前言
- 源码下载
- DEPICT实现
- 实验
前言
文 章 标 题:《Rethinking Decoders for Transformer-based Semantic Segmentation: Compression is All You Need》
当前的 Transformer-based 方法(如 DETR 和其变体)取得了显著进展。但这些解码器(decoder)的设计更多是基于经验,缺乏理论解释,难以确定性能瓶颈并进行进一步改进。
该论文将语义分割任务建模为“从主空间到子空间的信息压缩”问题,强调从高维图像特征中提取类别相关的紧凑表示。
提出 DEPICT 解码器:
- 基于 自注意力(MSSA) 和 交叉注意力(MSCA) 设计简单高效的解码器。
- MSSA 构建主子空间,去除冗余,优化图像特征。
- MSCA 动态提取类别相关特征,生成类别嵌入的低维表示。
源码下载
源代码地址:https://github.com/QishuaiWen/DEPICT
DEPICT实现
DEPICT流程:
1. 图像特征输入: 通过vit的主干网络对图像进行特征提取。这些特征中可能包含很多不重要的信息,比如背景噪声。我们的目标是提取出与分类相关的特征。
2.sa模式—自注意力模块(MSSA): 通过自注意力机制(Multi-head Subspace Self-Attention, MSSA),捕捉图像块之间的全局关系,去掉不相关信息,优化出更加紧凑的主要特征(主子空间)。它的具体操作是将 类别嵌入向量 与 图像特征进行 拼接操作 输入 MSSA模块进行特征优化。
3.ca模式—交叉注意力模块(MSCA):类别嵌入(这是一个可学习的特征向量)作为查询,图像特征作为键和值,通过交叉注意力(Multi-head Subspace Cross-Attention, MSCA)提取每个类别的相关特征,生成类别嵌入的低维表示。它的具体操作是将 类别嵌入向量 作为 查询向量 通过MSCA进行特征优化。
类别嵌入向量是一个可学习的参数,是从 主空间中提取 出的,与类别强相关的特征子集,是图像特征的降维。
4.生成分割掩码:用点积操作比较图像特征和类别嵌入,生成每块图像属于每个类别的概率。
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
from dec_blocks import Transformer
from segm.model.utils import init_weights
class MaskTransformer(nn.Module):
def __init__(
self,
n_cls,#类别数量
patch_size,# 图像分块大小
n_layers, # Transformer 的层数
n_heads, # 多头注意力中的头数
d_model, # 特征的嵌入维度
dropout, # dropout 概率
mode='ca', # 模式选择:'ca' (交叉注意力) 或 'sa' (自注意力)
):
super().__init__()
self.patch_size = patch_size
self.n_cls = n_cls
self.mode = mode
# cls_emb 是类别嵌入矩阵,初始化为随机值,形状为 (1, n_cls, d_model)。
# 在 DEPICT 中,类别嵌入对应于主子空间的基向量 P
self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
if mode == 'sa':
# 提取图像主特征
self.net = Transformer(d_model, n_layers, n_heads, 100, dropout)
self.decoder_norm = nn.LayerNorm(d_model)
elif mode == 'ca':
# 用于优化图像特征的主特征
self.snet = Transformer(d_model, n_layers, n_heads, 100, dropout)
# 用于进一步提取类别嵌入
self.cnet = Transformer(d_model, 3, n_heads, 50, dropout)
self.snorm = nn.LayerNorm(d_model)
self.cnorm = nn.LayerNorm(d_model)
else:
raise ValueError(f"Provided mode: {mode} is not valid.")
self.mask_norm = nn.LayerNorm(n_cls)
self.apply(init_weights)
trunc_normal_(self.cls_emb, std=0.02)
@torch.jit.ignore
def no_weight_decay(self):
return {"cls_emb"}
def forward(self, x, im_size=None):
H, W = im_size
GS = H // self.patch_size
# 扩张维度从(1, n_cls, d_model)到(batch_size,n_cls,d_model)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
if self.mode == 'sa':
# 拼接图像特征和类别嵌入
# (batch_size,num_patches,d_model)
x = torch.cat((x, cls_emb), 1)
# 通过 Transformer 网络
x = self.net(x)
# 归一化处理
x = self.decoder_norm(x)
# patches优化后的图像特征。
# cls_seg_feat:更新后的类别嵌入
patches, cls_seg_feat = x[:, :-self.n_cls], x[:, -self.n_cls:]
else:
# 优化图像特征
x = self.snet(x)
# 归一化处理
x = self.snorm(x)
# 通过交叉注意力提取类别嵌入
cls_emb = self.cnet(x, query=cls_emb)
# 归一化
cls_emb = self.cnorm(cls_emb)
# patches优化后的图像特征。
# cls_seg_feat:更新后的类别嵌入
patches, cls_seg_feat = x, cls_emb
# 向量标准化
patches = patches / patches.norm(dim=-1, keepdim=True)
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
# 点积操作:生成掩码
# patches:形状为 (batch_size, num_patches, d_model)。
# cls_seg_feat:形状为 (batch_size, n_cls, d_model)
# 转为 (batch_size, d_model, n_cls),方便点积运算。
# 输出 masks 的形状为 (batch_size, num_patches, n_cls),表示每个 patch 属于每个类别的得分。
masks = patches @ cls_seg_feat.transpose(1, 2)
# 标准化为了简化训练
masks = self.mask_norm(masks)
# 重排掩码形状
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
return masks
调用测试代码:
def main():
# 配置参数
n_cls = 10 # 类别数,例如分割任务有 10 个类别
patch_size = 16 # 图像分块大小
n_layers = 4 # Transformer 层数
n_heads = 8 # 多头注意力头数
d_model = 128 # 特征嵌入维度
dropout = 0.1 # dropout 比例
mode = 'ca' # 模式选择:'ca' 或 'sa'
# 初始化 MaskTransformer
model = MaskTransformer(
n_cls=n_cls,
patch_size=patch_size,
n_layers=n_layers,
n_heads=n_heads,
d_model=d_model,
dropout=dropout,
mode=mode
)
# 测试输入
batch_size = 2 # 批次大小
image_size = 128 # 图像尺寸(假设输入为 128x128)
num_patches = (image_size // patch_size) ** 2 # 分块后有多少个 patch
# 生成随机的图像特征输入 (batch_size, num_patches, d_model)
x = torch.randn(batch_size, num_patches, d_model)
# 设置 im_size
im_size = (image_size, image_size)
# 运行模型
masks = model(x, im_size=im_size)
# 输出形状
print("Output masks shape:", masks.shape)
实验
在ADE20K、cityscape和PascalContext数据集