YOLOV8训练好的best.pt模型转best.onnx并部署成python可调用

news2025/1/8 18:36:11

今天这篇博文是学习大佬作品以后,执行我的需求后的总结,做了一些代码调整,就此记录一下,非常感谢大佬提供如此好的输出。
已知yolov8 训练好的模型一般是pt格式,比如best.pt,现在我期望这个模型可以转成可以部署的格式,不那么明晃晃地调用yolo,于是乎就查到可以转成onnx格式。
1、onnx是什么格式
在这里插入图片描述
好像是一堆废话,就是可以用 ONNX Runtime 加载,还有高版本的OpenCV的dnn 也可以加载。
2、基于python 加载onnx模型
(1)将best.pt转成 best.onnx

from ultralytics import YOLO

# 加载训练好的 YOLOv8 模型
model = YOLO('E:/skin_yolo/runs/detect/spot_detection60/weights/best.pt')

# 导出为 ONNX 格式
#model.export(format='onnx')

model.export(format='onnx', imgsz=640)#我的输入是图像尺寸固定的640*640,所以我写死了

(2)python加载模型并做目标检测
首先需要安装onnxruntime、numpy、cv2等库。如果使用 GPU 进行推理,还需安装onnxruntime-gpu。

test_detector.py

#基于yolo模型检测皮肤图像上的目标
#2025-01-06

import cv2

#引用文件中的函数
from targetDetect import TargetDetection
from forDraw import draw_detections

# yolov8 onnx 模型推理
class YOLOV8NDetector:
    def __init__(self,model_path):
        super(YOLOV8NDetector, self).__init__()
        self.model_path = model_path
        self.detector = TargetDetection(self.model_path, conf_thres=0.5, iou_thres=0.3)

    def detect_image(self, input_image, output_image):
        cv_img = cv2.imread(input_image)
        boxes, scores, class_ids = self.detector.detect_objects(cv_img)
        cv_img = draw_detections(cv_img, boxes, scores, class_ids)
        cv2.namedWindow("output", cv2.WINDOW_NORMAL)
        cv2.imwrite(output_image, cv_img)
        cv2.imshow('output', cv_img)
        cv2.waitKey(0)

    def detect_video(self, input_video, output_video):
        cap = cv2.VideoCapture(input_video)
        fps = int(cap.get(5))
        videoWriter = None
        while True:
            _, cv_img = cap.read()
            if cv_img is None:
                break
            boxes, scores, class_ids = self.detector.detect_objects(cv_img)
            cv_img = draw_detections(cv_img, boxes, scores, class_ids)

            # 如果视频写入器未初始化,则使用输出视频路径和参数进行初始化
            if videoWriter is None:
                fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
                # 在这里给值了,它就不是None, 下次判断它就不进这里了
                videoWriter = cv2.VideoWriter(output_video, fourcc, fps, (cv_img.shape[1], cv_img.shape[0]))

            videoWriter.write(cv_img)
            cv2.imshow("aod", cv_img)
            cv2.waitKey(5)

            # 等待按键并检查窗口是否关闭
            if cv2.getWindowProperty("aod", cv2.WND_PROP_AUTOSIZE) < 1:
                # 点x退出
                break
        cap.release()
        videoWriter.release()
        cv2.destroyAllWindows()



if __name__ == '__main__':

    modelpath ="E:/skin_yolo/runs/detect/spot_detection60/weights/best.onnx"#模型路径
    det = YOLOV8NDetector(modelpath)

    #检测图片时调用
    input_image = "E:/Skin_Color/skin_pic/test/12/test.jpg"
    output_image = 'E:/Skin_Color/skin_pic/test/12/test_out.jpg'
    det.detect_image(input_image, output_image)

    #检测视频是调用
    # input_video = r"E:\yolodataset\video\A13.mp4"
    # output_video = "../testdata/fortest.mp4"
    # det.detect_video(input_video, output_video)

可以看出上面的代码依赖两个文件:targetDetect.py 和 forDraw .py

targetDetect.py 中定义了检测目标处理,forDraw .py 中定义了一些目标画框。

targetDetect.py

import time
import cv2
import numpy as np
import onnxruntime

#引用文件中的函数
from forDraw import xywh2xyxy, draw_detections,nms # 单类目标用nms , 多类目标用multiclass_nms


class TargetDetection:
    def __init__(self, path, conf_thres=0.7, iou_thres=0.5):
        self.conf_threshold = conf_thres
        self.iou_threshold = iou_thres

        # Initialize model
        self.initialize_model(path)

    def __call__(self, image):
        return self.detect_objects(image)

    def initialize_model(self, path):
        self.session = onnxruntime.InferenceSession(path, providers=onnxruntime.get_available_providers())
        # Get model info
        self.get_input_details()
        self.get_output_details()

    def detect_objects(self, image):
        input_tensor = self.prepare_input(image)

        # Perform inference on the image
        outputs = self.inference(input_tensor)

        self.boxes, self.scores, self.class_ids = self.process_output(outputs)

        return self.boxes, self.scores, self.class_ids

    def prepare_input(self, image):
        self.img_height, self.img_width = image.shape[:2]

        input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Resize input image
        input_img = cv2.resize(input_img, (self.input_width, self.input_height))

        # Scale input pixel values to 0 to 1
        input_img = input_img / 255.0
        input_img = input_img.transpose(2, 0, 1)
        input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
        return input_tensor

    def inference(self, input_tensor):
        start = time.perf_counter()
        outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor})

        # print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
        return outputs

    def process_output(self, output):
        predictions = np.squeeze(output[0]).T

        # Filter out object confidence scores below threshold
        scores = np.max(predictions[:, 4:], axis=1)
        predictions = predictions[scores > self.conf_threshold, :]
        scores = scores[scores > self.conf_threshold]

        if len(scores) == 0:
            return [], [], []

        # Get the class with the highest confidence
        class_ids = np.argmax(predictions[:, 4:], axis=1)

        # Get bounding boxes for each object
        boxes = self.extract_boxes(predictions)

        # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
        indices = nms(boxes, scores, self.iou_threshold)#我的目标只有一个类
        #indices = multiclass_nms(boxes, scores, class_ids, self.iou_threshold)#多类

        return boxes[indices], scores[indices], class_ids[indices]

    def extract_boxes(self, predictions):
        # Extract boxes from predictions
        boxes = predictions[:, :4]

        # Scale boxes to original image dimensions
        boxes = self.rescale_boxes(boxes)

        # Convert boxes to xyxy format
        boxes = xywh2xyxy(boxes)

        return boxes

    def rescale_boxes(self, boxes):
        # Rescale boxes to original image dimensions
        input_shape = np.array([self.input_width, self.input_height, self.input_width, self.input_height])
        boxes = np.divide(boxes, input_shape, dtype=np.float32)
        boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
        return boxes

    def draw_detections(self, image, draw_scores=True, mask_alpha=0.4):
        return draw_detections(image, self.boxes, self.scores,
                               self.class_ids, mask_alpha)

    def get_input_details(self):
        model_inputs = self.session.get_inputs()
        self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]

        self.input_shape = model_inputs[0].shape
        self.input_height = self.input_shape[2]
        self.input_width = self.input_shape[3]

        print(self.input_width, self.input_height)

    def get_output_details(self):
        model_outputs = self.session.get_outputs()
        self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]

forDraw .py

import numpy as np
import cv2

class_names = ['spot'] #我的类标记

# Create a list of colors for each class where each color is a tuple of class number integer values

rng = np.random.default_rng(1)#此处是1,我的目标只有一个分类
colors = rng.uniform(0, 255, size=(len(class_names), 1))#此处是1,我的目标只有一个分类


def nms(boxes, scores, iou_threshold):
    # 根据 scores 对检测框从高到低进行排序,得到排序后的索引
    sorted_indices = np.argsort(scores)[::-1]  # [::-1] 反转排序顺序

    keep_boxes = []
    while sorted_indices.size > 0:
        # 保留最高分数的边界框
        box_id = sorted_indices[0]
        keep_boxes.append(box_id)

        # 计算当前最高分数的边界框与剩余边界框的 IoU
        ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])

        # 找出 IoU 小于阈值的边界框索引,保留这些框,过滤重叠框
        keep_indices = np.where(ious < iou_threshold)[0]

        # 注意:由于 keep_indices 是相对于 sorted_indices[1:] 的索引,
        # 需要将其整体偏移 +1 来匹配到原始 sorted_indices
        sorted_indices = sorted_indices[keep_indices + 1]

    return keep_boxes

def multiclass_nms(boxes, scores, class_ids, iou_threshold):
    # 获取所有唯一的类别索引
    unique_class_ids = np.unique(class_ids)

    keep_boxes = []  # 存储最终保留的边界框索引

    for class_id in unique_class_ids:
        # 筛选出属于当前类别的边界框索引
        class_indices = np.where(class_ids == class_id)[0]  # np.where返回元组

        # 提取属于当前类别的边界框和分数
        class_boxes = boxes[class_indices, :]  # 当前类别的边界框
        class_scores = scores[class_indices]  # 当前类别的分数

        # 执行 NMS 并获取保留下来的索引
        class_keep_boxes = nms(class_boxes, class_scores, iou_threshold)

        # 将保留的索引(对应原始的索引)添加到结果中
        keep_boxes.extend(class_indices[class_keep_boxes])

    return keep_boxes
    
def compute_iou(box, boxes):
    # 计算交集区域的坐标,xmin 和 ymin: 交集左上角的坐标,xmax 和 ymax: 交集右下角的坐标
    xmin = np.maximum(box[0], boxes[:, 0])
    ymin = np.maximum(box[1], boxes[:, 1])
    xmax = np.minimum(box[2], boxes[:, 2])
    ymax = np.minimum(box[3], boxes[:, 3])

    # 计算交集区域面积,如果两个框没有重叠,交集宽度和高度会为负,使用 np.maximum 保证面积非负
    intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)

    # 计算每个边界框的面积
    box_area = (box[2] - box[0]) * (box[3] - box[1])
    boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

    # 计算并集区域面积
    union_area = box_area + boxes_area - intersection_area

    # 计算 IoU(交并比)
    iou = intersection_area / union_area  # 交集区域面积 / 并集区域面积

    return iou

def xywh2xyxy(x):
    # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
    # 将边界框从 (x_center, y_center, w, h) 格式转换为 (x1, y1, x2, y2)
    y = np.copy(x)
    # 计算左上角坐标 x1 和 y1
    y[..., 0] = x[..., 0] - x[..., 2] / 2
    y[..., 1] = x[..., 1] - x[..., 3] / 2
    # 计算右下角坐标 x2 和 y2
    y[..., 2] = x[..., 0] + x[..., 2] / 2
    y[..., 3] = x[..., 1] + x[..., 3] / 2
    return y


def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3):
    #画检测目标
    det_img = image.copy()

    img_height, img_width = image.shape[:2]
    font_size = min([img_height, img_width]) * 0.0006
    text_thickness = int(min([img_height, img_width]) * 0.001)

    det_img = draw_masks(det_img, boxes, class_ids, mask_alpha)

    # Draw bounding boxes and labels of detections
    for class_id, box, score in zip(class_ids, boxes, scores):
        color = colors[class_id]

        draw_box(det_img, box, color)

        label = class_names[class_id]
        caption = f'{label} {int(score * 100)}%'
        draw_text(det_img, caption, box, color, font_size, text_thickness)

    return det_img


def draw_box(image: np.ndarray, box: np.ndarray, color: tuple[int, int, int] = (0, 0, 255),
             thickness: int = 2) -> np.ndarray:
    x1, y1, x2, y2 = box.astype(int)
    return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)


def draw_text(image: np.ndarray, text: str, box: np.ndarray, color: tuple[int, int, int] = (0, 0, 255),
              font_size: float = 0.001, text_thickness: int = 2) -> np.ndarray:
    #显示注释
    x1, y1, x2, y2 = box.astype(int)
    (tw, th), _ = cv2.getTextSize(text=text, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                                  fontScale=font_size, thickness=text_thickness)
    th = int(th * 1.2)#线宽

    cv2.rectangle(image, (x1, y1),(x1 + tw, y1 - th), color, -1)#画注释框

    return cv2.putText(image, text, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, font_size, (255, 255, 255), text_thickness,cv2.LINE_AA)


def draw_masks(image: np.ndarray, boxes: np.ndarray, classes: np.ndarray, mask_alpha: float = 0.3) -> np.ndarray:
    mask_img = image.copy()
    # 画检测到的目标框
    for box, class_id in zip(boxes, classes):
        color = colors[class_id]

        x1, y1, x2, y2 = box.astype(int)

        cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)

    # return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)#返回半透明框
    return image #返回全透明框

看看处理结果
在这里插入图片描述
(3)一些感慨
看模型流程

yolo模型导出以后,要加载处理其实还是需要理解透彻模型的过程,首先是输入和输出
一个不错的网站,可以在线查看模型拓扑结构 https://netron.app/
巨长的流程拓扑结构,小白暂时就只盯着输入和输出看了。
在这里插入图片描述
预处理和中间过程都很重要
1)预处理可以是一个很大的绊脚石
2)读取图像并将图像的颜色空间从 BGR 格式转换为 RGB 格式 ONNX 模型则期望输入是 RGB 格式;
3)图像大小resize,我训练就将图像用640了,所以需要 resize 到模型要求的输入尺寸;
4)归一化处理,将像素值归一化到 [0, 1] 区间。
5)调整图像通道顺序,一般从 HWC(Height, Width, Channel)转换为 CHW ( Channel,Height, Width,)格式,并增加一个批次维度,使其变为 NCHW 格式,N 为批次大小,通常设为 1。

最后特别感谢大佬的参考:https://blog.csdn.net/MariLN/article/details/144330414

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

君正T41交叉编译ffmpeg、opencv并做h264软解,利用君正SDK做h264硬件编码

目录 1 交叉编译ffmpeg----错误解决过程&#xff0c;不要看 1.1 下载源码 1.2 配置 1.3 编译 安装 1.3.1 报错&#xff1a;libavfilter/libavfilter.so: undefined reference to fminf 1.3.2 报错&#xff1a;error: unknown type name HEVCContext; did you mean HEVCPr…

基于ASP.NET的动漫网站

一、系统架构与技术实现 系统架构&#xff1a;基于ASP.NET的MVC框架构建&#xff0c;实现网站的层次结构&#xff0c;使得网站更加易于维护和扩展。 技术实现&#xff1a;利用ASP.NET的技术特点&#xff0c;如强大的后端开发能力、丰富的UI控件等&#xff0c;结合前端技术如HT…

「Java 数据结构全面解读」:从基础到进阶的实战指南

「Java 数据结构全面解读」&#xff1a;从基础到进阶的实战指南 数据结构是程序设计中的核心部分&#xff0c;用于组织和管理数据。Java 提供了丰富的集合框架和工具类&#xff0c;涵盖了常见的数据结构如数组、链表、栈、队列和树等。本文将系统性地介绍这些数据结构的概念、…

安卓NDK视觉开发——手机拍照文档边缘检测实现方法与库封装

一、项目创建 创建NDK项目有两种方式&#xff0c;一种从新创建整个项目&#xff0c;一个在创建好的项目添加NDK接口。 1.创建NDK项目 创建 一个Native C项目&#xff1a; 选择包名、API版本与算法交互的语言&#xff1a; 选择C版本&#xff1a; 创建完之后&#xff0c;可…

MATLAB仿真:基于GS算法的经大气湍流畸变涡旋光束波前校正仿真

GS算法流程 GS&#xff08;Gerchberg-Saxton&#xff09;相位恢复算法是一种基于傅里叶变换的最速下降算法&#xff0c;可以通过输出平面和输入平面上光束的光强分布计算出光束的相位分布。图1是基于GS算法的涡旋光束畸变波前校正系统框图&#xff0c;在该框图中&#xff0c;已…

【React+TypeScript+DeepSeek】穿越时空对话机

引言 在这个数字化的时代&#xff0c;历史学习常常给人一种距离感。教科书中的历史人物似乎永远停留在文字里&#xff0c;我们无法真正理解他们的思想和智慧。如何让这些伟大的历史人物"活"起来&#xff1f;如何让历史学习变得生动有趣&#xff1f;带着这些思考&…

深入刨析数据结构之排序(上)

目录 1.内部排序 1.1概述 1.2插入排序 1.2.1其他插入排序 1.2.1.1 折半插入排序 1.2.1.2 2-路插入排序 1.3希尔排序 1.4快速排序 1.4.1起泡排序 1.4.2快速排序 1.4.2.1hoare版本 1.4.2.2挖坑版本 1.4.2.3前后指针版本 1.4.2.4优化版本 1.4.2.4.1小区间插入排序优…

AIA - APLIC之三(附APLIC处理流程图)

本文属于《 RISC-V指令集基础系列教程》之一,欢迎查看其它文章。 1 APLIC复位 APLIC复位后,其所有状态都变得有效且一致,但以下情况除外: 每个中断域的domaincfg寄存器(spec第 4.5.1 节);可能是machine-level interrupt domain的MSI地址配置寄存器(spec第4.5.3 和4.5…

openwrt 清缓存命令行

一、查看缓存 &#xff1a; free -m 二、清缓存&#xff1a;echo 3 > /proc/sys/vm/drop_caches  三、详解。 释放物理页缓存 echo 1 > /proc/sys/vm/drop_caches 释放可回收的slab对象&#xff0c;包含inode and dentry echo 2 > /proc/sys/vm/drop_caches 同时…

Linux -- 端口号、套接字、网络字节序、sockaddr 结构体

目录 什么是端口号&#xff1f; 什么是套接字&#xff1f; 网络字节序 struct sockaddr 结构体 什么是端口号&#xff1f; 我们日常上网的时候&#xff0c;主机其实是在进行两种操作&#xff1a; 1、把远端的数据拉取到本地&#xff0c;比如刷抖音的时候&#xff0c;手机向…

《数据结构》期末考试测试题【中】

《数据结构》期末考试测试题【中】 21.循环队列队空的判断条件为&#xff1f;22. 单链表的存储密度比1&#xff1f;23.单链表的那些操作的效率受链表长度的影响&#xff1f;24.顺序表中某元素的地址为&#xff1f;25.m叉树第K层的结点数为&#xff1f;26. 在双向循环链表某节点…

实际开发中,常见pdf|word|excel等文件的预览和下载

实际开发中,常见pdf|word|excel等文件的预览和下载 背景相关类型数据之间的转换1、File转Blob2、File转ArrayBuffer3、Blob转ArrayBuffer4、Blob转File5、ArrayBuffer转Blob6、ArrayBuffer转File 根据Blob/File类型生成可预览的Base64地址基于Blob类型的各种文件的下载各种类型…

《Opencv》基础操作详解(4)

接上篇&#xff1a;《Opencv》基础操作详解&#xff08;3&#xff09;-CSDN博客 目录 22、图像形态学操作 &#xff08;1&#xff09;、顶帽&#xff08;原图-开运算&#xff09; 公式&#xff1a; 应用场景&#xff1a; 代码示例&#xff1a; &#xff08;2&#xff09;…

大数据高级ACP学习笔记(2)

钻取&#xff1a;变换维度的层次&#xff0c;改变粒度的大小 星型模型 雪花模型 MaxCompute DataHub

尚硅谷· vue3+ts 知识点学习整理 |14h的课程(持续更ing)

vue3 主要内容 核心&#xff1a;ref、reactive、computed、watch、生命周期 常用&#xff1a;hooks、自定义ref、路由、pinia、miit 面试&#xff1a;组件通信、响应式相关api ----> 笔记&#xff1a;ts快速梳理&#xff1b;vue3快速上手.pdf 笔记及大纲 如下&#xff…

阻抗(Impedance)、容抗(Capacitive Reactance)、感抗(Inductive Reactance)

阻抗&#xff08;Impedance&#xff09;、容抗&#xff08;Capacitive Reactance&#xff09;、感抗&#xff08;Inductive Reactance&#xff09; 都是交流电路中描述电流和电压之间关系的参数&#xff0c;但它们的含义、单位和作用不同。下面是它们的定义和区别&#xff1a; …

在 SQL 中,区分 聚合列 和 非聚合列(nonaggregated column)

文章目录 1. 什么是聚合列&#xff1f;2. 什么是非聚合列&#xff1f;3. 在 GROUP BY 查询中的非聚合列问题示例解决方案 4. 为什么 only_full_group_by 要求非聚合列出现在 GROUP BY 中&#xff1f;5. 如何判断一个列是聚合列还是非聚合列&#xff1f;6. 总结 在 SQL 中&#…

B树与B+树:数据库索引的秘密武器

想象一下&#xff0c;你正在构建一个超级大的图书馆&#xff0c;里面摆满了各种各样的书籍。B树和B树就像是两种不同的图书分类和摆放方式&#xff0c;它们都能帮助你快速找到想要的书籍&#xff0c;但各有特点。 B树就像是一个传统的图书馆摆放方式&#xff1a; 1. 书籍摆放&…

回归预测 | MATLAB实现CNN-SVM多输入单输出回归预测

回归预测 | MATLAB实现CNN-SVM多输入单输出回归预测 目录 回归预测 | MATLAB实现CNN-SVM多输入单输出回归预测预测效果基本介绍模型架构程序设计参考资料 预测效果 基本介绍 CNN-SVM多输入单输出回归预测是一种结合卷积神经网络&#xff08;CNN&#xff09;和支持向量机&#…

Linux-Ubuntu之裸机驱动最后一弹PWM控制显示亮度

Linux-Ubuntu之裸机驱动最后一弹PWM控制显示亮度 一&#xff0c; PWM实现原理二&#xff0c;软件实现三&#xff0c;正点原子裸机开发总结 一&#xff0c; PWM实现原理 PWM和学习51时候基本上一致&#xff0c;控制频率&#xff08;周期&#xff09;和占空比&#xff0c;51实验…