平时我们在编写神经网络时,经常会用到layernorm这个函数来加快网络的收敛速度。那layernorm到底在哪个维度上进行归一化的呢?
一、问题描述
首先借用知乎上的一张图,原文写的也非常好,大家有空可以去阅读一下,链接放在参考文献里了。如左图所示,假设现在输入的维度是(bs,seq_len, embedding),其中bs代表batch_size, seq_len代表序列长度 ,embedding表示嵌入大小。
那在layernorm时,我们是对(seq_len, embedding)这个矩阵取均值和方差(上图);还是只对embedding这个维度取均值和方差呢(下图)?前者会得到bs个均值和方差,而后者会得到bs * seq_len 个均值和方差。下面我们进行编程验证。
二、编程实现
import torch
batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)
layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("用pytorch的layer_norm所得结果\n", layer_norm(embedding))
print("自己编写layer_norm所得结果")
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)
print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))
结果:
用pytorch的layer_norm所得结果
tensor([[[ 0.7475, -1.7061, 0.6676, 0.2910],
[ 0.1144, -0.6476, 1.5753, -1.0421],
[-1.0278, -0.7498, 0.2559, 1.5218]],
[[-1.0527, -0.8723, 1.3354, 0.5895],
[-0.6403, -1.1399, 1.4842, 0.2961],
[ 0.7352, -0.8236, -1.1342, 1.2226]]])
自己编写layer_norm所得结果
mean: torch.Size([2, 3, 1])
y_custom: tensor([[[ 0.7475, -1.7061, 0.6676, 0.2910],
[ 0.1144, -0.6476, 1.5753, -1.0421],
[-1.0278, -0.7498, 0.2559, 1.5218]],
[[-1.0527, -0.8723, 1.3354, 0.5895],
[-0.6403, -1.1399, 1.4842, 0.2961],
[ 0.7352, -0.8236, -1.1342, 1.2226]]])
结果的相等的。可以看到,我们在取均值和方差时,是对最后一个维度取的。所以我们会得到 (N,C)个均值与方差。假设二是正确的。
而实际上这种实现方法和Instance Norm是相同的
from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(2, 3, 4)
output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入
print(output.reshape(2,3,4))
layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))
结果:
tensor([[[ 0.7475, -1.7061, 0.6676, 0.2910],
[ 0.1144, -0.6476, 1.5753, -1.0421],
[-1.0278, -0.7498, 0.2559, 1.5218]],
[[-1.0527, -0.8723, 1.3354, 0.5895],
[-0.6403, -1.1399, 1.4842, 0.2961],
[ 0.7352, -0.8236, -1.1342, 1.2226]]])
tensor([[[ 0.1293, -1.0034, 1.5760, -0.7018],
[-1.3981, -0.4828, 1.0876, 0.7933],
[-1.7034, 0.8545, 0.4876, 0.3612]],
[[-1.4750, 1.2212, -0.2607, 0.5144],
[ 0.7017, -0.8350, 1.2502, -1.1169],
[-1.7273, 0.6965, 0.5147, 0.5161]]])
三、参考文献
(45 封私信 / 80 条消息) 为什么Transformer要用LayerNorm? - 知乎 (zhihu.com)https://www.zhihu.com/question/487766088/answer/2644783144