Netron可视化Pytorch保存的网络模型

news2024/11/18 13:57:28

目录

一.理清网络的输入与输出

二. 将模型转换为onnx格式

三.Netron可视化工具


一.理清网络的输入与输出

我自定义的网络模型(主要看看前向传播函数即可):

import torch
import torch.nn as nn

#导入数据预处理之后的相关数据
from dataPreprocessing import n_categories

#*********************************** 参考这篇文章的图 https://www.cnblogs.com/lccxqk/p/14622532.html
class RNN(nn.Module):
    # rnn = RNN(n_letters, 128, n_letters)说明有多少字符就有多少种输入情况,也就有多少种输出情况,所以最后需要一个Softmax层进行多元分类
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        #其实是两层?只不过i2h和i2o其实可以看做一层,只不过传递的方向不一样
        self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        #防止过拟合
        self.dropout = nn.Dropout(0.1)
        #多元分类,# 对列做Softmax,最后得到的每行和为1;dim=0则每列和为1
        self.softmax = nn.LogSoftmax(dim=1)

    # 前向传播,三个参数都是行向量,且前俩是one-hot矩阵
    # 前向传播,三个参数都是行向量,结合这篇文章的前向传播那里的图进行分析 https://hanhan.blog.csdn.net/article/details/128062706
    # hidden就是图中的a,即向右传的激活值,
    # 一个单词的从左往右的所有字母依次进行前向传播,每次前向传播就对应图中的一列
    # 三个线性层其实是两层
    def forward(self, category, input, hidden):
        '''
        运行以下代码查看torch.cat的功能,即把这三个行向量连接起来
        category=torch.zeros(1, 3)
        print(category)
        input=torch.ones(1,2)
        print(input)
        hidden=torch.zeros(1,2)
        print(hidden)
        input_combined = torch.cat((category, input, hidden), 1)
        print(input_combined)
        '''
        input_combined = torch.cat((category, input, hidden), 1)
        #往右传
        hidden = self.i2h(input_combined)
        #往上传
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        #行向量(2维,即一行2列的矩阵)
        return torch.zeros(1, self.hidden_size)

二. 将模型转换为onnx格式

因为Netron不支持pytorch保存的模型格式,所以需要将模型进行一下格式转换。

PyTorch中自带的torch.onnx模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。 

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)

参数:

  • model(torch.nn.Module)-要被导出的模型
  • args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
  • f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
  • export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
  • verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
  • training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
  • input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
  • output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。

.pth格式的模型转.onnx格式的模型的代码(注释中已有详细说明):

#RNN是我自己的自定义的一个网络
from buildModel import RNN
# 先创建模型对象并加载已经保存的模型参数
model=RNN(59, 128, 59)
model.load_state_dict(torch.load('./model/myRNN.pth'))
model.eval()

#给网络的输入和输出起名字(注意数量和顺序要和自定义网络的前向传播函数的参数对应起来)
input_names = ['名字种类','一个名字','隐藏状态']
output_names = ['预测结果','新隐藏状态']

#获取输入数据,注意,随便搞点输入数据也行,只要尺寸符合即可
from myTrain import  randomTrainingExample
category_tensor, input_line_tensor, target_line_tensor=randomTrainingExample()
hidden=torch.zeros(1, 128)

#参数参考上面的说明
#args是输入模型中的数据,f是保存模型的路径
torch.onnx.export(model=model,args=(category_tensor,input_line_tensor[0],hidden),f='./model/myRNN.onnx',
                  input_names=input_names, output_names=output_names,verbose='True')

三.Netron可视化工具

用上面的代码生成onnx格式的模型之后,再用netron(安装:pip install netron)生成网络结构图:

import netron
modelData='./model/myRNN.onnx'
netron.start(modelData)

netron会自动打开浏览器显式,然后一些操作也很简单,自己点吧点吧就明白了。

贴一下我生成的网络图:

当然不一定非要先保存了模型再转换,也可以训练完就用torch.onnx模块来保存模型为onnx格式的模型,到时候用到再说吧,先这样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/72623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Acrel-EMS企业微电网能效管理平台在某食品加工厂35kV变电站应用-Susie 周

1、概述 该食品加工厂变电站工程规模:电压等级:35/10.5kV,规划主变容量16.3MVA1台8MVA。有一个总配电室,包括35kV开关柜、10kV开关柜和0.4kV配电柜,两个独立变压器室,变压器为干式变压器。35kV供电系统采用…

(2)ITK中迭代器的时间效率

背景 ITK对图像处理中,为了提高代码运行效率,通过迭代器Iterator可以实现对时间的优化。 在ITK的官方文档中也有明确的说明: 针对此说明,本次使用对图像获取最大值最小值的方式,来实验和测试其效率。 代码实现 &am…

JDBC 数据库连接池之Driud

1 数据库连接池简介 数据库连接池是个容器,负责分配、管理数据库连接(Connection) 它允许应用程序重复使用一个现有的数据库连接,而不是再重新建立一个; 释放空闲时间超过最大空闲时间的数据库连接来避免因为没有释放数据库连接而引起的数据…

数据安全新战场,EasyMR为企业筑起“安全防线”

2020年1月,时间跨度长达14年的,微软2.5亿条客户服务和支持记录在网上泄露; 同年4月,微盟发生史上最贵“删库跑路”事件,造成微盟市值一夜之间缩水约24亿港币; 今年7月,网信办依据《数据安全法…

PCIEBPMCx4板卡

PCIEBPMCx4本板卡可以使标准的PMC板卡安装于带有PCIE插槽的PC机上使用,安装后占一个槽位,槽位可以为PCIE x4 PCIE x8、PCIE x16,安装后工作在PCIE x4模式。PCIE X1 后开口也可以使用,但只运行在PCIE X1模式。PCIE支持X4 V2.0,板载…

Python对json的操作总结

Json简介:Json,全名 JavaScript Object Notation,是一种轻量级的数据交换格式。Json最广泛的应用是作为AJAX中web服务器和客户端的通讯的数据格式。现在也常用于http请求中,所以对json的各种学习,是自然而然的事情。 J…

C++学习笔记(十四)——vector的模拟实现

vector各函数接口总览 vector当中的成员变量介绍 默认成员函数 构造函数1 构造函数2 构造函数3 拷贝构造函数 赋值运算符重载函数 析构函数 迭代器相关函数 begin和end 容量和大小相关函数 size和capacity reserve resize empty 修改容器内容相关函数 push_ba…

centos8:安装java

一、背景 因为centos 8 安装Jenkins需要java环境,所以本文记录安装java环境过程。 二、环境 开发电脑:Windows 10 CentOS 8.4 64位 三、安装 3.1、java -version检查是否已安装 java -version 没有安装 3.2、检查系统是否自带jdk rpm -qa |grep …

Word控件Spire.Doc 【超链接】教程(1):如何在C#/VB.NET中给Word 文档插入超链接

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下,轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具,专注于创建、编辑、转…

系统移植 uboot 2

一、uboot源码获取 1.1 uboot官网获取 ftp://ftp.denx.de/pub/u-boot/ 前提是是芯片厂家将uboot源码开源到uboot官网上 1.2 ST开发社区获取 https://wiki.stmicroelectronics.cn/stm32mpu/wiki/STM32MP1_Developer_Package 1.3 ST官网 https://www.st.com/en/embedded-sof…

opcj3—人人开源三大套件的简单用法

renren开源是一个很不错的开源开发组件,人人开源 其中目前对我们最有用的有三个:renren-fast、renren-fast-vue和renren-generator。 renren-generator是核心服务,可以根据数据库自动生成从controller层到service层,再到持久层的…

.net开发安卓入门 - 环境安装

文章目录工具VS2022Android SDK Manager如下图,安装一个镜像和工具模拟器设备管理器如下图启动模拟器,看一下效果常见问题工具 VS2022 下载地址:https://visualstudio.microsoft.com/zh-hans/thank-you-downloading-visual-studio/?skuCom…

Linux邮件服务Postfix部署

我们看下邮件协议: 简单邮件传输协议(SMTP):用于发送和中转出的电子邮件。使用TCP/25端口。 邮局协议版本(POP3):用于将邮件存储到本地,占用服务器的TCP/110端口。 Internet 消息访问…

【Python游戏】一个csdn小编用Python语言写了一个足球游戏,成功模拟世界杯决赛现场

前言 halo,包子们下午好 最近世界杯不是很火呀 很多小伙伴应该都知道球赛反正买,别墅靠大海! 今天就给大家实现一个类似世界杯的足球小游戏,咱就说真的堪比国足了! 哈哈哈~ 好啦 直接开整!!&am…

「以代码作画」从数据角度剖析Art Blocks生成艺术

作者:Mia Bao, co-founder of thepass.to, chief partner of WHALE members 数据:Jin, data analyst of thepass.to 出品:ThePASS & BeepCrypto 文章数据:https://docs.google.com/spreadsheets/d/1zDun4eUTwA-BMU5Hl2c5EC…

基于SSM网上商城购物系统的设计与实现

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

目标检测算法——人体姿态估计数据集汇总 2(附下载链接)

🎄🎄近期,小海带在空闲之余收集整理了一批人体姿态估计数据集供大家参考。 整理不易,小伙伴们记得一键三连喔!!!🎈🎈 目录 一、V-COCO数据集 二、宜家 ASM 数据集 三、…

如何解决在加载、保存或覆盖项目文件时 Lumion 可能无法打开或显示错误的问题?

为什么在加载、保存或覆盖项目文件时 Lumion 可能无法打开或显示错误?那么这个问题大家跟着赞奇云工作站一起来解答吧。 1. 这就是为什么 如果Lumion在加载 .LS Project文件时崩溃或显示错误 ,通常意味着 .LS Project 文件因保存错误而损坏。遗憾的是&…

电脑技巧:分享6个实用的资源网站

❤️作者主页:IT技术分享社区 ❤️作者简介:大家好,我是IT技术分享社区的博主,从事C#、Java开发九年,对数据库、C#、Java、前端、运维、电脑技巧等经验丰富。 ❤️个人荣誉: 数据库领域优质创作者🏆&#x…

一框式检索和高级检索

0. 学习内容 2022年12月8日15:38:07CNKI学习 学会多种检索方式检索基础文献 1. 一框式检索 1.1 简单使用 左侧选择检索字段 根据需求选择 输入想要的检索词输入想要的检索范围 顾名思义:在检索的时候只有一个搜索框,从而实现对文献进行检索 2. 高级检索…