1. 基本信息
- 论文标题: 《Rethinking Local Perception in Lightweight Vision Transformer》
- 中文标题: 《重新思考轻量化视觉Transformer中的局部感知》
- 作者单位: 清华大学
- 发表时间: 2023
- 论文地址: https://arxiv.org/abs/2303.17803
- 代码地址: https://github.com/qhfan/CloFormer
2. 应用场景
- 图像分类、目标检测、语义分割等领域。
3. 研究背景
- 现阶段,
Transformer
在图像分类、目标检测、语义分割等领域表现出优异的性能。然而Transformer
参数量和计算量太大,不适合部署到移动设备。 - 在现有的轻量级
Transformer
模型中,大多数方法只注重设计稀疏注意力以有效处理低频全局信息,而处理高频局部信息的方法相对简单。
4. 方法概述
为了同时利用共享权重和上下文感知权重的优势,提出了CloFormer
,这是一种具有上下文感知局部增强功能的轻量级视觉转换器,具体贡献如下:
- 在
CloFormer
中,引入了一种名为AttnConv
的卷积算子,它采用注意力机制,充分利用共享权重和上下文感知权重的优势来实现局部感知。 此外,它使用了一种新方法,该方法结合了比普通局部自注意力更强的非线性来生成上下文感知权重。 - 在
CloFormer
中,采用双分支架构,其中一个分支使用AttnConv
捕获高频信息,而另一个分支使用带有下采样的普通注意力捕获低频信息。 双分支结构使CloFormer
能够同时捕获高频和低频信息。 - 该方法在图像分类、目标检测和语义分割方面的广泛实验证明了
CloFormer
的优越性。CloFormer
在ImageNet1k
上仅用4.2M
参数和0.6G FLOP
就实现了77.0%
的准确率,明显优于其他模型。
4.1 整体网络结构
如上图所示,CloFormer包含一个卷积主干和四个阶段。每个阶段由Clo block和ConvFFN组成, 先通过卷积主干传递输入图像以获得tokens。 该系统由四个卷积组成,每个卷积的步幅分别为2、2、1和1。 随后,标记经过四个阶段的Clo block和ConvFFN来提取层次特征。 最后,利用全局平均池化和全连接层来生成预测。
为了将局部信息整合到FFN过程中,用ConvFFN取代了传统的FFN。 ConvFFN和常用的FFN之间的主要区别是,ConvFFN在GELU激活后使用深度卷积(DWconv),这使得ConvFFN能够聚合局部信息。 由于DWconv,下行采样可以直接在ConvFFN中执行,而无需引入PatchMerge模块。 CloFormer使用了两种类型的ConvFFN。 第一种是级内ConvFFN,它直接利用跳过连接。 另一个是连接两个阶段的ConvFFN。 在这种类型的ConvFFN的跳过连接中,使用DWconv和全连接层分别对输入进行下采样和上维。
每个块由一个本地分支和一个全局分支组成。 在全局分支中,首先对K和V进行下采样,然后对Q、K和V进行标准attention处理,提取低频全局信息。
4.2 AttnConv模块
全局分支的模式有效地减少了需要注意的flop的数量,也产生了一个全局接受野。 然而,它在有效捕获低频全局信息的同时,对高频局部信息的处理能力不足。
在AttnConv中,首先应用线性变换得到Q,K, V,这与标准注意力相同,在进行线性变换后,首先对V进行共享权值的局部特征聚合处理,然后基于处理后的V和Q, K进行上下文感知的局部增强。具体分为为如下三个步骤:
使用一个简单的深度卷积(DWconv)来对 V 进行局部信息聚合。
使用两个DWconv分别聚合Q和K的本地信息。 然后,计算Q和K的Hadamard积,并对结果进行一系列变换,以获得−1到1之间的上下文感知权重。 最后,利用生成的权值对局部特征进行增强。
使用简单的方法将局部分支的输出与全局分支的输出融合。
4.3 代码
可以将
Clo block
当作注意力机制使用,具体代码如下:
import torch
import torch.nn as nn
from efficientnet_pytorch.model import MemoryEfficientSwish # 从 EfficientNet 的库中引入高效激活函数 Swish
class AttnMap(nn.Module):
def __init__(self, dim):
super().__init__()
# 定义一个包含两层卷积和激活函数的块,用于生成注意力映射
self.act_block = nn.Sequential(
nn.Conv2d(dim, dim, 1, 1, 0), # 1x1 卷积,保持通道数不变
MemoryEfficientSwish(), # Swish 激活函数
nn.Conv2d(dim, dim, 1, 1, 0) # 再次使用 1x1 卷积
)
def forward(self, x):
return self.act_block(x) # 前向传播,返回处理后的张量
class CloAttention(nn.Module):
def __init__(self, dim, num_heads=8, group_split=[4, 4], kernel_sizes=[5], window_size=4,
attn_drop=0., proj_drop=0., qkv_bias=True):
super().__init__()
# 参数初始化和断言检查
assert sum(group_split) == num_heads # 确保分组的头总数等于注意力头总数
assert len(kernel_sizes) + 1 == len(group_split) # 核大小和分组数一致
self.dim = dim # 输入通道数
self.num_heads = num_heads # 总的多头注意力头数
self.dim_head = dim // num_heads # 每个头的通道数
self.scalor = self.dim_head ** -0.5 # 注意力缩放因子
self.kernel_sizes = kernel_sizes # 高频分支的卷积核大小
self.window_size = window_size # 低频分支窗口大小
self.group_split = group_split # 每个分支分配的头数
# 创建高频和低频分支的模块
convs = [] # 高频卷积
act_blocks = [] # 高频注意力模块
qkvs = [] # 高频分支的 QKV 卷积
for i in range(len(kernel_sizes)):
kernel_size = kernel_sizes[i]
group_head = group_split[i]
if group_head == 0:
continue # 如果分组头数为 0,跳过此分支
convs.append(nn.Conv2d(3 * self.dim_head * group_head, 3 * self.dim_head * group_head, kernel_size,
1, kernel_size // 2, groups=3 * self.dim_head * group_head)) # 高频卷积
act_blocks.append(AttnMap(self.dim_head * group_head)) # 注意力映射模块
qkvs.append(nn.Conv2d(dim, 3 * group_head * self.dim_head, 1, 1, 0, bias=qkv_bias)) # QKV 卷积
# 定义低频全局注意力分支
if group_split[-1] != 0:
self.global_q = nn.Conv2d(dim, group_split[-1] * self.dim_head, 1, 1, 0, bias=qkv_bias) # Q 卷积
self.global_kv = nn.Conv2d(dim, group_split[-1] * self.dim_head * 2, 1, 1, 0, bias=qkv_bias) # KV 卷积
self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity() # 平均池化
# 将模块添加到 ModuleList 中
self.convs = nn.ModuleList(convs)
self.act_blocks = nn.ModuleList(act_blocks)
self.qkvs = nn.ModuleList(qkvs)
self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias) # 投影层
self.attn_drop = nn.Dropout(attn_drop) # 注意力权重的 dropout self.proj_drop = nn.Dropout(proj_drop) # 输出的 dropout
def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):
'''
高频分支的注意力计算
x: (b c h w) 输入特征
''' b, c, h, w = x.size()
qkv = to_qkv(x) # 计算 QKV,得到 (b, 3*m*d, h, w) qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous() # 混合后得到 (3, b, m*d, h, w) q, k, v = qkv # 分解为 Q、K、V
attn = attn_block(q.mul(k)).mul(self.scalor) # 计算缩放后的注意力
attn = self.attn_drop(torch.tanh(attn)) # 使用 tanh 激活并应用 dropout res = attn.mul(v) # 应用注意力权重到 V return res
def low_fre_attention(self, x: torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):
'''
低频分支的注意力计算
x: (b c h w) 输入特征
''' b, c, h, w = x.size()
q = to_q(x).reshape(b, -1, self.dim_head, h * w).transpose(-1, -2).contiguous() # 计算 Q 并调整形状为 (b, m, h*w, d) kv = avgpool(x) # 对输入特征进行平均池化
kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h * w) // (self.window_size ** 2)).permute(1, 0, 2, 4, 3).contiguous() # 计算 KV k, v = kv # 分解为 K、V
attn = self.scalor * q @ k.transpose(-1, -2) # 计算缩放后的注意力
attn = self.attn_drop(attn.softmax(dim=-1)) # 对注意力进行 softmax 和 dropout res = attn @ v # 应用注意力权重到 V res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous() # 调整形状为原始形状
return res
def forward(self, x: torch.Tensor):
'''
x: (b c h w) 输入特征
''' res = [] # 保存各分支的输出
for i in range(len(self.kernel_sizes)):
if self.group_split[i] == 0:
continue
res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i])) # 高频分支输出
if self.group_split[-1] != 0:
res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool)) # 低频分支输出
return self.proj_drop(self.proj(torch.cat(res, dim=1))) # 合并分支输出并应用投影
# 输入 N C HW, 输出 N C H W
if __name__ == '__main__':
block = CloAttention(64).cuda() # 初始化 CloAttention 模块
input = torch.rand(1, 64, 64, 64).cuda() # 创建一个随机输入
output = block(input) # 前向传播
print(f"Input_Size:{input.size()}\nOutput_Size:{output.size()}") # 打印输入和输出的张量形状
5. 结果
表中报告了ImageNet1K分类结果。 结果表明,当模型大小和FLOPs相似时,模型比以前的模型性能更好。 其中,CloFormer-XXS仅使用4.2万个参数和0.6G FLOPs, Top-1准确率达到77.0%,分别超过ShuffleNetV22x、MobileViT-XS和EdgeViT-XXS 1.6%、2.2%和2.6%