题目:Salient Positions based Attention Network for Image Classification
论文地址:https://arxiv.org/pdf/2106.04996
创新点
-
提出了基于显著位置的注意力机制:论文提出了一种名为SPAblock的显著位置选择算法(SPS),通过在注意力计算中仅选择显著位置,减少了计算复杂度和内存需求,同时提取了对图像分类有用的上下文信息。这种方法有效减少了对非关键信息的处理,特别是在图像背景复杂的情况下能更好地避免噪声干扰。
-
采用了通道维度的聚合:与传统的非局部模块相比,SPAblock在通道维度上对信息进行聚合,而不是空间维度。这种聚合方式在减少计算资源的同时,能够更有效地提取特征信息,提高图像分类的准确性。
-
引入了适用于特征图的显著性度量:针对特征图的高维特性,论文设计了一种基于平方和的显著性度量方法,利用其近似高斯分布的特点,通过平方和来选择显著位置。这种方法适用于神经网络生成的高维特征图,而非传统的视觉图像。
-
在低层网络中取得更好效果:实验表明,SPAblock在低层网络中表现更佳,尤其在CIFAR和Tiny-ImageNet等数据集上优于传统的非局部模块,并显著减少了内存使用。这种设计更适合于低层网络的特性,能够在低层网络上更好地进行上下文建模。
方法
整体结构
这篇论文提出了SPAblock模型结构,其核心是基于显著位置的注意力机制(SPS算法),在输入特征中选择少量显著位置进行注意力计算,从而减少计算量和内存占用。模型首先生成查询和数值矩阵,通过SPS算法筛选显著位置,计算注意力矩阵后将数值矩阵的上下文信息聚合到输出特征中,并通过1×11 \times 1 卷积更新特征,最终与输入特征相加形成输出。该设计在图像分类任务中提升了精度,特别适合应用在网络的低层次。
-
输入特征生成查询和数值矩阵:特征图首先经过两个二维卷积层,分别生成查询矩阵QQ 和数值矩阵VV。这样可以将输入特征图转化为适合注意力计算的形式。
-
显著位置选择(Salient Positions Selection, SPS)算法:SPS算法根据查询矩阵的平方和选择出前kk 个显著位置。具体来说,SPS算法先计算查询矩阵各通道的平方和,并对每个通道进行求和,再根据该值选择显著位置。这一步骤减少了关注位置的数量,从而降低了计算复杂度。
-
计算注意力矩阵:利用SPS选出的显著位置构建注意力矩阵AA,并将其进行softmax归一化。此过程相当于计算查询和键的相似度,但仅限于显著位置,节省了大量计算资源。
-
特征聚合与更新:使用数值矩阵VV和注意力矩阵AA 进行矩阵乘法,将结果重新整形为与输入相同的尺寸。然后通过一个1×11 \times 1 卷积进行变换,并将结果与输入特征相加,以形成输出特征。
-
逐层级应用的灵活性:SPAblock可以插入到ResNet等深度网络的不同层级,尤其是在低层级的效果尤为显著。这是因为低层特征图通常包含更多空间细节,而SPAblock能够有效地提取其中的显著信息。
即插即用模块作用
SPAblock 作为一个即插即用模块,主要适用于:
-
图像分类:在分类任务中增强对重要特征的关注,忽略无关背景。
-
目标检测:提高对目标区域的聚焦,减少背景噪声的影响。
-
实时应用:在资源受限的环境中,如移动设备或嵌入式系统中,用于减少计算量和内存需求。
-
深度网络的低层或中层:在特征图信息丰富的低层或中层加入SPAblock,可以更有效地提取关键细节。
消融实验结果
该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。
该表比较了SPAblock在ResNet不同层级(从第1到第4层)加入时的性能。结果表明,在低层(尤其是第1层和第2层)加入SPAblock能显著提升分类精度,而在第4层的效果较弱。这说明SPAblock更适合用于低层次的网络结构,因为在低层特征图中显著区域的信息更为丰富,有助于提升模型性能。
即插即用模块
import torch
from torch import nn
class SPABlock(nn.Module):
def __init__(self, in_channels, k=784, adaptive = False, reduction=16, learning=False, mode='pow'):
super(SPABlock, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.k = k
self.adptive = adaptive
self.reduction = reduction
self.learing = learning
if self.learing is True:
self.k = nn.Parameter(torch.tensor(self.k))
self.mode = mode
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def forward(self, x, return_info=False):
input_shape = x.shape
if len(input_shape)==4:
x = x.view(x.size(0), self.in_channels, -1)
x = x.permute(0, 2, 1)
batch_size,N = x.size(0),x.size(1)
#(B, H*W,C)
if self.mode == 'pow':
x_pow = torch.pow(x,2)# (batchsize,H*W,channel)
x_powsum = torch.sum(x_pow,dim=2)# (batchsize,H*W)
if self.adptive is True:
self.k = N//self.reduction
if self.k == 0:
self.k = 1
outvalue, outindices = x_powsum.topk(k=self.k, dim=-1, largest=True, sorted=True)
outindices = outindices.unsqueeze(2).expand(batch_size, self.k, x.size(2))
out = x.gather(dim=1, index=outindices).to(self.device)
if return_info is True:
return out, outindices, outvalue
else:
return out
if __name__ == '__main__':
block = SPABlock(in_channels=128)
input = torch.rand(32, 784, 128)
output = block(input)
print(input.size()) print(output.size())