参考博客
BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm总结
PyTorch学习之归一化层(BatchNorm、LayerNorm、InstanceNorm、GroupNorm)
BN,LN,IN,GN从学术化上解释差异:
BatchNorm:batch方向做归一化,算NHW的均值,对小batchsize效果不好;BN主要缺点是对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布GroupNorm:将channel方向分group,然后每个group内做归一化,算(C//G)HW的均值;这样与batchsize无关,不受其约束。
BatchNorm
沿着通道计算每个batch的均值和方差, 因此计算的结果和batch_size有关===>缺点:对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布;
# x_shape:[B, C, H, W] x_mean = np.mean(x, axis=(0, 2, 3), keepdims=True) x_var = np.var(x, axis=(0, 2, 3), keepdims=True0)
算法过程:
- 沿着通道计算每个batch的均值u
- 沿着通道计算每个batch的方差σ^2
- 对x做归一化,x’=(x-u)/开根号(σ^2+ε)
- 加入缩放和平移变量γ和β ,归一化后的值,y=γx’+β
加入缩放平移变量的原因是:保证每一次数据经过归一化后还保留原有学习来的特征,同时又能完成归一化操作,加速训练。 这两个参数是用来学习的参数。
实现公式:
torch.nn.BatchNorm1d
(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.nn.BatchNorm2d
(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.nn.BatchNorm3d
(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
使用
BN = torch.nn.BatchNorm2d(num_features=3, eps=1e-6, affine=True)
print(x.shape) #torch.Size([6, 3, 2, 2])
print(x)
x = BN(x)
print(x)
print(x.shape) #torch.Size([6, 3, 2, 2])
GroupNorm
GN 特点是与批处理大小无关,不受其约束
主要是针对Batch Normalization对小batchsize效果差,GN将channel方向分group,然后每个group内做归一化,算(C//G)HW的均值,这样与batchsize无关,不受其约束。
x_mean = np.mean(x, axis=(2, 3, 4), keepdims=True) x_var = np.var(x, axis=(2, 3, 4), keepdims=True0)
torch.nn.GroupNorm
(num_groups, num_channels, eps=1e-05, affine=True)
参数:
num_groups:需要划分为的groups
num_features: 来自期望输入的特征数,该期望输入的大小为’batch_size x num_features [x width]’
eps: 为保证数值稳定性(分母不能趋近或取0),给分母加上的值。默认为1e-5。
momentum: 动态均值和动态方差所使用的动量。默认为0.1。
affine: 布尔值,当设为true,给该层添加可学习的仿射变换参数。
实现公式
使用
# 随机生成1-10范围内的随机数, 【批处理大小,通道数,宽,高】
x = np.random.randint(1,10, [6,3,2,2])
x = torch.FloatTensor(x)
GN= torch.nn.GroupNorm(num_groups=3, num_channels=3, eps=1e-6, affine=True)
print(x.shape) #torch.Size([6, 3, 2, 2])
print(x)
x = GN(x)
print(x)
print(x.shape) #torch.Size([6, 3, 2, 2])
生成的数据,验证计算公式:
以下图床