CBAM注意力机制详解(附pytorch复现)

news2025/2/26 11:42:09

简介

论文原址:1807.06521.pdf (arxiv.org)

CBAM(Convolutional Block Attention Module)是一种卷积神经网络模块,旨在通过引入注意力机制来提升网络的表示能力。CBAM包含两个顺序子模块:通道注意力模块和空间注意力模块。

通过在深度网络的每个卷积块中自适应地优化中间特征图,CBAM通过强调通道和空间维度上的有意义特征,实现了对关键信息的关注和不必要信息的抑制。研究表明,CBAM在ImageNet-1K数据集上能够显著提高各种基线网络的准确性,通过grad-CAM可视化验证,CBAM增强的网络能够更准确地关注目标对象。在MS COCO和VOC 2007数据集上的目标检测任务中,CBAM也展现出显著的性能改进,而由于CBAM精心设计为轻量级模块,其在大多数情况下几乎没有参数和计算开销。CBAM注意力模块可广泛应用于提升卷积神经网络的表示能力。

Channel attention module(CAM)

通过平均池化和最大池化操作,整合输入特征图的空间信息,生成两个不同的空间上下文描述符,得到两个 1×1×C 的特征图,分别表示为 F_c_avg 和 F_c_max。将 F_c_avg 和 F_c_max 分别送入一个共享的多层感知机(MLP),该 MLP 具有一个隐藏层,其中第一层神经元个数为 C/r(r 为减少率),激活函数为 ReLU,第二层神经元个数为 C。这两层神经网络是共享的,即它们的权重相同。将两个 MLP 的输出特征进行逐元素相加,并通过 sigmoid 激活函数,生成通道注意力图 Mc。

这是对池化操作的使用进行实验比较的结果。研究者发现,采用平均池化和最大池化并行的方式能够取得更好的效果。可能是因为采用并行连接方式,相比于单一的池化,能够更有效地保留有用的信息,进而提升模型性能。

Spatial attention module(SAM)

首先,将 Channel Attention 模块输出的特征图作为 Spatial Attention 模块的输入特征图。接着,对输入特征图进行基于通道的全局最大池化和全局平均池化操作,得到两个 H×W×1 的特征图。然后,将这两个特征图在通道维度上进行拼接,经过一个 7×7 的卷积操作,将通道数降维为 1,即得到 H×W×1 的特征图。最后,经过 sigmoid 操作生成空间注意力特征,即 Ms。将该特征与输入特征图进行乘法操作,得到最终生成的特征。这一过程有助于模型关注输入特征图中的重要区域,从而增强表示能力。

CBAM的pytorch实现

"""
Original paper addresshttps: https://arxiv.org/pdf/1807.06521.pdf
Time: 2024-02-28
"""
import torch
from torch import nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # shared MLP
        self.mlp = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7, padding=3):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, in_planes, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        result = out * self.sa(out)
        return result

if __name__ == '__main__':
    block = CBAM(16)
    input = torch.rand(1, 16, 8, 8)
    output = block(input)
    print(output.shape)

参考文章

CBAM——即插即用的注意力模块(附代码)_cbam模块-CSDN博客

[ 注意力机制 ] 经典网络模型2——CBAM 详解与复现_cbam代码复现-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1477453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

视频拉流推流技术梳理

概况 视频的整个流程主要分为推流和拉流 摄像头场景: 摄像头捕捉视频画面,推流到服务器,服务器分发到CDN, 客户端从CDN地址拉流,客户端进行播放 直播场景: 主播通过手机,电脑等客户端&…

强化学习(六)时序差分

时序差分(TD)是强化学习的核心,其是蒙特卡罗(MC)和动态规划(DP)的结合。 1、TD 预测 TD 和 MC 都是利用经验来解决预测问题。一种非平稳环境的一般访问蒙特卡罗方法是 V ( S t ) ← V ( S t …

力扣-H指数

问题 给你一个整数数组 citations ,其中 citations[i] 表示研究者的第 i 篇论文被引用的次数。计算并返回该研究者的 h 指数。 根据维基百科上 h 指数的定义:h 代表“高引用次数” ,一名科研人员的 h 指数 是指他(她&#xff09…

android开发平台,Java+性能优化+APP开发+NDK+跨平台技术

开头 通常作为一个Android APP开发者,我们并不关心Android的源代码实现,不过随着Android开发者越来越多,企业在筛选Android程序员时越来越看中一个程序员对于Android底层的理解和思考,这里的底层主要就是Android Framewok中各个组…

【Linux深入剖析】再续环境变量 | 进程地址空间

📙 作者简介 :RO-BERRY 📗 学习方向:致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 📒 日后方向 : 偏向于CPP开发以及大数据方向,欢迎各位关注,谢谢各位的支持 目录 1.环境变量再续1.1 和…

图书管理系统(使用IO流实现数据的读取和写入)--version4.0

目录 一、项目要求: 二、项目环境 三、项目使用的知识点 四、项目代码 五、项目运行结果 六、项目难点分析 图书管理系统--versions1.0: 图书管理系统--versions1.0-CSDN博客文章浏览阅读981次,点赞29次,收藏17次。本文使用…

Encoding, Encryption, Tokenization 傻傻分不清楚

Encoding, Encryption, Tokenization 傻傻分不清楚 本文转自 公众号 ByteByteGo,如有侵权,请联系,立即删除 今天来聊聊编码 (Encoding), 加密 (Encryption) 和 令牌化 (Tokenization) 的区别。 编码、加密和标记化是三种不同的流程&#xff…

游泳耳机哪个牌子质量好?4大高口碑产品推荐入手

游泳耳机作为一种专业的水上音频装备,能够使游泳者在游泳过程中享受音乐的同时保持安全和舒适。随着科技的发展,市面上涌现出许多品牌和型号的游泳耳机,但是其中哪个牌子的质量更好呢?下面这篇文章将为大家介绍四大热门口碑产品&a…

项目流程图

实现便利店自助付款项目 服务器: 1、并发服务器(多进程、多线程、IO多路复用) 2、SQL数据库的创建和使用(增删改查) 3、以模块化编写项目代码,按照不同模块编写.h/.c文件 客户端: 1、QT客户端界…

dolphinscheduler伪集群部署教程

文章目录 前言一、配置免密登录1. 配置root用户免密登录2. 创建用户2.1 创建dolphinscheduler用户2.2 配置dolphinscheduler用户免密登录2.3 退出dolphinscheduler用户 二、安装准备1. 安装条件2. 安装jdk3. 安装MySQL4. 安装zookeeper4.1 zookeeper单机部署4.1.1 zookeeper3.1…

js 手写深拷贝方法

文章目录 一、深拷贝实现代码二、代码讲解2.1 obj.constructor(obj)2.2 防止循环引用 手写一个深拷贝是我们常见的面试题,在实现过程中我们需要考虑的类型很多,包括对象、数组、函数、日期等。以下就是深拷贝实现逻辑 一、深拷贝实现代码 const origin…

扫码看视频的效果怎么做?在电脑上制作视频活码只需3步

怎么做扫码看视频的效果呢?通过二维码来储存视频并用来做展示用途,是现在很常见的一种二维码应用类型,这种方式可以有效的提升内容的快速传播,而且用户体验也比较好。 那么如何通过视频二维码生成器的功能来制作自己的二维码图片…

提升媒体文字质量:常见错误及改进措施解析

在现代媒体出版中,文字质量直接影响着信息的传递效率和准确性。近期,中国产业报协会全国行业报质检办公室对中央及国家机关主管的84家行业报纸进行了质量检查,发现了一系列共性的文字使用错误。本文旨在深入探讨这些错误,并提出改…

Springboot中ApplicationContextInitializer的使用及源码分析

文章目录 一、认识ApplicationContextInitializer1、ApplicationContextInitializer的作用2、认识ApplicationContextInitializer接口3、ApplicationContextInitializer的常用用法(1)注册BeanFactoryPostProcessor(2)注册Applicat…

关于StartAI本地部署相关问题解答

很多小伙伴们都有接入自己本地SD的需求,对此小编整理了一些相关问题~ 一、本地部署相关条件 对于想要本地部署的小伙伴要了解,相对于使用StartAI试用引擎本地部署更加考验电脑硬件配置备噢~ 流畅使用要nvidia显卡,6g以上显存(最…

Google发布Genie硬杠Sora:通过大量无监督视频训练最终生成可交互虚拟世界

前言 Sora 问世才不到两个星期,谷歌的世界模型也来了,能力看似更强大(嗯,看似):它生成的虚拟世界自主可控 第一部分 首个基础世界模型Genie 1.1 Genie是什么 Genie是第一个以无监督方式从未标记的互联网视频中训练的生成式交互…

浅析前端的堆栈原理以及深浅拷贝原理

浅析前端的堆栈原理以及深浅拷贝原理 首先来看一个案例 const obj {name:hzw,age:18 } let objName2 obj objName2.age 12 console.log(obj,objName2) // {name: hzw, age: 12} {name: hzw, age: 12}这里是不是很奇怪,为什么,为什么我改变objName2的…

使用 Gradle 版本目录进行依赖管理 - Android

/ 前言 / 在软件开发中,依赖管理是一个至关重要的方面。合理的依赖版本控制有助于确保项目的稳定性、安全性和可维护性。 Gradle版本目录(Version Catalogs)是 Gradle 构建工具的一个强大功能,它为项目提供了一种集中管理依赖…

轻松玩转Git

轻松玩转Git 快速入门什么是Git为什么要做版本控制安装git Git实战单枪匹马开始干拓展新功能小结 紧急修复bug分支紧急修复bug方案命令总结工作流 上传GitHub第一天上班前在家上传代码初次在公司新电脑下载代码下班回到家继续写代码到公司继续开发在公司约妹子忘记提交代码回家…

算法——滑动窗口之最大连续1的个数、将x减到0的最小操作数、水果成篮

3.最大连续1的个数 题目:. - 力扣(LeetCode) 题目要求的是给定一个二进制数组 nums 和一个整数 k,如果可以翻转最多 k 个 0 ,则返回 数组中连续 1 的最大个数 。 按照题目正面去做,还要替换0,很麻烦 反正我们最后要求的是最长…