TRT4-trt-integrate - 1 YOLOV5导出、编译、推理

news2025/1/10 16:59:22

 模型导出

 修改Image的Input动态维度

首先可以看到这个模型导出的时候Input有三个维度都是动态,而我们之前说过只需要一个batch维度是动态,所以要在export的export onnx 进行修改,将

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
                          training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                          do_constant_folding=not train,
                          input_names=['images'],
                          output_names=['output'],
                          dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                                        'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
                                        } if dynamic else None)

改为:

        torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
                          training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                          do_constant_folding=not train,
                          input_names=['images'],
                          output_names=['output'],
                          dynamic_axes={'images': {0: 'batch'},  # shape(1,3,640,640)
                                        'output': {0: 'batch'}  # shape(1,25200,85)
                                        } if dynamic else None)

修改完的已经变成了只有batch是动态维度 。

修改output

而且也可以看到,这里的output输出有四个tensor,其中三个都是fpn结构,80*80 , 40*40 , 20*20,这些我们在这里去掉,仅保留拼接后的结果。

将yolov5-6.0/models/yolo.py中的Class detect修改:

return x if self.training else (torch.cat(z, 1), x)

-->

return x if self.training else torch.cat(z, 1)

修改完毕的output仅保留拼接后的结果。

剪去多余节点:

之后发现这个onnx还是很丑啊

发现真正导致变丑的原因在于这些节点 比如Gather。所以下一步就是要干掉他。

这一步就是将:

      for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

改为:

            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = map(int,x[i].shape)  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

可以看到明显改善了不少。

调整reshape

 但是在reshape中可以看到中间维度是-1,而我们要的是batch是-1

 

 而bs是-1 , y.view还有个-1,这肯定是不行的,那么我们就要手动把这个计算出来,首先y的shape和x的shape一样,x的shape是bs*self.na*self.no*ny*nx,那么这里就是y.view(bs , self.na*nx*ny,self,no)

之后保存再次导出,可以看到已经变成batch的-1了

修改多余节点:

 但是发现还有比如expand这种节点,推断可能是由于数据跟踪引起的

    def _make_grid(self, nx=20, ny=20, i=0):
        d = self.anchors[i].device
        yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
        grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float()
        anchor_grid = (self.anchors[i].clone() * self.stride[i]) \
            .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float()

所以在这里就没必要每个都保存起来,直接给一个常量值就可以。

anchor_grid = (self.anchors[i].clone() * self.stride[i]).view(1,-1,1,1,2)

然后将所有用到self.anchor_grid的部分都替换为anchor_grid。

这样看起来就变成很平整的样子了

 

刚刚的那一大堆就变成了1*3*1*1*2这样子的常量值,这个kind是Initializer,就是常量的这个意思。

CPP推理过程:

TRT:

 YOLO:

 可以看到置信度稍微有一些区别。

输入input作warpaffine:

因为我们的输入是一个确认了的输入:是640*640。所以要对图像做一个类似warpaffine。就是等比缩放剧中填充

 


    ///
    // letter box
    auto image = cv::imread("car.jpg");
    // 通过双线性插值对图像进行resize
    float scale_x = input_width / (float)image.cols;
    float scale_y = input_height / (float)image.rows;
    float scale = std::min(scale_x, scale_y);
    float i2d[6], d2i[6];
    // resize图像,源图像和目标图像几何中心的对齐
    i2d[0] = scale;  i2d[1] = 0;  i2d[2] = (-scale * image.cols + input_width + scale  - 1) * 0.5;
    i2d[3] = 0;  i2d[4] = scale;  i2d[5] = (-scale * image.rows + input_height + scale - 1) * 0.5;

    cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);  // image to dst(network), 2x3 matrix
    cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);  // dst to image, 2x3 matrix
    cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);  // 计算一个反仿射变换
//为什么要计算逆矩阵:因为正矩阵是图像变成warpaffine的过程,逆变换是把框变回到图像尺度的过程
    cv::Mat input_image(input_height, input_width, CV_8UC3);
    cv::warpAffine(image, input_image, m2x3_i2d, input_image.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));  // 对图像做平移缩放旋转变换,可逆,填充全是常量值,114
    cv::imwrite("input-image.jpg", input_image);
//存储一下warpaffine效果
    int image_area = input_image.cols * input_image.rows;
    unsigned char* pimage = input_image.data;
    float* phost_b = input_data_host + image_area * 0;
    float* phost_g = input_data_host + image_area * 1;
    float* phost_r = input_data_host + image_area * 2;
    for(int i = 0; i < image_area; ++i, pimage += 3){
        // 注意这里的顺序rgb调换了
        *phost_r++ = pimage[0] / 255.0f;
        *phost_g++ = pimage[1] / 255.0f;
        *phost_b++ = pimage[2] / 255.0f;
    }
    ///
    checkRuntime(cudaMemcpyAsync(input_data_device, input_data_host, input_numel * sizeof(float), cudaMemcpyHostToDevice, stream));

存储一下warpaffine效果

之后就是作推理:


    // 3x3输入,对应3x3输出
    auto output_dims = engine->getBindingDimensions(1);
    int output_numbox = output_dims.d[1];
    int output_numprob = output_dims.d[2];
    int num_classes = output_numprob - 5;//类别数
    int output_numel = input_batch * output_numbox * output_numprob;
    float* output_data_host = nullptr;
    float* output_data_device = nullptr;
    checkRuntime(cudaMallocHost(&output_data_host, sizeof(float) * output_numel));
    checkRuntime(cudaMalloc(&output_data_device, sizeof(float) * output_numel));

    // 明确当前推理时,使用的数据输入大小
    auto input_dims = engine->getBindingDimensions(0);
    input_dims.d[0] = input_batch;

    execution_context->setBindingDimensions(0, input_dims);
    float* bindings[] = {input_data_device, output_data_device};
    bool success      = execution_context->enqueueV2((void**)bindings, stream, nullptr);
    checkRuntime(cudaMemcpyAsync(output_data_host, output_data_device, sizeof(float) * output_numel, cudaMemcpyDeviceToHost, stream));
    checkRuntime(cudaStreamSynchronize(stream));

这个结果就是我们之前YOLOV5的predict(https://blog.csdn.net/zhuangtu1999/article/details/131499750?spm=1001.2014.3001.5501)

但在这里根之前不太一样了:

vector<vector<float>> bboxes;
    float confidence_threshold = 0.25;
    float nms_threshold = 0.5;
    for(int i = 0; i < output_numbox; ++i){
        float* ptr = output_data_host + i * output_numprob;
        float objness = ptr[4];
        if(objness < confidence_threshold)
            continue;

        float* pclass = ptr + 5;
        int label     = std::max_element(pclass, pclass + num_classes) - pclass;
        float prob    = pclass[label];
        float confidence = prob * objness;
        if(confidence < confidence_threshold)
            continue;

        // 中心点、宽、高
        float cx     = ptr[0];
        float cy     = ptr[1];
        float width  = ptr[2];
        float height = ptr[3];

        // 预测框
        float left   = cx - width * 0.5;
        float top    = cy - height * 0.5;
        float right  = cx + width * 0.5;
        float bottom = cy + height * 0.5;

        // 对应图上的位置
        float image_base_left   = d2i[0] * left   + d2i[2];
        float image_base_right  = d2i[0] * right  + d2i[2];
        float image_base_top    = d2i[0] * top    + d2i[5];
        float image_base_bottom = d2i[0] * bottom + d2i[5];
        bboxes.push_back({image_base_left, image_base_top, image_base_right, image_base_bottom, (float)label, confidence});
    }
    printf("decoded bboxes.size = %d\n", bboxes.size());

这里的预测框,left,top等等对应的是warpaffine之后的图片,但我们要做的是把他在原来的图片上加入回来,所以还要做一个反变换的过程。

这里也是我们值前提到过,咱们只有缩放和平移的时候,有效的参数只有三个:scale , dx , dy,这里对应的就是d2i[0] , d2i[2] , d2i[5]。

在之后就是nms:


    // nms非极大抑制
    std::sort(bboxes.begin(), bboxes.end(), [](vector<float>& a, vector<float>& b){return a[5] > b[5];});
    std::vector<bool> remove_flags(bboxes.size());
    std::vector<vector<float>> box_result;
    box_result.reserve(bboxes.size());

    auto iou = [](const vector<float>& a, const vector<float>& b){
        float cross_left   = std::max(a[0], b[0]);
        float cross_top    = std::max(a[1], b[1]);
        float cross_right  = std::min(a[2], b[2]);
        float cross_bottom = std::min(a[3], b[3]);

        float cross_area = std::max(0.0f, cross_right - cross_left) * std::max(0.0f, cross_bottom - cross_top);
        float union_area = std::max(0.0f, a[2] - a[0]) * std::max(0.0f, a[3] - a[1]) 
                         + std::max(0.0f, b[2] - b[0]) * std::max(0.0f, b[3] - b[1]) - cross_area;
        if(cross_area == 0 || union_area == 0) return 0.0f;
        return cross_area / union_area;
    };

    for(int i = 0; i < bboxes.size(); ++i){
        if(remove_flags[i]) continue;

        auto& ibox = bboxes[i];
        box_result.emplace_back(ibox);
        for(int j = i + 1; j < bboxes.size(); ++j){
            if(remove_flags[j]) continue;

            auto& jbox = bboxes[j];
            if(ibox[4] == jbox[4]){
                // class matched
                if(iou(ibox, jbox) >= nms_threshold)
                    remove_flags[j] = true;
            }
        }
    }
    printf("box_result.size = %d\n", box_result.size());

通过cv::rectangle画框:

    for(int i = 0; i < box_result.size(); ++i){
        auto& ibox = box_result[i];
        float left = ibox[0];
        float top = ibox[1];
        float right = ibox[2];
        float bottom = ibox[3];
        int class_label = ibox[4];
        float confidence = ibox[5];
        cv::Scalar color;
        tie(color[0], color[1], color[2]) = random_color(class_label);//通过标签随机选择颜色
        cv::rectangle(image, cv::Point(left, top), cv::Point(right, bottom), color, 3);

        auto name      = cocolabels[class_label];
        auto caption   = cv::format("%s %.2f", name, confidence);
        int text_width = cv::getTextSize(caption, 0, 1, 2, nullptr).width + 10;
        cv::rectangle(image, cv::Point(left-3, top-33), cv::Point(left + text_width, top), color, -1);
        cv::putText(image, caption, cv::Point(left, top-5), 0, 1, cv::Scalar::all(0), 2, 16);
    }
    cv::imwrite("image-draw.jpg", image);

    checkRuntime(cudaStreamDestroy(stream));
    checkRuntime(cudaFreeHost(input_data_host));
    checkRuntime(cudaFreeHost(output_data_host));
    checkRuntime(cudaFree(input_data_device));
    checkRuntime(cudaFree(output_data_device));
}

总结:

在这次过程中,warpaffine(预处理)和后处理过程都可以使用我们之前的核函数去处理,这一部分打包到GPU上的话性能会变得更高。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/771696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

百度翻译申请KEY和ID

1.进入百度翻译网址:https://api.fanyi.baidu.com/ 2.右上角登录账号 3.跟着点点点 填写申请信息&#xff0c;剩下不用写 4.跳转到https://api.fanyi.baidu.com/api/trans/product/desktop 在底部查看KEY和ID

每日一题2023.7.19|ACM模式

文章目录 C的输入方式介绍cin>>cin.get(字符变量名)cin.get(数组名,接收字符数目)cin.get()cin.getline() getline()gets()getchar() AB问题|AB问题||AB问题|||ABⅣAB问题ⅤAB问题Ⅵ C的输入方式介绍 参考博客 cin>> 最基本&#xff0c;最常用的字符或者数字的输…

产品管理必备工具:选择最适合你的工具,让产品管理更高效!

Zoho Projects是一个能够帮助企业组织高效研发工作、快速推向市场并赢得用户青睐的有效工具。通过以下六个步骤&#xff0c;企业可以最大化地利用Zoho Projects&#xff0c;实现高效的产品研发和运营。 第一步&#xff1a;规划产品路线 在甘特图上勾画产品路线图&#xff0c;为…

STM32单片机示例:多个定时器同步触发启动

文章目录 前言基础说明关键配置与代码其它补充示例链接 前言 多个定时器同步触发启动是一种比较实用的功能&#xff0c;这里将对此做个示例说明。 基础说明 该示例演示通过一个TIM使能时同步触发使能另一个TIM。 本例中使用TIM1作为主机&#xff0c;使用TIM1的使能信号作为…

OpenCv之图像直方图

目录 一、基本概念 二、使用OpenCv统计直方图 三、使用掩膜的直方图 一、基本概念 图像直方图是用一表示教字图像中亮度分布的直方图&#xff0c;标绘了图像中每个高度值的像素数。可以借助观察该有方图了解需要如何调整亮度分布的直方图。这种直方图中&#xff0c;横坐标的左…

Android 个人开发者如何接入广告SDK,实现app流量变现

接入广告的APP连接 大家可以下载看看&#xff08;无需积分&#xff09; 链接: https://download.csdn.net/download/qq_38355313/88063389 开屏广告示意图&#xff1a; 1.个人开发者如何添加广告SDK&#xff1f; 像大厂的广告SDK&#xff0c;比如穿山甲SDK&#xff0c;点广…

SpringMvc配置静态资源访问路径

文章目录 1. 整体流程2. registry.addResourceHandler()2.1 函数分析2.2 结果演示 3. ResourceHandlerRegistration.addResourceLocations()3.1 函数分析3.2 结果演示 1. 整体流程 1. 写一个配置类继承WebMvcConfigurationSupport 2. 利用 registry.addResourceHandler("…

ylb-接口4投资排行榜

总览&#xff1a; 1、使用Redis存储投资信息 2、Redis常量类 在common模块constants包&#xff0c;创建一个Redis常量类&#xff08;RedisKey&#xff09;&#xff1a; package com.bjpowernode.common.constants;public class RedisKey {/*投资排行榜*/public static fin…

【雕爷学编程】Arduino动手做(164)---Futaba S3003舵机模块3

37款传感器与模块的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&#x…

计算机vcruntime140.dll丢失的解决方法,重新安装教程

vcruntime140.dll是Microsoft Visual C Redistributable文件中的一个动态链接库&#xff08;DLL&#xff09;。这个文件是由Microsoft开发的&#xff0c;用于支持C编程语言的运行环境。vcruntime140.dll是Windows系统非常重要的文件&#xff0c;通常会被一些应用程序或游戏所需…

JS-26 认识防抖和节流函数;自定义防抖、节流函数;自定义深拷贝、事件总线函数

目录 1_防抖和节流1.1_认识防抖和节流函数1.2_认识防抖debounce函数1.3_防抖函数的案例1.4_认识节流throttle函数 2_Underscore实现防抖和节流2.1_Underscore实现防抖和节流2.2_自定义防抖函数2.3_自定义节流函数 3_自定义深拷贝函数4_自定义事件总线 1_防抖和节流 1.1_认识防…

【源码解析】Mybatis执行原理

Mybatis执行原理 1.获取SqlSessionFactory2.创建SqlSession3.创建Mapper、执行SQL MyBatis 是一款优秀的持久层框架&#xff0c;MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以使用简单的 XML 或注解来配置和映射原生信息&#xff0c;将接口和…

深入篇【C++】谈vector中的深浅拷贝与迭代器失效问题

深入篇【C】谈vector中的深浅拷贝与迭代器失效问题 Ⅰ.深浅拷贝问题1.内置类型深拷贝2.自定义类型深拷贝 Ⅱ.迭代器失效问题1.内部迭代器失效2.外部迭代器失效 Ⅰ.深浅拷贝问题 1.内置类型深拷贝 浅拷贝是什么意思&#xff1f;就是单纯的值拷贝。 浅拷贝的坏处&#xff1a; ①…

❤️创意网页:HTML5,canvas创作科技感粒子特效(科技感粒子、js鼠标跟随、粒子连线)

✨博主&#xff1a;命运之光 &#x1f338;专栏&#xff1a;Python星辰秘典 &#x1f433;专栏&#xff1a;web开发&#xff08;简单好用又好看&#xff09; ❤️专栏&#xff1a;Java经典程序设计 ☀️博主的其他文章&#xff1a;点击进入博主的主页 前言&#xff1a;欢迎踏入…

力扣 452. 用最少数量的箭引爆气球

题目来源&#xff1a;https://leetcode.cn/problems/minimum-number-of-arrows-to-burst-balloons/description/ C题解1&#xff1a; 根据x_end排序&#xff0c;x_start小的在前&#xff0c;这样可以保证如果第 i 个球的x_end大于等于第 j 个球的x_start时&#xff0c;第 j 个球…

TabBar和TabBarView实现顶部滑动导航

home.dart子页面主要代码&#xff1a; import package:flutter/material.dart;class HomePage extends StatefulWidget {const HomePage({super.key});overrideState<HomePage> createState() > _HomePageState(); }class _HomePageState extends State<HomePage&…

WMTS 地图切片Web服务 协议数据解析

1. WMTS 描述 WMTS(Web Map Tiles Service):地图切片Web服务。 2. 数据示例&#xff1a; arcgis online 导出的wmts xml&#xff1a; https://sampleserver6.arcgisonline.com/arcgis/rest/services/WorldTimeZones/MapServer/WMTS 内容解析&#xff1a; contents中可能包…

hybridCLR热更遇到问题

报错1&#xff1a; No ‘git‘ executable was found. Please install Git on your system then restart 下载Git安装&#xff1a; Git - Downloading Package 配置&#xff1a;https://blog.csdn.net/baidu_38246836/article/details/106812067 重启电脑 unity&#xff1a;…

Mysql数据库日志和数据的备份恢复

目录 一、数据库备份的重要性 二、数据库备份的分类 三、常见的备份方法 1.物理冷备 2.专用备份工具 3.启用二进制日志进行增量备份 4.第三方工具备份 四、完全备份 1.简介 2.优缺点 五、完全备份与恢复 1.物理冷备份与恢复 2.mysqldump备份与恢复 六、mysql日志管…

SSRF漏洞(原理、挖掘点、漏洞利用、修复建议)

一、介绍SSRF漏洞 SSRF (Server-Side Request Forgery,服务器端请求伪造)是一种由攻击者构造请求&#xff0c;由服务端发起请求的安全漏洞。一般情况下&#xff0c;SSRF攻击的目标是外网无法访问的内部系统(正因为请求是由服务端发起的&#xff0c;所以服务端能请求到与自身相…