摘要
https://arxiv.org/pdf/2408.03703
视觉转换器(Vision Transformers,ViTs)以其标记混合器强大的全局上下文能力,在神经网络领域取得了革命性的进展。然而,尽管以往的工作已做出相当大的努力,但成对标记亲和力和复杂的矩阵运算限制了其在资源受限场景和实时应用(如移动设备)中的部署。在本文中,我们介绍了CAS-ViT:卷积加性自注意力视觉转换器,以在移动应用中实现效率和性能之间的平衡。首先,我们认为,标记混合器获取全局上下文信息的能力依赖于多个信息交互,如空间和通道域。随后,我们根据这一范式构建了一个新颖的加性相似度函数,并提出了一种称为卷积加性标记混合器(Convolutional Additive Token Mixer,CATM)的高效实现方法。这种简化显著降低了计算开销。我们在多种视觉任务上对CAS-ViT进行了评估,包括图像分类、目标检测、实例分割和语义分割。我们在GPU、ONNX和iPhone上进行的实验表明,与其他最先进的骨干网络相比,CAS-ViT取得了具有竞争力的性能,证明了它是高效移动视觉应用的可行选择。我们的代码和模型可在以下网址获取:https://github.com/Tianfang-Zhang/CAS-ViT
引言
近年来,视觉转换器(Vision Transformers,ViTs)的出现标志着神经网络架构的革命性转变(Han等人,2021;Huang等人,2023;Zhang等人,2023)。与以较低计算复杂度和较高推理效率著称的卷积神经网络(CNNs)(Zhang等人,2022;Zhang、Li和Peng,2023)相比,ViTs采用了一种名为标记混合器(token mixer)的新型架构(Yu等人,2022)。该模块通过捕获长距离依赖性的能力,促进了ViTs的全局建模和表示能力的提升。
ViT的基本模块包括标记混合器、多层感知机(MLP)以及相应的跳跃连接,其中标记混合器广泛实现为多头自注意力(MSA)。MSA对整个输入序列进行操作,弥补了具有受限感受野的CNNs的局限性,并在模型可扩展性和适应性方面提供了独特优势(Li等人,2023;Wu等人,2024)。尽管MSA备受赞誉,但现实情况是,其相对于输入图像大小的二次复杂性(Wu等人,2021)使得ViT模型更加资源密集,不适合实时应用,也不利于在资源受限的设备(如移动应用)上广泛部署(Mehta和Rastegari,2021;Liu等人,2023)。因此,开发既能高效部署又具有高性能的标记混合器,成为移动设备上的紧迫问题。
已经有许多努力致力于改进标记混合器,包括MSA的改进和异构MSA(H-MSA)。改进主要集中在Query和Key的矩阵乘法之后,以增强捕获长距离依赖性的能力,同时降低算法复杂度。具体技术包括特征偏移(Liu等人,2021)、载体标记(Hatamizadeh等人,2023)、稀疏注意力(Wang等人,2022)、线性注意力(Han等人,2023)等。作为一种替代范式,H-MSA旨在突破 Q \mathbf{Q} Q和 K \mathbf{K} K之间点对点相似性的限制,以探索更灵活的网络设计(Yu等人,2022)。最近提出的池化标记混合器(Yu等人,2022)和上下文向量(Sandler等人,2018;Shaker等人,2023)进一步提高了推理效率。
尽管相关工作已经取得了显著进展,但ViT模型仍然受到以下限制:1. 标记混合器中的Softmax等矩阵运算复杂度较高。2. 在移动设备或实时应用中同时实现准确性、效率和易于部署的难度较大。
为了解决遇到的问题,我们提出了一系列名为卷积加性自注意力(Convolutional Additive Self-attention,CAS)-ViT的轻量级网络,如图1所示,以实现计算与效率的平衡。首先,我们认为标记混合器获取全局上下文信息的能力依赖于多种信息交互,如空间和通道域。同时,我们根据这一范式构建了一种新颖的加性相似度函数,并期望能激发更多有价值的研究。此外,我们提出了卷积加性标记混合器(Convolutional Additive Token Mixer,CATM),采用潜在的空间和通道注意力作为新颖的交互形式。该模块消除了诸如矩阵乘法和Softmax等繁琐的复杂操作。最后,我们在各种视觉任务上评估了我们的方法,并在GPU、ONNX和iPhone上报告了吞吐量。大量实验表明,与其他最先进(State-Of-The-Art,SOTA)骨干网络相比,我们的性能具有竞争力。
相关工作
高效视觉转换器
自ViT(Dosovitskiy等人,2020)问世以来,并在ImageNet(Deng等人,2009)等大型数据集上的图像分类任务中成功验证后,它展示了自注意力机制在计算机视觉应用中的潜力(Dong等人,2022)。然而,网络规模的同时增加对资源受限的场景(如移动设备和实时应用)构成了巨大挑战。为了提升ViT的潜在价值,研究人员在高效ViT方面投入了大量精力。
ViT架构的改进思路涵盖了多个方面。其中之一是优化标记混合器以增强性能或解决自注意力机制的二次复杂度问题。例如,PVT(Wang等人,2021)采用空间降维策略来实现稀疏注意力,以处理高分辨率图像;Swin(Liu等人,2021)则采用窗口分割方法来实现局部自注意力,并通过窗口移动来处理块之间的依赖关系。
另一个角度是探索结合CNN和Transformer的混合模型,以弥补自注意力机制在处理局部信息方面的局限性(Guo等人,2022)。EdgeViT(Pan等人,2022)分别在块内采用卷积层和稀疏注意力进行信息整合和传播。NextViT(Li等人,2022a)通过全面的实验验证了混合模型及其设计策略的有效性。EfficientViT(Liu等人,2023)则进一步分析了各种操作的时间消耗比例,以实现高效的推理。
高效标记混合器
高效标记混合器的设计是Transformer演进中的关键方向之一。人们致力于追求更轻量、计算效率更高的标记混合器,以提高训练和实践应用的可行性。
一部分努力致力于改进MSA。Twins(Chu等人,2021)引入块之间的自注意力,以实现全局和局部依赖的串联。一些工作(Ali等人,2021;Ding等人,2022;Maaz等人,2022)则专注于通道之间的信息整合,并通过拼接和转置来实现不同域之间的自注意力。线性注意力(Shen等人,2021)通过假设相似度函数是线性可微的,解决了自注意力机制的二次复杂度问题。Han等人(2023)进一步推广了基于ReLU的线性注意力,使其通过聚焦函数更加离散化。Bolya等人(2022)则通过增加注意力头来进一步简化注意力复杂度。
异构多头自注意力(Heteromorphic-MSA,简称H-MSA)是自注意力机制发展的一种扩展形式,它突破了多头自注意力(MSA)框架的限制,旨在获得更好的特征关系和更鲁棒的推理效率。最初,MetaFormer(Yu等人,2022)提出,标记混合器并不是影响Transformer性能的关键组件,然而PoolFormer并未被验证为高度高效。随后,MobileViTv2(Mehta和Rastegari,2022)通过为上下文向量赋予全局信息,简化了复杂的矩阵乘法。SwiftFormer(Shaker等人,2023)甚至去除了“值(Value)”分量,并使用更简单的归一化操作实现了特征之间的加权和,从而实现了更简洁、高效的H-MSA。
方法
在本节中,我们首先回顾了MSA及其变体的原理。随后,我们将介绍提出的CATM,并重点介绍其与传统机制的区别和优势。最后,我们将描述CAS-ViT的整体网络架构。
自注意力及其变体的概述
作为视觉Transformer的关键组件,自注意力机制可以有效地捕获不同位置之间的关系。给定一个输入 x ∈ R N × d \mathbf{x} \in \mathbb{R}^{N \times d} x∈RN×d,如图2(a)所示,其中包含 N N N个标记,每个头内部有 d d d维嵌入向量。自注意力可以通过相似度函数 Sim ( Q , K ) = exp ( Q K ⊤ / d ) \operatorname{Sim}(\mathbf{Q}, \mathbf{K})= \exp \left(\mathbf{Q K}^{\top} / \sqrt{d}\right) Sim(Q,K)=exp(QK⊤/d)表示如下:
O = Softmax ( Q K ⊤ d ) V \mathbf{O}=\operatorname{Softmax}\left(\frac{\mathbf{Q K}^{\top}}{\sqrt{d}}\right) \mathbf{V} O=Softmax(dQK⊤)V
可分离自注意力(Sandler等人,2018年),见图2(b),将基于矩阵的特征度量简化为向量,通过降低计算复杂度实现轻量级和高效的推理。随后,通过Softmax(-)计算上下文分数。然后,将上下文分数 q \mathbf{q} q与 K \mathbf{K} K相乘,并沿空间维度求和,以获得衡量全局信息的上下文向量。具体可以描述为:
O = ( ∑ i = 1 N Softmax ( q ) ⋅ K ) V \mathbf{O}=\left(\sum_{i=1}^{N} \operatorname{Softmax}(\mathbf{q}) \cdot \mathbf{K}\right) \mathbf{V} O=(i=1∑NSoftmax(q)⋅K)V
其中, q ∈ R N × 1 \mathbf{q} \in \mathbb{R}^{N \times 1} q∈RN×1是通过线性层从 Q \mathbf{Q} Q键中获得的,-表示逐元素广播乘法。
Swift自注意力,见图2©,是一种备受关注的HMSA架构,它将自注意力的键减少到两个,从而实现快速推理。它利用通过线性变换 Q \mathbf{Q} Q获得的系数 α \alpha α来对每个标记进行加权。随后,沿空间域求和并与 K \mathbf{K} K相乘以获得全局上下文。具体表示为:
O = T ( ( ∑ i = 1 N α i ⋅ Q i ) K ) + Q ^ \mathbf{O}=\mathrm{T}\left(\left(\sum_{i=1}^{N} \alpha_{i} \cdot \mathbf{Q}_{i}\right) \mathbf{K}\right)+\hat{\mathbf{Q}} O=T((i=1∑Nαi⋅Qi)K)+Q^
其中, Q ^ \hat{\mathbf{Q}} Q^表示归一化查询, T ( ⋅ ) \mathrm{T}(\cdot) T(⋅)表示线性变换。
卷积加性自注意力
在本节中,我们认为自注意力机制的信息整合能力在于多种信息交互,例如MSA的空间关联、DaViT(Ding等人,2022年)在通道上的交互,以及MobileViTv2和SwiftFormer在两个维度上的压缩表示,如图2所示。或者,是否存在简单且有效的操作能在满足多种交互的同时表现更好?
基于这一假设,如图2(d)所示,我们创新地将相似度函数定义为 Q ∈ R N × d \mathbf{Q} \in \mathbb{R}^{N \times d} Q∈RN×d和 K ∈ R N × d \mathbf{K} \in \mathbb{R}^{N \times d} K∈RN×d的上下文分数之和:
Sim ( Q , K ) = Φ ( Q ) + Φ ( K ) s.t. Φ ( Q ) = C ( S ( Q ) ) \operatorname{Sim}(\mathbf{Q}, \mathbf{K})=\Phi(\mathbf{Q})+\Phi(\mathbf{K}) \text { s.t. } \Phi(\mathbf{Q})=\mathcal{C}(\mathcal{S}(\mathbf{Q})) Sim(Q,K)=Φ(Q)+Φ(K) s.t. Φ(Q)=C(S(Q))
其中,Query、Key和Value是通过独立的线性变换获得的,例如 Q = W q x \mathbf{Q}=W_{q} \mathbf{x} Q=Wqx, K = W k x \mathbf{K}=W_{k} \mathbf{x} K=Wkx, V = W v x \mathbf{V}=W_{v} \mathbf{x} V=Wvx, Φ ( ⋅ ) \Phi(\cdot) Φ(⋅)表示上下文映射函数,它包含了基本的信息交互。这种泛化的优点是不受手动上下文设计的限制,并且有可能通过卷积操作来实现。在本文中,我们简单地将 Φ ( ⋅ ) \Phi(\cdot) Φ(⋅)具体化为基于Sigmoid的通道注意力 C ( ⋅ ) ∈ R N × d \mathcal{C}(\cdot) \in \mathbb{R}^{N \times d} C(⋅)∈RN×d和空间注意力 S ( ⋅ ) ∈ R N × d \mathcal{S}(\cdot) \in \mathbb{R}^{N \times d} S(⋅)∈RN×d。因此,CATM(卷积加性自注意力模块)的输出可以表示为:
O = Γ ( Φ ( Q ) + Φ ( K ) ) ⋅ V \mathbf{O}=\Gamma(\Phi(\mathbf{Q})+\Phi(\mathbf{K})) \cdot \mathbf{V} O=Γ(Φ(Q)+Φ(K))⋅V
其中, Γ ( ⋅ ) ∈ R N × d \Gamma(\cdot) \in \mathbb{R}^{N \times d} Γ(⋅)∈RN×d表示用于整合上下文信息的线性变换。由于CATM中的操作由卷积表示,因此其复杂度为 O ( N ) \mathcal{O}(N) O(N)。
与可分离自注意力的关系:与(Mehta和Rastegari 2022)仅在Query上提取上下文分数的方法相比,我们在Query和Key分支上都进行了相似度提取,并保留了每个分支上的原始特征维度。这有助于更好地保留视觉稀疏特征,并避免了在二维分数向量上的信息丢失。
与高效加性注意力的关系:首先,在令牌混合器上,我们采用了经Sigmoid激活的注意力提取形式,而不是归一化。这有助于网络的并行化和在移动设备上的部署。此外,(Shaker et al. 2023)中的注意力模块仅应用于网络每个阶段的最后一层,而提出的CATM(卷积加性自注意力模块)将应用于整个ViT(视觉Transformer)架构的每一层。
复杂度分析:在具体实现中, S ( ⋅ ) \mathcal{S}(\cdot) S(⋅)被设计为深度卷积和Sigmoid激活的组合, Ω ( S ) = ( 11 + 4 b ) H W C \Omega(\mathcal{S})=(11+4b)HWC Ω(S)=(11+4b)HWC,其中 b b b表示批量大小,在训练和推理阶段都保持不变。 C ( ⋅ ) \mathcal{C}(\cdot) C(⋅)则是通过简化的通道注意力实现的, Ω ( C ) = ( 2 + b ) H W C \Omega(\mathcal{C})=(2+b)HWC Ω(C)=(2+b)HWC。结合QKV映射和 Γ \Gamma Γ的线性变换等操作,CATM相对于输入大小保持了线性复杂度:
Ω ( C A T M ) = ( 47 + 10 b ) H W C \Omega(\mathrm{CATM})=(47+10b)HWC Ω(CATM)=(47+10b)HWC
网络架构
图3(上部)展示了提出的网络架构。输入一个大小为 H × W H \times W H×W的自然图像。随后,通过Stem中的两个步长为2的连续卷积层,将其下采样到 H 4 × W 4 × C 1 \frac{H}{4} \times \frac{W}{4} \times C_{1} 4H×4W×C1。
之后,它通过四个阶段编码层,每个阶段之间使用Patch Embedding进行两次下采样,并获得大小为
H
8
×
W
8
×
C
2
\frac{H}{8} \times \frac{W}{8} \times C_{2}
8H×8W×C2、
H
16
×
W
16
×
C
3
\frac{H}{16} \times \frac{W}{16} \times C_{3}
16H×16W×C3和
H
32
×
W
32
×
C
4
\frac{H}{32} \times \frac{W}{32} \times C_{4}
32H×32W×C4的特征图。
C
i
C_{i}
Ci,
i
∈
{
1
,
2
,
3
,
4
}
i \in\{1,2,3,4\}
i∈{1,2,3,4}表示特征图通道数。每个阶段包含
N
i
N_{i}
Ni个堆叠块,如图3(下部)所示,特征图大小保持不变。
块引用混合网络的设计,如EfficientViT(Liu et al. 2023)和EdgeViT(Pan et al. 2022),包含带有残差捷径的三个部分:集成子网、CATM和MLP。受SwiftFormer(Shaker et al. 2023)的启发,集成子网由三个经ReLU(Glorot, Bordes, and Bengio 2011)激活的深度卷积层组成。我们通过改变通道数 C i C_{i} Ci和块数 N i N_{i} Ni来构建一系列轻量级ViT模型,具体参数设置请参考附录。
实验
ImageNet-1K 分类
实现细节 ImageNet-1K(Deng et al. 2009)包含超过130万张图像,跨越1000个自然类别。该数据集覆盖了广泛的对象和场景,并因其多样性而成为使用最广泛的数据集之一。我们从零开始训练网络,没有使用任何预训练模型或额外数据。训练策略遵循EdgeNeXt(Maaz et al. 2022),所有模型均在输入大小为 224 × 224 224 \times 224 224×224的条件下,使用AdamW(Loshchilov and Hutter 2018)优化器进行300个周期的训练,批量大小为2048。学习率设置为 6 × 1 0 − 3 6 \times 10^{-3} 6×10−3,并采用带有20个周期预热的余弦(Loshchilov and Hutter 2016)衰减计划。启用了标签平滑0.1(Szegedy et al. 2016)、随机调整大小裁剪、水平翻转、RandAugment(Cubuk et al. 2020)、多尺度采样器(Mehta and Rastegari 2021),并在训练过程中将EMA(Polyak and Juditsky 1992)的动量设置为0.9995。为了充分利用网络的有效性,我们在 384 × 384 384 \times 384 384×384分辨率和批量大小为64的条件下,以 1 0 − 5 10^{-5} 10−5的学习率对模型进行了另外30个周期的微调。
我们在基于PyTorch(Paszke et al. 2019)的TIMM{ }^1上实现了分类模型,并在16个V100 GPU上运行。此外,我们将Torch模型编译为ONNX格式,并分别在V100 GPU和Intel Xeon Gold CPU @ 3.00 GHz上测量了批量大小为64时的吞吐量。对于移动端,我们通过CoreML库{ }^2进行编译,并在部署到iPhone X神经引擎上后测量吞吐量。
结果 表1中ImageNet-1K(Deng et al. 2009)数据集的实验结果清晰地展示了我们的模型在图像分类领域的进步。与已建立的基准相比,我们的方法在显著提高分类精度的同时,还巧妙地处理了模型复杂性与计算需求之间的权衡。值得注意的是,我们的模型XS和S变体在尺寸(以数百万个参数衡量)和计算效率(以Flops量化)之间表现出了非凡的协同作用,同时没有牺牲关键的Top-1准确率指标。这一成就是关键性的,因为它表明我们的模型即使在受限的计算设置中也能保持高水平的准确率,这证明了其优化架构的有效性。我们的模型在性能上的优势,特别是在与传统模型(如MobileNetV3(Howard et al. 2019))相比,以更少的参数和更低的Flops实现更高的Top-1准确率方面,凸显了其在资源受限的现实世界应用中的潜力。此外,跨不同平台(GPU、ONNX和ANE)的吞吐量分析凸显了我们模型的适应性和效率,使其成为从移动设备到高端服务器等各种场景下部署的理想候选,从而拓宽了其在图像识别领域的应用范围。有关额外信息(如微调)的比较,请参阅附录。
计算效率 在所有任务中,我们的模型在计算效率方面保持了竞争优势。这一点在比较Flops和参数数量时尤为明显,其中我们的模型在较低的计算开销下实现了更高或可比的性能指标。例如,CAS-ViT-M在检测和分割任务中均优于所有列出的模型,同时保持了具有竞争力的计算成本,凸显了我们架构选择的有效性。
目标检测和实例分割
实现细节 以在ImageNet-1K上预训练的模型作为主干网络,集成了RetinaNet(Lin et al. 2017)和Mask RCNN(He et al. 2017)来评估我们模型在MS COCO 2017(Lin et al. 2014)数据集上的目标检测和实例分割性能。遵循(Yu et al. 2022)的方法,我们采用了 1 × 1 \times 1×训练策略,并使用AdamW(Loshchilov and Hutter 2018)优化器对网络进行了12个周期的微调,学习率为 2 × 1 0 − 4 2 \times 10^{-4} 2×10−4,批量大小为32。训练图像被调整到短边800像素,长边不超过1333像素。模型在MMDetection(Chen et al. 2019)代码库上实现。
结果 在对COCO val2017(Lin et al. 2014)数据集上的目标检测和实例分割进行全面评估时,我们评估了几个主干模型在不同尺度(小、中、大)和任务(检测和分割)上的平均精度(AP)指标,以及由参数(Par.)和Flops指示的计算效率。分析主要集中在将我们提出的模型与ResNet-50(He et al. 2016)、PVT-T(Wang et al. 2022)、PoolFormer-S12(Yu et al. 2022)、EfficientFormer-L1(Li et al. 2022b)、ResNeXt101-32x4d(Xie et al. 2017)、PVTv2B0(Wang et al. 2022)、PoolFormer-S36(Yu et al. 2022)和SwiftFormer-L1(Shaker et al. 2023)等已建立模型进行性能比较。
语义分割
实现细节 我们在ADE20K(Zhou et al. 2019)数据集上进行了语义分割实验。将提出的模型与Semantic FPN(Kirillov et al. 2019)结合进行评估,其中归一化层被冻结,作为主干网络并加载了ImageNet-1K分类的预训练权重。遵循通用做法(Yu et al. 2022),网络在批量大小为32和AdamW优化器的情况下进行了40K次迭代训练。学习率初始化为
2
×
1
0
−
4
2 \times 10^{-4}
2×10−4,并按照多项式调度策略以0.9的幂次衰减。在训练过程中,图像被调整大小并裁剪为
512
×
512
512 \times 512
512×512。该模型在MMSegmentation^{3}代码库上实现。
结果 如表2所示,对ADE 20K数据集上的分割结果进行定性和定量分析,揭示了我们的模型(CAS-ViT-XS、CAS-ViT-S和CAS-ViT-M)在计算效率和分割精度之间达到了优越的平衡。尽管我们的模型参数数量和Flops相对较低,但它们展示了具有竞争力或更优的平均交并比(mIoU)百分比,这表明它们在语义分割任务中的有效性。具体而言,CAS-ViT-M以 43.6 % 43.6\% 43.6%的mIoU超越了更多计算密集型模型,凸显了其架构效率和在设计上捕获详细语义信息的有效性。我们的模型变体在高效性和准确性之间的高度平衡,突显了它们在各种分割应用中的适用性,尤其是在资源受限的场景中,确立了它们在ADE20K(Zhou et al. 2019)数据集上进行高效且有效的语义分割的吸引力。
可视化
图4展示了在网络最后一层进行的热力图可视化对比,其中三种方法具有相似的参数数量。可以观察到,我们的方法能够准确地将感兴趣区域定位到关键部分。同时,相比之下,我们获得了更大的感受野,这有利于整体性能,特别是对于后续的密集预测任务。有关目标检测和实例分割的更多可视化内容,请参阅附录。
消融研究
我们通过消融研究证明了所提出方法的有效性。如表4(#1-#3)所示,我们分别用Pooling(Yu et al. 2022)和WindowMSA(Dosovitskiy et al. 2020)替换了CATM。没有信息交互的Pooling导致性能下降了-1.72%,而通过W-MSA进行交互仅下降了-0.02%。这验证了CATM通过Sigmoid激活注意力进行信息交互的可行性。表4(#4-#5)探讨了缺少维度交互对结果的影响,其中缺少空间和通道交互分别导致性能下降了0.45%和0.21%。(#6-#7)将查询(Query)和键(Key)分支中的上下文映射函数
Φ
(
⋅
)
\Phi(\cdot)
Φ(⋅)替换为不同的函数,并且只有在#7中保留完全交互才能实现类似的结果。
局限性和未来工作
CAS-ViT在特征处理中广泛使用了卷积层,这本身保留了归纳偏置。其优点包括比纯Transformer更快的收敛速度,但随之而来的问题是在大规模数据集和大参数模型上的效果略差。在未来的工作中,我们将探索更大的数据集和模型规模。我们致力于开发更加高效且便于部署的轻量级网络。
结论
在本文中,我们提出了一种卷积加性自注意力网络,称为CAS-ViT。首先,我们认为使ViT中的标记混合器高效工作的关键是进行包括空间和通道域在内的多重信息交互。随后,遵循这一范式,我们创新性地设计了一个加性相似度函数,并通过Sigmoid激活的注意力简单地实现了它,这有效地避免了纯ViT中矩阵乘法和Softmax等复杂操作。我们构建了一系列轻量级模型,并在图像分类、目标检测、实例/语义分割等任务上验证了它们的优越性能。同时,我们已将网络部署到ONNX和iPhone上,实验表明,CAS-ViT在保持高精度的同时,促进了移动设备上的高效部署和推理。