题目:Fast Vision Transformers with HiLo Attention
论文地址: https://arxiv.org/abs/2205.13213
创新点
-
HiLo自注意力机制:作者提出了一种新的自注意力机制,称为HiLo注意力,旨在同时捕捉图像中的高频和低频信息。该方法通过将自注意力分为两个分支,高频分支(Hi-Fi)处理局部的高分辨率细节,低频分支(Lo-Fi)处理全局的低分辨率结构。这样可以提高计算效率,特别是在高分辨率图像上,同时保持准确性。
-
LITv2模型:基于HiLo注意力机制,文献引入了LITv2模型,该模型在多个主流计算机视觉任务(如图像分类、物体检测和语义分割)上表现优越。LITv2通过在早期阶段删除多头自注意力(MSA)层,并在后期阶段使用高效的HiLo注意力机制,提升了模型的速度和内存效率。
-
速度优化:作者通过实际平台上的速度评估(而非通常的FLOPs计算)设计了该模型,以确保其在GPU和CPU上的实际速度更快。例如,HiLo机制在CPU上比局部窗口注意力机制快1.6倍,比空间缩减注意力机制快1.4倍。
-
相对位置编码优化:文献还对相对位置编码进行了优化,采用了3×3的深度卷积层代替传统的固定相对位置编码,这大大加快了密集预测任务(如分割)的训练和推理速度。
方法
整体结构
LITv2模型基于HiLo注意力机制,分离处理高频和低频信息,通过局部窗口自注意力捕捉细节、高效全局注意力处理全局结构。此外,模型采用3×3深度卷积层替代位置编码,减少计算复杂度并扩大感受野。整体架构分为多阶段,生成金字塔特征图,适用于密集预测任务,结合残差连接和全局自注意力确保性能与效率的平衡。
-
Patch Embedding层:模型首先将输入图像切分为固定大小的图像块(patch),然后通过线性变换将每个patch映射到一个高维特征空间,这与大多数Vision Transformer类似。
-
HiLo注意力机制:这是模型的核心创新点。HiLo注意力机制将多头自注意力(MSA)分成两个部分:
-
高频(Hi-Fi)注意力:处理局部的高频细节信息,使用的是局部窗口自注意力(例如2×2窗口),能够高效捕获图像中的细节信息。
-
低频(Lo-Fi)注意力:处理全局的低频信息,先通过平均池化获得低频特征,再进行全局自注意力计算,从而减少计算复杂度。
-
深度卷积层(Depthwise Convolution Layer):为了进一步提高效率,LITv2引入了3×3的深度卷积层用于代替传统的多层感知机(MLP)中的位置编码。这种设计不仅减少了位置编码的计算负担,还扩大了早期阶段特征的感受野。
-
多阶段结构:模型通常分为多个阶段(例如4个阶段),在每个阶段生成金字塔结构的特征图(pyramid feature maps),用于处理不同分辨率的特征。这使得模型在图像分类之外的密集预测任务(如物体检测和语义分割)中更具优势。
-
残差连接和归一化:在每个Transformer模块中,模型使用标准的残差连接和LayerNorm层。这些是标准的ViT组件,用于稳定训练并保持特征的传递。
-
后期的全局自注意力:在模型的后期阶段,虽然早期阶段使用了高效的局部自注意力和低频注意力机制,但后期阶段会使用标准的多头自注意力机制来处理下采样后的低分辨率特征图,以进一步提升性能。
即插即用模块
将HiLo注意力机制提取为即插即用模块,主要适用于以下场景:
-
高分辨率图像处理:在需要处理高分辨率图像的任务中,例如图像分类、目标检测、语义分割等,HiLo通过高效分离高频和低频信息,显著减少计算复杂度和内存占用,提升推理速度和处理能力。
-
低延迟应用场景:HiLo能够在实际硬件平台(如GPU和CPU)上加快推理速度,特别适用于需要低延迟的场景,例如无人机图像处理、自动驾驶中的实时感知系统等。
-
视觉任务中的密集预测:在需要对每个像素进行精细预测的任务中,如语义分割和实例分割,HiLo能够高效处理局部细节和全局结构,提升预测的准确性和速度。
消融实验
-
该表展示了LITv1-S模型在引入不同结构修改后的性能变化,包括加入3×3深度卷积层(ConvFFN)、去除相对位置编码(RPE)、以及使用HiLo注意力机制后的影响。
-
结果表明:引入深度卷积层后,模型在ImageNet分类和COCO检测任务中的性能提升显著,移除RPE后虽然有轻微的性能下降,但推理速度(FPS)显著提升,使用HiLo注意力机制后进一步提升了模型效率,特别是在FLOPs和推理速度上。
-
该图展示了HiLo注意力机制中高频和低频头部分配比例(α)的影响。随着α值的增加(更多头部用于低频注意力),FLOPs逐渐减少,模型的Top-1准确率在α=0.9时达到最佳。
-
该实验表明高频和低频信息在自注意力中的合理分配对模型效率和性能有重要影响。
-
该图通过Fast Fourier Transform(FFT)可视化了Hi-Fi和Lo-Fi注意力机制输出特征中的频率成分。结果显示,Hi-Fi注意力捕捉更多的高频信息,而Lo-Fi主要关注低频信息。
-
该实验验证了HiLo注意力机制在分离高频和低频特征时的有效性,符合文献提出的设计理念。
即插即用模块HiLo
import math
import torch
import torch.nn as nn
# 论文:Fast Vision Transformers with HiLo Attention
# 论文地址:https://arxiv.org/abs/2205.13213
class HiLo(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
head_dim = int(dim/num_heads)
self.dim = dim
# self-attention heads in Lo-Fi
self.l_heads = int(num_heads * alpha)
# token dimension in Lo-Fi
self.l_dim = self.l_heads * head_dim
# self-attention heads in Hi-Fi
self.h_heads = num_heads - self.l_heads
# token dimension in Hi-Fi
self.h_dim = self.h_heads * head_dim
# local window size. The `s` in our paper.
self.ws = window_size
if self.ws == 1:
# ws == 1 is equal to a standard multi-head self-attention
self.h_heads = 0
self.h_dim = 0
self.l_heads = num_heads
self.l_dim = dim
self.scale = qk_scale or head_dim ** -0.5
# Low frequence attention (Lo-Fi)
if self.l_heads > 0:
if self.ws != 1:
self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
self.l_proj = nn.Linear(self.l_dim, self.l_dim)
# High frequence attention (Hi-Fi)
if self.h_heads > 0:
self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
self.h_proj = nn.Linear(self.h_dim, self.h_dim)
def hifi(self, x):
B, H, W, C = x.shape
h_group, w_group = H // self.ws, W // self.ws
total_groups = h_group * w_group
x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)
qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim
attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws
attn = attn.softmax(dim=-1)
attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)
x = self.h_proj(x)
return x
def lofi(self, x):
B, H, W, C = x.shape
q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)
if self.ws > 1:
x_ = x.permute(0, 3, 1, 2)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
x = self.l_proj(x)
return x
def forward(self, x, H, W):
B, N, C = x.shape
x = x.reshape(B, H, W, C)
if self.h_heads == 0:
x = self.lofi(x)
return x.reshape(B, N, C)
if self.l_heads == 0:
x = self.hifi(x)
return x.reshape(B, N, C)
hifi_out = self.hifi(x)
lofi_out = self.lofi(x)
x = torch.cat((hifi_out, lofi_out), dim=-1)
x = x.reshape(B, N, C)
return x
def flops(self, H, W):
# pad the feature map when the height and width cannot be divided by window size
Hp = self.ws * math.ceil(H / self.ws)
Wp = self.ws * math.ceil(W / self.ws)
Np = Hp * Wp
# For Hi-Fi
# qkv
hifi_flops = Np * self.dim * self.h_dim * 3
nW = (Hp // self.ws) * (Wp // self.ws)
window_len = self.ws * self.ws
# q @ k and attn @ v
window_flops = window_len * window_len * self.h_dim * 2
hifi_flops += nW * window_flops
# projection
hifi_flops += Np * self.h_dim * self.h_dim
# for Lo-Fi
# q
lofi_flops = Np * self.dim * self.l_dim
kv_len = (Hp // self.ws) * (Wp // self.ws)
# k, v
lofi_flops += kv_len * self.dim * self.l_dim * 2
# q @ k and attn @ v
lofi_flops += Np * self.l_dim * kv_len * 2
# projection
lofi_flops += Np * self.l_dim * self.l_dim
return hifi_flops + lofi_flops
if __name__ == '__main__':
block = HiLo(dim=128)
input = torch.rand(32, 128, 128) # input with shape (B, N, C)
output = block(input, 16, 8) # H = 16, W = 8, since H * W should equal N
print(input.size())
print(output.size())