注意力机制(SE,ECA,CBAM) Pytorch代码

news2024/11/27 8:48:46

注意力机制

    • 1 SENet
    • 2 ECANet
    • 3 CBAM
      • 3.1 通道注意力
      • 3.2 空间注意力
      • 3.3 CBAM
    • 4 展示网络层具体信息

1 SENet

SE注意力机制(Squeeze-and-Excitation Networks):是一种通道类型的注意力机制,就是在通道维度上增加注意力机制,主要内容是是squeezeexcitation.

在这里插入图片描述
就是使用另外一个新的神经网络(两个Linear层),针对通道维度的数据进行学习,获取到特征图每个通道的重要程度,然后再和原始通道数据相乘即可。
具体参考Blog:
CNN中的注意力机制

小结:

  1. SENet的核心思想是通过全连接网络根据loss损失来自动学习特征权重,而不是直接根据特征通道的数值分配来判断,使有效的特征通道的权重大。

  2. 论文认为excitation操作中使用两个全连接层相比直接使用一个全连接层,它的好处在于,具有更多的非线性,可以更好地拟合通道间的复杂关联。

代码:
拆解步骤,forward代码写的比较细节


import torch
from torch import nn
from torchstat import stat  # 查看网络参数
 
# 定义SE注意力机制的类
class se_block(nn.Module):
    # 初始化, in_channel代表输入特征图的通道数, ratio代表第一个全连接下降通道的倍数
    def __init__(self, in_channel, ratio=4):
        # 继承父类初始化方法
        super(se_block, self).__init__()
        
        # 属性分配
        # 全局平均池化,输出的特征图的宽高=1
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        # 第一个全连接层将特征图的通道数下降4倍
        self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel//ratio, bias=False)
        # relu激活
        self.relu = nn.ReLU()
        # 第二个全连接层恢复通道数
        self.fc2 = nn.Linear(in_features=in_channel//ratio, out_features=in_channel, bias=False)
        # sigmoid激活函数,将权值归一化到0-1
        self.sigmoid = nn.Sigmoid()
        
    # 前向传播
    def forward(self, inputs):  # inputs 代表输入特征图
    
        # 获取输入特征图的shape
        b, c, h, w = inputs.shape
        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        x = self.avg_pool(inputs)
        # 维度调整 [b,c,1,1]==>[b,c]
        x = x.view([b,c])
        
        # 第一个全连接下降通道 [b,c]==>[b,c//4]  # 这里也是使用Linear层的原因,只是对Channel进行线性变换
        x = self.fc1(x)
        x = self.relu(x)
        # 第二个全连接上升通道 [b,c//4]==>[b,c]  # 再通过Linear层恢复Channel数目
        x = self.fc2(x)
        # 对通道权重归一化处理  # 将数值转化为(0,1)之间,体现不同通道之间重要程度
        x = self.sigmoid(x)
        
        # 调整维度 [b,c]==>[b,c,1,1]  
        x = x.view([b,c,1,1])
        
        # 将输入特征图和通道权重相乘
        outputs = x * inputs
        return outputs

结果展示:
在这里插入图片描述
提示: in_channel/ratio需要大于0,否则线性层输入是0维度,没有意义,可以根据自己需求调整ratio的大小

2 ECANet

作者表明 SENet 中的降维会给通道注意力机制带来副作用,并且捕获所有通道之间的依存关系是效率不高的,而且是不必要的。
参考Blog:
CNN中的注意力机制

代码:
详细版本:在forward中,介绍了每一步的作用

import torch
from torch import nn
import math
from torchstat import stat  # 查看网络参数
 
# 定义ECANet的类
class eca_block(nn.Module):
    # 初始化, in_channel代表特征图的输入通道数, b和gama代表公式中的两个系数
    def __init__(self, in_channel, b=1, gama=2):
        # 继承父类初始化
        super(eca_block, self).__init__()
        
        # 根据输入通道数自适应调整卷积核大小
        kernel_size = int(abs((math.log(in_channel, 2)+b)/gama))
        # 如果卷积核大小是奇数,就使用它
        if kernel_size % 2:
            kernel_size = kernel_size
        # 如果卷积核大小是偶数,就把它变成奇数
        else:
            kernel_size = kernel_size + 1
        
        # 卷积时,为例保证卷积前后的size不变,需要0填充的数量
        padding = kernel_size // 2
        
        # 全局平均池化,输出的特征图的宽高=1
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        # 1D卷积,输入和输出通道数都=1,卷积核大小是自适应的
        # 这个1维卷积需要好好了解一下机制,这是改进SENet的重要不同点
        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size,
                              bias=False, padding=padding)
        # sigmoid激活函数,权值归一化
        self.sigmoid = nn.Sigmoid()
    
    # 前向传播
    def forward(self, inputs):
        # 获得输入图像的shape
        b, c, h, w = inputs.shape
        
        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        x = self.avg_pool(inputs)
        # 维度调整,变成序列形式 [b,c,1,1]==>[b,1,c]
        x = x.view([b,1,c])   # 这是为了给一维卷积
        # 1D卷积 [b,1,c]==>[b,1,c]
        x = self.conv(x)
        # 权值归一化
        x = self.sigmoid(x)
        # 维度调整 [b,1,c]==>[b,c,1,1]
        x = x.view([b,c,1,1])
        
        # 将输入特征图和通道权重相乘[b,c,h,w]*[b,c,1,1]==>[b,c,h,w]
        outputs = x * inputs
        return outputs

精简版:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import math


class EfficientChannelAttention(nn.Module):           # Efficient Channel Attention module
    def __init__(self, c, b=1, gamma=2):
        super(EfficientChannelAttention, self).__init__()
        t = int(abs((math.log(c, 2) + b) / gamma))
        k = t if t % 2 else t + 1

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.avg_pool(x)
        # 这里可以对照上一版代码,理解每一个函数的作用
        x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        out = self.sigmoid(x)
        return out

效果展示:
在这里插入图片描述
总结:
ECANet参数更少!

3 CBAM

CBAM注意力机制是由**通道注意力机制(channel)空间注意力机制(spatial)**组成。
先通道注意力,后空间注意力的顺序注意力模块!
在这里插入图片描述

3.1 通道注意力

在这里插入图片描述
输入数据,对数据分别做最大池化操作和平均池化操作(输出都是batchchannel11),然后使用SENet的方法,针对channel进行先降维后升维操作,之后将输出的两个结果相加,再使用Sigmoid得到通道权重,再之后使用View函数恢复**(batchchannel11)**维度,和原始数据相乘得到通道注意力结果!
通道注意力代码:

#(1)通道注意力机制
class channel_attention(nn.Module):
    # 初始化, in_channel代表输入特征图的通道数, ratio代表第一个全连接的通道下降倍数
    def __init__(self, in_channel, ratio=4):
        # 继承父类初始化方法
        super(channel_attention, self).__init__()
        
        # 全局最大池化 [b,c,h,w]==>[b,c,1,1]
        self.max_pool = nn.AdaptiveMaxPool2d(output_size=1)
        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        
        # 第一个全连接层, 通道数下降4倍
        self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel//ratio, bias=False)
        # 第二个全连接层, 恢复通道数
        self.fc2 = nn.Linear(in_features=in_channel//ratio, out_features=in_channel, bias=False)
        
        # relu激活函数
        self.relu = nn.ReLU()
        # sigmoid激活函数
        self.sigmoid = nn.Sigmoid()
    
    # 前向传播
    def forward(self, inputs):
        # 获取输入特征图的shape
        b, c, h, w = inputs.shape
        
        # 输入图像做全局最大池化 [b,c,h,w]==>[b,c,1,1]
        max_pool = self.max_pool(inputs)
        # 输入图像的全局平均池化 [b,c,h,w]==>[b,c,1,1]
        avg_pool = self.avg_pool(inputs)
 
        # 调整池化结果的维度 [b,c,1,1]==>[b,c]
        max_pool = max_pool.view([b,c])
        avg_pool = avg_pool.view([b,c])
 
        # 第一个全连接层下降通道数 [b,c]==>[b,c//4]
        x_maxpool = self.fc1(max_pool)
        x_avgpool = self.fc1(avg_pool)
 
        # 激活函数
        x_maxpool = self.relu(x_maxpool)
        x_avgpool = self.relu(x_avgpool)
        
        # 第二个全连接层恢复通道数 [b,c//4]==>[b,c]
        x_maxpool = self.fc2(x_maxpool)
        x_avgpool = self.fc2(x_avgpool)
        
        # 将这两种池化结果相加 [b,c]==>[b,c]
        x = x_maxpool + x_avgpool
        # sigmoid函数权值归一化
        x = self.sigmoid(x)
        # 调整维度 [b,c]==>[b,c,1,1]
        x = x.view([b,c,1,1])
        # 输入特征图和通道权重相乘 [b,c,h,w]
        outputs = inputs * x
        
        return outputs

3.2 空间注意力

在这里插入图片描述
针对输入数据,分别选取数据中最大值所在的维度(batch1h*w),和按照维度进行数据平均操作(batch1hw),然后将两个数据做通道连接(batch2hw),使用卷积操作,将channel维度降为1,之后对结果取sigmoid,得到空间注意力权重,和原始数据相乘得到空间注意力结果。

代码:

#(2)空间注意力机制
class spatial_attention(nn.Module):
    # 初始化,卷积核大小为7*7
    def __init__(self, kernel_size=7):
        # 继承父类初始化方法
        super(spatial_attention, self).__init__()
        
        # 为了保持卷积前后的特征图shape相同,卷积时需要padding
        padding = kernel_size // 2
        # 7*7卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size,
                              padding=padding, bias=False)
        # sigmoid函数
        self.sigmoid = nn.Sigmoid()
    
    # 前向传播
    def forward(self, inputs):
        
        # 在通道维度上最大池化 [b,1,h,w]  keepdim保留原有深度
        # 返回值是在某维度的最大值和对应的索引
        x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)
        
        # 在通道维度上平均池化 [b,1,h,w]
        x_avgpool = torch.mean(inputs, dim=1, keepdim=True)
        # 池化后的结果在通道维度上堆叠 [b,2,h,w]
        x = torch.cat([x_maxpool, x_avgpool], dim=1)
        
        # 卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        x = self.conv(x)
        # 空间权重归一化
        x = self.sigmoid(x)
        # 输入特征图和空间权重相乘
        outputs = inputs * x
        return outputs

3.3 CBAM

将通道注意力模块和空间注意力模块顺序串联得到CBAM模块!
代码:

class cbam(nn.Module):
    # 初始化,in_channel和ratio=4代表通道注意力机制的输入通道数和第一个全连接下降的通道数
    # kernel_size代表空间注意力机制的卷积核大小
    def __init__(self, in_channel, ratio=4, kernel_size=7):
        # 继承父类初始化方法
        super(cbam, self).__init__()
        
        # 实例化通道注意力机制
        self.channel_attention = channel_attention(in_channel=in_channel, ratio=ratio)
        # 实例化空间注意力机制
        self.spatial_attention = spatial_attention(kernel_size=kernel_size)
    
    # 前向传播
    def forward(self, inputs):
        
        # 先将输入图像经过通道注意力机制
        x = self.channel_attention(inputs)
        # 然后经过空间注意力机制
        x = self.spatial_attention(x)
        
        return x

结果:
在这里插入图片描述

4 展示网络层具体信息

安装包

pip install torchstat

使用

from torchstat import stat 

net = cbam(16)
stat(net, (16, 256, 256))  # 不需要Batch维度

注意力机制后期学习到再持续更新!!

参考博客:
CNN注意力机制
ECANet

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/336054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【光线追踪】光线追踪重投影方法(Ray Tracing Reprojection)

光线追踪重投影方法 重投影这项技术一般用于时间性帧复用技术上,例如TAA(Temporal Anti-Aliasing)反走样或者抗锯齿技术。读这篇文章最好先对TAA这类技术的算法流程有了解。 1.TAA抗锯齿技术简介 先简单介绍下TAA抗锯齿的原理,在游戏中,当前…

解决ThinkPHP5.1出现MISS缓存未命中问题

一淘模板(56admin.com)给大家带来了关于ThinkPHP5.1的相关知识,其中主要介绍了CDN是什么?为什么使用它?怎么解决ThinkPHP5.1 MISS缓存未命中问题?感兴趣的朋友下面一起来看一下吧,希望对大家有帮…

疑难杂症篇(二十一)--Ubuntu18.04安装usb-cam过程出现的问题

对Ubuntu18.04{\rm Ubuntu 18.04}Ubuntu18.04环境下的ROS{\rm ROS}ROS的melodic{\rm melodic}melodic版本安装usb−cam{\rm usb-cam}usb−cam过程出现的两个常见问题提出解决方案。 1.问题1:usb-cam功能包编译时出现"未定义的引用"的问题 问题描述&#…

@RefreshScope 加在 Quartz 触发器类导致异常问题分析

背景 承接上篇,测试过程中又遇到了 Nacos Config 的动态刷新注解 RefreshScope 与 Quartz 框架结合的问题,Bug 排查路上,顺手记录一下吧。 问题 有个模块使用了Quartz ,通过配置控制任务调度的周期和分组名称。 因为引用了动态…

如何解决thinkphp验证码不能显示问题?

thinkPHP做验证码这一块,可以使用自带的验证码扩展,具体步骤如下: 一、安装扩展 composer require topthink/think-captcha 二、模版中使用 将原来静态页面的验证码图片替换为{:captcha_img()},这个会自动生成验证码图片。 <div>{:captcha_img()}</div> 或者 &…

如何理解 K8s 动态伸缩与触发上线?

K8s 版本&#xff1a;1.23.6 目录一、动态伸缩二、触发上线一般地&#xff0c;如果仅修改 Pod 的副本数&#xff08;如新增/缩减&#xff09;&#xff0c;这就属于动态伸缩。如果是修改容器镜像的版本&#xff0c;则会触发上线&#xff0c;具体看下面例子。 一、动态伸缩 1、…

00后整顿职场?公司测试岗却新来了个00后卷王,3个月薪资干到20K...

最近聊到软件测试的行业内卷&#xff0c;越来越多的转行和大学生进入测试行业。想要获得更好的待遇和机会&#xff0c;不断提升自己的技能栈成了测试老人迫在眉睫的问题。 不论是面试哪个级别的测试工程师&#xff0c;面试官都会问一句“会编程吗&#xff1f;有没有自动化测试…

Mybatis框架(全部基础知识)

&#x1f44c; 棒棒有言&#xff1a;也许我一直照着别人的方向飞&#xff0c;可是这次&#xff0c;我想要用我的方式飞翔一次&#xff01;人生&#xff0c;既要淡&#xff0c;又要有味。凡事不必太在意&#xff0c;一切随缘&#xff0c;缘深多聚聚&#xff0c;缘浅随它去。凡事…

2023年了,来试试前端格式化工具

在大前端时代&#xff0c;前端的各种工具链穷出不断&#xff0c;有eslint, prettier, husky, commitlint 等, 东西太多有的时候也是trouble&#x1f602;&#x1f602;&#x1f602;,怎么正确的使用这个是每一个前端开发者都需要掌握的内容&#xff0c;请上车&#x1f697;&…

DHCP Relay中继实验

DHCP Relay实验拓扑图设备配置结果验证拓扑图 要求PC1按照地址池自动分配&#xff0c;而PC要求分配固定的地址&#xff0c;网段信息已经在图中进行标明。 设备配置 AR1&#xff1a; AR1作为DHCP Server基本配置跟DHCP Server没区别&#xff0c;不过要加一条静态路由&#xff…

基础篇:02-SpringCloud概述

1.SpringCloud诞生 基于前面章节&#xff0c;我们深知微服务已成为当前开发的主流技术栈&#xff0c;但是如dubbo、zookeeper、nacos、rocketmq、rabbitmq、springboot、redis、es这般众多技术都只解决了一个或一类问题&#xff0c;微服务并没有一个统一的解决方案。开发人员或…

计算机组成原理(三)

5.掌握定点数的表示和应用&#xff08;主要是无符号数和有符号数的表示、机器数的定点表示、数的机器码表示&#xff09;&#xff1b; 定点数&#xff1a;小数点位置固定不变。   定点小数&#xff1a;小数点固定在数值位与符号位之间&#xff1b;   定点整数&#xff1a;小…

R语言贝叶斯方法在生态环境领域中的高阶技术

贝叶斯统计学即贝叶斯学派是一门基本思想与传统基于频率思想的统计学即频率学派完全不同的统计学方法&#xff0c;它在统计建模中具有灵活性和先进性特点&#xff0c;使其可以轻松应对复杂数据和模型结构。然而&#xff0c;很多初学者在面对思想、技术和方法都与传统统计学有着…

Springcloud----Nacos快速搭建使用

Nacos使用指南 Nacos完整的搭建和项目配置流程&#xff0c;上手简单 一、Nacos安装启动 1.Windows安装 开发阶段采用单机安装即可。 1.1.下载安装包 在Nacos的GitHub页面&#xff0c;提供有下载链接&#xff0c;可以下载编译好的Nacos服务端或者源代码&#xff1a; GitHub主…

尚硅谷的尚融宝项目

先建立一个Maven springboot项目 进来先把src删掉&#xff0c;因为是一个父项目&#xff0c;我们删掉src之后&#xff0c;pom里配置的东西&#xff0c;也能给别的模块使用。 改一下springboot的版本号码 加入依赖和依赖管理&#xff1a; <properties><java.versi…

大型智慧校园系统源码 智慧校园源码 Android电子班牌源码

一款针对中小学研发的智慧校园系统源码&#xff0c;智慧学校源码带电子班牌、人脸识别系统。系统有演示&#xff0c;可正常上线运营正版授权。 私信了解更多&#xff01; 技术架构&#xff1a; 后端&#xff1a;Java 框架&#xff1a;springboot 前端页面&#xff1a;vue e…

keepalived+mysql高可用

一.设置mysql同步信息两节点安装msyql略#配置节点11.配置权限允许远程访问mysql -u root -p grant all on *.* to root% identified by Root1212# with grant option; flush privileges;2.修改my.cnf#作为主节点配置(节点1)#作为主节点配置 server-id 1 …

leetcode刷题 | 关于前缀和题型总结1

leetcode刷题 | 关于前缀和题型总结1 文章目录leetcode刷题 | 关于前缀和题型总结1题目链接和为K的子数组连续数组/0 和 1 个数相同的子数组和大于等于 target 的最短子数组/长度最小的子数组路经总和Ⅲ题目链接 560. 和为 K 的子数组 - 力扣&#xff08;LeetCode&#xff09;…

Python-第三天 Python判断语句

Python-第三天 Python判断语句一、 布尔类型和比较运算符1.布尔类型2.比较运算符二、if语句的基本格式1.if 判断语句语法2.案例三、 if else 语句1.语法2.案例四 if elif else语句1.语法五、判断语句的嵌套1.语法六、实战案例一、 布尔类型和比较运算符 1.布尔类型 布尔&…

【学习笔记】Nginx实战

反向代理实战 解压Tomcat两次-Tomcat8081、Tomcat8082两个文件夹Tomcat8081只需要修改http协议端口8081Tomcat8082&#xff1a;&#xff08;三个都需要改&#xff0c;不然只会启动其中一个&#xff09;1.修改server的默认端口2.修改http协议的默认端口3.膝盖默认ajp协议的默认端…