【pytorch onnx】Pytorch导出ONNX及模型可视化教程

news2024/11/13 18:59:15

文章目录

    • 1 背景介绍
    • 2 实验环境
    • 3 torch.onnx.export函数简介
    • 4 单输入网络导出ONNX模型代码实操
    • 5 多输入网络导出ONNX模型代码实操
    • 6 ONNX模型可视化
    • 7 ir_version和opset_version修改
    • 8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

1 背景介绍

使用深度学习开源框架Pytorch训练完网络模型后,在部署之前通常需要进行格式转换,例如地平线工具链模型转换目前仅支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将Pytorch格式的模型导出到ONNX格式的模型。

2 实验环境

本文以Python3.6为例,涉及到的whl包及版本信息如下:

torch 1.10.2
onnx 1.8.0
onnxruntime 1.10.0
numpy 1.19.5

3 torch.onnx.export函数简介

torch.onnx.export函数实现了Pytorch模型导出到ONNX模型,在pytorch1.10.2中,torch.onnx.export函数参数如下:

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
           input_names=None, output_names=None, operator_export_type=None,
           opset_version=None, _retain_param_name=None, do_constant_folding=True,
           example_outputs=None, strip_doc_string=None, dynamic_axes=None,
           keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
           use_external_data_format=None):

大多数参数使用默认配置即可,下面对常用的几个参数进行介绍:

torch.onnx.export(
                model,             # 需要转换的网络模型
                args,              # ONNX模型输入,通常为 tuple 或 torch.Tensor
                f,                 # ONNX模型导出路径
                input_names=None,  # 按顺序定义ONNX模型输入结点名称,格式为:list of str,若不指定,会使用默认名字
                output_names=None, # 按顺序定义ONNX模型输出结点名称,格式为:list of str,若不指定,会使用默认名字
                opset_version=11   # opset版本,地平线目前仅支持设置为 10 or 11
            )

其它参数的介绍可参考官方torch.onnx.export()函数手册。

4 单输入网络导出ONNX模型代码实操

该节内容主要包括单输入网络构建、模型导出生成ONNX格式、导出的ONNX模型有效性验证三个部分。可直接运行下方代码得到对应的ONNX模型,欢迎参考代码中的注释进行理解。

import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime

# -----------------------------------#
#   定义一个简单的单输入网络   
# -----------------------------------#
class MyNet(nn.Module):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),   # input[3, 28, 28]  output[32, 28, 28]          
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # output[64, 14, 14]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)                             # output[64, 7, 7]
        )

        self.fc = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

# -----------------------------------#
#   导出ONNX模型函数
# -----------------------------------#
def model_convert_onnx(model, input_shape, output_path):
    dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1])
    input_names = ["input1"]        # 导出的ONNX模型输入节点名称
    output_names = ["output1"]      # 导出的ONNX模型输出节点名称

    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        verbose=False,          # 如果指定为True,在导出的ONNX中会有详细的导出过程信息description
        keep_initializers_as_inputs=False,  # 若为True,会出现需要warning消除的问题
        opset_version=11,       # 版本通常为10 or 11
        input_names=input_names,
        output_names=output_names,
    )


if __name__ == '__main__':
    model = MyNet()
    # print(model)
    # 建议将模型转成 eval 模式
    model.eval()
    # 网络模型的输入尺寸
    input_shape = (28, 28)      
    # ONNX模型输出路径
    output_path = './MyNet.onnx'

    # 导出为ONNX模型
    model_convert_onnx(model, input_shape, output_path)
    print("model convert onnx finsh.")

    # -----------------------------------#
    #   复杂模型可以使用下面的方法进行简化   
    # -----------------------------------#
    # import onnxsim
    # MyNet_sim = onnxsim.simplify(onnx.load(output_path))
    # onnx.save(MyNet_sim[0], "MyNet_sim.onnx")

    # -----------------------------------------------------------------------#
    #   第一轮ONNX模型有效性验证,用来检查模型是否满足 ONNX 标准   
    #   这一步是必要的,因为无论模型是否满足标准,ONNX 都允许使用 onnx.save 存储模型,
    #   我们都不会希望生成一个不满足标准的模型~
    # -----------------------------------------------------------------------#
    onnx_model = onnx.load(output_path)
    onnx.checker.check_model(onnx_model)
    print("onnx model check_1 finsh.")

    # ----------------------------------------------------------------#
    #   第二轮ONNX模型有效性验证,用来验证ONNX模型与Pytorch模型的推理一致性   
    # ----------------------------------------------------------------#
    # 随机初始化一个模型输入,注意输入分辨率
    x = torch.randn(size=(1, 3, input_shape[0], input_shape[1]))
    # torch模型推理
    with torch.no_grad():
        torch_out = model(x)
    print(torch_out)            # tensor([[-0.5728,  0.1695, ..., -0.3256,  1.1357, -0.4081]])
    # print(type(torch_out))      # <class 'torch.Tensor'>

    # 初始化ONNX模型
    ort_session = onnxruntime.InferenceSession(output_path)
    # ONNX模型输入初始化
    ort_inputs = {ort_session.get_inputs()[0].name: x.numpy()}
    # ONNX模型推理
    ort_outs = ort_session.run(None, ort_inputs)
    # print(ort_outs)             # [array([[-0.5727689 ,  0.16947027,  ..., -0.32555276,  1.13574252, -0.40812433]], dtype=float32)]
    # print(type(ort_outs))       # <class 'list'>,里面是个numpy矩阵
    # print(type(ort_outs[0]))    # <class 'numpy.ndarray'>
    ort_outs = ort_outs[0]        # 把内部numpy矩阵取出来,这一步很有必要

    # print(torch_out.numpy().shape)      # (1, 10)
    # print(ort_outs.shape)               # (1, 10)

    # ----------------------------------------------------------------#
    # 比较实际值与期望值的差异,通过继续往下执行,不通过引发AssertionError
    # 需要两个numpy输入
    # ----------------------------------------------------------------#
    np.testing.assert_allclose(torch_out.numpy(), ort_outs, rtol=1e-03, atol=1e-05)
    print("onnx model check_2 finsh.")

5 多输入网络导出ONNX模型代码实操

该节内容主要包括多输入网络构建、模型导出生成ONNX格式、导出的ONNX模型有效性验证三个部分。可直接运行下方代码得到对应的ONNX模型,欢迎参考代码中的注释进行理解。

import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime

# -----------------------------------#
#   定义一个简单的双输入网络   
# -----------------------------------#
class MyNet_multi_input(nn.Module):
    def __init__(self, num_classes=10):
        super(MyNet_multi_input, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)   # input[3, 28, 28]  output[32, 14, 14]
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)   # input[1, 28, 28]  output[32, 14, 14]
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)

        self.fc = nn.Linear(48 * 14 * 14, num_classes)

    def forward(self, x, y):
        x = self.relu1(self.bn1(self.conv1(x)))
        y = self.relu2(self.bn2(self.conv2(y)))
        z = torch.cat((x, y), 1)
        z = torch.flatten(z, start_dim=1)
        z = self.fc(z)
        return z

# -----------------------------------#
#   导出ONNX模型函数
# -----------------------------------#
def multi_input_model_convert_onnx(model, input_shape, output_path):
    dummy_input1 = torch.randn(1, 3, input_shape[0], input_shape[1])
    dummy_input2 = torch.randn(1, 1, input_shape[0], input_shape[1])
    input_names = ["input1", "input2"]        # 导出的ONNX模型输入节点名称
    output_names = ["output1"]      # 导出的ONNX模型输出节点名称

    torch.onnx.export(
        model,
        (dummy_input1, dummy_input2),
        output_path,
        verbose=False,          # 如果指定为True,在导出的ONNX中会有详细的导出过程信息description
        keep_initializers_as_inputs=False,  # 若为True,会出现需要warning消除的问题
        opset_version=11,       # 版本通常为10 or 11
        input_names=input_names,
        output_names=output_names,
    )

if __name__ == '__main__':
    multi_input_model = MyNet_multi_input()
    # print(multi_input_model)
    # 建议将模型转成 eval 模式
    multi_input_model.eval()
    # 网络模型的输入尺寸
    input_shape = (28, 28)      
    # ONNX模型输出路径
    multi_input_model_output_path = './multi_input_model.onnx'

    # 导出为ONNX模型
    multi_input_model_convert_onnx(multi_input_model, input_shape, multi_input_model_output_path)
    print("multi_input_model convert onnx finsh.")

    # -----------------------------------#
    #   复杂模型可以使用下面的方法进行简化   
    # -----------------------------------#
    # import onnxsim
    # multi_input_model_sim = onnxsim.simplify(onnx.load(multi_input_model_output_path))
    # onnx.save(multi_input_model_sim[0], "multi_input_model_sim.onnx")

    # -----------------------------------------------------------------------#
    #   第一轮ONNX模型有效性验证,用来检查模型是否满足 ONNX 标准   
    #   这一步是必要的,因为无论模型是否满足标准,ONNX 都允许使用 onnx.save 存储模型,
    #   我们都不会希望生成一个不满足标准的模型~
    # -----------------------------------------------------------------------#
    onnx_model = onnx.load(multi_input_model_output_path)
    onnx.checker.check_model(multi_input_model_output_path)
    print("onnx model check_1 finsh.")

    # ----------------------------------------------------------------#
    #   第二轮ONNX模型有效性验证,用来验证ONNX模型与Pytorch模型的推理一致性   
    # ----------------------------------------------------------------#
    # 随机初始化一个模型输入,注意输入分辨率
    x = torch.randn(size=(1, 3, input_shape[0], input_shape[1]))
    y = torch.randn(size=(1, 1, input_shape[0], input_shape[1]))
    # torch模型推理
    with torch.no_grad():
        torch_out = multi_input_model(x, y)
    # print(torch_out)            # tensor([[-0.5728,  0.1695, ..., -0.3256,  1.1357, -0.4081]])
    # print(type(torch_out))      # <class 'torch.Tensor'>

    # 初始化ONNX模型
    ort_session = onnxruntime.InferenceSession(multi_input_model_output_path)
    # ONNX模型输入初始化
    ort_inputs = {ort_session.get_inputs()[0].name: x.numpy(), ort_session.get_inputs()[1].name: y.numpy()}
    # ONNX模型推理
    ort_outs = ort_session.run(None, ort_inputs)
    # print(ort_outs)             # [array([[-0.5727689 ,  0.16947027,  ..., -0.32555276,  1.13574252, -0.40812433]], dtype=float32)]
    # print(type(ort_outs))       # <class 'list'>,里面是个numpy矩阵
    # print(type(ort_outs[0]))    # <class 'numpy.ndarray'>
    ort_outs = ort_outs[0]        # 把内部numpy矩阵取出来,这一步很有必要

    # print(torch_out.numpy().shape)      # (1, 10)
    # print(ort_outs.shape)               # (1, 10)

    # ----------------------------------------------------------------#
    # 比较实际值与期望值的差异,通过继续往下执行,不通过引发AssertionError
    # 需要两个numpy输入
    # ----------------------------------------------------------------#
    np.testing.assert_allclose(torch_out.numpy(), ort_outs, rtol=1e-03, atol=1e-05)
    print("onnx model check_2 finsh.")

更多内容可参考 PyTorch官方导出ONNX模型教程。

6 ONNX模型可视化

导出成ONNX模型后,可以使用开源可视化工具Netron来查看网络结构及相关配置信息。Netron的使用方式主要分为两种,一种是使用在线网页版,另一种是下载安装程序。下面以在线网页版打开第4节中导出单输入ONNX模型为例,进行介绍。点击在线网页版链接,打开导出的ONNX模型,可视化效果为:

在这里插入图片描述

地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7。

7 ir_version和opset_version修改

地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时怎么办呢?
如果有条件修改代码重新导出的话,这是一种解决方案。另外一种可尝试的解决方案是直接修改ONNX模型的对应属性,代码示例如下:

import onnx

model = onnx.load("./MyNet.onnx")
model.ir_version = 6
model.opset_import[0].version = 10
onnx.save_model(model, "MyNetOutput.onnx")

注意:高版本向低版本切换时可能会出现问题,这里只是一种可尝试的解决方案。

使用Netron可视化MyNetoutput.onnx,如下图所示:
在这里插入图片描述

8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/393127.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RocketMQ5.1控制台的安装与启动

RocketMQ控制台的安装与启动下载修改配置开放端口号重启防火墙添加依赖编译 rocketmq-dashboard运行 rocketmq-dashboard本地访问rocketmq无法发送消息失败问题。connect to &#xff1c;公网ip:10911&#xff1e; failed下载 下载地址 修改配置 修改其src/main/resources中…

【操作系统原理实验】银行家算法模拟实现

选择一种高级语言如C/C等&#xff0c;编写一个银行家算法的模拟实现程序。1) 设计相关数据结构&#xff1b;2) 实现系统资源状态查看、资源请求的输入等模块&#xff1b;3) 实现资源的预分配及确认或回滚程序&#xff1b;4) 实现系统状态安全检查程序&#xff1b;5) 组装各模块…

TCP模拟HTTP请求

HTTP的特性HTTP是构建于TCP/IP协议之上&#xff0c;是应用层协议&#xff0c;默认端口号80HTTP协议是无连接无状态的HTTP报文请求报文HTTP协议是以ASCⅡ码传输&#xff0c;建立在TCP/IP协议之上的应用层规范。HTTP请求报文由请求行&#xff08;request line&#xff09;、请求头…

Flutter 自定义今日头条版本的组件,及底部按钮切换静态样式

这里写目录标题1. 左右滑动实现标题切换&#xff0c;点击标题也可实现切换&#xff1b;2. 自定义KeepAliveWrapper 缓存页面&#xff1b;2.2 使用3. 底部导航切换&#xff1b;4. 自定义中间大导航&#xff1b;5.AppBar自定义顶部按钮图标、颜色6. Tabbar TabBarView实现类似头条…

iOS开发之UIStackView基本运用

UIStackView UIStackView是基于自动布局AutoLayout&#xff0c;创建可以动态适应设备方向、屏幕尺寸和可用空间的任何变化的用户界面。UIStackView管理其ArrangedSubview属性中所有视图的布局。这些视图根据它们在数组中的顺序沿堆栈视图的轴排列。由axis, distribution, align…

java医院云HIS系统:融合B/S版电子病历系统 能与公卫、PACS等各类外部系统融合

医院HIS系统源码 云HIS系统源码&#xff1a;SaaS运维平台完整文档 有源码&#xff0c;有演示 java基层医院云his系统 融合B/S版电子病历系统&#xff0c;支持电子病历4级 拥有自主知识产权。 看演示及源码可私信我哦&#xff01; 一、系统概述 一款满足二甲医院、基层医疗机构…

九、会话技术CookieSession

会话技术 1&#xff0c;会话跟踪技术的概述 对于会话跟踪这四个词&#xff0c;我们需要拆开来进行解释&#xff0c;首先要理解什么是会话&#xff0c;然后再去理解什么是会话跟踪: 会话:用户打开浏览器&#xff0c;访问web服务器的资源&#xff0c;会话建立&#xff0c;直到有…

3D目标检测(二)—— 直接处理点云的3D目标检测网络VoteNet、H3DNet

前言上次介绍了基于Point-Based方法处理点云的模块&#xff0c;3D目标检测&#xff08;一&#xff09;—— 基于Point-Based方法的PointNet点云处理系列,其中相关的模块则是构成本次要介绍的&#xff0c;直接在点云的基础上进行3D目标检测网络的基础。VoteNet对于直接在点云上预…

科目一《综合素质》

目录综合素质重点题型分布注意事项章节分解第一章 职业理念第一节 教育观1. 教育观&#xff08;基本内涵&#xff09;一字不差背过第二节 学生观2. 学生观 一字不差背过第三节 教师观3. 教师观 一字不差背过第二章 教育法律法规第一节 教师的权利与义务第二节 学生的权利及其保…

QT基础入门【Demo篇】QString的相关操作

&#x1f4a2;&#x1f4a2;目录总览&#x1f4a2;&#x1f4a2;&#xff1a;QT基础入门目录总览 QString支持的操作符号有&#xff1a; 用于字符串之间比较&#xff1a;"!" "<" "<" "" ">" 用于字符串之间传递&a…

第九章 多机系统

考柿时间是3.9 文章目录多机系统并行性发展及计算机系统的分类开发并行性的途径计算机系统的分类(Flynn分类)SISD与片内并行&#xff08;芯片内的并行机制&#xff09;SIMD分成两个子类&#xff1a;MIMD分为两类&#xff08;**主要区别就是它们是否有共享的内存**、单系统映像&…

MTK平台修改AP 5G模式下所支持的频宽

代码路径 vendor/mediatek/kernel_modules/connectivity/wlan/core/gen4m/common/wlan_lib.c 将 prWifiVar->ucAp5gBandwidth (uint8_t) wlanCfgGetUint32( -- prAdapter, "Ap5gBw", MAX_BW_80MHZ); prAdapter, "Ap5gBw&quo…

Windows下UXP插件开发环境搭建及程序试运行

从PS2021开始&#xff0c;Adobe官方引入了新的插件平台&#xff1a;UXP&#xff0c;它的最终目标任务是取代现有的CEP&#xff0c;所以赶紧来提前做一下准备吧&#xff0c;我对这方面也一直很感兴趣&#xff0c;但是这方面的中文资料太少了&#xff0c;然后在网上查了一些资料和…

【剧前爆米花--爪哇岛寻宝】包装类的装拆箱和泛型的擦除机制

作者&#xff1a;困了电视剧 专栏&#xff1a;《数据结构--Java》 文章分布&#xff1a;这是关于数据结构的基础之一泛型的文章&#xff0c;希望对你有所帮助。 目录 包装类 装箱 装箱源码小细节 拆箱 泛型 什么是泛型 泛型编译的擦除机制 不能实例化泛型类型数组 包装…

LicenseBox Crack,对服务器的要求最低

LicenseBox Crack,对服务器的要求最低 LicenseBox是用于管理基于PHP的软件、WordPress插件或主题、主题、插件和WordPress的更新和许可的完整软件。它易于安装&#xff0c;对服务器的要求最低&#xff0c;用户友好的界面&#xff0c;无限脚本的使用为您的创造力打开了大门。 Li…

基于STM32的水质浑浊度和PH值监测系统设计(仿真+程序+讲解)

基于STM32的水质浑浊度和PH值监测系统设计(仿真程序讲解&#xff09; 仿真图proteus 8.9 程序编译器&#xff1a;keil 5 编程语言&#xff1a;C语言 设计编号&#xff1a;C0077 这里写目录标题演示讲解视频1.主要功能2.仿真3. 程序4.资料清单&下载链接演示讲解视频 基于…

基于JSP的网上书店的设计与实现

技术&#xff1a;Java、JSP等摘要&#xff1a;近年来&#xff0c;随着互联网的迅速普及&#xff0c;网络已经走进了千家万户&#xff0c;作为信息交流的一种平台&#xff0c;它给我们的日常生活带来了很大的便利。今天&#xff0c;各种各样的网站已经深入到了我们的日常生活&am…

单例模式之饿汉式

目录 1 单例模式的程序结构 2 饿汉式单例模式的实现 3 饿汉式线程安全 4 防止反射破坏单例 5 总结 单例模式&#xff08;Singleton Pattern&#xff09;是 Java 中最简单的设计模式之一。所谓单例就是在系统中只有一个该类的实例&#xff0c;并且提供一个访问该实例的全局…

2023年房地产定价模型研究报告

第一章 房地产定价模型概述 受疫情和房地产发展模式影响&#xff0c;目前我国房地产行业遭受着多重冲击&#xff0c;消费者不断降低的购房意愿&#xff0c;频繁出现的烂尾楼问题&#xff0c;建筑材料和工人价格的不断上涨等。而房地产行业本身又是带动如电器&#xff0c;装修&…

《C++程序设计原理与实践》笔记 第14章 设计图形类

本章借助图形接口类介绍接口设计的思想和继承的概念。为此&#xff0c;本章将介绍与面向对象程序设计直接相关的语言特性&#xff1a;类派生、虚函数和访问控制。 14.1 设计原则 我们的图形接口类的设计原则是什么&#xff1f; 14.1.1 类型 我们的程序设计理念是在代码中直…