BN的计算公式
BN中均值与方差的计算
所以对于输入x: b, c, h, w
则 mean: 1 , c, 1 , 1
var: 1 , c, 1 , 1
代码
class BatchNorm ( nn. Module) :
def __init__ ( self, num_features, num_dims) :
super ( ) . __init__( )
if num_dims == 2 :
shape = ( 1 , num_features)
else :
shape = ( 1 , num_features, 1 , 1 )
self. gamma = nn. Parameter( torch. ones( shape) )
self. beta = nn. Parameter( torch. zeros( shape) )
self. moving_mean = torch. zeros( shape)
self. moving_var = torch. ones( shape)
def forward ( self, x, momentum= 0.9 , eps= 1e-5 ) :
if self. training:
assert len ( x. shape) in ( 2 , 4 )
if len ( x. shape) == 2 :
mean = x. mean( dim= 0 , keepdim= True )
var = x. var( dim= 0 , keepdim= True )
else :
mean = x. mean( dim= ( 0 , 2 , 3 ) , keepdim= True )
var = x. var( dim= ( 0 , 2 , 3 ) , keepdim= True )
x_hat = ( x - mean) / torch. sqrt( var + eps)
self. moving_mean = momentum * self. moving_mean + ( 1.0 - momentum) * mean
self. moving_var = momentum * self. moving_var + ( 1.0 - momentum) * var
else :
x_hat = ( x - self. moving_mean) / torch. sqrt( self. moving_var + eps)
out = self. gamma * x_hat + self. beta
return out