onnx进阶算子优化

news2025/1/22 16:50:36

一、定义

  1. 如何保证pytorch 模型顺利转为onnx. 前言
  2. pytorch 算子是如何与onnx 算子对齐的?
  3. Asinh 算子出现于第 9 个 ONNX 算子集。PyTorch 在 9 号版本的符号表文件中是怎样支持这个算子的?
  4. BitShift 算子出现于第11个 ONNX 算子集。PyTorch 在 11 号版本的符号表文件中是怎样支持这个算子的?
  5. 算子在pytorch 中已经实现,onnx 算子也实现,缺少映射方法,自己注册,实现转换。
  6. 自定义onnx 算子。
  7. 构造onnx 模型,并测试。
  8. onnx提取子模型

二、实现

  1. 如何保证pytorch 模型顺利转为onnx. 前言, 参考:https://zhuanlan.zhihu.com/p/513387413
    要使 PyTorch 算子顺利转换到 ONNX ,我们需要保证以下三个环节都不出错:
    算子在 PyTorch 中有实现
    有把该 PyTorch 算子映射成一个或多个 ONNX 算子的方法
    ONNX 有相应的算子
    可在实际部署中,这三部分的内容都可能有所缺失。其中最坏的情况是:我们定义了一个全新的算子,它不仅缺少 PyTorch 实现,还缺少 PyTorch 到 ONNX 的映射关系。但所谓车到山前必有路,对于这三个环节,我们也分别都有以下的添加支持的方法:
    PyTorch 算子
    组合现有算子
    添加 TorchScript 算子
    添加普通 C++ 拓展算子
    映射方法
    为 ATen 算子添加符号函数
    为 TorchScript 算子添加符号函数
    封装成 torch.autograd.Function 并添加符号函数
    ONNX 算子
    使用现有 ONNX 算子
    定义新 ONNX 算子

  2. pytorch 算子是如何与onnx 算子对齐的?
    onnx 算子文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md
    torch 对onnx算子映射:https://github.com/pytorch/pytorch/tree/main/torch/onnx
    在这里插入图片描述
    表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。在这里插入图片描述
    symbolic_opset{n}.py(符号表文件)即表示 PyTorch 在支持第 n 版 ONNX 算子集时新加入的内容。判定是否存在映射方法。

  3. Asinh 算子出现于第 9 个 ONNX 算子集。PyTorch 在 9 号版本的符号表文件中是怎样支持这个算子的?
    Asinh 在第9版本onnx 中实现,检查symbolic_opset9.py 发现,但pytorch 中已经实现torch.asinh(), 即缺少映射方法。

  4. BitShift 算子出现于第11个 ONNX 算子集。PyTorch 在 11 号版本的符号表文件中是怎样支持这个算子的?
    通过在 torch.onnx.symbolic_opset11.py 搜索 BitShift,我们可以发现 PyTorch 在 _lshift 和 _rshift 里用到了ONNX的 BitShift 算子。当输入类型为 Byte 时,PyTorch会把算子直接翻译翻译
    BitShift,以代替乘除 2 的次幂的操作。

  5. 算子在pytorch 中已经实现,onnx 算子也实现,缺少映射方法,自己注册,实现转换。
    1. 获取 ATen 中算子接口定义
    2. 添加符号函数

  6. 整合模型,导出onnx文件

  7. 测试算子
    ================================================================

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.asinh(x)

from torch.onnx.symbolic_registry import register_op

def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

register_op('asinh', asinh_symbolic, '', 9)

model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')

测试

import onnxruntime 
import torch 
import numpy as np 
 
class Model(torch.nn.Module): 
    def __init__(self): 
        super().__init__() 
 
    def forward(self, x): 
        return torch.asinh(x) 
 
model = Model() 
input = torch.rand(1, 3, 10, 10) 
torch_output = model(input).detach().numpy() 
 
sess = onnxruntime.InferenceSession('asinh.onnx') 
ort_output = sess.run(None, {'0': input.numpy()})[0] 
 
assert np.allclose(torch_output, ort_output) 
  1. 自定义onnx 算子。
    https://zhuanlan.zhihu.com/p/513387413
  2. 构造onnx 模型,并测试。
import onnx
from onnx import helper
from onnx import TensorProto

# input and output
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])

# Mul
mul = helper.make_node('Mul', ['a', 'x'], ['c'])

# Add
add = helper.make_node('Add', ['c', 'b'], ['output'])

# graph and model
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)

# save model
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx') 
import onnxruntime 
import numpy as np 
 
sess = onnxruntime.InferenceSession('linear_func.onnx') 
a = np.random.rand(10, 10).astype(np.float32) 
b = np.random.rand(10, 10).astype(np.float32) 
x = np.random.rand(10, 10).astype(np.float32) 
 
output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] 
 
assert np.allclose(output, a * x + b)
  1. onnx提取子模型
    https://zhuanlan.zhihu.com/p/516920606

https://zhuanlan.zhihu.com/p/543973749
https://zhuanlan.zhihu.com/p/516920606

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1840810.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT——信号和槽

一、信号的概念 信号成员会在满足一定条件时某个对象自动触发该对象的信号(触发条件可以通过QT帮助文档进行查看每个类中标示了signals的栏位),触发后程序会调用与信号相连接的槽函数。 二、槽函数的概念 Qt中的槽函数(Slot)是与信号(Signal)对应的概念,用于接收信号…

TIME_WAIT的危害

前言 该文章主要讨论下TIME_WAIT的存在意义和潜在危害,以及解决措施。 具体内容 首先看一下下面这幅图 这幅图来自《TCP IP详解卷1:协议 原书第2版中文》TCP状态变迁图。 TIME_WAIT存在意义 可靠的终止TCP连接。 保证让迟来的TCP报文有足够的时间被…

spring中@Conditional

多环境切换 java配置使用profile Profile设置在某个环境下,spring注入对应的bean public class JavaConfig {BeanProfile("dev")DataSource devDs(){DataSource ds new DataSource();ds.setUrl("dev");ds.setUsername("dev");ret…

Win11 删除文件时提示“找不到该项目,请重试”的解决办法

1、Win R 打开运行窗口,输入 notepad 并回车打开文本文档(记事本)软件,如下图: 2、在文本文档(记事本)软件中复制粘贴以下代码,如下图: del /f /a /q \\?\%1 rd /s /q \\?\%1或DEL /F /A /Q \\?\%1 RD /S /Q \\?…

基于SpringBoot+Vue农产品管理与销售APP设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝1W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟 感兴趣的可以先收藏起来,还…

python 版本切换,更换当前默认版本

电脑可以安装多个版本,但是好像没有正规的维护python版本的工具,比如前端就有nvm切换node版本,但是python我没找到比较好的(有大佬知道路过方便留言一下,跪谢。。) 废话不多说,更改默认版本很简…

反射复习(java)

文章目录 反射机制的作用反射机制的原理加载机制详细解释 获取 Class 对象反射获取构造方法:获取 Class 对象里面 Constructor 对象反射获取成员变量:获取Class 对象里面的 Field 对象反射获取成员方法:获取 Class 对象里的 Method 对象其他常…

融合创新,共筑未来 | 人大金仓为轨道交通发展注入新质力量

METROTRANS 2024 6月13日~15日,以“多元融合 高质量可持续发展”为主题的2024北京-南京国际城市轨道交通展览会暨高峰论坛在南京国际博览中心隆重举行。人大金仓受邀亮相本次大会,展示了其在轨道交通的创新应用和解决方案。 ‍点击视频 直击现场‍ 规模最…

机器学习课程复习——奇异值分解

1. 三种奇异值分解 奇异值分解(Singular Value Decomposition, SVD)包含了: 完全奇异值分解(Complete Singular Value Decomposition, CSVD)紧奇异值分解(Tight Singular Value Decomposition, TSVD&…

年终奖发放没几天,提离职领导指责我不厚道,我该怎么办?

“年终奖都发了,你还跳槽?太不厚道了吧!” “拿完年终奖就走人,这不是典型的‘骑驴找马’吗?” 每到岁末年初,关于“拿到年终奖后是否应该立即辞职”的话题总会引发热议。支持者认为,这是个人…

GLSB是什么?带你深入了解GLSB核心功能

伴随互联网的快速发展,大型企业等组织单位通过建设多数据中心,以提升用户体验。然而想要在多个数据中心实现流量的智能管理,提高网站的可靠性和可用性,则需要全局服务器负载均衡技术——GLSB的助力。GLSB是什么?它又有…

笔记本系统盘移植与windowsLinux双系统安装

目录 一、 前言二、 Windows系统移植二、 安装Linux三、 Windows分区配置 一、 前言 笔记本内存不够了,之前给笔记本添加了一个机械硬盘,也几乎爆满了,于是购置了1T的固态硬盘,打算用这个固态硬盘安装双系统,剩余空间…

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【11】ElasticSearch 简介基本概念ElasticSearch概念-倒排索引安装基本命令ik 分词器SpringBoot整合测试存储数据:测试复杂检索同步与异步调用 参考 简介 Elasticsearch 是一…

【UE数字孪生学习笔记】 Apifox一体化接口测试平台

声明:部分内容来自于b站,知乎,慕课,公开课等的课件,仅供学习使用。如有问题,请联系删除。 部分内容来自UE官方文档,博客等 Apifox接口测试 Apifox 是集 API 文档、API 调试、API Mock、API 自动…

DDMA信号处理以及数据处理的流程---距离速度测量

Hello,大家好,我是Xiaojie,好久不见,欢迎大家能够和Xiaojie一起学习毫米波雷达知识,Xiaojie准备连载一个系列的文章—DDMA信号处理以及数据处理的流程,本系列文章将从目标生成、信号仿真、测距、测速、cfar检测、测角、目标聚类、目标跟踪这几个模块逐步介绍,这个系列的…

腾讯《地下城与勇士:起源》手游在部分安卓平台停止更新

原标题:因合约到期 《DNF手游》停止安卓平台更新 易采游戏网6月19日消息:《地下城与勇士:起源》(简称DNF手游)官方今天公告,因合作协议到期,自6月20日起,该游戏将不再在某些安卓应用商店提供。腾讯公司已经…

OpenAI 发布多模态 GPT-4 模型,会开创哪些新的研究方向?

作者:JioNLP 链接:https://www.zhihu.com/question/589640227/answer/2936760622 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。 短期看,GPT4 就是个终结者 。开创不了什么新的方…

Redis学习|Jedis、SpringBoot整合Redis

Jedis 我们要使用Java 来操作 Redis,知其然并知其所以然,授人以渔!学习不能急躁,慢慢来会很快!什么是Jedis 是 Redis 官方推荐的java连接开发工具!使用java 操作Redis 中间件!如果你要使用 java操作redis,那么一定要对Jedis 十分的熟悉! 1、…

C++初学者指南第一步---7.控制流(基础)

C初学者指南第一步—7.控制流(基础) 文章目录 C初学者指南第一步---7.控制流(基础)1.术语:表达式/语句Expressions表达式Statements语句 2.条件分支3.Switching(切换):基于值的分支4.三元条件运算符5.循环迭代基于范围的循环   C…

STM32人工智能检测-筛选机器人

前言 本文描述了一种使用STM32进行机器人筛选的办法。筛选对象是我的粉s,删选办法是瞪眼法。 问题现象 每次当我的STM32 向外界发出一篇新的的报文,总能在1H之内得到focus,格式如下 [title][body][tail]于是我对各个focus 我报文的对象进…