TenserRT(三)PYTORCH 转 ONNX 详解

news2025/1/23 3:03:06

第三章:PyTorch 转 ONNX 详解 — mmdeploy 0.12.0 文档

torch.onnx — PyTorch 2.0 documentation

torch.onnx.export 细解

计算图导出方法

TorchScript是一种序列化和优化PyTorch模型的格式,将torch.nn.Module模型转换为TorchScript的torch.jit.ScriptModule模型,也是一种中间表示。

torch.onnx.export中使用的模型实际上是torch.jit.ScriptModule。

将torch.nn.Module转化为TorchScript模型(导出计算图)有两种模式:跟踪(trace)和脚本化(script)。

torch.onnx.export输入一个torch.nn.Module,默认会使用跟踪(trace)的方法导出。

 

import torch

class Model(torch.nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        for i in range(self.n):#控制输入张量被卷积的次数
            x = self.conv(x)
        return x

models = [Model(2), Model(3)]# n=2和n=3的模型
model_names = ['model_2', 'model_3']

for model, model_name in zip(models, model_names):
    dummy_input = torch.rand(1, 3, 10, 10)
    dummy_output = model(dummy_input)
    model_trace = torch.jit.trace(model, dummy_input)
    model_script = torch.jit.script(model)

    #torch.onnx.export默认使用trace,所有不需要先trace
    # 跟踪法与直接 torch.onnx.export(model, ...)等价
    # torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)
    torch.onnx.export(model,  dummy_input,f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)
    # 脚本化必须先调用 torch.jit.sciprt
    torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)

    # 如果是先运行了torch.jit.script,将模型转化成TorchScript,则export函数不需要再运行一遍
    # 如果输入不是TorchScript,则export需要运行一遍模型
    # dummy_input和dummy_output表示输入输出张量的数据类型和形状


跟踪法trace中,不同的n得到的ONNX模型结构是不一样的。

脚本法script中,Loop节点表示循环,不同的n可以有相同的结构。

推理引擎对静态图支持更好,不需要显式的将PyTorch模型转换为TorchScript,直接使用torch.onnx.export跟踪法导出即可。

虽然在代码中没有直接将trace的脚本作为export输入,但是可以通过trace来定位export问题是否出现在trace中。

参数讲解

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
           input_names=None, output_names=None, aten=False, export_raw_ir=False,
           operator_export_type=None, opset_version=None, _retain_param_name=True,
           do_constant_folding=True, example_outputs=None, strip_doc_string=True,
           dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
           enable_onnx_checker=True, use_external_data_format=False):
  • 模型(model):必选
  • 输入(args):必选
  • 导出的 onnx 文件名(f):必选
  • 模型中是否保存权重(export_params):一般模型结构和模型权重放在一个文件里存储,所以默认是true,如果是在不同的框架间传递模型,而不是用于部署,则设置为false。
  • 输入/输出张量名称(input_names, output_names):推理引擎一般都需要通过“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据,保证ONNX和推理引擎中使用同一套名称。
  • opset_version:ONNX算子集版本。
  • dynamic_axes:指定输入输出张量的哪些维度是动态的,为了追求效率,ONNX默认所有参与运算的张量都是静态的(张量的形状不发生改变)。可以显式的指明输入输出张量的哪几个维度的大小是可变的
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()#继承父类构造函数中
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.conv(x)
        return x


model = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']

# dynamic_axes_0 = {#第0维动态
#     'in' : [0],
#     'out' : [0]
# }
dynamic_axes_0 = {
    'in' : {0: 'batch'},
    'out' : {0: 'batch'}
}
dynamic_axes_23 = {#第2、3维动态
    'in' : [2, 3],
    'out' : [2, 3]
}

torch.onnx.export(model, dummy_input, model_names[0], input_names=['in'], output_names=['out'])#没有动态维度
torch.onnx.export(model, dummy_input, model_names[1], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0)#第0维动态
torch.onnx.export(model, dummy_input, model_names[2], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23)#第2、3维动态
# ONNX 要求每个动态维度都有一个名字,直接这样写会引出一条UserWarning,警告我们通过列表方式设置动态维度的话,系统会自动为它们分配名字

 

import onnxruntime
import numpy as np

origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)

inputs = [origin_tensor, mult_batch_tensor, big_tensor]
exceptions = dict()
model_names = ['model_static.onnx',#批量或者维度增加就会出错
'model_dynamic_0.onnx',#维度增加就会出错
'model_dynamic_23.onnx']#批量增加就会出错
for model_name in model_names:
    for i, input in enumerate(inputs):
        try:
            ort_session = onnxruntime.InferenceSession(model_name)
            ort_inputs = {'in': input}
            ort_session.run(['out'], ort_inputs)#只有在设置了对应的动态维度后才不会出错
        except Exception as e:
            exceptions[(i, model_name)] = e
            print(f'Input[{i}] on model {model_name} error.')
            print(exceptions[(1, 'model_static.onnx')])

        else:
            print(f'Input[{i}] on model {model_name} succeed.')

使用技巧

torch.onnx.is_in_onnx_export():PyTorch推理时不运行,但是在执行torch.onnx.export()时为真。

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.conv(x)
        if torch.onnx.is_in_onnx_export():# 仅在模型导出时把输出张量的数值限制在[0,1]之间
            #可以在代码中添加和模型部署相关的逻辑
            x = torch.clip(x, 0, 1)
        return x

利用中断张量跟踪的操作

import torch
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #item、for、list等方法都会导致ONNX模型不太正确
        x = x * x[0].item()#跟踪法会把某些取决于输入的中间结果变成常量
        # .item()把torch中的张量转换成普通的Python遍历
        return x, torch.Tensor([i for i in x])#遍历torch张量,并用一个列表新建一个torch张量。

model = Model()
dummy_input = torch.rand(10)
torch.onnx.export(model, dummy_input, 'a.onnx')

涉及到张量与普通变量转换的逻辑都会导致最终ONNX模型不太正确。

利用这个性质,在保证正确性的前提下令模型中间结果变成常量。

这个技巧尝尝用于模型的静态化上,即令模型中所有张量形状都变成常量。

使用张量为输入(PyTorch版本 < 1.9.0)

PyTorch 对 ONNX 的算子支持

如果torch.onnx.export()正常执行后,另一个容易出现的问题就是算子不兼容。

在转换普通torch.nn.Module模型时:

  • Pytorch利用跟踪法执行前向推理,把遇到的算子整合成计算图;
  • Pytorch把遇到的算子翻译成ONNX定义的算子。

算子翻译的过程可能遇到的情况:

  • 算子可以一对一翻译成ONNX算子。
  • 算子没有一对一的ONNX算子,被翻译成一个或多个ONNX算子。
  • 算子没有翻译成ONNX的规则。

ONNX 算子文档

onnx/Operators.md at main · onnx/onnx · GitHub

算子变更表格(算子名,算子变更版本号opset_version),第一次变更的版本号,表示算子第一次被支持,且第一个改动记录可以知道当前算子集中该算子的定义规则。

 表格中的链接可以说明该算子的输入输出参数规定使用示例。

PyTorch 对 ONNX 算子的映射

pytorch/torch/onnx at master · pytorch/pytorch · GitHub

 

symbloic_opset{n}.py表示pytorch对应的ONNX算子集版本。

在vscode中限定在torch/onnx文件夹搜索对应算子

 

 按照调用逻辑直接跳转到

@_onnx_symbolic(
    "aten::upsample_bicubic2d",
    decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
)
->

@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):
    return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)

->

@_beartype.beartype
def _interpolate_helper(name, dim, interpolate_mode):
    @quantized_args(True, False, False)
    def symbolic_fn(g, input, output_size, *args):
        ...

    return symbolic_fn

symbolic_fn中插值算子被映射成多个ONNX算子,一个g.op对应ONNX

return g.op(
                "Resize",
                input,
                empty_roi,
                empty_scales,
                output_size,
                coordinate_transformation_mode_s=coordinate_transformation_mode,
                cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                mode_s=interpolate_mode,  # nearest, linear, or cubic
                nearest_mode_s="floor",
            )  # only valid when mode="nearest"

 查找对应的ONNXonnx/Operators.md at main · onnx/onnx · GitHub resize算子定义,可以知道对应参数含义。

查询PyTorch到ONNX的映射关系,然后在torch.onnx.export()的opset_version设定一个版本号,然后去PyTorch符号表文件里去查。如果没有对应算子,就需要考虑用其他算子替代,或者自定义算子。

总结

  • 跟踪法和脚本化在导出待控制语句的计算图时有什么区别。
  • torch.onnx.export()中如何设置input_names, output_names, dynamic_axes。
  • 使用torch.onnx.is_in_onnx_export()来使得模型在转换到ONNX时有不同的行为。
  • 查询ONNX 算子文档。
  • 查询ONNX算子对PyTorch算子支持情况。
  • 查询ONNX算子对PyTorch算子使用方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/430153.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ERTEC200P-2 PROFINET设备完全开发手册(6-1)

6 报警和诊断 Profinet提供了强大的诊断功能&#xff0c;这是其他通讯协议所无法比拟的。PN设备检测到问题后可以向控制器发送报警信息&#xff0c;报警分为三大类&#xff1a; 诊断报警 &#xff08;PN设备本身故障触发的报警&#xff0c;例如&#xff1a;温度测量通道变送电…

Activiti学习02

这里写目录标题一、流对象简介1.1 事件1.2 活动1.3 条件二、Activiti系统服务结构图核心类:服务类:RepositoryServiceRuntimeServiceTaskServiceHistoryServiceFormServiceIdentityServiceManagementService三、Activiti数据库支持一、流对象简介 一个业务流程图有三个流对象的…

ATFX国际:中国一季度GDP同比增长4.5%,社消总额约11.5万亿元

ATFX国际&#xff1a;中国统计局发布一季度国民经济运行报告&#xff0c;其中值得关注两大数据分别为GDP同比增速、社会消费品零售总额增速。统计显示&#xff0c;一季度GDP总额28.5万亿元&#xff0c;同比增长4.5%&#xff0c;其中第一产业和第二产业的增速低于平均值&#xf…

Pyqt案例讲解(实现模拟计算器效果)

PyQt5是一个用于Python的GUI框架&#xff0c;它提供了一个简单易用的GUI工具包&#xff0c;可以用于创建各种类型的应用程序&#xff0c;包括计算器。下面是一个简单的计算器的实现&#xff0c;其中包括了一些难点和复杂的地方。 难点&#xff1a; 使用Qt的布局管理器来创建窗…

证书扫描件怎么弄?手机也能轻松扫描

现代社会中&#xff0c;证书是人们展示自己能力和经历的重要凭证。然而&#xff0c;我们有时需要将证书扫描并保存在电脑或手机中&#xff0c;以备不时之需。本文将介绍如何扫描证书以及手机上是否能进行扫描。 证书扫描的方法 将证书扫描成电子文档可以方便地将其存储在电脑或…

C++ Primer 第7章 类 - 中(零基础学习C++,精简学习笔记)

&#x1f916; 作者简介&#xff1a;努力的clz &#xff0c;一个努力编程的菜鸟 &#x1f423;&#x1f424;&#x1f425; &#x1f440; 文章专栏&#xff1a;C Primer 学习笔记 &#x1f4d4;专栏简介&#xff1a; 本专栏是博主学习 C Primer 的学习笔记&#xff0c;因为…

技巧:WIN10手动指定某个应用程序使用独立显卡

目录1. 背景2. 解决方法&#xff0c;假如要让剪映始终使用独立显卡2.1 步骤1&#xff0c;右击电脑桌面空白处&#xff0c;选择“显示设置”2.2 步骤2&#xff0c;拉到最下面&#xff0c;点击图形设置2.3 步骤3&#xff0c;选择桌面应用&#xff0c;点击浏览2.4 步骤4&#xff0…

领课在线教育系统源码 各行业都适用的分布式在线教育系统+支持讲师入驻功能

领课教育系统&#xff08;roncoo-education&#xff09;是基于领课网络多年的在线教育平台开发和运营经验打造出来的产品&#xff0c;致力于打造一个各行业都适用的分布式在线教育系统。系统采用前后端分离模式&#xff0c;前台采用vue.js为核心框架&#xff0c;后台采用Spring…

bash shell 无法使用 perl 正则

1.案例现象 前几天有一个小伙伴在群里求助&#xff0c;说他这个 shell 脚本有问题&#xff0c;让大家帮忙看看 #!/bin/bash regularExpression"^\[(\d)\].$" contentcat $1 for i in ${content} doif [[ $i ~ $regularExpression ]]thenecho -e "\033[32m 【 i…

一款多参数多合一的空气质量传感器【温湿度、TVOC甲醛CO2粉尘等】

档案馆库房专用的一款智能型空气质量云测仪 空气质量检测仪 空气质量传感器 环境集成传感器 集成/温湿度、粉尘PM2.5 PM10/甲醛/TVOC/CO2等高度集成的一款传感器/RS485信号输出 ◆温度测量参数: (1)温度测量范围: -40~80℃(2&#xff09;输出分辨率:0.1oC (3&#xff09;…

从零开始学架构——高性能负载均衡

高性能负载均衡 单服务器无论如何优化&#xff0c;无论采用多好的硬件&#xff0c;总会有一个性能天花板&#xff0c;当单服务器的性能无法满足业务需求时&#xff0c;就需要设计高性能集群来提升系统整体的处理性能。高性能集群的本质很简单——通过增加更多的服务器来提升系…

Pandas入门实践1 -初探

我们将开始介绍Series、DataFrame和Index类&#xff0c;它们是pandas的基本构建块&#xff0c;并展示如何使用它们。在本节结束时&#xff0c;您将能够创建DataFrame并对它们执行操作以检查和筛选数据。 DataFrame剖析 DataFrame由一个或多个Series组成。Series的名称构成列名…

( “树” 之 DFS) 111. 二叉树的最小深度 ——【Leetcode每日一题】

111. 二叉树的最小深度 给定一个二叉树&#xff0c;找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说明&#xff1a; 叶子节点是指没有子节点的节点。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;2…

matplotlib 笔记:subplot之间间距拉开

0 前情介绍 使用matplotlib的subplot时&#xff0c;由于默认间距不大&#xff0c;所以可能导致出的图会挤在一起 import matplotlib.pyplot as pltplt.subplot(221) plt.plot([1, 2, 3])plt.subplot(222) plt.bar([1, 2, 3], [4, 5, 6])plt.xlabel(xlabel, fontsize15, color…

码云私有仓库+宝塔面板部署WebHooks实现代码同步

权限问题&#xff0c;要分清楚两个帐号的权限www和root sudo -u www ssh -T gitgitee.com sudo -u root ssh -T gitgitee.com看清楚是用那个&#xff0c;建议用WWW帐号权限&#xff0c;不能用ROOT权限(最高权限不建议) 所以生成SSH是要注意要用那个帐号来生成&#xff1a; 一…

关于yolov7的一些理解

论文: https://arxiv.org/abs/2207.02696 Github: https://github.com/WongKinYiu/yolov7 YOLOV7的一些理解 1.摘要2.创新点3.具体工作3.1.网络结构优化3.2.辅助头训练3.3.标签分配策略3.4.重参数结构3.5.其它 1.摘要 Yolov7是Yolov4团队的作品&#xff0c;受到了yolo原作者…

Vue3瀑布流(Waterfall)

Vue2瀑布流&#xff08;Waterfall&#xff09; 可自定义设置以下属性&#xff1a; 图片数组&#xff08;images&#xff09;&#xff0c;类型&#xff1a;Array<{title: string, src: string}>&#xff0c;默认 [] 要划分的列数&#xff08;columnCount&#xff09;&a…

Linux下_多线程

线程 1. 为什么使用线程? 使用fork创建进程以执行新的任务&#xff0c;该方式的代价很高。多个进程间不会直接共享内存线程是进程的基本执行单元&#xff0c;一个进程的所有任务都在线程中执行&#xff0c;进程要想执行任务&#xff0c;必须得有线程&#xff0c;进程至少要有一…

11、响应数据

文章目录1、响应JSON1.1、引入开发场景1.2 、jackson.jar ResponseBody1、装填返回值处理器2、返回值初步处理3、获取并使用返回值处理器4、观察如何获取返回值处理器5、返回值处理器接口内部6、返回值处理器支持的类型7、返回值解析器原理1.3、HTTPMessageConverter 原理1、M…

c# 通过webView2模拟登陆小红书网页版,解析无水印视频图片,以及解决X-s,X-t签名验证【2023年4月15日】

一、c# WebView2简介 1.一开始使用WebBrowser&#xff0c;因为WebBrowser控件使用的是ie内核&#xff0c;经过修改注册表切换为Edge内核后&#xff0c; 发现Edge内核版本较低&#xff0c;加载一些视频网站提示“浏览器版本过低“&#xff0c;”视频无法加载“。 2.WebBrowser…