在《大语言模型之四-LlaMA-2从模型到应用》的LLama-2推理图中可以看到,在输入“你好!”时,是串行进行的,即先输入“你”这个token,然后是“好”,再然后是“!”token,前一个token需要保留前面的k和v矩阵,这就意味着随着输入sequence length的增长,需要的内存也会快速增长,计算量也会快速增长。这也显示了Transformer尽管在模型训练的时候并发(相比RNN)性能好,且模型的效果也好,但是推理的时候效率就比较低。
RetNet特点
微软提出的RetNet在训练并发、模型效果以及推理效率上都取得了不错的效果。下图是其paper中关于模型性能和推理效率和Transformer的对比情况。
其官方paper宣称,其实验数据显示,在语言建模任务上:
RetNet 可以达到与 Transformer 相当的困惑度(perplexity)
推理速度达8.4倍
推理算法延迟降低90%
内存占用减少70%
具有良好的扩展性
实现以上改进是因为RetNet 在 Transformer 的基础上,使用多尺度保持(Retention)机制替代了标准的自注意力机制。
与标准自注意力机制相比,保持机制有几大特点:
- 引入位置相关的指数衰减项取代 softmax,简化了计算,同时使前步的信息以衰减的形式保留下来。
- 引入复数空间表达位置信息,取代绝对或相对位置编码,容易转换为递归形式。
- 保持机制使用多尺度的衰减率,增加了模型的表达能力,并利用 GroupNorm 的缩放不变性来提高 Retention 层的数值精度。
训练并发
因为Transformer采用了self-attention机制,每一阶段的输出都可以用Q,K,V进行并发处理,这大大提高了GPU利用率,提高了训练效率。RNN网络的好处是推理的效率高(相比开篇提的KV历史是不需要保留的,这使得计算量和内存都极大减少了),内存复杂度低O(1)。
RetNet的巧妙之处在于,训练的时候依然类似Transformer的并发结构(如图3左边),而在推理的时候则可以采用RNN的结构(如图3右边)。即实现了parallel training, recurrent/chunk-wise inference。
RetNet改进点
RetNet相比于Transformer主要有两点改进:
1.引入multi-scale retention替代了multi-head attention;
2.RetNet可以用三种(parallel/recurrent/chunk-wise )方式实现,公式和结果上是相等的,因而可以在训练和推理的时候选择最为高效的方式实现;chunk-wise可以更高效的处理长sequence的序列情况。
并行训练
RetNet摒弃了softmax操作,引入了基于D矩阵的Hadamard积(对应元素相乘),然后是GroupNorm操作。Transformer中softmax为输入序列中的每个token提供了相对的注意力权重,有助于模型学习和保留长期依赖关系。之前也有一些研究是舍弃softmax运算,但是会降低模型性能。
具体来说Transformer中的softmax主要实现了两个目标:
1.对不同的时间步长采用不同的方式加权,这有助于模型注意力放在序列中应该感兴趣的部分,这是相比RNN而言最重要的贡献。RetNet论文中的D矩阵实现注意力机制,D矩阵是一个因果注意力矩阵,即序列的当前输入只能看到过去的信息,D-矩阵假设最近的时间步长比过去的时间步长更重要,因此采用了指数衰减权重。因此,softmax足够灵活,可以对不同的步骤做不同的权重估计,而D-矩阵以固定的预定义方式(指数衰减)权衡所有步骤。最终paper里的结果显示RetNet效果是比Transformer好的。
2.引入非线性,当没有softmax的时候,
Q
K
T
QK^T
QKT就是一种仿射变换(即从一种空间的标识变为另一种空间的表示),再多层的注意力堆叠也依然是一种仿射变换,非线性是通过GroupNorm的方式实现的,为什么是GroupNorm实现非线性,论文中并没有提及,似乎是盲测多种结构在GroupNorm时效果是最好的。
Transformer和RetNet的异同
这里我将博客《大语言模型之四-LlaMA-2从模型到应用》中LlaMA-2的推理流程图展示在这里了,该图的详细说明见博客。
图中第一步是用
W
q
,
W
v
,
W
k
W_q, W_v, W_k
Wq,Wv,Wk得到
Q
,
V
,
和
K
Q,V,和K
Q,V,和K,即
Q
=
X
W
q
Q=XW_q
Q=XWq,
K
=
X
W
k
K=XW_k
K=XWk,
V
=
X
W
v
V=XW_v
V=XWv,然后通过softmax得到Attention score
o
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
(
1
)
o= softmax (\frac{Q K^T}{\sqrt{d_k}})V (1)
o=softmax(dkQKT)V(1)
由于RetNet在循环网络结构和并行网络结构都可以得到一样的结果,所以论文的作者现在循环结构中实现激励的“保留”模块,然后将“保留”模块向量化,因此得到的第n个序列输入的结果(保留注意力)为:
o
n
=
∑
m
=
1
n
(
Q
n
A
n
−
m
K
m
T
)
v
m
,
Q
n
∈
R
1
×
d
(
2
)
o_n=\sum_{m=1}^n(Q_nA^{n-m}K_m^T)v_m, Q_n \in \mathbb{R}^{1\times d} (2)
on=m=1∑n(QnAn−mKmT)vm,Qn∈R1×d(2)
从上面公式1和2可以看到RetNet网络使用pos矩阵
A
n
−
m
A^{n-m}
An−m替换掉了Transformer中的softmax的Attention部分(非线性部分这里还没替换掉)。Retention score的想法是这样的,首先通过循环网络状态
s
(
n
)
\mathbf s(n)
s(n)做映射
v
(
n
)
→
o
(
n
)
v(n) \rightarrow o(n)
v(n)→o(n),即
s
n
=
A
s
n
−
1
+
K
n
T
v
n
,
A
∈
R
d
×
d
,
K
n
∈
R
1
×
d
(
3
)
\mathbf s_n = A \mathbf s_{n-1} +K_n^Tv_n, A \in \mathbb R^{d\times d}, K_n \in \mathbb R^{1 \times d} (3)
sn=Asn−1+KnTvn,A∈Rd×d,Kn∈R1×d(3)
然后使用线性变换递归对序列进行编码:
o
n
=
Q
n
s
n
=
∑
m
=
1
n
(
Q
n
A
n
−
m
K
m
T
)
v
m
,
Q
n
∈
R
1
×
d
(
4
)
o_n=Q_n \mathbf s_n=\sum_{m=1}^n(Q_nA^{n-m}K_m^T)v_m, Q_n \in \mathbb{R}^{1\times d}(4)
on=Qnsn=m=1∑n(QnAn−mKmT)vm,Qn∈R1×d(4)
根据原论文的公式3,可以将位置分为共轭的两个部分。
R
e
t
e
n
t
i
o
n
(
x
)
=
∑
m
=
1
n
(
Q
n
(
γ
e
i
θ
)
n
)
(
K
m
(
γ
e
i
θ
)
−
m
)
T
v
m
,
γ
,
θ
∈
R
d
(
5
)
Retention(x) = \sum_{m=1}^n(Q_n(\gamma e^{i \theta})^{n})(K_m(\gamma e^{i\theta})^{-m})^Tv_m, \gamma, \theta \in \mathbb R^d (5)
Retention(x)=m=1∑n(Qn(γeiθ)n)(Km(γeiθ)−m)Tvm,γ,θ∈Rd(5)
其中
Q
n
(
γ
e
i
θ
)
n
Q_n(\gamma e^{i \theta})^{n}
Qn(γeiθ)n和
K
m
(
γ
e
i
θ
)
−
m
K_m(\gamma e^{i\theta})^{-m}
Km(γeiθ)−m是位置矩阵,可以采用论文所述的xPos,为了简化上述的方程,可以用标量值替代
γ
\gamma
γ,这样在训练的时候可以采用并发的方式训练:
这样可以看到在RetNet时候,在得到Q、K以及V的方法和原始的Transformer一样是可以并发进行的,
e
i
n
θ
e^{in\theta}
einθ对应的位置信息逐点相乘即可。最后一步的Retention score使用到的D矩阵也是可以提前计算的,因为它只是一个相对位置嵌入+因果掩码。
图 3 RetNet的两种实现方式
公式6对应的就是图3左边的实现方式,在推理的时候采用循环网络的架构,即图3中右边的实现方式,公式如下7所示。
S
n
=
γ
S
n
−
1
+
K
n
T
V
n
,
R
e
t
e
n
t
i
o
n
(
X
n
)
=
Q
n
S
n
,
n
=
1
,
⋯
,
∣
x
∣
(
7
)
S_n=\gamma S_{n-1}+K_n^TV_n , Retention(X_n)=Q_nS_n, n=1,\cdots, |x| (7)
Sn=γSn−1+KnTVn,Retention(Xn)=QnSn,n=1,⋯,∣x∣(7)
其中的Q、K、V以及
γ
\gamma
γ的作用和意义和上式6是一样的。
RetNet并行计算过程
假设仅有“你好”这两个token输入序列,长度记为N,Embedding dim,D=3,则得到的QKV是矩阵的维度是NxD,假设初始的矩阵为:
Q
=
[
1
2
1
3
2
3
]
,
K
=
[
1
2
3
4
5
6
]
,
V
=
[
5
4
3
2
1
0
]
Q=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix},K=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 &6 \end{bmatrix},V=\begin{bmatrix} 5 & 4 & 3 \\ 2 & 1 &0 \end{bmatrix}
Q=[132213],K=[142536],V=[524130]
第一步:计算
Q
K
T
QK^T
QKT
Q
K
T
=
[
1
2
1
3
2
3
]
[
1
4
2
5
3
6
]
=
[
8
20
16
40
]
(
8
)
QK^T=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix} \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix}=\begin{bmatrix} 8 & 20 \\ 16 & 40 \end{bmatrix} (8)
QKT=[132213]
123456
=[8162040](8)
第二步:计算
Q
K
T
QK^T
QKT和
D
D
D矩阵的Hadamard积(对应元素相乘)
Q
K
T
⊙
D
=
[
8
20
16
40
]
⊙
[
1
0
0.25
1
]
=
[
8
4
0
40
]
(
9
)
QK^T\odot D=\begin{bmatrix} 8 & 20 \\ 16 & 40 \end{bmatrix} \odot \begin{bmatrix} 1 & 0 \\ 0.25 & 1 \end{bmatrix}= \begin{bmatrix} 8 & 4 \\ 0& 40 \end{bmatrix}(9)
QKT⊙D=[8162040]⊙[10.2501]=[80440](9)
其中
D
=
[
γ
1
−
0
γ
1
−
1
γ
2
−
0
γ
2
−
1
]
D=\begin{bmatrix} \gamma^{1-0} & \gamma^{1-1} \\ \gamma^{2-0} & \gamma^{2-1} \end{bmatrix}
D=[γ1−0γ2−0γ1−1γ2−1],当
γ
=
0.5
\gamma=0.5
γ=0.5时可以得到式9。
第三步和V相乘:
(
Q
K
T
⊙
D
)
V
=
[
8
4
0
40
]
[
5
4
3
2
1
0
]
=
[
40
32
24
100
56
12
]
(
10
)
(QK^T \odot D)V = \begin{bmatrix} 8 & 4 \\ 0& 40 \end{bmatrix} \begin{bmatrix} 5 & 4 & 3\\ 2& 1 & 0 \end{bmatrix} = \begin{bmatrix} 40 & 32 & 24\\ 100& 56 & 12 \end{bmatrix} (10)
(QKT⊙D)V=[80440][524130]=[4010032562412](10)
这样就得到了两个token输入时的最终上下文Embedding结果。
RetNet循环网络计算过程
图3的右侧所示过程。这里的Q、K、V和上面的并行计算并不是一样的,这里有下标n,这表明了是第n个输入token对应的矩阵,因而是一个1xD维矩阵(上一小节是NxD维),另外一个区别是含有当前token之前时间和位置信息的状态S,当前状态和前一个状态使用指数衰减因子,如公式7所示。
在计算的时候先是KV不再是QK相乘,图3中可以看到,和并行计算一样,初始的矩阵为
Q
=
[
1
2
1
3
2
3
]
,
K
=
[
1
2
3
4
5
6
]
,
V
=
[
5
4
3
2
1
0
]
Q=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix},K=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 &6 \end{bmatrix},V=\begin{bmatrix} 5 & 4 & 3 \\ 2 & 1 &0 \end{bmatrix}
Q=[132213],K=[142536],V=[524130]
第一步计算
K
1
T
⊗
V
1
K_1^T\otimes V_1
K1T⊗V1
K
1
T
V
1
=
[
1
2
3
]
[
5
4
3
]
=
[
5
4
3
10
8
6
15
12
9
]
(
11
)
K_1^T V_1 = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 5 & 4 & 3\end{bmatrix} = \begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} (11)
K1TV1=
123
[543]=
510154812369
(11)
第二步计算
S
1
S_1
S1
因为在输入“你”token之前并没有其他token输入,因而
S
0
S_0
S0是不存在的,因而先前没有状态叠加到“你”这个token对应的状态的。
S
1
=
γ
0
S
0
+
K
1
T
V
1
=
[
5
4
3
10
8
6
15
12
9
]
(
12
)
S_1=\gamma^{0}S_0 + K_1^TV_1=\begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} (12)
S1=γ0S0+K1TV1=
510154812369
(12)
第三步将Q和
S
1
S_1
S1相乘得到最终的Attention score。
Q
1
⊗
S
1
=
[
5
4
3
10
8
6
15
12
9
]
⊗
[
1
2
1
]
=
[
5
4
3
20
16
12
15
12
9
]
s
u
m
=
[
40
32
24
]
(
13
)
Q_1 \otimes S_1 = \begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} \otimes \begin{bmatrix} 1 \\ 2 \\ 1 \end{bmatrix} = \begin{bmatrix} 5 & 4 & 3\\ 20& 16 & 12 \\ 15 & 12 &9 \end{bmatrix}_{sum} = \begin{bmatrix} 40 & 32 & 24\end{bmatrix} (13)
Q1⊗S1=
510154812369
⊗
121
=
52015416123129
sum=[403224](13)
这和公式10最终结果的第一行是一样的。
第四步计算
K
2
T
V
2
K_2^TV_2
K2TV2,和第一步类似,得到:
K
2
T
V
2
=
[
4
5
6
]
[
2
1
0
]
=
[
8
4
0
10
5
0
12
6
0
]
(
14
)
K_2^T V_2 = \begin{bmatrix} 4 \\ 5 \\ 6 \end{bmatrix} \begin{bmatrix} 2 & 1 & 0\end{bmatrix} = \begin{bmatrix} 8 & 4 & 0\\ 10& 5 & 0 \\ 12 & 6 & 0 \end{bmatrix} (14)
K2TV2=
456
[210]=
81012456000
(14)
第五步计算
S
2
S_2
S2,
S
2
=
γ
2
S
1
+
K
2
T
V
2
=
0.
5
2
[
5
4
3
20
16
12
15
12
9
]
+
[
8
4
0
10
5
0
12
6
0
]
=
[
9.25
0.25
0.75
12.5
7
1.5
15.5
9
2.25
]
(
15
)
S_2=\gamma^{2}S_1 + K_2^TV_2=0.5^2\begin{bmatrix} 5 & 4 & 3\\ 20& 16 & 12 \\ 15 & 12 & 9 \end{bmatrix} + \begin{bmatrix} 8 & 4 & 0\\ 10& 5 & 0 \\ 12 & 6 & 0 \end{bmatrix} = \begin{bmatrix} 9.25 & 0.25 & 0.75\\ 12.5& 7 & 1.5 \\ 15.5 & 9 & 2.25 \end{bmatrix} (15)
S2=γ2S1+K2TV2=0.52
52015416123129
+
81012456000
=
9.2512.515.50.25790.751.52.25
(15)
第六步计算最终RetNet score。
Q
2
⊗
S
2
=
[
9.25
0.25
0.75
12.5
7
1.5
15.5
9
2.25
]
⊗
[
3
2
3
]
(
16
)
=
[
27.25
15
2.25
25
14
3
47.25
27
6.75
]
s
u
m
=
[
100
56
12
]
(
16
)
Q_2 \otimes S_2 = \begin{bmatrix} 9.25 & 0.25 & 0.75\\ 12.5& 7 & 1.5 \\ 15.5 & 9 & 2.25 \end{bmatrix} \otimes \begin{bmatrix} 3 \\ 2 \\ 3 \end{bmatrix}(16) =\begin{bmatrix} 27.25 & 15 & 2.25 \\ 25& 14 & 3 \\ 47.25 & 27 & 6.75 \end{bmatrix}_{sum} = \begin{bmatrix} 100 & 56 & 12\end{bmatrix} (16)
Q2⊗S2=
9.2512.515.50.25790.751.52.25
⊗
323
(16)=
27.252547.251514272.2536.75
sum=[1005612](16)
这样公式16的结果正好等于公式10,从公式13和16结果看和并行计算的结果是一样的,但是这中间并没有太多的存储空间需求,也没有非线性softmax计算。