金字塔切分注意力模块PSA学习笔记 (附代码)

news2025/1/17 3:03:44

已有研究表明:将注意力模块嵌入到现有CNN中可以带来显著的性能提升。比如,SENet、BAM、CBAM、ECANet、GCNet、FcaNet等注意力机制均带来了可观的性能提升。但是,目前仍然存在两个具有挑战性的问题需要解决。一是如何有效地获取和利用不同尺度的特征图的空间信息,丰富特征空间。二是通道注意力或者或空间注意力只能有效捕获局部信息,而不能建立长期的依赖关系。最新的一些方法虽然能有效解决上述问题,但是他们同时会带来巨大的计算负担。基于此,本文首先提出了一种新颖的轻量且高效的PSA注意力模块。PSA模块可以处理多尺度的输入特征图的空间信息并且能够有效地建立多尺度通道注意力间的长期依赖关系。然后,我们将PSA 模块替换掉ResNet网络Bottleneck中的3x3x卷积,其余保持不变,最后得到了新的EPSA(efficient pyramid split attention) block.基于EPSA block我们构建了一个新的骨干网络称作:EPSANet。它既可以提供强有力的多尺度特征表示能力。与此同时,EPSANet不仅在图像识别任务中的Top-1 Acc大幅度优于现有技术,而且在计算参数量上有更加高效。具体效果:如下图所示,

论文地址:https://arxiv.org/pdf/2105.14447v1.pdf

代码地址:https://gitcode.com/mirrors/murufeng/epsanet/blob/master/models/epsanet.py

1.是什么?

Pyramid Split Attention (PSA)是一种基于注意力机制的模块,它可以用于图像分类、目标检测等任务中。PSA模块通过将不同大小的卷积核的卷积结果进行拼接,形成一个金字塔状的特征图,然后在这个特征图上应用注意力机制,以提取更加丰富的特征信息。与其他注意力模块相比,PSA模块具有轻量、简单高效等特点,可以与ResNet等主流网络结构结合使用,提高模型的性能。

2.为什么?

1. SE仅仅考虑了通道注意力,忽略了空间注意力。
2. BAM和CBAM考虑了通道注意力和空间注意力,但仍存在两个最重要的缺点:(1)没有捕获不同尺度的空间信息来丰富特征空间。(2)空间注意力仅仅考虑了局部区域的信息,而无法建立远距离的依赖。
3. 后续出现的PyConv,Res2Net和HS-ResNet都用于解决CBAM的这两个缺点,但计算量太大。

基于以上三点分析,提出了Pyramid Split Attention。

3.怎么样?

3.1网络结构

PSA模块主要通过四个步骤实现

  • 首先,利用SPC模块来对通道进行切分,然后针对每个通道特征图上的空间信息进行多尺度特征提取;
  • 其次,利用SEWeight模块提取不同尺度特征图的通道注意力,得到每个不同尺度上的通道注意力向量;
  • 第三,利用Softmax对多尺度通道注意力向量进行特征重新标定,得到新的多尺度通道交互之后的注意力权重。
  • 第四,对重新校准的权重和相应的特征图按元素进行点乘操作,输出得到一个多尺度特征信息注意力加权之后的特征图。该特征图多尺度信息表示能力更丰富。

3.2 SPC module

从前面我们了解到:PSA的关键在于多尺度特征提取,即SPC模块。假设输入为X,我们先将其拆分为S部分[X_{0},X_{1},...,X_{S-1}],然后对不同部分提取不同尺度特征,最后将所提取的多尺度特征通过Concat进行拼接。上述过程可以简单描述如下:

在上述特征基础上,我们对不同部分特征提取注意力权值,公式如下:

为更好的实现注意力信息交互并融合跨维度信息,我们将上述所得注意力向量进行拼接,即
Z = Z_{0} \bigoplus Z_{1} \bigoplus \cdot \cdot \cdot \bigoplus Z_{S-1}.然后,我们再对所得注意力权值进行归一化,定义如下:

最后,我们即可得到校正后的特征:Y = F ⊙ att

3.3 框图


 

3.4 代码实现

SEWeightModule


import torch.nn as nn

class SEWeightModule(nn.Module):

    def __init__(self, channels, reduction=16):
        super(SEWeightModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)

        return weight

PSAModule 

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    """standard convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class PSAModule(nn.Module):

    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):
        super(PSAModule, self).__init__()
        self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,
                            stride=stride, groups=conv_groups[0])
        self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,
                            stride=stride, groups=conv_groups[1])
        self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,
                            stride=stride, groups=conv_groups[2])
        self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,
                            stride=stride, groups=conv_groups[3])
        self.se = SEWeightModule(planes // 4)
        self.split_channel = planes // 4
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        feats = torch.cat((x1, x2, x3, x4), dim=1)
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])

        x1_se = self.se(x1)
        x2_se = self.se(x2)
        x3_se = self.se(x3)
        x4_se = self.se(x4)

        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_weight = feats * attention_vectors
        for i in range(4):
            x_se_weight_fp = feats_weight[:, i, :, :]
            if i == 0:
                out = x_se_weight_fp
            else:
                out = torch.cat((x_se_weight_fp, out), 1)

        return out

 EPSANET

import torch
import torch.nn as nn
import math
from .SE_weight_module import SEWeightModule

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    """standard convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class PSAModule(nn.Module):

    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):
        super(PSAModule, self).__init__()
        self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,
                            stride=stride, groups=conv_groups[0])
        self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,
                            stride=stride, groups=conv_groups[1])
        self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,
                            stride=stride, groups=conv_groups[2])
        self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,
                            stride=stride, groups=conv_groups[3])
        self.se = SEWeightModule(planes // 4)
        self.split_channel = planes // 4
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        feats = torch.cat((x1, x2, x3, x4), dim=1)
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])

        x1_se = self.se(x1)
        x2_se = self.se(x2)
        x3_se = self.se(x3)
        x4_se = self.se(x4)

        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_weight = feats * attention_vectors
        for i in range(4):
            x_se_weight_fp = feats_weight[:, i, :, :]
            if i == 0:
                out = x_se_weight_fp
            else:
                out = torch.cat((x_se_weight_fp, out), 1)

        return out


class EPSABlock(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, conv_kernels=[3, 5, 7, 9],
                 conv_groups=[1, 4, 8, 16]):
        super(EPSABlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = norm_layer(planes)
        self.conv2 = PSAModule(planes, planes, stride=stride, conv_kernels=conv_kernels, conv_groups=conv_groups)
        self.bn2 = norm_layer(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class EPSANet(nn.Module):
    def __init__(self,block, layers, num_classes=1000):
        super(EPSANet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layers(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layers(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layers(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layers(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layers(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, num_blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def epsanet50():
    model = EPSANet(EPSABlock, [3, 4, 6, 3], num_classes=1000)
    return model

def epsanet101():
    model = EPSANet(EPSABlock, [3, 4, 23, 3], num_classes=1000)
    return model

参考:

EPSANet: 一种高效的多尺度通道注意力机制,主要提出了金字塔注意力模块,即插即用,效果显著,已开源!

EPSANet:金字塔拆分注意力模块 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1131401.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

主流电商平台价格如何高频监测

双十一来临在即,除了商家很兴奋,品牌和消费者同样持续关注,除了关注不同平台的产品上架情况,价格也是这些渠道参与者最为关注的,品牌需要通过掌握各店铺的价格情况,了解市场情况以及各经销商的渠道治理现状…

从零搭建一个PWA应用需要了解哪些知识

在国内由于小程序的风生水起,PWA 应用在国内的状况一直都不是很好,PWA 和小程序有很多的相似性,但是 PWA 是由谷歌发起的技术,小程序是微信发起的技术,所以小程序在国内得到了大力的扶持,很快就在国内技术界…

docker制作java项目镜像

docker制作java项目镜像 环境步骤Dockerfile 运行容器 环境 当前使用win10安装的docker win10安装Docker参考文章 步骤 将Dockerfile文件和jar包放在同一个目录下 编写Dockerfile文件 Dockerfile #设置镜像基础: jdk8-jre , 比jdk内存小 FROM java:8-jre #维护人员信息 MA…

众和策略可靠吗?dde大单净量可信吗?

可靠 DDE大单净量是指股票成交中的单笔生意量较大且净买入或净卖出的数量。这个方针在股票商场中被广泛运用,尤其是在技术剖析中。但是,有时候人们会怀疑DDE大单净量的可信度,下面我们从几个角度进行剖析。 首要,有些人以为DDE大…

持续性输出,继续推荐5款好用的软件

​ 分享是一种神奇的东西,它使快乐增大,它使悲伤减小,坚持分享一些好用的软件给大家,今天继续为大家带来五款好用的小软件。 1.文件解锁工具——Unlocker ​ Unlocker是一款解决Windows文件被占用无法删除或重命名的问题的小工具…

mysql数据库迁移达梦

迁移前准备: 授权给要迁移的数据库的用户,例如此时是 mysql迁移到达梦里面,所以得把你连接这个mysql数据库的这个用户root授权, CREATE USER root IDENTIFIED BY1123456;GRANT ALL privileges ON *.* TO rootroot WITH GRANT OPTI…

微信小程序菜单导航选中自动居中

菜单导航选中自动居中 示例库 代码片段

Qt之自定义事件QEvent

在Qt中,自定义事件的步骤大概如下: 1.创建自定义事件,自定义事件需要继承QEvent 2.使用QEvent::registerEventType()注册自定义事件类型,事件的类型需要在 QEvent::User 和 QEvent::MaxUser 范围之间,在QEvent::User之前是预留给系统的事件 3.使用sendEvent() 和 postEv…

迅为RK3399开发板Android 系统--打印级别设置(printk日志等级设置)

在内核源码 include/linux/kern_levels.h 文件中预定义了内核 log 等级,一共有八个等级,从 0 到 7,优先级依次降低,如下所示: // include/linux/kern_levels.h #define KERN_SOH "\001" /* ASCII Start Of…

Khronos: 面向万亿规模时间线的性能监控引擎建设实践

作者:余文清 阿里巴巴智能引擎事业部自研的 Khronos 系统是阿里内部接入规模最大的性能数据存储引擎。Khronos 支持动态生命周期的存储计算分离架构,采用 schemaless 的 data model 设计,在万亿数据规模下为业务提供易用、高效、经济的服务&a…

自媒体创业秘籍:视频号视频下载助你打造热门账号

​自媒体创业者们都知道,视频号已经成为拓展影响力和吸引更多用户的热门平台之一。然而,要想在这个竞争激烈的市场中脱颖而出,并打造一个热门账号,你需要掌握一些技巧和秘籍。在本文中,我将分享关于视频号视频下载的方…

ModelSim【紫光】

这软件是查看波形的。 如果ModelSim频繁弹窗,关闭电脑杀毒软件和电脑管家。 尤其是荣耀管家

Python学习8

前言:相信看到这篇文章的小伙伴都或多或少有一些编程基础,懂得一些linux的基本命令了吧,本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python:一种编程语言&…

UMMKD

方法 对于“Y”形模型,绿线之前的层是分开的,绿线之后的层在模态之间共享。对于“X”形模型,第一条蓝线之前和第二条蓝线之后的层是分开的,蓝线之间的层在模态之间共享 作者未提供数据

shell脚本中循环语句(极其粗糙版)

分界点:以下内容需要更改,正常放假更改 循环语句: 循环:重复执行一段代码的结构,通过循环,可以在满足一定的条件情况下,多次的执行相同的代码 循环包括:循环体以及循环条件&#…

第二证券:macd指标如何使用?

MACD方针是一种用于技术分析的方针,被广泛应用于股票、期货和外汇市场。本文将从什么是MACD,MACD的原理,怎样运用MACD和MACD的优缺点四个方面分析MACD方针怎样运用。 一、什么是MACD MACD是Moving Average Convergence Divergence的缩写&…

第二证券:股指预计维持蓄势震荡格局 关注证券、计算机设备等板块

第二证券指出,增发1万亿国债、自2000年以来首度年内调整预算传递财务积极发力的信号,估量会对四季度及明年GDP将产生显著带动作用,有利于股市整体情绪的提振。对于债市而言,在支撑信贷增加和实体经济批改的目标下,我们…

在线客服系统源码 客服系统源码

在线客服系统源码 客服系统源码 框架:Thinkphp5workerman,环境:nginxphp7.3mysql5.6 多商户客服、不限坐席、独立系统--数据存储自己服务器上,支持开启SSL、支持离线对话。 新款在线客服系统全开源无加密:多商户、国…

IN动态|小达智能科技领导一行莅临英码科技调研,携手打造时代特色的AI教学平台

近日,小达智能科技CEO何方明领导一行莅临英码科技进行走访调研,我司业务中心总监刘力搏及电商运营经理潘怡霏等一同热情接待,并展开深入的座谈交流。 01 广州小达智能科技有限公司 广州小达智能科技有限公司是于2023年融合砺锋信息科技有限公…