import torch
import torch.nn as nn
# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
print(output)
print(output.shape)
这段代码展示了如何使用 PyTorch 中的 nn.BatchNorm2d
模块进行批归一化操作。首先,创建了一个具有100个通道的 nn.BatchNorm2d
对象 m
,默认情况下具有可学习的参数。然后,又创建了一个相同的 nn.BatchNorm2d
对象 m
,但通过设置 affine=False
来禁用了可学习的参数。
接下来,创建了一个输入张量 input
,其形状为 (20, 100, 35, 45)
,表示一个批次大小为20的具有100个通道的二维图像。然后,将输入张量传递给 m
进行批归一化操作,得到输出张量 output
。
最后,打印输出张量 output
的形状,以验证批归一化操作的结果。