文章目录
- Transformer中使用的Post-Norm
- 大模型常用的Pre-Norm
- Pre-Norm一定比Post-Norm好吗?
- 二者区别总结
- 参考资料
Pre-Norm和Post-Norm的区别,是面试官非常喜欢问的问题。下面我们按照时间线,尽可能直白地讲清楚二者的区别。
直观来讲,Pre-Norm 和 Post-Norm 的区别就是 Layer Norm 和 Residual Connections 组合方式的不同。
- Post-Norm:传统的Layer Norm放在残差之后,做完Add再进行归一化。早期的很多模型都用的是 Post-Norm,比如著名的 Bert
- Pre-Norm:目前大模型大多数的做法是 先对输入做Layer Norm,然后再进行函数计算(例如attention和FFN)以及Add相加。Pre Norm 的训练更快,且更加稳定,所以之后的大模型架构大多都是 Pre Norm 了,比如 GPT,MPT,Falcon,和Llama。
Transformer中使用的Post-Norm
在原始Transformer中,对数据进行Normalization的方法是Post-Layer Norm(想要对Normalization进行深入理解的小伙伴可以看这篇博客:【Transformer】Normalization),如下图所示:
Post-Norm用公式表示为:
y = N o r m ( x l + a t t n ( x l ) ) x l + 1 = N o r m ( y + F F N ( y ) ) \begin{aligned} y&=Norm(x_l + attn(x_l)) \\ x_{l+1}&=Norm(y + FFN(y)) \end{aligned} yxl+1=Norm(xl+attn(xl))=Norm(y+FFN(y))
Post-Norm 之所以这么设计,是把 Normalization 放在一个模块的最后,这样下一个模块接收到的总是归一化后的结果。这比较符合 Normalization 的初衷,就是为了降低梯度的方差。但是层层堆叠起来,从上图可以看出,深度学习的基建 ResNet 的结构其实被破坏了,也就是Residual Connections在梯度反向传播时消失了。这就导致训练 Transformer 并不是那么容易的事情,需要加上各种补偿措施,例如 learning rate warm up, 初始化等。
那么为什么Post-Norm会导致Residual Connections在梯度反向传播时消失呢?
这里我们假设输入 x l x_l xl、 A t t e n t i o n ( x l ) Attention(x_l) Attention(xl)、 F N N ( x l ) FNN(x_l) FNN(xl) 的均值为0,方差都为1,且相互独立。事实上可能没有那么理想,因为权重矩阵的分布在学习过程中并不一定能保持理想的分布,这里为了说明问题对建模进行了简化。
我们知道,对于两个均值为0,方差为1且相互独立的分布 x 1 , x 2 ∈ N ( 0 , 1 ) x_1, x_2 \in N(0, 1) x1,x2∈N(0,1),那么 ( x 1 + x 2 ) (x_1+x_2) (x1+x2)就是均值为0,方差为2的分布,那么 Layer Norm对 ( x 1 + x 2 ) (x_1+x_2) (x1+x2)的计算就可以表示为 y = x 1 + x 2 2 y=\frac{x_1+x_2}{\sqrt{2}} y=2x1+x2
所以,Post-Norm的计算公式可以简化为:
y
=
N
o
r
m
(
x
l
+
a
t
t
n
(
x
l
)
)
=
x
l
+
a
t
t
n
(
x
l
)
2
x
l
+
1
=
N
o
r
m
(
y
+
F
F
N
(
y
)
)
=
y
+
F
F
N
(
y
)
2
=
x
l
+
a
t
t
n
(
x
l
)
2
+
F
F
N
(
y
)
2
=
x
l
+
a
t
t
n
(
x
l
)
+
2
F
F
N
(
y
)
2
=
x
l
+
f
(
x
l
)
2
\begin{aligned} y&=Norm(x_l+attn(x_l)) \\ &=\frac{x_l+attn(x_l)}{\sqrt{2}} \\ x_{l+1}&=Norm(y + FFN(y)) \\ &=\frac{y+FFN(y)}{\sqrt{2}} \\ &=\frac{\frac{x_l+attn(x_l)}{\sqrt{2}}+FFN(y)}{\sqrt{2}} \\ &=\frac{x_l+attn(x_l)+\sqrt{2}FFN(y)}{2} \\ &=\frac{x_l+f(x_l)}{2} \end{aligned}
yxl+1=Norm(xl+attn(xl))=2xl+attn(xl)=Norm(y+FFN(y))=2y+FFN(y)=22xl+attn(xl)+FFN(y)=2xl+attn(xl)+2FFN(y)=2xl+f(xl)
这里我们用 f ( x l ) f(x_l) f(xl)来表示对输入 x l x_l xl的复杂计算(包括attention和FFN)。这里可以看出,输入 x l x_l xl每经过一层,输出就变成 x l + f ( x l ) 2 \frac{x_l+f(x_l)}{2} 2xl+f(xl)。那么最终的输出对于最开始的输入来说:
o u t p u t = x l − 1 + f l ( x l − 1 ) 2 = x l − 1 2 + f l ( x l − 1 ) 2 = x l − 2 2 2 + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x l − 3 2 3 + f l − 2 ( x l − 3 ) 2 3 + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x 1 2 l − 1 + f 2 ( x 1 ) 2 l − 1 + . . . + f l − 1 ( x l − 2 ) 2 2 + f l ( x l − 1 ) 2 = x 1 2 l − 1 + g ( x ) \begin{aligned} output&=\frac{x_{l-1}+f_{l}(x_{l-1})}{2}\\ &=\frac{x_{l-1}}{2}+\frac{f_{l}(x_{l-1})}{2} \\ &=\frac{x_{l-2}}{2^2}+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{l-3}}{2^3}+\frac{f_{l-2}(x_{l-3})}{2^3}+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{1}}{2^{l-1}}+\frac{f_{2}(x_{1})}{2^{l-1}}+...+\frac{f_{l-1}(x_{l-2})}{2^2}+\frac{f_l(x_{l-1})}{2} \\ &=\frac{x_{1}}{2^{l-1}}+g(\bold{x}) \end{aligned} output=2xl−1+fl(xl−1)=2xl−1+2fl(xl−1)=22xl−2+22fl−1(xl−2)+2fl(xl−1)=23xl−3+23fl−2(xl−3)+22fl−1(xl−2)+2fl(xl−1)=2l−1x1+2l−1f2(x1)+...+22fl−1(xl−2)+2fl(xl−1)=2l−1x1+g(x)
这里我们用
g
(
x
)
g(\bold{x})
g(x)表示全部网络层对输入
x
1
x_1
x1的作用结果。然后我们对最终的输出求导可得:
∂
(
x
2
l
−
1
+
g
(
x
)
)
∂
x
=
1
2
l
−
1
+
∂
g
(
x
)
∂
x
\frac{\partial\left(\frac{x}{2^{l-1}}+g(x)\right)}{\partial x}=\frac{1}{2^{l-1}}+\frac{\partial g(x)}{\partial x}
∂x∂(2l−1x+g(x))=2l−11+∂x∂g(x)
看到了吗!!Resnet的Residual Connections没了,因为一般ResNet求导结果应该是: ∂ ( f ( x ) + x ) ∂ x = 1 + ∂ f ( x ) ∂ x \frac{\partial(f(x)+x)}{\partial x}=1+\frac{\partial f(x)}{\partial x} ∂x∂(f(x)+x)=1+∂x∂f(x),这里的1起到了防止梯度消失的作用,因此可以稳定训练更深的网路模型。但是经过我们推导发现,Post-Norm求导结果中的第一项,会随着网络层数的增加而指数递减。层数较低还好,如果像是现在的大模型一样堆叠32甚至64层,那几乎和0没什么区别了,也就丧失了 ResNet 的意义。
没了 ResNet 的架构,就导致在训练 Transforemr 的时候,需要小心翼翼。一般都要加一个 learning rate warm up 的过程,先让模型在小学习率上适应一段时间,然后再正常训练。warm up 的过程虽然在 Transformers 的论文里就提了一嘴,但是真正训练的时候会发现真的很重要。
大模型常用的Pre-Norm
发表在ACL 2019上的Learning Deep Transformer Models for Machine Translation 这篇文章首次提出了Layer Normalization位置对训练一个深层的Transformer模型至关重要,并且也开启了后续大家对Layer Normalization的探索。
同样的方式,让我们来看看Pre-Norm是什么计算的?下图左侧就是我们刚推导过的,Transformer原始论文中提到的Post-Norm,而右侧,则是现在大模型常用的Pre-Norm。
Pre-Norm用公式表示为:
y = x l + a t t n ( N o r m ( x l ) ) x l + 1 = y + F F N ( N o r m ( y ) ) \begin{aligned} y&=x_l + attn(Norm(x_l)) \\ x_{l+1}&=y + FFN(Norm(y)) \end{aligned} yxl+1=xl+attn(Norm(xl))=y+FFN(Norm(y))
很明显,Pre-Norm很好的保留了ResNet的核心Residual Connections,反向传播计算梯度时很好的缓解了梯度消失的问题。
因此,到这里可以总结出Pre-Norm相比于Post-Norm是有优势的,也就是:
- Post Norm,对模型,尤其是较深的模型训练不稳定,梯度容易爆炸,学习率敏感,初始化权重敏感,收敛困难。因此需要做大量调参工作,以及learning rate warm up的必要工作,费时费力。
- Pre Norm 则在训练稳定和收敛性方面有明显的优势,所以大模型时代基本都无脑使用 Pre Norm 了。
Pre-Norm一定比Post-Norm好吗?
但是 Pre Norm 也并不是都是好的,2020年的Understanding the Difficulty of Training Transformers这篇论文指出,Pre Norm 有潜在的 Representation Collapse (表示塌陷)问题,具体来说就是靠近输出位置的层会变得非常相似,从而对模型的贡献会变小。
因此,2023年微软提出的ResiDual: Transformer with Dual Residual Connections,就试图融合 Pre Norm 和 Post Norm 的优点。
这也就暗示着 Post Norm 虽然不好训练,但是潜力可能比 Pre Norm 更好。同时这篇论文中提到的在 Layer Norm 的时候,调整
x
x
x 和
f
(
x
)
f(x)
f(x) 的比重,其思路被 DeepNorm 借鉴。只不过这里是可学习的权重,而 DeepNorm 则是超参数。
二者区别总结
-
Post Norm
- 对模型,尤其是较深的模型训练不稳定,梯度容易爆炸,学习率敏感,初始化权重敏感,收敛困难。因此需要做大量调参工作,以及learning rate warm up的必要工作,费时费力
- 潜在好处是,在效果上的优势,但是这个事情还需要大量专业的实验来验证,毕竟现在大模型训练太费钱了,Post Norm 在效果上带来的提升很可能不如多扔点数据让 Pre Norm 更快的训练出来
-
Pre Norm
- 在训练稳定和收敛性方面有明显的优势,所以大模型时代基本都无脑使用 Pre Norm 了
- 但是其可能有潜在的Representation Collapse(表示塌陷) 问题,也就是上限可能不如 Post Norm
参考资料
- [1] https://note.mowen.cn/note/detail?noteUuid=xFoeBN-Ez4OcjRDEs1b51
- [2] https://zhuanlan.zhihu.com/p/474988236