本文重点
我们前面一章学习了自动求导,这很有用,但是在实际使用中我们基本不会使用,因为这个技术过于底层,我们接下来将学习pytorch中的nn模块,它是构建于autograd之上的神经网络模块,也就是说我们使用pytorch封装好的神经网络层,它自动会具有求导的功能,也就是说这部分我们根本不用关系。此专栏主要学习步骤2(神经网络的搭建),详细步骤请看前面的文章。
神经网络工具箱
torch.nn是专门为深度学习设计的工具箱,它的核心数据结构是Module类,它是一个抽象的概念,它既可以表示神经网络的一层,又可以表示一个包含很多层的神经网络。
我们在搭建网络模型的时候,最常见的做法就是继承nn.module,然后编写自己的网络层,下面通过一个简单的例子来看一下,我们如何通过nn.module模块来实现一个自己的全连接层。
自定义全连接层
import torch
from torch import nn
class MyLinear(nn.Module):
def __init__(self, inp, outp):
super(MyLinear, self).__init__()
self.w = nn.Parameter(torch.randn(outp, inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self, x):
x = x @ self.w.t() + self.b
return x
layer=MyLi