torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
功能:对输入的四维数组进行批量标准化处理(归一化)
计算公式如下:
对于所有的batch中样本的同一个channel的数据元素进行标准化处理,即如果有C个通道,无论batch中有多少个样本,都会在通道维度上进行标准化处理,一共进行C次
num_features:通道数
eps:分母中添加的值,目的是计算的稳定性(分母不出现0),默认1e-5
momentum:用于运行过程中均值方差的估计参数,默认0.1
affine:设为true时,给定开易学习的系数矩阵r和b
track_running_stats:BN中存储的均值方差是否需要更新,true需要更新
举个例子
>import torch
>import torch.nn as nn
>input = torch.arange(0, 12, dtype=torch.float32).view(1, 3, 2, 2)
>print(m)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]],
[[ 8., 9.],
[10., 11.]]]])
>bn = nn.BatchNorm2d(3)
>print(bn.weight)
tensor([1., 1., 1.], requires_grad=True)
>print(bn.bias)
tensor([0., 0., 0.], requires_grad=True)
>output = m(input)
>print(output)
tensor([[[[-1.3416, -0.4472],
[ 0.4472, 1.3416]],
[[-1.3416, -0.4472],
[ 0.4472, 1.3416]],
[[-1.3416, -0.4472],
[ 0.4472, 1.3416]]]], grad_fn=<NativeBatchNormBackward0>)
上面是使用nn接口计算,现在我们拿第一个数据计算一下验证
公式:
#先计算第一个通道的均值、方差
>first_channel = input[0][0] #第一个通道
tensor([[0., 1.],
[2., 3.]])
#1、计算均值方差
>mean = torch.Tensor.mean(first_channel)
tensor(1.5000) #均值
>var=torch.Tensor.var(first_channel,False)
tensor(1.2500) #方差
#2、按照公式计算
>bn_value =((input[0][0][0][0] -mean)/(torch.pow(var,0.5)+bn.eps))*bn.weight[0]+bn.bias[0]
#这里就是(0-1.5)/sqrt(1.25+1e-5)*1.0 + 1.0
tensor(-1.3416, grad_fn=<AddBackward0>)
第一个值都是-1.3416,对上了,其他都是一样。