SEnet注意力机制(逐行代码注释讲解)

news2025/1/16 12:58:53

目录

⒈结构图

⒉机制流程讲解

⒊源码(pytorch框架实现)及逐行解释

⒋测试结果


⒈结构图

左边是我自绘的,右下角是官方论文的。


⒉机制流程讲解

通道注意力机制的思想是,对于输入进来的特征层,我们在每一个通道学习不同的权重,这些权重与不同通道的特征相关,决定了每个通道在任务中的重要性。

对于SENet而言,它会对输入特征层进行这些操作:

①首先对输入特征层做了global average pooling,也就是全局平均池化,全局平均池化将对当前特征层取平均值,显然,高、宽分别为H、W的特征层经过平均池化操作后会得到一个实数,这个实数就是所有输入特征层的平均值;另外,平均池化并不影响通道数,因此,输入为C*H*W的特征经过平均池化后,H和W两个维度被压缩,就将得到只剩下C(也就是通道数)这一个维度的特征层。

②然后,对于平均池化输出的矩阵,进行两次全连接,第一次全连接和第二次是不完全相同的,区别在于:第一次全连接的通道数不完整,而是取原通道数的1/r,也就是这边的C/r,第二次则是用正常的通道数进行全连接。

这样做的目的是——能够减少通道个数从而降低计算量,并在一定程度上防止网络模型过拟合。(我在学习SEnet的结构时,看到第一次全连接减少通道数这个操作时,就有联想到神经网络的另一个trick,叫做dropout,dropout是一种正则化技巧,通过随机让神经网络中的部分神经元暂时失活,从而减少模型的过拟合风险,当时我以为SEnet的第一个全连接层就是运用了这个trick,但后来查阅资料时发现不是这样,dropout是随机减少全连接层中的部分神经元,而SEnet在这里是固定减少特征图的通道数,只能说有些异曲同工之妙吧),刚刚是在分享我学习过程遇到的小问题,现在说回正题,全连接1只取原通道数的1/r以此来减少计算量与防止过拟合,但是全连接2又用回原通道数——这样做是为了输出与原特征层相同的通道数,以便后续的最重要的reweight操作,也就是通过乘法逐通道加权到原先的输入特征层上。

值得注意的是,两个全连接层不是简单的直接相连,而是在全连接1后面经过一个relu激活函数,这是全连接层中很常规的操作,用来对一个全连接层的输出结果进行非线性变换,如果不这样做,所有的全连接层都只是普通的线性组合,这样训练出来的模型无法理解复杂的非线性数据和特征,可想而知这样的模型的检测效果肯定是很差的。

relu激活函数的公式其实很简单:f(x) = max(0, x),在x大于等于零时是线性函数,但当输入为负数时,输出为零,在负数部分截断了线性部分,将其映射到了一个确定的点上,从而实现了非线性变换。

自绘烂图,将就看。

③再然后,需要对全连接2的输出结果映射到sigmoid函数中,sigmoid是很经典的激活函数,它的值域是0到1,画一下函数图像(显然x=0时函数值等于0.5)……然后,它的定义域是整个实数集,值域是0到1,也就是说,全连接2的输出结果映射到sigmoid函数中后,就将得到一组0到1之间的值(因此称此操作为归一化),也就是所谓的不同通道的权重。

公式:

自绘烂图,我真的尽力画了/(ㄒoㄒ)/~~

最后最后,将这组通道权重与原输入2特征层通过乘法逐通道加权,就实现了“增强重要的通道,抑制不重要的通道”,也就是所谓的通道注意力机制

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summary
 
 
class SEAttention(nn.Module):
    def __init__(self, inputs, ratio=4):
        super(SEAttention, self).__init__()  # 调用父类构造方法
        _, c, _, _ = inputs.size()# NCHW
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.linear1 = nn.Linear(c, c // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(c // ratio, c, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        n, c, _, _ = inputs.size()
        x = self.avgpool(inputs).view(n, c)#nchw,池化加reshape压缩维度
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)
        x = x.view(n, c, 1, 1) #reshape还原维度
        return inputs * x
 
 
#这边是测试代码,用summary类总结网络模型层
inputs = torch.randn(32, 512, 26, 26)  # NCHW
my_model = SEAttention(inputs)
outputs = my_model(inputs)
summary(my_model.cuda(), input_size=(512, 26, 26))

 解释:

①依赖包为torch,以及torch里的nn模块(导入这个纯粹是省得还要用torch.nn去调用nn的类或方法),summary类是用来测试的,需要提前下载,命令为->pip install torchsummary

②从整体来看,我们运用封装思想将整个模块封装为类,且这个类继承于nn.Moudule这个类,这个类共两部分,

__init__函数用来对实例化对象进行初始化,在python中这个函数属于类的魔术方法。

#代码逐行解释:
def __init__(self, inputs, ratio=4):#self必须写,inputs接收输入张量,ratio是通道衰减因子
        super(SEAttention, self).__init__()  # super关键字调用父类(即nn.Moudule类)的构造方法
        _, c, _, _ = inputs.size()#获取张量的形状(即NCHW),该模块只关注参数C,其余用占位符忽略
        self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化,参数1等同于全局平均池化
        self.linear1 = nn.Linear(c, c // ratio, bias=False)#nn模块的全连接,这里输入c,输出c//ratio,bias是偏置参数,网络层是否有偏置,默认存在,若bias=False,则该网络层无偏置,图层不会学习附加偏差
        self.relu = nn.ReLU(inplace=True)#nn模块的ReLU激活函数,inplace=True表示要用引用传递(即地址传递),估计可以减少张量的内存占用(因为值传递要拷贝一份)
        self.linear2 = nn.Linear(c // ratio, c, bias=False)#同全连接1,但输入输出相反
        self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数

forward函数进行前向传播,用初始化好的网络模型对输入特征层进行一系列加工。

#代码逐行解释:
def forward(self, inputs):#self必须写,inputs接收输入特征张量
        n, c, _, _ = inputs.size()#获取张量形状(即NCHW),HW被忽略
        x = self.avgpool(inputs).view(n, c)#nchw,池化加view方法重塑(reshape)张量形状,因为全连接层之间的张量必须是二维的(一个输入维度一个输出维度),view的参数是(n,c)表示只保留这两个维度
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.sigmoid(x)#上面这四行直接调用初始化好的网络层即可
        x = x.view(n, c, 1, 1) #reshape还原维度,因为要和原输入特征相乘,不重塑形状不同无法相乘
        return inputs * x#和原输入特征层相乘

⒋测试结果

感觉summary类没有很好使。。。有些关键网络层的变换没有体现出来,这里是少了最后reshape的一层,但无伤大雅罢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1225391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

交通 | 神奇动物在哪里?Operations Research经典文章

论文作者:Robert G. Haight, Charles S. Revelle, Stephanie A. Snyder​ 论文原文:Robert G. Haight, Charles S. Revelle, Stephanie A. Snyder, (2000) An Integer Optimization Approach to a Probabilistic Reserve Site Selection Problem. Operat…

软件开发、网络空间安全、人工智能三个方向的就业和前景怎么样?哪个方向更值得学习?

软件开发、网络空间安全、人工智能这三个方向都是当前及未来的热门领域,每个领域都有各自的就业前景和价值,以下是对这三个方向的分析: 1、软件开发: 就业前景:随着信息化的加速,软件开发的需求日益增长。…

16万亿Web3蓝图落地新加坡

作者:秦晋 11月15日,新加坡金管局(MAS)宣布与金融行业合作,以扩大资产代币化计划,并开发扩大代币化市场的基础能力。Project Guardian 由 17 家金融机构组成,启动五个行业试点,以测试…

优化|优化求解器自动调参

原文信息:MindOpt Tuner: Boost the Performance of Numerical Software by Automatic Parameter Tuning 作者:王孟昌 (达摩院决策智能实验室MindOpt团队成员) 一个算法开发者,可能会幻想进入这样的境界:算…

LeetCode【4】寻找两个正序数组中位数

题目: 思路: https://blog.csdn.net/a1111116/article/details/115033098 代码: public double findMedianSortedArrays(int[] nums1, int[] nums2) {int[] ints Arrays.copyOf(nums1, nums1.length nums2.length);System.arraycopy(nums2…

ROS 学习应用篇(八)ROS中的坐标变换管理之tf广播与监听的编程实现

偶吼吼胜利在望,冲冲冲 老规矩新建功能包 工作空间目录下/src下开启终端输入 catkin_create_pkg learning_tf roscpp rospy tf turtlesim 如何实现tf广播 引入库 c python …

报道 | 2023年12月-2024年2月国际运筹优化会议汇总

2023年12月-2024年2月召开会议汇总: The 16th Annual International Conference on Combinatorial Optimization and Applications (COCOA 2023) Location: Virtual Important dates: Conference: December 11, 2023 (Start) - December 13, 2023 (End) Details…

与博主交流

我是一个性格比较随和且有些内敛的人,喜欢与人交流技术。 如果你有一些问题想与我交流,请联系我。 交流说明:请直接描述你的需求。

电子学会C/C++编程等级考试2021年12月(一级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:输出整数部分 输入一个双精度浮点数f, 输出其整数部分。 时间限制:1000 内存限制:65536输入 一个双精度浮点数f(0 < f < 100000000)。输出 一个整数,表示浮点数的整数部分。样例输入 3.8889样例输出 3 答案: //参…

python算法例15 合并数字

1. 问题描述 给出n个数&#xff0c;将这n个数合并成一个数&#xff0c;每次只能选择两个数a、b合并&#xff0c;合并需要消耗的能量为ab&#xff0c;输出将n个数合并成一个数后消耗的最小能量。 2. 问题示例 给出[1&#xff0c;2&#xff0c;3&#xff0c;4]&#xff0c;返回…

【信息安全】浅谈IDOR越权漏洞的原理、危害和防范:直接对象引用导致的越权行为

前言 ┌──────────────────────────────────┐ │ 正在播放《越权访问》 - Hanser │ ●━━━━━━─────── 00:00 / 03:05 │ ↻ ◁ ❚❚ ▷ ⇆ └───────────────────────────────…

C/C++数据结构之堆栈(Stack):理解、实现与运用

当我们讨论堆栈时&#xff0c;我们首先需要了解它的概念和基本原理。堆栈是一种后进先出&#xff08;Last In, First Out&#xff0c;LIFO&#xff09;的数据结构&#xff0c;它的操作主要包括压栈&#xff08;Push&#xff09;和弹栈&#xff08;Pop&#xff09;&#xff0c;以…

学习css过渡动画-transition

文章目录 前言transition属性语法宽度改变效果透明度改变效果位置改变效果如有启发&#xff0c;可点赞收藏哟~ 前言 通常&#xff0c;当一个元素的样式属性值发生变化时&#xff0c;会立即看到页面发生变化。 css属性transition能让页面元素不是立即的、而是慢慢的从一种状态变…

java 访问sqlserver 和 此驱动程序不支持jre1.8错误

sqlserver数据如下&#xff1b; TestSQL.java&#xff1b; import java.sql.*;public class TestSQL {public static void main(String[] args) throws ClassNotFoundException, SQLException {String driverName "com.microsoft.sqlserver.jdbc.SQLServerDriver";…

工程化实战 - 前端AST(进阶)

###脚手架 *快速自动化搭建启动工具 目标: ####第一步:处理依赖 npm i path npm i chalk4.1.0 npm i fs-extra npm i inquirer8.2.2 npm i commander npm i axios npm i download-git-repo //ora easy-table figlet ####第二步:处理工程入口 ####3.加入命令交互 交互好帮手…

LeetCo

题目描述如下&#xff1a; 罗马数字包含以下七种字符: I&#xff0c; V&#xff0c; X&#xff0c; L&#xff0c;C&#xff0c;D 和 M。 字符 数值 I 1 V 5 X 10 L 50 C 100 D 500 M …

开源WIFI继电器之方案介绍

一、实物 1、外观 2、电路板 二、功能说明 输出一路继电器常开信号&#xff0c;最大负载电流10A输入一路开关量检测联网方式2.4G Wi-Fi通信协议MQTT配网方式AIrkiss&#xff0c;SmartConfig设备管理本地Web后台管理&#xff0c;可配置MQTT参数供电AC220V其它一个功能按键&…

08-黑马点评项目发布笔记和查看笔记功能的实现

发布笔记 数据模型 tb_blog探店笔记表,包含笔记的标题、文字、图片等 tb_blog探店笔记表对应的实体类 增加用户图标和和用户姓名以及是否被点赞过了的字段,这些字段不属于Blog表只是为了实现在展示笔记的时候同时展示用户的信息 Data EqualsAndHashCode(callSuper false) …

<MySQL> 什么是JDBC?如何使用JDBC进行编程?

目录 一、JDBC是什么&#xff1f; 二、JDBC常用接口和类 2.1 DataSource 2.2 Connection 2.3 Statement 2.4 ResultSet 三、JDBC的使用 3.1 获得数据库驱动包 3.2 添加到项目依赖 3.3 描述数据库服务器 3.4 建立数据库连接 3.6 执行SQL语句和接收返回数据 3.7 释放…

利用 Pandoc + ChatGPT 优雅地润色论文,并保持 Word 公式格式:Pandoc将Word和LaTeX文件互相转化

论文润色完美解决方案&#xff1a;Pandoc 与 ChatGPT 的强强联合 写在最前面其他说明 一、通过 Pandoc 将 Word 转换为 LaTeX 的完整指南步骤 1: 安装 PandocWindows:macOS:Linux: 步骤 2: 准备 Word 文档步骤 3: 转换文档步骤 4: 检查并调整输出步骤 5: 编译 LaTeX 文档总结 二…