CBAM: Convolutional Block Attention Module论文总结和代码实现

news2024/11/28 16:37:30

论文:https://arxiv.org/pdf/1807.06521.pdf

中文版:CBAM: Convolutional Block Attention Module中文翻译

源码:https://github.com/Jongchan/attention-module

目录

一、论文的出发点

二、论文的主要工作

三、CBAM模块的具体实现

四、实验

五、总结

六、代码实现


卷积块注意模块(CBAM),一个简单而有效的用于前馈卷积神经网络的注意模块。

给定中间特征图,CBAM模块可以顺序地推导出两个独立维度的注意力图(通道和空间),然后将注意力乘到输入特征图上进行自适应特征细化。

一、论文的出发点

cnn基于其丰富的表征能力,极大地推动了视觉任务的完成,为了提高cnn网络的性能,最近的研究主要聚焦在网络的三个重要因素:深度、宽度和基数。除了这些因素,作者还研究了网络架构的一个不同方面——注意力。注意力研究的目标是通过使用注意机制来增加表现能力:关注重要的特征,并抑制不必要的特征。在本文中,作者提出了一个新的网络模块,名为“卷积块注意模块”(CBAM),该模块用来强调这两个主要维度上的有意义的特征:通道和空间轴,该模块实现方式是通过学习强调或抑制哪些信息,有效地帮助信息在网络中流动

二、论文的主要工作

1. 提出了一种简单而有效的注意力模块(CBAM),可广泛应用于增强cnn的表示能力。
2. 作者验证了我们的注意模块的有效性,通过广泛的消融试验。
3. 通过插入CBAM,作者验证了在多个基准测试(ImageNet-1K, MS COCO,和VOC 2007)上,各种网络的性能都得到了极大的改善。

三、CBAM模块的具体实现

CBAM模块的整体结构图:

该模块有两个顺序子模块:通道(Channel)和空间(Spatial)

1. Channel attention module

目的:利用特征的通道间关系生成通道注意图。

方法通道维度不变,压缩输入特征图的空间维度。

步骤

(1)AP和MP操作:首先通过使用AP(average pooling)和MP(max pooling)操作聚合特征图F的空间信息,生成特征向量\mathbf{F^c_{avg}}\mathbf{F^c_{max}}

(2)转发入共享网络Shared MLP:\mathbf{F^c_{avg}}\mathbf{F^c_{max}}被转发到一个共享网络,共享网络由一个隐含层的多层感知器(MLP)组成,为了减少参数开销,隐含层的激活大小设置为\mathbb{R}^{C/r\times 1\times 1},在该模块中,输入的特征图先再通过一个全连接层将通道数压缩为原来的1/r倍,经过ReLU激活函数进行激活,再通过一个全连接层扩张到原通道数,输出得到两个激活后的特征向量。

(3)特征合并和进行softmax:将共享网络应用到每个特征向量后,使用按元素进行求和并通过softmax函数得到包含通道注意力的特征向量。原文中没有给予这个特征向量命名符,为了方便将其称之为s。

Channel attention module模块整体的算子公式如下所示:

σ是指sigmoid函数,W_0\in \mathbb{R}^{C/r \times C}W_1\in \mathbb{R}^{C \times C/r}

最后将s与原特征图F相乘,得到特征图F',传递给Spatial attention module。

2. Spatial attention module

目的:利用特征间的空间关系生成空间注意图。

方法空间维度不变,压缩通道维度

步骤

(1)AP和MP操作:首先特征图F'使用AP(average pooling)和MP(max pooling)操作得到两个1*H*W的特征图。

(2)拼接和卷积:将它们拼接在一起得到一个2*H*W的特征图,再通过一个7x7的卷积重新得到1*H*W的特征图。

(3)sigmoid:最后,通过一个sigmoid函数,得到包含空间注意力的特征图。原文中没有给予这个特征图命名符,为了方便将其称之为z。

Spatial attention module模块整体的算子公式如下所示:

最后将z与F'进行相乘,就得到了原特征图大小且包含空间和通道注意力的特征图,进行输出。

重复该过程,进行端到端训练,得到最佳的空间和通道注意力。

四、实验

以ResNet作为主干特征提取网络,将CBAM嵌入ResBlock,嵌入位置如下:

实验1:寻找最佳的池化方法,使得通道注意力提取最佳。

实验2:寻找最佳的模块实验顺序。

实验3:寻找最佳的通道池化方法和卷积核大小,使得空间注意力提取最佳。

五、总结

CBAM模块可以顺序地推导出两个独立维度的注意力图(通道和空间),然后将注意力乘到输入特征图上进行自适应特征细化。CBAM模块中子模块Channel attention module与SE模块十分相似,都是经过池化层、全连接层,最后由softmax函数得到channel权重,并且提出了空间注意力提取的子模块,最终得到的特征图同时包含空间和通道注意力。

六、代码实现

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        result = out * self.sa(out)
        return result

if __name__ == '__main__':
    x = torch.randn(1, 1024, 32, 32)
    net = CBAM(1024)
    out = net.forward(x)
    criterion = nn.L1Loss()
    loss = criterion(out, x)
    loss.backward()
    # 最终输出特征图V的size和损失值
    print('out shape : {}'.format(out.shape))
    print('loss value : {}'.format(loss))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/576741.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++系列P5】‘类与对象‘-三部曲——[对象特殊成员](3/3)

前言 大家好吖,欢迎来到 YY 滴 C系列 ,热烈欢迎! 【 类与对象-三部曲】的大纲主要内容如下: 如标题所示,本章是【 类与对象-三部曲】三章中的第三章节——对象&成员章节,主要内容如下: 目录…

1723_PolySpace Bug Finder命令行执行探索

全部学习汇总: GreyZhang/g_matlab: MATLAB once used to be my daily tool. After many years when I go back and read my old learning notes I felt maybe I still need it in the future. So, start this repo to keep some of my old learning notes servral …

小兔鲜--项目总结3

目录 结算模块-地址切换交互实现 地址切换交互需求分析 打开弹框交互实现 地址激活交互实现 订单模块-生成订单功能实现 支付模块-实现支付功能 支付业务流程 支付模块-支付结果展示 支付模块-封装倒计时函数 理解需求 实现思路分析 会员中心-个人中心信息渲染 分页…

【JavaSE】Java基础语法(二十六):Collection集合

文章目录 1. 数组和集合的区别2. 集合类体系结构3. Collection 集合概述和使用【应用】4. Collection集合的遍历【应用】5. 增强for循环【应用】 1. 数组和集合的区别 相同点 都是容器,可以存储多个数据不同点 数组的长度是不可变的,集合的长度是可变的 数组可以存基本数据类型…

【C++系列P4】‘类与对象‘-三部曲——[类](2/3)

前言 大家好吖,欢迎来到 YY 滴 C系列 ,热烈欢迎! 【 类与对象-三部曲】的大纲主要内容如下: 如标题所示,本章是【 类与对象-三部曲】三章中的第二章节——类章节,主要内容如下: 目录 一.类 1.…

CodeForces..学习读书吧.[简单].[条件判断].[找最小值]

题目描述: 题目解读: 给定一组数,分别是 “时间 内容”,内容分为00,01,10,11四种,求能够得到11的最小时间。 解题思路: 看似00,01,10&#xff0…

完整卸载office以及重装office 2021

完整卸载office以及重装 一.背景 之前很早安装的word最近发现打开,编辑等操作都很卡,而且占用的CPU很多,20%左右,而在网上搜索了一些结果无法解决问题后,决定卸载重装 二. 卸载的建议方法 直接参考官方链接从PC卸载…

华为OD机试之租车骑绿岛(Java源码)

租车骑绿岛 题目描述 部门组织绿岛骑行团建活动。租用公共双人自行车,每辆自行车最多坐两人,最大载重M。 给出部门每个人的体重,请问最多需要租用多少双人自行车。 输入描述 第一行两个数字m、n,分别代表自行车限重,部…

k8s 对外服务之 ingress|ingress的对外暴露方式|ingress http,https代理|ingress nginx的认证,nginx重写

k8s 对外服务之 ingress|ingress的对外暴露方式|ingress http,https代理|ingress nginx的认证,nginx重写 一 Ingress 简介二 Ingress 组成三 ingress 暴露服务的方式四 部署 nginx-ingress-controller4.1 修改 ClusterRole 资源配置4.2 DaemonSetHostNet…

STM32HAL库RS485-ModBus协议控制伺服电机

STM32HAL库RS485-ModBus协议控制伺服电机 一个月前,接手了一个学长的毕设小车,小车采用rs485通信的modbus协议驱动轮毂电机,与往常我学习的pwm控制电机方法大相径庭,在这里以这篇博客记录下该学习过程。 小车主要架构 电机型号 …

Python期末复习题库(上)——“Python”

小雅兰期末加油冲冲冲!!! 1. (单选题) Python源程序的扩展名为( A ) A. py B. c C. class D. ph 2. (单选题) 下列( A )符合可用于注释Python代码。 A. # B. */ C. // D. $ 3. (单选题)下列…

SMARTPHONE PLATFORM st解决方案

智能手机是最常用的计算设备。 它们展示了强大的硬件功能和复杂的操作系统,支持高级功能和人工智能应用、互联网和云访问、图像和视频采集、游戏以及语音通话和短信等核心电话功能。 要执行如此多样的应用,智能手机必须包含许多设备,包括大量…

一、电路分析的变量

点我回到主目录 ------------------------------------------------------------------------------------------------------------------------- 目录 1.电流 2.电压 3.功率 4.关联参考方向 5.电路吸收或发出功率的判断 1.电流 •电流 单位A(安培…

vue基于Python的图书商城销售系统qo85w

系统以浏览器/服务器模式即B/S模板式为基础。本系统使用MySQL数据库,利用Python开发的操作系统;主要的功能有个人中心、用户管理、图书资讯管理、图书类型管理、图书信息管理、爬虫管理、留言板管理、系统管理、订单管理等组成。 本文首先介绍了现代化图书销售系统管…

2023电工杯B题全保姆论文讲解手把手教程 人工智能影响评价

更新:电工杯B题全保姆论文成品教程,手把手教你完成高质量成品 这次b题是这一道问卷分析题目,是我最擅长的题目之一了,问卷分析看起来简单,实际上没那么那简单,考验的是我们能不能把数据描述清楚&#xff0…

2023哈佛大学博士后/访问学者研究班一览

哈佛大学是全球顶尖的高等教育机构之一,其所拥有的丰富资源和卓越师资吸引了来自全球各地的优秀学者前来攻读博士学位或作为访问学者进行研究。而博士后访问学者研究班则是哈佛大学提供给这些博士后访问学者的一个重要平台。博士后访问学者研究班是一个跨学科的研究…

echarts 被封装后多次复用,图表被覆盖,解决方法

场景:为了方便样式统一,封装了一个盒子,其中包含echarts,option是从父组件传来的 问题: 多个父级页面使用这个盒子后,发现只有第一个盒子展示图表,但展示的是最后一个图片的样式,其他…

【数据结构】如何应用堆解决海量数据的问题

堆(Heap数据结构堆在计算机科学中有着广泛的应用,今天来介绍两种堆的应用:堆排序、Top-k问题🍉 堆排序 ​ 堆排序是一种基于堆数据结构的排序算法。它的基本思想是,将待排序的序列构建成一个大根堆(或小根堆&#xff…

三展齐发,DBF户外展、高博会、健身展隆重开幕,火爆现场燃炸鹏城!

5月25日,深圳建设国家体育消费试点城市系列活动,第四届DBF深圳国际户外运动博览会,DBF深圳国际高尔夫运动博览会暨深圳国际健身运动博览会(以下简称DBF运动户外生活展)在深圳国际会展中心5.7号馆盛大举办!开…

recurdyn接触特征参数含义

一般接触特征设置 Static Threshold Velocity静态门槛速度:判断静态摩擦和动态摩擦的标准,若相对速度小于此值,摩擦为静摩擦;若相对速度大于此值,摩擦为动摩擦。静态摩擦区域内摩擦系数计算函数为 Dynamic Threshold V…