使用libtorch加载YOLOv8生成的torchscript文件进行目标检测

news2024/12/30 3:25:19

      在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe  工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:

path: ../datasets/melon # dataset root dir
train: images/train # train images (relative to 'path')
val: images/val  # val images (relative to 'path')
test: # test images (optional)
 
# Classes
names:
  0: watermelon
  1: wintermelon

      使用以下python脚本进行训练生成torchscript文件

import argparse
import colorama
from ultralytics import YOLO
 
def parse_args():
	parser = argparse.ArgumentParser(description="YOLOv8 object detect")
	parser.add_argument("--yaml", required=True, type=str, help="yaml file")
	parser.add_argument("--epochs", required=True, type=int, help="number of training")
 
	args = parser.parse_args()
	return args
 
def train(yaml, epochs):
	model = YOLO("yolov8n.pt") # load a pretrained model
	results = model.train(data=yaml, epochs=epochs, imgsz=640) # train the model
 
	metrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings remembered
 
	model.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support
	# model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
	model.export(format="torchscript") # libtorch
 
if __name__ == "__main__":
	colorama.init()
	args = parse_args()
 
	train(args.yaml, args.epochs)
 
	print(colorama.Fore.GREEN + "====== execution completed ======")

      以下是使用libtorch接口加载torchscript文件进行目标检测的实现代码:

namespace {

constexpr bool cuda_enabled{ false };
constexpr int image_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400)
constexpr float model_score_threshold{ 0.45 }; // confidence threshold
constexpr float model_nms_threshold{ 0.50 }; // iou threshold

#ifdef _MSC_VER
constexpr char* onnx_file{ "../../../data/best.onnx" };
constexpr char* torchscript_file{ "../../../data/best.torchscript" };
constexpr char* images_dir{ "../../../data/images/predict" };
constexpr char* result_dir{ "../../../data/result" };
constexpr char* classes_file{ "../../../data/images/labels.txt" };
#else
constexpr char* onnx_file{ "data/best.onnx" };
constexpr char* torchscript_file{ "data/best.torchscript" };
constexpr char* images_dir{ "data/images/predict" };
constexpr char* result_dir{ "data/result" };
constexpr char* classes_file{ "data/images/labels.txt" };
#endif

std::vector<std::string> parse_classes_file(const char* name)
{
	std::vector<std::string> classes;

	std::ifstream file(name);
	if (!file.is_open()) {
		std::cerr << "Error: fail to open classes file: " << name << std::endl;
		return classes;
	}
	
	std::string line;
	while (std::getline(file, line)) {
		auto pos = line.find_first_of(" ");
		classes.emplace_back(line.substr(0, pos));
	}

	file.close();
	return classes;
}

auto get_dir_images(const char* name)
{
	std::map<std::string, std::string> images; // image name, image path + image name

	for (auto const& dir_entry : std::filesystem::directory_iterator(name)) {
		if (dir_entry.is_regular_file())
			images[dir_entry.path().filename().string()] = dir_entry.path().string();
	}

	return images;
}

void draw_boxes(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,
	const std::vector<cv::Rect>& boxes, const std::string& name, cv::Mat& frame)
{
	if (ids.size() != confidences.size() || ids.size() != boxes.size() || confidences.size() != boxes.size()) {
		std::cerr << "Error: their lengths are inconsistent: " << ids.size() << ", " << confidences.size() << ", " << boxes.size() << std::endl;
		return;
	}

	std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;

	std::random_device rd;
	std::mt19937 gen(rd());
	std::uniform_int_distribution<int> dis(100, 255);

	for (auto i = 0; i < ids.size(); ++i) {
		auto color = cv::Scalar(dis(gen), dis(gen), dis(gen));
		cv::rectangle(frame, boxes[i], color, 2);

		std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);
		cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
		cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);

		cv::rectangle(frame, text_box, color, cv::FILLED);
		cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
	}

	cv::imshow("Inference", frame);
	cv::waitKey(-1);

	std::string path(result_dir);
	path += "/" + name;
	cv::imwrite(path, frame);
}

float letter_box(const cv::Mat& src, cv::Mat& dst, const std::vector<int>& imgsz)
{
	if (src.cols == imgsz[1] && src.rows == imgsz[0]) {
		if (src.data == dst.data) {
			return 1.;
		} else {
			dst = src.clone();
			return 1.;
		}
	}

	auto resize_scale = std::min(imgsz[0] * 1. / src.rows, imgsz[1] * 1. / src.cols);
	int new_shape_w = std::round(src.cols * resize_scale);
	int new_shape_h = std::round(src.rows * resize_scale);
	float padw = (imgsz[1] - new_shape_w) / 2.;
	float padh = (imgsz[0] - new_shape_h) / 2.;

	int top = std::round(padh - 0.1);
	int bottom = std::round(padh + 0.1);
	int left = std::round(padw - 0.1);
	int right = std::round(padw + 0.1);

	cv::resize(src, dst, cv::Size(new_shape_w, new_shape_h), 0, 0, cv::INTER_AREA);
	cv::copyMakeBorder(dst, dst, top, bottom, left, right, cv::BORDER_CONSTANT, cv::Scalar(114.));

	return resize_scale;
}

torch::Tensor xywh2xyxy(const torch::Tensor& x)
{
	auto y = torch::empty_like(x);
	auto dw = x.index({ "...", 2 }).div(2);
	auto dh = x.index({ "...", 3 }).div(2);
	y.index_put_({ "...", 0 }, x.index({ "...", 0 }) - dw);
	y.index_put_({ "...", 1 }, x.index({ "...", 1 }) - dh);
	y.index_put_({ "...", 2 }, x.index({ "...", 0 }) + dw);
	y.index_put_({ "...", 3 }, x.index({ "...", 1 }) + dh);

	return y;
}

// reference: https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/nms_kernel.cpp
torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, float iou_threshold)
{
	if (bboxes.numel() == 0)
		return torch::empty({ 0 }, bboxes.options().dtype(torch::kLong));

	auto x1_t = bboxes.select(1, 0).contiguous();
	auto y1_t = bboxes.select(1, 1).contiguous();
	auto x2_t = bboxes.select(1, 2).contiguous();
	auto y2_t = bboxes.select(1, 3).contiguous();

	torch::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);

	auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));

	auto ndets = bboxes.size(0);
	torch::Tensor suppressed_t = torch::zeros({ ndets }, bboxes.options().dtype(torch::kByte));
	torch::Tensor keep_t = torch::zeros({ ndets }, bboxes.options().dtype(torch::kLong));

	auto suppressed = suppressed_t.data_ptr<uint8_t>();
	auto keep = keep_t.data_ptr<int64_t>();
	auto order = order_t.data_ptr<int64_t>();
	auto x1 = x1_t.data_ptr<float>();
	auto y1 = y1_t.data_ptr<float>();
	auto x2 = x2_t.data_ptr<float>();
	auto y2 = y2_t.data_ptr<float>();
	auto areas = areas_t.data_ptr<float>();

	int64_t num_to_keep = 0;

	for (int64_t _i = 0; _i < ndets; _i++) {
		auto i = order[_i];
		if (suppressed[i] == 1)
			continue;
		keep[num_to_keep++] = i;
		auto ix1 = x1[i];
		auto iy1 = y1[i];
		auto ix2 = x2[i];
		auto iy2 = y2[i];
		auto iarea = areas[i];

		for (int64_t _j = _i + 1; _j < ndets; _j++) {
			auto j = order[_j];
			if (suppressed[j] == 1)
				continue;
			auto xx1 = std::max(ix1, x1[j]);
			auto yy1 = std::max(iy1, y1[j]);
			auto xx2 = std::min(ix2, x2[j]);
			auto yy2 = std::min(iy2, y2[j]);

			auto w = std::max(static_cast<float>(0), xx2 - xx1);
			auto h = std::max(static_cast<float>(0), yy2 - yy1);
			auto inter = w * h;
			auto ovr = inter / (iarea + areas[j] - inter);
			if (ovr > iou_threshold)
				suppressed[j] = 1;
		}
	}

	return keep_t.narrow(0, 0, num_to_keep);
}

torch::Tensor non_max_suppression(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300)
{
	using torch::indexing::Slice;
	using torch::indexing::None;

	auto bs = prediction.size(0);
	auto nc = prediction.size(1) - 4;
	auto nm = prediction.size(1) - nc - 4;
	auto mi = 4 + nc;
	auto xc = prediction.index({ Slice(), Slice(4, mi) }).amax(1) > conf_thres;

	prediction = prediction.transpose(-1, -2);
	prediction.index_put_({ "...", Slice({None, 4}) }, xywh2xyxy(prediction.index({ "...", Slice(None, 4) })));

	std::vector<torch::Tensor> output;
	for (int i = 0; i < bs; i++) {
		output.push_back(torch::zeros({ 0, 6 + nm }, prediction.device()));
	}

	for (int xi = 0; xi < prediction.size(0); xi++) {
		auto x = prediction[xi];
		x = x.index({ xc[xi] });
		auto x_split = x.split({ 4, nc, nm }, 1);
		auto box = x_split[0], cls = x_split[1], mask = x_split[2];
		auto [conf, j] = cls.max(1, true);
		x = torch::cat({ box, conf, j.toType(torch::kFloat), mask }, 1);
		x = x.index({ conf.view(-1) > conf_thres });
		int n = x.size(0);
		if (!n) { continue; }

		// NMS
		auto c = x.index({ Slice(), Slice{5, 6} }) * 7680;
		auto boxes = x.index({ Slice(), Slice(None, 4) }) + c;
		auto scores = x.index({ Slice(), 4 });
		auto i = nms(boxes, scores, iou_thres);
		i = i.index({ Slice(None, max_det) });
		output[xi] = x.index({ i });
	}

	return torch::stack(output);
}

} // namespace

int test_yolov8_detect_libtorch()
{
	// reference: ultralytics/examples/YOLOv8-LibTorch-CPP-Inference
	if (auto flag = torch::cuda::is_available(); flag == true)
		std::cout << "cuda is available" << std::endl;
	else
		std::cout << "cuda is not available" << std::endl;

	torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);

	auto classes = parse_classes_file(classes_file);
	if (classes.size() == 0) {
		std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;
		return -1;
	}

	std::cout << "classes: ";
	for (const auto& val : classes) {
		std::cout << val << " ";
	}
	std::cout << std::endl;

	try {
		// load model
		torch::jit::script::Module model;
		if (torch::cuda::is_available() == true)
			model = torch::jit::load(torchscript_file, torch::kCUDA);
		else
			model = torch::jit::load(torchscript_file, torch::kCPU);
		model.eval();
		// note: cpu is normal; gpu is abnormal: the model may not be fully placed on the gpu 
		// model = torch::jit::load(file); model.to(torch::kCUDA) ==> model = torch::jit::load(file, torch::kCUDA)
		// model.to(device, torch::kFloat32);

		for (const auto& [key, val] : get_dir_images(images_dir)) {
			// load image and preprocess
			cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);
			if (frame.empty()) {
				std::cerr << "Warning: unable to load image: " << val << std::endl;
				continue;
			}

			cv::Mat bgr;
			letter_box(frame, bgr, {image_size[0], image_size[1]});

			torch::Tensor tensor = torch::from_blob(bgr.data, { bgr.rows, bgr.cols, 3 }, torch::kByte).to(device);
			tensor = tensor.toType(torch::kFloat32).div(255);
			tensor = tensor.permute({ 2, 0, 1 });
			tensor = tensor.unsqueeze(0);
			std::vector<torch::jit::IValue> inputs{ tensor };

			// inference
			torch::Tensor output = model.forward(inputs).toTensor().cpu();

			// NMS
			auto keep = non_max_suppression(output, 0.1f, 0.1f, 300)[0];

			std::vector<int> ids;
			std::vector<float> confidences;
			std::vector<cv::Rect> boxes;
			for (auto i = 0; i < keep.size(0); ++i) {
				int x1 = keep[i][0].item().toFloat();
				int y1 = keep[i][1].item().toFloat();
				int x2 = keep[i][2].item().toFloat();
				int y2 = keep[i][3].item().toFloat();
				boxes.emplace_back(cv::Rect(x1, y1, x2 - x1, y2 - y1));

				confidences.emplace_back(keep[i][4].item().toFloat());
				ids.emplace_back(keep[i][5].item().toInt());
			}

			draw_boxes(classes, ids, confidences, boxes, key, bgr);
		}
	} catch (const c10::Error& e) {
		std::cerr << "Error: " << e.msg() << std::endl;
	}

	return 0;
}

      labels.txt文件内容如下:仅2类

watermelon 0
wintermelon 1

      说明

      1.这里使用的libtorch版本为2.2.2;

      2.通过函数torch::cuda::is_available()判断执行cpu还是gpu

      3.通过非cmake构建项目时,调用torch::cuda::is_available()时即使在gpu下也会返回false,解决方法:项目属性:链接器 --> 命令行:其他选项中添加如下语句:

/INCLUDE:?warp_size@cuda@at@@YAHXZ

      4.gpu下,语句model.to(torch::kCUDA)有问题,好像并不能将模型全部放置在gpu上,应调整为如下语句:

model = torch::jit::load(torchscript_file, torch::kCUDA);

      执行结果如下图所示:同样的预测图像集,结果不如使用 opencv dnn 方法好,它们的前处理和后处理方式不同

      其中一幅图像的检测结果如下图所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1697900.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络通信(二)

UDP通信 特点&#xff1a;无连不是先接、不可靠通信 不事先建立连接&#xff1b;发送端每次把要发送的数据&#xff08;限制在64KB内&#xff09;、接收端IP、等信息封装成一个数据包&#xff0c;发出去就不管了 java提供了一个java.net.DatagramSocket类来实现UDP通信 Dat…

20.SkyWalking

一.简介 SkyWalking用于应用性能监控、分布式链路跟踪、诊断&#xff1a; 参考连接如下&#xff1a; https://github.com/apache/skywalking https://skywalking.apache.org/docs/ 二.示例 通过官网连接进入下载页面&#xff1a;https://archive.apache.org/dist/skywalkin…

普通人转行程序员,最大的困难是找不到就业方向

来百度APP畅享高清图片 大家好&#xff0c;这里是程序员晚枫&#xff0c;小破站也叫这个名。 我自己是法学院毕业后&#xff0c;通过2年的努力才转行程序员成功的。[吃瓜R] 我发现对于一个外行来说&#xff0c;找不到一个适合自己的方向&#xff0c;光靠努力在一个新的行业里…

美团Java社招面试题真题,最新面试题

如何处理Java中的内存泄露&#xff1f; 1、识别泄露&#xff1a; 使用内存分析工具&#xff08;如Eclipse Memory Analyzer Tool、VisualVM&#xff09;来识别内存泄露的源头。 2、代码审查&#xff1a; 定期进行代码审查&#xff0c;关注静态集合类属性和监听器注册等常见内…

Leetcode算法题笔记(3)

目录 矩阵101. 生命游戏解法一解法二 栈102. 移掉 K 位数字解法一 103. 去除重复字母解法一 矩阵 101. 生命游戏 根据 百度百科 &#xff0c; 生命游戏 &#xff0c;简称为 生命 &#xff0c;是英国数学家约翰何顿康威在 1970 年发明的细胞自动机。 给定一个包含 m n 个格子…

Redis简介与安装到python的调用

前言 本文只不对redis的具体用法做详细描述&#xff0c;做简单的介绍&#xff0c;安装&#xff0c;和python代码调用详细使用教程可查看一下网站 https://www.runoob.com/redis/redis-tutorial.html https://pypi.org/project/redis/ 官方原版: https://redis.io/ 中文官网:…

齿轮常见故障学习笔记

大家好&#xff0c;这期咱们聊一聊齿轮常见的失效形式&#xff0c;查阅了相关的资料&#xff0c;做个笔记分享给大家&#xff0c;共同学习。 介绍 齿轮故障可能以多种方式发生。如果在设计阶段本身就尽量防止这些故障的产生&#xff0c;则可以产生改更为优化的齿轮设计。齿轮…

Top期刊:针对论文Figure图片的7个改进建议

我是娜姐 迪娜学姐 &#xff0c;一个SCI医学期刊编辑&#xff0c;探索用AI工具提效论文写作和发表。 通过对来自细胞生物学、生理学和植物学领域的580篇论文&#xff0c;进行检查和归纳总结&#xff0c;来自德国德累斯顿工业大学的Helena Jambor及合作者&#xff0c;在PLOS Bio…

五分钟搭建一个Suno AI音乐站点

五分钟搭建一个Suno AI音乐站点 在这个数字化时代&#xff0c;人工智能技术正以惊人的速度改变着我们的生活方式和创造方式。音乐作为一种最直接、最感性的艺术形式&#xff0c;自然也成为了人工智能技术的应用场景之一。今天&#xff0c;我们将以Vue和Node.js为基础&#xff…

第12章-ADC采集电压和显示 基于STM32的ADC—电压采集(详细讲解+HAL库)

我们的智能小车用到了ADC测量电池电压的功能&#xff0c;这章节我们做一下。 我们的一篇在这里 第一篇 什么是ADC 百度百科介绍&#xff1a; 我们知道万用表 电压表可以测量电池&#xff0c;或者电路电压。那么我们是否可以通过单片机获得电压&#xff0c;方便我 们监控电池状…

WPF学习日常篇(一)--开发界面视图布局

接下来开始日常篇&#xff0c;我在主线篇&#xff08;正文&#xff09;中说过要介绍一下我的界面排布&#xff0c;科学的排布才更科学更有效率的进行敲代码和开发。日常篇中主要记录我的一些小想法和所考虑的一些细节。 一、主界面设置 主界面分为左右两部分&#xff0c;分为…

查分数组总结

文章目录 查分数组定义应用举例LeetCode 1109 题「[航班预订统计] 查分数组定义 差分数组的主要适用场景是频繁对原始数组的某个区间的元素进行增减。 通过这个 diff 差分数组是可以反推出原始数组 nums 的&#xff0c;代码逻辑如下&#xff1a; int res[diff.size()]; // 根…

(2024,SDE,对抗薛定谔桥匹配,离散时间迭代马尔可夫拟合,去噪扩散 GAN)

Adversarial Schrdinger Bridge Matching 公众号&#xff1a;EDPJ&#xff08;进 Q 交流群&#xff1a;922230617 或加 VX&#xff1a;CV_EDPJ 进 V 交流群&#xff09; 目录 0. 摘要 1. 简介 4. 实验 0. 摘要 薛定谔桥&#xff08;Schrdinger Bridge&#xff0c;SB&…

flash-linear-attention中的Chunkwise并行算法的理解

这里提一下&#xff0c;我维护的几三个记录个人学习笔记以及社区中其它大佬们的优秀博客链接的仓库都获得了不少star&#xff0c;感谢读者们的认可&#xff0c;我也会继续在开源社区多做贡献。github主页&#xff1a;https://github.com/BBuf &#xff0c;欢迎来踩 0x0. 前言 …

MySQL——索引与事务

目录 前言 一、索引 1.索引概述 &#xff08;1&#xff09;基本概念 &#xff08;2&#xff09;索引作用 &#xff08;3&#xff09;索引特点 &#xff08;4&#xff09;适用场景 2.索引的操作 &#xff08;1&#xff09;查看索引 &#xff08;2&#xff09;创建索引…

【算法题】520 钻石争霸赛 2024 全解析

都是自己写的代码&#xff0c;发现自己的问题是做题速度还是不够快 520-1 爱之恒久远 在 520 这个特殊的日子里&#xff0c;请你直接在屏幕上输出&#xff1a;Forever and always。 输入格式&#xff1a; 本题没有输入。 输出格式&#xff1a; 在一行中输出 Forever and always…

2024年推荐的适合电脑和手机操作的线上兼职副业平台

总是会有人在找寻着线上兼职副业&#xff0c;那么在如今的2024年&#xff0c;互联网提供了诸多方便&#xff0c;无论你是宝妈、大学生、程序员、外卖小哥还是打工族&#xff0c;如果你正在寻找副业机会&#xff0c;那么这篇文章将为你提供一些适合电脑和手机操作的线上兼职副业…

[Linux]文件/文件描述符fd

一、关于文件 文件&#xff1d;内容&#xff0b;属性 那么所有对文件的操作&#xff0c;就是对内容/属性操作。内容是数据&#xff0c;属性也是数据&#xff0c;那么存储文件&#xff0c;就必须既存储内容数据&#xff0c;又存储属性数据。默认就是在磁盘中的文件。当进程访问…

知识分享:隔多久查询一次网贷大数据信用报告比较好?

随着互联网金融的快速发展&#xff0c;越来越多的人开始接触和使用网络贷款。而在这个过程中&#xff0c;网贷大数据信用报告成为了评估借款人信用状况的重要依据。那么&#xff0c;隔多久查询一次网贷大数据信用报告比较好呢?接下来随小易大数据平台小编去看看吧。 首先&…

YOLOV5 改进:替换backbone为EfficientNet

1、介绍 本章将会把yolov5的主干网络替换成EfficientNet V2,这里直接粘贴代码 详细的可以参考之前的内容:YOLOV5 改进:替换backbone(MobileNet为例)_yolov5主干网络更换为mobile-CSDN博客 更多的backbone更换参考本专栏: YOLOV5 实战项目(训练、部署、改进等等)_听风吹…