【深度学习注意力机制系列】—— CBAM注意力机制(附pytorch实现)

news2024/12/23 9:04:53

CBAM(Convolutional Block Attention Module)是一种用于增强卷积神经网络(CNN)性能的注意力机制模块。它由Sanghyun Woo等人在2018年的论文[1807.06521] CBAM: Convolutional Block Attention Module (arxiv.org)中提出。CBAM的主要目标是通过在CNN中引入通道注意力和空间注意力来提高模型的感知能力,从而在不增加网络复杂性的情况下改善性能。

1、概述

CBAM旨在克服传统卷积神经网络在处理不同尺度、形状和方向信息时的局限性。为此,CBAM引入了两种注意力机制:通道注意力和空间注意力。通道注意力有助于增强不同通道的特征表示,而空间注意力有助于提取空间中不同位置的关键信息。

2、模型结构

CBAM由两个关键部分组成:通道注意力模块(C-channel)空间注意力模块(S-channel)。这两个模块可以分别嵌入到CNN中的不同层,以增强特征表示。

2.1 通道注意力模块

在这里插入图片描述

通道注意力模块的目标是增强每个通道的特征表达。以下是实现通道注意力模块的步骤:

  1. 全局最大池化和全局平均池化: 对于输入特征图,首先对每个通道执行全局最大池化和全局平均池化操作,计算每个通道上的最大特征值和平均特征值。这会生成两个包含通道数的向量,分别表示每个通道的全局最大特征和平均特征。

  2. 全连接层: 将全局最大池化和平均池化后的特征向量输入到一个共享全连接层中。这个全连接层用于学习每个通道的注意力权重。通过学习,网络可以自适应地决定哪些通道对于当前任务更加重要。将全局最大特征向量和平均特征向相交,得到最终注意力权重向量。

  3. Sigmoid激活: 为了确保注意力权重位于0到1之间,应用Sigmoid激活函数来产生通道注意力权重。这些权重将应用于原始特征图的每个通道。

  4. 注意力加权: 使用得到的注意力权重,将它们与原始特征图的每个通道相乘,得到注意力加权后的通道特征图。这将强调对当前任务有帮助的通道,并抑制无关的通道。

代码实现

class ChannelAttention(nn.Module):
    """
    CBAM混合注意力机制的通道注意力
    """

    def __init__(self, in_channels, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            # 全连接层
            # nn.Linear(in_planes, in_planes // ratio, bias=False),
            # nn.ReLU(),
            # nn.Linear(in_planes // ratio, in_planes, bias=False)

            # 利用1x1卷积代替全连接,避免输入必须尺度固定的问题,并减小计算量
            nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False)
        )

        self.sigmoid = nn.Sigmoid()

        def forward(self, x):
            avg_out = self.fc(self.avg_pool(x))
            max_out = self.fc(self.max_pool(x))
            out = avg_out + max_out
            out = self.sigmoid(out)
            return out * x

2.2 空间注意力模块

在这里插入图片描述

空间注意力模块的目标是强调图像中不同位置的重要性。以下是实现空间注意力模块的步骤:

  1. 全局最大池化和全局平均池化: 对于输入特征图,分别执行全局最大池化和全局平均池化操作,生成不同上下文尺度的特征。
  2. 连接和卷积: 将全局最大池化和全局平均池化后的特征沿着通道维度进行连接(拼接),得到一个具有不同尺度上下文信息的特征图。然后,通过卷积层处理这个特征图,以生成空间注意力权重。
  3. Sigmoid激活: 类似于通道注意力模块,对生成的空间注意力权重应用Sigmoid激活函数,将权重限制在0到1之间。
  4. 注意力加权: 将得到的空间注意力权重应用于原始特征图,对每个空间位置的特征进行加权。这样可以突出重要的图像区域,并减少不重要的区域的影响。

代码实现

class SpatialAttention(nn.Module):
    """
    CBAM混合注意力机制的空间注意力
    """

    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv1(out))
        return out * x

2.3 混合注意力模块

在这里插入图片描述

CBAM就是将通道注意力模块和空间注意力模块的输出特征逐元素相乘,得到最终的注意力增强特征。这个增强的特征将用作后续网络层的输入,以在保留关键信息的同时,抑制噪声和无关信息。原文实验证明先进行通道维度的整合,再进行空间维度的整合,模型效果更好(有效玄学炼丹的感觉)。

代码实现

class CBAM(nn.Module):
    """
    CBAM混合注意力机制
    """

    def __init__(self, in_channels, ratio=16, kernel_size=3):
        super(CBAM_Block, self).__init__()
        self.channelattention = ChannelAttention(in_channels, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = self.channelattention(x)
        x = self.spatialattention(x)
        return x

总结

总之,CBAM模块通过自适应地学习通道和空间注意力权重,以提高卷积神经网络的特征表达能力。通过将通道注意力和空间注意力结合起来,CBAM模块能够在不同维度上捕获特征之间的相关性,从而提升图像识别任务的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/856642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实现静态资源访问的几种方法

什么是静态资源? 静态资源是指在服务器端存储的不会变化的文件,如HTML、CSS、JavaScript、图片、音频、视频等文件。这些文件一般不包含动态内容,每次请求时返回的内容都是固定的。 为什么要使用静态资源? 提升网站性能&#xf…

gitblit-使用

1.登入GitBlit服务器 默认用户和密码: admin/admin 2.创建一个新的版本库 点击图中的“版本库”,然后点击图中“创建版本库” 填写名称和描述,注意名称最后一定要加 .git选择限制查看、克隆和推送勾选“加入README”和“加入.gitignore文件”在图中的1处…

kafka-2.12使用记录

kafka-2.12使用记录 安装kafka 2.12版本 下载安装包 根据你的系统下载rpm /deb /zip包等等, 这里我使用的是rpm包 安装命令 rpm -ivh kafka-2.12-1.nfs.x86_64.rpm启动内置Zookeeper 以下命令要写在同一行上 /opt/kafka-2.12/bin/zookeeper-server-start.sh /opt/kafka-2…

5.3.7.自动创建字符设备驱动的设备文件 class_create device_create

5.3.7.自动创建字符设备驱动的设备文件 5.3.7.1、问题描述: (1)整体流程回顾 (2)使用mknod创建设备文件的缺点 (3)能否自动生成和删除设备文件 5.3.7.2、解决方案:udev是PC机(嵌入式中用的是mdev) (1)什么是udev?应用层…

C语言笔记6

关于microsoft visual 的学习笔记 CtrlF5就是启动编译程序 先CtrlA进行全选&#xff0c;然后AitF8就自动的调节代码的格式 #include <stdio.h> #include <stdlib.h> int main() {//system启动程序(在一个程序中启动另外一个程序)//如果程序环境变量中找不到程序&am…

OpenCV实战(29)——视频对象追踪

OpenCV实战&#xff08;29&#xff09;——视频对象追踪 0. 前言1. 追踪视频中的对象2. 中值流追踪器算法原理3. 完整代码小结系列链接 0. 前言 我们已经学习了如何跟踪图像序列中点和像素的运动。但在多数应用中&#xff0c;通常要求追踪视频中的特定移动对象。首先确定感兴趣…

FFmpeg安装和使用

sudo apt install ffmpeg sudo apt-get install libavfilter-devcmakelist模板 CMakeLists.txt cmake_minimum_required(VERSION 3.16) project(ffmpeg_demo)# 设置ffmpeg依赖库及头文件所在目录&#xff0c;并存进指定变量 set(ffmpeg_libs_DIR /usr/lib/x86_64-linux-gnu) …

SpringBoot自动装配及run方法原理探究

自动装配 1、pom.xml spring-boot-dependencies&#xff1a;核心依赖在父工程中&#xff01;我们在写或者引入一些SpringBoot依赖的时候&#xff0c;不需要指定版本&#xff0c;就因为有这些版本仓库 1.1 其中它主要是依赖一个父工程&#xff0c;作用是管理项目的资源过滤及…

冠达管理:“高温超导”不是“室温超导”,5天4板百利电气再次澄清

短短半个月&#xff0c;“室温超导”在惊喜、质疑间回转&#xff0c;但资本市场对“超导概念股”的炒作还在进行&#xff0c;8月7日室温超导概念持续疯涨。同花顺显现&#xff0c;到8月7日收盘&#xff0c;18只超导概念股中&#xff0c;有16只股票飘红。 广东研山私募证券投资&…

如何将GPS坐标点如何网格化?

目录 题主问题&#xff1a; 解答&#xff1a; 高效判断点是否在正六边形蜂窝内的方法 代码实现&#xff1a;ArcGIS中实现指定面积蜂窝&#xff08;正六边形&#xff09;方法 碰巧自己前段时间处理过类似的数据&#xff0c;讲一下自己的解决思路。 题主问题&#xff1a; 解…

【小练习】交互式网格自定义增删改(进行中)

学习SQL和PLISQL数据类型的区别和应用场景 Oracle plsql 基础篇1 数据类型以及流程控制_bb_tarek的博客-CSDN博客https://blog.csdn.net/bb_tarek/article/details/17555713?ops_request_misc&request_id&biz_id102&utm_termplsql%E5%9F%BA%E6%9C%AC%E6%95%B0%E6…

9.异常

文章目录 9.1 Java 异常类层次结构图9.2 Throwable 类常用方法9.3 try-catch-finally9.4使用 try-with-resources 来代替try-catch-finally 9.1 Java 异常类层次结构图 在 Java 中&#xff0c;所有的异常都有一个共同的祖先 java.lang 包中的 Throwable 类。Throwable 类有两个…

CentOS安装Postgresql

PG基本安装步骤 安装postgresql&#xff1a; sudo yum install postgresql-server初始化数据库&#xff1a;安装完毕后&#xff0c;需要初始化数据库并创建初始用户&#xff1a; sudo postgresql-setup initdb启动和停止服务&#xff1a; sudo systemctl start postgresql sudo…

06微服务间的通信方式

一句话导读 微服务设计的一个挑战就是服务间的通信问题&#xff0c;服务间通信理论上可以归结为进程间通信&#xff0c;进程可以是同一个机器上的&#xff0c;也可以是不同机器的。服务可以使用同步请求响应机制通信&#xff0c;也可以使用异步的基于消息中间件间的通信机制。同…

【TS第三讲】完善TS开发环境

文章目录 &#x1f31f; 写在前面&#x1f31f; ts-node&#x1f31f; nodemon&#x1f31f; nodemon文件类型&#x1f31f; nodemon文件范围&#x1f31f; 写在最后 &#x1f31f; 写在前面 &#x1f525;探索TypeScript世界&#xff0c;驭Vue3Ts潮流&#xff0c;开启前端之旅…

【Ubuntu】简化反向代理和个性化标签页体验

本文将介绍如何使用Docker部署Nginx Proxy Manager和OneNav&#xff0c;两个功能强大且易用的工具。Nginx Proxy Manager用于简化和管理Nginx反向代理服务器的配置&#xff0c;而OneNav则提供个性化的新标签页体验和导航功能。通过本文的指导&#xff0c;您将学习如何安装和配置…

【打印整数二进制的奇数位和偶数位】

打印整数二进制的奇数位和偶数位 1.题目 获取一个整数二进制序列中所有的偶数位和奇数位&#xff0c;分别打印出二进制序列 2.题目分析 打印一个整数的二进制位中的偶数位和奇数位&#xff0c;可以对整数进行移位操作&#xff0c;再将移位的二进制位与1进行&操作。 按位&a…

HarmonyOS/OpenHarmony应用开发-ArkTS语言渲染控制概述

ArkUI通过自定义组件的build()函数和builder装饰器中的声明式UI描述语句构建相应的UI。 在声明式描述语句中开发者除了使用系统组件外&#xff0c;还可以使用渲染控制语句来辅助UI的构建&#xff0c;这些渲染控制语句包括控制组件是否显示的条件渲染语句&#xff0c;基于数组数…

Rocky Linux更换为国内源

Rocky Linux提供的可供切换的源列表&#xff1a;Mirrors - Mirror Manager 其中以 COUNTRY 列为 CN 的是国内源。 选择其中一个Rocky Linux 源使用帮助 — USTC Mirror Help 文档 操作前请做好备份 对于 Rocky Linux 8&#xff0c;使用以下命令替换默认的配置 sed -e s|^mirr…

Java用方法实现登录名和密码的校验

Java用方法实现登录名和密码的校验 需求分析代码实现小结Time 需求分析 系统正确的登录名和密码是:学习/123&#xff0c;请在控制台开发一个登录界面&#xff0c;接收用户输入的登录名和密码&#xff0c;判断用户是否登录成功&#xff0c;登录成功后展示:“欢迎进入系统!”&…