使用TensorRT对YOLOv8模型进行加速推理

news2025/1/23 21:24:10

      这里使用GitHub上shouxieai的 infer框架 对YOLOv8模型进行加速推理,操作过程如下所示:

      1.配置环境,依赖项,包括:

      (1).CUDA: 11.8

      (2).cuDNN: 8.7.0

      (3).TensorRT: 8.5.3.1

      (4).ONNX: 1.16.0

      (5).OpenCV: 4.10.0

      2.clone infer代码:https://github.com/shouxieai/infer

      3.使用 https://blog.csdn.net/fengbingchun/article/details/140691177 中采用的数据集生成best.onnx,训练代码如下所示:

import argparse
import colorama
from ultralytics import YOLO
import torch

def parse_args():
	parser = argparse.ArgumentParser(description="YOLOv8 train")
	parser.add_argument("--yaml", required=True, type=str, help="yaml file")
	parser.add_argument("--epochs", required=True, type=int, help="number of training")
	parser.add_argument("--task", required=True, type=str, choices=["detect", "segment"], help="specify what kind of task")

	args = parser.parse_args()
	return args

def train(task, yaml, epochs):
	if task == "detect":
		model = YOLO("yolov8n.pt") # load a pretrained model
	elif task == "segment":
		model = YOLO("yolov8n-seg.pt") # load a pretrained model
	else:
		print(colorama.Fore.RED + "Error: unsupported task:", task)
		raise

	results = model.train(data=yaml, epochs=epochs, imgsz=640) # train the model

	metrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings remembered

	# model.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support
	model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
	model.export(format="torchscript") # libtorch
	model.export(format="engine", imgsz=640, dynamic=False, verbose=False, batch=1, workspace=2) # tensorrt fp32
	# model.export(format="engine", imgsz=640, dynamic=True, verbose=True, batch=4, workspace=2, half=True) # tensorrt fp16
	# model.export(format="engine", imgsz=640, dynamic=True, verbose=True, batch=4, workspace=2, int8=True, data=yaml) # tensorrt int8

if __name__ == "__main__":
	# python test_yolov8_train.py --yaml datasets/melon_new_detect/melon_new_detect.yaml --epochs 1000 --task detect
	colorama.init()
	args = parse_args()

	if torch.cuda.is_available():
		print("Runging on GPU")
	else:
		print("Runting on CPU")

	train(args.task, args.yaml, args.epochs)

	print(colorama.Fore.GREEN + "====== execution completed ======")

      4.将best.onnx文件通过infer中的v8trans.py转换为best.transd.onnx,执行如下命令:增加Transpose层,YOLOv5不需要

python v8trans.py best.onnx

      注:yolov8 onnx的输出为NHW,而inter框架的输出只支持NWH,因此需要在原始onnx的输出之前添加一个Transpose节点

      5.从 https://docs.nvidia.com/deeplearning/cudnn/archives/cudnn-870/install-guide/index.html#install-zlib-windows 下载zlib123dllx64.zip,解压缩将其中的zlibwapi.dll拷贝到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin目录下

      6.通过TensorRT中的trtexec.exe将best.transd.onnx转换为best.transd.engine,分别执行如下命令:

trtexec.exe --onnx=best.transd.onnx --saveEngine=best.transd.fp32.engine
trtexec.exe --onnx=best.transd.onnx --fp16 --saveEngine=best.transd.fp16.engine
trtexec.exe --onnx=best.transd.onnx --int8 --saveEngine=best.transd.int8.engine

      :产生的best.transd.fp32.engine和best.transd.fp16.engine大小类似,推理耗时和准确度也类似;best.transd.int8.engine大小约是best.transd.fp32.engine的四分之一,推理耗时也小,但是准确度非常低

      7.测试代码TensorRT_infer.cpp如下:工程见:TensorRT_Infer

#include <iostream>
#include <filesystem>
#include <vector>
#include <fstream>
#include <sstream>
#include <random>
#include <map>
#include <memory>
#include <chrono>
#include <string>
#include <algorithm>

#include <opencv2/opencv.hpp>
#include "yolo.hpp"

namespace {

constexpr float confidence_threshold{ 0.45f }; // confidence threshold
constexpr float nms_threshold{ 0.50f }; // nms threshold
constexpr char* engine_file{ "../../../data/best.transd.fp32.engine" };
constexpr char* images_dir{ "../../../data/images/predict" };
constexpr char* result_dir{ "../../../data/result" };
constexpr char* classes_file{ "../../../data/images/labels.txt" };

std::vector<std::string> parse_classes_file(const char* name)
{
	std::vector<std::string> classes;

	std::ifstream file(name);
	if (!file.is_open()) {
		std::cerr << "Error: fail to open classes file: " << name << std::endl;
		return classes;
	}

	std::string line;
	while (std::getline(file, line)) {
		auto pos = line.find_first_of(" ");
		classes.emplace_back(line.substr(0, pos));
	}

	file.close();
	return classes;
}

auto get_dir_images(const char* name)
{
	std::map<std::string, std::string> images; // image name, image path + image name

	for (auto const& dir_entry : std::filesystem::directory_iterator(name)) {
		if (dir_entry.is_regular_file())
			images[dir_entry.path().filename().string()] = dir_entry.path().string();
	}

	return images;
}

auto get_random_color(int labels_number)
{
	std::random_device rd;
	std::mt19937 gen(rd());
	std::uniform_int_distribution<int> dis(100, 255);

	std::vector<cv::Scalar> colors;

	for (auto i = 0; i < labels_number; ++i) {
		colors.emplace_back(cv::Scalar(dis(gen), dis(gen), dis(gen)));
	}

	return colors;
}

} // namespace

int main()
{
    namespace fs = std::filesystem;

    if (!fs::exists(result_dir)) {
        fs::create_directories(result_dir);
    }

    auto classes = parse_classes_file(classes_file);
    if (classes.size() == 0) {
        std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;
        return -1;
    }

	std::cout << "classes: ";
	for (const auto& val : classes) {
		std::cout << val << " ";
	}
	std::cout << std::endl;

	auto colors = get_random_color(classes.size());

	auto model = yolo::load(engine_file, yolo::Type::V8, confidence_threshold, nms_threshold);

	for (auto i = 0; i < 10; ++i) {
		std::cout << "i: " << i << std::endl;
		for (const auto& [key, val] : get_dir_images(images_dir)) {
			cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);
			if (frame.empty()) {
				std::cerr << "Warning: unable to load image: " << val << std::endl;
				continue;
			}

			auto tstart = std::chrono::high_resolution_clock::now();
			auto objs = model->forward(yolo::Image(frame.data, frame.cols, frame.rows));
			auto tend = std::chrono::high_resolution_clock::now();
			std::cout << "elapsed millisenconds: " << std::chrono::duration_cast<std::chrono::milliseconds>(tend - tstart).count() << " ms" << std::endl;

			for (const auto& obj : objs) {
				cv::rectangle(frame, cv::Point(obj.left, obj.top), cv::Point(obj.right, obj.bottom), colors[obj.class_label], 2);

				std::string class_string = classes[obj.class_label] + ' ' + std::to_string(obj.confidence).substr(0, 4);
				cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
				cv::Rect text_box(obj.left, obj.top - 40, text_size.width + 10, text_size.height + 20);

				cv::rectangle(frame, text_box, colors[obj.class_label], cv::FILLED);
				cv::putText(frame, class_string, cv::Point(obj.left + 5, obj.top - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
			}

			std::string path(result_dir);
			path += "/" + key;
			cv::imwrite(path, frame);
		}
	}

	std::cout << "test finish" << std::endl;
    return 0;
}

      执行结果如下图所示:

      检测结果如下图所示:

      trtexec.exe是一个快速使用TensorRT的工具,无需开发自己的应用程序。此工具有三个主要用途:

      (1).根据随机或用户提供的输入数据对网络进行基准测试。

      (2).从模型生成序列化引擎(engine)。

      (3).从构建器生成序列化时序缓存(serialized timing cache)。

      trtexec.exe常用flags说明:

      1.构建阶段flags

      (1).--onnx=<model>:指定输入ONNX模型。如果输入模型为ONNX格式,使用--minShapes、--optShapes和--maxShapes标志来控制输入shapes的范围(包括batch大小)。

      (2).--minShapes=<shapes>, --optShapes=<shapes>, and --maxShapes=<shapes>:指定用于构建engine的输入shapes的范围。仅当输入模型为ONNX格式时才需要。

      (3).–-memPoolSize=<pool_spec>:指定策略允许使用的workspace的最大大小。

      (4).--saveEngine=<file>:指定保存engine的路径。

      (5).--fp16, --bf16, --int8, --fp8, --noTF32, and --best:指定network-level精度。

      (6).--stronglyTyped:创建strongly typed网络。

      (7).--sparsity=[disable|enable|force]:指定是否使用支持结构化稀疏性(structured sparsity)的策略。

      (8).--noCompilationCache:禁用构建中的编译缓存(默认是启用编译缓存)。

      (9).--verbose:开启详细日志。

      (10).--skipInference:构建并保存engine而不运行推理。

      (11).--dumpLayerInfo, --exportLayerInfo=<file>:打印/保存engine的layer信息。

      (12).--precisionConstraints=spec:控制精度约束设置。指定的值可为:none、prefer、obey。

      (13).--layerPrecisions=spec:控制每层精度约束。仅当precisionConstraints设置为obey或prefer时才有效。规范从左到右读取,后面的会覆盖前面的。"*"可用作layerName,以指定所有未指定层的默认精度。

      如:--layerPrecisions=*:fp16,layer_1:fp32 将除layer_1之外的所有层的精度设置为FP16,而layer_1的精度将设置为FP32。

      (14).--layerOutputTypes=spec:控制每层输出类型约束。仅当precisionConstraints设置为obey或prefer时才有效。规范从左到右读取,后面的会覆盖前面的。"*"可用作layerName,以指定所有未指定层的默认精度。

      (15).--versionCompatible, --vc:为engine构建和推理启用版本兼容模式。

      (16).--tempdir=<dir>:覆盖TensorRT在创建临时文件时将使用的默认临时目录。

      2.推理阶段flags

      (1).--loadEngine=<file>:从序列化计划文件加载engine,而不是从输入ONNX模型构建它。如果输入模型是ONNX格式或者engine是使用明确的batch dimension构建的,则改用--shapes。

      (2).--shapes=<shapes>:指定用于运行推理的输入shapes。

      (3).--loadInputs=<specs>:从文件加载输入值。默认生成随机输入。

      (4).--noDataTransfers:关闭host to device和device to host的数据传输。

      (5).--verbose:开启详细日志。

      (6).--dumpProfile, --exportProfile=<file>:打印/保存每层性能概况。

      (7).--dumpLayerInfo, --exportLayerInfo=<file>:打印engine的层信息。

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1965382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

redis:Linux安装redis,redis常用的数据类型及相关命令

1. 什么是NoSQL nosql[not only sql]不仅仅是sql。所有非关系型数据库的统称。除去关系型数据库之外的都是非关系数据库。 1.1为什么使用NoSQL ​ NoSQL数据库相较于传统关系型数据库具有灵活性、可扩展性和高性能等优势&#xff0c;适合处理非结构化和半结构化数据&#xff0c…

服务运营|摘要:INFORMS 近期收益管理(Revenue Management )相关文章

编者按&#xff1a; 本期涵盖了INFORMS与收益管理相关的文章及其基本信息。 Title: Online Learning for Constrained Assortment Optimization Under Markov Chain Choice Model 基于马尔可夫链选择模型的约束下选品优化的在线学习 Link: https://pubsonline.informs.org/do…

召唤生命,阻止轻生——《生命门外》

本书的目的&#xff0c;就是阻止自杀&#xff01;拉回那些深陷在这样的思维当中正在挣扎犹豫的人&#xff0c;提醒他们珍爱生命&#xff0c;让更多的人&#xff0c;尤其是年轻人从执迷不悟的犹豫徘徊中幡然醒悟&#xff0c;回归正常的生活。 网络上抱孩子跳桥轻生的母亲&#…

Linux中gdb调试器的使用

Linux调试器&#xff1a;gdb gdb简介基本使用和常见的指令断点相关运行相关命令 gdb简介 我们都知道一个程序一般有两个版本分别是debug&#xff0c;和release版本&#xff0c;后者就是发布给用户的版本&#xff0c;而前者就是我们程序员用来调试用的版本。 他们有什么区别呢&…

Docker搭建Mysql主从复制,最新,最详细

Docker搭建Mysql主从复制&#xff0c;最新&#xff0c;最详细 这次搭建Mysql主从复制的时候&#xff0c;遇到不少问题&#xff0c;所以本次重新记录一下&#xff0c;使用Docker搭建一主三从的Mysql 一、Docker-Compose创建4个Mysql容器 1.1 创建对应的映射文件夹和对应的配置…

GitLab的安装步骤与代码拉取上传操作

一、GitLab的安装 详情见如下博客链接&#xff1a;gitlab安装 二、GitLab配置ssh key &#xff08;1&#xff09;打开Git Bash终端生成SSH和添加步骤 1、全局配置git用户名 git config --global user.name "xxx"注意&#xff1a;xxx为你自己gitlab的名字 2、全局…

JavaScript递归菜单栏

HTML就一个div大框架 <div class"treemenu"></div> 重中之重的JavaScript部分他来啦&#xff01; 注释也很清楚哟家人们&#xff01; let data; let arr []; let cons;let xhr new XMLHttpRequest(); // 设置请求方式和请求地址 xhr.open(get, ./js…

Linux上如何分析进程内存分配,优化进程内存占用大小

云计算场景下,服务器上内存宝贵,只有尽可能让服务器上服务进程占用更少的内存,方才可以提供更多的内存给虚拟机,卖给云客户。 虚拟化三大件:libvirt、qemu、kvm内存开销不小,可以优化占用更少的内存。如何找到进程内存开销的地方直观重要,以qemu为例说明。 一、查看进…

别让不专业的HR逼走你的人才!人力资源管理应该遵循哪些原则?

优秀的HR能够带领整个人力资源部门为企业招揽人才、培养人才和留住人才&#xff0c;促使人才为企业的业务增长提供支持。而不专业的HR&#xff0c;不仅无法做到这些&#xff0c;还会把企业原有的人才逼走&#xff0c;因为不合适的人力管理也是导致人才离职的原因。所以&#xf…

【C++】前缀和算法专题

目录 介绍 【模版】一维前缀和 算法思路&#xff1a; 代码实现 【模版】二维前缀和 算法思路 代码实现 寻找数组中心的下标 算法思路 代码实现 总结 除自身以外数组的乘积 算法思路 代码实现 和为K的子数组 算法思路 代码实现 和可被整除的K的子数组 算法思…

C++ 操作Git仓库

代码 #include "common.h" #include "args.c" #include "common.c"enum index_mode {INDEX_NONE,INDEX_ADD };struct index_options {int dry_run;int verbose;git_repository* repo;enum index_mode mode;int add_update; };/* Forward declar…

Python零基础详细入门教程

Python零基础详细入门教程可以从以下几个方面展开&#xff0c;帮助初学者系统地学习Python编程&#xff1a; 一、Python基础入门 1. Python简介 Python的由来与发展&#xff1a;Python是一种广泛使用的高级编程语言&#xff0c;以其简洁的语法和强大的功能而受到开发者的喜爱…

2024第二十届中国国际粮油产品及设备技术展示交易会

2024第二十届中国国际粮油产品及设备技术展示交易会 时间&#xff1a;2024年11月15-17日 地点&#xff1a; 南昌绿地国际博览中心 展会介绍&#xff1a; 随着国家逐年加大对农业的投入&#xff0c;调整农业产业结构&#xff0c;提高农产品附加值&#xff0c;促进农民增收。…

CRMEB-众邦科技 使用笔记

1.启动项目报错 Unable to load authentication plugin ‘caching_sha2_password’. 参考&#xff1a;http://t.csdnimg.cn/5EqaE 解决办法&#xff1a;升级mysql驱动 <dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</ar…

超级弱口令检查工具

一、背景 弱口令问题主要源于用户和管理员的安全意识不足&#xff0c;以及为了方便记忆而采用简单易记的密码。这些密码往往仅包含简单的数字和字母&#xff0c;缺乏复杂性和多样性&#xff0c;因此极易被破解。弱口令的存在严重威胁到系统和用户的数据安全&#xff0c;使得攻击…

在局域网中的另一台主机如何访问windows10WSL中的服务

文章目录 1&#xff0c;开启win10 路由功能2&#xff0c;配置转发规则 1&#xff0c;开启win10 路由功能 2&#xff0c;配置转发规则 netsh advfirewall firewall add rule name"Allowing LAN connections" dirin actionallow protocolTCP localport80 netsh interf…

计算机体系结构:缓存一致性ESI

集中式缓存处理器结构&#xff08;SMP&#xff09; 不同核访问存储器时间相同。 分布式缓存处理器结构&#xff08;NUMA&#xff09; 共享存储器按模块分散在各处理器附近&#xff0c;处理器访问本地存储器和远程存储器的延迟不同&#xff0c;共享数据可进入处理器私有高速缓存…

程序员自曝接单:三年时间接了25个单子,收入12万

程序员接单在程序员的副业中并不少见。程序员接单作为一个起步快、门槛低、类型多样的副业选择&#xff0c;一直深受程序员的青睐。就算你没有接触过接单&#xff0c;也一定对接单有过了解。 程序员接单是指程序员通过接取开发者发布的项目或任务来获取收入的一种工作方式。程序…

“八股文”的江湖:助力、阻力还是空谈?深度解析程序员面试的敲门砖

一、引言&#xff1a;八股文的江湖——助力、阻力还是空谈&#xff1f; 1.1 八股文的定义与背景 八股文&#xff0c;原指我国明清时期科举考试的一种应试文体&#xff0c;因其固定模式和空洞内容而备受诟病。在当今的程序员面试中&#xff0c;程序员的“八股文”通常指的是在技…

告别手动操作:这个微信自动化工具你一定要试试!

随着科技的发展&#xff0c;越来越多的自动化工具应运而生&#xff0c;帮助我们轻松管理微信号。 今天&#xff0c;就给大家揭开这个能让微信自动化的工具的神秘面纱&#xff0c;看看它们能为我们的工作带来哪些便利。 1、批量自动加好友 通过个微管理系统&#xff0c;你可以…