pytorch小记(十三):pytorch中`nn.ModuleList` 详解
- PyTorch 中的 `nn.ModuleList` 详解
- 1. 什么是 `nn.ModuleList`?
- 2. 为什么不直接使用普通的 Python 列表?
- 3. `nn.ModuleList` 的基本用法
- 示例:构建一个包含两层全连接网络的模型
- 4. 使用 `nn.ModuleList` 计算参数总数(与普通列表对比)
- 示例代码
- 5. `nn.ModuleList` 的其他应用
- 示例:构建动态 MLP 模型
- Transformers中的多头注意力机制
- 6. 总结
PyTorch 中的 nn.ModuleList
详解
在构建深度学习模型时,经常需要管理多个网络层(例如多个 nn.Linear
、nn.Conv2d
等)。在 PyTorch 中,nn.ModuleList
是一个非常有用的容器,可以帮助我们存储多个子模块,并自动注册它们的参数。这对于确保所有参数能够参与训练非常重要。本文将详细介绍 nn.ModuleList
的作用、使用方法及与普通 Python 列表的区别,并给出清晰的代码示例。
1. 什么是 nn.ModuleList
?
nn.ModuleList
是一个类似于 Python 列表的容器,但专门用来存储 PyTorch 的子模块(也就是继承自 nn.Module
的对象)。其主要特点是:
-
自动注册子模块:将
nn.Module
存储在ModuleList
中后,这些模块的参数会自动被添加到父模块的参数列表中。这意味着当你调用model.parameters()
时,这些子模块的参数也会被包含进去,从而参与梯度计算和优化。 -
灵活管理:它可以像普通列表一样进行索引、迭代和切片操作,方便构建动态网络结构。
注意:
nn.ModuleList
不会像nn.Sequential
那样自动定义前向传播(forward)流程。你需要在模型的forward()
方法中手动遍历ModuleList
并调用各个子模块。
2. 为什么不直接使用普通的 Python 列表?
虽然可以将 nn.Module
对象存储在普通列表中,但这样做有一个主要问题:
普通列表中的模块不会自动注册为父模块的子模块。
这会导致:
- 调用
model.parameters()
时无法获取这些模块的参数; - 优化器无法更新这些参数,从而影响模型训练。
而使用 nn.ModuleList
可以避免这个问题,因为它会自动将内部所有的模块注册到父模块中。
3. nn.ModuleList
的基本用法
下面通过一个简单的示例来说明如何使用 nn.ModuleList
构建一个简单的神经网络模型。
示例:构建一个包含两层全连接网络的模型
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 创建一个 ModuleList 来存储各层
self.layers = nn.ModuleList([
nn.Linear(10, 20), # 第 1 层:输入 10 个特征,输出 20 个特征
nn.ReLU(), # 激活层
nn.Linear(20, 5) # 第 2 层:输入 20 个特征,输出 5 个特征
])
def forward(self, x):
# 手动遍历 ModuleList 中的每个模块,并依次调用 forward
for layer in self.layers:
x = layer(x)
return x
# 创建模型实例
model = MyModel()
# 打印模型结构
print("模型结构:")
print(model)
# 生成一组示例输入
input_tensor = torch.randn(3, 10) # 3 个样本,每个样本 10 个特征
# 得到模型输出
output = model(input_tensor)
print("\n模型输出:")
print(output)
模型结构:
MyModel(
(layers): ModuleList(
(0): Linear(in_features=10, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=5, bias=True)
)
)
模型输出:
tensor([[ 0.3741, 0.0883, 0.3550, -0.3930, 0.5173],
[ 0.2171, -0.0978, -0.0585, -0.4568, 0.3331],
[ 0.1232, -0.1491, 0.2026, -0.0978, 0.5478]],
grad_fn=<AddmmBackward0>)
说明:
- 在
__init__()
方法中,我们将各个层放在了nn.ModuleList
中。 - 在
forward()
方法中,我们使用了一个简单的 for 循环,依次调用self.layers
中的每个子模块。
4. 使用 nn.ModuleList
计算参数总数(与普通列表对比)
为了进一步说明 nn.ModuleList
与普通列表的区别,我们分别计算一下两种方式下模型的参数总数。
示例代码
import torch.nn as nn
# 使用 ModuleList 存储模型层
layers_ml = nn.ModuleList([
nn.Linear(10, 20),
nn.Linear(20, 5)
])
# 计算 ModuleList 中的参数总数
ml_params = 0
for p in layers_ml.parameters():
ml_params += p.numel()
# 使用普通 Python 列表存储模型层
layers_list = [
nn.Linear(10, 20),
nn.Linear(20, 5)
]
# 计算普通列表中的参数总数
list_params = 0
# 先遍历列表中的每个层
for layer in layers_list:
# 再遍历每个层的参数
for p in layer.parameters():
list_params += p.numel()
print("ModuleList 参数总数:", ml_params)
print("普通列表参数总数:", list_params)
ModuleList 参数总数: 325
普通列表参数总数: 325
说明:
- 第一个 for 循环遍历
layers_ml.parameters()
,直接累加所有参数的元素数。 - 第二部分中,我们先遍历普通列表中的每个
layer
,再单独遍历每个层的参数。这样做使每一步都清晰易懂。
5. nn.ModuleList
的其他应用
示例:构建动态 MLP 模型
当网络结构比较复杂或层数不固定时,可以利用列表生成器动态构建 ModuleList
。
class DynamicMLP(nn.Module):
def __init__(self, layer_sizes):
super(DynamicMLP, self).__init__()
# 使用 for 循环构造每一层,存储在 ModuleList 中
layers = [] # 先用普通列表保存层
for i in range(len(layer_sizes) - 1):
linear_layer = nn.Linear(layer_sizes[i], layer_sizes[i + 1])
layers.append(linear_layer)
# 将普通列表转换为 ModuleList
self.layers = nn.ModuleList(layers)
def forward(self, x):
# 遍历每一层(没有嵌套循环,逐个执行)
for layer in self.layers:
x = torch.relu(layer(x))
return x
# 创建一个动态 MLP:输入 10,隐藏层 20, 30,输出 5
dynamic_model = DynamicMLP([10, 20, 30, 5])
print("动态 MLP 模型:")
print(dynamic_model)
# 测试模型
input_tensor = torch.randn(4, 10) # 4 个样本,每个样本 10 个特征
output = dynamic_model(input_tensor)
print("\n动态 MLP 模型输出:")
print(output)
说明:
- 在
__init__()
方法中,我们使用一个普通列表layers
存储每个nn.Linear
层,然后再将它转换为nn.ModuleList
。 - 在
forward()
方法中,使用单独的 for 循环逐个调用每一层,并对输出应用 ReLU 激活函数。 - 这种写法适用于层数动态变化的网络(例如 MLP、RNN、Transformer 中部分模块)。
Transformers中的多头注意力机制
class SingleHeadAttention(nn.Module):
def __init__(self, embed_dim, head_dim):
super().__init__()
self.query = nn.Linear(embed_dim, head_dim)
self.key = nn.Linear(embed_dim, head_dim)
self.value = nn.Linear(embed_dim, head_dim)
def forward(self, x):
# 实现注意力计算逻辑...
return attended_values
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.head_dim = embed_dim // num_heads
# 显式创建每个注意力头
self.head1 = SingleHeadAttention(embed_dim, self.head_dim)
self.head2 = SingleHeadAttention(embed_dim, self.head_dim)
self.head3 = SingleHeadAttention(embed_dim, self.head_dim)
# 使用ModuleList管理多个头
self.heads = nn.ModuleList([
self.head1,
self.head2,
self.head3
])
self.output_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# 分别处理每个头
head1_out = self.head1(x)
head2_out = self.head2(x)
head3_out = self.head3(x)
# 拼接结果
combined = torch.cat([head1_out, head2_out, head3_out], dim=-1)
return self.output_proj(combined)
关键点解析:
-
显式声明每个注意力头(避免循环)
-
使用ModuleList统一管理注意力头
-
在forward中分别调用每个头
-
保持各头独立性,便于后续调试
6. 总结
nn.ModuleList
是专门用于存储多个子模块的容器,它会自动注册子模块,确保所有参数能参与训练。- 与普通 Python 列表相比,
ModuleList
可以直接通过model.parameters()
获取其中所有参数,从而方便地进行优化。 - 使用
ModuleList
时,前向传播需要手动遍历其中的模块,这提供了更大的灵活性,但也要求开发者理解循环过程。