【bug】Transformer输出张量的值全部相同?!
- 现象
- 原因
- 解决
现象
输入经过TransformerEncoderLayer
之后,基本所有输出都相同了。
核心代码如下,
from torch.nn import TransformerEncoderLayer
self.trans = TransformerEncoderLayer(d_model=2,
nhead=2,
batch_first=True,
norm_first=True)
...
x = torch.randn(2, 8, 2)
print("x before transformer", x, x.shape)
x = self.trans(x) # Transformer Encoder Layers
print("x after transformer", x, x.shape)
输出:
x before transformer tensor([[[ 0.2244, -1.9497],
[ 0.4710, -0.7532],
[-1.4016, 0.5266],
[-1.1386, -2.5170],
[-0.0733, 0.0240],
[-0.9647, -0.9760],
[ 2.4195, -0.0135],
[-0.3929, 1.2231]],
[[ 0.1451, -1.2050],
[-1.1139, -1.7213],
[ 0.5105, 0.4111],
[ 2.1308, 2.5476],
[ 1.2611, -0.7307],
[-2.0910, 0.1941],
[-0.3903, 1.3022],
[-0.2442, 0.5787]]]) torch.Size([2, 8, 2])
x after transformer tensor([[[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000]],
[[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[-1.0000, 1.0000],
[-1.0000, 1.0000]]], grad_fn=<NativeLayerNormBackward0>) torch.Size([2, 8, 2])
原因
在询问过全知全能的New Bing之后,找到一篇文章。
简化Transformer模型训练技术简介
Understand the difficulty of training transformer
时间:2020
引用:124
期刊会议:EMNLP 2020
代码:https://github.com/LiyuanLucasLiu/Transformer-Clinic
Transformer的Layer Norm的位置很关键。
如果我们使用Post-LN,模型可能对参数不稳定,导致训练的失败。 而Pre-LN却不会。
原始Transformer论文中为Post-LN。一般来说,Post-LN会比Pre-LN的效果好。
针对这点,Understand the difficulty of training transformer文中提出使用Admin初始化。在训练稳定的前提下,拥有Post-LN的性能。
解决
这里我们使用Pre-LN。
torch.nn.TransformerEncodelayer
就提供了norm_frist
的选项。
self.trans = TransformerEncoderLayer(d_model=2,
nhead=2,
batch_first=True,
norm_first=True)
修改后,输出:
x before transformer tensor([[[ 0.5373, 0.9244],
[ 0.6239, -1.0643],
[-0.5129, -1.1713],
[ 0.5635, -0.7778],
[ 0.4507, -0.0937],
[ 0.2720, 0.7870],
[-0.5518, 0.8583],
[ 1.5244, 0.5447]],
[[ 0.3450, -1.9995],
[ 0.0530, -0.9778],
[ 0.8687, -0.6834],
[-1.6290, 1.6586],
[ 1.2630, 0.4155],
[-2.0108, 0.9131],
[-0.0511, -0.8622],
[ 1.5726, -0.7042]]]) torch.Size([2, 8, 2])
x after transformer tensor([[[ 0.5587, 0.9392],
[ 0.5943, -1.0631],
[-0.5196, -1.1681],
[ 0.5635, -0.7765],
[ 0.4341, -0.0819],
[ 0.2943, 0.7998],
[-0.5329, 0.8661],
[ 1.5166, 0.5528]],
[[ 0.3450, -1.9860],
[ 0.0273, -0.9603],
[ 0.8415, -0.6682],
[-1.6297, 1.6686],
[ 1.2261, 0.4175],
[-2.0205, 0.9314],
[-0.0595, -0.8421],
[ 1.5567, -0.6847]]], grad_fn=<AddBackward0>) torch.Size([2, 8, 2])