文章目录
- PyTorch 神经网络模型可视化(Netron)
- ONNX
- torch.save
- torch.jit.script
- torch.jit.trace
PyTorch 神经网络模型可视化(Netron)
Netron 是一个用于可视化深度学习模型的工具,可以帮助我们更好地理解模型的结构和参数。
支持以下格式的模型存储文件:
格式 | 模板(文件) | 免下载打开 |
---|---|---|
ONNX | squeezenet | open |
TensorFlow Lite | yamnet | open |
TensorFlow | chessbot | open |
Keras | mobilenet | open |
TorchScript | traced_online_pred_layer | open |
Core ML | exermote | open |
Darknet | yolo | open |
GitHub 链接:https://github.com/lutzroeder/netron
官网:https://netron.app
ONNX
(1)在 PyTorch 中,可以使用 torch.onnx.export
函数将模型导出为 ONNX 格式:
import torch
import netron
# 定义 PyTorch 模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.bn = torch.nn.BatchNorm2d(64)
self.relu = torch.nn.ReLU(inplace=True)
self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = torch.nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 64 * 8 * 8)
x = self.fc(x)
return x
# 创建模型实例并加载预训练权重
model = MyModel()
# 设置示例输入
input = torch.randn(1, 3, 32, 32)
# 将模型导出为 ONNX 格式
torch.onnx.export(model, input, './model/Test/onnx_model.onnx') # 导出后 netron.start(path) 打开
(2)再使用 Netron 的 netron.start
指令打开导出的 ONNX 模型文件:
import netron
# 打开导出的 ONNX 模型文件
netron.start('./model/Test/onnx_model.onnx')
Serving './model/Test/onnx_model.onnx' at http://localhost:8080
将在浏览器中自动启动 Netron 工具,并对该模型文件进行可视化。
注意:
当模型被导出为 ONNX 格式,会在指定目录生成以 .onnx
为后缀的文件,只需将其上传至 Netron 官网 也可实现可视化:
在 Netron 中,可以查看模型的结构、参数和输入输出等信息。可以通过缩放、旋转和平移等操作来调整模型的可视化效果,以更好地理解模型的结构和参数。
torch.save
当使用 torch.save
对保存的模型进行可视化时:
# 保存模型
torch.save(model.state_dict(), './model/Test/saved_model.pt')
# 可视化
netron.start('./model/Test/saved_model.pt')
如下图,这种方式并不能显示该模型的详细信息:
所以: Netron 不支持 PyTorch 通过 torch.save
方式导出的模型文件。
torch.jit.script
可参考:torch.jit.script 与 torch.jit.trace
使用 torch.jit.script
先将模型转换为脚本,再使用 torch.jit.save
保存模型,最后进行可视化:
# TorchScript:script
scripted_model = torch.jit.script(model)
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
# 可视化
netron.start('./model/Test/scripted_model.pth')
torch.jit.trace
可参考:torch.jit.script 与 torch.jit.trace
使用 torch.jit.trace
先将模型转换为跟踪模型执行的工具,再使用 torch.jit.save
保存模型,最后进行可视化:
# TorchScript:trace
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')
# 可视化
netron.start('./model/Test/traced_model.pth')