TenserRT(四)在 PYTORCH 中支持更多 ONNX 算子

news2025/1/19 16:29:21

第四章:在 PyTorch 中支持更多 ONNX 算子 — mmdeploy 0.12.0 文档

PyTorch扩充。

PyTorch转换成ONNX:

  • PyTorch有实现。
  • PyTorch可以转化成一个或者多个ONNX算子。
  • ONNX有相应算子。

如果即没有PyTorch实现,且缺少PyTorch与ONNX的映射关系,则需要:

  • Pytorch算子
    • 组合现有算子
    • 添加TorchScript算子
    • 添加普通C++拓展算子
  • 映射方法
    • 为ATen算子添加符号函数
    • 为TorchScript算子添加符号函数
    • 封装成torch.autograd.Function并添加符号函数
  • ONNX算子
    • 使用现用ONNX算子
    • 定义新ONNX算子

不同的情况需要灵活的选用和组合这些方法。

支持 ATen 算子

算子在ATen中已经实现,ONNX也有相关算子定义,但是相关算子映射成ONNX的规则没有写。为ATen算子补充描述映射规则的符号函数。

PyTorch C++ API — PyTorch master documentation ATen是PyTorch内置的C++张量计算库,PyTorch算子在底层绝大多数计算都是用ATen实现的。

例如ONNX的Asinh https://github.com/onnx/onnx/blob/main/docs/Operators.md#Asinh 算子在ATen中有实现,但缺少映射到ONNX算子的符号函数,则需要补全符号函数,并导出一个包含该算子的ONNX模型。

获取 ATen 中算子接口定义

torch/_C/_VariableFunctions.pyi和 torch/nn/functional.pyi 两个文件可以获取函数的输入定义,这两个文件是编译pytorch时自动生成的,里面包含了ATen算子的pytorch调用接口,在torch/_C/_VariableFunctions.pyi中搜索asinh接口为

def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...

缺失算子为asinh,在ATen中实现的算子,在_VariableFunctions.pyi找到对应接口,需要补充对应的符号函数,使其在转场ONNX时不在报错。

添加符号函数

符号函数,看成pytorch的静态方法,将pytroch转换成ONNX模型时,pytroch算子的符号函数将被依次调用,已完成Pytorch算子到ONNX算子的转换。

def symbolic(g: torch._C.Graph, input_0: torch._C.Value, input_1: torch._C.Value, ...):

torch._C.Graph 和 torch._C.Value对应Pytorch的C++实现里的一些类。第一个参数g,表示和计算图相关内容;后面的参数input是算子输入,需要和算子的前向推理接口输入已知。对于ATen算子来说就时两个.pyi文件里的函数接口。

g有一个op方法。把pytorc算子转换成ONNX算子时,需要在符号函数中调用此方法来最终计算投图添加一个ONNX算子

def op(name: str, input_0: torch._C.Value, input_1: torch._C.Value, ...)

name:算子名称,如ONNX算子名称。

简单情况,将pytorch算子的输入用g.op()一一对应到ONNX算子上,并把g.op()的返回值作为符号函数的返回值。复杂的情况,将一个pytorch算子新建为若干个ONNX算子。

from torch.onnx import register_custom_op_symbolic
def asinh_symbolic(g, input):
    return g.op("custom_domain::Asinh", input)

register_custom_op_symbolic('custom_ops::asinh', asinh_symbolic, 9)

asinh_symbolic就是asinh符号函数,输入参数需要按照在ATen中的定义

def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...

符号函数的函数体重g.op("custom_domain::Asinh", input)完成了ONNX算子的定义,第一个参数custom_domain::Asinh是算子在ONNX中的名称,之于第二个参数input,这个算子只有一个输入,主需要把符号函数的输入参数input对应过去就可以了。ONNX的custom_domain::Asinh输出和ATen的asinh的输出一直,因此直接把g.op()结果返回即可。

使用pytorch API中register_op将富含函数和原来的ATen算子绑定在一起,

register_custom_op_symbolic('custom_ops::asinh', asinh_symbolic, 9)

第一个参数custom_ops::asinh是目标ATen算子名。

第二个参数asinh_symbolic是要注册的符号函数。

第三个参数9表示算子集注册。

import torch
from torch.onnx import register_custom_op_symbolic

#创建一个简单的神经网络层,实现forward方法,
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.asinh(x)#计算反双曲正弦值

def asinh_symbolic(g, input, *, out=None):
    #pytorch的计算图有节点Node和边edge组成,Node表示操作(加减乘除卷积)
    #边表示张量间数据流关系,一个Asinh的节点,以input作为输入,输出就是Asinh(input)
    return g.op("Asinh", input)

#将Model中的asinh操作重新绑定为asinh_symbolic,重命名该节点为“Asinh”,使用9号算法集
register_custom_op_symbolic('aten::asinh', asinh_symbolic,  9)

model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'asinh.onnx')
#总结,就是声明一个torch.asinh的操作,该操作通过register_custom_op_symbolic注册为,9号算法集中Asinh操作

 


import onnxruntime
import torch
import numpy as np

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.asinh(x)

model = Model()
input = torch.rand(1, 3, 10, 10)
torch_output = model(input).detach().numpy()#torch做了一次推理,然后转成numpy格式

sess = onnxruntime.InferenceSession('asinh.onnx')
ort_output = sess.run(None, {'onnx::Asinh_0': input.numpy()})[0]
#这里的名字要和onnx中图节点的名字一致啊
assert np.allclose(torch_output, ort_output)#判断torch值和onnxruntime值一致,assert断言返回值,allclose判断两个张量是否一致

支持torcscript算子

pytorch算子无法直接满足复杂实现,需要自定义一个pytorch算子,然后转成ONNX形式。

为算子添加符号函数:

1、获取原算子的前向推理接口。#forward

2、获取目标ONNX算子的定义。#https://github.com/onnx/onnx/blob/main/docs/Operators.md

3、编写符号函数并绑定。#asinh_symbolic,register_custom_op_symbolic

使用torchscript算子

import torch
import torchvision

#定义一个包含算子的模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 18, 3)#作为形变卷积的偏移张量
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)#形变卷积

    def forward(self, x):
        return self.conv2(x, self.conv1(x))

自定义 ONNX 算子


#定义一个包含算子的模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 18, 3)#作为形变卷积的偏移张量
        self.conv2 = torchvision.ops.DeformConv2d(3, 3, 3)#形变卷积点https://pytorch.org/vision/stable/ops.html

    def forward(self, x):
        return self.conv2(x, self.conv1(x))
#@parse_args装饰器,torchscripp算子的符号函数要求标注出么米一个输入参数的数据类型,
#v表示torch库中的value类型,一般用于标注张量
#i表示int类型
#f表示float
#none表示该数据为空。
#可以在torch.onny.symbolic_helper.py中查看
@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "none")
def symbolic(g,
        input,#张量v
        weight,#张量v
        offset,#张量v
        mask,#张量v
        bias,#张量v
        stride_h, stride_w, #int i
        pad_h, pad_w,
        dil_h, dil_w,
        n_weight_grps,
        n_offset_grps,
        use_mask):
    #以查询到的DeformConv2d算子输入参数作为符号函数的输入
    #custom是命名空间,以区别官方的算子
    #只使用input和offset来构造ONNX算子
    return g.op("custom::deform_conv2d", input, offset)#只是简单的例子,如何定义一个onnx中的deform节点,所以不做具体实现。
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::deform_conv2d", symbolic, 9)

model = Model()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, input, 'dcn.onnx')

 使用torch.autograd.Function

为 PyTorch 添加 C++ 拓展

//my_add.cpp
#include <torch/torch.h>

torch::Tensor my_add(torch::Tensor a, torch::Tensor b)//torch::Tensor就是c++中torch张量
{
    return 2 * a + b;
}

PYBIND11_MODULE(my_lib, m)//PYBIND11_MODULE为C++提供python调用接口。这里的my_lib是将来要在python中导入的模块名
{
    m.def("my_add", my_add);//my_add是python调用的接口名称,这里的接口名称与c++函数名称不一定要一样,但是这一命名辨识度比较高。
}
python setup.py develop
#编译文件

用 torch.autograd.Function 封装底层调用

import torch
import my_lib
#Function类本身是pytorch的一个可导函数,只需要实现前向推理和反向传播实现。
class MyAddFunction(torch.autograd.Function):
    #pytorch自动调用该函数,合适的执行前向和反向计算。
    @staticmethod
    def forward(ctx, a, b):
        #forward函数中调用c++函数,my_lib是库名,my_add函数名,这两个名字是在
        #PYBIND11_MODULE中定义的。
        return my_lib.my_add(a, b)
    #对模型部署来说,function类有个很好的性质:如果定义了symbolic静态方法,
    #该function在执行torch.onnx.export()时就可以根据symbolic中定义的规则,
    #转换成ONNX算子,这个symbolic就是前面提到的符号函数,只是这里的名称必须是symbolic而已。
    @staticmethod
    def symbolic(g, a, b):
        #g.op()只需要根据ONNX算子定义的规则把输入参数填入即可
        #ONNX中把新建常量当成一个算子来看待,尽管这个算子并不会以节点形式出现在ONNX模型的可视化结果里。
        two = g.op("Constant", value_t=torch.tensor([2]))#常量算子,把pytorch张量值传入value_t参数
        a = g.op('Mul', a, two)#乘法
        return g.op('Add', a, b)#加法
my_add = MyAddFunction.apply#apply是torch.autograd.Function的方法,这个方法完成了Function在前向推理或者反向传播的调度。
#在使用Function的派生类做推理时,不应该显示的调用forward,而应该调用apply方法。

class MyAdd(torch.nn.Module):#把my_add封装成一个神经网络中的计算层。
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        return my_add(a, b)

测试算子

import torch
import my_lib
#Function类本身是pytorch的一个可导函数,只需要实现前向推理和反向传播实现。
class MyAddFunction(torch.autograd.Function):
    #pytorch自动调用该函数,合适的执行前向和反向计算。
    @staticmethod
    def forward(ctx, a, b):
        #forward函数中调用c++函数,my_lib是库名,my_add函数名,这两个名字是在
        #PYBIND11_MODULE中定义的。
        return my_lib.my_add(a, b)
    #对模型部署来说,function类有个很好的性质:如果定义了symbolic静态方法,
    #该function在执行torch.onnx.export()时就可以根据symbolic中定义的规则,
    #转换成ONNX算子,这个symbolic就是前面提到的符号函数,只是这里的名称必须是symbolic而已。
    @staticmethod
    def symbolic(g, a, b):
        #g.op()只需要根据ONNX算子定义的规则把输入参数填入即可
        #ONNX中把新建常量当成一个算子来看待,尽管这个算子并不会以节点形式出现在ONNX模型的可视化结果里。
        two = g.op("Constant", value_t=torch.tensor([2]))#常量算子,把pytorch张量值传入value_t参数
        a = g.op('Mul', a, two)#乘法
        return g.op('Add', a, b)#加法
my_add = MyAddFunction.apply#apply是torch.autograd.Function的方法,这个方法完成了Function在前向推理或者反向传播的调度。
#在使用Function的派生类做推理时,不应该显示的调用forward,而应该调用apply方法。

class MyAdd(torch.nn.Module):#把my_add封装成一个神经网络中的计算层。
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        return my_add(a, b)

model = MyAdd()
input = torch.rand(1, 3, 10, 10)
torch.onnx.export(model, (input, input), 'my_add.onnx')
torch_output = model(input, input).detach().numpy()

import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('my_add.onnx')
ort_output = sess.run(None, {'a.1': input.numpy(), 'b.1': input.numpy()})[0]

assert np.allclose(torch_output, ort_output)

 总结

  • ATen是pytorch的C++张量库,~\Lib\site-packages\torch\_C\_VariableFunctions.pyi 和~\\Lib\site-packages\torch\nn可以知道ATen算子的python接口定义。
  • register_op可以为ATen算子补充注册符号函数。
  • register_custom_op_symbolic可以为TorchScript算子补充注册符号函数
  • 在Pytorch里添加C++扩展,#include <torch/torch.h> 、PYBIND11_MODULE(my_lib, m)、setup
  • torch.autograd.Function封装一个自定义的pytorch算子
  • symbolic编写符号函数。
  • g.op()把一个pytorch算子映射成一个或者多个ONNX算子。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/813628.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

太猛了,靠“吹牛”过顺丰一面,月薪30K

说在前面 在40岁老架构师尼恩的&#xff08;50&#xff09;读者社群中&#xff0c;经常有小伙伴&#xff0c;需要面试美团、京东、阿里、 百度、头条等大厂。 下面是一个5年小伙伴成功拿到通过了顺丰面试&#xff0c;拿到offer&#xff0c;月薪30K。 现在把面试真题和参考答…

一起学算法(插入排序篇)

概念&#xff1a; 插入排序&#xff08;inertion Sort&#xff09;一般也被称为直接插入排序&#xff0c;是一种简单的直观的排序算法 工作原理&#xff1a;将待排列元素划分为&#xff08;已排序&#xff09;和&#xff08;未排序&#xff09;两部分&#xff0c;每次从&…

Python毕业设计可用小游戏:5个热门类型,引爆学生热情!每个类型附单独案例!

游戏大全 前言1.格斗技能类小游戏2.益智塔防类小游戏3.MMO类型游戏4.养成类游戏5.经济类游戏 总结 前言 大家好&#xff0c;我是辣条哥 在当今数字化时代&#xff0c;编程已经成为一项不可或缺的技能。而Python作为一门简洁易学的编程语言&#xff0c;正受到越来越多学生的青睐…

03_使用execle表生成甘特图

背景 每次排期都需要话很多时间 很可能排期还不对头 这时候需要一个表能看到 1.什么时候项目结束 开始 转阶段 2.当前手上的活能不能做完 当前阶段手上有多少活 3.产品经理每次修改完计划迅速排期 甘特图生成 execle表生成 1.需要使用亿图创建甘特图 2.把当前的甘特图数据进…

使用Excel建立贷款损失计算器

前几天上了一门Excel课程&#xff0c;掌握了一些新的小技能&#xff0c;比如模拟运算表和控件以及动态图表的使用&#xff0c;结合工作内容进行了下实操练习。 一、控件和动态图表的使用 以贷款产品的损益测算为例&#xff0c;计算在不同资金成本、获客成本、提前还款损失以及风…

SpringBoot2.5.6整合Elasticsearch7.12.1

SpringBoot2.5.6整合Elasticsearch7.12.1 下面将通过SpringBoot整合Elasticseach&#xff0c;SpringBoot的版本是2.5.6&#xff0c;Elasticsearch的版本是7.12.1。 SpringBoot整合Elasticsearch主要有三种方式&#xff0c;一种是通过elasticsearch-rest-high-level-client&am…

c++里的基础类 is_empty_v<_Ty1>

&#xff08;1&#xff09;为什么要研究这个问题&#xff0c;因为包括智能指针等很多源代码里都会使用 _Compressed_pair 这个类&#xff0c;其是一对值。研究这个类&#xff0c;就牵涉另一个更基础的类 is_empty_v<_Ty1> &#xff08;2&#xff09; is_empty_v<_Ty1&…

内部类(下)匿名内部类,静态内部类的使用

文章目录 前言一、匿名内部类二、静态内部类三、内部类的继承总结 前言 该文将会介绍匿名内部类、静态内部类的使用&#xff0c;补充完毕java中的内部类。补充内容为向上转型为接口、使用this关键字获取引用、内部类的继承。 一、匿名内部类 定义&#xff1a;没有名称的内部类。…

redis 淘汰策略和持久化

文章目录 一、淘汰策略1.1 背景1.2 淘汰策略 二、持久化2.1 AOF日志2.1.1 AOF配置2.1.2 AOF策略2.1.3 AOF缺点2.1.4 AOF Rewrite2.1.5 AOF Rewrite配置2.1.6 AOF Rewrite缺点2.1.7 fork进程时的写时复制2.1.8 大key对持久化的影响 2.2 RDB快照2.2.1 RDB配置2.2.2 RDB缺点 2.3 混…

二分查找算法(全网最详细代码演示)

二分查找也称 半查找&#xff08;Binary Search&#xff09;&#xff0c;它时一种效率较高的查找方法。但是&#xff0c;折半查找要求线性表必须采用顺序存储结构&#xff0c;而且表中元素按关键字 有序 排列。 注意&#xff1a;使用二分查找的前提是 该数组是有序的。 在实际开…

web前端常用调试工具

概述 当我们写 webapp 或者 移动端H5网页时&#xff0c;要在手机上调试并不容易。 alert&#xff1a;很早之前的调试办法&#xff08;已被抛弃&#xff09; vconsole&#xff1a;是2016年由微信公众平台前端团队推出&#xff08;目前大量使用&#xff09; eruda&#xff1a…

解读随机森林的决策树:揭示模型背后的奥秘

一、引言 随机森林[1]是一种强大的机器学习算法&#xff0c;在许多领域都取得了显著的成功。它由多个决策树组成&#xff0c;而决策树则是构建随机森林的基本组件之一。通过深入解析决策树&#xff0c;我们可以更好地理解随机森林模型的工作原理和内在机制。 决策树是一种树状结…

虚拟现实技术(VR)

目录 1.什么是虚拟现实技术 2.虚拟现实技术的由来 3.虚拟现实技术给人类带来的好处 4.虚拟现实技术未来的走向 1.什么是虚拟现实技术 虚拟现实技术&#xff08;Virtual Reality&#xff0c;简称VR&#xff09;是一种通过计算机生成的模拟环境&#xff0c;使用户能够身临其境…

【js】经纬度位置获取navigator.geolocation.getCurrentPosition:

文章目录 一、经纬度位置获取navigator.geolocation.getCurrentPosition二、getCurrentPosition()在google chrome上不起作用 一、经纬度位置获取navigator.geolocation.getCurrentPosition 【文档】https://developer.mozilla.org/zh-CN/docs/Web/API/Window/navigator // 获取…

Redis 数据库高可用

Redis 数据库的高可用 一.Redis 数据库的持久化 1.Redis 高可用概念 &#xff08;1&#xff09;在web服务器中&#xff0c;高可用是指服务器可以正常访问的时间&#xff0c;衡量的标准是在多长时间内可以提供正常服务&#xff08;99.9%、99.99%、99.999%等等&#xff09;。 …

《MySQL 实战 45 讲》课程学习笔记(三)

事务隔离 事务就是要保证一组数据库操作&#xff0c;要么全部成功&#xff0c;要么全部失败。 隔离性与隔离级别 事务特性&#xff1a;ACID&#xff08;Atomicity、Consistency、Isolation、Durability&#xff0c;即原子性、一致性、隔离性、持久性&#xff09;。当数据库上…

Web-1-网站工作流程介绍

我们学习web开发&#xff0c;首先要知道什么是Web&#xff1f; Web: 全球广域网&#xff0c;也称为万维网(www World Wide Web)&#xff0c;能够通过浏览器访问的网站 比如我展示的这京东&#xff0c;淘宝唯品会都叫做网站&#xff0c;那么现在大家想一下&#xff0c;你还知道什…

用Ubuntu交叉编译Linux内核源码并部署到树莓派4B上

参考文章 1. 配置交叉编译环境 之前在ubuntu上配置过了&#xff0c;直接跳过 2.获取Linux内核源码 Linux内核源码链接 到链接里面选择自己合适版本的内核源码下载下来&#xff0c;然后传到ubuntu中进行解压 3.Linux内核源码的配置 参考文章 厂家配linux内核源码&#xff…

数据可视化库pyecharts简单入门

文章目录 0. 介绍1. 快速开始1.1 安装1.1.1 pip安装1.1.2 源码安装 1.2 快速上手1.2.1 柱状图1.2.2 链式调用1.2.3 使用options选项配置参数&#xff08;一切皆options&#xff09;1.2.4 渲染成图片文件1.2.5 使用主题1.2.6 地图 2. 全局配置项2.1 使用指南2.2 常用全局配置项2…

直呼牛逼!阿里最新 SpringBoot 进阶笔记涵盖了 SpringBoot 所有骚操作

相信从事 Java 开发的朋友都听说过 SSM 框架&#xff0c;老点的甚至经历过 SSH&#xff0c;说起来有点恐怖&#xff0c;比如我就是经历过 SSH 那个时代未流。当然无论是 SSM 还是 SSH 都不是今天的重点&#xff0c;今天要说的是 Spring Boot&#xff0c;一个令人眼前一亮的框架…