1. nn.functional
torch.nn中还有一个很常用的模块:nn.functional。torch.nn中的大多数layer,在functional中都有一个与之相对应的函数。nn.functional中的函数和nn.Module的主要区别在于:使用nn.Module实现的layer是一个特殊的类,其由class layer(nn.Module)定义,会自动提取可学习的参数;使用nn.functional实现的layer更像是纯函数,由def function(input)定义。
2. nn.functional与nn.Module的区别
下面举例说明functional的使用,并对比它与nn.Module的不同之处:
In: input = t.randn(2, 3)
model = nn.Linear(3, 4)
output1 = model(input)
output2 = nn.functional.linear(input, model.weight, model.bias)
output1.equal(output2)
Out:True
In: b1 = nn.functional.relu(input)
b2 = nn.ReLU()(input)
b1.equal(b2)
Out:True
此时读者可能会问,应该什么时候使用nn.Module,什么时候使用nn.functional呢?答案很简单,如果模型具有可学习的参数,那么最好用nn.Module,否则既可以使用nn.functional,也可以使用nn.Module。二者在性能上没有太大差异,具体的选择取决于个人的喜好。由于激活函数(如ReLU、sigmoid、tanh)、池化(如MaxPool)等层没有可学习参数,可以使用对应的functional函数代替,对于卷积、全连接等具有可学习参数的层,建议使用nn.Module。另外,虽然dropout操作也没有可学习参数,但是建议使用nn.Dropout而不是nn.functional.dropout,因为dropout在训练和测试两个阶段的行为有所差异,使用nn.Module对象能够通过model.eval()操作加以区分。下面举例说明如何在模型中搭配使用nn.Module和nn.functional:
In: from torch.nn import functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.pool(F.relu(self.conv1(x)), 2)
x = F.pool(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
对于不具备可学习参数的层(如激活层、池化层等),可以将它们用函数代替,这样可以不用放置在构造函数__init__()中。对于有可学习参数的模块,也可以用functional代替,只不过实现起来较为烦琐,需要手动定义参数Parameter。例如,前面实现的全连接层,就可以将weight和bias两个参数单独拿出来,在构造函数中初始化为Parameter:
In: class MyLinear(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(t.randn(3, 4))
self.bias = nn.Parameter(t.zeros(3))
def forward(self):
return F.linear(input, weight, bias)
关于nn.functional的设计初衷,以及它和nn.Module的比较说明,读者可参考PyTorch论坛的相关讨论和说明。
3. 采样函数
在nn.functional中还有一个常用的函数:采样函数torch.nn.functional.grid_sample,它的主要作用是对输入的Tensor进行双线性采样,并将输出变换为用户想要的形状。下面以lena为例进行说明:
In: to_pil(lena.data.squeeze(0)) # 原始的lena数据
In: # lena的形状是1×1×200×200,(N,C,Hin,Win)
# 进行仿射变换,对图像进行旋转
angle = -90 * math.pi / 180
theta = t.tensor([[math.cos(angle), math.sin(-angle), 0], \
[math.sin(angle), math.cos(angle), 0]], dtype=t.float)
# grid形状为(N,Hout,Wout,2)
# grid最后一个维度大小为2,表示输入中pixel的位置信息,取值范围在(-1,1)
grid = F.affine_grid(theta.unsqueeze(0), lena.size())
In: import torch
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
out = F.grid_sample(lena, grid=grid, mode='bilinear')
to_pil(out.data.squeeze(0))