魔改并封装 YoloV5 Version7 的 detect.py 成 API接口以供 python 程序使用

news2025/1/10 22:35:18

文章目录

  • Introduction
  • Section 1 起因
  • Section 2 魔改的思路
  • Section 3 代码
    • Part 1 参数部分
    • Part 2 识别 API
    • Part 3 完整的 `DetectAPI.py`
    • Part 4 修改 `dataloaders.py`
  • Section 4 调用
  • Reference

Introduction

YoloV5 作为 YoloV4 之后的改进型,在算法上做出了优化,检测的性能得到了一定的提升。其特点之一就是权重文件非常的小,可以在一些配置更低的移动设备上运行,且提高速度的同时准确度更高。具体的性能见下图[^1]。本次使用的是最新推出的 YoloV5 Version7 版本。
GitHub 地址:YOLOv5 🚀 是世界上最受欢迎的视觉 AI,代表 Ultralytics 对未来视觉 AI 方法的开源研究,结合在数千小时的研究和开发中积累的经验教训和最佳实践。

YOLOV5 v7.0 SOTA Realtime Instance Segmentation


Section 1 起因

本人目前的一个项目需要使用到手势识别,得益于 YoloV5 的优秀的识别速度与准确率,因此识别部分的模型均使用 YoloV5 Version7 版本进行训练。训练之后需要使用这个模型,原始的 detect.py 程序使用 argparse 对参数进行封装,这为初期验证模型提供了一定的便利,我们可以通过 Pycharm 或者 Terminal 来快速地执行程序,然后在 run/detect 路径下快速地查看到结果。但是在实际的应用中,识别程序往往是作为整个系统的一个组件来运行的,现有的 detect.py 无法满足使用需求,因此需要将其封装成一个可供多个程序调用的 API 接口。通过这个接口可以获得 种类、坐标、置信度 这三个信息。通过这些信息来控制系统软件做出对应的操作。


Section 2 魔改的思路

这部分的代码与思路参照了[^2] 爆改YOLOV7的detect.py制作成API接口供其他python程序调用(超低延时) 这篇文章的思路。由于 YoloV5 和 YoloV7 的程序有些许不一样,因此做了一些修改。

大体的思路是去除掉 argparse 部分,通过类将参数封装进去,去除掉识别这个核心功能之外的其它功能。


Section 3 代码

Part 1 参数部分

需要传入一些常用的参数,后面的 API 会使用到这个类里面的参数

class YoloOpt:
    def __init__(self, weights='weights/last.pt',
                 imgsz=(640, 640), conf_thres=0.25,
                 iou_thres=0.45, device='cpu', view_img=False,
                 classes=None, agnostic_nms=False,
                 augment=False, update=False, exist_ok=False,
                 project='/detect/result', name='result_exp',
                 save_csv=True):
        self.weights = weights  # 权重文件地址
        self.source = None  # 待识别的图像
        if imgsz is None:
            self.imgsz = (640, 640)
        self.imgsz = imgsz  # 输入图片的大小,默认 (640,640)
        self.conf_thres = conf_thres  # object置信度阈值 默认0.25  用在nms中
        self.iou_thres = iou_thres  # 做nms的iou阈值 默认0.45   用在nms中
        self.device = device  # 执行代码的设备,由于项目只能用 CPU,这里只封装了 CPU 的方法
        self.view_img = view_img  # 是否展示预测之后的图片或视频 默认False
        self.classes = classes  # 只保留一部分的类别,默认是全部保留
        self.agnostic_nms = agnostic_nms  # 进行NMS去除不同类别之间的框, 默认False
        self.augment = augment  # augmented inference TTA测试时增强/多尺度预测,可以提分
        self.update = update  # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
        self.exist_ok = exist_ok  # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
        self.project = project  # 保存测试日志的参数,本程序没有用到
        self.name = name  # 每次实验的名称,本程序也没有用到
        self.save_csv = save_csv  # 是否保存成 csv 文件,本程序目前也没有用到

Part 2 识别 API

class DetectAPI:
    def __init__(self, weights, imgsz=640):
        self.opt = YoloOpt(weights=weights, imgsz=imgsz)
        weights = self.opt.weights
        imgsz = self.opt.imgsz

        # Initialize 初始化
        # 获取设备 CPU/CUDA
        self.device = select_device(self.opt.device)
        # 不使用半精度
        self.half = self.device.type != 'cpu'  # # FP16 supported on limited backends with CUDA

        # Load model 加载模型
        self.model = DetectMultiBackend(weights, self.device, dnn=False)
        self.stride = self.model.stride
        self.names = self.model.names
        self.pt = self.model.pt
        self.imgsz = check_img_size(imgsz, s=self.stride)

        # 不使用半精度
        if self.half:
            self.model.half() # switch to FP16

        # read names and colors
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]

    def detect(self, source):
        # 输入 detect([img])
        if type(source) != list:
            raise TypeError('source must a list and contain picture read by cv2')

        # DataLoader 加载数据
        # 直接从 source 加载数据
        dataset = LoadImages(source)
        # 源程序通过路径加载数据,现在 source 就是加载好的数据,因此 LoadImages 就要重写
        bs = 1 # set batch size

        # 保存的路径
        vid_path, vid_writer = [None] * bs, [None] * bs

        # Run inference
        result = []
        if self.device.type != 'cpu':
            self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(
                next(self.model.parameters())))  # run once
        dt, seen = (Profile(), Profile(), Profile()), 0

        for im, im0s in dataset:
            with dt[0]:
                im = torch.from_numpy(im).to(self.model.device)
                im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
                im /= 255  # 0 - 255 to 0.0 - 1.0
                if len(im.shape) == 3:
                    im = im[None]  # expand for batch dim

                # Inference
                pred = self.model(im, augment=self.opt.augment)[0]

                # NMS
                with dt[2]:
                    pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, self.opt.classes, self.opt.agnostic_nms, max_det=2)

                # Process predictions
                # 处理每一张图片
                det = pred[0]  # API 一次只处理一张图片,因此不需要 for 循环
                im0 = im0s.copy()  # copy 一个原图片的副本图片
                result_txt = []  # 储存检测结果,每新检测出一个物品,长度就加一。
                                 # 每一个元素是列表形式,储存着 类别,坐标,置信度
                # 设置图片上绘制框的粗细,类别名称
                annotator = Annotator(im0, line_width=3, example=str(self.names))
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    # 映射预测信息到原图
                    det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                    # 
                    for *xyxy, conf, cls in reversed(det):
                        line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item())  # label format
                        result_txt.append(line)
                        label = f'{self.names[int(cls)]} {conf:.2f}'
                        annotator.box_label(xyxy, label, color=self.colors[int(cls)])
                result.append((im0, result_txt))  # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
            return result, self.names

Part 3 完整的 DetectAPI.py

import argparse
import os
import platform
import random
import sys
from pathlib import Path

import torch
from torch.backends import cudnn

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode, time_sync

"""
使用面向对象编程中的类来封装,需要去除掉原始 detect.py 中的结果保存方法,重写
保存方法将结果保存到一个 csv 文件中并打上视频的对应帧率

"""


class YoloOpt:
    def __init__(self, weights='weights/last.pt',
                 imgsz=(640, 640), conf_thres=0.25,
                 iou_thres=0.45, device='cpu', view_img=False,
                 classes=None, agnostic_nms=False,
                 augment=False, update=False, exist_ok=False,
                 project='/detect/result', name='result_exp',
                 save_csv=True):
        self.weights = weights  # 权重文件地址
        self.source = None  # 待识别的图像
        if imgsz is None:
            self.imgsz = (640, 640)
        self.imgsz = imgsz  # 输入图片的大小,默认 (640,640)
        self.conf_thres = conf_thres  # object置信度阈值 默认0.25  用在nms中
        self.iou_thres = iou_thres  # 做nms的iou阈值 默认0.45   用在nms中
        self.device = device  # 执行代码的设备,由于项目只能用 CPU,这里只封装了 CPU 的方法
        self.view_img = view_img  # 是否展示预测之后的图片或视频 默认False
        self.classes = classes  # 只保留一部分的类别,默认是全部保留
        self.agnostic_nms = agnostic_nms  # 进行NMS去除不同类别之间的框, 默认False
        self.augment = augment  # augmented inference TTA测试时增强/多尺度预测,可以提分
        self.update = update  # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
        self.exist_ok = exist_ok  # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
        self.project = project  # 保存测试日志的参数,本程序没有用到
        self.name = name  # 每次实验的名称,本程序也没有用到
        self.save_csv = save_csv  # 是否保存成 csv 文件,本程序目前也没有用到


class DetectAPI:
    def __init__(self, weights, imgsz=640):
        self.opt = YoloOpt(weights=weights, imgsz=imgsz)
        weights = self.opt.weights
        imgsz = self.opt.imgsz

        # Initialize 初始化
        # 获取设备 CPU/CUDA
        self.device = select_device(self.opt.device)
        # 不使用半精度
        self.half = self.device.type != 'cpu'  # # FP16 supported on limited backends with CUDA

        # Load model 加载模型
        self.model = DetectMultiBackend(weights, self.device, dnn=False)
        self.stride = self.model.stride
        self.names = self.model.names
        self.pt = self.model.pt
        self.imgsz = check_img_size(imgsz, s=self.stride)

        # 不使用半精度
        if self.half:
            self.model.half() # switch to FP16

        # read names and colors
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]

    def detect(self, source):
        # 输入 detect([img])
        if type(source) != list:
            raise TypeError('source must a list and contain picture read by cv2')

        # DataLoader 加载数据
        # 直接从 source 加载数据
        dataset = LoadImages(source)
        # 源程序通过路径加载数据,现在 source 就是加载好的数据,因此 LoadImages 就要重写
        bs = 1 # set batch size

        # 保存的路径
        vid_path, vid_writer = [None] * bs, [None] * bs

        # Run inference
        result = []
        if self.device.type != 'cpu':
            self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(
                next(self.model.parameters())))  # run once
        dt, seen = (Profile(), Profile(), Profile()), 0

        for im, im0s in dataset:
            with dt[0]:
                im = torch.from_numpy(im).to(self.model.device)
                im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
                im /= 255  # 0 - 255 to 0.0 - 1.0
                if len(im.shape) == 3:
                    im = im[None]  # expand for batch dim

                # Inference
                pred = self.model(im, augment=self.opt.augment)[0]

                # NMS
                with dt[2]:
                    pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, self.opt.classes, self.opt.agnostic_nms, max_det=2)

                # Process predictions
                # 处理每一张图片
                det = pred[0]  # API 一次只处理一张图片,因此不需要 for 循环
                im0 = im0s.copy()  # copy 一个原图片的副本图片
                result_txt = []  # 储存检测结果,每新检测出一个物品,长度就加一。
                                 # 每一个元素是列表形式,储存着 类别,坐标,置信度
                # 设置图片上绘制框的粗细,类别名称
                annotator = Annotator(im0, line_width=3, example=str(self.names))
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    # 映射预测信息到原图
                    det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                    #
                    for *xyxy, conf, cls in reversed(det):
                        line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item())  # label format
                        result_txt.append(line)
                        label = f'{self.names[int(cls)]} {conf:.2f}'
                        annotator.box_label(xyxy, label, color=self.colors[int(cls)])
                result.append((im0, result_txt))  # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
            return result, self.names

Part 4 修改 dataloaders.py

文件路径在 utils/dataloaders.py ,修改其中的 LoadImages 类,将下面的代码完整替换掉就可以了。

class LoadImages:
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
    def __init__(self, path, img_size=640, stride=32):

       for img in path:
           if type(img) != np.ndarray or len(img.shape) != 3:
               raise TypeError('item is not a picture read by cv2')

       self.img_size = img_size
       self.stride = stride
       self.files = path
       self.nf = len(path)
       self.mode = 'image'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        # Read image
        self.count += 1

        # Padded resize
        img = letterbox(path, self.img_size, stride=self.stride)[0]

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        return img, path

    def _new_video(self, path):
        # Create a new video capture object
        self.frame = 0
        self.cap = cv2.VideoCapture(path)
        self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
        self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META))  # rotation degrees
        # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)  # disable https://github.com/ultralytics/yolov5/issues/8493

    def _cv2_rotate(self, im):
        # Rotate a cv2 video manually
        if self.orientation == 0:
            return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
        elif self.orientation == 180:
            return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
        elif self.orientation == 90:
            return cv2.rotate(im, cv2.ROTATE_180)
        return im

    def __len__(self):
        return self.nf  # number of files

Section 4 调用

if __name__ == '__main__':
    cap = cv2.VideoCapture(0)
    a = DetectAPI.DetectAPI(weights='weights/last.pt')
    with torch.no_grad():
        while True:
            rec, img = cap.read()
            result, names = a.detect([img])
            img = result[0][0]  # 每一帧图片的处理结果图片
            # 每一帧图像的识别结果(可包含多个物体)
            for cls, (x1, y1, x2, y2), conf in result[0][1]:
                print(names[cls], x1, y1, x2, y2, conf)  # 识别物体种类、左上角x坐标、左上角y轴坐标、右下角x轴坐标、右下角y轴坐标,置信度
                '''
                cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
                cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))'''
            print()  # 将每一帧的结果输出分开
            cv2.imshow("vedio", img)

            if cv2.waitKey(1) == ord('q'):
                break

实测效果


Reference

本程序的修改参考了以下的资料,在此为前人做出的努力与贡献表示感谢!

https://github.com/ultralytics/yolov5/releases/tag/v7.0
https://blog.csdn.net/weixin_51331359/article/details/126012620
https://blog.csdn.net/CharmsLUO/article/details/123422822

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/350131.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

errgroup 原理简析

golang.org/x/sync/errgroup errgroup提供了一组并行任务中错误采集的方案。 先看注释 Package errgroup provides synchronization, error propagation, and Context cancelation for groups of goroutines working on subtasks of a common task. Group 结构体 // A Gro…

Sphinx : 高性能SQL全文检索引擎

Sphinx是一款基于SQL的高性能全文检索引擎,Sphinx的性能在众多全文检索引擎中也是数一数二的,利用Sphinx,我们可以完成比数据库本身更专业的搜索功能,而且可以有很多针对性的性能优化。 Sphinx的特点 快速创建索引:3分…

Barra模型因子的构建及应用系列三之Momentum因子

一、摘要 在之前的Barra模型系列文章中,我们已经初步讲解、构建了Size因子和Beta因子,并分别创建了对应的单因子策略。通过回测发现,其中Size因子的小市值效应具有很强的收益能力。而本篇文章将在该系列下进一步构建Momentum因子。 二、模型…

90%企业在探索的敏捷开发怎么做?极狐GitLab总结了这些逻辑与流程

本文来自: 彭亮 极狐(GitLab) 高级产品经理 毛超 极狐(GitLab) 研发工程师 极狐(GitLab) 市场部内容团队 “敏捷” 是指能够驾驭变化,保持组织竞争优势的一种能力。自 2001 年《敏捷宣言》以来,敏捷及敏捷开发理念逐渐席卷全球。中国信通院《…

面试已上岸,成功拿到阿里和腾讯的入职offer,Java程序员面经全在这了,希望能帮到你!

前言 一开始的时候简历海投大多数都被拒绝了,后来自己找在腾讯上班的朋友帮忙改了一下简历,果然不一样了大多都能拿到面试机会,当然拿到后也没有那么顺利,面了差不多有十几家公司的样子,大大小小的都有,其中…

C++和QML混合编程_QML发送信号到C++端(信号和槽绑定)

C和QML混合编程_QML发送信号到C端(信号和槽绑定) 前言: 下面是之前讲解过的三种方法 1、使用Q_INVOKABLE声明一下普通函数,在QML端可以直接调用 2、使用Connections绑定QML的信号和C端的槽函数 3、使用connect绑定QML的信号和C端的…

通俗易懂理解——布隆过滤器

文章目录概述本质优缺点优点:缺点:实际应用解决redis缓存穿透问题:概述 本质 本质:很长的二进制向量(数组) 主要作用:判断一个数据在这个数组中是否存在,如果不存在为0&#xff0c…

NR PDCP duplication

欢迎关注同名微信公众号“modem协议笔记”。 PDCP duplication 是PDCP 的一个功能,主要是为满足URLLC 场景的可靠性/延迟要求,而产生的一种提高传输可靠性的机制,具体就是在信号状况比较差的情况下,网络侧通过配置PDCP duplicati…

集中式存储和分布式存储

分布式存储是相对于集中式存储来说的,在介绍分布式存储之前,我们先看看什么是集中式存储。不久之前,企业级的存储设备都是集中式存储。所谓集中式存储,从概念上可以看出来是具有集中性的,也就是整个存储是集中在一个系…

Zynq非Video Mixer方案实现视频叠加输出,无需SDK配置,提供工程源码和技术支持

目录1、前言2、Video Mixer的不便之处3、FDMA取代Video Mixer实现视频叠加输出4、Vivado工程详解5、上板调试验证并演示6、福利:工程代码的获取1、前言 关于Zynq使用Video Mixer方案实现视频叠加输出方案请参考点击查看:Video Mixer方案 对于Zynq和Micr…

Elasticsearch:Security API 介绍

在我之前的文章 “Elasticsearch:运用 API 创建 roles 及 users” ,我展示了如何使用 Security API 来创建用户及角色来控制访问 Elasticsearch 中的索引。在今天的文章中,我将展示一个使用 Security API 来创建一个用户及角色来访问一个索引…

双指针【灵神基础精讲】

来源0x3f:https://space.bilibili.com/206214 文章目录同向双指针[209. 长度最小的子数组](https://leetcode.cn/problems/minimum-size-subarray-sum/)[713. 乘积小于 K 的子数组](https://leetcode.cn/problems/subarray-product-less-than-k/)[3. 无重复字符的最…

计算机相关专业毕业论文选题推荐

计算机科学以下是我推荐的20个计算机科学专业的本科论文选题:基于机器学习的推荐算法研究与实现基于区块链技术的数字身份认证方案设计与实现基于深度学习的图像识别技术研究与应用基于虚拟现实技术的教育培训平台设计与实现基于物联网技术的智能家居系统研究与开发…

Dubbo与Spring Cloud优缺点分析(文档学习个人理解)

文章目录核心部件1、总体框架1.1 Dubbo 核心部件如下1.2 Spring Cloud 总体架构2、微服务架构核心要素3、通讯协议3.1 Dubbo3.2 Spring Cloud3.3 性能比较4、服务依赖方式4.1 Dubbo4.2 Spring Cloud5、组件运行流程5.1 Dubbo5.2 Dubbo 运行组件5.3 Spring Cloud5.4 Spring Clou…

[数据治理-02]一个例子搞懂元数据、参考数据、主数据、交易数据...的关系

杜威说过“所有知识都是分类”!很好理解,分类是认知经济,任何有效分类,都可以极大地节省我们的认知精力。谈到数据就必须做个分类,谈到数据分类可以从多个维度出发,比如按业务维度、这是财务数据、那是人力…

C++ ——多态 下 (图解多态原理、虚函数的再认知)

目录 一、抽象类 1)抽象类定义 2)抽象类的继承 3)抽象类实现多态 4)抽象类的好处 二、多态的实现原理 1)虚函数的存储方式 2)子类中虚函数的存储方式 ① 子类将基类中的虚表原封不动的拷贝到自己的…

【原创】java+swing+mysql教师管理系统设计与实现

教师管理系统主要是方便学校对教师进行管理,本文主要介绍如何使用java的swing窗体控件和mysql数据库去设计一个简单的教师管理系统。 功能分析: 本系统为javaswingmysql的教师管理系统,管理员、教师 功能如下: 管理员&#xff…

Quartz入门教程

本文参考文章编写 Quartz 官网 Quartz 是 OpenSymphony 开源组织在 Job Scheduling 领域又一个开源项目,是完全由 Java 开发的一个开源任务日程管理系统,“任务进度管理器”就是一个在预先确定(被纳入日程)的时间到达时&#xff…

2022——寒假总结

文章目录背景报名摸索结果总结背景 大一上学期,刚上大学没有尽快适应,什么都没有学到。 因为疫情,所以平时的测试以及期末都是线上进行的,就没怎么认真学,网课直接划水。 我的生活与学习很不平衡,还热衷于参…

搭建hadoop高可用集群(二)

搭建hadoop高可用集群(一)配置hadoophadoop-env.shworkerscore-site.xmlhdfs-site.xmlmapred-site.xmlyarn-site.xml/etc/profile拷贝集群首次启动1、先启动zk集群(自动化脚本)2、在hadoop151,hadoop152,hadoop153启动JournalNode…