目录
- 一、注意力机制介绍
- 1、什么是注意力机制?
- 2、注意力机制的分类
- 3、注意力机制的核心
- 二、SOCA模块
- 1、SOCA模块的原理
- 2、实验结果
- 3、应用示例
- 三、SimAM模块
- 1、SimAM模块的原理
- 2、实验结果
- 3、应用示例
大家好,我是哪吒。
🏆本文收录于,目标检测YOLO改进指南。
本专栏均为全网独家首发,内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。
在机器学习和自然语言处理领域,随着数据的不断增长和任务的复杂性提高,传统的模型在处理长序列或大型输入时面临一些困难。传统模型无法有效地区分每个输入的重要性,导致模型难以捕捉到与当前任务相关的关键信息。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。
一、注意力机制介绍
1、什么是注意力机制?
注意力机制(Attention Mechanism)是一种在机器学习和自然语言处理领域中广泛应用的重要概念。它的出现解决了模型在处理长序列或大型输入时的困难,使得模型能够更加关注与当前任务相关的信息,从而提高模型的性能和效果。
本文将详细介绍注意力机制的原理、应用示例以及应用示例。
2、注意力机制的分类
类别 | 描述 |
---|---|
全局注意力机制(Global Attention) | 在计算注意力权重时,考虑输入序列中的所有位置或元素,适用于需要全局信息的任务。 |
局部注意力机制(Local Attention) | 在计算注意力权重时,只考虑输入序列中的局部区域或邻近元素,适用于需要关注局部信息的任务。 |
自注意力机制(Self Attention) | 在计算注意力权重时,根据输入序列内部的关系来决定每个位置的注意力权重,适用于序列中元素之间存在依赖关系的任务。 |
Bahdanau 注意力机制 | 全局注意力机制的一种变体,通过引入可学习的对齐模型,对输入序列的每个位置计算注意力权重。 |
Luong 注意力机制 | 全局注意力机制的另一种变体,通过引入不同的计算方式,对输入序列的每个位置计算注意力权重。 |
Transformer 注意力机制 | 自注意力机制在Transformer模型中的具体实现,用于对输入序列中的元素进行关联建模和特征提取。 |
3、注意力机制的核心
注意力机制的核心思想是根据输入的上下文信息来动态地计算每个输入的权重。这个过程可以分为三个关键步骤:计算注意力权重、对输入进行加权和输出。首先,计算注意力权重是通过将输入与模型的当前状态进行比较,从而得到每个输入的注意力分数。这些注意力分数反映了每个输入对当前任务的重要性。对输入进行加权是将每个输入乘以其对应的注意力分数,从而根据其重要性对输入进行加权。最后,将加权后的输入进行求和或者拼接,得到最终的输出。注意力机制的关键之处在于它允许模型在不同的时间步或位置上关注不同的输入,从而捕捉到与任务相关的信息。
🏆YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块
🏆YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块
🏆YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块
🏆YOLOv5/v7 添加注意力机制,30多种模块分析④,CA模块,ECA模块
二、SOCA模块
1、SOCA模块的原理
SOCA(Second-order Channel Attention,二阶通道注意力)模块是一种用于图像超分辨率的注意力机制。它可以通过对输入特征张量进行协方差计算,并使用计算出的协方差矩阵作为权重,来提高模型对重要通道的关注度,从而提高模型的超分辨率效果。
SOCA模块的主要思想是将输入特征张量中每个通道之间的关系进行建模,即通过计算协方差矩阵来表示各个通道之间的相关性。具体来说,SOCA模块包括以下几个步骤:
- 将输入特征张量进行拉平操作,得到一个二维矩阵。
- 对该矩阵进行归一化处理,使得每个通道的均值为0、标准差为1。
- 计算该矩阵的协方差矩阵,并对其进行特征分解,得到其特征向量和特征值。
- 使用特征向量来构造一个注意力向量,并对输入特征张量进行加权平均。
- 将加权平均后的特征张量与原始特征张量相加,得到最终的输出特征张量。
2、实验结果
5.6×105次迭代中在Set5(4×)上的最佳PSNR(分贝)值。
3、应用示例
下面是使用SOCA模块的应用示例:
import torch.nn as nn
import torch.nn.functional as F
class SOCA(nn.Module):
def __init__(self, in_channels):
super(SOCA, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False)
self.bn = nn.BatchNorm2d(in_channels)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.conv(y.view(b, c, 1, 1)).view(b, c, 1, 1)
y = self.bn(y)
y = self.sigmoid(y)
return x * y
该代码定义了一个SOCA(Second-Order Channel Attention)模块的类,包括以下几个部分:
__init__(self, in_channels)
:构造函数,接受输入通道数in_channels
作为参数。avg_pool = nn.AdaptiveAvgPool2d(1)
:创建一个自适应平均池化层,将输入特征张量缩小到 1x1 的大小。conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False)
:创建一个 1x1 的卷积层,用于降低通道维度。bn = nn.BatchNorm2d(in_channels)
:创建一个批归一化层,对输出进行标准化处理。sigmoid = nn.Sigmoid()
:创建一个 Sigmoid 激活函数,用于生成注意力权重。forward(self, x)
:前向传播函数,接受输入特征张量x
作为参数,并返回加权后的特征张量。
在前向传播函数中,首先对输入特征张量进行自适应平均池化操作,然后通过卷积和批归一化层来降低通道维度,并使用 Sigmoid 激活函数生成注意力权重。最后将输入特征张量与注意力权重相乘,得到加权后的特征张量作为输出。
三、SimAM模块
1、SimAM模块的原理
SimAM模块是一种用于图像分类任务的自适应注意力机制。它的核心思想是利用相似性信息来调整每个通道的注意力权重,从而提升图像分类的性能。
SimAM模块首先通过两个全局池化操作获取特征图中每个通道的全局平均值和标准差。对于每个通道,SimAM将其与其他通道之间的相似性定义为该通道与其他通道的余弦相似度,并将这些相似性作为一个矩阵输入到一个子网络中。该子网络使用多层感知器(MLP)来学习如何将相似性信息转换为注意力权重。SimAM根据学习到的注意力权重对输入的特征进行加权求和,得到了调整后的特征表示。
SimAM模块的优点在于它能够灵活地适应不同的数据分布,从而提高图像分类的泛化性能。此外,由于它只依赖于全局平均值和标准差,SimAM的计算成本比较低,适合大规模图像分类任务。
2、实验结果
在ImageNet-1K上,ResNet-50有和没有我们的模块的训练曲线比较。左侧和右侧分别显示了两个网络的Top-1(%)和Top-5(%)准确性。可以看出,将SimAM集成到ResNet-50中,在训练和验证中都比基线模型效果更好。
使用经过训练的带有SimAM的ResNet-50的特征激活可视化。对于每个图像,从左到右的地图是SimAM之前和之后的特征GradCAM,注意力权重和注意力权重的GradCAM。注意力映射是通过沿通道维度平均3-D注意力权重获得的。
3、应用示例
下面是使用SimAM模块在YOLOv5中进行目标检测的应用示例:
(1)在YOLOv5模型中导入SimAM模块:
# 定义SimAM模块
class SimAM(nn.Module):
def __init__(self, out_channels, groups=8):
super().__init__()
self.groups = groups
self.out_channels = out_channels
self.conv_theta = nn.Conv2d(out_channels, out_channels // groups, kernel_size=1, bias=False)
self.conv_phi = nn.Conv2d(out_channels, out_channels // groups, kernel_size=1, bias=False)
self.conv_g = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
self.conv_attn = nn.Conv2d(out_channels // groups, out_channels, kernel_size=1, bias=False)
def forward(self, x):
# 获取输入张量的高度、宽度和通道数
b, c, h, w = x.shape
# 计算query、key、value
theta = self.conv_theta(x).view(b, self.groups, -1, h * w).permute(0, 1, 3, 2) # b, g, hw, c//g
phi = self.conv_phi(x).view(b, self.groups, -1, h * w) # b, g, c//g, hw
g = self.conv_g(x).view(b, self.out_channels, -1).permute(0, 2, 1) # b, hw, c
# 计算相似度矩阵
attn = torch.matmul(theta, phi) / math.sqrt(c // self.groups)
attn = F.softmax(attn, dim=-1)
# 计算加权后的value
attn_g = torch.matmul(attn, g).permute(0, 2, 1).reshape(b, self.out_channels // self.groups, h, w)
attn_g = self.conv_attn(attn_g)
# 输出结果
return x + attn_g
以上代码是SimAM(Similarity Attention Module)模块的定义及其前向传播过程。代码说明如下:
- 初始化函数中,out_channels表示输入张量的通道数,groups表示将输入张量的通道数分成几组进行相似度计算。
- conv_theta、conv_phi、conv_g和conv_attn分别表示用于计算query、key、value和加权后的value的卷积层。
- forward函数中,首先获取输入张量的高度、宽度和通道数,然后按照SimAM的原理,计算出query、key、value,并通过相似度矩阵计算出加权后的value。
- 最后将加权后的value与输入张量相加作为输出结果。
(2)在YOLOv5模型中使用SimAM模块:
# 定义YOLOv5模型
class YOLOv5(nn.Module):
def __init__(self, num_classes=80):
super().__init__()
self.num_classes = num_classes
# 省略其余部分...
# 添加SimAM模块
self.sa1 = SimAM(256)
self.sa2 = SimAM(512)
self.sa3 = SimAM(1024)
def forward(self, x):
# 省略起始部分...
# 使用SimAM模块提取特征
x = self.sa1(x)
x = self.conv3(x)
x = self.sa2(x)
x = self.conv4(x)
x = self.sa3(x)
x = self.conv5(x)
# 省略其余部分...
# 输出结果
return x
代码说明如下:
- 在初始化函数中,num_classes表示模型需要识别的物体类别数。
- 添加了三个SimAM模块,分别用于提取不同层级的特征,即sa1对应backbone中的C3层,sa2对应C4层,sa3对应C5层。
- 在forward函数中,首先省略掉起始部分(包括输入张量的大小调整、Backbone和FPN网络),然后使用SimAM模块提取特征,并通过卷积层预测目标框、置信度和类别信息。
- 最后输出预测结果。
参考论文:
- https://openaccess.thecvf.com/content_CVPR_2019/papers/Dai_Second-Order_Attention_Network_for_Single_Image_Super-Resolution_CVPR_2019_paper.pdf
- http://proceedings.mlr.press/v139/yang21o/yang21o.pdf
🏆本文收录于,目标检测YOLO改进指南。
本专栏均为全网独家首发,🚀内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。
🏆华为OD机试(JAVA)真题(A卷+B卷)
每一题都有详细的答题思路、详细的代码注释、样例测试,订阅后,专栏内的文章都可看,可加入华为OD刷题群(私信即可),发现新题目,随时更新,全天CSDN在线答疑。
🏆哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师。
🏆往期回顾:
YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块
YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块
YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块
YOLOv5/v7 添加注意力机制,30多种模块分析④,CA模块,ECA模块