注意力机制(一)SE模块(Squeeze-and-Excitation Networks)论文总结和代码实现

news2025/1/11 7:57:19

Squeeze-and-Excitation Networks(压缩和激励网络)

论文地址:Squeeze-and-Excitation Networks

论文中文版:Squeeze-and-Excitation Networks_中文版

代码地址:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks

目录

一、论文出发点

二、论文的主要工作

三、Squeeze-and-Excitation模块

(1)Transformation(Ftr): 转型

(2)Squeeze:全局信息嵌入

(3)Excitation:自适应重新校正

(4)Scale:重新加权

四、模型:SE-Inception和SE-ResNet

五、实验

六、结论

七、源码分析

 (1)SE模块

 (2)SE-ResNet完整代码


一、论文出发点

为了提高网络的表示能力,许多现有的工作已经显示出增强空间编码的好处。而作者专注于通道,希望能够提出了一种新的架构单元,通过显式地建模出卷积特征通道之间的相互依赖性来提高网络的表示能力。

这里引用“博文:Squeeze-and-Excitation Networks解读”中的总结:核心思想是不同通道的权重应该自适应分配,由网络自己学习出来的,而不是像Inception net一样留下过多人工干预的痕迹。

二、论文的主要工作

1.提出了一种新的架构单元Squeeze-and-Excitation模块,该模块可以显式地建模卷积特征通道之间的相互依赖性来提高网络的表示能力。

2.提出了一种机制,使网络能够执行特征重新校准,通过这种机制可以学习使用全局信息来选择性地强调信息特征并抑制不太有用的特征。

三、Squeeze-and-Excitation模块

(1)Transformation(Ftr): 转型

F_{tr}:X\rightarrow U,经过F_{tr}特征图X变为特征图U。
F_{tr}可以看作一个标准的卷积算子。该卷积算子公式为U_{c}=V_{c}*X=\sum_{s=1}^{C'}V_{c}^{s}*X^{s}

其中:

1.   U=[U_{1},U_{2}...U_{c}],这里U_c指输出特征图的一个单通道2D特征层。

2.   V=[V_{1},V_{2}...V_{c}]表示学习到的一组滤波器核,Vc指的是第c个滤波器的参数,此外V_{c}=[V_{c}^{1},V_{c}^{2}...V_{c}^{c'}]这里 V_{c}^{s}是指一个通道数为1的2D空间核

3.  X=[X^{1},X^{2}...X^{c'}]这里X^{s}是指输入特征图的一个单通道2D特征层

该卷积算子公式表示,输入特征图X的每一层都经过一个2D空间核的卷积最终得到C个输出的feature map,组成特征图U。

原文内容如下:

  • X∈R^(H′×W′×C′):输入特征图
  • U∈R^(H×W×C):输出特征图
  • V:表示学习到的一组滤波器核
  • Vc:指的是第c个滤波器的参数
  • V_{c}^{s}​:表示一个2D的空间核
  • *:卷积操作

(2)Squeeze:全局信息嵌入

Fsq就是使用通道的全局平均池化。
原文中为了解决利用通道依赖性的问题,选择将全局空间信息压缩到一个信道描述符中,即使用通道的全局平均池化,将包含全局信息的W×H×C 的特征图直接压缩成一个1×1×C的特征向量Z,C个feature map的通道特征都被压缩成了一个数值,这样使得生成的通道级统计数据Z就包含了上下文信息,缓解了通道依赖性的问题。
算子公式如下:

Zc为Z的第c个元素。

(3)Excitation:自适应重新校正

目的为了利用压缩操作中汇聚的信息,我们接下来通过Excitation操作来全面捕获通道依赖性。
实现方法
为了实现这个目标,这个功能必须符合两个标准
第一,它必须是灵活的 (它必须能够学习通道之间的非线性交互)
第二,它必须学习一个非互斥的关系,因为独热激活相反,这里允许强调多个通道。
为了满足这些标准,作者采用了两层全连接构成的门机制,第一个全连接层把C个通道压缩成了C/r个通道来降低计算量,再通过一个RELU非线性激活层,第二个全连接层将通道数恢复回为C个通道,再通过Sigmoid激活得到权重s,最后得到的这个s的维度是1×1×C,它是用来刻画特征图U中C个feature map的权重。r是指压缩的比例。

为什么这里要有两个FC,并且通道先缩小,再放大?

因为一个全连接层无法同时应用relu和sigmoid两个非线性函数,但是两者又缺一不可。为了减少参数,所以设置了r比率。

(4)Scale:重新加权

目的:最后是Scale操作,将前面得到的注意力权重加权到每个通道的特征上

实现方法:
特征图U中的每个feature map乘以对应的权重,得到SE模块的最终输出\widetilde{X}

四、模型:SE-Inception和SE-ResNet

通过将一个整体的Inception模块看作SE模块中F_{tr},为Inception网络构建SE模块。

同理, 将一个整体的Residual模块看作SE模块中F_{tr},为ResNet网络构建SE模块。

五、实验

六、结论

本文提出的SE模块,这是一种新颖的架构单元,旨在通过使网络能够执行动态通道特征重新校准来提高网络的表示能力。大量实验证明了SENets的有效性,其在多个数据集上取得了最先进的性能。

七、源码分析

将SEblock嵌入ResNet的残差模块中

 (1)SE模块

'''-------------一、SE模块-----------------------------'''
#全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
    def __init__(self, inchannel, ratio=16):
        super(SE_Block, self).__init__()
        # 全局平均池化(Fsq操作)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        # 两个全连接层(Fex操作)
        self.fc = nn.Sequential(
            nn.Linear(inchannel, inchannel // ratio, bias=False),  # 从 c -> c/r
            nn.ReLU(),
            nn.Linear(inchannel // ratio, inchannel, bias=False),  # 从 c/r -> c
            nn.Sigmoid()
        )

    def forward(self, x):
            # 读取批数据图片数量及通道数
            b, c, h, w = x.size()
            # Fsq操作:经池化后输出b*c的矩阵
            y = self.gap(x).view(b, c)
            # Fex操作:经全连接层输出(b,c,1,1)矩阵
            y = self.fc(y).view(b, c, 1, 1)
            # Fscale操作:将得到的权重乘以原来的特征图x
            return x * y.expand_as(x)

 (2SE-ResNet完整代码

不同版本的ResNet各层主要是由BasicBlock模块(18-layer、34-layer)或Bottleneck模块(50-layer、101-layer、152-layer)构成的,因此只要在BasicBlock模块或Bottleneck模块尾部添加SE模块即可,但是要注意放在shortcut之前,因为shortcut仅是为了保存梯度,把SE模块加在作为提取信息的主干上即可。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

'''-------------一、SE模块-----------------------------'''
#全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
    def __init__(self, inchannel, ratio=16):
        super(SE_Block, self).__init__()
        # 全局平均池化(Fsq操作)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        # 两个全连接层(Fex操作)
        self.fc = nn.Sequential(
            nn.Linear(inchannel, inchannel // ratio, bias=False),  # 从 c -> c/r
            nn.ReLU(),
            nn.Linear(inchannel // ratio, inchannel, bias=False),  # 从 c/r -> c
            nn.Sigmoid()
        )

    def forward(self, x):
            # 读取批数据图片数量及通道数
            b, c, h, w = x.size()
            # Fsq操作:经池化后输出b*c的矩阵
            y = self.gap(x).view(b, c)
            # Fex操作:经全连接层输出(b,c,1,1)矩阵
            y = self.fc(y).view(b, c, 1, 1)
            # Fscale操作:将得到的权重乘以原来的特征图x
            return x * y.expand_as(x)

'''-------------二、BasicBlock模块-----------------------------'''
# 左侧的 residual block 结构(18-layer、34-layer)
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inchannel, outchannel, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(outchannel)
        self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(outchannel)
        # SE_Block放在BN之后,shortcut之前
        self.SE = SE_Block(outchannel)

        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != self.expansion*outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, self.expansion*outchannel,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*outchannel)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        SE_out = self.SE(out)
        out = out * SE_out
        out += self.shortcut(x)
        out = F.relu(out)
        return out

'''-------------三、Bottleneck模块-----------------------------'''
# 右侧的 residual block 结构(50-layer、101-layer、152-layer)
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inchannel, outchannel, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(outchannel)
        self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(outchannel)
        self.conv3 = nn.Conv2d(outchannel, self.expansion*outchannel,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*outchannel)
        # SE_Block放在BN之后,shortcut之前
        self.SE = SE_Block(self.expansion*outchannel)

        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != self.expansion*outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, self.expansion*outchannel,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*outchannel)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        SE_out = self.SE(out)
        out = out * SE_out
        out += self.shortcut(x)
        out = F.relu(out)
        return out

'''-------------四、搭建SE_ResNet结构-----------------------------'''
class SE_ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(SE_ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)                  # conv1
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)       # conv2_x
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)      # conv3_x
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)      # conv4_x
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)      # conv5_x
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        out = self.linear(x)
        return out


def SE_ResNet18():
    return SE_ResNet(BasicBlock, [2, 2, 2, 2])


def SE_ResNet34():
    return SE_ResNet(BasicBlock, [3, 4, 6, 3])


def SE_ResNet50():
    return SE_ResNet(Bottleneck, [3, 4, 6, 3])


def SE_ResNet101():
    return SE_ResNet(Bottleneck, [3, 4, 23, 3])


def SE_ResNet152():
    return SE_ResNet(Bottleneck, [3, 8, 36, 3])


'''
if __name__ == '__main__':
    model = SE_ResNet50()
    print(model)
    input = torch.randn(1, 3, 224, 224)
    out = model(input)
    print(out.shape)
# test()
'''
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = SE_ResNet50().to(device)
    # 打印网络结构和参数
    summary(net, (3, 224, 224))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/582420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

git 文件恢复与项目还原:008

1. 【文件恢复】:将文件恢复到上一次提交的状态 注意:新建且没有提交的文件无法使用文件恢复 命令: git checkout -- 文件名假如我们的一开始是这样的,这是没有报错的状态文件 然后我添加了一段内容, 比如我添加这段内…

做外贸算运费的时候需不需要多算一些

看到一个网友在一篇文章下留言说:客户算运费的时候需不需要多算一些 听公司老员工说给客户算运费要多加20% 这样合适吗 我个人感觉有点离谱。 那我们就这个话题,谈一谈运费是否要多加一些呢?为什么要多加一些? 首先,要…

Zookeeper学习---3、服务器动态上下线监听案例、ZooKeeper 分布式锁案例、企业面试真题

1、服务器动态上下线监听案例 1、需求 某分布式系统中,主节点可以有多台,可以动态上下线,任意一台客户端都能实时感知到主节点服务器的上下线。 2、需求分析 3、具体实现 (1)先在集群上创建/servers 节点 &#xff…

软考A计划-试题模拟含答案解析-卷八

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&am…

2023年上半年系统集成项目管理工程师下午真题及答案解析

试题一(18分) A公司跨国收购了B公司的主营业务,保留了B公司原有的人员组织结构和内部办公系统。为了解决B公司内部办公系统与A公司原有系统不兼容的问题,财务、人力和行政部门联合向公司高层申请尽快启动系统和业务的整合。 A公司领导指定HR总监王工担…

云容灾部署前的准备指南

据ITIC的研究表明,98%的千人规模企业每年都会遭遇停机危机,每停机一小时就会损失约700,000人民币。当灾难发生时,使用云容灾的企业可以通过云平台提供的资源和服务,快速帮助企业恢复业务。 HyperBDR云容灾,深度对接全…

Kibana:使用 Docker 安装 Kibana - 8.x

Kibana 的 Docker 镜像可从 Elastic Docker 注册中心获得。 基本映像是 ubuntu:20.04。www.docker.elastic.co 上提供了所有已发布的 Docker 图像和标签的列表。 源代码在 GitHub 中。 这些镜像包含免费和订阅功能。 开始 30 天试用以试用所有功能。 如果你还没有安装好自己的…

一文了解什么是ChatGPT

ChatGPT 是一种自然语言人工智能聊天机器人。在最基本的层面上,这意味着你可以问它任何问题,它会生成一个答案。 一、如何使用聊天 GPT 首先,转到chat.openai.com。如果这是您的第一次,您需要在开始之前使用 OpenAI 设置一个免费…

C919中有哪些项目是华为之作?

#C919# C919和华为都是我们国人的骄傲。那你知道在C919中有哪些项目是华为之作吗?C919与华为的合作主要涉及航空电子领域: 1.飞机高清视频传输系统:该系统使用华为的数字视频传输技术,可以将高清视频信号快速地传输到地面监控中心…

Gradio的web界面演示与交互机器学习模型,高级接口特征《6》

大多数模型都是黑盒,其内部逻辑对最终用户是隐藏的。为了鼓励透明度,我们通过简单地将Interface类中的interpretation关键字设置为default,使得向模型添加解释变得非常容易。这允许您的用户了解输入的哪些部分负责输出。 1、Interpret解释 …

NetApp E 系列混合闪存阵列——专为需要高带宽的专用应用程序而构建(如数据分析、视频监控、HPC、基于磁盘的备份)

E 系列混合闪存阵列:专为交付而构建 为什么选择 NetApp E 系列阵列? 超过 100 万次的安装和计数 凭借其提供的精简性和可靠性,我们的 E 系列阵列成为了众多企业的首选系统。从推动数据密集型应用程序(如分析、视频监控和基于磁盘…

PLC/DCS系统常见的干扰现象及判断方法

一般来说,常见的干扰现象有以下几种: 1.系统发指令时,电机无规则地转动; 2.信号等于零时,数字显示表数值乱跳; 3。传感器工作时,DCS/PLC 采集过来的信号与实际参数所对应的信号值不吻合,且误…

微信小程序报错:“该小程序提供的服务出现故障,请稍后再试”(IOS报错,Android则正常)

记录对接微信小程序时遇到的问题,问题表现为: 1、发送消息后出现报错:该小程序提供的服务出现故障,请稍后再试 2、只有IOS会报错,Android则是正常的 3、IOS报错的微信号,即使在电脑端登录,使…

HKPCA Show携手电巢直播开启“云”观展!掀起一场电子人的顶级狂欢!

近日,国际电子电路(深圳)展览会(HKPCA Show)已于深圳国际会展中心圆满举办!本次展览划分七大主题专区,面积超50,000平方米,展位超2500个,汇聚众多行业知名、有影响力的参…

腾讯云3年轻量应用服务器和5年CVM云服务器限制说明

腾讯云轻量服务器2核2G4M带宽三年388元、2核4G5M带宽三年599元、CVM云服务器2核2G配置5年1728元、2核4G配置5年3550元、4核8G配置5年6437元,从性价比角度来看,还是轻量应用服务器比较划算,腾讯云百科分享阿里云3年轻量应用服务器和5年云服务器…

华为手机怎么录屏?分享2个好用的手机录屏方法!

案例:华为手机怎么录制屏幕? 【有些内容通过文字和图片,不能很好地表达。我想把内容录制下来,发给别人,方便他们理解。有人知道华为手机怎么录屏吗?】 华为是一款知名的智能手机品牌,其强大的…

PUSH消息推送的实现原理

PUSH消息推送的实现原理_腾讯新闻 编辑导语:如今,push已经成为了我们手机信息流的一种推广方式,那么push消息推送是如何实现的呢?作者总结了几种消息推送的类型以及实现原理,一起来看看。 一、消息推送的类型 1. 短信…

使用 Elastic Learned Sparse Encoder 和混合评分的卓越相关性

作者:The Elastic Platform team 2023 年 5 月 25 今天,我们很高兴地宣布 Elasticsearch 8.8 正式发布。 此版本为矢量搜索带来了多项关键增强功能,让开发人员无需付出通常的努力和专业知识即可在搜索应用程序中利用一流的 AI 驱动技术。 使…

06- AOP(实现案例:记录日志操作)

目录 1. 通知类型 2. 通知顺序 3. 切入点表达式 execution() annotation() 4. 连接点(JoinPoint) 5. 案例:将CRUD接口的相关操作记录到数据库中 AOP: Aspect Oriented Programming (面向切面编程、面向方面编程),其实就是…

Zookeeper学习---2、客户端API操作、客户端向服务端写数据流程

1、客户端API操作 1.1 IDEA 环境搭建 前提&#xff1a;保证 hadoop102、hadoop103、hadoop104 服务器上 Zookeeper 集群服务端启动。 1、创建一个工程&#xff1a;Zookeeper 2、添加pom文件 <?xml version"1.0" encoding"UTF-8"?> <project …