YOLOv8 目标检测程序,依赖的库最少,使用onnxruntime推理

news2024/10/5 15:29:36

YOLOv8 目标检测程序,依赖的库最少,使用onnxruntime推理

flyfish
为了方便理解,加入了注释

"""
YOLOv8 目标检测程序
Author: flyfish
Date: 
Description: 该程序使用ONNX运行时进行YOLOv8模型的目标检测。
      它对输入图像进行预处理,运行推理,并对输出进行后处理,在图像上绘制边界框和标签。
"""
import argparse
import cv2
import numpy as np
import onnxruntime as ort

# 定义类别字典,将类别ID映射到类别名称
classes = {
    0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 
    9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 
    16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 
    24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 
    31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 
    37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 
    44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 
    52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 
    60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 
    67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 
    74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'
}

class YOLOv8:
    """YOLOv8目标检测模型类,用于处理推理和可视化。"""

    def __init__(self, onnx_model, input_image, confidence_thres, iou_thres):
        """
        初始化YOLOv8类的实例。

        参数:
            onnx_model: ONNX模型的路径。
            input_image: 输入图像的路径。
            confidence_thres: 用于过滤检测结果的置信度阈值。
            iou_thres: 非极大值抑制的IoU(交并比)阈值。
        """
        self.onnx_model = onnx_model
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres

        # 加载类别名称
        self.classes = classes
     
        # 为每个类别生成一个颜色调色板
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))

    def draw_detections(self, img, box, score, class_id):
        """
        在输入图像上绘制检测到的边界框和标签。

        参数:
            img: 要在其上绘制检测结果的输入图像。
            box: 检测到的边界框。
            score: 对应的检测置信度分数。
            class_id: 检测到的对象的类别ID。

        返回值:
            无
        """
        # 提取边界框的坐标
        x1, y1, w, h = box

        # 获取类别ID对应的颜色
        color = self.color_palette[class_id]

        # 在图像上绘制边界框
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)

        # 创建包含类别名称和置信度分数的标签文本
        label = f"{self.classes[class_id]}: {score:.2f}"

        # 计算标签文本的尺寸
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # 计算标签文本的位置
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10

        # 绘制填充矩形作为标签文本的背景
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )

        # 在图像上绘制标签文本
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    def preprocess(self):
        """
        在执行推理之前预处理输入图像。

        返回值:
            image_data: 预处理后的图像数据,准备进行推理。
        """
        # 使用OpenCV读取输入图像
        self.img = cv2.imread(self.input_image)

        # 获取输入图像的高度和宽度
        self.img_height, self.img_width = self.img.shape[:2]

        # 将图像的颜色空间从BGR转换为RGB
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

        # 调整图像大小以匹配输入形状
        img = cv2.resize(img, (self.input_width, self.input_height))

        # 通过除以255.0来规范化图像数据
        image_data = np.array(img) / 255.0

        # 转置图像,使通道维度为第一个维度
        image_data = np.transpose(image_data, (2, 0, 1))  # 通道优先

        # 扩展图像数据的维度以匹配预期的输入形状
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)

        # 返回预处理后的图像数据
        return image_data

    def postprocess(self, input_image, output):
        """
        对模型的输出进行后处理,以提取边界框、置信度分数和类别ID。

        参数:
            input_image (numpy.ndarray): 输入图像。
            output (numpy.ndarray): 模型的输出。

        返回值:
            numpy.ndarray: 带有绘制检测结果的输入图像。
        """
        # 转置并压缩输出以匹配预期的形状
        outputs = np.transpose(np.squeeze(output[0]))

        # 获取输出数组中的行数
        rows = outputs.shape[0]

        # 用于存储检测到的边界框、置信度分数和类别ID的列表
        boxes = []
        scores = []
        class_ids = []

        # 计算边界框坐标的缩放因子
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        # 遍历输出数组中的每一行
        for i in range(rows):
            # 从当前行中提取类别分数
            classes_scores = outputs[i][4:]

            # 找到类别分数中的最大值
            max_score = np.amax(classes_scores)

            # 如果最大值大于置信度阈值
            if max_score >= self.confidence_thres:
                # 获取具有最高分数的类别ID
                class_id = np.argmax(classes_scores)

                # 从当前行中提取边界框坐标
                x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]

                # 计算缩放后的边界框坐标
                left = int((x - w / 2) * x_factor)
                top = int((y - h / 2) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                # 将类别ID、置信度分数和边界框坐标添加到各自的列表中
                class_ids.append(class_id)
                scores.append(max_score)
                boxes.append([left, top, width, height])

        # 应用非极大值抑制以过滤重叠的边界框
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)

        # 遍历非极大值抑制后选择的索引
        for i in indices:
            # 获取对应于索引的边界框、置信度分数和类别ID
            box = boxes[i]
            score = scores[i]
            class_id = class_ids[i]

            # 在输入图像上绘制检测结果
            self.draw_detections(input_image, box, score, class_id)

        # 返回修改后的输入图像
        return input_image

    def main(self):
        """
        使用ONNX模型执行推理并返回带有绘制检测结果的输出图像。

        返回值:
            output_img: 带有绘制检测结果的输出图像。
        """
        # 使用ONNX模型创建推理会话并指定执行提供程序
        session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

        # 获取模型输入
        model_inputs = session.get_inputs()

        # 存储输入的形状以便稍后使用
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]

        # 预处理图像数据
        img_data = self.preprocess()

        # 使用预处理后的图像数据进行推理
        outputs = session.run(None, {model_inputs[0].name: img_data})

        # 对输出进行后处理以获得输出图像
        return self.postprocess(self.img, outputs)  # 输出图像


if __name__ == "__main__":
    # 创建一个参数解析器以处理命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="yolov8n.onnx", help="输入您的ONNX模型。")
    parser.add_argument("--img", type=str, default="bus.jpg", help="输入图像的路径。")
    parser.add_argument("--conf-thres", type=float, default=0.5, help="置信度阈值")
    parser.add_argument("--iou-thres", type=float, default=0.5, help="非极大值抑制的IoU阈值")
    args = parser.parse_args()

    # 使用指定的参数创建YOLOv8类的实例
    detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres)

    # 执行目标检测并获取输出图像
    output_image = detection.main()

    # 在窗口中显示输出图像
    cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
    cv2.imshow("Output", output_image)

    # 等待按键以退出
    cv2.waitKey(0)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1858098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

尴尬时刻:如何在忘记名字时巧妙应对

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

代理IP知识:导致代理IP访问超时的原因有哪些?

很多用户在使用代理IP进行网络访问时,可能会遇到代理IP超时的情况,也就是代理IP的延迟过高。代理IP延迟过高会影响用户的网络体验和数据获取效率。因此,了解代理IP延迟过高的原因很重要。以下是导致代理IP延迟过高的一些常见原因:…

美容美发店营销版微信小程序源码

打造线上生意新篇章 一、引言:微信小程序,开启美容美发行业新纪元 在数字化时代,微信小程序以其便捷、高效的特点,成为了美容美发行业营销的新宠。本文将带您深入了解美容美发营销微信小程序,探讨其独特优势及如何助…

盘点5款最热门的AI绘画软件!总有一款是你的菜

在数字化艺术日益盛行的今天,AI绘画软件成为了创作者们的新宠。这些软件不仅能够帮助艺术家们快速生成独特的艺术作品,还能为普通用户带来全新的绘画体验。今天,我们就来盘点五款最热门的AI绘画软件,看看哪一款是你的菜&#xff0…

深度学习 --- stanford cs231学习笔记五(训练神经网络的几个重要组成部分之三,权重矩阵的初始化)

权重矩阵的初始化 3,权重矩阵的初始化 深度学习所学习的重点就是要根据损失函数训练权重矩阵中的系数。即便如此,权重函数也不能为空,总是需要初始化为某个值。 3,1 全都初始化为同一个常数可以吗? 首先要简单回顾一下…

技术干货 | AI驱动工程仿真和设计创新

在当今快速发展的技术领域,人工智能(AI)、机器学习和深度学习等技术已经成为推动工程仿真和设计创新的关键力量。Altair技术经理张晨在Altair “AI FOR ENGINEERS”线下研讨会上发表了相关精彩演讲,本文摘自演讲内容,与…

数字化校园平台:引领教育创新的智慧之选

数字化校园平台是信息化技术与传统教育深度结合的产物。在当今这个信息技术日新月异的时代,数字化校园平台正逐渐崭露头角,成为教育领域一股不可小觑的革新力量。它如同一座桥梁,连接起教育资源的各个角落,将繁杂的教学材料、珍贵…

猫狗识别—视频识别

猫狗识别—视频识别 1. 导入所需的库:2. 创建Tkinter主窗口并设置标题:3. 设置窗口的宽度和高度:4. 创建一个Canvas,它将用于显示视频帧:5. 初始化一个视频流变量cap,用于存储OpenCV的视频捕获对象&#xf…

Matlab要这样批量读取txt数据!科研效率UpUp第10期

假如我们有多组txt格式的数据: 其数据格式是这样的: 想要批量读取这些数据,并把他们画在一张图上,该怎么操作呢? ​之前有分享load函数的版本,本期进一步分享适用性更强的readtable函数的实现方法​。 首…

工业的物联网在构建弹性供应链系统中的作用

物联网 (IoT) 可以显着提高供应链系统的效率,因为物联网处理设备之间的连接。简而言之,物联网转化为“连接设备”,物联网的这种能力导致了智能系统或环境。物联网将这些设备与传感器和执行器连接起来,这些传感器和执行器收集数据并…

【计算机网络仿真】b站湖科大教书匠思科Packet Tracer——实验8 IPv4地址 — 分类地址

一、实验目的 1.验证分类IP地址的作用; 2.初步了解路由器的功能。 二、实验要求 1.使用Cisco Packet Tracer仿真平台; 2.观看B站湖科大教书匠仿真实验视频,完成对应实验。 三、实验内容 1.构建网络拓扑; 2.修改网络拓扑&…

原创作品—工业软件界面设计作品

在工业4.0时代,界面设计不仅要追求美观,更要以用户体验为核心。通过简化操作流程、优化交互逻辑,降低用户的学习成本,提高使用效率。这样的设计能够为企业数字化转型提供有力支持,增强用户对产品的黏性。 数字化转型的…

云盘高速视觉检测机如何提升螺丝尺寸检测效率?

螺丝,一种用来连接和固定物体的金属件,通常是长有螺纹的金属棒。螺丝有不同种类和尺寸,常见的用途包括组装家具、机械设备和其他结构。连接和固定物体,通过螺丝的螺纹结构,将两个或多个物体牢固地连接在一起。提供调节…

LabVIEW与C#相互调用dll

C#调用LabVIEW创建的dll 我先讲LabVIEW创建自己的.net类库的方法吧,重点是创建,C#调用的步骤,大家可能都很熟悉了。 1、创建LabVIEW项目,并创建一个简单的add.vi,内容就是abc,各个接线端都正确连接就好。 …

一种改进解卷积算法在旋转机械故障诊断中的应用(MATLAB)

轴承振动是随机振动。在不同的时刻,轴承振动值是不尽相同的,不能用一个确定的时间函数来描述。这是由于滚动体除了有绕轴承公转运动以外,还有绕自身轴线的自旋运动,且在轴承运转时,滚动接触表面形貌是不断变化的&#…

大脑网路分析的进展:基于大规模自监督学习的诊断| 文献速递-先进深度学习疾病诊断

Title 题目 BrainMass: Advancing Brain Network Analysis for Diagnosis with Large-scale Self-Supervised Learning 大脑网路分析的进展:基于大规模自监督学习的诊断 01 文献速递介绍 功能性磁共振成像(fMRI)利用血氧水平依赖&#x…

颠覆传统!支持70+三维格式转换,3D模型格式转换在线即可一键处理!

老子云自研AMRT展示框架及三维格式具有广泛兼容性,同时还会用户提供了3D格式在线转换工具,支持实现70三维格式模型的快速处理和转换。 你是不是也遇到过这种情况:做了半天的3D模型图,好不容易弄好了,到最后插入的时候居…

“硝烟下的量子”:以色列为何坚持让量子计算中心落地?

自2023年10月7日新一轮巴以冲突爆发以来,支持巴勒斯坦伊斯兰抵抗运动(哈马斯)的黎巴嫩真主党不时自黎巴嫩南部向以色列北部发动袭击,以军则用空袭和炮击黎南部目标进行报复,双方在以黎边境的冲突持续至今。 冲突走向扑…

炎炎夏日,矿物质水为你防暑补水

炎炎夏日,整座城市如同一个巨大的“烤箱” 人们行走在炙热烈阳中 汗如雨下,口干舌燥 在这样的高温天气中 中暑的风险也随之增加 烈日当头的夏天 该如何预防中暑呢? 或许答案藏在一杯矿物质水中 为什么矿物质水能够预防中暑?…

AlertDialog和Dialog的区别

在安卓开发过程中,Dialog是我们常用的UI组件之一,它主要用来显示提示信息、与用户进行交互等。在Android中,Dialog有很多种类,其中最常见的就是AlertDialog和普通的Dialog。本文将详细介绍这两者之间的区别,并通过示例…