使用Pytorch导出自定义ONNX算子

news2024/11/23 13:50:59

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes


class CustomAffineGrid(Function):
    @staticmethod
    def forward(ctx, theta: torch.Tensor, size: torch.Tensor):
        grid = F.affine_grid(theta=theta, size=size.cpu().tolist())
        return grid

    @staticmethod
    def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):
        return g.op("AffineGrid", theta, size)


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):
        grid = CustomAffineGrid.apply(theta, size)
        x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")
        return x


def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        theta = torch.randn(1, 2, 3)
        size = torch.as_tensor([1, 3, 512, 512])

        torch.onnx.export(model=custum_model,
                          args=(x, theta, size),
                          f="custom.onnx",
                          input_names=["input0_x", "input1_theta", "input2_size"],
                          output_names=["output"],
                          dynamic_axes={"input0_x": {2: "h0", 3: "w0"},
                                        "output": {2: "h1", 3: "w1"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)


if __name__ == '__main__':
    main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes


class CustomRot90AndScale(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90
        x *= 1.2
        return x

    @staticmethod
    def symbolic(g: torch.Graph, x: torch.Tensor):
        return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")


class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        return CustomRot90AndScale.apply(x)


def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)

        torch.onnx.export(model=custum_model,
                          args=(x,),
                          f="custom_rot90.onnx",
                          input_names=["input"],
                          output_names=["output"],
                          dynamic_axes={"input": {2: "h0", 3: "w0"},
                                        "output": {2: "w0", 3: "h0"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)


if __name__ == '__main__':
    main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1489454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Maven】Maven 基础教程(四):搭建 Maven 私服 Nexus

《Maven 基础教程》系列,包含以下 4 篇文章: Maven 基础教程(一):基础介绍、开发环境配置Maven 基础教程(二):Maven 的使用Maven 基础教程(三):b…

JVM(类加载机制)

类加载就是 .class 文件, 从文件(硬盘) 被加载到内存(元数据区)中的过程 类加载的过程 加载: 找 .class 文件的过程, 打开文件, 读文件, 把文件读到内存中 验证: 检查 .class 文件的格式是否正确 .class 是一个二进制文件, 其格式有严格的说明 准备: 给类对象分配内存空间 (先在…

c++ thread的使用 调用类里面的函数和调用类外的函数的区别

1.thread 调用类外的函数。 在使用thread之前要加上#include <thread>。 例1&#xff1a; #include <iostream> #include <thread> using namespace std; void Threadfunc1() {cout << "Threadfunc1" << endl; }void Threadfunc2(in…

sqllab 11-22

11.有回显&#xff0c;单引号 首先判断是字符型还是数字型 通过order by 来获取字段数 方便后续union联合 注意这里mime表明了内容要进行url编码&#xff0c;测试3报错&#xff0c;2正常&#xff0c;所以有2列。 还需要判断显示位&#xff0c;因为只有显示位的数据才能被爆出…

基于springboot+vue的多媒体素材库的开发与应用系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

Python处理表格数据库之Agate使用详解

概要 您是否有时觉得在处理表格数据时感到不知所措? 也许你在处理一个大型 CSV 文件,遇到了各种数据不一致的问题,或者需要验证数据,确保其准确无误才能进行下一步分析。 传统的数据分析库或许功能强大,但学习曲线陡峭,用起来有点杀鸡用牛刀的感觉。 这时,有一个更…

Redis是AP的还是CP的?

redis是一个开源的内存数据库&#xff0c;那么他到底是AP的还是CP的呢&#xff1f; 有人说&#xff1a;单机的是redis是cp的&#xff0c;而集群的redis是ap的&#xff1f; 但是我不这么认为&#xff0c;我觉得redis就是ap的&#xff0c;虽然在单机redis中&#xff0c;因为只有…

计算机网络实验 基于ENSP的协议分析

实验二 基于eNSP的协议分析 一、实验目的&#xff1a; 1&#xff09;熟悉VRP的基本操作命令 2&#xff09;掌握ARP协议的基本工作原理 3&#xff09;掌握IP协议的基本工作原理 4&#xff09;掌握ICMP协议的基本工作原理 二、实验内容&#xff1a; 1、场景1&#xff1a;两台PC机…

力扣每日一题 用队列实现栈 模拟

Problem: 225. 用队列实现栈 文章目录 思路复杂度Code 思路 &#x1f468;‍&#x1f3eb; 力扣官解 辅助队列存栈顶元素主队列存逆序序列 复杂度 时间复杂度: 添加时间复杂度, 示例&#xff1a; O ( n ) O(n) O(n) 空间复杂度: 添加空间复杂度, 示例&#xff1a; O ( …

Markdown的语法使用

目录 前言一、Markdown基本使用二、基本使用补充2.1 三、特殊字符四、数学公式 前言 Markdown是网页版的文本编辑器&#xff0c;Markdown 允许您使用易于阅读、易于编写的纯文本格式进行编写&#xff0c;然后将其转换为结构有效的 XHTML&#xff08;或 HTML&#xff09;。本文主…

【C语言】内存操作篇---动态内存管理----malloc,realloc,calloc和free的用法【图文详解】

欢迎来CILMY23的博客喔&#xff0c;本篇为【C语言】内存操作篇---动态内存管理----malloc&#xff0c;realloc&#xff0c;calloc和free的用法【图文详解】&#xff0c;感谢观看&#xff0c;支持的可以给个一键三连&#xff0c;点赞关注收藏。 前言 在学完结构体后&#xff08;…

《Trustzone/TEE/安全-实践版》介绍

第一章&#xff1a;课程说明和准备 课程介绍和说明 资料准备 为什么使用qemu_v8环境&#xff1f; 为什么选择香橙派开发板&#xff1f; optee qemu_v8环境展示 香橙派optee环境展示 第二章&#xff1a;Qemu环境搭建 ubuntu20.04的安装(virtualboxubuntu20.04) 搭建optee qem…

云手机的境外舆情监控应用——助力品牌公关

在当今数字化时代&#xff0c;社交媒体已成为品牌传播和互动的主要平台。随之而来的是海量的信息涌入&#xff0c;品牌需要及时了解并应对海外社交媒体上的舆情变化。本文将介绍如何通过云手机进行境外舆情监控&#xff0c;更好地帮助企业公关及时作出决策。 1. 境外舆情监控与…

BUUCTF---[BJDCTF2020]藏藏藏1

1.题目描述 2.下载附件&#xff0c;解压之后是一张图片和一个文本 3.把图片放在winhex,发现图片里面包含压缩包 4.在kali中使用binwalk查看&#xff0c;然后使用foremost分离&#xff0c;在使用tree查看分离出来的文件&#xff0c;最后将zip文件使用unzip进行解压。步骤如下 5.…

rtt的io设备框架面向对象学习-电阻屏LCD设备

目录 1.8080通信的电阻屏LCD设备1.1 构造流程1.2 使用2.i2c和spi通信的电阻屏LCD 电阻屏LCD通信接口有支持I2c、SPI和8080通信接口的。 1.8080通信的电阻屏LCD设备 rtt没有实现的设备驱动框架层&#xff0c;那么是在驱动层直接实现的&#xff0c;以stm32f407-atk-explorer为例…

NOIP 2010普及组初赛试题及解析

NOIP 2010普及组初赛试题及解析 一. 单项选择题 &#xff08;共20题&#xff0c;每题1.5分&#xff0c;共计30分。每题有且仅有一个正确答案.&#xff09;。二. 问题求解&#xff08;共2题&#xff0c;每题5分&#xff0c;共计10分&#xff09;三. 阅读程序写结果&#xff08;共…

SpringBoot-首页和图标定制

1.静态资源导入 SpringBoot中的静态资源&#xff0c;默认有以下四个路径可以访问&#xff1a; classpath:/META-INF/resources/ classpath:/resources/ classpath:/static/ classpath:/public/ 其中第一个路径&#xff0c;一般不常用&#xff0c;它是来获取用maven导入webj…

[LeetBook]【学习日记】寻找链表相交节点

来源于「Krahets」的《图解算法数据结构》 https://leetcode.cn/leetbook/detail/illustration-of-algorithm/ 本题与主站 160 题相同&#xff1a;https://leetcode-cn.com/problems/intersection-of-two-linked-lists/ 训练计划 V 某教练同时带教两位学员&#xff0c;分别以…

xshell安装java/jdk

1.下载jdk wget https://download.java.net/java/GA/jdk11/13/GPL/openjdk-11.0.1_linux-x64_bin.tar.gz 2.解压jdk安装包 tar -zxvf openjdk-11.0.1_linux-x64_bin.tar.gz 其中第三步 编辑 ~/.bashrc 或 ~/.bash_profile 文件 打开vim文本编辑器 vim ~/.bash_profile export …

基于springboot+vue的校园网上店铺

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…