【tensorflow onnx】TensorFlow2导出ONNX及模型可视化教程

news2025/1/23 21:32:34

文章目录

  • 1 背景介绍
  • 2 实验环境
  • 3 tf2onnx工具介绍
  • 4 代码实操
    • 4.1 TensorFlow2与ONNX模型导出
    • 4.2 ONNX正确性验证
    • 4.3 TensorFlow2与ONNX的一致性检查
    • 4.4 多输入的情况
    • 4.5 设定输入/输出节点
  • 5 ONNX模型可视化
  • 6 ir_version和opset_version修改
  • 7 ONNX输入输出维度修改
  • 8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

1 背景介绍

使用深度学习开源框架Pytorch训练完网络模型后,在部署之前通常需要进行格式转换,地平线工具链模型转换目前支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将TensorFlow2得到的模型导出为ONNX格式。

2 实验环境

本教程的实验环境如下:

Python库Version
tensorflow-cpu2.11.0
tensorflow-intel2.11.0
tf2onnx1.13.0
protobuf3.20.2
onnx1.13.0
onnxruntime1.14.0

3 tf2onnx工具介绍

tf2onnx可以通过命令行的方式将TensorFlow/Keras的模型转换为ONNX,该工具的主要配置参数如下:

python -m tf2onnx.convert
    --saved-model          #以save-model方式保存的tf模型文件夹
    --output               #转换为ONNX格式的完整模型名称
    --opset                #默认为13,请手动配置10或11
    --inputs               #可选,用于指定导出的首节点
    --outputs              #可选,用于指定导出的尾节点

tf2onnx的更多详细介绍可以参考: https://github.com/onnx/tensorflow-onnx

4 代码实操

4.1 TensorFlow2与ONNX模型导出

以下代码展示了如何搭建一个简单分类模型以TensorFlow2的save-model方式保存并转换为ONNX格式。

import tensorflow as tf
import os
import onnx

def MyNet():
    input1 = tf.keras.layers.Input(shape=(7, 7, 3))

    x = tf.keras.layers.Conv2D(16, (3, 3),
               activation='relu',
               padding='same',
               name='conv1')(input1)
    x = tf.keras.layers.Conv2D(16, (3, 3),
               activation='relu',
               padding='same',
               name='conv2')(x)

    x = tf.keras.layers.Flatten(name='flatten')(x)
    x = tf.keras.layers.Dense(100, activation='relu', name='fc1')(x)
    output = tf.keras.layers.Dense(2, activation='softmax', name='predictions')(x)

    input_1 = input1
    model = tf.keras.models.Model(inputs=[input_1], outputs=output)
    return model

model = MyNet()

#需要先使用model.save方法保存模型
model.save('model')
#调用tf2onnx将上一步保存的模型导出为ONNX
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")

4.2 ONNX正确性验证

可以用以下代码验证ONNX模型的正确性,会检查模型的版本,图的结构,节点及输入输出。若输出为 Check: None 则表示无报错信息,模型导出正确。

import onnx

onnx_model = onnx.load("./model.onnx")
check = onnx.checker.check_model(onnx_model)
print('Check: ', check)

4.3 TensorFlow2与ONNX的一致性检查

可以使用以下代码检查导出的ONNX模型和原始的PaddlePaddle模型是否有相同的计算结果。

import tensorflow as tf
import onnxruntime
import numpy as np

input1 = np.random.random((1, 7, 7, 3)).astype('float32')

ort_sess = onnxruntime.InferenceSession("./model.onnx")
ort_inputs = {ort_sess.get_inputs()[0].name: input1}
ort_outs = ort_sess.run(None, ort_inputs)

tf_model = tf.saved_model.load(export_dir="model")
tf_outs = tf_model(inputs=input1)

print(ort_outs[0])
print(tf_outs.numpy())
np.testing.assert_allclose(tf_outs.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("onnx model check finsh.")

4.4 多输入的情况

若您的模型存在多输入,则可参考下方代码以TensorFlow2的save-model方式保存并转换为ONNX格式。

import tensorflow as tf
import os

def MyNet():
    input1 = tf.keras.layers.Input(shape=(7, 7, 3))
    input2 = tf.keras.layers.Input(shape=(7, 7, 3))

    x = tf.keras.layers.Conv2D(16, (3, 3),
               activation='relu',
               padding='same',
               name='conv1')(input1)
    y = tf.keras.layers.Conv2D(16, (3, 3),
               activation='relu',
               padding='same',
               name='conv2')(input2)
    z = tf.keras.layers.Concatenate(axis=-1)([x, y])
    z = tf.keras.layers.Flatten(name='flatten')(z)
    z = tf.keras.layers.Dense(100, activation='relu', name='fc1')(z)
    output = tf.keras.layers.Dense(2, activation='softmax', name='predictions')(z)

    input_1 = input1
    input_2 = input2
    model = tf.keras.models.Model(inputs=[input_1,input_2], outputs=output)
    return model

model = MyNet()

model.save('model')
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")

4.5 设定输入/输出节点

有时考虑到部署难度,我们不希望TensorFlow网络结构的前后处理部分也导入进ONNX模型。此时可以使用tf2onnx工具的inputs和outputs参数,指定导出的首尾节点,这样首节点之前和尾节点之后的部分都不会导入进ONNX模型。

5 ONNX模型可视化

导出成ONNX模型后,可以使用开源可视化工具Netron来查看网络结构及相关配置信息。Netron的使用方式主要分为两种,一种是使用在线网页版 https://netron.app/ ,另一种是下载安装程序 https://github.com/lutzroeder/netron 。此教程中模型的可视化效果为:

6 ir_version和opset_version修改

地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时,可以修改代码重新导出,或者尝试编写脚本直接修改ONNX模型的对应属性,第二种方式的示例代码如下:

import onnx

model = onnx.load("./model.onnx")
model.ir_version = 6
model.opset_import[0].version = 11
onnx.save_model(model, "./model_version.onnx")

注意: 高版本向低版本切换时可能会出现问题,这里只是一种可尝试的解决方案。

7 ONNX输入输出维度修改

当发现使用tf2onnx工具保存的ONNX模型的输入输出节点出现异常值时,比如以下情况:

在这里插入图片描述

可以使用如下代码进行修改:

import onnx

onnx_model = onnx.load("./model.onnx")
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx.save(onnx_model, './model_dim.onnx')

打开保存的ONNX模型文件,可以看到输入输出节点的维度已经正常:
在这里插入图片描述

至此,该ONNX模型已满足地平线工具链的转换条件。

8 致谢

原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/397014.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教学典型案例】18.开门小例子理解面向对象

目录一:背景介绍业务场景:业务分析:二:实现思路1、面向过程:2、面向对象(抽象、封装、继承、多态)3、面向对象(抽象、封装、继承、多态、反射)三:实现过程1、…

如何在 Istio 中使用 SkyWalking 进行分布式追踪

在云原生应用中,一次请求往往需要经过一系列的 API 或后台服务处理才能完成,这些服务有些是并行的,有些是串行的,而且位于不同的平台或节点。那么如何确定一次调用的经过的服务路径和节点以帮助我们进行问题排查?这时候…

二极管损坏的原因有哪些?

大家好,我是记得诚。 最近项目上肖特基二极管出问题了,概率性损坏,二极管本来是一个很简单的器件,这次重新整理一下,供大家参考。 二极管损坏,个人总结有如下几种情况。 1、过压 在Ta=25℃下,超过二极管的最大反向电压VR,二极管可能会被击穿,导致损坏。 2、过流 …

SpringBoot的基本概念和使用

文章目录一、什么是SpringBoot二、Spring Boot优点三、Spring Boot项目创建四、Spring Boot 配置文件1. yml语法2.properties与yml关系3.多系统的配置五、Spring Boot日志文件1.日志对象2.日志级别日志级别的设置System.out.println VS 日志的两个致命缺点3.日志持久化4.更简单…

[ 常用工具篇 ] windows安装phpStudy_v8.1_X64

🍬 博主介绍 👨‍🎓 博主介绍:大家好,我是 _PowerShell ,很高兴认识大家~ ✨主攻领域:【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 🎉点赞➕评论➕收藏 养成习…

如何实现大文件断点续传、秒传

大家先来了解一下几个概念: 「文件分块」:将大文件拆分成小文件,将小文件上传\下载,最后再将小文件组装成大文件; 「断点续传」:在文件分块的基础上,将每个小文件采用单独的线程进行上传\下载&…

CobaltStrike密码爆破、伪造上线以及DDos——csIntruder

Git仓库: https://github.com/ljy1058318852/csIntruder0x01 概述 本项目包含CobaltStrike密码爆破、伪造上线以及DDos功能。其中伪造上线支持常见魔改版CS。 This project includes CobaltStrike password blasting, fake online and DDos functions. Among them…

云计算创新展望-精耕细作的超级云计算平台

前言在当今云计算深入各行业、计算量暴增现状之下,云计算生态迎来百花齐放。但用户不希望将所有鸡蛋放在一个篮子里面,因此每个企业都在发展自己的私有云、公有云等多云、混合云结构。因云计算的高灵活性、可扩展性、高性价比,在本地10台服务…

ubuntu的快速安装与配置

文章目录前言一、快速安装二 、基础配置1 Sudo免密码2 ubuntu20.04 pip更新源3 安装和配置oneapi(infort/mpi/mkl) apt下载第一次下载的要建立apt源apt下载(infort/mpi/mkl)4 安装一些依赖库等5 卸载WSLpython总结前言 win11系统 ubuntu20.04 提示:以下…

【力扣-10天SQL入门】5~8天刷题 知识点总结

https://leetcode.cn/study-plan/sql/?progressjgmzq5s第5天 合并175. 组合两个表就是一个简单的left join1581. 进店却未进行过交易的顾客Q:两个表Visits(有visit_id和customer_id两列)和Transactions(有transaction_id、visit_…

Go垃圾回收原理

术语介绍 赋值器:说白了就是你写的程序代码,在程序的执行过程中,可能会改变对象的引用关系,或者创建新的引用。 回收器:垃圾回收器的责任就是去干掉那些程序中不再被引用得对象。 STW:全称是stop the word,GC期间某个阶段会停止…

插值多项式的龙格现象的介绍与模拟

在文章拉格朗日插值多项式的原理介绍及其应用中,笔者介绍了如何使用拉格朗日插值多项式来拟合任意数据点集。   事实上,插值多项式会更倾向于某些形状。德国数学家卡尔龙格Carl Runge发现,插值多项式在差值区间的端点附近会发生扭动&#x…

一篇文章彻底理解setState是同步还是异步!

本文内容均针对于18.x以下版本setState 到底是同步还是异步?很多人可能都有这种经历,面试的时候面试官给了你一段代码,让你说出输出的内容,比如这样:constructor(props) {super(props);this.state {data: data} }comp…

Sentinel架构篇 - 来源访问控制

来源访问控制(黑白名单) 概念 Sentinel 提供了黑白名单限制资源能否通过的功能。如果配置了白名单,则只有位于白名单的请求来源的对应的请求才能通过;如果配置了黑名单,则位于黑名单的请求来源对应的请求不能通过。 …

图形报表ECharts

图形报表ECharts1 图形报表ECharts1.1 ECharts简介-富客户端图表库ECharts缩写来自Enterprise Charts,商业级数据图表,是百度的一个开源的使用JavaScript实现的数据可视化工具,可以流畅的运行在PC和移动设备上,兼容当前绝大部分浏…

【3.8】操作系统内存管理、Redis数据结构、哈希表

内存满了,会发生什么? 当应用程序读写了这块虚拟内存,CPU 就会去访问这个虚拟内存, 这时会发现这个虚拟内存没有映射到物理内存, CPU 就会产生缺页中断,进程会从用户态切换到内核态,并将缺页中…

MySQL索引15连问,抗住!

1. 索引是什么?索引是一种能提高数据库查询效率的数据结构。它可以比作一本字典的目录,可以帮你快速找到对应的记录。索引一般存储在磁盘的文件中,它是占用物理空间的。正所谓水能载舟,也能覆舟。适当的索引能提高查询效率&#x…

实战小项目之视频监控(1-2)

实战小项目之视频监控(1-2) Nginx 移植 前面也给大家提到了,我们可以使用 Nginx 来搭建 RTMP 流媒体服务器,譬如你可以在一台公网 IP 主 机上搭建流媒体服务器,当然,笔者并没有这个条件;这里我…

2023年计算语言学和自然语言处理国际会议(CLNLP 2023)

2023年计算语言学和自然语言处理国际会议(CLNLP 2023) 重要信息 会议网址:www.clnlp.org 会议时间:2023年8月18-20日 召开地点:中国南京 截稿时间:2023年6月31日 录用通知:投稿后2周内 收…

MATLAB绘制三Y轴坐标图:补充坐标轴及字体设置

三轴坐标图 1 函数 MATLAB绘制三轴图函数可见MATLAB帮助-multiplotyyy 基础图形绘制是很简单,但坐标轴及字体设置该如何实现呢? 本文以以下几个例子为例,希望可以解决在利用MATLAB绘制三轴坐标图时常见的疑惑。 2 案例 2.1 案例1&#xf…