TensorRT部署模型基本步骤(C++)

news2024/11/19 5:44:51

TensorRT部署模型基本步骤(C++)


文章目录

  • TensorRT部署模型基本步骤(C++)
  • 前言
  • 一、onnx模型转engine
    • 1.基于C++代码生成engine
    • 2.基于trtexec.exe命令行生成
  • 二、读取本地模型
  • 三、创建推理引擎
  • 四、创建推理上下文
  • 五、创建GPU显存缓冲区
  • 六、 配置输入数据
  • 七、模型推理
  • 八、获得输出数据
  • 总结


前言

经典的一个TensorRT部署模型步骤为:onnx模型转engine读取本地模型创建推理引擎创建推理上下文创建GPU显存缓冲区配置输入数据模型推理以及处理推理结果(后处理)


一、onnx模型转engine

目前多种模型框架都将onnx模型当作中间转换格式,是的该模型结构变得越来越通用,因此TensorRT目前主要在更新的就是针对该模型的转换。TensorRT是可以直接读取engine文件(缺点是:读取engine要求是相同TensorRT版本和相同的平台),对于onnx模型需要进行一些列转换配置,转为engine引擎才可以进行后续的推理,因此在进行模型推理前,需要先进行模型的转换。

1.基于C++代码生成engine

void onnx_to_engine(std::string onnx_file_path, std::string engine_file_path, int type) {

	// 构建器,获取cuda内核目录以获取最快的实现
	// 用于创建config、network、engine的其他对象的核心类
	nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
	const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
	// 解析onnx网络文件
	// tensorRT模型类
	nvinfer1::INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
	// onnx文件解析类
	// 将onnx文件解析,并填充rensorRT网络结构
	nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);
	// 解析onnx文件
	parser->parseFromFile(onnx_file_path.c_str(), 2);
	for (int i = 0; i < parser->getNbErrors(); ++i) {
		std::cout << "load error: " << parser->getError(i)->desc() << std::endl;
	}
	printf("tensorRT load mask onnx model successfully!!!...\n");

	// 创建推理引擎
	// 创建生成器配置对象。
	nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
	// 设置最大工作空间大小。
	config->setMaxWorkspaceSize(16 * (1 << 20));
	// 设置模型输出精度
	if (type == 1) {
		config->setFlag(nvinfer1::BuilderFlag::kFP16);
	}
	if (type == 2) {
		config->setFlag(nvinfer1::BuilderFlag::kINT8);
	}
	// 创建推理引擎
	nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
	// 将推理银枪保存到本地
	std::cout << "try to save engine file now~~~" << std::endl;
	std::ofstream file_ptr(engine_file_path, std::ios::binary);
	if (!file_ptr) {
		std::cerr << "could not open plan output file" << std::endl;
		return;
	}
	// 将模型转化为文件流数据
	nvinfer1::IHostMemory* model_stream = engine->serialize();
	// 将文件保存到本地
	file_ptr.write(reinterpret_cast<const char*>(model_stream->data()), model_stream->size());
	// 销毁创建的对象
	model_stream->destroy();
	engine->destroy();
	network->destroy();
	parser->destroy();
	std::cout << "convert onnx model to TensorRT engine model successfully!" << std::endl;
}

2.基于trtexec.exe命令行生成

图一
图二
备注:

  • –onnx=“filepath” filepath是onnx文件路径
  • –saveEngine=“filepath” filepath是保存engine文件路径

二、读取本地模型

读取onnx转换的engine二进制文件,将模型文件信息读取到内存中。由于engine保存了模型的信息以及电脑的(TensorRT)配置环境信息,所以如果要将模型部署在其他电脑上,要保证电脑的配置环境以及平台(windows/linux/macos)是否相同。

	std::string enginepath = "E:/TensorRT-8.6.0.12/bin/resnet18.engine";	//读取二进制文件engine的路径
	std::ifstream file(enginepath, std::ios::binary);		// 以二进制方式打开
	char* trtModelStream = NULL;							// 定义一个字符指针,用于读取engine文件数据
	int size = 0;											// 存储二进制文件字符的数量
	if (file.good()) {
		file.seekg(0, file.end);							//将文件指针移动到文件末尾
		size = file.tellg();								//获取当前文件指针的位置,即文件的大小
		file.seekg(0, file.beg);							//文件指针移回文件开始处
		trtModelStream = new char[size];					//分配足够的内存储存文件内容
		assert(trtModelStream);								//检查内存是否分配成功
		file.read(trtModelStream, size);					//读取文件信息,并存储在trtModelStream
		file.close();										//关闭文件
	}

三、创建推理引擎

首先需要初始化日志记录接口类,该类用于创建后续反序列化引擎使用;然后创建反序列化引擎,其主要作用是允许对序列化的功能上不安全的引擎进行反序列化,接下调用反序列化引擎来创建推理引擎,这一步只需要输入上一步中读取的模型文件数据以及长度即可。

// 日志记录接口
Logger logger;
// 反序列化引擎
nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
assert(runtime != nullptr);
// 推理引擎
nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(model_stream, size);
assert(engine != nullptr);

四、创建推理上下文

创建可执行的IExecutionContent实例 - createExecutionContext,为后面进行模型推理的类。

nvinfer1::IExecutionContext* context = engine->createExecutionContext();
assert(context != nullptr);
delete[] trtModelStream ;		// 释放内存

五、创建GPU显存缓冲区

TensorRT是利用英伟达显卡(GPU)进行模型推理的,但是我们的推理数据以及后续处理数据是在内存(CPU)中实现的,因此需要创建显存缓冲区,用于输入推理数据以及读取推理结果数据。

// 创建GPU显存缓冲区
void** data_buffer = new void* [num_ionode];
// 创建GPU显存输入缓冲区
// getBindingDimensions函数获取指定索引的绑定维度信息,然后从中提取高度和宽度的值
int input_node_index = engine->getBindingIndex(input_node_name);
cudaMalloc(&(data_buffer[input_node_index]), input_data_length * sizeof(float));
// 创建GPU显存输出缓冲区
int output_node_index = engine->getBindingIndex(output_node_name);
cudaMalloc(&(data_buffer[output_node_index]), output_data_length * sizeof(float));

六、 配置输入数据

配置输入数据时只需要调用cudaMemcpyAsync()方法,便可将cuda流数据加载到与i里模型上。但数据需要根据模型要求进行预处理,除此以外需要将数据结果加入到cuda流中。

// 创建输入cuda流
cudaStream_t stream;
cudaStreamCreate(&stream);
/*
*******输入图像预处理*******
*/
// HWC => CHW,转换成张量格式
cv::Mat tensor = ::dnn::blobFromImage(blob);
// 输入数据由内存到GPU显存
cudaMemcpyAsync(data_buffer[input_node_index], tensor.ptr<float>(), input_data_length * sizeof(float), cudaMemcpyHostToDevice, stream);

七、模型推理

context->enqueueV2(data_buffer, stream, nullptr);

八、获得输出数据

最后处理数据是在内存上实现的,首先需要将数据由显存读取到内存中。

std::vector<float> prob;
// 创建临时缓存输出
prob.resize(output_h * output_w);
// GPU显存到内存
cudaMemcpyAsync(prob.data(),  data_buffer[output_node_index], output_data_length * sizeof(float), cudaMemcpyDeviceToHost, stream);
/*
*******输出图像后处理*******
*/

总结

本文主要介绍了TensorRT+C++部署的基本步骤,欢迎阅读交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1710280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Offline RL : Context-Former: Stitching via Latent Conditioned Sequence Modeling

paper 基于HIM的离线RL算法&#xff0c;解决基于序列模型的离线强化学习算法缺乏对序列拼接能力。 Intro 文章提出了ContextFormer&#xff0c;旨在解决决策变换器&#xff08;Decision Transformer, DT&#xff09;在轨迹拼接&#xff08;stitching&#xff09;能力上的不足…

windows11如何安装IIS

目录 IIS是什么&#xff1f; 为什么要配置IIS&#xff1f; 1.打开控制面板进入程序 2.点击启用或者关闭windos功能 3.勾选IIS相关的web项 4.点击确定等待一分钟程序变更即可 5.主页搜索internet 点击进入 6.进入IIS进行查看配置&#xff0c;并测试&#xff0c;也可以浏…

43、Flink 的 Window Join 详解

1.Window Join a&#xff09;概述 Window join 作用在两个流中有相同 key 且处于相同窗口的元素上&#xff0c;窗口可以通过 window assigner 定义&#xff0c;并且两个流中的元素都会被用于计算窗口的结果。 两个流中的元素在组合之后&#xff0c;会被传递给用户定义的 Joi…

stream-实践应用-统计分析

背景 业务部门提供了一个数据&#xff0c;数据甚至不是excel类型的&#xff0c;是data.txt&#xff0c;每一行都是一个数据&#xff0c;需要对此数据进行统计分析 统计各个月份的销量 因为直接获取resources下的data.txt&#xff0c;所以要借助输入流进行获取数据&#xff0c;再…

sqli-labs---第三关

1、判断什么类型注入 ?id1 正常显示 ?id1 &#xff08;报错&#xff1a;1) LIMIT 0,1&#xff09; ?id1 正常显示 ?id1#(报错&#xff1a;1) LIMIT 0,1) 可知闭合方式为) 2、查看列数 ?id1) order by 3 -- (没有报错) ?id1) order by 4 -- (报错) 说明有3列 3、使用联合查…

Scrapy框架简单介绍及Scrapy项目编写详细步骤(Scrapy框架爬取豆瓣网站示例)

引言 Scrapy是一个用Python编写的开源、功能强大的网络爬虫框架&#xff0c;专为网页抓取和数据提取设计。它允许开发者高效地从网站上抓取所需的数据&#xff0c;并通过一系列可扩展和可配置的组件来处理这些数据。Scrapy框架的核心组成部分包括&#xff1a; Scrapy Engine&…

window本地部署Dify

Dify与之前的MaxKB不同&#xff0c;MaxKB可以实现基础的问答以及知识库功能&#xff0c;但是如果要开发一个Agent&#xff0c;或者工作流就还是需要额外开发&#xff0c;而Dify 是一个开源 LLM 应用开发平台。其直观的界面结合了 AI 工作流、RAG 管道、代理功能、模型管理、可观…

python制作一个批量更新文件名称的工具

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一.前言 二.实现 三.使用效果 一.前言 随着数字化时代的到来&#xff0c;文件管理和处理变…

14.微信小程序之地理定位功能

目录 1.地理定位介绍 1.1 申请开通 1.2 使用方法 2.拒绝授权后的解决方案 3.开通腾讯位置服务 4.LBS 逆地址解析 1.地理定位介绍 小程序地理定位是指通过小程序开发平台提供的 API&#xff0c;来获取用户的地理位置信息。用户在使用小程序时&#xff0c;可以授权小程序获…

【LabVIEW FPGA入门】同步C系列模块

1.同步使用循环定时器VI计时循环速率的系列模块 数字模块SAR ADC 模块多路复用模块 数字通道可以在一个时钟周期内执行。模拟通道需要多个时钟周期。 同步模拟模块的每个通道有一个 ADC&#xff0c;采集的数据在通道之间没有明显的偏差。多路复用模块使用多路复用器通过单个 A…

解决:error: failed to push some refs to ‘https://gitee.com/***/***.git‘(高效快速)

解决方案&#xff1a; git pull --rebase origin master 具体原因&#xff1a; 主要原因是gitee(github)中的README.md文件不在本地代码目录中 要执行git pull --rebase origin master命令将README.md拉到本地 然后就可以执行git push啦 写在最后&#xff1a; 要是问题得到…

MySQL简单测试和安装

MySQL 的特点 1、MySQL 性能卓越、服务稳定&#xff0c;很少出现异常宕机。 2、MySQL开放源代码且无版权制约&#xff0c;自主性及使用成本低。 3、MySQL历史悠久(版本众多)&#xff0c;用户使用活跃&#xff0c;遇到问题可以寻求帮助。 4、MySQL体积小(相对大型关系型数据库)…

你还不知道宠物空气净化器的五大好处?难怪家里总有异味和猫毛!

养猫是一件非常令人愉快的事情&#xff0c;猫咪的陪伴能带给我们无尽的欢乐。然而&#xff0c;随着时间的推移&#xff0c;许多养猫的朋友会发现一个问题&#xff0c;那就是家中的猫毛和异味问题。其实&#xff0c;解决这些问题的关键就在于选择一款高效的宠物空气净化器。今天…

嵌入式学习——3——多点通信

1、套接字选项&#xff08;socket options&#xff09; int getsockopt(int sockfd, int level, int optname, void *optval, socklen_t *optlen); int setsockopt(int sockfd, int level, int optname, const void *optval, socklen_t optlen); 功能&#xff1a;获取或设置套接…

二叉树——基础知识详解

前言&#xff1a; 经过前面的学习&#xff0c;我们接下来要开始二叉树的学习&#xff0c;因二叉树有难度&#xff0c;为了方便讲解以及各位的理解&#xff0c;本节知识会分成不同的小节进行学习&#xff0c;在本阶段只学习初阶的二叉树&#xff08;堆&#xff0c;二叉数基本知识…

多分支拓扑阻抗匹配

最近测试信号质量&#xff0c;发现在有过冲、振铃等问题的时候大部分硬件工程师喜欢直接调大匹配电阻或者减小驱动电流&#xff0c;虽然这种操作是有效果的&#xff0c;但是我认为应该还可以更严谨的计算下&#xff0c;而不是选几个电阻多次尝试&#xff0c;显得不是很专业。 …

SOLIDWORKS正版一年多少钱 2024版报价

SOLIDWORKS软件作为一款优秀的三维设计工具&#xff0c;以其强大的功能和优质的设计工具&#xff0c;为设计师们提供了前所未有的便利。SOLIDWORKS三维设计软件是一款多科学集成软件&#xff0c;它在产品开发和制造方面发挥着重要作用。 作为整个SOLIDWORKS产品开发解决方案套件…

boost asio异步服务器(2)实现伪闭包延长连接生命周期

闭包 在函数内部实现一个子函数&#xff0c;子函数的作用域内能访问外部函数的局部变量。闭包就是能够读取其他函数内部变量。但是由于闭包会使得函数中的变量都被保存在内存中&#xff0c;内存消耗很大&#xff0c;所以不能滥用闭包&#xff0c;否则会造成程的性能问题&#x…

Discourse 使用 DiscourseConnect 来进行用户数据同步

我们都知道 Discourse 的用户管理和设置都高度依赖电子邮件。 如果 Discourse 没有设置电子邮件 SMTP 的话&#xff0c;作为管理员是没有办法对用户邮箱进行修改并且通过验证的。 可以采取的办法是通过 Discourse 的 DiscourseConnect 来进行用户同步。 根据官方的说法&…

Golang原生http实现中间件

Golang原生http实现中间件 中间件&#xff08;middleware&#xff09;&#xff1a;常被用来做认证校验、审计等 大家常用的Iris、Gin等web框架&#xff0c;都包含了中间件逻辑。但有时我们引入该框架显得较为繁重&#xff0c;本文将介绍通过golang原生http来实现中间件操作。全…