paper: https://arxiv.org/abs/2308.00442
code: https://github.com/LeapLabTHU/FLatten-Transformer
摘要
当将transformer模型应用于视觉任务时,自注意的二次计算复杂度( n 2 n^2 n2)一直是一个持续存在的挑战。另一方面,线性注意通过精心设计的映射函数来近似Softmax操作,提供了一个更有效的替代方法。然而,当前的线性注意方法要么遭受显著的性能下降,要么从映射函数中引入额外的计算开销。在本文中,作者提出了一种新的聚焦线性注意模块,以实现高效率和表达性。具体来说,首先从聚焦能力和特征多样性两个角度分析了导致线性注意性能下降的因素。为了克服这些限制,引入了一个简单而有效的映射函数和一个高效的秩恢复模块,在保持低计算复杂度的同时提高自我注意的表达性。大量的实验表明,线性注意模块适用于各种高级视觉transorfomer,并在多个基准上取得了持续改进的性能。
介绍
将transorfomer应用于视觉模型是一项艰巨的任务。在将自注意力机制应用于全局感受野,与轻量级卷积神经网络不同,相对于序列长度
n
n
n的二次计算复杂度
O
(
n
2
)
O(n^2)
O(n2)导致了较高的计算成本。之前的工作通过限制全局感受野为一个更小的区域感受野,比如稀疏的全局注意模式、更小的注意力窗口。尽管这些方法很有效,但由于它们的注意力模式,它们要么倾向于忽略其他区域的信息特征,要么不可避免地牺牲了建模长期依赖关系的能力。
另一方面,线性注意被认为是一种简单而有效的替代方法,通过降低一般的复杂性来解决计算困境。早期的研究利用了一种局部敏感的哈希方案,该方案将计算复杂度从
O
(
n
2
)
O(n^2)
O(n2)压缩为
O
(
n
l
o
g
(
n
)
)
O(nlog(n))
O(nlog(n))。然而,它在复杂度项之前引入了一个很大的常数,这使得它在常见情况下仍然负担不起。最近的研究注意到,在自注意操作中使用Softmax函数实际上迫使所有查询和键之间进行两两计算,从而导致主要的
O
(
n
2
)
O(n^2)
O(n2)复杂度。为了解决这个问题,一些方法开始采用简单的激活函数或者定制映射函数去近似原始的Softmax函数。如图1所示,通过将计算顺序从(查询键)值改为查询值(键值),可以将整体计算复杂度降低到
O
(
n
)
O (n)
O(n)。然而,与Softmax注意相比,目前的线性注意方法仍然存在严重的性能下降,可能涉及映射函数的额外计算开销,从而限制了其实际应用。
本文针对当前线性注意方法的局限性,提出了一种新的Focused Linear Attention,该模块既具有高效率和表达性。具体来说,对导致线性关注性能下降的因素进行了双管齐下的分析,并提出了相应的解决方案。首先,前一种线性注意模块的注意权重分布相对平滑,缺乏处理信息最丰富的特征的聚焦能力。作为一种补救措施,本文提出了一个简单的映射函数来调整查询和键的特征方向,使注意力权重更容易区分。其次,注意矩阵的秩的减少限制了线性注意特征的多样性。为了解决这一问题,提出了一个对原始注意矩阵进行深度卷积(DWC)的秩恢复模块,这有助于恢复矩阵的秩,并保持不同位置的输出特征多样化。利用这些改进的技术,模块展示了与Softmax的同类产品相当或更优越的性能,同时享受了低计算复杂度的好处。
相关工作
Vision Transformer
transformer和自我注意机制首次引入自然语言处理领域,在计算机视觉领域获得了广泛的研究兴趣。然而,自注意集的高计算复杂度限制了其直接应用于视觉任务。之前的视觉transformer考虑通过将相邻像素合并为单个令牌来降低输入分辨率。接下来的研究也采用了类似的见解,也扩展到下游任务。另一项研究逐渐降低了特征的分辨率,并采用了精心设计的注意模式来限制token数量。例如,PVT 使用一个稀疏的注意模式,并从全局的角度选择注意令牌。AT 遵循该路径,设计了一个可变形的注意模块,以实现数据依赖的注意模式。Swin变压器通过将输入划分为孤立的窗口来局部选择注意令牌。NAT在卷积中遵循以查询为中心的模式,并为所有查询设计独立的注意标记。一些研究也注意到,卷积运算对变压器模型很有价值,可能有助于提高整体效率。CMT将变压器块与高效的卷积算子相结合,像深度可分离卷积(depthwise convolution),并实现了更好的效率和性能的权衡。ACmix 共享了卷积和自注意的计算开销,并以有限的成本集成了这两个模块。并提出了有效训练变压器的方法。
然而,这些方法仍然依赖于Softmax算子,其继承了较高的计算复杂度,不可避免地给模型架构设计和实际应用带来了不便。
Linear Attention
除了上述方法外,另一项研究利用线性注意解决高计算复杂度。具体地说,线性注意用单独的核函数代替了自注意中的Softmax函数。在这种情况下,线性注意不需要首先计算成对相似度
Q
K
T
QK^T
QKT,正如图1所示,这种情况softmax函数将不可用,因此需要再设计一个映射函数。如何设计与softmax注意力机制一样有效的线性注意模块仍然是一个重要的问题。Performer近似于具有正交随机特征的Softmax操作。Efficient attention 将Softmax函数分别应用于Q和K,这自然保证了每一行
Q
K
T
QK^T
QKT的总和为1。Nystromformer和SOFT通过矩阵分解近似全自注意矩阵。Hydra attention用余弦相似度代替Softmax。EfficientVit利用深度可分离卷积来提高线性注意的局部特征提取能力。然而,目前的线性注意设计要么没有足够的表达能力来赶上sotmax注意,要么涉及来自复杂核函数的额外计算开销。在本研究中,作者从聚焦能力和特征多样性的角度分析了线性注意性能下降的原因。在此基础上,作者提出了一种新的线性注意模块,称为聚焦线性注意,它在计算复杂度较低的情况下比Softmax注意更好的性能(图2)。
具体来说:
自注意力机制公式:对于每一个token有
O
i
=
∑
j
=
1
N
S
i
m
(
Q
i
,
K
j
)
∑
j
=
1
N
S
i
m
(
Q
i
,
K
j
)
V
j
O_i = \sum_{j=1}^N\frac{Sim(Q_i,K_j)}{\sum_{j=1}^N Sim(Q_i,K_j)}V_j
Oi=∑j=1N∑j=1NSim(Qi,Kj)Sim(Qi,Kj)Vj,
S
i
m
Sim
Sim表示相似度函数,
S
i
m
(
Q
i
,
K
j
)
=
e
x
p
(
Q
j
K
j
T
/
d
)
Sim(Q_i,K_j)=exp(Q_j{K_j}^T/\sqrt{d})
Sim(Qi,Kj)=exp(QjKjT/d)
线性注意力机制:精心设计的核作为原始相似度函数的近似值
S
i
m
(
Q
i
,
K
j
)
=
ϕ
(
Q
i
)
ϕ
(
K
j
)
T
Sim(Q_i,K_j)=\phi (Q_i)\phi(K_j)^T
Sim(Qi,Kj)=ϕ(Qi)ϕ(Kj)T
那么自注意力机制公式就可以被重写为
O
i
=
∑
j
=
1
N
ϕ
(
Q
i
)
ϕ
(
K
i
)
T
∑
j
=
1
N
ϕ
(
Q
i
)
ϕ
(
K
j
)
T
V
j
O_i = \sum_{j=1}^N\frac{\phi(Q_i)\phi(K_i)^T}{\sum_{j=1}^N\phi(Q_i)\phi(K_j)T}V_j
Oi=∑j=1N∑j=1Nϕ(Qi)ϕ(Kj)Tϕ(Qi)ϕ(Ki)TVj
这样就可以将
(
Q
K
T
)
V
(QK^T)V
(QKT)V转化为
Q
(
K
T
V
)
Q(K^TV)
Q(KTV),即
O
i
=
ϕ
(
Q
i
)
(
∑
j
=
1
N
ϕ
(
K
j
)
T
V
j
)
ϕ
(
Q
i
)
(
∑
j
=
1
N
ϕ
(
K
I
)
T
)
O_i = \frac{\phi(Q_i)(\sum_{j=1}^N\phi(K_j)^TV_j)}{\phi(Q_i)(\sum_{j=1}^N\phi(K_I)^T)}
Oi=ϕ(Qi)(∑j=1Nϕ(KI)T)ϕ(Qi)(∑j=1Nϕ(Kj)TVj)
注意
Q
i
Q_i
Qi为query向量,
K
j
K_j
Kj为key向量,
V
j
V_j
Vj为value向量。
方法(Focused Linear Attention)
Focus ability
softmax注意力机制实际上提供了一种非线性重加权机制,使其很容易集中在重要的特征。如图3所示,来自Softmax注意的注意图在某些区域的分布特别明显,如前景物体。相比之下,线性注意的分布是相对的平滑,使其输出更接近所有特征的平均值,而不能关注信息更丰富的区域。
作为补救措施,作者提出了一个简单而有效的解决方案,通过调整每个查询和关键特征的方向,接近相似的查询键对,同时消除不同的查询键对。具体来说,作者提出了一个简单的映射函数
f
p
f_p
fp,称为Focused函数:
S
i
m
(
Q
i
,
K
j
)
=
ϕ
p
(
Q
i
)
ϕ
p
(
K
j
)
T
Sim(Q_i,K_j)=\phi_p(Q_i)\phi_p(K_j)^T
Sim(Qi,Kj)=ϕp(Qi)ϕp(Kj)T
where
ϕ
p
(
x
)
=
f
p
(
R
e
L
U
(
x
)
)
,
f
p
(
x
)
=
∣
∣
x
∣
∣
∣
∣
x
∗
∗
p
∣
∣
x
∗
∗
p
\phi_p(x)=f_p(ReLU(x)),f_p(x)=\frac{||x||}{||x^{**p}||}x^{**p}
ϕp(x)=fp(ReLU(x)),fp(x)=∣∣x∗∗p∣∣∣∣x∣∣x∗∗p
其中
x
∗
∗
p
x^{**p}
x∗∗p表示x按元素的p次方。作者证明了所提出的映射函数
f
p
f_p
fp实际上影响了注意力的分布。
命题1: f p f_p fp调整特征方向
令 x = ( x 1 , . . . , x n ) , y = ( y 1 , . . . , y n ) ∈ R n , x i , y i ≥ 0 x=(x_1,...,x_n),y=(y_1,...,y_n) \in \mathbb{R}^n,x_i,y_i\ge 0 x=(x1,...,xn),y=(y1,...,yn)∈Rn,xi,yi≥0假设x和y分别有一个最大的值 x m x_m xm和 y n y_n yn。
当 m = n m=n m=n时,有 ∃ p > 1 , s . t . ⟨ ϕ p ( x ) , ϕ p ( y ) ⟩ > ⟨ x , y ⟩ \exists p> 1, s.t. \left \langle \phi_p(x),\phi_p(y) \right \rangle > \left \langle x,y \right \rangle ∃p>1,s.t.⟨ϕp(x),ϕp(y)⟩>⟨x,y⟩
当 m ≠ n m\ne n m=n时,有 ∃ p > 1 , s . t . ⟨ ϕ p ( x ) , ϕ p ( y ) ⟩ < ⟨ x , y ⟩ \exists p> 1, s.t. \left \langle \phi_p(x),\phi_p(y) \right \rangle < \left \langle x,y \right \rangle ∃p>1,s.t.⟨ϕp(x),ϕp(y)⟩<⟨x,y⟩
⟨ x , y ⟩ \left \langle x,y \right \rangle ⟨x,y⟩表示內积 x y T xy^T xyT
这个命题可以这样理解,
f
p
f_p
fp使相似的query-key更明显的区别(
m
=
n
m=n
m=n时內积相比原始值变大),不相似的query-key恢复了尖锐的注意力分布作为原来的Softmax函数(
m
≠
n
m\ne n
m=n內积更小),从而实现focus ablity。
为了更好地理解,我们给出了一个例子来显示图4中
f
p
f_p
fp的影响。可以看出,
f
p
f_p
fp实际上将每个向量“拉”到它最近的轴上,而p决定了这种“拉”的程度。通过这样做,
f
p
f_p
fp有助于根据特征最近的轴将特征划分为几个组,提高每个组内的相似性,同时减少组之间的相似性。可视化与我们上面的分析一致。
Feature diversity
除了focus ablity外,特征多样性也是限制线性注意表达能力的因素之一。其中一个可能的原因可能归功于注意力矩阵的秩,其中可以看到显著的差异。以
N
=
14
×
14
N=14×14
N=14×14的DeiT-Tiny的transformer层为例,从图5 (a)可以看出,注意力矩阵具有完整的秩(196中的196),显示了从值聚合特征时的多样性。
然而,在线性注意的情况下,这一点很难实现。事实上,注意矩阵在线性注意中的秩受到每个头部的令牌数N和通道维数d的限制:
r
a
n
k
(
ϕ
(
Q
)
ϕ
(
K
)
T
)
≤
m
i
n
{
r
a
n
k
(
ϕ
(
Q
)
)
,
r
a
n
k
(
ϕ
(
Q
)
)
}
≤
m
i
n
{
N
,
d
}
rank(\phi(Q)\phi(K)^T) \le min\{rank(\phi(Q)),rank(\phi (Q))\} \le min\{N,d\}
rank(ϕ(Q)ϕ(K)T)≤min{rank(ϕ(Q)),rank(ϕ(Q))}≤min{N,d}
因为d通常小于N,所以线性注意力机制的注意力矩阵小于等于d,而softmax注意力小于等于N(大概率是等于d和等于n)。在这种情况下,注意矩阵秩的上界被限制在一个较低的比率,这表明注意映射的许多行被严重均质化。由于自注意力的输出是同一组V的加权和,注意力权重的均匀化不可避免地导致聚合特征之间的相似性。
为了更好地说明,我们将DeiT-Tiny中的原始Softmax注意替换为线性注意,并显示了图5 (b).中的注意图的rank,可以观察到,rank大大下降(196中有54),注意矩阵的许多行是相似的。
作为一种补救方法,我们提出了一个简单而有效的解决方案来解决线性注意的限制。具体地说,在注意矩阵中添加了一个深度卷积(DWC,depthwise convolution)模块,输出可以表示为:
O
=
ϕ
(
Q
)
ϕ
(
K
)
T
V
+
D
W
C
(
V
)
O=\phi(Q)\phi(K)^TV+DWC(V)
O=ϕ(Q)ϕ(K)TV+DWC(V)
为了更好地理解这个DWC模块的效果,我们可以把它看作是一种attention,即每个query只关注空间中的几个相邻特征,而不是所有特征V。这种局部性保证了即使两个查询对应的线性注意值相同,我们仍然可以从不同的局部特征中得到不同的输出,从而保持特征的多样性。DWC的影响也可以从矩阵秩的角度来解释:
O
=
(
ϕ
(
Q
)
ϕ
(
K
)
T
+
M
D
W
C
)
V
=
M
e
q
V
O=(\phi(Q)\phi(K)^T+M_{DWC})V=M_{eq}V
O=(ϕ(Q)ϕ(K)T+MDWC)V=MeqV
M
D
W
C
M_{DWC}
MDWC是深度卷积的稀疏矩阵,
M
e
q
M_{eq}
Meq对应于注意力矩阵,因为
M
D
W
C
M_{DWC}
MDWC是满秩,所以
M
e
q
M_{eq}
Meq大概率也满秩。
为了更好地说明,我们在DeiT-Tiny上进行了类似的修改。通过附加的DWC模块,注意图在线性注意中的秩可以恢复到全秩(196/196,如图5©所示),从而保持特征多样性作为原来的Softmax注意。
模块构成
我们的模块可以表述为: O = S i m ( Q , K ) V = ϕ p ( Q ) ϕ p ( K ) T V + D W C ( V ) O=Sim(Q,K)V=\phi_p(Q)\phi_p(K)^TV+DWC(V) O=Sim(Q,K)V=ϕp(Q)ϕp(K)TV+DWC(V)