1、原理
在图像预处理过程中会对图像进行标准化处理,这样能够加速网络的收敛速度。
如下图所示,对于Conv1来说输入的是满足某一分布的特征矩阵,但对于Conv2来说输入的feature map就不一定满足某一分布规律。 Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。
注意:这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,即要计算出整个训练集的feature map然后再进行标准化处理。
对于一个大型的数据集来说,这明显是不可能计算的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch的feature map然后再进行标准化(Batch越大越接近整个数据集的分布,效果越好)。
原论文中说“对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。” 假设输入的x是RGB三通道的彩色图像,那么d就是输入图像的channels,即d=3。x=(x1, x2, x3),其中x1, x2, x3分别代表R、G、B通道所对应的特征矩阵。所以标准化就是:分别对我们的R通道,G通道,B通道进行处理。公式如下所示:
2、计算过程
上图展示了一个batch size=2(两张图片)的Batch Normalization的计算过程:假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,它们的channel都为2,x(1)代表这个batch里所有feature的channel1的数据,x(2)代表这个batch里所有feature的channel2的数据,根据公式分别计算每个channel的标准化值。
3、计算细节
在训练网络的过程中,我们通过一个batch一个batch的数据进行训练,但是在预测过程中,通常都是输入一张图片进行预测,此时batch size=1,如果在通过上述方法计算均值和方差就没有意义了。所以在训练过程中要不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法来记录统计的均值和方差,在训练完后可以近似认为所统计的均值和方差等于整个训练集的均值和方差。然后在验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理。
4、Pytorch实验
bn_process函数是自定义的BN处理方法,用来验证与使用官方的BN处理方法得到结果是否一致:
- 在bn_process中计算batch里所有feature的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方);
- 然后使用均值和总体标准差对feature的每个维度进行标准化处理;
- 最后使用均值和样本标准差对均值和标准差进行更新统计。
import numpy as np
import torch.nn as nn
import torch
def bn_process(feature, mean, var):
feature_shape = feature.shape
for i in range(feature_shape[1]):
# [batch, channel, height, width]
feature_t = feature[:, i, :, :]
mean_t = feature_t.mean()
# 总体标准差
std_t1 = feature_t.std()
# 样本标准差
std_t2 = feature_t.std(ddof=1)
# bn process
# 这里记得加上eps和pytorch保持一致
feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
# update calculating mean and var
mean[i] = mean[i] * 0.9 + mean_t * 0.1
var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
print(feature)
# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())
# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)
打印自定义bn_process函数得到的输出以及使用官方的BN处理方法得到输出,明显结果是一样的(只是精度不同)。
5、使用BN时的注意事项
- 训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。
- batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。(设置的大小具体还是要看电脑配置)
- 建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,即使使用了偏置bias求出的结果也是一样的。