即插即用模块(1) -MAFM特征融合

news2025/4/21 12:17:20

(即插即用模块-特征处理部分) 一、(2024) MAFM&MCM 特征融合+特征解码

在这里插入图片描述

paper:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection

1. 多尺度感知融合模块 (MAFM)

多尺度感知融合模块 (MAFM) 旨在高效融合 RGB 和深度模态的互补信息,充分利用 RGB 图像的丰富纹理细节和深度图像的空间结构特性,同时克服 RGB 对光照变化的敏感性以及深度图像细节不足的局限性。通过多尺度特征整合和非线性变换,MAFM 实现高效的特征融合,同时降低计算复杂度。

实现流程:
  1. 特征拼接:将 RGB 和深度特征图沿通道维度拼接,形成统一的多模态特征表示,保留各模态的独特信息。
  2. 深度可分离卷积 (DW 层):应用深度可分离卷积高效提取空间局部特征,随后进行批归一化 (BN) 以稳定训练,并通过 GELU 激活函数引入非线性,提升特征表达能力。
  3. 点卷积 (PW 层):通过点卷积优化通道间交互,再次应用 BN 和 GELU 激活,确保特征的有效 recalibration。
  4. 多头多尺度卷积 (MHMC):将融合特征输入 MHMC 模块,通过多尺度卷积捕捉不同尺度的上下文信息,进一步增强特征融合效果。
  5. 残差融合:通过残差结构和元素级求和,整合不同分支的特征图,保留全局和局部信息。
  6. 非线性变换:最终通过 GELU 激活函数进行非线性变换,生成融合特征图。

Multi-scale Awareness Fusion Module 结构图:
在这里插入图片描述

2. 多级卷积模块 (MCM)

多级卷积模块 (MCM) 旨在通过多尺度特征融合,逐步生成包含丰富细节的噪声目标预测图。MCM 采用残差结构,包含多个卷积块,通过整合不同尺度的特征图显著提升解码器的学习能力和泛化性能。

实现流程:
  1. 特征上采样与拼接:对高级特征图进行上采样,并与下一级特征图沿通道维度拼接,构建多尺度特征表示。
  2. 深度可分离卷积 (DW 层):使用深度可分离卷积提取空间特征,随后进行 BN 和 GELU 激活,以高效处理多尺度信息。
  3. 点卷积 (PW 层):通过点卷积优化通道间特征交互,再次应用 BN 和 GELU 激活,确保特征鲁棒性。
  4. 残差连接:将融合特征图与残差连接的结果进行元素级求和,生成最终输出,保留细节并增强稳定性。

Multi-level Convolution Module 结构图:

在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from timm.models.layers import trunc_normal_


# Conv_One_Identity
class COI(nn.Module):
    def __init__(self, inc, k=3, p=1):
        super().__init__()
        self.outc = inc
        self.dw = nn.Conv2d(inc, self.outc, kernel_size=k, padding=p, groups=inc)
        self.conv1_1 = nn.Conv2d(inc, self.outc, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(self.outc)
        self.bn2 = nn.BatchNorm2d(self.outc)
        self.bn3 = nn.BatchNorm2d(self.outc)
        self.act = nn.GELU()
        self.apply(self._init_weights)

    def forward(self, x):
        shortcut = self.bn1(x)

        x_dw = self.bn2(self.dw(x))
        x_conv1_1 = self.bn3(self.conv1_1(x))
        return self.act(shortcut + x_dw + x_conv1_1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()


class MHMC(nn.Module):
    def __init__(self, dim, ca_num_heads=4, qkv_bias=True, proj_drop=0., ca_attention=1, expand_ratio=2):
        super().__init__()

        self.ca_attention = ca_attention
        self.dim = dim
        self.ca_num_heads = ca_num_heads

        assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."

        self.act = nn.GELU()
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.split_groups = self.dim // ca_num_heads

        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.s = nn.Linear(dim, dim, bias=qkv_bias)
        for i in range(self.ca_num_heads):
            local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2),
                                   padding=(1 + i), stride=1,
                                   groups=dim // self.ca_num_heads)  # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4
            setattr(self, f"local_conv_{i + 1}", local_conv)
        self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1,
                               groups=self.split_groups)
        self.bn = nn.BatchNorm2d(dim * expand_ratio)
        self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        v = self.v(x)
        s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1,
                                                                                          2)  # num_heads,B,C,H,W
        for i in range(self.ca_num_heads):
            local_conv = getattr(self, f"local_conv_{i + 1}")
            s_i = s[i]  # B,C,H,W
            s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)
            if i == 0:
                s_out = s_i
            else:
                s_out = torch.cat([s_out, s_i], 2)
        s_out = s_out.reshape(B, C, H, W)
        s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))
        self.modulator = s_out
        s_out = s_out.reshape(B, C, N).permute(0, 2, 1)
        x = s_out * v

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# Multi-scale Awareness Fusion Module
class MAFM(nn.Module):
    def __init__(self, inc):
        super().__init__()
        self.outc = inc
        self.attention = MHMC(dim=inc)
        self.coi = COI(inc)
        self.pw = nn.Sequential(
            nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=1, stride=1),
            nn.BatchNorm2d(inc),
            nn.GELU()
        )
        self.pre_att = nn.Sequential(
            nn.Conv2d(inc * 2, inc * 2, kernel_size=3, padding=1, groups=inc * 2),
            nn.BatchNorm2d(inc * 2),
            nn.GELU(),
            nn.Conv2d(inc * 2, inc, kernel_size=1),
            nn.BatchNorm2d(inc),
            nn.GELU()
        )

        self.apply(self._init_weights)

    def forward(self, x, d):
        B, C, H, W = x.shape
        x_cat = torch.cat((x, d), dim=1)
        x_pre = self.pre_att(x_cat)
        # Attention
        x_reshape = x_pre.flatten(2).permute(0, 2, 1)  # B,C,H,W to B,N,C
        attention = self.attention(x_reshape, H, W)  # attention
        attention = attention.permute(0, 2, 1).reshape(B, C, H, W)  # B,N,C to B,C,H,W

        # COI
        x_conv = self.coi(attention)  # dw3*3,1*1,identity
        x_conv = self.pw(x_conv)  # pw

        return x_conv

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()


# Decoder
class MCM(nn.Module):
    def __init__(self, inc, outc):
        super().__init__()
        self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.rc = nn.Sequential(
            nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
            nn.BatchNorm2d(inc),
            nn.GELU(),
            nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
            nn.BatchNorm2d(outc),
            nn.GELU()
        )
        self.predtrans = nn.Sequential(
            nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc),
            nn.BatchNorm2d(outc),
            nn.GELU(),
            nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1)
        )

        self.rc2 = nn.Sequential(
            nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2),
            nn.BatchNorm2d(outc * 2),
            nn.GELU(),
            nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1),
            nn.BatchNorm2d(outc),
            nn.GELU()
        )

        self.apply(self._init_weights)

    def forward(self, x1, x2):
        x2_upsample = self.upsample2(x2)  # 上采样
        x2_rc = self.rc(x2_upsample)  # 减少通道数
        shortcut = x2_rc

        x_cat = torch.cat((x1, x2_rc), dim=1)  # 拼接
        x_forward = self.rc2(x_cat)  # 减少通道数2
        x_forward = x_forward + shortcut
        pred = F.interpolate(self.predtrans(x_forward), 384, mode="bilinear", align_corners=True)  # 预测图

        return pred, x_forward

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()


if __name__ == '__main__':
    x = torch.randn(4, 16, 128, 128).cuda()
    y = torch.randn(4, 16, 128, 128).cuda()
    z = torch.randn(4, 32, 64, 64).cuda()
    model = MAFM(16).cuda()
    out = model(x, y)

    # model = MCM(32, 16).cuda()
    # _, out = model(x, z)
    # print(out.shape)
':
    x = torch.randn(4, 16, 128, 128).cuda()
    y = torch.randn(4, 16, 128, 128).cuda()
    z = torch.randn(4, 32, 64, 64).cuda()
    model = MAFM(16).cuda()
    out = model(x, y)

    # model = MCM(32, 16).cuda()
    # _, out = model(x, z)
    # print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2339416.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(学习总结34)Linux 库制作与原理

Linux 库制作与原理 库的概念静态库操作归档文件命令 ar静态库制作静态库使用 动态库动态库制作动态库使用与运行搜索路径问题解决方案方案2:建立同名软链接方案3:使用环境变量 LD_LIBRARY_PATH方案4:ldconfig 方案 使用外部库目标文件ELF 文…

DSP28335入门学习——第一节:工程项目创建

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难,但我还是想去做! 本文写于:2025.04.20 DSP28335开发板学习——第一节:工程项目创建 前言开发板说明引用解答…

MDG 实现后端主数据变更后快照自动刷新的相关设置

文章目录 前言实现过程BGRFC期初配置(可选)设置 MDG快照 BGRFC维护BP出站功能模块 监控 前言 众所周知,在MDG变更请求创建的同时,所有reuse模型实体对应的快照snapshot数据都会记录下来。随后在CR中,用户可以修改这些…

【Linux】Linux 操作系统 - 05 , 软件包管理器和 vim 编辑器的使用 !

文章目录 前言一、软件包管理器1 . 软件安装2 . 包管理器3 . Linux 生态 二、软件安装 、卸载三、vim 的使用1 . 什么是 vim ?2 . vim 多模式3 . 命令模式 - 命令4 . 底行模式 - 命令5. 插入模式6 . 替换模式7 . V-BLOCK 模式8 . 技巧补充 总结 前言 本篇笔者将会对软件包管理…

【操作系统原理05】存储器管理

大纲 文章目录 大纲一. 内存基础知识0.大纲1.什么是内存2.进程运行基本原理2.1 指令工作原理2.2逻辑地址VS物理地址2.3 从写程序到程序运行完整运行三种链接方式 二.内存管理0.大纲1.操作系统进行内存管理 三.覆盖与交换0.大纲1.覆盖技术2.交换技术 四.连续分配管理方式0.大纲1…

学习笔记—C++—string(练习题)

练习题 仅仅反转字母 917. 仅仅反转字母 - 力扣(LeetCode) 题目 给你一个字符串 s ,根据下述规则反转字符串: 所有非英文字母保留在原有位置。所有英文字母(小写或大写)位置反转。 返回反转后的 s 。…

[Swift]Xcode模拟器无法请求http接口问题

1.以前偷懒一直是这样设置 <key>NSAppTransportSecurity</key> <dict><key>NSAllowsArbitraryLoads</key><true/><key>NSAllowsArbitraryLoadsInWebContent</key><true/> </dict> 现在我在Xcode16.3上&#xff…

返回之术:用 navigate(-1) 闯荡前端江湖

前言 在前端这片江湖,页面跳转宛如轻功水上漂,来去无踪,飘忽不定。但其中有一门绝学,专治“回头是岸”之需求,那便是 React Router 中的 navigate(-1) 身法。 昔日我闯荡项目林,误入“下一页”禁地,一脚踏空,身陷页面迷阵。正当我焦头烂额之际,师父袖袍一挥,口吐一…

网络编程3

day3 一、服务器模型 1.循环服务器模型 同一个时刻只能响应一个客户端的请求 2.并发服务器模型 2.1含义 同一个时刻可以响应多个客户端的请求&#xff0c;常用的模型有多进程模型/多线程模型/IO多路复用模型。 2.2多进程模型 每来一个客户端连接&#xff0c;开一个子进程来专门…

海拔与大气压关系,大气压单位,气压传感器对比

mbmbar 毫巴(百帕) mbar 毫巴(百帕) hPa 百帕 1百帕1毫巴3/4毫米水银柱 1Kpa10百帕7.5毫米汞柱7.5mmhg 1Bar0.1MPa1000mba1000hpa100*7.5mmhg75mmhg1个大气压 HP303B HP303S HP203N BMP280

Linux 进程概念补充 (自用)

进程概念 内核进程进程状态内存泄漏进程调度。Linux真实调度算法环境变量 内核 狭义上的操作系统指的是 内核就是进程管理进程调度&#xff0c;文件系统等等。 广义上的操作系统其实在外壳指令这些。封装了系统调用的东西。 进程 课本概念程序的一个基本实例 内核观点&#…

PyTorch - Tensor 学习笔记

上层链接&#xff1a;PyTorch 学习笔记-CSDN博客 Tensor 初始化Tensor import torch import numpy as np# 1、直接从数据创建张量。数据类型是自动推断的 data [[1, 2],[3, 4]] x_data torch.tensor(data)torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])输出&am…

Navicat、DataGrip、DBeaver在渲染 BOOLEAN 类型字段时的一种特殊“视觉风格”

文章目录 前言✅ 为什么 Boolean 字段显示为 [ ]&#xff1f;✅ 如何验证实际数据类型&#xff1f;✅ 小结 前言 看到的 deleted: [ ] 并不是 Prisma 的问题&#xff0c;而是数据库客户端&#xff08;如 Navicat、DataGrip、DBeaver&#xff09;在渲染 BOOLEAN 类型字段时的一种…

基于 Vue3 + ECharts + GeoJson 实现区域地图钻取功能详解

文章目录 前言一、实现步骤1. 项目初始化2. 准备GeoJson数据3. 创建地图组件4. 创建主页面组件5. 使用组件 二、功能亮点三、性能优化建议四、常见问题解决五、结语六、实战demo七、资源下载 前言 在数据可视化领域&#xff0c;地图展示是一种非常直观的表现形式。而地图钻取&…

安卓学习24 -- 网络

1 整体架构 &#xff08;出处见水印&#xff09; 这两张是能找到的比较清楚的图。目前可以看出&#xff0c;底层的网络业务&#xff0c;还是传统的linux内核提供。&#xff08;注&#xff1a;这两个图我个人觉得不是非常对。。。&#xff09; 在安卓上增加的两个比较重要的部…

github新建一个远程仓库并添加了README.md,本地git仓库无法push

1.本地git仓库与远程仓库绑定 2.push时报错&#xff0c;本地的 main 分支落后于远程仓库的 main 分支&#xff08;即远程有更新&#xff0c;但你本地没有&#xff09;&#xff0c;需要拉取远程的仓库--->在merge合并&#xff08;解决冲突&#xff09;--->push 3.但是git …

Python:使用web框架Flask搭建网站

Date: 2025.04.19 20:30:43 author: lijianzhan Flask 是一个轻量级的 Python Web 开发框架&#xff0c;以简洁灵活著称&#xff0c;适合快速构建中小型 Web 应用或 API 服务。以下是 Flask 的核心概念、使用方法和实践指南 Flask 的核心特点&#xff1a; 轻量级 核心代码仅约…

Kotlin delay方法解析

本文记录了kotlin协程(Android)中delay方法的字节码实现&#xff0c;并解析了delay方法如何实现挂起操作。 一、delay方法介绍 1.1、delay方法使用举例 class TestDelay {suspend fun testDelay() {Log.d("TestDelay", "before delay")delay(1000)Log.d…

【Vulkan 入门系列】创建描述符集布局和图形管线(五)

描述符集布局定义了着色器如何访问资源&#xff08;如缓冲区和图像&#xff09;&#xff0c;是渲染管线配置的关键部分。图形管线定义了从顶点数据到最终像素输出的整个处理流程&#xff0c;包括可编程阶段&#xff08;如顶点和片段着色器&#xff09;和固定功能阶段&#xff0…

mysql中in的用法详解

MySQL 中 IN 操作符用法详解 IN 是 MySQL 中用于多值筛选的高效操作符&#xff0c;常用于 WHERE 子句&#xff0c;可替代多个 OR 条件&#xff0c;简化查询逻辑并提升可读性。以下从基础语法、应用场景、性能优化、常见问题及高级技巧进行全方位解析。 一、基础语法与优势 1.…