目录
前言
1.使用nn.Module基类
2.使用nn.Sequential容器
3. 使用nn.ModuleList
4. 使用nn.ModuleDict
5. 混合使用nn.Module和原生Python代码
6.表格总结
前言
nn.Module
:最通用、最灵活的方式,适用于几乎所有场景。nn.Sequential
:适合简单的顺序模型,代码简洁。nn.ModuleList
和nn.ModuleDict
:适合需要动态调整层的模型,方便子模块的管理和访问。- 混合使用原生Python代码:适合需要动态逻辑或复杂决策的网络模型。
这些方式可以根据具体项目需求进行选择,通常,
nn.Module
是最常用的方式,它能够满足几乎所有的模型设计需求。
1.使用nn.Module
基类
这是最常用的方法之一。你可以通过继承nn.Module
基类来定义自己的神经网络。nn.Module
提供了神经网络层的封装以及模型参数的管理。
示例:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNet()
详细步骤:
__init__
方法:在这里定义网络层。self.fc1 = nn.Linear(784, 128)
表示创建了一个输入大小为784、输出大小为128的全连接层。forward
方法:定义了数据的前向传播方式。输入数据依次通过定义的各个层,最后得到输出。
优点:灵活,适合复杂网络。
2.使用nn.Sequential
容器
如果你的模型是一个简单的顺序网络(即各层按顺序逐个执行,没有复杂的网络结构),可以使用nn.Sequential
来简化代码。
示例:
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
详细步骤:
nn.Sequential
接受一系列的层作为参数,并按顺序逐个应用于输入数据。- 各个层之间的前向传播方式自动处理,减少了手动编写
forward
方法的工作。
优点:简洁,适合简单的线性模型。
3. 使用nn.ModuleList
nn.ModuleList
可以用来存储一个nn.Module
的列表,但不会定义网络的前向传播逻辑,需要在forward
方法中手动实现。
class CustomNet(nn.Module):
def __init__(self):
super(CustomNet, self).__init__()
self.layers = nn.ModuleList([nn.Linear(100, 100) for i in range(5)])
self.relu = nn.ReLU()
def forward(self, x):
for layer in self.layers:
x = self.relu(layer(x))
return x
model = CustomNet()
详细步骤:
- 使用
nn.ModuleList
存储多个相同或不同的层。 - 在
forward
方法中循环这些层,自定义前向传播逻辑。
优点:适合需要灵活定义多个子层的网络结构。
4. 使用nn.ModuleDict
nn.ModuleDict
和nn.ModuleList
类似,但它以字典的形式存储模块,允许通过键值对的方式来访问不同的子模块。
class CustomNet(nn.Module):
def __init__(self):
super(CustomNet, self).__init__()
self.layers = nn.ModuleDict({
'fc1': nn.Linear(784, 128),
'fc2': nn.Linear(128, 64),
'fc3': nn.Linear(64, 10)
})
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.layers['fc1'](x))
x = self.relu(self.layers['fc2'](x))
x = self.layers['fc3'](x)
return x
model = CustomNet()
详细步骤:
- 使用
nn.ModuleDict
来存储模块,可以通过键值访问。 - 灵活构建前向传播路径,适合需要不同路径的网络结构。
优点:适合需要动态访问或选择子模块的网络。
5. 混合使用nn.Module
和原生Python代码
在某些情况下,你可能需要在模型中嵌入一些动态的逻辑。此时,可以将nn.Module
与原生Python控制流(如if-else
、for
循环等)结合使用,构建更加复杂的模型。
class DynamicNet(nn.Module):
def __init__(self):
super(DynamicNet, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
if x.mean() > 0.5:
x = self.fc1(x)
else:
x = self.fc2(x)
return self.relu(x)
model = DynamicNet()
详细步骤:
- 在
forward
中使用Python原生的控制流来决定前向传播路径。 - 这种方式非常灵活,适合复杂的模型逻辑需求。
优点:灵活且强大,适合复杂模型。