c++通过tensorRT调用模型进行推理

news2025/1/26 5:03:18

模型来源
算法工程师训练得到的onnx模型

c++对模型的转换
拿到onnx模型后,通过tensorRT将onnx模型转换为对应的engine模型,注意:训练用的tensorRT版本和c++调用的tensorRT版本必须一致。

如何转换:

  1. 算法工程师直接转换为.engine文件进行交付。
  2. 自己转换,进入tensorRT安装目录\bin目录下,将onnx模型拷贝到bin目录,地址栏中输入cmd回车弹出控制台窗口,然后输入转换命令,如:

trtexec --onnx=model.onnx --saveEngine=model.engine --workspace=1024 --optShapes=input:1x13x512x640 --fp16

然后回车,等待转换完成,完成后如图所示:
在这里插入图片描述
并且在bin目录下生成.engine模型文件。

c++对.engine模型文件的调用和推理
首先将tensorRT对模型的加载及推理进行封装,命名为CTensorRT.cpp,老样子贴代码:

//CTensorRT.cpp
class Logger : public nvinfer1::ILogger {
	void log(Severity severity, const char* msg) noexcept override {
		if (severity <= Severity::kWARNING)
			std::cout << msg << std::endl;
	}
};

Logger logger;
class CtensorRT
{
public:
	CtensorRT() {}
	~CtensorRT() {}

private:
	std::shared_ptr<nvinfer1::IExecutionContext> _context;
	std::shared_ptr<nvinfer1::ICudaEngine> _engine;

	nvinfer1::Dims _inputDims;
	nvinfer1::Dims _outputDims;
public:
	void cudaCheck(cudaError_t ret, std::ostream& err = std::cerr)
	{
		if (ret != cudaSuccess)
		{
			err << "Cuda failure: " << cudaGetErrorString(ret) << std::endl;
			abort();
		}
	}

	bool loadOnnxModel(const std::string& filepath)
	{
		auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));
		if (!builder)
		{
			return false;
		}

		const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
		auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));
		if (!network)
		{
			return false;
		}

		auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
		if (!config)
		{
			return false;
		}

		auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));
		if (!parser)
		{
			return false;
		}

		parser->parseFromFile(filepath.c_str(), static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));
		std::unique_ptr<IHostMemory> plan{ builder->buildSerializedNetwork(*network, *config) };
		if (!plan)
		{
			return false;
		}

		std::unique_ptr<IRuntime> runtime{ createInferRuntime(logger) };
		if (!runtime)
		{
			return false;
		}

		_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()));
		if (!_engine)
		{
			return false;
		}
		_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());
		if (!_context)
		{
			return false;
		}

		int nbBindings = _engine->getNbBindings();
		assert(nbBindings == 2); // 输入和输出,一共是2个

		for (int i = 0; i < nbBindings; i++)
		{
			if (_engine->bindingIsInput(i))
				_inputDims = _engine->getBindingDimensions(i);    // (1,3,752,752)
			else
				_outputDims = _engine->getBindingDimensions(i);
		}
		return true;
	}

	bool loadEngineModel(const std::string& filepath)
	{
		std::ifstream file(filepath, std::ios::binary);
		if (!file.good())
		{
			return false;
		}

		std::vector<char> data;
		try
		{
			file.seekg(0, file.end);
			const auto size = file.tellg();
			file.seekg(0, file.beg);

			data.resize(size);
			file.read(data.data(), size);
		}
		catch (const std::exception& e)
		{
			file.close();
			return false;
		}
		file.close();

		auto runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));
		_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(data.data(), data.size()));
		if (!_engine)
		{
			return false;
		}

		_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());
		if (!_context)
		{
			return false;
		}

		int nbBindings = _engine->getNbBindings();
		assert(nbBindings == 2); // 输入和输出,一共是2个

		// 为输入和输出创建空间
		for (int i = 0; i < nbBindings; i++)
		{
			if (_engine->bindingIsInput(i))
				_inputDims = _engine->getBindingDimensions(i);    //得到输入结构
			else
				_outputDims = _engine->getBindingDimensions(i);//得到输出结构
		}
		return true;
	}

	void ONNX2TensorRT(const char* ONNX_file, std::string save_ngine)
	{
		// 1.创建构建器的实例
		nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);

		// 2.创建网络定义
		uint32_t flag = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
		nvinfer1::INetworkDefinition* network = builder->createNetworkV2(flag);

		// 3.创建一个 ONNX 解析器来填充网络
		nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);

		// 4.读取模型文件并处理任何错误
		parser->parseFromFile(ONNX_file, static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));
		for (int32_t i = 0; i < parser->getNbErrors(); ++i)
		{
			std::cout << parser->getError(i)->desc() << std::endl;
		}

		// 5.创建一个构建配置,指定 TensorRT 应该如何优化模型
		nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();

		// 7.指定配置后,构建引擎
		nvinfer1::IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);

		// 8.保存TensorRT模型
		std::ofstream p(save_ngine, std::ios::binary);
		p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());

		// 9.序列化引擎包含权重的必要副本,因此不再需要解析器、网络定义、构建器配置和构建器,可以安全地删除
		delete parser;
		delete network;
		delete config;
		delete builder;

		// 10.将引擎保存到磁盘,并且可以删除它被序列化到的缓冲区
		delete serializedModel;
	}

	uint32_t getElementSize(nvinfer1::DataType t) noexcept
	{
		switch (t)
		{
		case nvinfer1::DataType::kINT32: return 4;
		case nvinfer1::DataType::kFLOAT: return 4;
		case nvinfer1::DataType::kHALF: return 2;
		case nvinfer1::DataType::kBOOL:
		case nvinfer1::DataType::kINT8: return 1;
		}
		return 0;
	}

	int64_t volume(const nvinfer1::Dims& d)
	{
		return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
	}

	bool infer(unsigned char* input, int real_input_size, cv::Mat& out_mat)
	{
		tensor_custom::BufferManager buffer(_engine);

		cudaStream_t stream;
		cudaStreamCreate(&stream); // 创建异步cuda流

		int binds = _engine->getNbBindings();
		for (int i = 0; i < binds; i++)
		{
			if (_engine->bindingIsInput(i))
			{
				size_t input_size;
				float* host_buf = static_cast<float*>(buffer.getHostBufferData(i, input_size));
				memcpy(host_buf, input, real_input_size);
				break;
			}
		}

		// 将输入传递到GPU
		buffer.copyInputToDeviceAsync(stream);
		// 异步执行
		bool status = _context->enqueueV2(buffer.getDeviceBindngs().data(), stream, nullptr);
		if (!status)
			return false;

		buffer.copyOutputToHostAsync(stream);
		for (int i = 0; i < binds; i++)
		{
			if (!_engine->bindingIsInput(i))
			{
				size_t output_size;
				float* tmp_out = static_cast<float*>(buffer.getHostBufferData(i, output_size));
				//do your something here
				break;
			}
		}
		cudaStreamSynchronize(stream);
		cudaStreamDestroy(stream);
		return true;
	}
};

调用方式

int main()
{
	vector<int> dims = { 1,13,512,640 };
	vector<float> vall;
	for (int i=0;i<13;i++)
	{
		string file = "D:\\xxx\\" + to_string(i) + ".png";
		cv::Mat mt = imread(file, IMREAD_GRAYSCALE);
		cv::resize(mt, mt, cv::Size(640,512));
		mt.convertTo(mt, CV_32F, 1.0 / 255);
		cv::Mat shape_xr = mt.reshape(1, mt.total() * mt.channels());
		std::vector<float> vec_xr = mt.isContinuous() ? shape_xr : shape_xr.clone();
		vall.insert(vall.end(), vec_xr.begin(), vec_xr.end());
	}
	cv::Mat mt_4d(4, &dims[0], CV_32F, vall.data());

	string engine_model_file = "model.engine";
	CtensorRT cTensor;
	if (cTensor.loadEngineModel(engine_model_file))
	{
		cv::Mat out_mat;
		if (!cTensor.infer(mt_4d.data, vall.size() * 4, out_mat))
			std::cout << "infer error!" << endl;
		else
			cv::imshow("out", out_mat);
	}
	else
		std::cout << "load model file failed!" << endl;
	cv::waitKey(0);
	return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/984499.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Json“牵手”亚马逊商品详情数据方法,亚马逊商品详情API接口,亚马逊API申请指南

亚马逊平台是美国最大的一家网络电子商务公司&#xff0c;亚马逊公司是1995年成立&#xff0c;刚开始只做网上书籍售卖业务&#xff0c;后来扩展到了其他产品。现在已经是全世界商品品种最多的网上零售商和第二互联网公司&#xff0c;亚马逊是北美洲、欧洲等地区的主流购物平台…

为什么5G 要分离 CU 和DU?(4G分离RRU 和BBU)

在 Blog 一文中&#xff0c;5G--BBU RRU 如何演化到 CU DU&#xff1f;_5g rru_qq_38480311的博客-CSDN博客 解释了4G的RRU BBU 以及 5G CU DU AAU&#xff0c;主要是讲了它们分别是什么。但是没有讲清楚 为什么&#xff0c;此篇主要回答why。 4G 为什么分离基站为 RRU 和 BBU…

什么是原生IP?原生IP与住宅IP有何区别?

相信许多做跨境的都会接触到IP代理&#xff0c;比如电商平台、社媒平台、收款平台等等&#xff0c;都会检测IP。那也会经常听到一些词汇&#xff1a;原生IP、住宅IP&#xff0c;这两者之间有什么区别呢&#xff1f;什么业务需要用到呢&#xff1f;接下来带大家具体了解一下。 什…

React Antd可编辑单元格,非官网写法,不使用可编辑行和form验证

antd3以上的写法乍一看还挺复杂&#xff0c;自己写了个精简版 没用EditableRowCell的结构&#xff0c;也不使用Context、高阶组件等&#xff0c;不使用form验证 最终效果&#xff1a; class EditableCell extends React.Component {state {editing: false};toggleEdit () &…

SFUD固件移植

SFUD作用 SFUD 是一款开源的串行 SPI Flash 通用驱动库。由于现有市面的串行 Flash 种类居多&#xff0c;各个 Flash 的规格及命令存在差异&#xff0c; SFUD 就是为了解决这些 Flash 的差异现状而设计&#xff0c;让我们的产品能够支持不同品牌及规格的 Flash&#xff0c;提高…

Android 修改代码后不生效问题的终极方案

前言&#xff1a; 最近遇到几个项目&#xff0c;都出现了代码修改后&#xff0c;直接点studio上的run&#xff0c;跑起来后代码没生效&#xff0c;如果重新clean rebuild可以生效&#xff0c;但是这太浪费时间了。网上找了各种方案&#xff0c;前面几个项目&#xff0c;有的是可…

手写Spring:第19章-JDBC功能整合

文章目录 一、目标&#xff1a;JDBC功能整合二、设计&#xff1a;JDBC功能整合三、实现&#xff1a;JDBC功能整合3.1 工程结构3.2 整合JDBC功能核心类图3.3 数据源操作3.3.1 数据源操作抽象类3.3.2 JDBC 工具类 3.4 数据库执行3.4.1 语句处理器接口3.4.2 结果处理器接口3.4.3 行…

嵌入式Linux驱动开发(LCD屏幕专题)(四)

单Buffer的缺点与改进方法 1. 单Buffer的缺点 如果APP速度很慢&#xff0c;可以看到它在LCD上缓慢绘制图案 即使APP速度很高&#xff0c;LCD控制器不断从Framebuffer中读取数据来显示&#xff0c;而APP不断把数据写入Framebuffer 假设APP想把LCD显示为整屏幕的蓝色、红色 很…

线程池的实现

目录 一、线程池的实现 1.什么是线程池 2.设计线程类 3.设计线程池类 4.运行 5.RAII加锁改造 二、利用单例模式改造线程池 1.复习 2.饿汉模式 3.懒汉模式 关于系统编程的知识我们已经学完了&#xff0c;最后我们需要利用之前写过的代码实现一个线程池&#xff0c;彻底…

如何理解张量、张量索引、切片、张量维度变换

Tensor 张量 Tensor&#xff0c;中文翻译“张量”&#xff0c;是一种特殊的数据结构&#xff0c;与数组和矩阵非常相似。在 PyTorch 中&#xff0c;使用张量对模型的输入和输出以及模型的参数进行编码。 Tensor 是一个 Python Class。PyTorch 官方文档中定义“Tensor&#xff0…

Datawhale × 和鲸科技丨《2023 中国人工智能人才学习白皮书》发布!

2023 是一个历史性的年份&#xff0c;它标志着人工智能技术的崛起与普及&#xff0c;这一年里&#xff0c;AI 不仅在科技、经济、社会、文化等各个领域取得突破性的进展&#xff0c;也在人类日常生活中扮演愈加重要的角色。随着人工智能时代的加速到来&#xff0c;我国 AI 人才…

msvcp140.dll丢失的有哪些解决方法,丢失msvcp140.dll是什么意思

在我们的日常生活中&#xff0c;电脑问题是无处不在的&#xff0c;而msvcp140.dll丢失又是其中比较常见的一种。msvcp140.dll是Microsoft Visual C运行时库的一部分&#xff0c;它包含了一些重要的函数和资源。当这个文件丢失时&#xff0c;可能会导致电脑出现各种问题&#xf…

链路追踪Skywalking快速入门

目录 1 Skywalking概述1.1 微服务系统监控三要素1.2 什么是链路追踪1.2.1 链路追踪1.2.2 OpenTracing1、数据模型&#xff1a;2、核心接口语义 1.3 常见APM系统1.4 Skywalking介绍1、SkyWalking 核心功能&#xff1a;2、SkyWalking 特点&#xff1a;3、Skywalking架构图&#x…

mysql之DML的select分组排序

目录 一、创建表employee和department表 1.创建department表 2.创建employee表 3.给employee表格和department表格建立外键 4.给department插入数据 5.给employee表插入数据 6.删除名字为那个的数据 二、分组查询和排序查询&#xff0c;以及对数据的处理&#xff08;av…

ARM/X86工业级数据采集 (DAQ) 与控制产品解决方案

I/O设备&#xff0c;包括信号调理模块、嵌入式PCI/PCIE卡、便携式USB模块、DAQ嵌入式计算机、模块化DAQ系统&#xff0c;以及DAQNavi/SDK软件开发包和DAQNavi/MCM设备状态监测软件。 工业I/O产品适用于各种工业自动化应用&#xff0c;从机器自动化控制、测试测量到设备状态监测…

Java“牵手”京东商品详情数据,京东商品详情API接口,京东API接口申请指南

京东平台商品详情接口是开放平台提供的一种API接口&#xff0c;通过调用API接口&#xff0c;开发者可以获取京东商品的标题、价格、库存、月销量、总销量、库存、详情描述、图片等详细信息 。 获取商品详情接口API是一种用于获取电商平台上商品详情数据的接口&#xff0c;通过…

kuiper安装

1:使用docker方式安装 docker pull lfedge/ekuiper:latest docker run -p 9081:9081 -d --name kuiper -e MQTT_SOURCE__DEFAULT__SERVERtcp://127.0.0.1:1883 lfedge/ekuiper:latest这样就安装好了&#xff0c;但是操作只能通过命令完成&#xff0c;如果想要通过页面来操作&…

@DS注解方式springboot多数据源配置及失效场景解决

1.使用教程 导入依赖 <!--多数据源--><dependency><groupId>com.baomidou</groupId><artifactId>dynamic-datasource-spring-boot-starter</artifactId><version>3.5.0</version></dependency>配置数据源 datasource:…

MT36291 2.5A,高效型1.2MHz电流模式升压转换器芯片

MT36291 2.5A&#xff0c;高效型1.2MHz电流模式升压转换器芯片 特征 ●集成了80ms功率的MOSFET ●2.2V到16V的输入电压 ●1.2MHz固定开关频率 ●可调过电流保护&#xff1a; 0.5A ~2.5A ●内部2.5开关限流&#xff08;OC引脚浮动&#xff09; ●可调输出电压 ●内部补偿 ●过电…