伪装目标检测之注意力CBAM:《Convolutional Block Attention Module》

news2024/12/26 10:56:10

论文地址:link
代码:link

摘要

我们提出了卷积块注意力模块(CBAM),这是一种简单而有效的用于前馈卷积神经网络的注意力模块。给定一个中间特征图,我们的模块依次推断沿着两个独立维度的注意力图,通道和空间,然后将这些注意力图与输入特征图相乘,进行自适应特征细化。由于CBAM是一个轻量级和通用的模块,它可以无缝地集成到任何CNN架构中,几乎没有额外开销,并且可以与基础CNN一起端到端地进行训练。我们通过在ImageNet-1K、MS COCO检测和VOC 2007检测数据集上进行大量实验来验证我们的CBAM。 我们的实验证明,在各种模型中,分类和检测性能均有一致的提升,展示了CBAM的广泛适用性。

1.介绍

卷积神经网络(CNN)凭借其丰富的表示能力,显著提升了视觉任务的性能[1,2,3]。为了增强CNN的性能,最近的研究主要探讨了网络的三个重要因素:深度、宽度和基数。。除了这些因素,我们研究了架构设计的另一个方面,即注意力。注意力的重要性在前的文献中得到了广泛研究[12,13,14,15,16,17]。注意力不仅告诉我们要关注哪里,还改善了感兴趣的表示。我们的目标是通过使用注意力机制增强表示能力:专注于重要特征并抑制不必要的特征。在本文中,我们提出了一个新的网络模块,名为“卷积块注意力模块”。由于卷积操作通过将跨通道和空间信息混合在一起提取信息特征,我们采用我们的模块来强调这两个主要维度上的有意义特征:通道和空间轴。为了实现这一点,我们依次应用通道和空间注意力模块(如图1所示),以便每个分支可以分别学习在通道和空间轴上关注“什么”和“哪里”。
主要贡献
1.我们提出了一个简单而有效的注意力模块(CBAM),可以广泛应用于提升CNN的表示能力。
2.我们通过广泛的消融研究验证了我们的注意力模块的有效性。
3.我们验证了在多个基准测试(ImageNet-1K、MS COCO和VOC 2007)上,通过插入我们的轻量级模块,各种网络的性能得到了显著提升。

2.卷积块注意模块

在这里插入图片描述
给定一个中间特征图 F ∈ R C × H × W F \in {R^{C \times H \times W}} FRC×H×W作为输入,CBAM 依次推断出一个1D通道注意力图 M c ∈ R C × 1 × 1 {M_c} \in {R^{C \times 1 \times 1}} McRC×1×1和一个 2D 空间注意力图 M s ∈ R 1 × H × W {M_s} \in {R^{1 \times H \times W}} MsR1×H×W,如图1所示。总体注意力过程可以总结为:
在这里插入图片描述
在这里,符号 ⊗ 表示逐元素乘法。在乘法过程中,注意力值相应地进行广播(复制):通道注意力值沿空间维度广播,反之亦然。F ′′ 是最终的精炼输出。图2描述了每个注意力图的计算过程。以下描述了每个注意力模块的细节。

2.1 Channel attention module

利用特征的通道间关系来生成通道注意力图,特征图的每个通道都被视为特征检测器,通道注意力集中在给定输入图像的情况下“什么”是有意义的,为了有效地计算通道注意力,压缩输入特征图的空间维度,使用平均池化和最大池化特征。利用这两个特征可以极大地提高网络的表示能力,而不是单独使用每个特征。
首先利用平均池化和最大池化操作来聚合特征图的空间信息,生成两个不同的空间上下文描述符, F a v g c F_{avg}^c Favgc F m a x c F_{max}^c Fmaxc,分别表示平均池化特征和最大池化特征。然后,这两个描述符都被转发到共享网络以生成我们的通道注意力图 M c ∈ R C × 1 × 1 {M_c} \in {R^{C \times 1 \times 1}} McRC×1×1 。共享网络由具有一个隐藏层的多层感知器(MLP)组成。为了减少参数开销,隐藏激活大小设置为 R C / r × 1 × 1 R ^{C/r×1×1} RC/r×1×1 ,其中 r 是缩减比率。将共享网络应用于每个描述符后,我们使用逐元素求和来合并输出特征向量。简而言之,通道注意力计算如下:
在这里插入图片描述
其中 σ 表示 sigmoid 函数,W0 ∈ R C / r × C R^{C/r×C} RC/r×C ,W1 ∈ R C × C / r R^{C×C/r} RC×C/r 。请注意,两个输入共享 MLP 权重 W 0 W_0 W0 W 1 W_1 W1,并且 ReLU 激活函数后面跟着 W 0 W_0 W0

2.2 Spatial attention module

利用特征的空间关系生成空间注​​意力图。与通道注意力不同,空间注意力关注的是“哪里”,这是信息性的部分,与通道注意力是互补的。为了计算空间注意力,我们首先沿着通道轴应用平均池化和最大池化操作并将它们连接起来以生成有效的特征描述符。沿着通道轴应用池化操作被证明可以有效地突出显示信息区域。在级联特征描述符上,我们应用卷积层来生成空间注​​意力图 M s ( F ) ∈ R H × W {M_s}\left( F \right) \in {R^{H \times W}} Ms(F)RH×W,它对强调或抑制的位置进行编码。下面我们描述详细操作。通过使用两个池化操作来聚合特征图的通道信息,生成两个2D图: F a v g s ∈ R 1 × H × W F_{avg}^s \in {R^{1 \times H \times W}} FavgsR1×H×W F m a x s ∈ R 1 × H × W F_{max}^s \in {R^{1 \times H \times W}} FmaxsR1×H×W。每个表示整个通道的平均池化特征和最大池化特征。然后,它们通过标准卷积层连接和卷积,生成我们的 2D 空间注意力图。简而言之,空间注意力计算如下:
在这里插入图片描述
其中σ表示sigmoid函数,f 7×7表示滤波器尺寸为7×7的卷积运算。

注意力模块的安排。给定输入图像,两个注意力模块(通道和空间)计算互补注意力,分别关注“什么”和“哪里”。考虑到这一点,两个模块可以并行或顺序放置。我们发现顺序排列比并行排列提供更好的结果。对于顺序过程的安排,我们的实验结果表明,通道优先顺序略好于空间优先顺序。
代码实现:

import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )
            elif pool_type=='lp':
                lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( lp_pool )
            elif pool_type=='lse':
                # LSE pool only
                lse_pool = logsumexp_2d(x)
                channel_att_raw = self.mlp( lse_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1537954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5.域控服务器都要备份哪些资料?如何备份DNS服务器?如何备份DHCP服务器?如何备份组策略?如何备份服务器状态的备份?

(2.1) NTD(域控数据库)备份 (2.2)DNS备份 (2.3)DHCP备份 (2.4)组策略备份 (2.5)CA证书备份 (2.6)系统状态备份 (2.1)…

乳腺癌分类模型

乳腺癌分类模型的定义中,必须有_init_(初始化)函数和forward(正向传播)函数 乳腺癌分类模型定义 # 自定义模型 class MyModel(torch.nn.Module):def __init__(self,in_features):super(MyModel,self).__init__() #调用…

qt+ffmpeg 实现音视频播放(三)之视频播放

一、视频播放流程 (PS:视频的播放流程跟音频的及其相似!!) 1、打开视频文件 通过 avformat_open_input() 打开媒体文件并分配和初始化 AVFormatContext 结构体。 函数原型如下: int avformat_open_inpu…

Flask python 开发篇:链接mysql

一、历史回顾 根据上一篇:配置文件编写,已经把各种配置根据开发环境做了区分,再config.py中,我们可以分别处理测试、生产的相关配置,这节主要说一下数据库的链接和使用 二、配置数据库连接 Flask定义和链接数据库文…

手机可以格式化存储卡吗?格式化以后出现什么情况

随着智能手机的普及,存储卡(如SD卡、MicroSD卡等)已成为手机存储扩展的重要工具。然而,在使用过程中,我们有时可能会遇到需要格式化存储卡的情况。那么,手机能否直接格式化存储卡呢?格式化后存储…

【Flutter学习笔记】10.3 组合实例:TurnBox

参考资料:《Flutter实战第二版》 10.3 组合实例:TurnBox 这里尝试实现一个更为复杂的例子,其能够旋转子组件。Flutter中的RotatedBox可以旋转子组件,但是它有两个缺点: 一是只能将其子节点以90度的倍数旋转二是当旋转…

在服务器(Ubuntu20.04)安装用户级别的cuda11.8(以及仿照前面教程安装cuda11.3后安装cudnn和pytorch1.9.0)

1、cuda11.8的下载 首先在cuda官网下载我们需要的cuda版本,这里我下载的是cuda11.8(我的最高支持cuda12.0) 这里我直接使用wget命令下载不了,于是我直接在浏览器输入后面的链接下载到本地,之后再上传至服务器的&am…

数据分析概述、Conda环境搭建及JupyterLab的搭建

1. 数据分析职责概述 当今世界对信息技术的依赖程度在不断加深,每天都会有大量的数据产生,我们经常会感到数据越来越多,但是要从中发现有价值的信息却越来越难。这里所说的信息,可以理解为对数据集处理之后的结果,是从…

SQLiteC/C++接口详细介绍sqlite3_stmt类(十)

返回:SQLite—系列文章目录 上一篇:SQLiteC/C接口详细介绍sqlite3_stmt类(九) 下一篇: SQLiteC/C接口详细介绍sqlite3_stmt类(十一) 38、sqlite3_column_value sqlite3_column_valu…

Python:熟悉简单的skfuzzy构建接近生活事件的模糊控制器”(附带详细注释说明)+ 测试结果

参考资料:https: // blog.csdn.net / shelgi / article / details / 126908418 ————通过下面这个例子,终于能理解一点模糊理论的应用了,感谢原作。 熟悉简单的skfuzzy构建接近生活事件的模糊控制器 假设下面这样的场景, 我们希望构建一套…

linux系统------------MySQL 存储引擎

目录 一、存储引擎概念介绍 二、常用的存储引擎 2.1MyISAM 2.1.1MYlSAM的特点 2.1.2MyISAM 表支持 3 种不同的存储格式⭐: (1)静态(固定长度)表 (2)动态表 (3)压缩表 2.1.3MyISAM适…

基于python+vue食品安全信息管理系统flask-django-nodejs-php

食品安全信息管理系统设计的目的是为用户提供食品信息、科普专栏、食品检测、检测结果、交流论坛等方面的平台。 与PC端应用程序相比,食品安全信息管理系统的设计主要面向于用户,旨在为管理员和用户提供一个食品安全信息管理系统。用户可以通过APP及时查…

Gitlab介绍

1.什么是Gitlab GitLab是一个流行的版本控制系统平台,主要用于代码托管、测试和部署。 GitLab是基于Git的一个开源项目,它提供了一个用于仓库管理的Web服务。GitLab使用Ruby on Rails构建,并提供了诸如wiki和issue跟踪等功能。它允许用户通…

欧科云链:2024将聚焦发展与安全,用技术助力链上数据安全和合规

近期,OpenAI和Web3.0两大新技术发展势头迅猛。OpenAI 再次引领AI领域的新浪潮,推出了创新的文本转视频模型——Sora,Sora 可以创建长达60 秒的视频,包含高度详细的场景、複杂的摄像机运动以及情感丰富角色,再次将AI 的…

Django在日志中使用AdminEmailHandler发送邮件(同步),及celery异步发送日志邮件的实现

目录 一、使用AdminEmailHandler实现发送日志通知邮件 1,配置日志项 2,配置邮件项 3,在视图里使用日志 二、继承AdminEmailHandler使用celery实现异步发送邮件 1,安装配置celery 2,继承AdminEmailHandler类&…

python食品安全信息管理系统flask-django-nodejs-php

。 食品安全信息管理系统是在安卓操作系统下的应用平台。为防止出现兼容性及稳定性问题,编辑器选择的是Hbuildex,安卓APP与后台服务端之间的数据存储主要通过MySQL。用户在使用应用时产生的数据通过 python等语言传递给数据库。通过此方式促进食品安全信…

VMware 15 中 Ubuntu与windows 10共享文件夹设置

wmware 15.5.7中安装ubuntu 22.04 物理机为windows 10 1.选中ubuntu中想要共享的文件夹右击,点属性 2.在Local network share中勾选share this folder,第一次会提示你安装samba,安装即可 3.window10的资源管理器中使用 虚拟机计算机名即可…

无人机采集图像的相关知识

1.飞行任务规划 一般使用飞行任务规划软件进行飞行任务的设计,软件可以自动计算相机覆盖和图像重叠情况。比如ArduPilot (ArduPilot - Versatile, Trusted, Open) 和UgCS (http://www.ugcs.com)是两个飞行任务规划软件,可以适用大多数无人机系统。 2.图…

如何减少pdf的文件大小?pdf压缩工具介绍

文件发不出去,有时就会耽误工作进度,文件太大无法发送,这应该是大家在发送PDF时,常常会碰到的问题吧,那么PDF文档压缩大小怎么做呢?因此我们需要对pdf压缩后再发送,那么有没有好用的pdf压缩工具…

解决IE11报错:CSS 因 Mime 类型不匹配而被忽略

简要概述: 本人用springboot开发网站,手动处理js和css文件请求,报错:CSS 因 Mime 类型不匹配而被忽略 后台代码: 如下三个代码块 GetMapping("/Guest/ASN1/{FileName}")public void GetFiles(PathVariab…