CBAM注意力机制(结构图加逐行代码注释讲解)

news2024/11/18 7:32:40

学CBAM前建议先学会SEnet(因为本篇涉及SEnet的重合部分会略加带过)->传送门

⒈结构图

下面这个是自绘的,有些许草率。。。

因为CBAM机制是由通道和空间两部分组成的,所以有这两个模块(左边是通道注意力机制,右边是空间注意力机制)

下面这两个是官方论文里的:

⒉机制流程讲解

SEnet只关注了通道注意力机制而忽略了空间上的一些简单特征,相比之下,CBAM将通道注意力机制和空间注意力机制进行一个结合,对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理,而是是先通道后空间,也就是第一张结构图表达的意思。

①首先是通道机制:

对于输入特征层,分别作全局最大池化和全局平均池化,输出结果分别送入一个共享全连接层(官方源码在这里和SEnet的全连接层一模一样),为什么叫共享全连接层?因为最大池化和平均池化的两条路线用的是这同一个全连接层。然后对两个结果(maxout和avgout)做加法,最后进行归一化操作,获得通道上的权重矩阵。

②然后是空间机制:

对于输入特征层,在每一个特征点的通道上取最大值和平均值,(这里和通道机制的最大池化和平均池化完全不同,通道机制里是在H、W两个维度求最大或平均,空间机制是在C一个维度上求最大和平均。)然后对两个结果(maxout和avgout)做拼接,也就是maxout的1*H*W与avgout的1*H*W进行拼接,得到2*H*W的张量,因此紧接着下一步就要进行一个7*7的卷积(conv)将通道压缩回1,最后还是进行归一化操作,获得空间上的权重矩阵。

③整体上:

对于输入特征层,输入特征层先乘上通道机制的输出权重(channel_out),然后再乘上空间上的输出权重(spatial_out)

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summary
 
class ChannelModule(nn.Module):
    def __init__(self, inputs, ratio=16):
        super(ChannelModule, self).__init__()
        _, c, _, _ = inputs.size()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.share_liner = nn.Sequential(
            nn.Linear(c, c // ratio),
            nn.ReLU(),
            nn.Linear(c // ratio, c)
        )
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        x = self.maxpool(inputs).view(inputs.size(0), -1)#nc
        maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#nchw
        y = self.avgpool(inputs).view(inputs.size(0), -1)
        avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(maxout + avgout)
 
 
class SpatialModule(nn.Module):
    def __init__(self):
        super(SpatialModule, self).__init__()
        self.maxpool = torch.max
        self.avgpool = torch.mean
        self.concat = torch.cat
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#n1hw
        avgout = self.avgpool(inputs, dim=1, keepdim=True)#n1hw
        outs = self.concat([maxout, avgout], dim=1)#n2hw
        outs = self.conv(outs)#n1hw
        return self.sigmoid(outs)
 
 
class CBAM(nn.Module):
    def __init__(self, inputs):
        super(CBAM, self).__init__()
        self.channel_out = ChannelModule(inputs)
        self.spatial_out = SpatialModule()
 
    def forward(self, inputs):
        outs = self.channel_out(inputs) * inputs
        return self.spatial_out(outs) * outs

 解释:

①依赖包和SEnet解释的一样。

②整体上看,将通道机制和空间机制分别封装成类,再封装一个CBAM类来对这两个机制调用,其中用到的__init__构造方法(python称魔术方法)和foward函数(前向传播过程),这些模板和上面介绍SEnet时是一模一样的。

先来看通道机制:

class ChannelModule(nn.Module):#继承nn模块的Module类
    def __init__(self, inputs, ratio=16):#self必写,inputs接收输入特征张量,ratio是通道衰减因子
        super(ChannelModule, self).__init__()#调用父类构造
        _, c, _, _ = inputs.size()#获取通道数
        self.maxpool = nn.AdaptiveMaxPool2d(1)#nn模块的自适应二维最大池化
        self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化
        self.share_liner = nn.Sequential(
            nn.Linear(c, c // ratio),
            nn.ReLU(),
            nn.Linear(c // ratio, c)
        )#这个共享全连接的3层和SEnet的一模一样,这里借助Sequential这个容器把这3个层整合在一起,方便forward函数去执行,直接调用share_liner(x)相当于直接执行了里面这3层
        self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数
 
    def forward(self, inputs):
        x = self.maxpool(inputs).view(inputs.size(0), -1)#对于输入特征张量,做完最大池化后再重塑形状,view的第一个参数inputs.size(0)表示第一维度,显然就是n;-1表示会自适应的调整剩余的维度,在这里就将原来的(n,c,1,1)调整为了(n,c*1*1),后面才能送入全连接层(fc层)
        maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#做完全连接后,再用unsqueeze解压缩,也就是还原指定维度,这里用了两次,分别还原2维度的h,和3维度的w
        y = self.avgpool(inputs).view(inputs.size(0), -1)
        avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)#y走的平均池化路线的代码和x是一样的解释
        return self.sigmoid(maxout + avgout)#最后相加两个结果并作归一化

再来看空间机制:(重复的模板就不再反复赘述了)

class SpatialModule(nn.Module):
    def __init__(self):
        super(SpatialModule, self).__init__()
        self.maxpool = torch.max
        self.avgpool = torch.mean
        #和通道机制不一样!这里要进行的是在C这一个维度上求最大和平均,分别用的是torch库里的max方法和mean方法
        self.concat = torch.cat#torch的cat方法,用于拼接两个张量
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)#nn模块的二维卷积,其中的参数分别是:输入通道(2),输出通道(1),卷积核大小(7*7),步长(1),灰度填充(3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#maxout接收特征点的最大值很好理解,为什么还要一个占位符?因为torch.max不仅返回张量最大值,还会返回索引,索引用不着所以直接忽略,dim=1表示在维度1(也就是nchw的c)上求最大值,keepdim=True表示要保持原来张量的形状
        avgout = self.avgpool(inputs, dim=1, keepdim=True)#torch.mean则只返回张量的平均值,至于参数的解释和上面是一样的
        outs = self.concat([maxout, avgout], dim=1)#torch.cat方法,传入一个列表,将列表中的张量在指定维度,这里是维度1(也就是nchw的c)拼接,即n*1*h*w拼接n*1*h*w得到n*2*h*w
        outs = self.conv(outs)#卷积压缩上面的n*2*h*w,又得到n*1*h*w
        return self.sigmoid(outs)

 最后看整体:

class CBAM(nn.Module):
    def __init__(self, inputs):
        super(CBAM, self).__init__()
        self.channel_out = ChannelModule(inputs)#获得通道权重
        self.spatial_out = SpatialModule()#获得空间权重
 
    def forward(self, inputs):
        outs = self.channel_out(inputs) * inputs #先乘上通道权重
        return self.spatial_out(outs) * outs #在乘完通道权重的基础上再乘上空间权重

⒋测试结果

大问题没有,但还是少了一些关键层,尤其是空间机制那里的拼接maxout和avgout,通道变为2再用卷积压缩回1的过程都没体现。。。只能说summary确实不太好使,或者说我没用对?网络层简写导致的?(最不可能是这个原因,因为我拿官方的源码测试也是summary出这些结果)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1226440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在el-tree懒加载并且包含下级的情况下进行数据回显-02

上一篇文章如何在el-tree懒加载并且包含下级的情况下进行数据回显-01对于el-tree懒加载,包含下级的情况下,对于回显提出两种方案,第一种方案有一些难题无法解决,我们重点来说说第二种方案。 第二种方案是使用这个变量对其是否全选…

ScalableMap

问题引入 传统方案在处理线性地图元素时忽略了其结构性约束,建图距离太近 方法 简介 结构引导BEV特征提取 一种新的层次稀疏地图表示方法 设计渐进解码机制和基于此表示的监督策略 组件 结构引导BEV表征 通过车载摄像头捕捉的环绕视图图像,利用Res…

tamarin运行

首先我们找到安装tamarin的文件位置,找到以后进入该文件夹下 ubuntuubuntu:~$ sudo find / -name tamarin-prover /home/linuxbrew/.linuxbrew/var/homebrew/linked/tamarin-prover /home/linuxbrew/.linuxbrew/Cellar/tamarin-prover /home/linuxbrew/.linuxbrew/…

数据结构【DS】图的基本概念

定义 完全图(简单完全图) 完全无向图:边数为𝐧𝐧−𝟏𝟐完全有向图:边数为 𝐧(𝐧−𝟏) 子图、生成子图 G的子图:所有的顶点和边都属于图G的图 G的生成子图…

五个必知的速率限制策略,以最大化流量流动

速率限制是一种策略,我们在工作中常常使用,它定义了系统在设定的时间框架内可以处理的最大请求数量。 速率限制定义了系统在指定时间段内可以处理的最大请求数量。 Image.png 速率限制是一种策略,我们在工作中常常使用,它定义了…

【IPC】消息队列

1、IPC对象 除了最原始的进程间通信方式信号、无名管道和有名管道外,还有三种进程间通信方式,这 三种方式称之为IPC对象 IPC对象分类:消息队列、共享内存、信号量(信号灯集) IPC对象也是在内核空间开辟区域,每一种IPC对象创建好…

【开源】基于Vue.js的在线课程教学系统的设计和实现

项目编号: S 014 ,文末获取源码。 \color{red}{项目编号:S014,文末获取源码。} 项目编号:S014,文末获取源码。 目录 一、摘要1.1 系统介绍1.2 项目录屏 二、研究内容2.1 课程类型管理模块2.2 课程管理模块2…

【Feign】 基于 Feign 远程调用、 自定义配置、性能优化、实现 Feign 最佳实践

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 SpringCloud MybatisPlus JVM 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Feign 一、 基于 Feign 远程调用1.1 RestTemplate方式…

MySQL 教程 1.1

MySQL 教程1.1 MySQL 是最流行的关系型数据库管理系统,在 WEB 应用方面 MySQL 是最好的 RDBMS(Relational Database Management System:关系数据库管理系统)应用软件之一。 在本教程中,会让大家快速掌握 MySQL 的基本知识,并轻松…

数字IC前端学习笔记:异步复位,同步释放

相关阅读 数字IC前端https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 异步复位 异步复位是一种常见的复位方式,可以使电路进入一个可知的状态。但是不正确地使用异步复位会导致出现意想不到的错误,复位释放便是…

分组表,分桶表

1,启动Hive服务 (1)启动HiveServer2服务 nohup hive --service metastore &(2)启动Metastore服务 nohup hive --service hiveserver2 &(3)查看进程信息 lsof -i:100002,…

hologres 索引与查询优化

hologres 优化部分 1 hologres 建表优化1.1 建表中的配置优化1.1 字典索引 dictionary_encoding_columns1.2 位图索引 bitmap_columns1.2.2 Bitmap和Clustering Key的区别 1.3 聚簇索引Clustering Key 1 hologres 建表优化 1.1 建表中的配置优化 根据 holo的 存储引擎部分的知…

公共字段自动填充-@TableField的fill实现(2)

TheadLocal 客户端发送的每次http请求,在服务端都会分配新的线程。因此登录检查过滤器、controller、元数据对象处理器属于一个线程。 TheadLocal是线程的局部变量: TheadLocal常用方法: 如何在元数据对象处理器中获取当前登录用户的id&…

力扣hot100 两数之和 哈希表

&#x1f468;‍&#x1f3eb; 力扣 两数之和 &#x1f60b; 思路 在一个数组中如何快速找到某一个数的互补数&#xff1a;哈希表 O(1)实现⭐ AC code class Solution {public int[] twoSum(int[] nums, int target){HashMap<Integer, Integer> map new HashMap<&g…

shell脚本学习笔记07

如何让shell实现 可选择性执行 的功能 用了while进行循环&#xff0c;是死循环&#xff0c;在循环时&#xff0c;使用case进行使用哪个脚本进行执行。使用clear进行每一次操作前的清屏&#xff0c;eof代表输入这个会显示目录。read用来读取输入的值&#xff0c;如果不输入值不会…

初始ProtoBuf

目录​​​​​​​ ⼀、初识ProtoBuf 1. 序列化概念 2. ProtoBuf是什么 3. ProtoBuf的使用特点 ⼆、安装ProtoBuf 1、ProtoBuf在window下的安装 2、ProtoBuf在Linux下的安装 ⼀、初识ProtoBuf 1. 序列化概念 序列化和反序列化 序列化&#xff1a;把对象转换为字节序列…

【开源】基于Vue.js的高校实验室管理系统的设计和实现

项目编号&#xff1a; S 015 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S015&#xff0c;文末获取源码。} 项目编号&#xff1a;S015&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 实验室类型模块2.2 实验室模块2.3 实…

LeetCode【12】整数转罗马数字

题目&#xff1a; 思路&#xff1a; https://blog.csdn.net/m0_71120708/article/details/128769894 代码&#xff1a; public String intToRoman(int num) {String[] thousands new String[] {"", "M", "MM", "MMM"};String[] hun…

【C++历练之路】list的重要接口||底层逻辑的三个封装以及模拟实现

W...Y的主页 &#x1f60a; 代码仓库分享&#x1f495; &#x1f354;前言&#xff1a; 在C的世界中&#xff0c;有一种数据结构&#xff0c;它不仅像一个神奇的瑰宝匣&#xff0c;还像一位能够在数据的海洋中航行的智慧舵手。这就是C中的list&#xff0c;一个引人入胜的工具…

千梦网创:外貌与内貌

一、怎样提高身价&#xff1f; 同样的商品或服务怎样卖得更贵&#xff1f; 要么通过更贵的渠道、要么通过更好的包装。 水还是那个水&#xff0c;放在星巴克可以卖很贵&#xff0c;印上不同的logo可以卖不同的价格。 拿线下的教育培训行业来说&#xff0c;真正让你去测评哪…