1、 Theory
look for this link for more information, actually only this image can illustrate the group normalization.you can ignore the rest of this artical.
2、 Code
check this link for detailed about the formulation and the theory of the group normalziation, cause it's very good for you to understand the advanced konwledge about GL
# group normalization
import torch
# input = (batch_size,C,H,W)
batch_size = 1
C = 6
H = 3
W = 3
# divide input channels(C) into small groups
group = 2
input = torch.arange(batch_size*C*H*W).reshape(batch_size,C,H,W).float()
group_input = input.view(batch_size,group,-1)
# calculate the mean and square mean,then calculate var of the group_input
mean = group_input.mean(dim=-1,keepdim=True)
mean_square = (group_input**2).mean(dim=-1,keepdim=True)
var = mean_square - mean**2
eps = 1e-5
group_input_norm = (group_input - mean) / torch.sqrt(var + eps)
group_input_norm = group_input_norm.view(batch_size,C,-1)
# define gamma(scale) and beta(shift)
scale = torch.ones(C,).view(1,-1,1)
shift = torch.zeros(C,).view(1,-1,1)
output = group_input_norm * scale + shift
output = output.view(batch_size,C,H,W)