通过解读yolov5_gpu_optimization学习如何使用onnx_surgon

news2025/1/22 15:08:09

onnx实战一: 解析yolov5 gpu的onnx优化案例:
这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的。 通过这个案例理解原版的onnx的导出流程然后我们看英伟达是怎么拿gs来优化这个onnx

在这里插入图片描述

原版的export_onnx函数

先看torch.onnx.export函数的参数解释:

  1. model: 要导出的PyTorch模型, 在工程中这里输入的是训练好的pt文件

  2. im: 这里对应torch.onnx.export的args, 这个是用作模型输入的示例张量。这帮助ONNX确定输入的形状和类型。

  3. f: 输出ONNX模型的文件名或文件对象, 用来指定导出模型的路径和文件名。

  4. verbose (默认为 False): 如果设置为 True,则会打印出模型导出时的详细日志。

  5. opset_version: 导出的ONNX模型的操作集版本。不同的版本可能支持不同的操作。

  6. training:

    • torch.onnx.TrainingMode.TRAINING: 表示模型处于训练模式。
    • torch.onnx.TrainingMode.EVAL: 表示模型处于评估模式。
  7. do_constant_folding (默认为 True): 当设置为 True,导出过程中会尝试简化模型,将常量子图折叠为一个常量节点。

  8. input_names: 为模型的输入提供名称, 参数规定是数组

  9. output_names: 为模型的输出提供名称, 参数规定是数组

  10. dynamic_axes: 为模型的输入/输出定义动态轴。对于那些维度在推理时可能会发生变化的情况(例如,批处理大小),此参数允许指定哪些轴是动态的。这里images是输入, 本来是1x3x640x640, 这里通过指定把0, 2, 3维度变成了动态轴的输入, 第二个维度是3这个还是固定的。如果使用动态, 可以输入任意数量和任意大小的图片而不是规定的单张640x640

  • 'images': 对应的张量名称。
    • 0: 'batch': 表示第0个维度(即批处理维度)是动态的,并命名为’batch’。
    • 2: 'height': 表示第2个维度(即图像的高度)是动态的。
    • 3: 'width': 表示第3个维度(即图像的宽度)是动态的。
  • 'output': 对应的张量名称。
    • 0: 'batch': 表示第0个维度(即批处理维度)是动态的。
    • 1: 'anchors': 表示第1个维度是动态的。
  1. dynamic (没有在给定的函数调用中明确给出,但可以从上下文推断):
  • True: 如果你想让某些轴动态,你可以设置此参数为True
  • False: 表示不使用动态轴。

导出了onnx之后开始做onnxsim

  1. model_onnx, check = onnxsim.simplify(...):使用onnxsim的simplify方法简化模型。它返回简化后的onnx模型和一个布尔值check,表示简化是否成功。

  2. 在对动态输入的onnx导出的时候, dynamic_input_shape=dynamic是不够的,还要把输入给他,让onnxsim更加谨慎的优化onnx, 确保满足我们给他的输出,所以这里多了一个input_shapes={'images': list(im.shape)} if dynamic else None

def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
    # YOLOv5 ONNX export
    try:
        check_requirements(('onnx',))
        import onnx

        LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
        f = file.with_suffix('.onnx')

        torch.onnx.export(
            model,
            im,
            f,
            verbose=False,
            opset_version=opset,
            training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
            do_constant_folding=not train,
            input_names=['images'],
            output_names=['output'],
            dynamic_axes={
                'images': {
                    0: 'batch',
                    2: 'height',
                    3: 'width'},  # shape(1,3,640,640)
                'output': {
                    0: 'batch',
                    1: 'anchors'}  # shape(1,25200,85)
            } if dynamic else None)

        # Checks
        model_onnx = onnx.load(f)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model

        # Metadata
        d = {'stride': int(max(model.stride)), 'names': model.names}
        for k, v in d.items():
            meta = model_onnx.metadata_props.add()
            meta.key, meta.value = k, str(v)
        onnx.save(model_onnx, f)

        # Simplify
        if simplify:
            try:
                check_requirements(('onnx-simplifier',))
                import onnxsim

                LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
                model_onnx, check = onnxsim.simplify(model_onnx,
                                                     dynamic_input_shape=dynamic,
                                                     input_shapes={'images': list(im.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, f)
            except Exception as e:
                LOGGER.info(f'{prefix} simplifier failure: {e}')
        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        return f
    except Exception as e:
        LOGGER.info(f'{prefix} export failure: {e}')

更改过的export_onnx函数

  1. 首先是把onnx的输出由一个改成了3个, 然后指定动态输出, 因为有多个输出,全部都把他们的batch, width, height指定为动态的,满足不同的输入输出。 不过这边的问题是看起来是只改了最后的输出,但是前面在yolo.py的地方已经把sigmoid后面的计算都干掉了, 因为后面的计算映射了一堆的算子导致了在计算图太冗余

在这里插入图片描述
这一坨全部不要了就保留sigmoid就可以了,然后就是直接硬编码t就是int32

diff --git a/models/yolo.py b/models/yolo.py
index 02660e6..c810745 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -55,29 +55,15 @@ class Detect(nn.Module):
         z = []  # inference output
         for i in range(self.nl):
             x[i] = self.m[i](x[i])  # conv
-            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
-            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
-
-            if not self.training:  # inference
-                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
-                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
-
-                y = x[i].sigmoid()
-                if self.inplace:
-                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
-                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
-                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
-                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
-                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
-                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
-                    y = torch.cat((xy, wh, conf), 4)
-                z.append(y.view(bs, -1, self.no))
-
-        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
+            y = x[i].sigmoid()
+            z.append(y)
+        return z
 
     def _make_grid(self, nx=20, ny=20, i=0):
         d = self.anchors[i].device
-        t = self.anchors[i].dtype
+        # t = self.anchors[i].dtype
+        # TODO(tylerz) hard-code data type to int
+        t = torch.int32
         shape = 1, self.na, ny, nx, 2  # grid shape
         y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
         if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
-- 
2.36.0
  1. onnxsim这里跟之前是一样的, 也是直接onnxsim, 如果动态的要给输出给onnxsim然后让它更加的谨慎,满足需求

  2. 这后面的重点是增加了用onnx-surgon来更改onnx, 先把整个onnx导入进来,然后使用然后用Variable做模型的输出, 这里做四个模型输出, 分别是DecodeNumDetection, DecodeDetectionBoxes, DecodeDetectionScores, DecodeDetectionClasses

  3. 然后设置一个attrs, gs设置的attrs使用字典的格式弄的。这里设置max_stride, num_classes, anchors, prenms_score_threshold四个属性,这些属性的操作会在TensorRT中实现的

  4. decode_plugin是中间的节点,这个节点上面是inputs, 下面是四个不同的decodes, 这里就是把这个nodes做出来了

  5. 然后在整体的网络上添加这个节点,然后再把输出改成这个节点的输出保持一致,在计算图中把其他的节点claenup()清洁掉, 最后导出

def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
    # YOLOv5 ONNX export
    # try:
    check_requirements(('onnx',))
    import onnx

    LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
    f = file.with_suffix('.onnx')
    print(train)
    torch.onnx.export(
        model,
        im,
        f,
        verbose=False,
        opset_version=opset,
        training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
        do_constant_folding=not train,
        input_names=['images'],
        output_names=['p3', 'p4', 'p5'],
        dynamic_axes={
            'images': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,3,640,640)
            'p3': {
                0: 'batch',
                2: 'height',
                3: 'width'},  # shape(1,25200,4)
            'p4': {
                0: 'batch',
                2: 'height',
                3: 'width'},
            'p5': {
                0: 'batch',
                2: 'height',
                3: 'width'}
        } if dynamic else None)

    # Checks
    model_onnx = onnx.load(f)  # load onnx model
    onnx.checker.check_model(model_onnx)  # check onnx model
    
    # Simplify
    if simplify:
        # try:
        check_requirements(('onnx-simplifier',))
        import onnxsim

        LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
        model_onnx, check = onnxsim.simplify(model_onnx,
                                                dynamic_input_shape=dynamic,
                                                input_shapes={'images': list(im.shape)} if dynamic else None)
        assert check, 'assert check failed'
        onnx.save(model_onnx, f)
        # except Exception as e:
        #     LOGGER.info(f'{prefix} simplifier failure: {e}')

    # add yolov5_decoding:
    import onnx_graphsurgeon as onnx_gs
    import numpy as np
    yolo_graph = onnx_gs.import_onnx(model_onnx)
    p3 = yolo_graph.outputs[0]
    p4 = yolo_graph.outputs[1]
    p5 = yolo_graph.outputs[2]
    decode_out_0 = onnx_gs.Variable(
        "DecodeNumDetection",
        dtype=np.int32
    )
    decode_out_1 = onnx_gs.Variable(
        "DecodeDetectionBoxes",
        dtype=np.float32
    )
    decode_out_2 = onnx_gs.Variable(
        "DecodeDetectionScores",
        dtype=np.float32
    )
    decode_out_3 = onnx_gs.Variable(
        "DecodeDetectionClasses",
        dtype=np.int32
    )

    decode_attrs = dict()

    decode_attrs["max_stride"] = int(max(model.stride))
    decode_attrs["num_classes"] = model.model[-1].nc
    decode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]
    decode_attrs["prenms_score_threshold"] = 0.25

    decode_plugin = onnx_gs.Node(
        op="YoloLayer_TRT",
        name="YoloLayer",
        inputs=[p3, p4, p5],
        outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],
        attrs=decode_attrs
    )

    yolo_graph.nodes.append(decode_plugin)
    yolo_graph.outputs = decode_plugin.outputs
    yolo_graph.cleanup().toposort()
    model_onnx = onnx_gs.export_onnx(yolo_graph)

    d = {'stride': int(max(model.stride)), 'names': model.names}
    for k, v in d.items():
        meta = model_onnx.metadata_props.add()
        meta.key, meta.value = k, str(v)

    onnx.save(model_onnx, f)
    LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    return f
    # except Exception as e:
    #     LOGGER.info(f'{prefix} export failure: {e}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1041450.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

伦敦银如何选择最优的交易方法

经常有投资者会问,伦敦银投资中如何选择最好的方法呢?我们进行伦敦银投资,目的就是找到一个能够盈利的交易方法,它能够使我们大部分交易都是盈利,少部分交易亏损,但是可以将亏损控制在一定的范围之内&#…

Windows10关闭此电脑“桌面”“图片”“视频”“3D对象”“文档”等显示,只显示“设备与驱动器”

如何关闭下图"文件夹"等7个子文件夹,只显示“设备和驱动器”? 关闭步骤: 打开cmd,输入regedit打开注册表编辑器打开注册表编辑器后,定位到HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\My…

会声会影和剪映哪个好,2023年全新功能对比详细解析

随着网络视频的蓬勃发展,越来越多的人开始涉足视频剪辑领域,毕竟技多不压身嘛。在众多剪辑软件中,剪映和会声会影是备受新手青睐的两种。那么,会声会影和剪映哪个好呢?在它们之间,哪一个更适合初学者呢接&a…

软件测试之Web安全测试详解

前言 随着互联网时代的蓬勃发展,基于Web环境下的应用系统、应用软件也得到了越来越广泛的使用。 目前,很多企业的业务发展都依赖于互联网,比如,网上银行、网络购物、网络游戏等。但,由于很多恶意攻击者想通过截获他人…

ASCII码-对照表

ASCII 1> ASCII 控制字符2> ASCII 显示字符3> 常用ASCII码3.1> 【CR】\r 回车符3.2> 【LF】\n 换行符3.3> 不同操作系统,文件中换行 1> ASCII 控制字符 2> ASCII 显示字符 3> 常用ASCII码 3.1> 【CR】‘\r’ 回车符 CR Carriage Re…

如何通过git指令加入管理者仓库并提交分支(Github Gitee)

文章目录 创建GitHub、Gitee账户安装git下载gitgit基础配置 管理者创建gitee仓库新建仓库配置公钥 管理者管理仓库开发者通过git指令提交git提交错误原因: 创建GitHub、Gitee账户 GitHub: https://github.com/ Gitee : https://gitee.com/ …

redis做缓存(cache)

什么是缓存 缓存(Cache)的核心思路就是把一些常用的数据放到访问速度更快的地方,方便获取。关于硬件的访问速度来说 CPU寄存器>内存>硬盘>网络 因此常见使用内存作为硬盘的缓存,例如redis。使用硬盘作为网络的缓存,例如浏览器通过h…

2023-9-25 货仓选址

题目链接&#xff1a;货仓选址 #include <iostream> #include <algorithm>using namespace std;const int N 100010;int n; int a[N];int main() {cin >> n;for(int i 0; i < n; i ) cin >> a[i];sort(a, a n);int res 0;for(int i 0; i < …

生产管理看板系统助力高压线束生产车间实现生产任务的可视化管理

随着企业对生产智能化的追求不断提升&#xff0c;生产现场设备联网进行数据采集成为实现生产智能化的第一步&#xff0c;也是打造并实现数字化工厂最基础的一步。在这个过程中&#xff0c;生产管理看板系统发挥着重要的作用&#xff0c;能够实时在线统计车间工业生产设备的运行…

删除链表的倒数第N个节点-双指针法

【题目描述】 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 【示例】 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5] 输入&#xff1a;head [1], n 1 输出&#xff1a;[] 输入&#xff1a;head [1,2], n …

CHAPTER 11: DESIGN A NEWS FEED SYSTEM

Step 1 - Understand the problem and establish design scope Candidate: Is this a mobile app? Or a web app? Or both? Interviewer: Both Candidate: What are the important features? Interview: A user can publish a post and see her friends’ posts on the ne…

《红警3》因计算机中丢失d3dx9_35.dll无法打开游戏怎么办?最新解决方法推荐

d3dx9_35.dll 是 DirectX 9.0c 的一部分&#xff0c;它是一个动态链接库 (DLL)&#xff0c;包含了许多用于支持 DirectX 9.0c 功能的函数和类。DirectX 是一种由微软开发的游戏和多媒体应用程序编程接口&#xff0c;它提供了许多功能&#xff0c;如 3D 图形、音频、输入等&…

电子器件系列55:lm339比较器

以这个比较器为例 电压比较器可以看作是放大倍数接近“无穷大”的运算放大器。 电压比较器的功能&#xff1a;比较两个电压的大小(用输出电压的高或低电平&#xff0c;表示两个输入电压的大小关系)&#xff1a; 当””输入端电压高于”&#xff0d;”输入端时&#xff0c;电压…

阿里巴巴中国站获得1688商品详情 API 返回值说明

1688商品详情API接口可以获得1688商品详情原数据。 这个API接口有两种参数&#xff0c;公共参数和请求参数。 公共参数有以下几个&#xff1a; apikey&#xff1a;这是您自己的API密钥&#xff0c;可以在1688开发者中心获取。 请求参数有以下几个&#xff1a; num_iid&…

115V/400Hz 中频交流航空电源系统测试负载箱

中频交流航空电源系统测试负载箱主要面向战斗机、教练机以及民航飞机的生产及使用单位&#xff0c;用于对航空电源系统&#xff08;28V低压直流电源系统、270V/540V高压电源系统和115V/230V三相400HZ交流电源系统&#xff09;进行维护测试、功能性验证、可靠性负载试验。 系统加…

uni-app:实现图片周围的图片按照圆进行展示

效果 代码 <template><view class"position"><view class"circle"><img src"/static/item1.png" class"center-image"><view v-for"(item, index) in itemList" :key"index" class&q…

通俗讲解MobileNet-v1/v2/v3网络

MobileNet网络是由google团队在2017年提出的&#xff0c;专注于移动端或者嵌入式设备中的轻量级CNN网络。相比传统卷积神经网络&#xff0c;在准确率小幅降低的前提下大大减少模型参数与运算量。(相比VGG16准确率减少了0.9%&#xff0c;但模型参数只有VGG的1/32)。MobileNet网络…

智慧电力平台打造无人值守配电房、变电所

随着科技的发展&#xff0c;电力行业也在不断进步。为了提高电力供应的可靠性和效率&#xff0c;智慧电力平台应运而生。通过智慧电力平台&#xff0c;打造无人值守配电房和变电所成为行业趋势。 一、无人值守配电房和变电所的概念 无人值守配电房和变电所是指通过数字化、…

夜莺启动时报dialector() not supported

sudo nohup /opt/n9e/n9e &>> /opt/n9e/n9e.log &启动的时候&#xff0c;ss -tlnp|grep 17000查看一下监控端口状态&#xff0c;发现提示的信息是[1] Exit 1 sudo nohup /opt/n9e/n9e &>> /opt/n9e/n9e.log&#xff0c;这表明没有正常启动。 原来以为…

linux系统中通过docker安装python包

前提 已经安装docker 已经安装python,python镜像是jetz_python3.7.13 1、将需要安装的python包都写入txt文件中 2、拷贝文件 (1)、通过已安装的python包找到python安装路径 首次安装镜像后,容器启动,进入容器中 docker run -it --name py37 jetz_python3.7.13 /bin/ba…