tf2使用savemodel保存之后转化为onnx适合进行om模型部署

news2025/1/16 18:46:40

tf2使用savemodel保存之后转化为onnx适合进行om模型部署

  • tf保存为kears框架h5文件
  • 将h5转化为savemodel格式,方便部署
  • 查看模型架构
  • 将savemodel转化为onnx格式
  • 使用netron
  • onnx模型细微处理
  • 代码转化为om以及推理代码,要么使用midstudio

tf保存为kears框架h5文件

前提环境是tf2.2及其版本以上的框架,模型训练结果保存为h5(也就是kears框架)
Pasted image 20240507233042

将h5转化为savemodel格式,方便部署

之后将h5文件转化为savemodel的格式
Pasted image 20240507233120

custom是在保存模型的时候需要的自定义函数,如果没有则不需要添加

保存结果如下
Pasted image 20240507233222

这个地方记得验证一下savemodel格式是否能成功搭载测试代码

import  os  
import pandas as pd  
import numpy as np  
from sklearn.metrics import accuracy_score  
from sklearn.model_selection import train_test_split  
from tensorflow.keras.models import Sequential  
from tensorflow.keras.layers import LSTM,Dense,Dropout  
from keras.utils import to_categorical  
import tensorflow as tf  
from tensorflow.python.keras.layers import Activation  
  
os.chdir('D:/software_project/心电信号分类/')  
  
  
# 加载 SavedModel 目录  
loaded_model = tf.saved_model.load('tfmodel_save')  
  
# 获取默认的服务签名  
infer = loaded_model.signatures['serving_default']  
print(infer.structured_input_signature)  
print(infer.structured_outputs)  
  
# 加载CSV文件  
file_path = 'data2/shuffled_merged_data.csv'  
data = pd.read_csv(file_path)  
from sklearn.preprocessing import StandardScaler  
  
# 创建StandardScaler实例  
scaler = StandardScaler()  
features = data.iloc[0:1, :-1]  
  
# 获取最后一列作为标签  
labels = data.iloc[0:1, -1]  
features1 = scaler.fit_transform(features)  
# features1 = features1.astype(np.float32)  
  
# # 转化为numpy  
# features = features.to_numpy()  
trainX3 = features1.reshape((features1.shape[0], features1.shape[1], 1))  
  
# # 将数据转换为Tensor  
input_data = tf.convert_to_tensor(trainX3, dtype=tf.float32)  
  
  
output = infer(conv1d_input=input_data)  
output4=output['dense_3']  
print(output4.numpy())  
  
# 为了确定每个样本的预测标签,我们找到概率最高的类别的索引  
predicted_indices = np.argmax(output4.numpy(), axis=1)  
accuracy = accuracy_score(labels, predicted_indices)  
print(accuracy)  
  
output2=output["dense_8"]  
print(output["dense_8"])  
predicted_indices2= np.argmax(output2.numpy(), axis=1)  
accuracy2 = accuracy_score(labels, predicted_indices2)  
print(accuracy2)  
  
  
output2_1=output["dense_8_1"]  
print(output["dense_8_1"])  
predicted_indices2= np.argmax(output2_1.numpy(), axis=1)  
accuracy3 = accuracy_score(labels, predicted_indices2)  
print(accuracy3)  
  
print('nihao')  
# 不可用  
# print(output["StatefulPartitionedCall:0"])

查看模型架构

可以使用这个代码查看模型架构,输入输出的名字

 saved_model_cli show --dir D:\software_project\心电信号分类\tfmodel_save --tag_set serve --sig
nature_def serving_default

结构如下
Pasted image 20240507233536

如果可以用咱们继续进行下一步

将savemodel转化为onnx格式

之后将保存的savemodel格式转化为onnx格式

这里直接上大佬博客
在Atlas 200 DK中部署深度学习模型

基本把每个步骤过一遍即可

注意安装tensorflowgpu的版本是很高的
Pasted image 20240507233803

转换指令

python -m tf2onnx.convert --saved-model tensorflow-model-path --output model.onnx

使用netron

把模型放入到netron中
Netron

导出的onnx模型如下
Pasted image 20240507234346

onnx模型细微处理

获得的onnx模型放入netron中进行查看,发现有些未知输出量需要修改
【tensorflow onnx】TensorFlow2导出ONNX及模型可视化教程_tf2onnx-CSDN博客

主要是这种未知量
Pasted image 20240507233928

代码转化为om以及推理代码,要么使用midstudio

之后即可使用代码进行模型的转化为om

转化成功之后,放到atlks200dk板子中进行模型的推理
代码

import numpy as np  
import acllite_utils as utils  
import constants as const  
from acllite_model import AclLiteModel  
from acllite_resource import AclLiteResource  
import time  
import csv  
import numpy as np  
  
class Reasoning(object):  
    """  
    class for reasoning    """    def __init__(self, model_path):  
        self._model_path = model_path  
        self.device_id = 0  
        self._model = None  
    def init(self):  
        """  
        Initialize        """  
        # Load model  
        self._model = AclLiteModel(self._model_path)  
  
        return const.SUCCESS  
    def inference(self, one_dim_data):  
        """  
        model inference        """        return self._model.execute(one_dim_data)  
  
def main():  
    model_path = 'model_dim_replace.om'  
    # 打开 CSV 文件  
    with open('shuffled_merged_data.csv', newline='') as csvfile:  
        # 创建 CSV 读取器对象  
        csvreader = csv.reader(csvfile, delimiter=',')  
        # 跳过第一行(标题行)  
        next(csvreader)  
        # 读取第二行数据  
        second_row = next(csvreader)  
        # 移除最后一个数据  
        second_row_without_last = second_row[:-1]  
        # 将数据转换为 NumPy 数组  
        np_array = np.array(second_row_without_last, dtype=np.float32)  
        print(np_array.dtype)  
        # 输出转换后的 NumPy 数组  
    acl_resource = AclLiteResource()  
    acl_resource.init()  
    reasoning = Reasoning(model_path)  
    # init  
    ret = reasoning.init()  
    utils.check_ret("Reasoning.init ", ret)  
    start_time = time.time()  
  
    # 假设你有一个名为 input_data 的 NumPy 数组,它包含模型的输入数据  
    input_data = np.array([np_array])  # 替换为你的输入数据  
    result_class = reasoning.inference(input_data)  
  
  
    end_time = time.time()  
    execution_time = end_time - start_time  
  
    print(result_class)  
if __name__ == '__main__':  
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1653698.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows注册表

注册表 一.概述 注册表(Registry)是Microsoft Windows中的一个重要的数据库,用于[存储系统]和[应用程序]的设置信息。早在[Windows 3.0]推出[OLE]技术的时候,注册表就已经出现。随后推出的[Windows NT]是第一个从系统级别广泛使…

IT项目管理 选择/判断 【太原理工大学】

第一章、IT项目管理 判断题 1、搬家属于项目。( 对 ) 2、项目是为了创造一个唯一的产品或提供一个唯一的服务而进行的永久性的努力。( 错 ) 3、项目具有临时性的特征。( 对 ) 4、项目开发过程…

基于RTI Connext使用Simulink的DDS Blockset

MathWorks一直是数据分发服务(DDS)标准的长期支持者。RTI Connext基于DDS,已与Simulink集成多年,使用户能够导入数据进行更逼真的模拟工作。 2021年,MathWorks通过其新推出的Simulink附加产品DDS Blockset提高了标准。…

怎么制作好玩的gif?试试这个工具轻松制作

视频之所以受大众的喜爱是因为有声音、画面的搭配,让观者深入其中体验感会更强。但是视频的体积较大、时长也比较长,给我们的传播和保存造成了一定的影响。那么,我们可以将视频制作成gif图片来使用,不需要下载软件,使用…

在哪个网站找视频素材?在哪可以下视频素材?

在这个视频内容极为重要的时代,高质量的视频素材成为了创作的关键。特别是4K和无水印视频素材,它们不仅提升了视觉效果,也为作品增加了专业度。以下是一些优秀的国内外视频素材网站,希望能助您一臂之力。 1. 蛙学府 专注于为中国…

01-基本概念- 索引,文档和 REST API

# kibana_sample_data_ecommerce 为es 索引#查看索引相关信息 GET kibana_sample_data_ecommerce#查看索引的文档总数 GET kibana_sample_data_ecommerce/_count#查看前10条文档,了解文档格式 POST kibana_sample_data_ecommerce/_search { }#_cat indices API #查看…

福州网站建设如何设计极简风格合理?

福州网站建设如何设计极简风格合理?企业网站逐渐流行,每个人的审美也发生着巨大的改变,开始追求一种极简的风格。简单的 风格才能够凸显原有的主题,不会太过主次不分。 越来越多的网站建设中选择极简的风格,简单的页面…

8款好用的电脑监控软件分享丨好资源不私藏!

电脑已经成为我们日常生活和工作的重要工具。随之而来的是,电脑监控的需求也逐渐增加。为了帮助大家更好地管理和监控电脑使用情况,本文将为您推荐8款好用的电脑监控软件。这些软件功能强大,易于使用,适用于各种场景,让…

什么情况下 MySQL 连查询都能被阻塞?

MySQL 的锁也是不少,在哪种情况下会连查询都能被阻塞?这是一个有意思的问题。 工作中,很多开发和 DBA 可能接触较多的锁也就行锁了。对于行锁,阻塞写能理解,阻塞读实在是想不到。能阻塞读的那肯定是颗粒度更大的锁了&…

电脑怎么压缩图片?压缩图片并不难

电脑怎么压缩图片?随着数字时代的来临,我们每天都在与大量的图片打交道,无论是社交媒体上的个人照片,还是工作中的设计图片,或是网页上的广告图片。然而,高质量的图片往往意味着大文件大小,这不…

如何将jsp项目转成springboot项目

昨天说过,springboot推荐使用Thymeleaf作为前后端渲染的模板引擎,为什么推荐用Thymeleaf呢,有以下几个原因: 动静结合:Thymeleaf支持HTML原型,允许在HTML标签中增加额外的属性来实现模板与数据的结合。这样…

Linux的基础IO:文件系统

目录 学前补充 磁盘的存储结构 OS如何对磁盘的存储进行逻辑抽象 细节内容 文件的增删改查 学前补充 问题:计算机只认二进制,即0、1,什么是0、1? 解释:0、1在物理层面可能有不同的表现,0、1是数字逻辑…

美股订单类型有哪些

美股交易中,订单类型是投资者执行交易指令的重要工具。了解不同类型的订单,可以帮助投资者制定更有效的交易策略,并控制风险。 1. 市价单:快速成交,不惧踏空 市价单(Market Order)是一种以当时…

【NodeMCU实时天气时钟温湿度项目 5】获取关于城市天气实况和天气预报的JSON信息(心知天气版)

| 今天是第五专题内容,主要是介绍如何从心知天气官网,获取包含当前天气实况和未来 3 天天气预报的JSON数据信息。 在学习获取及显示天气信息前,我们务必要对JSON数据格式有个深入的了解。 如您需要了解其它专题的内容&#xf…

鸿蒙内核源码分析(ELF格式篇) | 应用程序入口并不是main

阅读之前的说明 先说明,本篇很长,也很枯燥,若不是绝对的技术偏执狂是看不下去的.将通过一段简单代码去跟踪编译成ELF格式后的内容.看看ELF究竟长了怎样的一副花花肠子,用readelf命令去窥视ELF的全貌,最后用objdump命令…

吴恩达2022机器学习专项课程C2(高级学习算法)W1(神经网络):2.5 更复杂的神经网络

目录 示例填写第三层的层数1.问题2.答案 公式:计算任意层的激活值激活函数 示例 层数有4层,不包括输入层。 填写第三层的层数 1.问题 你能把第二个神经元的上标和下标填写出来吗? 2.答案 根据公式g(wxb),这里的x对应的是上…

Liunx系统怎么设置免密登录?看这一篇!

远程口令爆破也是黑客常用的手段,有些人安全意识薄弱的会设置一些简单的密码,这样分分钟会被黑客爆破进去,一旦操作系统沦陷,里面的数据必将被黑客一览无余,使用免密登录可以有效降低密码被爆破的风险,具体…

C++学习第十二天(继承)

1、继承的概念以及定义 继承的概念 继承机制是面向对象程序设计使代码可以复用的最重要的手段,它允许程序员在保持原有类特性的基础上进行拓展,增加功能,这样产生新的类,称派生类。继承呈现了面向对象程序设计的层次结构&#x…

光伏设计的核心要素有哪些?

光伏设计是可再生能源领域中的一个重要分支,它涉及到将太阳能转换为电能的整个过程。在光伏系统的设计和构建过程中,有几个核心要素需要被充分考虑和精确计算,以确保系统的性能、可靠性和经济效益。 一、光照条件分析 光照条件是光伏系统设计…

从Python整数变量内存大小占用28字节谈起

实验结果 本机环境64位Python 3.12 内存布局图 0 4 8 12 16 20 24 28 |----------|----------|----------|----------|----------|----------|----------| | ob_refcnt | ob_type | ob_digit | …