【Block总结】MAB,多尺度注意力块|即插即用

news2025/3/6 17:14:17

文章目录

  • 一、论文信息
  • 二、创新点
  • 三、方法
    • MAB模块解读
      • 1、MAB模块概述
      • 2、MAB模块组成
      • 3、MAB模块的优势
  • 四、效果
  • 五、实验结果
  • 六、总结
  • 代码

一、论文信息

  • 标题: Multi-scale Attention Network for Single Image Super-Resolution
  • 作者: Yan Wang, Yusen Li, Gang Wang, Xiaoguang Liu
  • 机构: 南开大学
  • 发表会议: CVPR 2024 Workshops
  • 论文链接: arXiv
  • GitHub代码库: icandle/MAN
    在这里插入图片描述

二、创新点

  • 多尺度大核注意力(MLKA): 结合了多尺度机制与大核卷积,能够有效捕捉不同尺度的信息,避免了传统方法中常见的“块状伪影”问题。

  • 门控空间注意力单元(GSAU): 通过引入门控机制,优化了空间注意力的计算,去除了不必要的线性层,从而提高了信息聚合的效率和准确性。

  • 灵活的网络结构: 通过堆叠不同数量的MLKA和GSAU模块,构建出多种复杂度的网络,以实现性能与计算量之间的平衡。

三、方法

  1. 网络架构: Multi-scale Attention Network (MAN)由三个主要模块组成:

    • 浅层特征提取模块(SF): 负责初步的特征提取。
    • 深层特征提取模块(DF): 基于多个多尺度注意力块(MAB),进一步提取丰富的特征。
    • 高质量图像重建模块: 将提取的特征用于最终的图像重建。
  2. 多尺度注意力块(MAB):

    • MLKA模块: 结合大核注意力、多个尺度机制和门控聚合,建立不同尺度之间的相关性。
    • GSAU模块: 整合空间注意力和门控机制,简化前馈网络。
  3. MLKA的功能:

    • 大核注意力: 通过分解卷积建立长距离关系。
    • 多尺度机制: 增强固定LKA以学习全尺度信息的注意力图。
    • 门控聚合: 动态调整注意力图以避免伪影。

MAB模块解读

在这里插入图片描述

1、MAB模块概述

MAB(Multi-scale Attention Block)是Multi-scale Attention Network (MAN)中的核心组件,旨在通过结合多尺度大核注意力(MLKA)和门控空间注意力单元(GSAU)来提升图像超分辨率的性能。MAB模块的设计旨在有效捕捉图像中的局部和全局特征,同时避免传统卷积网络中常见的“块状伪影”问题。

2、MAB模块组成

MAB模块主要由以下两个部分构成:

  1. 多尺度大核注意力(MLKA):

    • 功能: MLKA通过引入多尺度机制,结合大核卷积,能够在不同尺度上提取丰富的特征信息。
    • 结构:
      • 首先,MLKA使用点卷积(Point-wise convolution)调整通道数。
      • 然后,将特征分成三组,每组使用不同尺寸的大核卷积(如7×7、21×21、35×35)进行处理,膨胀率分别设置为(2,3,4)。
      • 为了避免膨胀卷积带来的“块状伪影”,MLKA引入了门控聚合机制,通过逐元素乘法将深度卷积的输出与对应组的LKA输出相结合,从而动态调整注意力图的输出。
  2. 门控空间注意力单元(GSAU):

    • 功能: GSAU旨在增强特征表示能力,通过结合空间注意力和门控机制,优化信息聚合过程。
    • 结构:
      • GSAU通常由两个分支组成,其中一个分支使用深度卷积对特征进行加权,另一个分支则通过空间自注意力机制捕捉空间上下文信息。
      • 这种设计减少了不必要的线性层,降低了计算复杂度,同时增强了特征的表达能力。

3、MAB模块的优势

  • 多尺度特征提取: 通过MLKA,MAB能够在多个尺度上提取特征,增强了模型对不同图像细节的敏感性。

  • 减少伪影: 通过门控聚合机制,MAB有效地减少了由于膨胀卷积引起的块状伪影,提升了图像重建的质量。

  • 高效的特征表示: GSAU的引入使得模型能够更好地聚合空间信息,提升了特征的表达能力,进而提高了超分辨率的效果。

四、效果

  • 性能提升: 实验结果表明,MAN在多个超分辨率基准测试中表现优异,能够与当前最先进的模型(如SwinIR)相媲美,同时在计算效率上也有显著改善。

  • 避免伪影: 通过MLKA和GSAU的结合,模型有效减少了图像重建中的伪影现象,提升了视觉效果。

五、实验结果

  • 基准测试: 论文中使用了多个数据集(如Set5、Set14、BSD100、Urban100、Manga109)进行测试,结果显示MAN在PSNR和SSIM指标上均优于传统的超分辨率模型,尤其是在高倍数放大(如×4)时表现突出。
数据集上采样因子MAN PSNR (dB)与SwinIR比较
Set5×238.42相当
×334.91略低
×432.87良好
Set14×234.44相近
×330.92略低
×429.09良好
BSD100×232.53相近
×329.65略低
×427.90良好
Urban100×233.80相近
×334.45略低
×433.73良好
Manga109×240.02相近
×335.21略低
×431.22良好

六、总结

Multi-scale Attention Network (MAN)通过结合多尺度大核注意力和门控空间注意力机制,成功提升了单幅图像超分辨率重建的性能和效率。该研究不仅解决了传统方法中的一些局限性,还为未来的超分辨率模型设计提供了新的思路。MAN在多个基准测试中的优异表现,证明了其在实际应用中的潜力,尤其是在需要高质量图像重建的场景中。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class SGAB(nn.Module):
    def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):
        super().__init__()
        i_feats = n_feats * 2

        self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
        self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats)
        self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

    def forward(self, x):
        shortcut = x.clone()

        # Ghost Expand
        x = self.Conv1(self.norm(x))
        a, x = torch.chunk(x, 2, dim=1)
        x = x * self.DWConv1(a)
        x = self.Conv2(x)

        return x * self.scale + shortcut


class GroupGLKA(nn.Module):
    def __init__(self, n_feats, k=2, squeeze_factor=15):
        super().__init__()
        i_feats = 2 * n_feats

        self.n_feats = n_feats
        self.i_feats = i_feats

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        # Multiscale Large Kernel Attention
        self.LKA7 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA5 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA3 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))

        self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)
        self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)
        self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)

        self.proj_first = nn.Sequential(
            nn.Conv2d(n_feats, i_feats, 1, 1, 0))

        self.proj_last = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 1, 1, 0))

    def forward(self, x, pre_attn=None, RAA=None):
        shortcut = x.clone()

        x = self.norm(x)

        x = self.proj_first(x)

        a, x = torch.chunk(x, 2, dim=1)

        a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)

        a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],
                      dim=1)

        x = self.proj_last(x * a) * self.scale + shortcut

        return x

    # MAB


class MAB(nn.Module):
    def __init__(
            self, n_feats):
        super().__init__()

        self.LKA = GroupGLKA(n_feats)

        self.LFE = SGAB(n_feats)

    def forward(self, x, pre_attn=None, RAA=None):
        # large kernel attention
        x = self.LKA(x)

        # local feature extraction
        x = self.LFE(x)

        return x


if __name__ == "__main__":
    dim=96 # 通道要被3整除
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, channels,height, width)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 MAB模块

    block = MAB(dim)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2289744.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

移动互联网用户行为习惯哪些变化,对小程序的发展有哪些积极影响

一、碎片化时间利用增加 随着生活节奏的加快,移动互联网用户的碎片化时间越来越多。在等公交、排队、乘坐地铁等间隙,用户更倾向于使用便捷、快速启动的应用来满足即时需求。小程序正好满足了这一需求,无需下载安装,随时可用&…

使用 Tauri 2 + Next.js 开发跨平台桌面应用实践:Singbox GUI 实践

Singbox GUI 实践 最近用 Tauri Next.js 做了个项目 - Singbox GUI,是个给 sing-box 用的图形界面工具。支持 Windows、Linux 和 macOS。作为第一次接触这两个框架的新手,感觉收获还蛮多的,今天来分享下开发过程中的一些经验~ 为啥要做这个…

攻防世界_simple_php

同类型题(更难版->)攻防世界_Web(easyphp)(php代码审计/json格式/php弱类型匹配) php代码审计 show_source(__FILE__):show_source() 函数用于显示指定文件的源代码,并进行语法高亮显示。__FILE__ 是魔…

C++哈希(链地址法)(二)详解

文章目录 1.开放地址法1.1key不能取模的问题1.1.1将字符串转为整型1.1.2将日期类转为整型 2.哈希函数2.1乘法散列法(了解)2.2全域散列法(了解) 3.处理哈希冲突3.1线性探测(挨着找)3.2二次探测(跳…

Solon Cloud Gateway 开发:导引

Solon Cloud Gateway 是 Solon Cloud 体系提供的分布式网关实现(轻量级实现)。 分布式网关的特点(相对于本地网关): 提供服务路由能力提供各种拦截支持 1、分布式网关推荐 建议使用专业的分布式网关产品&#xff0…

dmfldr实战

dmfldr实战 本文使用达梦的快速装载工具,对测试表进行数据导入导出。 新建测试表 create table “BENCHMARK”.“TEST_FLDR” ( “uid” INTEGER identity(1, 1) not null , “name” VARCHAR(24), “begin_date” TIMESTAMP(0), “amount” DECIMAL(6, 2), prim…

Spring AOP 入门教程:基础概念与实现

目录 第一章:AOP概念的引入 第二章:AOP相关的概念 1. AOP概述 2. AOP的优势 3. AOP的底层原理 第三章:Spring的AOP技术 - 配置文件方式 1. AOP相关的术语 2. AOP配置文件方式入门 3. 切入点的表达式 4. AOP的通知类型 第四章&#x…

Upscayl-官方开源免费图像AI增强软件

upscayl 链接:https://pan.xunlei.com/s/VOI0Szqe0fCwSSUSS8zRqKf7A1?pwdhefi#

SpringBoot Web开发(SpringMVC)

SpringBoot Web开发(SpringMVC) MVC 核心组件和调用流程 Spring MVC与许多其他Web框架一样,是围绕前端控制器模式设计的,其中中央 Servlet DispatcherServlet 做整体请求处理调度! . 除了DispatcherServletSpringMVC还会提供其他…

苍穹外卖第一天

角色分工 技术选型 pojo子模块 nginx反向代理 MD5密码加密

C# Winform enter键怎么去关联button

1.关联按钮上的Key事件按钮上的keypress,keydown,keyup事件随便一个即可private void textBox1_KeyDown(object sender, KeyEventArgs e){if (e.KeyCode Keys.Enter){this.textBox2.Focus();}}2.窗体上的事件private void textBox2_KeyPress(object sen…

LeGO LOAM坐标系问题的自我思考

LeGO LOAM坐标系问题的自我思考 IMU坐标系LeGO LOAM代码分析代码 对于IMU输出测量值的integration积分过程欧拉角的旋转矩阵VeloToStartIMU()函数TransformToStartIMU(PointType *p) IMU坐标系 在LeGO LOAM中IMU坐标系的形式采用前(x)-左(y)-上(z)的形式,IMU坐标系…

vim交换文件的作用

1.数据恢复:因为vim异常的退出,使用交换文件可以恢复之前的修改内容。 2.防止多人同时编辑:vim检测到交换文件的存在,会给出提示,以避免一个文件同时被多人编辑。 (vim交换文件的工作原理:vim交换文件的工作…

PHP实现混合加密方式,提高加密的安全性(代码解密)

代码1&#xff1a; <?php // 需要加密的内容 $plaintext 授权服务器拒绝连接;// 1. AES加密部分 $aesKey openssl_random_pseudo_bytes(32); // 生成256位AES密钥 $iv openssl_random_pseudo_bytes(16); // 生成128位IV// AES加密&#xff08;CBC模式&#xff09…

五. Redis 配置内容(详细配置说明)

五. Redis 配置内容(详细配置说明) 文章目录 五. Redis 配置内容(详细配置说明)1. Units 单位配置2. INCLUDES (包含)配置3. NETWORK (网络)配置3.1 bind(配置访问内容)3.2 protected-mode (保护模式)3.3 port(端口)配置3.4 timeout(客户端超时时间)配置3.5 tcp-keepalive()配置…

(9) 上:学习与验证 linux 里的 epoll 对象里的 EPOLLIN、 EPOLLHUP 与 EPOLLRDHUP 的不同

&#xff08;1&#xff09;经过之前的学习。俺认为结论是这样的&#xff0c;因为三次握手到四次挥手&#xff0c;到 RST 报文&#xff0c;都是 tcp 连接上收到了报文&#xff0c;这都属于读事件。所以&#xff1a; EPOLLIN : 包含了读事件&#xff0c; FIN 报文的正常四次挥手、…

因果推断与机器学习—用机器学习解决因果推断问题

Judea Pearl 将当前备受瞩目的机器学习研究戏谑地称为“仅限于曲线拟合”,然而,曲线拟合的实现绝非易事。机器学习模型在图像识别、语音识别、自然语言处理、蛋白质分子结构预测以及搜索推荐等多个领域均展现出显著的应用效果。 在因果推断任务中,在完成因果效应识别之后,需…

2025全自动企业站群镜像管理系统 | 支持繁简转换拼音插入

2025全自动企业站群镜像管理系统 | 支持繁简转换拼音插入 在全球化的今天&#xff0c;企业面临着管理多站点的挑战&#xff0c;尤其是跨语言和地理位置的站点。为此&#xff0c;我们设计了一套基于PHP的全自动企业站群镜像管理系统&#xff0c;它不仅能够自动化站点的管理&…

基于阿里云百炼大模型Sensevoice-1的语音识别与文本保存工具开发

基于阿里云百炼大模型Sensevoice-1的语音识别与文本保存工具开发 摘要 随着人工智能技术的不断发展&#xff0c;语音识别在会议记录、语音笔记等场景中得到了广泛应用。本文介绍了一个基于Python和阿里云百炼大模型的语音识别与文本保存工具的开发过程。该工具能够高效地识别东…

GIS与相关专业软件汇总

闲来无事突然想整理一下看看 GIS及相关领域 究竟有多少软件或者工具包等。 我询问了几个AI工具并汇总了一个软件汇总&#xff0c;不搜不知道&#xff0c;一搜吓一跳&#xff0c;搜索出来了大量的软件&#xff0c;大部分软件或者工具包都没有见过&#xff0c;不知大家还有没有要…