CBAM(Convolutional Block Attention Module)卷积注意力模块用法及代码实现

news2025/1/27 13:09:41

CBAM卷积注意力模块用法及代码实现

  • CBAM
    • Channel Attention模块(CAM)
    • Spatial Attention模块(SAM)
  • 代码实现

CBAM

CBAM( Convolutional Block Attention Module )是一种轻量级注意力模块的提出于2018年。CBAM包含CAM(Channel Attention Module)和SAM(Spartial Attention Module)两个子模块,分别在通道上和空间上添加注意力机制。这样不仅可以节约参数和计算力,而且保证了其能够做为即插即用的模块集成到现有的网络架构中去。可以无缝的集成到CNNs中,并且可以与基本CNNs一起端到端的训练。

通道注意力让网络关注图像“是什么”,而空间注意力则让网络关注图像中物体“在哪“。

CBAM模型结构

Channel Attention模块(CAM)

通道注意力模块:通道维度不变,压缩空间维度。关注输入图片中有意义的信息(不同channel中有不同的信息)
在这里插入图片描述
在通道注意力模块中,通过将输入的特征图分别经过最大池化以及平均池化,将特征图从CHW变为C11的大小,然后经过两层共享全连接层(shared MLP)中,它先将通道数压缩为原来的1/r(Reduction,减少率)倍,再扩张到原通道数,再将这两个输出进行elementwise逐元素相加操作,经过sigmoid激活,最终即可获得通道注意力模块的特征图。再将这个输出结果乘原图,变回CHW的大小。

通道注意力机制(Channel Attention Module)是将特征图在通道维度不变,压缩空间维度,得到一个一维矢量后再进行操作。通道注意力关注的是这张图上哪些内容是有重要作用的。平均值池化对特征图上的每一个像素点都有反馈;而最大值池化在进行梯度反向传播计算时,只有特征图中响应最大的地方有梯度的反馈。

Spatial Attention模块(SAM)

空间注意力模块:空间维度不变,压缩通道维度。该模块关注的是目标的位置信息。
在这里插入图片描述
在空间注意力模块中,是将通道注意力模块输出的特征图作为输入对通道进行压缩。依次做一个基于channel维度的最大池化和平均池化得到两个1HW特征图,最大池化的操作就是在通道上提取最大值,提取的次数是H × W;平均池化的操作就是在通道上提取平均值,提取的次数也是是H × W;从而可以获得一个2通道的特征图。然后将两层进行torch.cat操作。然后进行77卷积,降为1个channel,再经sigmoid获得空间注意力模块输出的特征图。最后将输出结果乘原图变回CH*W大小。

在这里插入图片描述
实验表明,我们可以看出CBAM模型中,先通过channel,再通过spatial,会获得更好的准确率和更低的错误率。

论文:https://arxiv.org/abs/1807.06521

代码实现

https://github.com/Jongchan/attention-module
代码1如下:

import torch
import torch.nn as nn
class CBAM(nn.Module):
    def __init__(self, channel, reduction=16, spatial_kernel=7):
        super(CBAM, self).__init__()
        # channel attention 压缩H,W为1
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # shared MLP
        self.mlp = nn.Sequential(
            # Conv2d比Linear方便操作
            # nn.Linear(channel, channel // reduction, bias=False)
            nn.Conv2d(channel, channel // reduction, 1, bias=False),
            # inplace=True直接替换,节省内存
            nn.ReLU(inplace=True),
            # nn.Linear(channel // reduction, channel,bias=False)
            nn.Conv2d(channel // reduction, channel, 1, bias=False)
        )
        # spatial attention
        self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel,
                              padding=spatial_kernel // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        max_out = self.mlp(self.max_pool(x))
        avg_out = self.mlp(self.avg_pool(x))
        channel_out = self.sigmoid(max_out + avg_out)
        x = channel_out * x
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        # print('max_out:',max_out.shape)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        # print('avg_out:',avg_out.shape)
        a=torch.cat([max_out, avg_out], dim=1)
        # print('a:',a.shape)
        spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1)))
        # print('spatial:',spatial_out.shape)
        x = spatial_out * x
        # print('x:',x.shape)
        return x

代码2如下:

class CBAM(nn.Module):
    '''CBAM包含CAM通道注意力模块(Channel Attention Module)和SAM空间注意力模块(Spartial Attention Module)两个子模块,
    分别进行通道和空间上的Attention。这样不只能够节约参数和计算力,并且保证了其能够做为即插即用的模块集成到现有的网络架构中去。
    '''
    def __init__(self, in_channels, out_channels, r = 0.5):
        super(CBAM, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.shared_mlp_cbam_1 = nn.Linear(out_channels, int(out_channels*r))
        self.shared_mlp_cbam_2 = nn.Linear(int(out_channels*r), out_channels)
        self.conv_cbam = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding = 3)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv1(x)
        Fc_avg = x.mean(dim = -1).mean(dim = -1)
        Fc_max = x.max(dim = -1)[0].max(dim = -1)[0]
        Fc = torch.sigmoid(self.shared_mlp_cbam_2(torch.relu(self.shared_mlp_cbam_1(Fc_avg))) + 
                self.shared_mlp_cbam_2(torch.relu(self.shared_mlp_cbam_1(Fc_max))))
        
        Fc = Fc.unsqueeze(-1).unsqueeze(-1).repeat((1,1,H,W))
        Fc = torch.mul(x, Fc)  # 

        Fs_avg = Fc.mean(dim = 1, keepdim=True)
        Fs_max = Fc.max(dim = 1, keepdim = True)[0]
        Fs = torch.sigmoid(self.conv_cbam(torch.cat((Fs_avg, Fs_max), dim = 1)))
        Fs = Fs.repeat((1, C, 1, 1))

        Fs = torch.mul(Fc, Fs)

        return (x + Fs)

在网络中,即插即用

class ResnetFPN4_CBAM(nn.Module):
    def __init__(self,
                num_channels,
                cbam_block=ResidualBlockCBAM,
                cfg=None):
        super(ResnetFPN4_CBAM, self).__init__()
        self.cfg = cfg
        
        # Block 1
        block = []
        block.append(nn.Conv2d(num_channels, num_channels, kernel_size = 3, stride = 1, padding = 1))
        block.append(nn.BatchNorm2d(num_channels))
        block.append(cbam_block(num_channels, num_channels))
        self.block1 = nn.Sequential(*block)

        # Block 2
        block = []
        block.append(nn.Conv2d(num_channels, num_channels, kernel_size = 3, stride = 2, padding = 1))
        block.append(nn.BatchNorm2d(num_channels))
        block.append(cbam_block(num_channels, num_channels))
        block.append(cbam_block(num_channels, num_channels))
        self.block2 = nn.Sequential(*block)

        # Block 3
        block = []
        block.append(nn.Conv2d(num_channels, 2*num_channels, kernel_size = 3, stride = 2, padding = 1))
        block.append(nn.BatchNorm2d(2*num_channels))

        block.append(cbam_block(2*num_channels, 2*num_channels))
        block.append(cbam_block(2*num_channels, 2*num_channels))
        self.block3 = nn.Sequential(*block)

        # Block 4
        block = []
        block.append(nn.Conv2d(2*num_channels, 4*num_channels, kernel_size = 3, stride = 2, padding = 1))
        block.append(nn.BatchNorm2d(4*num_channels))

        block.append(cbam_block(4*num_channels, 4*num_channels))
        block.append(cbam_block(4*num_channels, 4*num_channels))
        self.block4 = nn.Sequential(*block)


        # FPN
        self.up1 = nn.ConvTranspose2d(num_channels, 2*num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.up2 = nn.ConvTranspose2d(num_channels, 2*num_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = (1,1))
        self.up3 = nn.ConvTranspose2d(2*num_channels, 2*num_channels, kernel_size = 3, stride = 4, padding = 1, output_padding = (3,3))
        self.up4 = nn.ConvTranspose2d(4*num_channels, 4*num_channels, kernel_size = 5, stride = 8, padding = 1, output_padding = (5,5))

    def forward(self, x):
        ### Backbone ###
        x = self.block1(x)
        up_1 = self.up1(x)

        x = self.block2(x)
        up_2 = self.up2(x)

        x = self.block3(x)
        up_3 = self.up3(x)

        x = self.block4(x)
        up_4 = self.up4(x)

        ### Neck ### 
        out = torch.cat((up_1, up_2, up_3, up_4),1)

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/89289.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

185-200-spark-核心编程-Streaming

185-spark-核心编程-Streaming: 数据处理延迟的长短分为:实时数据处理(毫秒级别),离线数据处理(小时,天) 数据处理的方式分为:流式数据处理(streaming&…

ORACLE19c数据库随LINUX操作系统自动启动实现方式

1.建立目录 # su - oracle $ mkdir /home/oracle/scripts 2.建立启动脚本: $ cd /home/oracle/scripts $ vim startdb.sh #!/bin/bash export ORACLE_BASE/u01/app/oracle export ORACLE_HOME$ ORACLE_BASE/product/19.16.0/db_1 export ORACLE_SIDemrep export PAT…

【电脑使用】利用diskpart删除电脑的EFI分区

文章目录前言问题描述问题解决扩展:测量磁盘读写速度1 win10自带工具2 第三方工具前言 在Windows的磁盘管理中,往往会发现自己电脑的磁盘中莫名多了一些分区,有一些是系统分区(一般不删),还有一些是还原分区…

m索引OFDM调制解调系统的性能仿真分析

目录 1.算法描述 2.仿真效果预览 3.MATLAB核心程序 4.完整MATLAB 1.算法描述 随着无线通信技术的不断发展,人们对下一代移动通信系统提出了越来越高的要求。在这样的时代背景下,具有低峰均比,强频偏对抗能力和高能量效率的索引调制OFDM系统(Orthogonal Frequency Division …

【跟学C++】C++STL三大主要组件——容器/迭代器/算法(Study19)

文章目录1、前言2、简介2.1、STL是什么?2.2、STL能干什么?2.3、STL组成3、容器3.1、顺序容器3.2、排序容器(关联式容器)3.3、哈希容器3.4、容器适配器3、迭代器3.1、迭代器介绍3.2、迭代器定义方式3.3、迭代器类别3.4、辅助函数4、算法5、总结 【说明】…

【MATLAB教程案例60】使用matlab实现基于GRU网络的数据分类预测功能与仿真分析

欢迎订阅《FPGA学习入门100例教程》、《MATLAB学习入门100例教程》 目录 1.软件版本 2.GRU网络理论概述

【云原生进阶之容器】第一章Docker核心技术1.5.4节——cgroups使用

4 CGroups使用 4.1 挂载cgroup树 开始使用cgroup前需要先挂载cgroup树,下面先看看如何挂载一颗cgroup树,然后再查看其根目录下生成的文件。 #准备需要的目录 #准备需要的目录 dev@ubuntu:~$ mkdir cgroup && cd cgroup dev@ubuntu:~/cgroup$ mkdir demo#由于name=…

[论文解析] Diffusion Guided Domain Adaptation of Image Generators

project link: https://styleganfusion.github.io/ 文章目录OverviewWhat problem is addressed in the paper?What is the key to the solution?What is the main contribution?IntroductionBackgroundLatent diffusion modelClassifier-free guidanceMethodModel Structur…

pytorch深度学习实战lesson36

第三十六课 锚框 因为我们在目标检测里面需要预测边缘框,所以给我们的预测带来了很大的问题。我们在卷积神经网络里面做图片分类的时候,整个代码写起来看上去非常简单,就是一个 soft Max 出去就完事了。但是因为有边框的加入,使得…

第十二期 | 万元的正版课程仅花9.9就可买到?

顶象防御云业务安全情报中心监测发现,某线上教育培训类平台课件遭遇大规模盗取。被盗取的课件,经加工处理后,进行低价转售,严重损害了平台的合法权益。 飞速发展的在线教育和看不见的风险 随着5G、视频编解码等技术融合&#xff…

DevExpress .Net Components 22.2.3 Crack

DevExpress .Net适用于 Windows、Internet 以及您的移动世界的用户界面组件 使用适用于 WinForms、WPF、ASP.NET(WebForms、MVC 和 Core)、Windows 10、VCL 和 JavaScript 的 DevExpress 组件,打造一流的用户体验并模拟最热门的企业生产力程…

产品负责人 VS 产品经理

概述 Scrum框架创造了对新角色的需求,其中就包括 “产品负责人” 。这不可避免额外地导致对产品负责人和产品经理角色的误解和误用,对团队产生不必要的压力。 角色混淆会带来噪音和摩擦,削弱团队对价值、质量、速度和满意度的关注。这种混乱…

让搜狗快速收录网站的方法,批量查询网站有没有被搜狗收录

让搜狗快速收录只需做到以下8点: 1、网页标题要与内容相关。 2、页面少用flash,图片等 3、将网站链接大量推送给搜狗。 4、网页尽量采用静态网页。 5、首页的外部链接不要过多。 6、搜狗更喜欢受用户欢迎的内容的网站。 7、网站不要欺骗用户。 8、网站不…

四道编程题(涉及最大公约数最小公倍数,子序列等)

tips 1. scanf当是读取整数%d的时候,这时候如果它读取到\n,它就会停止读取。并且碰到空格的时候也会跳过。 2. getchar不需要传入参数,读取失败的时候会返回EOF。那getchar或者scanf到底是怎么从键盘上读取我输入的字符呢?在getc…

VSCode入门

VSCode入门 零、文章目录 文章地址 个人博客-CSDN地址:https://blog.csdn.net/liyou123456789个人博客-GiteePages:https://bluecusliyou.gitee.io/techlearn 代码仓库地址 Gitee:https://gitee.com/bluecusliyou/TechLearnGithub&#…

[附源码]Node.js计算机毕业设计高校创新学分申报管理系统Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

微服务实用篇4-消息队列MQ

今天主要来学习异步通讯技术MQ,主要包括初识MQ,RabbitMQ快速入门,SpringAMQP三大部分,下面就来一起学习吧。路漫漫其修远兮,吾将上下而求索,继续加油吧,少年。 目录 一、初识MQ 1.1、同步通讯…

文件历史记录无法识别此驱动器如何修复?

案例: 在电脑中尝试使用内置工具文件历史记录将文件备份到另一个硬盘时,发现如图所示的错误“文件历史记录无法识别此驱动器”,这可怎么办? 文件历史记录驱动器断开连接的原因 文件历史记录无法识别此驱动器的原因可能是启动类型…

四种排序(选择排序、冒泡排序、快速排序和插入排序)

四种排序(选择排序、冒泡排序、快速排序和插入排序)选择排序:完整代码:运行结果:冒泡排序:完整代码:运行结果:插入排序:完整代码:运行结果:快速排…

linux 环境异常登录的甄别方法

1、关于linux的登录记录 查看最近登录IP和历史命令执行日期 last 显示的最末尾的 使用last -10 看最新的 登录IP地址 时间 still仍在登录 选项: (1)-x:显示系统开关机以及执行等级信息 (2)-a&am…