onnx模型修改:将均值和方差放到模型中

news2024/11/30 8:39:55

训练模型时,一般都会对原始数据进行归一化再送入网络,即减均值和除方差。在部署时,我们也要进行同样的操作。有些推理框架会提供对应的接口,我们只需要设置均值和方差即可,如MNN.也有一些框架不提供这样的功能,如Tensorrt,这时,我们就需要自己去逐像素进行这个操作,不仅繁琐,还可能比较耗时。还有一种方式是将这个操作放到模型中,一个方法是在我们的原始pytorch模型中增加一个固定参数的Batchnorm层,另一种方式就是本文要讲的在导出的onnx模型中插入Sub和Div节点来完成。

插入节点

主要步骤:
1.创建2个常量节点,分别是均值和方差向量
2.分别插入一个Sub和Div节点,Sub节点输入是模型的输入和均值节点,Div节点输入是Sub节点输出和方差向量
3.将输入层后的第一层的输入修改为Div节点的输出。
代码如下:

import onnx
from onnx import numpy_helper
import numpy as np

# 加载ONNX模型
model_path = "xx.onnx"
model = onnx.load(model_path)
# 创建均值和方差张量
mean_value = [0.485*255, 0.456*255, 0.406*255]
variance_value = [0.229*255, 0.224*255, 0.225*255]

mean_tensor = numpy_helper.from_array(np.array(mean_value, dtype=np.float32).reshape(1,3,1,1), "mean")
variance_tensor = numpy_helper.from_array(np.array(variance_value, dtype=np.float32).reshape(1,3,1,1), "variance")

# 插入均值和方差节点
mean_node = onnx.helper.make_node("Constant", [], ["mean"], value=mean_tensor)
variance_node = onnx.helper.make_node("Constant", [], ["variance"], value=variance_tensor)

model.graph.node.insert(0, mean_node)
model.graph.node.insert(1, variance_node)

# 插入归一化节点
input_name = model.graph.input[0].name
normalize_node = onnx.helper.make_node("Sub", [input_name, "mean"], ["sub_output"])
scale_node = onnx.helper.make_node("Div", ["sub_output", "variance"], ["input_norm"])

# 插入节点到模型中
model.graph.node.insert(2, normalize_node)
model.graph.node.insert(3, scale_node)

# 更新模型
model.graph.node[4].input[0] = "input_norm"

# 保存修改后的ONNX模型
modified_model_path = "xx_norm.onnx"
# shape inference
model = onnx.shape_inference.infer_shapes(model)
# check model
onnx.checker.check_model(model)
onnx.save(model, modified_model_path)

在这里插入图片描述

验证模型结果

比对修改前和修改后的模型输出

import onnxruntime as ort
import numpy as np

# 加载原始模型
origin_ort = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
origin_input_name = origin_ort.get_inputs()[0].name
origin_output_name = origin_ort.get_outputs()[0].name
input = np.random.randn(1, 3, 128, 128).astype(np.float32)
input_norm = (input - np.array(mean_value,np.float32).reshape(1, 3, 1, 1)) / np.array(variance_value,np.float32).reshape(1, 3, 1, 1)
origin_output = origin_ort.run([origin_output_name], {origin_input_name: input_norm})[0]

# 加载修改后的模型
modified_ort = ort.InferenceSession(modified_model_path, providers=["CPUExecutionProvider"])
modified_input_name = modified_ort.get_inputs()[0].name
modified_output_name = modified_ort.get_outputs()[0].name
modified_output = modified_ort.run([modified_output_name], {modified_input_name: input})[0]

# 比较两个模型的输出
print(np.allclose(origin_output, modified_output, atol=1e-3))

输出为True证明添加节点后的模型正确,后续使用使,不再需要在外部进行均值和方差操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/951423.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单元测试:优雅编写Kotlin单元测试

一、MockK简介 MockK是一款功能强大、易于使用的Kotlin mocking框架。在编写单元测试时,MockK能够帮助我们简化代码、提高测试覆盖率,并改善测试的可维护性。除了基本用法外,MockK还提供了许多额外的功能和灵活的用法,让我们能够…

剑走偏锋:非传统问题在面试中的应对策略

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

如何优雅地在windows上玩ROS(一个紧致的解决方案)

如何优雅地在windows上玩ROS(一个紧致的解决方案) - 知乎 Excerpt 前言为了节省您的时间,本文适用的OS为win10,win11;适用的ROS1版本为melodic和noetic;适用于ROS2 foxy。如果你的目标OS和ROS不在上述的范围…

雅思写作 三小时浓缩学习顾家北 笔记总结(一)

目录 饥饿网翻译100个句子记录 There are some economically deprived communities in large cities. there is no clear link between grouping student by ability and their levels of attainment. young people without tertiary education qualification normally hav…

java八股文面试[数据库]——索引的基本原理、设计原则

索引的设计原则 索引覆盖是什么: 索引(在MySQL中也叫做“键(key)”) 是存储引擎用于快速找到记录的一种数据结构。这是索引的基本功能。 索引对于良好的性能非常关键。尤其是当表中的数据量越来越大时,索引…

成集云 | 钉钉财务费用单同步至畅捷通 | 解决方案

源系统成集云目标系统 方案介绍 财务管理作为企业管理中重要的组成部分,在企业的发展和成长中扮演着重要角色,成集云以钉钉费用单OA审批与畅捷通TCloud系统为例,与钉钉连接器深度融合,通过数据处理和字段匹配实现了费用…

【LeetCode】290. 单词规律

这里写自定义目录标题 2023-8-30 09:34:23 290. 单词规律 2023-8-30 09:34:23 这道题目,我是根据 205. 同构字符串 的思路一样,都转化为另外一个第三方的字符串,在比较翻译过后的语句是不是一样的。 class Solution {public boolean wordP…

自然语言处理的多行业应用

在我们小时候,甚至是我们会走路或说话之前,就已经在察觉周围发出的声音了。我们倾听其他人发出的声响和声音。我们将声音组合成有意义的词语,例如“母亲”和“门”,并学习解读周围人的面部表情,以加深我们对词组的理解…

火爆全网!HubSpot CRM全面集成,引爆营销业绩!

HubSpot CRM是什么?它是一款强大的客户关系管理工具,专为企业优化销售、服务和市场营销流程而设计。它在B2B行业中扮演着极为重要的角色,让我来告诉你为什么吧! HubSpot CRM不仅拥有用户友好的界面和强大的功能,还能够…

Dimensions网站——一个链接研究知识系统

Dimensions网站——一个链接研究知识系统 一、Dimensions网站简介 Dimensions 是一个链接的研究知识系统,它重新构想了发现和研究的获取。Dimensions 由 Digital Science 与全球 100 多个领先研究组织合作开发,汇集了资助、出版物、引文、替代指标、临…

城市内涝监测预警系统:构筑智慧城市的内涝防控网络

治理城市内涝事关人民群众生命财产安全,既是重大民生工程,又是重大发展工程。近年来,各地区各部门大力推进排水防涝设施建设,城市内涝治理取得积极进展,但仍存在自然调蓄空间不足、排水设施建设滞后、应急管理能力不强…

CTF学习资源

文章目录 一、buuctf靶场1、MD52、一眼解密3、Url编码4、回旋踢5、摩斯6、Password7、变异凯撒8、Quoted-printable9、Rabbit10、篱笆墙的影子11、RSA12、丢失的MD5 二、ctf题型1、PWN,Reverse1)Reverse2)pwn 2、Crypto1)古典密码学2)现代密码学 3、web4、Misc1)Rec…

下岗吧,Excel

ChatGPT的诞生使Excel公式变得过时。通过使用 ChatGPT 的代码解释器你可以做到: 分析数据创建图表 这就像用自然语言与电子表格交谈一样。我将向大家展示如何使用 ChatGPT 执行此操作并将结果导出为Excel格式: 作为示例,我将分析并创建美国…

夸克扫描王App用上了AI大模型 让扫描更清楚、提取文字更方便

对上班族来说,找到一个好用的工具类APP,绝对可以提升工作效率。比如最常见的扫描文件,公司的扫描仪虽然好用但是很难进行深度编辑且不能外出使用;很多手机App也有扫描功能,但技术能力总是差一点,当面对复杂…

Android系统-线程-java线程

引言 Android的框架应用是java环境的。java天生就是多线程。所以对java线程的理解尤为重要。 概念 线程状态转换图 NEW 初始状态 RUNNABLE 运行状态 BLOCKED 阻塞状态 WAITING 等待状态 TIME_WAITING 超时等待状态 TERMINATED 终止状态 注意:调用obj.wait(…

AI大模型的使用-用LangChain链式调用简化多步提示语

众所周知,openAI的prompt对英文比较友好,也就是英文提示它的结果会更准确,假如我们不会英文,我们把中文问题给到OpenAI,然后让它翻译成英文,并把翻译后的英文给到OpenAI,让它帮忙给出解答问题&a…

【升职加薪秘籍】我在服务监控方面的实践(9)-报警设计

大家好,我是蓝胖子,关于性能分析的视频和文章我也大大小小出了有一二十篇了,算是已经有了一个系列,之前的代码已经上传到github.com/HobbyBear/performance-analyze, 接下来这段时间我将在之前内容的基础上,结合自己在…

干了外包3个月,技术退步明显...

先说一下自己的情况,大专生,18年通过校招进入湖南某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试…

你觉得 Android 还有必要继续吗?

前言 这些年,总是听到有人说Android 开发岗位要凉了,不好做了。坦白说,市场倾向理性,竞争变强是很正常的事。但你发现总有些人,他们拿的 Offer 薪资是更高的,能达到年薪五六十万,甚至年薪百万。…

贪心算法总结篇

文章转自代码随想录 贪心算法总结篇 我刚刚开始讲解贪心系列的时候就说了,贪心系列并不打算严格的从简单到困难这么个顺序来讲解。 因为贪心的简单题可能往往过于简单甚至感觉不到贪心,如果我连续几天讲解简单的贪心,估计录友们一定会不耐…