使用onnxruntime加载YOLOv8生成的onnx文件进行目标检测

news2024/10/5 15:34:18

      在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe  工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:

path: ../datasets/melon # dataset root dir
train: images/train # train images (relative to 'path')
val: images/val  # val images (relative to 'path')
test: # test images (optional)
 
# Classes
names:
  0: watermelon
  1: wintermelon

      使用以下python脚本进行训练生成onnx文件

import argparse
import colorama
from ultralytics import YOLO

def parse_args():
	parser = argparse.ArgumentParser(description="YOLOv8 train")
	parser.add_argument("--yaml", required=True, type=str, help="yaml file")
	parser.add_argument("--epochs", required=True, type=int, help="number of training")
	parser.add_argument("--task", required=True, type=str, choices=["detect", "segment"], help="specify what kind of task")

	args = parser.parse_args()
	return args

def train(task, yaml, epochs):
	if task == "detect":
		model = YOLO("yolov8n.pt") # load a pretrained model
	elif task == "segment":
		model = YOLO("yolov8n-seg.pt") # load a pretrained model
	else:
		print(colorama.Fore.RED + "Error: unsupported task:", task)
		raise

	results = model.train(data=yaml, epochs=epochs, imgsz=640) # train the model

	metrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings remembered

	model.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support
	# model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
	model.export(format="torchscript") # libtorch

if __name__ == "__main__":
	colorama.init()
	args = parse_args()

	train(args.task, args.yaml, args.epochs)

	print(colorama.Fore.GREEN + "====== execution completed ======")

      以下是使用onnxruntime接口加载onnx文件进行目标检测的实现代码:

namespace {

constexpr bool cuda_enabled{ false };
constexpr int image_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400)
constexpr float model_score_threshold{ 0.45 }; // confidence threshold
constexpr float model_nms_threshold{ 0.50 }; // iou threshold

#ifdef _MSC_VER
constexpr char* onnx_file{ "../../../data/best.onnx" };
constexpr char* torchscript_file{ "../../../data/best.torchscript" };
constexpr char* images_dir{ "../../../data/images/predict" };
constexpr char* result_dir{ "../../../data/result" };
constexpr char* classes_file{ "../../../data/images/labels.txt" };
#else
constexpr char* onnx_file{ "data/best.onnx" };
constexpr char* torchscript_file{ "data/best.torchscript" };
constexpr char* images_dir{ "data/images/predict" };
constexpr char* result_dir{ "data/result" };
constexpr char* classes_file{ "data/images/labels.txt" };
#endif

std::vector<std::string> parse_classes_file(const char* name)
{
	std::vector<std::string> classes;

	std::ifstream file(name);
	if (!file.is_open()) {
		std::cerr << "Error: fail to open classes file: " << name << std::endl;
		return classes;
	}
	
	std::string line;
	while (std::getline(file, line)) {
		auto pos = line.find_first_of(" ");
		classes.emplace_back(line.substr(0, pos));
	}

	file.close();
	return classes;
}

auto get_dir_images(const char* name)
{
	std::map<std::string, std::string> images; // image name, image path + image name

	for (auto const& dir_entry : std::filesystem::directory_iterator(name)) {
		if (dir_entry.is_regular_file())
			images[dir_entry.path().filename().string()] = dir_entry.path().string();
	}

	return images;
}

void draw_boxes(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,
	const std::vector<cv::Rect>& boxes, const std::string& name, cv::Mat& frame)
{
	if (ids.size() != confidences.size() || ids.size() != boxes.size() || confidences.size() != boxes.size()) {
		std::cerr << "Error: their lengths are inconsistent: " << ids.size() << ", " << confidences.size() << ", " << boxes.size() << std::endl;
		return;
	}

	std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;

	std::random_device rd;
	std::mt19937 gen(rd());
	std::uniform_int_distribution<int> dis(100, 255);

	for (auto i = 0; i < ids.size(); ++i) {
		auto color = cv::Scalar(dis(gen), dis(gen), dis(gen));
		cv::rectangle(frame, boxes[i], color, 2);

		std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);
		cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
		cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);

		cv::rectangle(frame, text_box, color, cv::FILLED);
		cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
	}

	//cv::imshow("Inference", frame);
	//cv::waitKey(-1);

	std::string path(result_dir);
	path += "/" + name;
	cv::imwrite(path, frame);
}

std::wstring ctow(const char* str)
{
	constexpr size_t len{ 128 };
	wchar_t wch[len];
	swprintf(wch, len, L"%hs", str);

	return std::wstring(wch);
}

float image_preprocess(const cv::Mat& src, cv::Mat& dst)
{
	cv::cvtColor(src, dst, cv::COLOR_BGR2RGB);
	float resize_scales{ 1. };

	if (src.cols >= src.rows) {
		resize_scales = src.cols * 1.f / image_size[1];
		cv::resize(dst, dst, cv::Size(image_size[1], static_cast<int>(src.rows / resize_scales)));
	} else {
		resize_scales = src.rows * 1.f / image_size[0];
		cv::resize(dst, dst, cv::Size(static_cast<int>(src.cols / resize_scales), image_size[0]));
	}

	cv::Mat tmp = cv::Mat::zeros(image_size[0], image_size[1], CV_8UC3);
	dst.copyTo(tmp(cv::Rect(0, 0, dst.cols, dst.rows)));
	dst = tmp;

	return resize_scales;
}

template<typename T>
void image_to_blob(const cv::Mat& src, T* blob)
{
	for (auto c = 0; c < 3; ++c) {
		for (auto h = 0; h < src.rows; ++h) {
			for (auto w = 0; w < src.cols; ++w) {
				blob[c * src.rows * src.cols + h * src.cols + w] = (src.at<cv::Vec3b>(h, w)[c]) / 255.f;
			}
		}
	}
}

void post_process(const float* data, int rows, int stride, float xfactor, float yfactor, const std::vector<std::string>& classes,
	cv::Mat& frame, const std::string& name)
{
	std::vector<int> class_ids;
	std::vector<float> confidences;
	std::vector<cv::Rect> boxes;

	for (auto i = 0; i < rows; ++i) {
		const float* classes_scores = data + 4;

		cv::Mat scores(1, classes.size(), CV_32FC1, (float*)classes_scores);
		cv::Point class_id;
		double max_class_score;

		cv::minMaxLoc(scores, 0, &max_class_score, 0, &class_id);

		if (max_class_score > model_score_threshold) {
			confidences.push_back(max_class_score);
			class_ids.push_back(class_id.x);

			float x = data[0];
			float y = data[1];
			float w = data[2];
			float h = data[3];

			int left = int((x - 0.5 * w) * xfactor);
			int top = int((y - 0.5 * h) * yfactor);

			int width = int(w * xfactor);
			int height = int(h * yfactor);

			boxes.push_back(cv::Rect(left, top, width, height));
		}

		data += stride;
	}

	std::vector<int> nms_result;
	cv::dnn::NMSBoxes(boxes, confidences, model_score_threshold, model_nms_threshold, nms_result);

	std::vector<int> ids;
	std::vector<float> confs;
	std::vector<cv::Rect> rects;
	for (size_t i = 0; i < nms_result.size(); ++i) {
		ids.emplace_back(class_ids[nms_result[i]]);
		confs.emplace_back(confidences[nms_result[i]]);
		rects.emplace_back(boxes[nms_result[i]]);
	}
	draw_boxes(classes, ids, confs, rects, name, frame);
}

} // namespace

int test_yolov8_detect_onnxruntime()
{
	// reference: ultralytics/examples/YOLOv8-ONNXRuntime-CPP
	try {
		Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Yolo");
		Ort::SessionOptions session_option;

		if (cuda_enabled) {
			OrtCUDAProviderOptions cuda_option;
			cuda_option.device_id = 0;
			session_option.AppendExecutionProvider_CUDA(cuda_option);
		}

		session_option.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
		session_option.SetIntraOpNumThreads(1);
		session_option.SetLogSeverityLevel(3);

		Ort::Session session(env, ctow(onnx_file).c_str(), session_option);
		Ort::AllocatorWithDefaultOptions allocator;
		std::vector<const char*> input_node_names, output_node_names;
		std::vector<std::string> input_node_names_, output_node_names_;

		for (auto i = 0; i < session.GetInputCount(); ++i) {
			Ort::AllocatedStringPtr input_node_name = session.GetInputNameAllocated(i, allocator);
			input_node_names_.emplace_back(input_node_name.get());
		}

		for (auto i = 0; i < session.GetOutputCount(); ++i) {
			Ort::AllocatedStringPtr output_node_name = session.GetOutputNameAllocated(i, allocator);
			output_node_names_.emplace_back(output_node_name.get());
		}

		for (auto i = 0; i < input_node_names_.size(); ++i)
			input_node_names.emplace_back(input_node_names_[i].c_str());
		for (auto i = 0; i < output_node_names_.size(); ++i)
			output_node_names.emplace_back(output_node_names_[i].c_str());

		Ort::RunOptions options(nullptr);
		std::unique_ptr<float[]> blob(new float[image_size[0] * image_size[1] * 3]);
		std::vector<int64_t> input_node_dims{ 1, 3, image_size[1], image_size[0] };

		auto classes = parse_classes_file(classes_file);
		if (classes.size() == 0) {
			std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;
			return -1;
		}

		for (const auto& [key, val] : get_dir_images(images_dir)) {
			cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);
			if (frame.empty()) {
				std::cerr << "Warning: unable to load image: " << val << std::endl;
				continue;
			}

			auto tstart = std::chrono::high_resolution_clock::now();
			cv::Mat rgb;
			auto resize_scales = image_preprocess(frame, rgb);
			image_to_blob(rgb, blob.get());
			Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
				Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob.get(), 3 * image_size[1] * image_size[0], input_node_dims.data(), input_node_dims.size());
			auto output_tensors = session.Run(options, input_node_names.data(), &input_tensor, 1, output_node_names.data(), output_node_names.size());

			Ort::TypeInfo type_info = output_tensors.front().GetTypeInfo();
			auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
			std::vector<int64_t> output_node_dims = tensor_info.GetShape();
			auto output = output_tensors.front().GetTensorMutableData<float>();
			int stride_num = output_node_dims[1];
			int signal_result_num = output_node_dims[2];
			cv::Mat raw_data = cv::Mat(stride_num, signal_result_num, CV_32F, output);
			raw_data = raw_data.t();
			float* data = (float*)raw_data.data;

			auto tend = std::chrono::high_resolution_clock::now();
			std::cout << "elapsed millisenconds: " << std::chrono::duration_cast<std::chrono::milliseconds>(tend - tstart).count() << " ms" << std::endl;

			post_process(data, signal_result_num, stride_num, resize_scales, resize_scales, classes, frame, key);
		}
	}
	catch (const std::exception& e) {
		std::cerr << "Error: " << e.what() << std::endl;
		return -1;
	}

	return 0;
}

      labels.txt文件内容如下:仅2类

watermelon 0
wintermelon 1

      说明

      1.这里使用的onnxruntime版本为1.18.0;

      2.windows下,onnxruntime库在debug和release为同一套库,在debug和release下均可执行;

      3.通过指定变量cuda_enabled判断走cpu还是gpu流程 ;

      4.windows下,onnxruntime中有些接口参数为wchar_t*,而linux下为char*,因此在windows下需要单独做转换,这里通过ctow函数实现从char*到wchar_t的转换;

      5.yolov8中提供的sample有问题,需要作调整。

      执行结果如下图所示:同样的预测图像集,与opencv dnn结果相似,它们具有相同的后处理流程;下面显示的耗时是在cpu下,gpu下仅20毫秒左右

      其中一幅图像的检测结果如下图所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1799030.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React 18

创建 React 18 脚手架项目 全局安装 create-react-app npm install -g create-react-app yarn global add create-react-app . 确认是否已安装 create-react-app npm list -g create-react-app yarn global list | grep create-react-app . 如果安装失败 有时&#xff0…

Vue3整合Tailwindcss之padding样式类

04 常用基础样式 padding 样式类 什么是内边距 基础样式 ClassPropertiesp-0padding: 0px;px-0padding-left: 0px; padding-right: 0px;py-0padding-top: 0px; padding-bottom: 0px;ps-0padding-inline-start: 0px;pe-0padding-inline-end: 0px;pt-0padding-top: 0px;pr-0pa…

JVM 运行流程

JVM 是 Java 运行的基础&#xff0c;也是实现一次编译到处执行的关键&#xff0c;那么 JVM 是如何执行的呢&#xff1f; JVM 执行流程 程序在执行之前先要把java代码转换成字节码&#xff08;class 文件&#xff09;&#xff0c; JVM 首先需要把字节码通过一定的 方式 类加…

华为面经整理

文章目录 实习第一面准备提问相关算法相关 第一面结果提问环节 总结 实习 第一面准备 提问相关 操作系统有哪些功能 进程管理&#xff1a; 进程调度、进程同步和通信、多任务处理 内存管理&#xff1a; 内存分配、虚拟内存技术、内存保护 文件系统管理&#xff1a; 文件存储…

MMUNet:形态学特征增强网络在结肠癌病理图像分割中的应用

MMUNet: Morphological feature enhancement network for colon cancer segmentation in pathological images. 发表在&#xff1a;Biomedical Signal Processing and Control2024--影响因子&#xff1a;3.137 南华大学的论文 论文地址&#xff1a;main.pdf (sciencedirecta…

【ffmpeg】本地格式转换 mp4转wav||裁剪mp4

个人感受&#xff1a;太爽了&#xff01;&#xff01;&#xff01;&#xff08;可能用惯了转换网站和无良的转换软件&#xff09; ———— 使用FFmpeg把mp4文件转换为WAV文件 - 简书 (jianshu.com) FFMPEG 视频分割和合并 - 简书 (jianshu.com) ———— 示例 ffmpeg -i …

计算机组成结构—IO接口(IO控制器)

目录 一、I/O 接口的功能 二、I/O 接口的基本结构 1. 总线连接的数据通路 2. I/O 接口的基本组成 三、I/O 端口及其编址 1. 统一编址 2. 不统一编址 四、I/O 接口的类型 两个系统或两个部件之间的交接部分&#xff0c;一般就称为 接口。接口可以是硬件上两种设备间的连…

哈夫曼树的创建

要了解哈夫曼树&#xff0c;可以先了解一下哈夫曼编码&#xff0c;假设我们有几个字母&#xff0c;他们的出现频率是A: 1 B: 2 C: 3 D: 4 E: 5 F: 6 G: 7。那么如果想要压缩数据的同时让访问更加快捷&#xff0c;就要让频率高的字母离根节点比较进&#xff0c;容易访问&#xf…

微生物共生与致病性:动态变化与识别挑战

谷禾健康 细菌耐药性 抗生素耐药性细菌感染的发生率正在上升&#xff0c;而新抗生素的开发由于种种原因在制药行业受重视程度下降。 最新在《柳叶刀-微生物》&#xff08;The Lancet Microbe&#xff09;上&#xff0c;科学家提出了基于细菌适应性、竞争和传播的生态原则的跨学…

个人vsCode配置文件<setting.js>

个人vsCode配置文件setting.js 快速打开1、使用快捷键 CtrlShiftP &#xff0c;然后搜索setting2、手动 自用配置 快速打开 1、使用快捷键 CtrlShiftP &#xff0c;然后搜索setting 2、手动 自用配置 {"terminal.integrated.profiles.windows": {"PowerShell&…

阿里云对象存储OSS简单使用

文章目录 概念基本概念Bucket 准备工作控制台操作对象存储OSSJava客户端操作对象存储OSS参考来源 概念 基本概念 阿里云对象存储 OSS是一款海量、安全、低成本、高可靠的云存储服务&#xff0c;提供最高可达 99.995 % 的服务可用性。而且提供了多种存储类型&#xff0c;降低我…

顶顶通呼叫中心中间件-asr录音路径修改(mod_cti基于FreeSWITCH)

顶顶通呼叫中心中间件-asr录音路径修改(mod_cti基于FreeSWITCH) 录音路径模板。如果不是绝对路径&#xff0c;会把这个路径追加到FreeSWITCH的recordings后面。支持变量&#xff0c;比如日期 ${strftime(%Y-%m-%d)}。最后一个录音文件路径会保存到变量 ${cti_asr_last_record_…

PDF 文件的解析

1、文本 PDF 的解析 1.1、文本的提取 进行文本提取的 Python 库包括&#xff1a;pdfminer.six、PyMuPDF、PyPDF2 和 pdfplumber&#xff0c;效果最好的是 PyMuPDF&#xff0c;PyMuPDF 在进行文本提取时能够最大限度地保留 PDF 的阅读顺序&#xff0c;这对于双栏 PDF 文件的抽…

刷完50题,搞定十大网络基础知识

号主&#xff1a;老杨丨11年资深网络工程师&#xff0c;更多网工提升干货&#xff0c;请关注公众号&#xff1a;网络工程师俱乐部 上午好&#xff0c;我的网工朋友 咱新手网工&#xff0c;入行之前最需要做的准备之一&#xff0c;就是抓住网络基础知识&#xff0c;毕竟是饭碗&…

C语言野指针、规避野指针、assert宏断言

目录 a.野指针成因 1.指针未初始化 2.指针越界访问 3.指针指向的空间释放 b.规避野指针 1.指针初始化 2.小心指针越界 3.指针变量不再使用时&#xff0c;及时置NULL&#xff0c;指针使用之前检查有效性 4.避免返回局部变量的地址 c.assert宏断言的使用 概念&#xff1…

上位机快速开发框架

右上角向下按钮 -> 后台配置 系统菜单 角色管理 分配权限 用户管理 设备配置 通道管理 首页界面设计 设备1配置 带反馈按钮&#xff0c;如&#xff1a;用户按键00105&#xff0c;PLC反馈状态00106 设备2配置 参数说明&#xff1a; TagName_Main&#xff1a;主要信息&#…

加密经济浪潮:探索Web3对金融体系的颠覆

随着区块链技术的快速发展&#xff0c;加密经济正在成为全球金融领域的一股新的浪潮。而Web3作为下一代互联网的代表&#xff0c;以其去中心化、可编程的特性&#xff0c;正深刻影响着传统金融体系的格局和运作方式。本文将深入探讨加密经济对金融体系的颠覆&#xff0c;探索We…

SpringBoot+百度地图+Mysql实现中国地图可视化

通过SpringBoot百度地图Mysql实现中国地图可视化 一、申请百度地图的ak值 进入百度开发者平台 编辑以下内容 然后申请成功 二、Springboot写一个接口 确保数据库里有数据 文件目录如下 1、配置application.properties文件 #访问端口号 server.port9090 # 数据库连接信息 spr…

Xamarin.Android实现通知推送功能(1)

目录 1、背景说明1.1 开发环境1.2 实现效果1.2.1 推送的界面1.2.2 推送的设置1.2.3 推送的功能实现1.2.3.1、Activity的设置【重要】1.2.3.2、代码的实现 2、源码下载3、总结4、参考资料 1、背景说明 在App开发中&#xff0c;通知&#xff08;或消息&#xff09;的推送&#x…

CiteScore 2023发布,AI Open斩获45分,位列全球计算机领域前1%

与影响因子&#xff08;IF&#xff09;一样&#xff0c;引用分数&#xff08;CiteScore&#xff09;同样是衡量学术期刊影响力的重要指标之一&#xff0c;且大有赶超前者的势头。 6 月 6 日&#xff0c;CiteScore 2023 正式发布&#xff0c;人工智能领域可自由访问的期刊平台 …