常见注意力机制解析

news2024/12/23 8:42:32

1.Squeeze-and-Excitation(SE)

SE的主要思想是通过对输入特征进行压缩和激励,来提高模型的表现能力。具体来说,SE注意力机制包括两个步骤:Squeeze和Excitation。在Squeeze步骤中,通过全局平均池化操作将输入特征图压缩成一个向量,然后通过一个全连接层将其映射到一个较小的向量。在Excitation步骤中,使用一个sigmoid函数将这个向量中的每个元素压缩到0到1之间,并将其与原始输入特征图相乘,得到加权后的特征图。通过SE注意力机制,模型可以自适应地学习到每个通道的重要性,从而提高模型的表现能力。在实际应用中,SE注意力机制已经被广泛应用于各种深度学习模型中,取得了很好的效果。

 代码如下:

import torch
from torch import nn
from torchstat import stat  # 查看网络参数

# 定义SE注意力机制的类
class SE_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(SE_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

if __name__ == '__main__':
    # 构造输入层
    inputs = torch.rand(2,320,32,32)
    # 获取输入通道数
    channel = inputs.shape[1]
    # 模型实例化
    model = SE_block(channel, ratio=16)

    # 前向传播查看输出结果
    outputs = model(inputs)
    print(outputs.shape)  #[2, 320, 32, 32]

    # print(model)  # 查看模型结构
    stat(model, input_size=[320,32,32])  # 查看参数,不需要指定batch维度

2.Convolutional Block Attention Module(CBAM)

CBAM用于自适应地调整关注图像中的关键区域。CBAM注意力机制由两个模块组成:通道注意力模块和空间注意力模块。通道注意力模块用于学习每个通道的重要性,使得网络更多地关注那些有意义的特征,同时消除那些无意义的特征。这个模块的输入是一个卷积层的输出,输出是经过权重归一化之后的特征图。具体地,对于每个通道,通道注意力模块使用全局平均池化操作来得到一个特征向量,然后通过两个全连接层学习该通道的重要性,并对最终的卷积层输出进行加权。

空间注意力模块用于学习图像的不同空间区域的重要性,即在不同位置上的注意力权值。这个模块的输入是卷积层的输出,输出是通过空间维度上的卷积操作和全连接操作,产生的每个空间位置的注意力矩阵。这个矩阵能够自适应地调整感受野大小并关注图像中的关键区域。

通道和空间注意力模块可以结合使用,构建CBAM注意力机制。通过这种注意力机制的应用,可以提高图像分类、目标检测和语义分割等任务的准确性。

代码如下:

import torch
from torch import nn
from torchstat import stat  # 查看网络参数

# 定义CBAM注意力机制的类
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # 利用1x1卷积代替全连接
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(CBAM_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = x * self.channelattention(x)
        x = x * self.spatialattention(x)
        return x


if __name__ == '__main__':
    # 构造输入层
    inputs = torch.rand(2,320,32,32)
    # 获取输入通道数
    channel = inputs.shape[1]
    # 模型实例化
    model = CBAM_block(channel, ratio=16, kernel_size=7)

    # 前向传播查看输出结果
    outputs = model(inputs)
    print(outputs.shape)  #[2, 320, 32, 32]

    # print(model)  # 查看模型结构
    stat(model, input_size=[320,32,32])  # 查看参数,不需要指定batch维度

3.Efficient Channel Attention(ECA)

ECA是一种用于图像处理的注意力机制模型。它主要是通过对图像通道进行注意力调控,提高图像特征表示的有效性。具体来说,ECA注意力机制模型由两部分组成:全局平均池化和线性变换。全局平均池化可以对每个通道的信息进行汇聚,从而判断该通道中的信息是否关键;线性变换可以将通道的信息进行缩放和平移,使得关键信息得到更好的保留,非关键信息得到抑制。如果某个通道的信息对于图像表现并不关键,则可以对其进行抑制,以提高其他通道的表现。ECA注意力机制相比于其他注意力机制模型的优势在于其模型复杂度低、计算效率高、效果好。因此,它被广泛应用于图像分类、目标检测和图像分割等领域。

 代码如下:

import torch
import math
from torch import nn
from torchstat import stat  # 查看网络参数

# 定义ECA注意力机制的类
class ECA_block(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super(ECA_block, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)


if __name__ == '__main__':
    # 构造输入层
    inputs = torch.rand(2,320,32,32)
    # 获取输入通道数
    channel = inputs.shape[1]
    # 模型实例化
    model = ECA_block(channel)

    # 前向传播查看输出结果
    outputs = model(inputs)
    print(outputs.shape)  #[2, 320, 32, 32]

    # print(model)  # 查看模型结构
    stat(model, input_size=[320,32,32])  # 查看参数,不需要指定batch维度

4.Coordinate attention(CA)

CA可以避免全局pooling-2D操作造成的位置信息丢失,将注意力分别放在宽度和高度两个维度上,有效利用输入特征图的空间坐标信息。CA主要分为两个步骤:第一步是坐标信息的嵌入,给定输入X,使用池化层分别沿着水平坐标和垂直坐标对每个通道进行编码,得到一对方向感知特征图。第二步是坐标信息特征图的生成,首先将提取到的特征信息进行拼接,然后利用一个1×1卷积变换函数进行信息转换,进而得到中间特征图,并沿着空间维度分解为两个单独的张量,再利用两个卷积变换为具有相同通道数的张量,最后将输出结果进行扩展,分别作为注意力权重分配值。CA是一个简单灵活即插即用的模块,可以在不带来任何额外开销的前提下,提升网络的精度。

代码如下:

import torch
from torch import nn
from torchstat import stat  # 查看网络参数

# 定义CA注意力机制的类
class CA_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CA_Block, self).__init__()

        self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1,
                                  bias=False)

        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(channel // reduction)

        self.F_h = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
                             bias=False)
        self.F_w = nn.Conv2d(in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1,
                             bias=False)

        self.sigmoid_h = nn.Sigmoid()
        self.sigmoid_w = nn.Sigmoid()

    def forward(self, x):
        _, _, h, w = x.size()

        x_h = torch.mean(x, dim=3, keepdim=True).permute(0, 1, 3, 2)
        x_w = torch.mean(x, dim=2, keepdim=True)

        x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))

        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)

        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))

        out = x * s_h.expand_as(x) * s_w.expand_as(x)
        return out

if __name__ == '__main__':
    # 构造输入层
    inputs = torch.rand(2,320,32,32)
    # 获取输入通道数
    channel = inputs.shape[1]
    # 模型实例化
    model = CA_Block(channel, reduction=16)
    
    # 前向传播查看输出结果
    outputs = model(inputs)
    print(outputs.shape)  #[2, 320, 32, 32]

    # print(model)  # 查看模型结构
    stat(model, input_size=[320,32,32])  # 查看参数,不需要指定batch维度

reference

http://t.csdn.cn/0XBy9icon-default.png?t=N3I4http://t.csdn.cn/0XBy9

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/503356.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【2023年Mathorcup杯数学建模竞赛C题】电商物流网络包裹应急调运与结构优化--完整作品分享

1.问题背景 2.论文摘要 为了应对电商物流网络中物流场地和线路电商物流网络中物流场地和线路上货量波动的情况, 设计合理的物流网络调整方案以保障物流网络的正常运行。本文运用 0-1 整数规划模型,多目标动 态规划模型,给出了问题的结果。 针…

深入讲解eMMC简介

1 eMMC是什么 eMMC是embedded MultiMediaCard的简称,即嵌入式多媒体卡,是一种闪存卡的标准,它定义了基于嵌入式多媒体卡的存储系统的物理架构和访问接口及协议,具体由电子设备工程联合委员会JEDEC订立和发布。它是对MMC的一个拓展&#xff0…

redi缓存使用

1、缓存的特征 第一个特征:在一个层次化的系统中,缓存一定是一个快速子系统,数据存在缓存中时,能避免每次从慢速子系统中存取数据。 第二个特征:缓存系统的容量大小总是小于后端慢速系统的,不可能把所有数…

GAMES101 计算机图形学 | 学习笔记 (上)

目录 环境安装什么是计算机图形学物体上点的坐标变换顺序齐次坐标光栅化如何判定一个点在三角形内光栅化填充三角形示例代码光栅化产生的问题 采样不足(欠采样)导致锯齿抗锯齿滤波算法 环境安装 1. C中安装opencv库 2. C中安装eigen库 3. C中安装open…

ChatGPT调教指北,技巧就是效率!

技巧就是效率 很多人都知道ChatGPT很火很强,几乎无所不能,但跨越了重重门槛之才有机会使用的时候却有些迷茫,一时间不知道如何使用它。如果你就是把他当作一个普通的智能助手来看待,那与小爱同学有什么区别?甚至还差劲…

热乎的面经——踏石留印

⭐️前言⭐️ 本篇文章记录博主面试北京某公司所记录的面经,希望能给各位带来帮助。 🍉欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍉博主将持续更新学习记录收获,友友们有任何问题可以在评论…

Origin如何绘制三维图形?

文章目录 0.引言1.使用矩阵簿窗口2.三维数据转换3.三维绘图4.三维曲面图5.三维XYY图6.三维符号、条状、矢量图7.等高线图 0.引言 因科研等多场景需要,绘制专业的图表,笔者对Origin进行了学习,本文通过《Origin 2022科学绘图与数据》及其配套素…

63.空白和视觉层级的实战应用

例如看我们之前的小网页; 这些标题的上下距离一样,这样让我们很容易对这些标题进行混淆,我们可以适当的添加一点空白 header, section {margin-bottom: 96px; }这样看上去似乎就好很多! 除此之外,如我们之间学的空…

【line features】线特征

使用BinaryDescriptor接口提取线条并将其存储在KeyLine对象中,使用相同的接口计算每个提取线条的描述符,使用BinaryDescriptorMatcher确定从不同图像获得的描述符之间的匹配。 opencv提供接口实现 线提取和描述符计算 下面的代码片段展示了如何从图像中…

K8S相关核心概念

个人笔记: 要弄明白k8s的细节,需要知道k8s是个什么东西。它的主要功能,就是容器的调度--也就是把部署实例,根据整体资源的使用状况,部署到任何地方 注意任何这两个字,预示着你并不能够通过常规的IP、端口…

如何全面学习Object-C语言的语法知识 (Xmind Copilot生成)

网址:https://xmind.ai/login/ 登录后直接输入:如何全面学习Object-C语言的语法知识,就可以生成大纲 点击右上角的 按钮,可以显示md格式的问题,再点击生成全文,就可以生成所有内容了, 还有这个…

CentOS7/8 安装 5+ 以上的Linux kernel

CentOS以稳定著称,稳定在另外一方面就是保守。所以CentOS7还在用3.10,CentOS8也才是4.18。而当前最新的Linux Kernel都更新到6.0 rc3了。其他较新的发行版都用上了5.10的版本。本文简单介绍如何在CentOS7、8上直接安装5.1以上版本的第三方内核。 使用ted…

5.8晚间黄金行情走势分析及短线交易策略

近期有哪些消息面影响黄金走势?本周黄金多空该如何研判? ​黄金消息面解析:周一亚洲时段,现货黄金小幅反弹,目前交投于2024.3美元/盎司附近,一方面是金价上周五守住了 2000 整数关口,逢低买盘涌…

java环境Springboot框架中配置使用GDAL,并演示使用GDAL读取shapefile文件

GDAL是应用广泛的空间数据处理库,可以处理几何、栅格数据,Springboot是常用的JAVA后端开发框架。本文讲解如何在Springboot中配置使用GDAL。本文示例中使用的GDAL版本为3.4.1(64位) 图1 GDAL读取shp效果 一、部署GDAL类库 将GDA…

什么是点对点传输?什么是点对多传输

点对点技术(peer-to-peer, 简称P2P)又称对等互联网络技术,是一种网络新技术,依赖网络中参与者的计算能力和带宽,而不是把依赖都聚集在较少的几台服务器上。P2P网络通常用于通过Ad Hoc连接来连接节点。这类网…

WiFi(Wireless Fidelity)基础(四)

目录 一、基本介绍(Introduction) 二、进化发展(Evolution) 三、PHY帧((PHY Frame ) 四、MAC帧(MAC Frame ) 五、协议(Protocol) 六、安全&#x…

功能测试常用的测试用例大全

登录、添加、删除、查询模块是我们经常遇到的,这些模块的测试点该如何考虑 1)登录 ① 用户名和密码都符合要求(格式上的要求) ② 用户名和密码都不符合要求(格式上的要求) ③ 用户名符合要求,密码不符合要求(格式上的要求) ④ 密码符合要求,…

1_1torch学习

一、torch基础知识 1、torch安装 pytorch cuda版本下载地址:https://download.pytorch.org/whl/torch_stable.html 其中先看官网安装torch需要的cuda版本,之后安装cuda版本,之后采用pip 下载对应的torch的gpu版本whl来进行安装。使用pip安装…

Linux内核中的链表(list_head)使用分析

【摘要】本文分析了linux内核中的list_head数据结构的底层实现及其相关的各种调用源码,有助于理解内核中链表对象的使用。 二、内核中的队列/链表对象 在内核中存在4种不同类型的列表数据结构: singly-linked listssingly-linked tail queuesdoubly-lin…

SSM框架学习-bean生命周期理解

Spring启动,查找并加载需要被Spring管理的Bean,进行Bean的实例化(反射机制);利用依赖注入完成 Bean 中所有属性值的配置注入; 第一类Aware接口: 如果 Bean 实现了 BeanNameAware 接口的话&#…