论文阅读与源码解析:CMX

news2024/9/20 18:45:44

论文阅读与源码解析:CMX: Cross-Modal Fusion for RGB-X Semantic Segmentation with Transformers

论文地址:https://arxiv.org/pdf/2203.04838
GitHub项目地址:https://github.com/huaaaliu/RGBX_Semantic_Segmentation
源码:https://github.com/huaaaliu/RGBX_Semantic_Segmentation/blob/main/models/net_utils.py

Motivation

不同类型的传感器可以提供具有丰富互补信息的RGB图像。例如,深度测量可以帮助识别物体的边界,并提供密集场景元素的几何信息。热图像有助于通过特定的红外成像识别不同的物体。此外,极化和事件信息有利于镜面反射和动态真实场景的感知。激光雷达数据可以在驾驶场景中提供空间信息。
img_3
现有的多模态语义分割方法可以分为两类:(1)第一类采用单一网络从RGB和另一种模态中提取特征,融合在输入阶段(见图2a)。(2) 第二类方法部署两个主干分别从 RGB- 和另一种模态中提取特征,然后将提取的两个特征融合为一个特征以进行语义预测(见图 2b)。然而,这两种类型通常是针对单个特定模态对(例如 RGB-D 或 RGB-T)量身定制的,但很难扩展到其他模态组合进行操作。为了解决上述挑战,我们提出了 CMX,这是一种通用跨模态融合框架,用于交互式融合方式的 RGB-X 语义分割(图 2c)。具体来说,CMX 被构建为双流架构,即 RGB 和 X 模态流。设计了两个特定的模块,用于两者之间的特征交互和特征融合。
20240828105110

Method

作者在特征提取阶段让两种模态的特征进行交互,并且融合。
20240828105351

  1. 跨模态特征校正模块(CM-FRM),通过利用它们的空间和通道相关性来校准双模态特征,这使得两个流能够更多地关注彼此的互补信息线索,并减轻来自不同模态的不确定性和噪声测量的影响。这种特征校正解决了不同模式的不同噪声和不确定性。它支持更好的多模态特征提取和交互。
  2. 特征融合模块(FFM)分两个阶段构建,在合并特征之前进行充分的信息交换。受自我注意获得的大接受域的启发,在FFM的第一阶段设计了一种交叉注意机制来实现跨模态全局推理。在第二阶段,应用混合通道嵌入来产生增强的输出特征。

源码解读

  1. 所需要的包
import torch
import torch.nn as nn

from timm.models.layers import trunc_normal_
import math
  1. CM-FRM模块本质就是利用通道注意力和空间注意力时两个模态的特征进行交互,也就是将两种模态cat后计算的通道注意力和空间注意力后,在计算注意力时进行了一定的交互,然后再将对应的注意力权重乘以特征后按一定比例加到另一个模态特征中。
# Feature Rectify Module  
# 计算通道注意力,这里利用了两种池化(最大池化和平均池化)
# 将两种池化的结果cat然后计算通道注意力权重
class ChannelWeights(nn.Module):
    def __init__(self, dim, reduction=1):
        super(ChannelWeights, self).__init__()
        self.dim = dim
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
                    nn.Linear(self.dim * 4, self.dim * 4 // reduction),
                    nn.ReLU(inplace=True),
                    nn.Linear(self.dim * 4 // reduction, self.dim * 2), 
                    nn.Sigmoid())

    def forward(self, x1, x2):
        B, _, H, W = x1.shape
        x = torch.cat((x1, x2), dim=1)
        avg = self.avg_pool(x).view(B, self.dim * 2)
        max = self.max_pool(x).view(B, self.dim * 2)
        y = torch.cat((avg, max), dim=1) # B 4C
        y = self.mlp(y).view(B, self.dim * 2, 1)
        channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1
        return channel_weights

# 计算空间注意力
class SpatialWeights(nn.Module):
    def __init__(self, dim, reduction=1):
        super(SpatialWeights, self).__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
                    nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(self.dim // reduction, 2, kernel_size=1), 
                    nn.Sigmoid())

    def forward(self, x1, x2):
        B, _, H, W = x1.shape
        x = torch.cat((x1, x2), dim=1) # B 2C H W
        spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W
        return spatial_weights

# FRM模块
class FeatureRectifyModule(nn.Module):
    def __init__(self, dim, reduction=1, lambda_c=.5, lambda_s=.5):
        super(FeatureRectifyModule, self).__init__()
        self.lambda_c = lambda_c
        self.lambda_s = lambda_s
        self.channel_weights = ChannelWeights(dim=dim, reduction=reduction)
        self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    
    def forward(self, x1, x2):
        channel_weights = self.channel_weights(x1, x2)
        spatial_weights = self.spatial_weights(x1, x2)
        out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2
        out_x2 = x2 + self.lambda_c * channel_weights[0] * x1 + self.lambda_s * spatial_weights[0] * x1
        return out_x1, out_x2 

# Stage 1
# 利用的是线性注意力公式,也就是先将kv相乘,然后再与q相乘, 这样可以减少计算量
# 交叉注意力也就是一种模态的q去另一种模态生成的kv图里面去查询
class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
        super(CrossAttention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias)

    def forward(self, x1, x2):
        B, N, C = x1.shape
        q1 = x1.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
        q2 = x2.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
        k1, v1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
        k2, v2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()

        ctx1 = (k1.transpose(-2, -1) @ v1) * self.scale
        ctx1 = ctx1.softmax(dim=-2)
        ctx2 = (k2.transpose(-2, -1) @ v2) * self.scale
        ctx2 = ctx2.softmax(dim=-2)

        x1 = (q1 @ ctx2).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() 
        x2 = (q2 @ ctx1).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() 

        return x1, x2


class CrossPath(nn.Module):
    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.LayerNorm):
        super().__init__()
        self.channel_proj1 = nn.Linear(dim, dim // reduction * 2)
        self.channel_proj2 = nn.Linear(dim, dim // reduction * 2)
        self.act1 = nn.ReLU(inplace=True)
        self.act2 = nn.ReLU(inplace=True)
        self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads)
        self.end_proj1 = nn.Linear(dim // reduction * 2, dim)
        self.end_proj2 = nn.Linear(dim // reduction * 2, dim)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)

    def forward(self, x1, x2):
        y1, u1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1)
        y2, u2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1)
        v1, v2 = self.cross_attn(u1, u2)
        y1 = torch.cat((y1, v1), dim=-1)
        y2 = torch.cat((y2, v2), dim=-1)
        out_x1 = self.norm1(x1 + self.end_proj1(y1))
        out_x2 = self.norm2(x2 + self.end_proj2(y2))
        return out_x1, out_x2


# Stage 2
class ChannelEmbed(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d):
        super(ChannelEmbed, self).__init__()
        self.out_channels = out_channels
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.channel_embed = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels//reduction, kernel_size=1, bias=True),
                        nn.Conv2d(out_channels//reduction, out_channels//reduction, kernel_size=3, stride=1, padding=1, bias=True, groups=out_channels//reduction),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(out_channels//reduction, out_channels, kernel_size=1, bias=True),
                        norm_layer(out_channels) 
                        )
        self.norm = norm_layer(out_channels)
        
    def forward(self, x, H, W):
        B, N, _C = x.shape
        x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous()
        residual = self.residual(x)
        x = self.channel_embed(x)
        out = self.norm(residual + x)
        return out

# FFM模块,先进行Attention交互特征,然后利用卷积和线性层融合双模态特征
class FeatureFusionModule(nn.Module):
    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.cross = CrossPath(dim=dim, reduction=reduction, num_heads=num_heads)
        self.channel_emb = ChannelEmbed(in_channels=dim*2, out_channels=dim, reduction=reduction, norm_layer=norm_layer)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x1, x2):
        B, C, H, W = x1.shape
        x1 = x1.flatten(2).transpose(1, 2)
        x2 = x2.flatten(2).transpose(1, 2)
        x1, x2 = self.cross(x1, x2) 
        merge = torch.cat((x1, x2), dim=-1)
        merge = self.channel_emb(merge, H, W)
        
        return merge

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2083593.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

生产es所有节点全部掉线 排查

生产es所有节点全部掉线 查看message日志发现 内存溢出 修改jvm的改小 清理buff/cache sync && echo 1 > /proc/sys/vm/drop_caches sync && echo 2 > /proc/sys/vm/drop_caches sync && echo 3 > /proc/sys/vm/drop_caches 把es内存的…

GenAI 斜杠计划丨开启职业加速密码:图文设计专场参会体验

目录 前言 活动概览 活动开始:AI时代的召唤 主题分享一:《看到GenAI的力量:Amazon Bedrock功能操作》 1. 大模型的选择与理解 2. Amazon Bedrock的神奇魅力 主题分享二:《创意与技术的交汇:Zilliz向量数据库助力…

element的日期时间修改时间没有秒以及默认的时间时分修改

<el-form-item label"上架时间" required"required"><el-form:model"courseForm"ref"unmountFormRef"inlinestyle"text-align: left"label-position"left":rules"sjtimeRules"><el-form…

搜维尔科技:人形机器人的动作捕捉技术是实现机器人拟人化动作的关键技术之一

人形机器人的动作捕捉技术是实现机器人拟人化动作的关键技术之一&#xff0c;以下为您详细介绍几款动作捕捉系统&#xff1a; 1.光学式动作捕捉&#xff1a; • 原理&#xff1a;通过在人体关键部位&#xff08;如关节&#xff09;贴上反光标记点&#xff0c;利用多个高速摄像…

如何使用mcu 内置 flash 实现fatfs

一、环境与目的 AT32F403AVGT7&#xff0c;FLASH从0x80e0000到最后&#xff0c;共128K。扇区大小为512。 注意&#xff1a;Flash 的扇区大小为2KB。 fatfs 80286 /* Revision ID */ 目标在于利用单片机1MBflash后面的一小部分&#xff0c;以方便应用程序存储系统参数。 …

Ubuntu上安装剪切板管理软件

1. 更新系统和软件 确保你的系统和软件是最新的&#xff0c;有时更新可以修复这类错误。 sudo apt update sudo apt upgrade 2. 重新安装 Diodon 尝试卸载并重新安装 Diodon。 sudo apt remove diodon sudo apt install diodon 3. 检查依赖项 确保系统中安装了所有必要…

Aiseesoft Data Recovery for Mac:专业级数据恢复解决方案

在数字时代&#xff0c;数据的安全与恢复成为了我们不可忽视的重要议题。对于Mac用户而言&#xff0c;Aiseesoft Data Recovery无疑是一款值得信赖的专业级数据恢复软件。它以其强大的恢复能力、简洁的操作界面以及广泛的兼容性&#xff0c;在众多数据恢复工具中脱颖而出&#…

I 2U-Net: 一种具有丰富信息交互的双路径U-Net用于医学图像分割|文献速递-大模型与多模态诊断阿尔茨海默症与帕金森疾病

Title 题目 I 2U-Net: A dual-path U-Net with rich information interaction for medical image segmentation I 2U-Net: 一种具有丰富信息交互的双路径U-Net用于医学图像分割 01 文献速递介绍 在计算机视觉领域&#xff0c;医学图像分割是主要的挑战之一&#xff0c;例如…

【Java】—— Java面向对象基础:使用Java模拟银行账户与客户交易系统

目录 账户类&#xff08;Account&#xff09; 客户类&#xff08;Customer&#xff09; 测试类&#xff08;CustomerTest&#xff09; 运行结果 在今天的博文中&#xff0c;我们将通过Java编程语言来模拟一个简单的银行账户与客户交易系统。这个系统将包括两个主要类&#…

算法设计:实验四回溯法

【实验目的】 应用回溯法求解图的着色问题 【实验要求】 设下图G(V,E)是一连通无向图&#xff0c;有3种颜色&#xff0c;用这些颜色为G的各顶点着色&#xff0c;每个顶点着一种颜色&#xff0c;且相邻顶点颜色不同。试用回溯法设计一个算法&#xff0c;找出所有可能满足上述…

Cookie、Session、Token、JWT的区别

先总结 其实比较的话就只是Session、Token、JWT的区别&#xff0c;Session是基于Cookie的 这里暂时只比较Session和JWT的区别 存放位置不同 Session基于Cookie存储在服务端JWT存放在客户端&#xff0c;通常是在浏览器的Cookie或LocalStorage中。 JWT将 Token 和 Payload 加…

从学习到工作,2024年不可或缺的翻译助手精选

翻译工具利用先进的机器学习和自然语言处理技术&#xff0c;能够迅速将一种语言的文档转换为另一种语言&#xff0c;极大地促进了信息的无障碍流通。接下来&#xff0c;我们将介绍几款功能强大、操作简便的类似deepl翻译的工具&#xff0c;帮助你轻松应对各种翻译需求。 第一款…

容器网络(桥接、host、none)及跨主机网络(etcd、flannel、docker)

1.本地网络 1.bridge 所有容器连接到桥就可以使用外网&#xff0c;使用nat让容器可以访问外网&#xff0c;使用ip a s指令查看桥&#xff0c;所有容器连接到此桥&#xff0c;ip地址都是 172.17.0.0/16网段&#xff0c;桥是启动docker服务后出现&#xff0c;在centos使用bridge…

深度强化学习算法(五)(附带MATLAB程序)

深度强化学习&#xff08;Deep Reinforcement Learning, DRL&#xff09;结合了深度学习和强化学习的优点&#xff0c;能够处理具有高维状态和动作空间的复杂任务。它的核心思想是利用深度神经网络来逼近强化学习中的策略函数和价值函数&#xff0c;从而提高学习能力和决策效率…

数据结构(6.4_2)——最短路径问题_BFS算法

最短路径问题 BFS求无权图的单源最短路径 原代码 改造visit函数后

list的使用及其相关知识点

目录 ◉list的底层逻辑 ◉关于list的新增功能 ▲splice功能 ▲remove函数 ▲unique函数 ▲merge函数 ▲sort函数 ▣迭代器类型 ▲reverse函数 作为数据容器之一的list和其他容器的使用上有很多相似的地方&#xff0c;比如都有大致相同的构造函数&#xff0c;大致相同的头插尾插…

CUDA编程之CUDA Sample-5_Domain_Specific-volumeFiltering(光线追踪)

volumeFiltering演示了使用 3D 纹理和 3D 表面写入进行 3D 体积过滤。它从磁盘加载一个 3D 体积&#xff0c;并使用光线步进和 3D 纹理进行显示。 以下是该示例的主要内容和功能&#xff1a; 主要功能 3D 体积加载: 从磁盘加载 3D 体积数据&#xff0c;通常为医学成像数据或体…

图像处理中的腐蚀与膨胀算法详解

引言 在图像处理领域&#xff0c;形态学操作&#xff08;Morphological Operations&#xff09;是处理二值图像的重要工具。腐蚀&#xff08;Erosion&#xff09;和膨胀&#xff08;Dilation&#xff09;是形态学操作的两种基本形式&#xff0c;它们常用于消除噪声、分割图像、…

深入解析C#中的锁机制:`lock(this)`、`lock(privateObj)`与`lock(staticObj)`的区别

前言 在C#的多线程编程中&#xff0c;lock关键字是确保线程安全的重要工具。它通过锁定特定的对象&#xff0c;防止多个线程同时访问同一块代码&#xff0c;从而避免数据竞争和资源冲突。然而&#xff0c;选择适当的锁对象对于实现高效的线程同步至关重要。本文将深入探讨使用…

三种tcp并发服务器实现程序

都需先进行tcp连接 1、多进程并发 2、多线程并发 3、IO多路复用并发 &#xff08;1&#xff09;select &#xff08;2&#xff09;epoll