onnx代码解读

news2024/11/29 22:28:23

一、定义

  1. torch.jit.trace 相关代码解读
  2. onnx 内部实现
    3 查看是否为aten 算子
  3. aten 算子实现
  4. torch.autograd.Functions 算子实现
  5. 自定义算子实现
  6. 查找未实现的节点
  7. 一次性发现所有的未实现 aten 算子

二、实现

  1. torch.jit.trace 相关代码解读
    1. torch.jit.script() : 将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。
    2. torch.jit.trace : 跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。

3 查看是否为aten 算子

import torch

print(
    torch.jit.trace(
        torch.nn.ELU(), # module
        torch.ones(1)   # example input
    ).graph
)

算子追踪,在这里插入图片描述
3. aten 算子实现
  1.查看torch 接口定义    torch/nn/functional.pyi
  2.查看onnx 算子命名    https://github.com/onnx/onnx/blob/main/docs/Operators.md
  3. 查看注册函数书写   symbolic_opset9.py

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.

#def relu(input: Tensor) -> Tensor: ...   查看接口定义,
def correct_relu_symbolic_function(g, input):
    return g.op("Relu", input)             #查看onnx 实现

torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)
  1. torch.autograd.Functions 算子实现
    如果算子是torch.autograd.Functions 的子模块,可以使用该方法实现。
import torch

class MyRelu(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))


import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15

myrelu = MyRelu.apply        #核心
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.Linear(4, 5),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return myrelu(self.layers(x))

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)
  1. 自定义算子实现
    1. onnx 算子实现

    1. 自定义c++ 算子 +Extending TorchScript with Custom C++ Operators 实现
  2. 查找未实现的节点

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.   注册函数错误,导致find_mismatch 算子
def incorrect_relu_symbolic_function(g, self):
    return self
torch.onnx.register_custom_op_symbolic(
    "aten::relu",
    incorrect_relu_symbolic_function,
    opset_version=opset_version,
)
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

#===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X   __2 X    __1 \u2713
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 \u2713
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 \u2713
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}

修改后:
aten 算子实现

import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.

#def relu(input: Tensor) -> Tensor: ...   查看接口定义,
def correct_relu_symbolic_function(g, input):
    return g.op("Relu", input)             #查看onnx 实现

torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

方式二、
c++ 自定义算子


import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15


from torch.onnx import register_custom_op_symbolic        # 为 TorchScript 算子补充注册符号函数
from torch.onnx.symbolic_helper import parse_args
# '''
# 装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value 类型,
# 一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
# '''
@parse_args("v", "v")
def correct_relu_symbolic_function(g,input):
    return g.op("Relu", input)


torch.onnx.register_custom_op_symbolic(     #注册
    "aten::relu",
    correct_relu_symbolic_function,
    opset_version=opset_version,
)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)

graph_info = torch.onnx.verification.find_mismatch(
    Model(),
    (torch.randn(2, 3),),
    opset_version=opset_version,
)

  1. 一次性发现所有的未实现 aten 算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.ReLU(),
            torch.nn.Linear(4, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 6),
        )
    def forward(self, x):
        return self.layers(x)


torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
    Model(), (torch.randn(2, 3),), opset_version=opset_version
)

print(set(unconvertible_ops))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2206029.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库的基本概念、安装MySQL及基础运用

目录 一、数据库的基本概念 1. 使用数据库的必要性 2. 数据(Data) 3. 表 4. 数据库 5. 数据库管理系统(DBMS) 6. 数据库管理系统DBMS的优点 7. 使用数据库的必要性总结 8. 访问数据库的流程 二、数据库发展及基本功能 1.…

宠物空气净化器怎么选?希喂、霍尼韦尔、美的宠物哪款除毛好?

身为养宠五年的资深铲屎官,最近收到了很多新手养宠朋友关于宠物空气净化器的挑选疑问。宠物空气净化器作为宠物领域目前最火热的产品,谈论度一直很高,评价也褒贬不一。双十一购物节又即将到来,大家都想赶上这一波优惠活动。 铺天盖…

Automa插件之js脚本小技巧:零依赖的日期时间格式化,亲测好用!

背景 在使用 Automa 插件自动下载文件时,有时候需要根据当前时间重新命名文件,如果是时间戳的话倒是也可以防重复文件命名,只不过那样的话,没有了时间可读性. 所以需要日期时间格式化,分享一个一直在用的纯 js 格式化日期脚本,可实现简单的日期格式化. 文末附完整代码,直接复制…

时序约束进阶四:set_input_delay和set_output_delay详解

目录 一、前言 二、set_input_delay/set_output_delay 2.1 延时约束 2.2 约束设置界面 2.3 示例工程 2.4 Delay Value 2.5 Delay value is relative to clock edge 2.6 Delay value already includes latencies of the specified clock edge 2.7 Rise/Fall 2.8 Max/M…

教育部白名单赛事到底是什么?大家为什么那么重视它?

近年来,随着素质教育的推广和升学竞争的加剧,白名单赛事这一概念变得越来越热门。所谓的白名单赛事,是指经过教育部批准并公布的竞赛名单。这些比赛不仅具备权威性和高含金量,还受到各大中小学、重点高中和高校的广泛认可。在升学…

文件句柄泄漏排查及方法总结

如果只是怀疑文件句柄泄漏,可以通过Process Explorer 找到对应进程,双击点开查看performance中的handles变化即可,然后结合I/O项变化进行大致分析。 ——当然对于程序员而言,不光是要发现问题,还要定位问题。 针对li…

Qt 自绘开关按钮以及设计器中的提升为用法

文章目录 自绘按钮实现概要效果图代码 提升为用法介绍步骤 总结 自绘按钮实现 概要 当我们需要一个开关样式的QPushbutton,没有图片的话,我们可以采用自绘的形式实现。且使用QtDesinger中提升为Promote to的功能加入界面中,而不是使用代码的…

C++入门基础知识107—【关于C++continue 语句】

成长路上不孤单😊😊😊😊😊😊 【14后😊///C爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于C continue 语句的相关内容!…

初始爬虫13(js逆向)

为了解决网页端的动态加载,加密设置等,所以需要js逆向操作。 JavaScript逆向可以分为三大部分:寻找入口,调试分析和模拟执行。 1.chrome在爬虫中的作用 1.1preserve log的使用 默认情况下,页面发生跳转之后&#xf…

基于html的大模型调试页面

效果1 源码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>API Call Example</title><st…

C++面向对象--------继承篇

目录 一.继承&#xff08;重点&#xff09; 1.1 概念 1.2 构造函数 1.2.1 派生类与基类的构造函数关系 1.2.2 解决方案 1.2.2.1 补充基类的无参构造函数 1.2.2.2 手动在派生类中调用基类构造函数 1.2.2.2.1 透传构造 1.2.2.2.2 委托构造 1.2.2.2.3 继承构造 1.3 对象…

为什么SEO是一个不断学习和适应的过程?

SEO并不是一成不变的&#xff0c;它是一个需要不断学习和适应的过程。谷歌的算法经常更新&#xff0c;用户的搜索行为也在不断变化&#xff0c;这使得SEO策略必须与时俱进&#xff0c;才能保持有效性。企业需要认识到&#xff0c;SEO是一项长期的投资&#xff0c;需要持续的关注…

Spring WebFlux 响应式概述(1)

1、响应式编程概述 1.1、响应式编程介绍 1.1.1、为什么需要响应式 传统的命令式编程在面对当前的需求时的一些限制。在应用负载较高时&#xff0c;要求应用需要有更高的可用性&#xff0c;并提供低的延迟时间。 1、Thread per Request 模型 比如使用Servlet开发的单体应用&a…

PostgreSQL学习笔记十:锁机制详解

一、PostgreSQL 的锁机制 PostgreSQL中的锁机制是确保数据一致性和完整性的关键。它通过不同级别的锁来控制对数据对象的并发访问&#xff0c;主要包括表级锁、行级锁、页级锁、咨询锁&#xff08;Advisory Locks&#xff09;以及死锁&#xff08;Deadlocks&#xff09;。 1. …

基于Java+Springboot+Vue开发的大学竞赛报名管理系统

项目简介 该项目是基于JavaSpringbootVue开发的大学竞赛报名管理系统&#xff08;前后端分离&#xff09;&#xff0c;这是一项为大学生课程设计作业而开发的项目。该系统旨在帮助大学生学习并掌握Java编程技能&#xff0c;同时锻炼他们的项目设计与开发能力。通过学习基于Java…

# linux从入门到精通-从基础学起,逐步提升,探索linux奥秘(九)--网络设置与文件上传下载

linux从入门到精通-从基础学起&#xff0c;逐步提升&#xff0c;探索linux奥秘&#xff08;九&#xff09;–网络设置与文件上传下载 一、网络设置 1、首先知道网卡配置文件位置&#xff1a;/etc/sysconfig/network-scripts [rootlocalhost test1]# ls /etc/sysconfig/netwo…

JSON 格式化工具:快速便捷地格式化和查看 JSON 数据

JSON 格式化工具&#xff1a;快速便捷地格式化和查看 JSON 数据 为什么需要 JSON 格式化工具&#xff1f; 在日常开发和调试中&#xff0c;JSON 是非常常见的数据交换格式。无论是前端与后端的接口调用&#xff0c;还是数据存储和处理&#xff0c;JSON 格式都扮演着重要角色。…

【HarmonyOS开发笔记 2 】 -- ArkTS语法中的变量与常量

ArkTS是HarmonyOS开发的编程语言 ArkTS语法中的变量 【语法格式】&#xff1a; let 变量名: 类型 值 let&#xff1a;是定义变量的关键字类型&#xff1a; 值数据类型&#xff0c; 常用的数据类型 字符型&#xff08;string&#xff09;、数字型&#xff08;number&#xf…

PG 17 增量备份功能介绍

背景 PG 17 新增了增量备份功能&#xff0c;可以通过 pg_basebackup --incrementalPATH_TO_MANIFEST 命令进行增量备份。 官方文档&#xff1a;https://www.postgresql.org/docs/current/app-pgbasebackup.html 使用方法 全量备份 启动实例后&#xff0c;首先配置参数 sum…

【北京迅为】《STM32MP157开发板嵌入式开发指南》- 第三十五章 嵌入式开发概述及环境构建

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器&#xff0c;既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构&#xff0c;主频650M、1G内存、8G存储&#xff0c;核心板采用工业级板对板连接器&#xff0c;高可靠&#xff0c;牢固耐…