tensorRT C++使用pt转engine模型进行推理

news2025/1/10 6:02:06

目录

  • 1. 前言
  • 2. 模型转换
  • 3. 修改Binding
  • 4. 修改后处理

1. 前言

本文不讲tensorRT的推理流程,因为这种文章很多,这里着重讲从标准yolov5的tensort推理代码(模型转pt->wts->engine)改造成TPH-yolov5(pt->onnx->engine)的过程。

2. 模型转换

请查看上一篇文章https://blog.csdn.net/wyw0000/article/details/139737473?spm=1001.2014.3001.5502

3. 修改Binding

如果不修改Binding,会报下图中的错误。
在这里插入图片描述
该问题是由于Binding有多个,而代码中只申请了input和output,那么如何查看engine模型有几个Bingding呢?代码如下:

int get_model_info(const string& model_path) {
    // 创建 logger
    Logger gLogger;

    // 从文件中读取 engine
    std::ifstream engineFile(model_path, std::ios::binary);
    if (!engineFile) {
        std::cerr << "Failed to open engine file." << std::endl;
        return -1;
    }

    engineFile.seekg(0, engineFile.end);
    long int fsize = engineFile.tellg();
    engineFile.seekg(0, engineFile.beg);

    std::vector<char> engineData(fsize);
    engineFile.read(engineData.data(), fsize);
    if (!engineFile) {
        std::cerr << "Failed to read engine file." << std::endl;
        return -1;
    }

    // 反序列化 engine
    auto runtime = nvinfer1::createInferRuntime(gLogger);
    auto engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);

    // 获取并打印输入和输出绑定信息
    for (int i = 0; i < engine->getNbBindings(); ++i) {
        nvinfer1::Dims dims = engine->getBindingDimensions(i);
        nvinfer1::DataType type = engine->getBindingDataType(i);

        std::cout << "Binding " << i << " (" << engine->getBindingName(i) << "):" << std::endl;
        std::cout << "  Type: " << (int)type << std::endl;
        std::cout << "  Dimensions: ";
        for (int j = 0; j < dims.nbDims; ++j) {
            std::cout << (j ? "x" : "") << dims.d[j];
        }
        std::cout << std::endl;
        std::cout << "  Is Input: " << (engine->bindingIsInput(i) ? "Yes" : "No") << std::endl;
    }

    // 清理资源
    engine->destroy();
    runtime->destroy();

    return 0;
}

下图是我的tph-yolov5的Binding,可以看到有5个Binding,因此在doInference推理之前,要给5个Binding都申请空间,同时要注意获取BindingIndex时,名称和dimension与查询出来的对应。
在这里插入图片描述

//for tph-yolov5
    int Sigmoid_921_index = trt->engine->getBindingIndex("onnx::Sigmoid_921");
    int Sigmoid_1183_index = trt->engine->getBindingIndex("onnx::Sigmoid_1183");
    int Sigmoid_1367_index = trt->engine->getBindingIndex("onnx::Sigmoid_1367");
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_921_index], BATCH_SIZE * 3 * 192 * 192 * 7 * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1183_index], BATCH_SIZE * 3 * 96 * 96 * 7 * sizeof(float)));
    CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1367_index], BATCH_SIZE * 3 * 48 * 48 * 7 * sizeof(float)));

    trt->data = new float[BATCH_SIZE * 3 * INPUT_H * INPUT_W];
    trt->prob = new float[BATCH_SIZE * OUTPUT_SIZE];
    trt->inputIndex = trt->engine->getBindingIndex(INPUT_BLOB_NAME);
    trt->outputIndex = trt->engine->getBindingIndex(OUTPUT_BLOB_NAME);

还有推理的部分也要做修改,原来只有input和output两个Binding时,那么输出是buffers[1],而目前是有5个Binding那么输出就变成了buffers[4]

void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output, int batchSize) {
    // infer on the batch asynchronously, and DMA output back to host
    context.enqueueV2(buffers, stream, nullptr);
    //CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    CUDA_CHECK(cudaMemcpyAsync(output, buffers[4], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    cudaStreamSynchronize(stream);
}

4. 修改后处理

之前的yolov5推理代码是将pt模型转为wts再转为engine的,输出维度只有一维,而TPH输出维度为145152*7,因此要对原来的后处理代码进行修改。

struct BoundingBox {
    //bbox[0],bbox[1],bbox[2],bbox[3],conf, class_id
    float x1, y1, x2, y2, score, index;
};

float iou(const BoundingBox&  box1, const BoundingBox& box2) {
	float max_x = max(box1.x1, box2.x1);  // 找出左上角坐标哪个大
	float min_x = min(box1.x2, box2.x2);  // 找出右上角坐标哪个小
	float max_y = max(box1.y1, box2.y1);
	float min_y = min(box1.y2, box2.y2);
	if (min_x <= max_x || min_y <= max_y) // 如果没有重叠
		return 0;
	float over_area = (min_x - max_x) * (min_y - max_y);  // 计算重叠面积
	float area_a = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
	float area_b = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);
	float iou = over_area / (area_a + area_b - over_area);
	return iou;
}

std::vector<BoundingBox> nonMaximumSuppression(std::vector<std::vector<float>>& boxes, float overlapThreshold) {
	std::vector<BoundingBox> convertedBoxes;

	// 将数据转换为BoundingBox结构体
	for (const auto&  box: boxes) {
		if (box.size() == 6) { // Assuming [x1, y1, x2, y2, score]
			BoundingBox bbox;
			bbox.x1 = box[0];
			bbox.y1 = box[1];
			bbox.x2 = box[2];
			bbox.y2 = box[3];
			bbox.score = box[4];
			bbox.index = box[5];
			convertedBoxes.push_back(bbox);
		}
		else {
			std::cerr << "Invalid box format!" << std::endl;
		}
	}

	// 对框按照分数降序排序
	std::sort(convertedBoxes.begin(), convertedBoxes.end(), [](const BoundingBox& a, const BoundingBox&  b) {
		return a.score > b.score;
		});

	// 非最大抑制
	std::vector<BoundingBox> result;
	std::vector<bool> isSuppressed(convertedBoxes.size(), false);

	for (size_t i = 0; i < convertedBoxes.size(); ++i) {
		if (!isSuppressed[i]) {
			result.push_back(convertedBoxes[i]);

			for (size_t j = i + 1; j < convertedBoxes.size(); ++j) {
				if (!isSuppressed[j]) {
					float overlap = iou(convertedBoxes[i], convertedBoxes[j]);

					if (overlap > overlapThreshold) {
						isSuppressed[j] = true;
					}
				}
			}
		}
	}
#if 0
	// 输出结果
	std::cout << "NMS Result:" << std::endl;
	for (const auto& box: result) {
		std::cout << "x1: " << box.x1 << ", y1: " << box.y1
			<< ", x2: " << box.x2 << ", y2: " << box.y2
			<< ", score: " << box.score << ",index:" << box.index << std::endl;
	}
#endif 
	return result;
}

void post_process(float *prob_model, float conf_thres, float overlapThreshold, std::vector<Yolo::Detection> & detResult)
{
	int cols = 7, rows = 145152;
	//  ========== 8. 获取推理结果 =========
	std::vector<std::vector<float>> prediction(rows, std::vector<float>(cols));

	int index = 0;
	for (int i = 0; i < rows; ++i) {
		for (int j = 0; j < cols; ++j) {
			prediction[i][j] = prob_model[index++];
		}
	}

	//  ========== 9. 大于conf_thres加入xc =========
	std::vector<std::vector<float>> xc;
	for (const auto& row : prediction) {
		if (row[4] > conf_thres) {
			xc.push_back(row);
		}
	}
	//  ========== 10. 置信度 = obj_conf * cls_conf =========
	//std::cout << xc[0].size() << std::endl;
	for (auto& row: xc) {
		for (int i = 5; i < xc[0].size(); i++) {
			row[i] *= row[4];
		}
	}

	// ========== 11. 切片取出xywh 转为xyxy=========
	std::vector<std::vector<float>> xywh;
	for (const auto& row: xc) {
		std::vector<float> sliced_row(row.begin(), row.begin() + 4);
		xywh.push_back(sliced_row);
	}
	std::vector<std::vector<float>> box(xywh.size(), std::vector<float>(4, 0.0));

	xywhtoxxyy(xywh, box);
	
	// ========== 12. 获取置信度最高的类别和索引=========
	std::size_t mi = xc[0].size();
	std::vector<float> conf(xc.size(), 0.0);
	std::vector<float> j(xc.size(), 0.0);

	for (std::size_t i = 0; i < xc.size(); ++i) {
		// 模拟切片操作 x[:, 5:mi]
		auto sliced_x = std::vector<float>(xc[i].begin() + 5, xc[i].begin() + mi);

		// 计算 max
		auto max_it = std::max_element(sliced_x.begin(), sliced_x.end());

		// 获取 max 的索引
		std::size_t max_index = std::distance(sliced_x.begin(), max_it);

		// 将 max 的值和索引存储到相应的向量中
		conf[i] = *max_it;
		j[i] = max_index;  // 加上切片的起始索引
	}

	// ========== 13. concat x1, y1, x2, y2, score, index;======== =
	for (int i = 0; i < xc.size(); i++) {
		box[i].push_back(conf[i]);
		box[i].push_back(j[i]);
	}

	std::vector<std::vector<float>> output;
	for (int i = 0; i < xc.size(); i++) {
		output.push_back(box[i]); // 创建一个空的 float 向量并
	}

	// ==========14 应用非最大抑制 ==========
	std::vector<BoundingBox>  result = nonMaximumSuppression(output, overlapThreshold);
	for (const auto& r : result)
	{
		Yolo::Detection det;
		det.bbox[0] = r.x1;
		det.bbox[1] = r.y1;
		det.bbox[2] = r.x2;
		det.bbox[3] = r.y2;
		det.conf = r.score;
		det.class_id = r.index;
		detResult.push_back(det);
	}

}

代码参考:
https://blog.csdn.net/rooftopstars/article/details/136771496
https://blog.csdn.net/qq_73794703/article/details/132147879

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1854912.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序毕业设计-在线厨艺平台系统项目开发实战(附源码+论文)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;微信小程序毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计…

XSS漏洞实验

本篇为xss漏洞实验练习&#xff0c;练习网址来源于网络 练习网址&#xff1a;XSS平台|CTF欢迎来到XSS挑战|XSS之旅|XSS测试 一、前置说明 在测试过程中&#xff0c;有哪些东西是我们可以利用来猜测与判断的&#xff1a; 网页页面的变化&#xff1b;审查网页元素&#xff1b;查…

奇瑞复活经典路虎!中国技术,英国车标,卖向全球

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。 更多资源欢迎关注 6月19日&#xff0c;奇瑞和捷豹路虎宣布签署战略合作意向书&#xff0c;将复活“Freelander神行者”品牌。 小编当课代表&#xff0c;做个简单总结&#xff1a; 英国品牌&#xff0c;中国技术&#xf…

数组 (java)

文章目录 一维数组静态初始化动态初始化 二维数组静态初始化动态初始化 数组参数传递可变参数关于 main 方法的形参 argsArray 工具类sort 中的 comparable 和 comparatorcomparator 比较器排序comparable 自然排序 一维数组 线性结构 静态初始化 第一种&#xff1a;int[] a…

基于uni-app和图鸟UI开发上门服务小程序

一、技术栈选择 uni-app&#xff1a;我们选择了uni-app作为开发框架&#xff0c;因为它基于Vue.js&#xff0c;允许我们编写一次代码&#xff0c;发布到多个平台&#xff0c;包括iOS、Android、Web以及各种小程序。uni-app的丰富组件库、高效的状态管理以及便捷的预览调试功能&…

无霍尔BLDC驱动

目前主要的无霍尔控制方案是基于反电势检测信 息判断换相点,本文研究反电势在 PWM - OFF 点的检 测方案确定换相点。 1. 反电动势检测方案 BLDC 的模型做等效,将线圈阻抗看成是一个 线性电阻和一个储能电感的等效,其等效电路图如图 1所示。 电机三相绕组输出端电压的电压…

【PL理论深化】(3) MI 归纳法:归纳假设 (IH) | 结构归纳法 | 归纳假设的证明

&#x1f4ac; 写在前面&#xff1a;所有编程语言都是通过归纳法定义的。因此&#xff0c;虽然编程语言本身是有限的&#xff0c;但用该语言编写的程序数量是没有限制的&#xff0c;本章将学习编程语言研究中最基本的归纳法。本章我们继续讲解归纳法&#xff0c;介绍归纳假设和…

issues.sonatype.org网站废弃,Maven仓库账号被废弃问题解决

问题起因&#xff1a; 今天自己的项目发布了一个新版本&#xff0c;打算通过GitHub流水线直接推送至Maven中央仓库&#xff0c;结果发现报错 401&#xff0c;说我的账号密码认证失败。我充满了疑惑我寻思难度我的号被盗掉了吗。于是我打开Nexus Repository Manager尝试登录账号…

最新下载:XmanagerXShell【软件附加安装教程】

​相信大家都认同支持IPv6&#xff1a;最近越来越多的公司和国家都采用了IPv6&#xff0c;Xmanager的最新版本v5也加入支持这个功能&#xff0c;无论你是同时使用IPv4和IPv6网络或者完全的IPv6网络&#xff0c;Xmanager 5都可完全满足你的要求&#xff0c;使用MIT Kerberos认证…

Vue41 ref属性

ref属性 ref是Vue提供的获取组件的属性 <template><div><h1 v-text"msg" ref"title"></h1><button ref"btn" click"showDOM">点我输出上方的DOM元素</button><MySchool ref"sch"…

Android开发神器:OkHttp框架源码解析

NetworkInterceptors CallServiceInterceptor 1.RealInterceptorChain.proceed() 2.EventListener.callStart()也是在RealCall.execute()嵌入到Request调用过程, EventListener.callEnd()位于StreamAllocation中调用 3.Request.Builder url (String/URL/HttpUrl) header …

WordPress插件:子比zibll主题插件 炙焰美化全开源插件V3.2

在数字时代&#xff0c;拥有一个美观且功能丰富的网站是吸引和保持用户的关键。WordPress作为全球最受欢迎的内容管理系统之一&#xff0c;提供了一个灵活的平台&#xff0c;让网站所有者能够通过插件来增强其网站的功能和外观。"炙焰美化全开源插件V3.2"正是这样一款…

SD卡无法读取?数据恢复全攻略!

SD卡无法读取问题描述 在日常使用电子设备时&#xff0c;我们有时会遇到SD卡无法读取的情况。当插入SD卡后&#xff0c;设备可能无法识别或访问其中的数据&#xff0c;这给我们带来了诸多不便。SD卡无法读取&#xff0c;意味着存储在卡中的重要文件、照片和视频等资料可能面临…

QListView、QTableView或QTreeView截取滚动区域(截长图)

本文以QTreeView为例,理论上继承自QAbstractScrollArea的类都支持本文所述的方法。 一.效果 一共5个文件夹,每个文件文件夹下有5个文件,先把文件夹展开,然后截图。将滚动条拖到居中位置,是为了证明截图对滚动条无影响 下面是截的图 二.原理 将滚动区域的viewport设置为…

Typora最新安装教程2024

Typora是一款广受好评的跨平台Markdown编辑软件&#xff0c;支持Windows、MacOS和Linux操作系统。它的设计旨在提供一个无干扰、高效且直观的写作环境。户快速管理和查找文档&#xff0c;支持直接在软件内浏览和操作文件结构。 Typora以其简洁而强大的功能集合&#xff0c;成为…

SQL Server中CROSS APPLY连接操作

在 SQL Server 中&#xff0c;CROSS APPLY 是一个连接操作&#xff0c;它类似于 INNER JOIN&#xff0c;但有一些关键差异&#xff0c;特别是在处理表值函数&#xff08;TVF&#xff09;、行集函数或子查询时。CROSS APPLY 返回对于外部查询中的每一行&#xff0c;在内部查询或…

【栈和队列】

目录 1&#xff0c;栈&#xff08;Stack&#xff09; 1.1 概念 1.2 栈的使用 1.3 栈的模拟实现 1.4 栈的应用场景 1.5 概念区分 1.6 使用链表来实现栈 2&#xff0c; 队列(Queue) 2.1 概念 2.2 队列的使用 2.3 队列模拟实现 3&#xff0c;双端队列 (Deque) 4&…

r2frida:基于Frida的远程进程安全检测和通信工具

关于r2frida r2frida是一款能够将Radare2和Frida的功能合二为一的强大工具&#xff0c;该工具本质上是一个Radare2的自包含插件&#xff0c;可以帮助广大研究人员利用Frida的功能实现对目标进程的远程安全检测和通信管理。 Radare2项目提供了针对逆向工程分析的完整工具链&…

利用golang_Consul代码实现Prometheus监控目标的注册以及动态发现与配置

文章目录 前言一、prometheus发现方式二、监控指标注册架构图三、部分代码展示1.核心思想2.代码目录3、程序入口函数剖析4、settings配置文件5、初始化配置文件及consul6、全局变量7、配置config8、公共方法目录common9、工具目录tools10、service层展示11、命令行参数12、Make…

使用 axios 进行 HTTP 请求

使用 axios 进行 HTTP 请求 文章目录 使用 axios 进行 HTTP 请求1、介绍2、安装和引入3、axios 基本使用4、axios 发送 GET 请求5、axios 发送 POST 请求6、高级使用7、总结 1、介绍 什么是 axios axios 是一个基于 promise 的 HTTP 库&#xff0c;可以用于浏览器和 Node.js 中…