Transformer 基础
Transformer 模型架构
主要组成: Encoder, Decoder, Generator.
Encoder (编码器)
由
N
N
N 层结构相同(参数不同)的 EncoderLayer 网络组成.
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{Out}: [batch\_sz, seq\_len, d_{model}]
Out:[batch_sz,seq_len,dmodel]
EncoderLayer: 由一层自注意力 Multi-Head Attention (多头注意力) 子网络, 一层 Position-wise Feed-Forward (基于位置的前馈) 子网络, 以及用于连接子网络的 Residual Connection (残差连接) 和 Layer Normalization (层标准化) 组成.
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
- 自注意力 Multi-Head Attention 网络: Q, K, V 均来自上一层(Input Embedding/EncoderLayer)网络.
In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Decoder (解码器)
由
N
N
N 层结构相同(参数不同)的 DecoderLayer 网络组成.
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
DecoderLayer: 由一层自注意力 Masked Multi-Head Attention 子网络, 一层(Encoder-Decoder)注意力 Multi-Head Attention 子网络, 一层 Position-wise Feed-Forward (基于位置的前馈) 子网络, 以及用于连接子网络的 Residual Connection (残差连接) 和 Layer Normalization (层标准化) 组成.
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
- 自注意力 Masked Multi-Head Attention 网络: Q, K, V 均来自上一层(Output Embedding/DecoderLayer)网络. “Masked” 是通过掩码(
[
1
,
s
e
q
_
l
e
n
,
s
e
q
_
l
e
n
]
[1,seq\_len,seq\_len]
[1,seq_len,seq_len])将后续位置屏蔽, 仅关注需要预测的下一个位置.
In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel] - (Encoder-Decoder)注意力 Multi-Head Attention 网络: Q 来自上一层(Masked Multi-Head Attention)网络; K,V 来自 Encoder 的输出 memory.
In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Generator (生成器)
由
[
In
:
d
m
o
d
e
l
,
Out
:
v
o
c
a
b
_
s
z
]
[\text{In}: d_{model}, \text{Out}:vocab\_sz]
[In:dmodel,Out:vocab_sz] 的线性网络和 Softmax 操作组成.
y
=
s
o
f
t
m
a
x
(
L
i
n
e
a
r
(
x
)
)
=
s
o
f
t
m
a
x
(
x
A
T
+
b
)
y = \mathrm{softmax}(\mathrm{Linear}(x))=\mathrm{softmax}(xA^T+b)
y=softmax(Linear(x))=softmax(xAT+b)
生成器是按序列顺序一次只输出下一个位置的预测概率.
In
:
[
b
a
t
c
h
_
s
z
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
v
o
c
a
b
_
s
z
]
\textbf{In}: [batch\_sz, d_{model}], \textbf{Out}: [batch\_sz, vocab\_sz]
In:[batch_sz,dmodel],Out:[batch_sz,vocab_sz]
※ Multi-Head Attention
Scaled Dot-Product Attention (缩放点积注意力):
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
⊤
d
k
)
V
\pmb{\mathrm{Attention}(Q,K,V) = \mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}})V}
Attention(Q,K,V)=softmax(dkQK⊤)V
维度变化:
- 输入:
- Q [ b a t c h _ s z , h , s e q _ l e n , d k ] Q\ [batch\_sz,h,seq\_len,d_k] Q [batch_sz,h,seq_len,dk]
- K [ b a t c h _ s z , h , s e q _ l e n , d k ] K\ [batch\_sz,h,seq\_len,d_k] K [batch_sz,h,seq_len,dk], K ⊤ [ b a t c h _ s z , h , d k , s e q _ l e n ] K^{\top}\ [batch\_sz,h,d_k,seq\_len] K⊤ [batch_sz,h,dk,seq_len]
- V [ b a t c h _ s z , h , s e q _ l e n , d k ] V\ [batch\_sz,h,seq\_len,d_k] V [batch_sz,h,seq_len,dk]
- Q K ⊤ [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] QK^{\top}\ [batch\_sz,h,seq\_len,seq\_len] QK⊤ [batch_sz,h,seq_len,seq_len]
- Q K ⊤ d k \frac{QK^{\top}}{\sqrt{d_k}} dkQK⊤ 与 Mask 操作: 不改变形状 [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] [batch\_sz,h,seq\_len,seq\_len] [batch_sz,h,seq_len,seq_len]
- s o f t m a x ( Q K ⊤ d k ) \mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}}) softmax(dkQK⊤): 最后一维进行 Softmax 操作, 不改变形状 [ b a t c h _ s z , h , s e q _ l e n , s e q _ l e n ] [batch\_sz,h,seq\_len,seq\_len] [batch_sz,h,seq_len,seq_len]
-
s
o
f
t
m
a
x
(
Q
K
⊤
d
k
)
V
\mathrm{softmax}(\frac{QK^{\top}}{\sqrt{d_k}})V
softmax(dkQK⊤)V:
[
b
a
t
c
h
_
s
z
,
h
,
s
e
q
_
l
e
n
,
d
k
]
[batch\_sz,h,seq\_len,d_k]
[batch_sz,h,seq_len,dk]
完整公式(参考 FlashAttention):
S
=
τ
Q
K
⊤
∈
R
N
×
N
S
masked
=
MASK
(
S
)
∈
R
N
×
N
P
=
softmax
(
S
masked
)
∈
R
N
×
N
P
dropped
=
dropout
(
P
,
p
d
r
o
p
)
]
Attention
(
Q
,
K
,
V
)
=
O
=
P
dropped
V
∈
R
N
×
d
\begin{aligned} & S=\tau QK^{\top}\in\mathbb{R}^{N\times N}\\ & S^{\text{masked}}=\text{MASK}(S)\in\mathbb{R}^{N\times N}\\ & P=\text{softmax}(S^{\text{masked}})\in\mathbb{R}^{N\times N}\\ & P^{\text{dropped}}=\text{dropout}(P, p_{drop})]\\ & \text{Attention}(Q,K,V)=O=P^{\text{dropped}}V\in\mathbb{R}^{N\times d} \end{aligned}
S=τQK⊤∈RN×NSmasked=MASK(S)∈RN×NP=softmax(Smasked)∈RN×NPdropped=dropout(P,pdrop)]Attention(Q,K,V)=O=PdroppedV∈RN×d
Multi-Head Attention (多头注意力) 机制:
M
u
l
t
i
H
e
a
d
A
t
t
n
(
Q
,
K
,
V
)
=
C
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
h
)
W
O
w
h
e
r
e
h
e
a
d
i
=
A
t
t
e
n
t
i
o
n
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
\begin{aligned} MultiHeadAttn(Q,K,V) &= Concat(head_1, ..., head_h)W^O\\ \mathrm{where}\ head_i &= Attention(QW^Q_i, KW^K_i, VW^V_i) \end{aligned}
MultiHeadAttn(Q,K,V)where headi=Concat(head1,...,headh)WO=Attention(QWiQ,KWiK,VWiV)
其中,
W
i
Q
∈
R
d
m
o
d
e
l
×
d
k
,
W
i
K
∈
R
d
m
o
d
e
l
×
d
k
,
W
i
V
∈
R
d
m
o
d
e
l
×
d
v
,
W
O
∈
R
h
d
v
×
d
m
o
d
e
l
W^Q_i\in\mathbb{R}^{d_{model\times d_k}}, W^K_i\in\mathbb{R}^{d_{model}\times d_k}, W^V_i\in\mathbb{R}^{d_{model}\times d_v}, W^O\in\mathbb{R}^{hd_v\times d_{model}}
WiQ∈Rdmodel×dk,WiK∈Rdmodel×dk,WiV∈Rdmodel×dv,WO∈Rhdv×dmodel
在实现中,
W
Q
=
(
W
1
Q
,
.
.
.
,
W
h
Q
)
W^Q=(W^Q_1,...,W^Q_h)
WQ=(W1Q,...,WhQ),
W
K
=
(
W
1
K
,
.
.
.
,
W
h
K
)
W^K=(W^K_1,...,W^K_h)
WK=(W1K,...,WhK),
W
V
=
(
W
1
V
,
.
.
.
,
W
h
V
)
W^V=(W^V_1,...,W^V_h)
WV=(W1V,...,WhV),
W
O
W^O
WO, 由 4 个
[
In
:
d
m
o
d
e
l
,
Out
:
d
m
o
d
e
l
]
[\text{In}: d_{model}, \text{Out}:d_{model}]
[In:dmodel,Out:dmodel] 的线性网络组成,
d
k
=
d
v
=
d
m
o
d
e
l
/
h
d_k=d_v=d_{model}/h
dk=dv=dmodel/h
维度变化:
- 输入: X [ b a t c h _ s z , s e q _ l e n , d m o d e l ] X\ [batch\_sz, seq\_len, d_{model}] X [batch_sz,seq_len,dmodel]
- 多头预处理: X [ b a t c h _ s z , s e q _ l e n , d m o d e l ] X\ [batch\_sz, seq\_len, d_{model}] X [batch_sz,seq_len,dmodel] → X [ b a t c h _ s z , h , s e q _ l e n , d k ] X\ [batch\_sz,h,seq\_len,d_k] X [batch_sz,h,seq_len,dk]
- 注意力机制: X [ b a t c h _ s z , h , s e q _ l e n , d k ] X\ [batch\_sz,h,seq\_len,d_k] X [batch_sz,h,seq_len,dk] → Q , K , V [ b a t c h _ s z , h , s e q _ l e n , d k ] Q,K,V\ [batch\_sz,h,seq\_len,d_k] Q,K,V [batch_sz,h,seq_len,dk] → A t t e n t i o n ( Q , K , V ) [ b a t c h _ s z , h , s e q _ l e n , d k ] \mathrm{Attention}(Q,K,V)\ [batch\_sz,h,seq\_len,d_k] Attention(Q,K,V) [batch_sz,h,seq_len,dk]
- 拼接多头结果: C o n c a t ( h e a d 1 , . . . , h e a d h ) [ b a t c h _ s z , h , s e q _ l e n , d k ] Concat(head_1, ..., head_h)\ [batch\_sz,h,seq\_len,d_k] Concat(head1,...,headh) [batch_sz,h,seq_len,dk]
- 输出: M u l t i H e a d A t t n ( Q , K , V ) [ b a t c h _ s z , s e q _ l e n , d m o d e l ] MultiHeadAttn(Q,K,V)\ [batch\_sz, seq\_len, d_{model}] MultiHeadAttn(Q,K,V) [batch_sz,seq_len,dmodel]
Position-wise Feed-Forward
F F N ( x ) = L i n e a r 2 ( R e L U ( L i n e a r 1 ( x ) ) ) = max ( 0 , x W 1 + b 1 ) W 2 + b 2 \mathrm{FFN}(x)=\mathrm{Linear}_2(\mathrm{ReLU}(\mathrm{Linear}_1(x)))=\max(0, xW_1 + b_1) W_2 + b_2 FFN(x)=Linear2(ReLU(Linear1(x)))=max(0,xW1+b1)W2+b2
L
i
n
e
a
r
1
(
x
)
\mathrm{Linear}_1(x)
Linear1(x) :
[
In
:
d
m
o
d
e
l
,
Out
:
d
f
f
]
[\text{In}:d_{model},\ \text{Out}:d_{ff}]
[In:dmodel, Out:dff]
L
i
n
e
a
r
2
(
x
)
\mathrm{Linear}_2(x)
Linear2(x) :
[
In
:
d
f
f
,
Out
:
d
m
o
d
e
l
]
[\text{In}:d_{ff},\ \text{Out}:d_{model}]
[In:dff, Out:dmodel]
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Add&Norm
论文中: (post-Norm)
S
u
b
l
a
y
e
r
C
o
n
n
e
c
t
i
o
n
(
X
)
=
L
a
y
e
r
N
o
r
m
(
X
+
S
u
b
l
a
y
e
r
(
X
)
)
\mathrm{SublayerConnection}(X)= \mathrm{LayerNorm}(X +\mathrm{Sublayer}(X))
SublayerConnection(X)=LayerNorm(X+Sublayer(X))
AnnotatedTransformer 实现中: (pre-Norm)
S
u
b
l
a
y
e
r
C
o
n
n
e
c
t
i
o
n
(
X
)
=
X
+
S
u
b
l
a
y
e
r
(
L
a
y
e
r
N
o
r
m
(
X
)
)
\mathrm{SublayerConnection}(X)= X+\mathrm{Sublayer}(\mathrm{LayerNorm}(X))
SublayerConnection(X)=X+Sublayer(LayerNorm(X))
In : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] , Out : [ b a t c h _ s z , s e q _ l e n , d m o d e l ] \textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}] In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
其中:
- S u b l a y e r ∈ { M u l t i H e a d A t t n , F F N } \mathrm{Sublayer}\in\{\mathrm{MultiHeadAttn},\mathrm{FFN}\} Sublayer∈{MultiHeadAttn,FFN}
- 层标准化
L
a
y
e
r
N
o
r
m
(
X
)
\mathrm{LayerNorm}(X)
LayerNorm(X): 对张量
X
X
X 的最后一维(
d
m
o
d
e
l
d_{model}
dmodel 维, 表示每个样本)
x
=
X
[
b
,
p
o
s
,
:
]
∈
R
d
m
o
d
e
l
x=X[b,pos,:]\in\mathbb{R}^{d_{model}}
x=X[b,pos,:]∈Rdmodel 进行标准化.
N o r m ( x ) = x − E ( x ) S D ( x ) + ϵ ∗ γ + β \mathrm{Norm}(x)=\frac{x-E(x)}{SD(x)+\epsilon}*\gamma+\beta Norm(x)=SD(x)+ϵx−E(x)∗γ+β. 其中, E ( x ) E(x) E(x) 为平均值(期望), S D ( x ) SD(x) SD(x) 为标准差, γ , β ∈ R d m o d e l \gamma,\beta\in\mathbb{R}^{d_{model}} γ,β∈Rdmodel 为可学习的参数, ϵ \epsilon ϵ 是用于数值稳定性(避免除 0)在分母上加的一个极小值标量. - 残差连接 (Residual Connection): y = x + F ( x ) y=x+\mathcal{F}(x) y=x+F(x)
- 注: pre-Norm 与 post-Norm 的区别, 参考: 【重新了解Transformer模型系列_1】PostNorm/PreNorm的差别 - 知乎
Token Embedding
大小为
v
o
c
a
b
_
s
z
vocab\_sz
vocab_sz 嵌入维度为
d
m
o
d
e
l
d_{model}
dmodel 的查询表(lookup table).
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len],Out:[batch_sz,seq_len,dmodel]
E
m
b
e
d
d
i
n
g
(
x
)
=
l
u
t
(
x
)
⋅
d
m
o
d
e
l
\mathrm{Embedding(x)} = \mathrm{lut}(x)\cdot\sqrt{d_{model}}
Embedding(x)=lut(x)⋅dmodel
Positional Encoding
用于
P
E
(
p
o
s
,
2
i
)
=
sin
(
p
o
s
/
1000
0
2
i
/
d
model
)
P
E
(
p
o
s
,
2
i
+
1
)
=
cos
(
p
o
s
/
1000
0
2
i
/
d
model
)
\begin{aligned} &PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{\text{model}}})\\ &PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}}) \end{aligned}
PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
P E ( X ) = X + P , where ( p ( b , p o s , i ) ) = P , p ( b , p o s , i ) = P E ( p o s , i ) \mathrm{PE}(X)=X+ P,\ \text{where}\ (p_{(b,pos,i)})=P,\ p_{(b,pos,i)} = PE_{(pos,i)} PE(X)=X+P, where (p(b,pos,i))=P, p(b,pos,i)=PE(pos,i)
其中,
X
,
P
∈
R
b
a
t
c
h
_
s
z
×
s
e
q
_
l
e
n
×
d
m
o
d
e
l
X,P\in\mathbb{R}^{batch\_sz\times seq\_len\times d_{model}}
X,P∈Rbatch_sz×seq_len×dmodel, 即
X
X
X 和
P
P
P 为
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
[batch\_sz,seq\_len,d_{model}]
[batch_sz,seq_len,dmodel] 形状的张量;
p
(
b
,
p
o
s
,
i
)
p_{(b,pos,i)}
p(b,pos,i) 为
P
P
P 对应位置的元素,
p
o
s
pos
pos 为 token 在
s
e
q
_
l
e
n
seq\_len
seq_len 长度的序列中位置,
i
i
i 为
d
m
o
d
e
l
d_{model}
dmodel 中的维度.
In
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
,
Out
:
[
b
a
t
c
h
_
s
z
,
s
e
q
_
l
e
n
,
d
m
o
d
e
l
]
\textbf{In}: [batch\_sz, seq\_len, d_{model}], \textbf{Out}: [batch\_sz, seq\_len, d_{model}]
In:[batch_sz,seq_len,dmodel],Out:[batch_sz,seq_len,dmodel]
Subsequent Mask
也称为 “Causal Attention Mask”, 因果注意力掩码("FlashAttention"中的说法). 用于 Decoder 的注意力网络中屏蔽预测位置之后的信息, 即仅根据预测位置及之前的信息进行预测.
掩码应用于矩阵
Q
K
T
/
d
k
QK^T/\sqrt{d_k}
QKT/dk, 是一个包括对角线的下三角矩阵(对应保留
Q
Q
Q 的
s
e
q
_
l
e
n
seq\_len
seq_len 索引
i
i
i 大于等于
K
T
K^T
KT 的
s
e
q
_
l
e
n
seq\_len
seq_len 索引
j
j
j 的计算结果), 将掩码为 0 部分(上三角部分为 0)对应的矩阵数据替换为极小值(如 -1e9
).
shape
:
[
1
,
s
e
q
_
l
e
n
,
s
e
q
_
l
e
n
]
\text{shape}: [1,seq\_len, seq\_len]
shape:[1,seq_len,seq_len]
代码实现
- The Annotated Transformer 官方 Colab 代码: AnnotatedTransformer.ipynb
- 带详细中文注释的 Colab 代码: AnnotatedTransformer.ipynb
- The Annotated Transformer 官方 GitHub 仓库: harvardnlp/annotated-transformer
- 带详细中文注释且模型代码分离的 GitHub 仓库: peakcrosser7/annotated-transformer
参考资料
- Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[J]. Advances in neural information processing systems, 2017, 30. https://dl.acm.org/doi/10.5555/3295222.3295349
- The Annotated Transformer - Harvard University
- Self-Attention v/s Attention: understanding the differences | by Nishant Usapkar | Medium
- Self attention vs attention in transformers | MLearning.ai
- 【重新了解Transformer模型系列_1】PostNorm/PreNorm的差别 - 知乎