Pytorch~ONNX

news2025/1/13 13:36:58

pytorch转onnx其实也就是python转的 ,之前有个帖子了讲的怎么操作,这个就是在说说为什么这么做~~~

(1)Pytorch转ONNX的意义

一般来说转ONNX只是一个手段,在之后得到ONNX模型后还需要再将它做转换,比如转换到TensorRT上完成部署,或者有的人多加一步,从ONNX先转换到caffe,再从caffe到tensorRT。原因是Caffe对tensorRT更为友好,这里关于友好的定义后面会谈。

因此在转ONNX工作开展之前,首先必须明确目标后端。ONNX只是一个格式,就和json一样。只要你满足一定的规则,都算是合法的,因此单纯从Pytorch转成一个ONNX文件很简单。但是不同后端设备接受的onnx是不一样的,因此这才是坑的来源。

Pytorch自带的torch.onnx.export转换得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。

这里面举一个最简单的Maxpool的例:

Maxunpool可以被看作Maxpool的逆运算,咱们先来看一个Maxpool的例子,假设有如下一个C*H*W的tensor(shape[2, 3, 3]),其中每个channel的二维矩阵都是一样的,如下所示

在这种情况下,如果我们在Pytorch对它调用MaxPool(kernel_size=2, stride=1,pad=0)

那么会得到两个输出,第一个输出是Maxpool之后的值:

另一个是Maxpool的Idx,即每个输出对应原来的哪个输入,这样做反向传播的时候就可以直接把输出的梯度传给对应的输入: 

细心的同学会发现其实Maxpool的Idx还可以有另一种写法: 

即每个channel的idx放到一起,并不是每个channel单独从0开始。这两种写法都没什么问题,毕竟只要反向传播的时候一致就可以。

但是当我在支持OpenMMEditing的时候,会涉及到Maxunpool,即Maxpool的逆运算:输入MaxpoolId和Maxpool的输出,得到Maxpool的输入。

Pytorch的MaxUnpool实现是接收每个channel都从0开始的Idx格式,而Onnxruntime则相反。因此如果你希望用Onnxruntime跑一样的结果,那么必须对输入的Idx(即和Pytorch一样的输入)做额外的处理才可以。换言之,Pytorch转出来的神经网络图和ONNXRuntime需要的神经网络图是不一样的。

(2)ONNX与Caffe

主流的模型部署有两种路径,以TensorRT为例,一种是Pytorch->ONNX->TensorRT,另一种是Pytorch->Caffe->TensorRT。个人认为目前后者更为成熟,这主要是ONNX,Caffe和TensorRT的性质共同决定的

上面的表列了ONNX和Caffe的几点区别,其中最重要的区别就是op的粒度。举个例子,如果对Bert的Attention层做转换,ONNX会把它变成MatMul,Scale,SoftMax的组合,而Caffe可能会直接生成一个叫做Multi-Head Attention的层,同时告诉CUDA工程师:“你去给我写一个大kernel“(很怀疑发展到最后会不会把ResNet50都变成一个层。。。)

 

因此如果某天一个研究员提了一个新的State-of-the-art的op,很可能它直接就可以被转换成ONNX(如果这个op在Pytorch的实现全都是用Aten的库拼接的),但是对于Caffe的工程师,需要重新写一个kernel。

细粒度op的好处就是非常灵活,坏处就是速度会比较慢。这几年有很多工作都是在做op fushion(比如把卷积和它后面的relu合到一起算),XLA和TVM都有很多工作投入到了op fushion,也就是把小op拼成大op。

TensorRT是NVIDIA推出的部署框架,自然性能是首要考量的,因此他们的layer粒度都很粗。在这种情况下把Caffe转换过去有天然的优势。

除此之外粗粒度也可以解决分支的问题。TensorRT眼里的神经网络就是一个单纯的DAG:给定固定shape的输入,执行相同的运算,得到固定shape的输出。

**目前TensorRT的一个发展方向是支持dynamic shape,但是还很不成熟。

tensor i = funcA();
if(i==0)
j = funcB(i);
else
j = funcC(i);
funcD(j);

对于上面的网络,假设funcA,funcB,funcC和funcD都是onnx支持的细粒度算子,那么ONNX就会面临一个困难,它转换得到的DAG要么长这样:funcA->funcB->funcD,要么funcA->funcC->funcD。但是无论哪种肯定都是有问题的。

而Caffe可以用粗粒度绕开这个问题

开这个问题

tensor i = funcA();
coarse_func(tensor i) {
if(i==0) return funcB(i);
else return funcC(i);
}
funcD(coarse_func(i))

因此它得到的DAG是:funcA->coarse_func->funcD

当然,Caffe的代价就是苦逼的HPC工程师就要手写一个coarse_func kernel。。。(希望Deep Learning Compiler可以早日解放HPC工程师)

(3)Pytorch本身的局限

熟悉深度学习框架的同学都知道,Pytorch之所以可以在tensorflow已经占据主流的情况下横空出世,成功抢占半壁江山,主要的原因是它很灵活。举个不恰当的例子,tensorflow就像是C++,而Pytorch就是Python。

tensorflow会把整个神经网络在运行前做一次编译,生成一个DAG(有向无环图),然后再去跑这张图。Pytorch则相反,属于走一步看一步,直到运行到这个节点算出结果,才知道下一个节点该算啥。

ONNX其实就是把上层深度学习框架中的网络模型转换成一张图,因为tensorflow本身就有一张图,因此只需要直接把这张图拿到手,修修补补就可以。

但是对于Pytorch,没有任何图的概念,因此如果想完成Pytorch到ONNX的转换,就需要让ONNX再旁边拿个小本子,然后跑一遍Pytorch,跑到什么就把什么记下来,把记录的结果抽象成一张图。因此Pytorch转ONNX有两个天然的局限。

1. 转换的结果只对特定的输入。如果换一个输入导致网络结构发生了变化,ONNX是无法察觉的(最常见的情况是如果网络中有if语句,这次的输入走了if的话,ONNX就只会生成if对应的图,把else里面全部的信息都丢掉)。

2. 需要比较多的计算量,因为需要真刀真枪的跑一遍神经网络。

PS:针对于以上的两个局限,我的本科毕设论文提出了一种解决方案,就是通过编译器里面的词法分析,语法分析直接扫描Pytorch或者tensorflow的源代码得到图结构,这样可以轻量级的完成模型到ONNX的转换,同时也可以得到分支判断等信息,这里放一个github链接(https://github.com/drcut/NN_transform),希望大家多多支持

*目前Pytorch官方希望通过用TorchScript的方式解决分支语句的问题,但据我所知还不是很成熟。

whaosoft aiot http://143ai.com   


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/110247.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

韩顺平java-枚举和注解异常包装类

文章目录11章 枚举和注解11.1枚举11.2注解12章 异常12.1 异常类型12.2异常处理1)try - catch - finally2)throws12.3 自定义异常13章 包装类wrapper13.1包装类13.2 String——不可变字符序列13.2 StringBuffer——可变字符序列13.3 StringBuilder13.4 Ma…

【深入浅出Spring原理及实战】「开发实战系列」SpringSecurity技术实战之通过注解表达式控制方法权限

Spring Security权限控制机制 Spring Security中可以通过表达式控制方法权限,其中有四个支持使用表达式的注解,分别是PreAuthorize、PostAuthorize、PreFilter和PostFilter。其中前两者可以用来在方法调用前或者调用后进行权限检查,后两者可…

蓝牙不正常因为 unmet condition check

蓝牙不正常因为 unmet condition check date: 2022-12-22lastmod: 2022-12-23 现象:蓝牙键盘鼠标均不工作,图标也不显示,KDE系统设置显示“无已配对设备”,但是配对设备的按钮没有显示,啥按钮也没有显示 事前&#x…

使用msf生成木马反弹shell(windows系统)

一、理论总结 使用msf进行木马攻击总共分为三步即可: 1、生成攻击木马 2、配置监控主机 3、上传木马使得靶机中招 1.1生成攻击木马 使用msf来生成木马是一个比较方便的事情,使用msfvenom即可,框架模版为: msfvenom -p wind…

我国户外广告行业现状 电梯广告触达率最高 传统户外媒体刊例花费下降

户外广告特指以具体形式展示广告或告示、宣传品的载体,主要包含LED显示屏、LED幕墙、门头招牌、广告字、户外(室内)灯箱、大型立牌,甚至喷绘印刷品等生产制作环节,以及发布公益公告、商业广告、标识标牌、指示导向等广告内容。户外广告是面向…

Python - Order in chaos 混乱中的秩序之随机点中值连线

一.引言 刷短视频刷到了一个有趣的图形变化,随机给定 N 个点,将 N 个点首尾连接生成一个多边形,随后将每个边的中点连接并得到新的多边形,如此多次循环,最终总会得到一个椭圆形。 A.初始化 N 个点并生成多边形 B.取多…

Transformer实现以及Pytorch源码解读(四)-Encoder层

Transformer结构图 先放一张原论文中的图。从inputs到Poitional Encoding在前三部分中已经分析清楚,接下来往后分析。 Pytorch中对Transformer的调用 Pytorch将图1中左半部分的神经网络层用一个TransformerEncdoer(encoder_layer,num_layers)类进行封装&#xf…

【Kotlin 协程】Flow 异步流 ⑥ ( 调用 Flow#launchIn 函数指定流收集协程 | 通过取消流收集所在的协程取消流 )

文章目录一、调用 Flow#launchIn 函数指定流收集协程1、指定流收集协程2、Flow#launchIn 函数原型3、代码示例二、通过取消流收集所在的协程取消流一、调用 Flow#launchIn 函数指定流收集协程 1、指定流收集协程 响应式编程 , 是 基于事件驱动 的 , 在 Flow 流中会产生源源不断…

MySQL的数据类型和存储引擎介绍

一. MySQL数据类型 1. 整数类型 注:MySQL可以为整数类型指定宽度,比如 int(3)、int(5),这个限制不是限制value的合法范围,所以对绝大数应用没有任何意义,对于存储而言,int(3) 和 int(5) 是相同的&#xff…

机器学习中的数学原理——随机梯度下降法

这个专栏主要是用来分享一下我在机器学习中的学习笔记及一些感悟,也希望对你的学习有帮助哦!感兴趣的小伙伴欢迎私信或者评论区留言!这一篇就更新一下《白话机器学习中的数学——随机梯度下降法》! 一、什么是随机梯度下降法 随机…

NVM安装

注意事项: 1、不能安装任何node版本(如存在请删除后安装nvm); 安装步骤: 1、下载nvm ![在这里插入图片描述](https://img-blog.csdnimg.cn/c9dcc27383aa41888347080438c0914e.png 解压后点击exe文件进行安装: &#x…

负载均衡简介

一、什么是负载均衡? 互联网早期,业务流量比较小并且业务逻辑比较简单,单台服务器便可以满足基本的需求;但随着互联网的发展,业务流量越来越大并且业务逻辑也越来越复杂,单台机器的性能问题以及单点问题凸显…

SMMP:一种基于稳定成员资格的多峰聚类算法(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

伦茨LENZE GDC操作指南

1、GDC软件综述 GDC程序可以“在线模式”和“离线模式”使用: 离线模式 可以在没有与目标系统(控制器)相连接条件下使用。该功能允许离线设定参数、编程等工作。 在线模式 通过PC的串口、并口或系统总线实现GDC与目标系统(控制器…

376. 机器任务——最小点覆盖+匈牙利算法

有两台机器 A,B 以及 K 个任务。 机器 A 有 N 种不同的模式(模式 0∼N−1),机器 B 有 M 种不同的模式(模式 0∼M−1)。 两台机器最开始都处于模式 0。 每个任务既可以在 A 上执行,也可以在 B…

艾美捷游离巯基检测试剂盒基本参数和特点说明

游离硫醇(即蛋白质上的游离半胱氨酸、谷胱甘肽和半胱氨酸残基)的检测和测量是研究许多生物系统中的生物过程和事件的基本任务之一。 艾美捷游离巯基检测试剂盒提供了一种简单、可重复和灵敏的工具,用于测定样品(即血浆、血清、组织…

3D格式转换工具HOOPS Exchange助力3D 打印软件实现质的飞跃

HOOPS SDK是用于3D工业软件开发的工具包,其中包括4款工具,分别是用于读取和写入30多种CAD文件格式的HOOPS Exchange、专注于Web端工程图形渲染的HOOPS Communicator、用于移动端和PC端工程图形渲染的HOOPS Visualize、支持将3D数据以原生3D PDF、HTML和标…

解决电脑C盘空间不足,发现微信和qq文件占用了大量内存

项目场景: 电脑C盘空间不足,需要隔一段时间清理垃圾,分析占用空间的文件,将C盘文件迁移到E盘。 问题描述 C盘提示空间不足 原因分析: 通过扫描磁盘发现微信和qq文件占用了几十G的内存,由于微信和qq的一…

C++成员函数当作参数调用的两种方式

平时编程时,多用来将数据进行传参,在考虑回调场景下我们会将函数单做参数传给被调用函数,让被调用函数在时机成熟时进行调用。在某些场景下,需要将类的成员函数当作参数进行回调,此时定义成员函数形参的方式通常有两种…

我的python学习经历及资源整理

对于小白来说,有个人引导会比自学要高效的多,尤其容易坚持不下去的小伙伴。可以试试下面这个入门课程,不用本地安装Python环境,能直接在网页上敲代码,还有大牛老师带着入门,能少走很多弯路!只要…