0. 简介
PyTorch是一个深度学习框架,它使用张量(tensor)作为核心数据结构。在可视化PyTorch模型时,了解每个张量运算的意义非常重要。张量运算作为神经网络模型中的基本操作。它们用于处理输入数据、执行权重更新和生成预测结果。同时张量运算还用于计算损失函数。损失函数衡量了模型预测与真实标签之间的差异。通过使用张量运算,可以计算出模型的预测结果与真实标签之间的差异,并将其最小化。所以一款能够可视化任何PyTorch模型的张量显示开源项目非常重要。这里是该项目的Github地址。
1. 了解TorchLens
TorchLens是一个用于完成两个任务的软件包:
- 轻松地从PyTorch模型的每个中间操作中提取激活值,无需进行任何修改,只需一行代码即可。这里的“每个操作”指的是每个操作;“一行代码”指的是一行代码。
- 通过直观的自动可视化和关于网络计算图的详细元数据(部分列表在此处)来理解模型的计算结构。
下面是一个非常简单的循环模型的示例;正如您所看到的,您只需像正常定义模型一样将其传入,TorchLens将返回完整的前向传递日志以及可视化结果:
class SimpleRecurrent(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(in_features=5, out_features=5)
def forward(self, x):
for r in range(4):
x = self.fc(x)
x = x + 1
x = x * 2
return x
simple_recurrent = SimpleRecurrent()
model_history = tl.log_forward_pass(simple_recurrent, x,
layers_to_save='all',
vis_opt='rolled')
print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer
'''
tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197],
[-0.1083, -1.5051, -0.2570, -0.2024, 0.8248],
[ 0.1031, -1.4315, -0.5999, -0.4017, 0.7580],
[-0.0396, -1.3813, -0.3523, -0.2008, 0.6654],
[ 0.0980, -1.4073, -0.5934, -0.3866, 0.7371],
[-0.1106, -1.2909, -0.3393, -0.2439, 0.7345]])
'''
这是一个非常复杂的变压器模型(swin_v2_b),其前向传递过程中涉及了1932个操作;我们也可以获取每个操作的保存输出。
2. 安装TorchLens
要安装TorchLens,请先安装graphviz(用于生成网络可视化),然后使用pip安装TorchLens:
sudo apt install graphviz
pip install torchlens
TorchLens与PyTorch的1.8.0及以上版本兼容
3. 如何使用TorchLens
TorchLens的主要功能是log_forward_pass:当在模型和输入上调用时,它会在模型上运行前向传递,并返回一个包含中间层激活和相关元数据的ModelHistory对象,同时还提供了在前向传递过程中发生的每个操作的可视化表示: