Pytorch 注意力机制解析与代码实现

news2024/11/26 5:51:11

什么是注意力机制

注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力。

注意力机制的核心重点就是让网络关注到它更需要关注的地方。

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

一般而言,注意力机制可以分为通道注意力机制,空间注意力机制,以及二者的结合。

1.SENet介绍

SE注意力模块是一种通道注意力模块,SE模块能对输入特征图进行通道特征加强,且不改变输入特征图的大小

  1. SE模块的S(Squeeze):对输入特征图的空间信息进行压缩

  2. SE模块的E(Excitation):学习到的通道注意力信息,与输入特征图进行结合,最终得到具有通道注意力的特征图

  3. SE模块的作用是在保留原始特征的基础上,通过学习不同通道之间的关系,提高模型的表现能力。在卷积神经网络中,通过引入SE模块,可以动态地调整不同通道的权重,从而提高模型的表现能力。

实现方式:
1、对输入进来的特征层进行全局平均池化。
2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同。
3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。
4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。

在这里插入图片描述

实现代码

import torch
from torch import nn


class SEAttention(nn.Module):

    def __init__(self, channel=512, reduction=16):
        super().__init__()
        # 对空间信息进行压缩
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 经过两次全连接层,学习不同通道的重要性
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # 取出batch size和通道数
        b, c, _, _ = x.size()
        # b,c,w,h -> b,c,1,1 -> b,c 压缩与通道信息学习
        y = self.avg_pool(x).view(b, c)
        # b,c->b,c->b,c,1,1
        y = self.fc(y).view(b, c, 1, 1)
        # 激励操作
        return x * y.expand_as(x)


if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    se = SEAttention(channel=512, reduction=8)
    output = se(input)
    print(input.shape)
    print(output.shape)

SE模块是一个即插即用的模块,在上图中左边是在一个卷积模块之后直接插入SE模块,右边是在ResNet结构中添加了SE模块。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1161090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch 集群分片出现 unassigned 其中一种原因详细还原

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹…

miniconda快速安装

目录 一、Linux下miniconda安装 1.1、安装 1.2、miniconda初始化 二、Windows下miniconda安装 三、maOS下miniconda安装 3.1、安装 3.2、miniconda初始化 四、参考: 本文给出windows、macos、linux下快速安装miniconda方法。 对比conda,minicond…

XUbuntu22.04之simplenote支持的Markdown语法总结(一百九十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

Qwt 使用QwtDial绘制钟表

1.概述 QwtDial是Qwt库中的一个类,用于绘制一个可旋转的仪表盘,QwtAnalogClock继承自QwtDial, 模拟时钟。 以下是类继承关系: 2.运行结果 自定义Clock类,继承自QwtAnalogClock,增加一个QTimer&#xff0…

【计算机网络笔记】传输层——可靠数据传输之流水线机制与滑动窗口协议

系列文章目录 什么是计算机网络? 什么是网络协议? 计算机网络的结构 数据交换之电路交换 数据交换之报文交换和分组交换 分组交换 vs 电路交换 计算机网络性能(1)——速率、带宽、延迟 计算机网络性能(2)…

基于tpshop开发多商户源码支持手机端+商家+门店 +分销+淘宝数据导入+APP+可视化编辑

tpshop多商户源码,tpshop商城源码,tpshop b2b2c源码-支持手机端商家门店 分销淘宝数据导入APP可视化编辑 tpshop商城源码算是 thinkphp框架里做的比较早 比较好的源码了,写法简明 友好面向程序猿。 这是一款前几年的版本 虽然后台看着好了些,丝毫不影响…

【Linux】关于Nginx的详细使用,部署项目

前言: 今天小编给大家带来的是关于Nginx的详细使用,部署项目,希望可以给正在学习,工作的你带来有效的帮助! 一,Nginx简介 Nginx是一个高性能的开源Web服务器和反向代理服务器。它最初由Igor Sysoev在2004年…

探讨jdk源码中的二分查找算法返回值巧妙之处

文章目录 1.什么是二分查找算法1.1 简介1.2 实现思路 2.二分查找的示例3.jdk 中的 Arrays.binarySearch()4.jdk 中核心二分查找方法解析4.1 为什么 low 是插入点4.2 为什么要进行取反:-(low 1)4.3 为什么不直接返回 插入点 low 的相反数&…

MySQL学习-获取排名,按行更新

获取排名 需求:获取分类平均值的名次? 比如10个班级的平均分,按照班级名称排序,后面跟着名次。 记录表:student ; 字段:banji 班级;AvgS 平均分;pm 排名&#xff1b…

解决问题Conda:CondaValueError: Malformed version string ‘~’ : invalid character(s)

解决问题Conda:CondaValueError: Malformed version string ‘~’ : invalid character(s) 背景 今天使用Conda构建项目运行环境的时候报错::CondaValueError: Malformed version string ‘~’ : invalid character(s) ##报错问题 在安装te…

Express框架开发接口之书城商店原型图

这是利用Axure画的,简单画一下原型图,根据他们的业务逻辑我们完成书城商店API开发 首页 分类 购物车 个人中心

探索C++中的不变之美:const与构造函数的深度剖析

W...Y的主页😊 代码仓库分享💕 🍔前言: 关于C的博客中,我们已经了解了六个默认函数中的四个,分别是构造函数、析构函数、拷贝构造函数以及函数的重载。但是这些函数都是有返回值与参数的。提到参数与返回…

Spring Security 6.1.x 系列(4)—— 基于过滤器链的源码分析

一、自动配置 在 Spring Security 6.1.x 系列(1)—— 初识Spring Security 中我们只引入spring-boot-starter-security 依赖,就可以实现登录认证,这些都得益于Spring Boot 的自动配置。 在spring-boot-autoconfigure模块中集成了…

MyBitis自动拼接了LIMIT

1.前言 最近系统在运营的过程中发现一个很奇怪的问题,莫名其妙的SQL语句会被拼接上一小段SQL,但是发现这被拼接的SQL并不是当前这个API所使用的SQL,因此导致select语句出错。 2.排查思路 2.1.第一步 首先我排查了打印日志里面的错误对应的…

Louis 谈 Restaking:去中心化信任的交流电时刻

人际信任是社会资本的主要形态。信任促成协作(主要是经济交易),是人类文明的基石。 当全球已有数十亿人接入互联网,协作的物理限制已经消除,但传统的人际信任仍然局限于家族、长期积累的声誉和长期相处形成的私人关系…

【JAVA学习笔记】55 - 集合-Map接口、HashMap类、HashTable类、Properties类、TreeMap类(难点)

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter14/src/com/yinhai/map_ Map接口 一、Map接口的特点(难点) 难点在于对Node和Entry和EntrySet的关系 注意:这里讲的是JDK8的Map接口特点 Map java 1) Map与Collect…

【Mquant】2、量化平台的选择

文章目录 一、选择因素二、常见的量化平台三、为什么选择VeighNa?四、参考 一、选择因素 功能和工具集:量化平台应该提供丰富的功能和工具集,包括数据分析、策略回测、实时交易等。不同的平台可能有不同的特点和优势,可以根据自己…

【数据库】形式化关系查询语言(一):关系代数Relational Algebra:基本运算、附加关系代数、扩展的关系代数

目录 一、关系代数Relational Algebra 1. 基本运算 a. 选择运算(Select Operation) b. 投影运算(Project Operation) 组合 c. 并运算(Union Operation) d. 集合差运算(Set Difference Op…

Vue3.0 reactive与ref :VCA模式

简介 Vue3 最大的一个变动应该就是推出了 CompositionAPI,可以说它受ReactHook 启发而来;它我们编写逻辑更灵活,便于提取公共逻辑,代码的复用率得到了提高,也不用再使用 mixin 担心命名冲突的问题。 ref 与 reactive…

pytorch学习第五篇:NN与CNN代码实例

这篇文章详细介绍了全链接神经网络实现方法,以及卷积的实现方法。最后我们发现,卷积的实现方法与全链接大同小异,因为 torch 为我们做了很多工作,我们来看看这两个有什么区别。 我们使用 torch 框架来实现两种神经网络,来对图形进行分类。 NN 首先我们引入依赖包 impor…