Attention Free Transformer(AFT)
paper: An Attention Free Transformer
date: 2021-05
org: Apple
1 Motivation
原本基于dot product self attention Transformer的时间复杂度和空间复杂度都很高。提出了一个新的AFT层来降低transformer的计算量。
2 Method
2.1 Multi-Head Attention回顾
首先回顾一下经典的Multi-Head Attention(MHA),每一个head的计算如下
f i ( X ) = σ ( Q i ( K i ) T d k ) V i , s . t . Q i = X W i Q , K i = X W i K , V i = X W i V , (1) f _ { i } ( X ) = \sigma ( \frac { Q _ { i } ( K _ { i } ) ^ { T } } { \sqrt { d _ { k } } } ) V _ { i } , \ \mathrm { s . t . } \ Q _ { i } = X W _ { i } ^ { Q } , K _ { i } = X W _ { i } ^ { K } , V _ { i } = X W _ { i } ^ { V } , \tag{1} fi(X)=σ(dkQi(Ki)T)Vi, s.t. Qi=XWiQ,Ki=XWiK,Vi=XWiV,(1)
其中: W i Q ∈ R d × d k , W i K ∈ R d × d k , W i V ∈ R d × d υ W _ { i } ^ { Q } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { K } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { V } \; \in \; R ^ { d \times d _ { \upsilon } } WiQ∈Rd×dk,WiK∈Rd×dk,WiV∈Rd×dυ。 σ \sigma σ是非线性函数,默认为 s o f t m a x softmax softmax。通常情况下 d v = d k , h = d d k d_v = d_k, h = \frac{d}{d_k} dv=dk,h=dkd。假定输入 X ∈ R T × d X \in \mathbb {R}^ {T \times d} X∈RT×d, 经过 f i f_i fi转化后的输出 f i ( X ) ∈ R T × d v f_i{(X)} \in \mathbb{R} ^{T \times d_v} fi(X)∈RT×dv。将所有head的结果拼接起来得到最后的输出 R T × d \mathbb{R} ^{T \times d} RT×d。
单头Attention的时间复杂度计算:
- Q K V QKV QKV 的计算,此处有3个矩阵乘法,计算量为 d × d k × T × 3 d \times d_k \times T \times 3 d×dk×T×3, 时间复杂度为: O ( 1 h T d 2 ) \mathcal{O}(\frac{1}{h}Td^2) O(h1Td2)
- Q K T QK^T QKT的计算,计算量为: d k × T × T d_k \times T \times T dk×T×T, 时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)
- scale 的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
- softmax的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
- 最后加权乘法计算量为 d k × T × T d_k \times T \times T dk×T×T,时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)
对于MHA,时间复杂度为 O ( T d 2 ) \mathcal{O}(Td^2) O(Td2)
2.2 Attention Free Transofrmer(AFT)
2.2.1 AFT full
第一步和MHA一样,输入 X X X经过三个linear transfer得到 Q K V QKV QKV,3个矩阵, 维度为 R T × d \mathbb{R}^{T \times d} RT×d。AFT引入了一个新的可训练参数矩阵 w ∈ R T × T w \in \mathbb{R}^{T \times T} w∈RT×T,论文将其称之为可学习的一对一位置偏置(learned pair-wise position biases)。
我们以 y t y_t yt 为视角看每一步的具体流程。
SETP1: 求 w e i g h t e d ( K ( t ) ) \mathrm{weighted}(K^{(t)}) weighted(K(t))。从 w w w 取 t = t t=t t=t的向量, 和 K K K做点乘后以列方向进行 s o f t m a x \mathrm{softmax} softmax。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)
W e i g h t e d ( K ( t ) ) = exp ( K + w t ) ∑ i = 1 T exp ( k i + w t i ) (2) \mathrm{Weighted}(K^{(t)}) = \frac{\exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{2} Weighted(K(t))=∑i=1Texp(ki+wti)exp(K+wt)(2)
STEP2: 求 A t t e n t i o n ( t ) \mathrm{Attention}^{(t)} Attention(t)矩阵。将q_t用sigmoid变换后,点乘wighted(K)。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)
A
t
t
e
n
t
i
o
n
(
t
)
=
σ
(
q
t
)
⊙
W
e
i
g
h
t
e
d
(
K
(
t
)
)
=
σ
(
q
t
)
⊙
exp
(
K
+
w
t
)
∑
i
=
1
T
exp
(
k
i
+
w
t
i
)
(3)
\mathrm{Attention^{(t)}} = \sigma(q_t) \odot \mathrm{Weighted}(K^{(t)})= \frac{\sigma(q_t) \odot \exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{3}
Attention(t)=σ(qt)⊙Weighted(K(t))=∑i=1Texp(ki+wti)σ(qt)⊙exp(K+wt)(3)
STEP3: 计算 y t y_t yt。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)
y t = ∑ i = 1 T ( A t t e n t i o n ( t ) i ⊙ v i ) = ∑ i = 1 T σ ( q t ) ⊙ exp ( k i + w t ) ∑ i = 1 T exp ( k i + w t i ) ⊙ v i (4) y_t = \sum_{i=1}^{T}(\mathrm{Attention^{(t)}}_i \odot v_i) = \sum_{i=1}^{T} \frac{\sigma(q_t) \odot \exp (k_i + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \odot v_i \tag{4} yt=i=1∑T(Attention(t)i⊙vi)=i=1∑T∑i=1Texp(ki+wti)σ(qt)⊙exp(ki+wt)⊙vi(4)
对式(4)稍做变形,可得论文中的计算公式
y t = σ ( q t ) ⊙ ∑ i = 1 T exp ( k i + w t ) ⊙ v i ∑ i = 1 T exp ( k i + w t i ) (5) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i + w_t ) \odot v_i}{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{5} yt=σ(qt)⊙∑i=1Texp(ki+wti)∑i=1Texp(ki+wt)⊙vi(5)
将所有的步骤串起来的流程如下。可以看到AFT其实也用到了attention的思想。但AFT中的Attention Score的计算并没有用到矩阵乘法,只用到了向量点乘。虽整体的计算复杂度仍然是 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d),但计算量已有所下降。
式(4)计算pipeline
式(5)计算pipeline
2.2.1 AFT local
在许多情况下,局部性是一个很重要的归纳偏置(inductive bias),而标准的Transformer的计算中没有引入局部信息。因此,作者提出AFT-local。其形式与AFT-Full一致。区别在于,引入了下式限制
w t , t ′ = { w t , t ′ , i f ∣ t − t ′ ∣ < s 0 , o t h e r w i s e . (6) w_{t, t'} = \begin{cases} w_{t, t'}, \quad \mathrm{if} |t - t'| < s \\ 0, \quad \mathrm{otherwise.}\end{cases} \tag{6} wt,t′={wt,t′,if∣t−t′∣<s0,otherwise.(6)
式中的 s s s就是定义的局部窗口大小(local window size)。它进一步降低了计算量。变换后的 w w w如下图所示(此时 s = 2 s=2 s=2, 黑色方块为0)。
2.2.2 AFT simple
AFT simple是AFT local当 s = 0 s = 0 s=0时的特殊形式。此时没有位置偏置。可将式5化简为,因为对不同的 t t t, ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i} ∑i=1T(softmax(K)⊙V)i都是相同的。AFT simple的时间复杂度为 O ( T d ) \mathcal{O}(Td) O(Td)
y t = σ ( q t ) ⊙ ∑ i = 1 T exp ( k i ) ⊙ v i ∑ i = 1 T exp ( k i ) = σ ( q t ) ⊙ ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i (6) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i) \odot v_i}{\sum_{i=1}^{T} \exp (k_i) } = \sigma(q_t)\odot \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i}\tag{6} yt=σ(qt)⊙∑i=1Texp(ki)∑i=1Texp(ki)⊙vi=σ(qt)⊙i=1∑T(softmax(K)⊙V)i(6)
2.2.3 AFT conv
作者进一步将局部性的思想扩展到空间权重共享(如卷积),提出AFT-conv。具体来说,让 w t , t ′ w_{t,t'} wt,t′的值仅依赖 t t t和 t ′ t' t′的相对位置。为了考虑参数数量随着 h e a d head head数增加而增长的情况,作者采用了一个设计选择,将 K K K的维度与head数绑定在一起(MHA的思路)。这使得AFT-conv可以采用深度可分离卷积、全局池化和element-wise操作的实现方式。
可以看到与AFT simple相比,AFT conv引入了head思想,并通过1维卷积的计算结果引入局部信息。其形式与式(6)相比分子分母中新增了 c o n v 1 d ( exp ( K j ) ⊙ V j , exp ( w j ) − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) conv1d(exp(Kj)⊙Vj,exp(wj)−1), c o n v 1 d ( exp ( K j ) , exp ( w j ) − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) conv1d(exp(Kj),exp(wj)−1)。(上标 j j j表示第 j j j个head)。此时的 w w w为conv1d的filter。
y t j = σ q ( q t j ) ⊙ c o n v 1 d ( exp ( K j ) ⊙ V j , exp ( w j ) − 1 ) + ∑ i = 1 T exp ( k i j ) ⊙ v i j c o n v 1 d ( exp ( K j ) , exp ( w j ) − 1 ) + ∑ i = 1 T exp ( k i j ) (7) y _ { t } ^ { j } = \sigma _ { q } ( q _ { t } ^ { j } ) \odot \frac { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ { i } ^ { j } ) \odot v _ { i } ^ { j } } { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ {i } ^ { j } ) } \tag{7} ytj=σq(qtj)⊙conv1d(exp(Kj),exp(wj)−1)+∑i=1Texp(kij)conv1d(exp(Kj)⊙Vj,exp(wj)−1)+∑i=1Texp(kij)⊙vij(7)
从ViT可视化attention map中可以看出(横轴为head, 纵轴为layer)。原本的ViT(左边)的不同层,head的attention map的响应最大区域基本都是中心区域。而用了AFT-conv后,不同层、head的attention都有所不同,有助于模型捕获不同尺度的特征。
3 小结
本文提出了一种Dot Product Attention Free的Transformer,最多能将transofmer的时间复杂度从 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d)降低到 O ( T d ) \mathcal{O}(Td) O(Td)(AFT-simple)。