【学习笔记】【Pytorch】六、nn.Module的使用
- 学习地址
- 主要内容
- 一、torch.nn模块概述
- 二、nn.Module类的使用
- 1.使用说明
- 2.代码实现
学习地址
PyTorch深度学习快速入门教程【小土堆】.
主要内容
一、torch.nn模块概述
概述:帮助程序员方便执行与神经网络相关的行为。
二、nn.Module类的使用
概述:所有神经网络模块的基类,既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。
一、torch.nn模块概述
from torch import nn
概述:nn是Neural Network的简称,帮助程序员方便执行如下的与神经网络相关的行为:
(1)创建神经网络
(2)训练神经网络
(3)保存神经网络
(4)恢复神经网络
nn文件夹结构:
二、nn.Module类的使用
概述:所有神经网络模块的基类,既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。
1.使用说明
【继承】官方示例:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
# 前向传播,继承 nn.Module 必需被重写的函数,
# 否则使用 object(x) 实现会报错
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
前向传播:
注:relu是激活函数,小于0就是0,大于0就是本身。
2.代码实现
import torch
from torch import nn
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
# 前向传播,继承 nn.Module 必需被重写的函数,
# 否则使用 object(x) 实现会报错
def forward(self, input):
output = input + 1
return output
nn_model = Model() # 创建一个实例化
x = torch.tensor(1.0) # 将1.0这个数转换为tensor数据类型
output = nn_model(x) # nn.Module 的__call__()调用了forward()
print(output) # tensor(2.)
控制台输出:
tensor(2.)