SE、CBAM、ECA 、CA注意力机制

news2024/11/17 3:38:08

文章目录

  • 1. SE (Squeeze-and-Excitation)
  • 2. CBAM (Convolutional Block Attention Module)
  • 3. ECA (Efficient Channel Attention)
  • 4. CA (Coordinate Attention)

1. SE (Squeeze-and-Excitation)

SENet是通道注意力机制的典型实现
对于SENet而言,其重点是获得输入进来的特征层,每一个通道的权值。利用SENet,我们可以让网络关注它最需要关注的通道。


实现方式:
1、对输入的特征层进行全局平局池化
2、然后进行两次全连接,第一次全连接输出的通道数会少一些,第二次全连接输出的通道数和输入的特征层相同
3、在完成两次全连接之后,会使用一次sigmoid将值固定在[0,1]之间,此时我们获得了输入特征层每一个通道的权值
4、将获得的权值与输入特征层相乘
在这里插入图片描述


优点:
简单有效:SE注意力机制提出简单,易于实现,同时在各种视觉任务中证明了其有效性。
参数少:相较于其他注意力机制,SE模块的参数量相对较少,因此在性能和计算开销之间取得了平衡。

缺点:
计算相对复杂:虽然参数少,但在网络中引入SE模块可能增加计算的复杂性,特别是在大规模网络中。


代码

import torch
from torch import nn

class senet(nn.Module):
    def __init__(self, channel, ration=16):
        super(senet, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // ration, bias=False),
            nn.ReLU(),
            nn.Linear(channel // ration, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, h, w = x.size()
        # b, c ,h, w --> b, c, 1, 1
        avg = self.avg_pool(x).view([b,c])
        fc = self.fc(avg).view([b, c, 1, 1])

        return x * fc

2. CBAM (Convolutional Block Attention Module)

CBAM将通道注意力机制和空间注意力机制进行一个结合,相比于SENet只关注通道的注意力机制可以取得更好的效果。其实现示意图如下所示,CBAM会对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理。
在这里插入图片描述
实现方式
图像的上半部分为通道注意力机制,通道注意力机制的实现可以分为两个部分,我们会对输入进来的单个特征层,分别进行全局平均池化和全局最大池化。之后对平均池化和最大池化的结果,利用共享的全连接层进行处理,我们会对处理后的两个结果进行相加,然后取一个sigmoid,此时我们获得了输入特征层每一个通道的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。

图像的下半部分为空间注意力机制,我们会对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,然后取一个sigmoid,此时我们获得了输入特征层每一个特征点的权值(0-1之间)。在获得这个权值后,我们将这个权值乘上原输入特征层即可。
在这里插入图片描述


优点:
结合了卷积和注意力机制,可以从空间和通道两个方面上对图像进行关注。
缺点:
需要更多的计算资源,计算复杂度更高。


代码

import torch
from torch import nn

#通道注意力
class channel_attention(nn.Module):
    def __init__(self, channel, ration=16):
        super(channel_attention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel//ration, bias=False),
            nn.ReLU(),
            nn.Linear(channel//ration, channel, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        avg_pool = self.avg_pool(x).view([b, c])
        max_pool = self.max_pool(x).view([b, c])

        avg_fc = self.fc(avg_pool)
        max_fc = self.fc(max_pool)

        out = self.sigmoid(max_fc+avg_fc).view([b, c, 1, 1])
        return x * out

#空间注意力
class spatial_attention(nn.Module):
    def __init__(self, kernel_size=7):
        super(spatial_attention, self).__init__()

        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, stride=1,
                              padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        #通道的最大池化
        max_pool = torch.max(x, dim=1, keepdim=True).values
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        pool_out = torch.cat([max_pool, avg_pool], dim=1)
        conv = self.conv(pool_out)
        out = self.sigmoid(conv)

        return out * x

#将通道注意力和空间注意力进行融合
class CBAM(nn.Module):
    def __init__(self, channel, ration=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = channel_attention(channel, ration)
        self.spatial_attention = spatial_attention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x)
        out = self.spatial_attention(out)

        return out


model = CBAM(512)
print(model)
inputs = torch.ones([2,512,26,26])
out = model(inputs)

3. ECA (Efficient Channel Attention)

CANet可以看作是SENet的改进版。
ECANet的作者认为SENet对通道注意力机制的预测带来了副作用,捕获所有通道的依赖关系是低效并且是不必要的。
在ECANet的论文中,作者认为卷积具有良好的跨通道信息获取能力。

ECA模块去除了原来SE模块中的全连接层,直接在全局平均池化之后的特征上通过一个1D卷积进行学习。

既然使用到了1D卷积,那么1D卷积的卷积核大小的选择就变得非常重要了,1D卷积的卷积核大小会影响注意力机制每个权重的计算要考虑的通道数量。用更专业的名词就是跨通道交互的覆盖率。
在这里插入图片描述


优点:
计算效率高:ECA模块采用了一维卷积的方式,相较于二维卷积,在保持性能的前提下降低了计算复杂度。
缺点:
空间信息未利用:ECA主要关注通道信息,相对忽略了空间信息,这可能在某些任务中不是最优的选择。


代码

import torch
from torch import nn
import math
class eca_block(nn.Module):
    def __init__(self, channel, gamma=2, b=1):
        super(eca_block, self).__init__()
        kernel_size = int(abs((math.log(channel,2)+  b)/gamma))
        kernel_size = kernel_size if kernel_size % 2  else kernel_size+1
        padding = kernel_size//2
        self.avg_pool =nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        #变成序列的形式
        avg = self.avg_pool(x).view([b, 1, c])
        out = self.conv(avg)
        out = self.sigmoid(out).view([b, c, 1, 1])
        return  out * x

model = eca_block(512)
print(model)
inputs = torch.ones([2,512,26,26])
outputs = model(inputs)

4. CA (Coordinate Attention)

该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H]和[C, 1, W],之后进行转置。获得两个特征层[C, H, 1]和[C, 1, W]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制

在这里插入图片描述


优点:
准确性高:CA注意力机制能够准确地捕捉不同通道之间的关系,提高了特征表达的准确性。
通用性强:CA注意力机制可以适用于各种不同的网络结构和任务。
缺点:
计算复杂度高:CA模块的计算复杂度较高,特别是在大规模网络中,可能会增加显著的计算开销。


代码

import torch
from torch import nn

class CA_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CA_Block, self).__init__()

        self.conv_1x1 = nn.Conv2d(channel, channel//reduction, kernel_size=1, stride=1, bias=False)

        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(channel//reduction)

        self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
        self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)

        self.sigmoid_h = nn.Sigmoid()
        self.sigmoid_w = nn.Sigmoid()

    def forward(self, x):
        #b,c,h,w
        _, _, h, w = x.size()
        #(b, c, h, w) --> (b, c, h, 1)  --> (b, c, 1, h)
        x_h = torch.mean(x, dim=3, keepdim=True).permute(0, 1, 3, 2)
        #(b, c, h, w) --> (b, c, 1, w)
        x_w = torch.mean(x, dim=2, keepdim=True)
        #(b, c, 1, w) cat (b, c, 1, h) --->  (b, c, 1, h+w)
        #(b, c, 1, h+w) ---> (b, c/r, 1, h+w)
        x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h,x_w), 3))))
        #(b, c/r, 1, h+w) ---> (b, c/r, 1, h)  、 (b, c/r, 1, w)
        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h,w], 3)
        #(b, c/r, 1, h) ---> (b, c, h, 1)
        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
        #(b, c/r, 1, w) ---> (b, c, 1, w)
        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
        #s_h往宽方向进行扩展, s_w往高方向进行扩展
        out = (s_h.expand_as(x) * s_w.expand_as(x)) * x

        return out

model = CA_Block(512)
print(model)

inputs = torch.ones([2,512,26,26])
model(inputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1060295.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

螺杆支撑座有哪些品牌?

螺杆支撑座是机械设备中重要的支撑部件,用于固定和支撑螺杆,以确保机械设备的稳定性和精度。以下是一些生产螺杆支撑座的品牌以及它们的特点: 1、NSK:提供各种高质量的轴承和机械部件,他们的螺杆支撑座采用先进的制造技…

2023.9.26 IO 文件操作详解

目录 文件 文件路径 文件类型 Java 文件操作 文件系统操作 文件内容操作 字节流 InputStream OutputStream 字符流 Reader Writer 补充 close 的必要性 Scanner 的基本了解 文件 当前指硬盘上的文件和文件夹相对于 变量 在内存中,文件 则是在硬盘上 …

竞赛选题 机器视觉目标检测 - opencv 深度学习

文章目录 0 前言2 目标检测概念3 目标分类、定位、检测示例4 传统目标检测5 两类目标检测算法5.1 相关研究5.1.1 选择性搜索5.1.2 OverFeat 5.2 基于区域提名的方法5.2.1 R-CNN5.2.2 SPP-net5.2.3 Fast R-CNN 5.3 端到端的方法YOLOSSD 6 人体检测结果7 最后 0 前言 &#x1f5…

sheng的学习笔记-【中文】【吴恩达课后测验】Course 2 - 改善深层神经网络 - 第二周测验

课程2_第2周_测验题 目录:目录 第一题 1.当输入从第8个mini-batch的第7个的例子的时候,你会用哪种符号表示第3层的激活? A. 【  】 a [ 3 ] { 8 } ( 7 ) a^{[3]\{8\}(7)} a[3]{8}(7) B. 【  】 a [ 8 ] { 7 } ( 3 ) a^{[8]\{7\}(3)} a…

【iptables 实战】9 docker网络原理分析

在开始本章阅读之前,需要提前了解以下的知识 阅读本节需要一些docker的基础知识,最好是在linux上安装好docker环境。提前掌握iptables的基础知识,前文参考【iptables 实战】 一、docker网络模型 docker网络模型如下图所示 说明&#xff1…

23.2 Bootstrap框架3

1.卡片 1.1卡片样式 在Bootstrap 5中, .card, card-header, .card-body, .card-footer类是用于创建卡片样式.下面是这些类的简单介绍: * 1. .card: 用于创建一个基本的卡片容器它作为一个包裹元素,通常与其他卡片类一起使用.* 2. .card-header: 用于创建卡片的头部部分.通常在…

双重差分模型(DID)论文写作指南与操作手册

手册链接:双重差分模型(DID)论文写作指南与操作手册https://www.cctalk.com/m/group/90983583?xh_fshareuid60953990 简介: 当前,对于准应届生们来说,毕设季叠加就业季,写作时间显得十分宝贵…

Sentinel安装

Sentinel 微服务保护的技术有很多,但在目前国内使用较多的还是Sentinel,所以接下来我们学习Sentinel的使用。 1.介绍和安装 Sentinel是阿里巴巴开源的一款服务保护框架,目前已经加入SpringCloudAlibaba中。官方网站: 首页 | Se…

Sublime Text 4 for Mac激活下载

Sublime Text for Mac是一款适用于Mac平台的文本编辑器。它具有快速的性能和丰富的功能,可以帮助用户快速进行代码编写和文本编辑。 软件下载:Sublime Text 4 for Mac激活下载 该软件具有直观的界面和强大的功能,包括多行选择、代码折叠、自动…

【数据开发】DW数仓分层设计架构与同步策略(ODS、DWD、DWS等字段含义)

文章目录 1、什么是数据仓库(DW)2、DW分层设计架构(ODS,DWD,DWS)3、数仓同步策略 1、什么是数据仓库(DW) Data warehouse(可简写为DW或者DWH)数据仓库是什么…

【软考】系统集成项目管理工程师(六)项目整体管理【6分】

一、 前言 1、项目管理三从四得 2、ITO共性总结 1、上一个过程的输出大部分是下-个过程的输入 2、计划和文件是不一样的 (每个输入都有计划和文件) 3、被批准的变更请求约等于计划 4、在执行和监控过程产生新的变更请求(变更请求包括变什么和怎么变,这是变更请求和…

[JAVAee]SpringBoot-AOP

目录 Spring AOP ​编辑AOP适用场景 AOP的组成 连接点(Join Point) 切点(Pointcut) 通知(Advice) Spring AOP的实现 添加依赖 定义切面与切点 切点表达式的说明 定义相关的通知 Spring AOP AOP(Aspect Oriented Programming)是面向切面编程,是一种设计思想.对某一类…

联想Lenovo 威6 15-ITL(82F2)原厂Win10系统

lenovo联想原装出厂系统 自带所有驱动、出厂主题壁纸LOGO、Office办公软件、联想电脑管家等预装程序 下载链接:https://pan.baidu.com/s/1darORHmIyAXkD7HvKRNHNw?pwddh6e 所需要工具:16G或以上的U盘 文件格式:ISO 文件大小:11.…

号卡推广管理系统源码/手机流量卡推广网站源码/PHP源码+带后台版本+分销系统

源码简介: 号卡推广管理系统源码/手机流量卡推广网站源码,基于PHP源码,而且它是带后台版本,分销系统。运用全新UI流量卡官网系统源码有后台带文章。 这个流量卡销售网站源码,PHP流量卡分销系统,它可以支持…

mysql技术文档--阿里巴巴java准则《Mysql数据库建表规约》--结合阿丹理解尝试解读--国庆开卷

阿丹: 国庆快乐呀大家! 在项目开始前一个好的设计、一个健康的表关系,不仅会让开发变的有趣舒服,也会在后期的维护和升级迭代中让系统不断的成长。那么今天就认识和解读一下阿里的准则!! 建表规约 表达是…

【科学文献计量】关于使用metaknowledge读取文献后转化字典结构URLError报错问题的解决方式

关于使用metaknowledge读取文献后转化字典结构URLError报错问题的解决方式 1 报错提醒2 问题解决 1 报错提醒 读入数据后,转化为字典数据结构中,出现URLError报错 2 问题解决 (1) 网络波动 重新运行几次后,自动连…

大厂生产级Redis高并发分布式锁实战

文章目录 一、扣减库存不加锁二、加一把jvm锁试试看三、引入分布式锁四、try finally五、设置key的过期时间六、原子设置锁和过期时间七、给线程设置唯一id八、锁续命redisson九、redisson加锁释放锁的逻辑十、redisson源码分析 一、扣减库存不加锁 先看一段扣减库存的代码 Au…

vscode登录租的新服务器

1.connect to…… 选择 connect current window to host 2.configure SSH Host 选择本地配置文件 打开配置文件,把主机名端口号写进去 再返回vscode远程登录页面,左侧栏就会出现这个主机名了。

Hadoop启动后jps发现没有DateNode解决办法

多次使用 Hadoop namenode -format 格式化节点后DateNode丢失 找到hadoop配置文件core-site.xml查找tmp路径 进入该路径,使用rm -rf data删除data文件 再次使用Hadoop namenode -format 格式化后jps后出现DateNode节点

实现springboot的简单使用~

在之前学习SpringSpringMVCMybatis框架时,我们学习了多种配置spring程序的方式,例如:使用XML,注解,Java配置类,或者是将它们结合使用,但配置文件配置起来依然过于复杂,而我们接下来要…