YOLOv5源码中的参数超详细解析(4)— 推理部分detect.py

news2024/11/6 7:25:48

前言:Hello大家好,我是小哥谈。YOLOv5是一种先进的目标检测算法,它可以实现快速和准确的目标检测。detect.py是YOLOv5项目目录结构中的一个重要的脚本文件,它用于执行目标检测任务,可以通过命令行参数指定要检测的图像或视频文件,以及模型文件的路径。它还可以指定检测的置信度阈值和非极大值抑制(NMS)的阈值,以控制检测结果的准确率和召回率。在运行过程中,detect.py会将检测结果保存为JSON格式的文件,并在图像或视频上绘制出检测框和类别标签。🌈

 前期回顾:

             YOLOv5源码中的参数超详细解析(1)— 项目目录结构解析

             YOLOv5源码中的参数超详细解析(2)— 配置文件yolov5s.yaml

             YOLOv5源码中的参数超详细解析(3)— 训练部分train.py 

             目录

🚀1.detect.py的主要内容

🚀2.detect.py的主要作用

🚀3.detect.py的代码详解

💥💥3.1 导包和基本配置

💥💥3.2 执行main函数

💥💥3.3 设置opt参数

💥💥3.4 执行run函数

💞💞💞3.4.1 载入参数

💞💞​​​​​​​💞3.4.2 初始化配置

💞​​​​​​​💞​​​​​​​💞3.4.3 保存结果

💞​​​​​​​💞​​​​​​​💞3.4.4 加载模型

💞​​​​​​​💞​​​​​​​💞3.4.5 加载数据

💞​​​​​​​💞​​​​​​​💞3.4.6 推理部分

💞​​​​​​​💞​​​​​​​💞3.4.7 打印结果

🚀1.detect.py的主要内容

detect.py是一个代码文件,用于使用YOLOv5模型进行目标检测。在该文件中,有几个主要的函数和模块的定义和使用。🍃

首先,在函数 parse_opt() 中,解析了命令行参数。

然后,在函数 main() 中,调用函数 run() 来执行目标检测的主要逻辑。在函数 run() 中,首先进行了配置的初始化,然后加载数据进行预处理,接着进行目标检测的输入和NMS操作,最后保存结果并进行打印。

另外,在 detect.py 中还有一些自定义模块的定义和使用,包括

  • models/common.py:定义了一些通用的类模块,比如各种卷积模块;
  • utils.dataloaders.py:定义了加载图像或视频帧并进行预处理的类;
  • utils.general.py:定义了一些工具函数,比如日志、坐标转换等;
  • utils.plot.py:用于画图和标框;
  • utils.torch_utils.py:定义了与pytorch相关的工具函数,比如设备选择等。

最后,detect.py中还包含了一些代码的注释和详解,以及使用教程。🍃

等等...🍉 🍓 🍑 🍈 🍌 🍐

总的来说,detect.py是一个用于目标检测的代码文件,使用了YOLOv5算法模型并包含了一些自定义模块和工具函数的定义和使用。通过解析命令行参数,加载数据并进行预处理,执行目标检测操作,并保存结果和打印相关信息。🍃

说明:♨️♨️♨️

YOLOv5官方代码下载地址:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite


🚀2.detect.py的主要作用

YOLOv5是当前目标检测领域非常流行的算法。在YOLOv5项目目录结构中,detect.py是一个非常重要的文件,该文件用于验证自己训练出来的模型在实验中的表现情况。🍃

说明:♨️♨️♨️

关于YOLOv5项目目录结构的详细信息,请参考我的另外一篇文章。

YOLOv5源码中的参数超详细解析(1)— 项目目录结构解析

detect.py文件的作用主要包括以下部分:👇

🍀参数设置

在detect.py文件中,首先进行的是参数设置,包括批处理大小、输入图像分辨率、模型权重、类别文件等等。参数设置的目的是为了使模型在识别目标时能够准确匹配到所有的类别,而且输入分辨率和批处理大小也需要合理设置,以达到最优的性能。

🍀模型载入

导入模型文件和权重文件。这个步骤是比较关键的一步,模型的效果和准确性很大程度上受制于模型的训练数据和权重参数,这个步骤的任务就是导入模型和权重, 使得模型具备识别目标的能力,同时也让模型能够通过图片进行预测。

🍀图片预处理

detect.py文件在进行图片预处理的时候,分别进行了两个操作,一是对输入的图片进行缩放,二是对图像进行中心切割。这样可以确保输入模型的图片大小和比例都是统一的。

🍀目标识别

通过前面的数据预处理,我们得到了输入模型的图片,这时模型会根据图片中的像素信息和感受野对目标进行识别,并对每个目标产生一个置信度、类别和边界框等信息。这个过程是非常关键的步骤,也是模型性能的重要指标之一。

🍀输出目标

在目标识别结束后,detect.py文件会输出识别的结果,结果包括目标的类别、置信度和边界框信息。同时,它也会把识别结果可视化为一张图片,以便人类进行直观观察和判断。

总体而言,YOLOv5中的detect.py文件非常关键,它是识别目标的关键一步,通过对输入数据进行预处理,并结合模型进行目标识别和结果输出,最终得到的结果具备准确性和实用性,为用户提供了高效、精确的目标检测服务。


🚀3.detect.py的代码详解

💥💥3.1 导包和基本配置

"=================导入安装好的python库================="
import argparse   # 解析命令行参数的库
import os        # 与操作系统进行交互的文件库,包含文件路径操作和解析
import platform  # platform是用来获取操作系统的信息的模块
import sys       # sys模块包含了与python的解释器和它的环境有关的函数
from pathlib import Path  # path能够更加方便地对字符串路径进行处理

import torch    # pytorch,深度学习库
import torch.backends.cudnn as cudnn # 让内置的cudnn的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的目的。

首先需要导入的是常用python库

argparse:它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息。

os: 它提供了多种操作系统的接口。通过os模块提供的操作系统接口,我们可以对操作系统里文件、终端、进程等进行操作。

platform:是用来获取操作系统的信息的模块。

sys: 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 python 运行时环境的不同部分。

pathlib: 这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读。

torch: 这是主要的Pytorch库。它提供了构建、训练和评估神经网络的工具。

torch.backends.cudnn: 它提供了一个接口,用于使用cuDNN库,在NVIDIA GPU上高效地进行深度学习。cudnn模块是一个Pytorch库的扩展。

"=================获取当前文件的绝对路径================="
FILE = Path(__file__).resolve()  # __file__指的是当前文件(即detect.py),FILE最终保存着当前文件的绝对路径。
ROOT = FILE.parents[0]  # YOLOv5 root directory  Root保存着当前项目的父目录。
if str(ROOT) not in sys.path:   # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径。
    sys.path.append(str(ROOT))  # add ROOT to PATH 把Root添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative ROOT设置为相对路径

这段代码会获取当前文件的绝对路径,并使用Path库将其转换为Path对象。📚

这一部分的主要作用有两个

  • 将当前项目添加到系统路径上,以使得项目中的模块可以调用。
  • 将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件。
"=================加载自定义模块================="
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode

最后则是一些自定义模块,其中主要包括了:

models/common.py:定义了一些通用的类模块,比如各种卷积模块。

utils.dataloaders.py:这个文件定义了两个类,LoadImages 和 LoadStreams,它们可以加载图像或视频帧,并对它们进行一些预处理,以便进行物体检测或识别。

utils.general.py:定义一些工具函数,比如日志、坐标转换等。

utils.plots.py:画图,标框。

utils.torch_utils.py:定义了一些与pytorch相关的工具函数,比如设备选择等。

通过导入这些模块,可以减少代码的复杂度、耦合性、冗余程度。🌳

💥💥3.2 执行main函数

"=================设置main函数================="
def main(opt):
    # 检查环境/打印参数,主要是 requirement.txt的包是否安装,用彩色显示设置的参数
    check_requirements(exclude=('tensorboard', 'thop'))
    # 执行run()函数
    run(**vars(opt))

# 命令使用
if __name__ == "__main__":
    opt = parse_opt()  # 解析参数
    main(opt)      # 执行主函数

这是程序的主函数。它调用了 check_requirements()函数  run()函数,并将命令行参数 opt 转换为字典作为参数传递给 run() 函数。📚

说明:♨️♨️♨️

if __name__ == "__main__":的作用:
一个python文件通常有两种使用方法,一是作为脚本直接执行,二是 import 到其他的 python 脚本中被调用(模块重用)执行。因此 if __name__ == "__main__":的作用就是控制这两种情况执行代码的过程,在if __name__ == "__main__":下的代码只有在第一种情况下(即文件作为脚本直接执行)才会被执行,而 import 到其他脚本中是不会被执行的。

💥💥3.3 设置opt参数

执行main函数时用到 parse_opt() 这个函数,它的功能主要是为模型进行推理时提供参数,在parse_opt()执行完成之后,会将opt传给函数main()。

"=================parse_opt()用来设置输入参数的子函数================="
def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default= ROOT / r'C:\Users\Lenovo\PycharmProjects\yolov5-master-mookcake\runs\train\exp17\weights\best.pt', help='model path(s)')
    parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='show results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--visualize', action='store_true', help='visualize features')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
    print_args(vars(opt))
    return opt

这段代码是一个python脚本的一个函数,用于解析命令行参数并返回这些参数的值。其主要功能是为模型进行推理的时候提供参数。📚

下面是每个参数的作用和默认值:👇

'--weights--' :指定权重文件的路径,默认是yolov5s.pt,可以使用自己训练的权重,也可以使用官网提供的权重,下载好后放在根目录就好。默认官网的权重yolov5s.pt (yolov5n.pt/yolov5s.pt/yolov5m.pt/yolov5l.pt/yolov5x.pt/区别在于网络的宽度和深度依次增加)

'--source':指定网络输入的测试数据的路径文件夹,可以是图片/视频路径,也可以是'0'(电脑自带摄像头),也可以是rtsp等视频流,也可以指定具体的文件或者扩展名。默认是 data/images 文件夹,测试的时候默认测试此文件夹下的图片。

'--data' :配置数据的文件路径,默认为COCO128数据集的配置文件路径,包括数据集的下载路径和一些基本信息,在预测时如果不自己指定数据集,系统会自己下载coco128数据集。

'--imgsz' :模型在检测图片前会把图片resize成640 × 640的尺寸,然后再输入进网络里,即预测时网络输入图片的尺寸,默认为640 × 640。

'--conf-thres' :置信度阈值,默认为 0.25。表示预测置信度大于0.25的值才会被框选出来。

说明:♨️♨️♨️

置信度:指网络对检测出来的目标正确的相信程度。当参数设置为0时,网络只要认为检测目标有一丢丢的正确,就会被框选出来。

'--iou-thres' :非极大抑制时的 IoU 阈值,默认为 0.45。

说明:♨️♨️♨️

关于非极大值抑制原理、解析等知识内容,请参考我的另外一篇文章:

YOLOv5基础知识入门(7)— NMS(非极大值抑制)原理解析

'--max-det' :保留的最大检测框(检测目标)数量,每张图片中检测目标的个数最多为1000类。

'--device' :预测时使用的设备,按照自己需求可以选择GPU或CPU。填写的是cuda 设备的 ID(例如 0,1,2,3或者是 'cpu'),显卡编号可以使用nvidia-smi指令来查看。

'--view-img':是否展示预测之后的图片/视频,默认为False。

'--save-txt' :是否将预测的框坐标以txt文件形式保存,默认为False。

'--save-conf':是否将检测结果的置信度保存起来,并保存成.txt 格式,默认为False。

说明:♨️♨️♨️

必须和'--save-txt' 配合使用,单独使用不报错,但是也没有效果。

'--save-crop' :是否把模型检测的结果裁剪下来,并保存在crops文件夹下,默认为False。

'--nosave'  :不保存图片、视频等预测结果。不设置--nosave 在runs/detect/exp*/会出现预测的结果,若设置了--nosave,则只会产生空文件夹,文件夹里无任何预测结果。

'--classes'  :根据类别编号,仅检测指定类别,检测类别可以多个。

 '--agnostic-nms' :是否使用类别不敏感的非极大抑制(即不考虑类别信息),默认为False。

'--arugment' :是否使用数据增强进行推理,默认为False。

'--visualize' :是否可视化特征图,默认为False。

'--update' :对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False。

'--project' :预测结果保存的项目目录路径,默认为 'ROOT/runs/detect'。

'--name' :预测结果保存的子目录名称,默认为 'exp'。

'--exist-ok':是否覆盖已有结果,默认为False。若指定了此参数,预测的结果保存在上一次保存的文件夹中,若不指定,每次预测结果则会保存一个新的文件夹中。

'--line-thickness':画 bounding box (检测框)时的线条宽度,默认为 3,数值越大,线条越粗(过粗会遮挡检测目标)。

'--hide-labels' :是否隐藏标签信息,默认为False。

'--hide-conf' :是否隐藏置信度信息,默认为False。

'--half':是否使用 FP16 半精度进行推理,默认为 False。

'--dnn​​​':是否使用 OpenCV DNN 进行 ONNX 推理,默认为 False。

💥💥3.4 执行run函数

main()函数中调用了run()函数,run()函数主要分为七个部分

🍀(1)载入参数

🍀(2)初始化配置

🍀(3)保存结果

🍀(4)加载模型的权重

🍀(5)加载待预测的数据

🍀(6)执行模型的推理过程

🍀(7)打印输出信息。

接下来让我们按照执行流程来依次解析run函数吧!!!

💞💞💞3.4.1 载入参数

"=================载入参数================="
@smart_inference_mode()
def run(
        weights=ROOT / 'yolov5s.pt',  # model.pt path(s)
        source=ROOT / 'data/images',  # file/dir/URL/glob, 0 for webcam
        data=ROOT / 'data/coco128.yaml',  # dataset.yaml path
        imgsz=(640, 640),  # inference size (height, width)
        conf_thres=0.25,  # confidence threshold
        iou_thres=0.45,  # NMS IOU threshold
        max_det=1000,  # maximum detections per image
        device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
        view_img=False,  # show results
        save_txt=False,  # save results to *.txt
        save_conf=False,  # save confidences in --save-txt labels
        save_crop=False,  # save cropped prediction boxes
        nosave=False,  # do not save images/videos
        classes=None,  # filter by class: --class 0, or --class 0 2 3
        agnostic_nms=False,  # class-agnostic NMS
        augment=False,  # augmented inference
        visualize=False,  # visualize features
        update=False,  # update all models
        project=ROOT / 'runs/detect',  # save results to project/name
        name='exp',  # save results to project/name
        exist_ok=False,  # existing project/name ok, do not increment
        line_thickness=3,  # bounding box thickness (pixels)
        hide_labels=False,  # hide labels
        hide_conf=False,  # hide confidences
        half=False,  # use FP16 half-precision inference
        dnn=False,  # use OpenCV DNN for ONNX inference
):

这段代码定义了run()函数,并设置了一系列参数,用于指定物体检测或识别的相关参数。📚

这里介绍下这些参数:

weights:模型权重文件的路径,默认为YOLOv5s的权重文件路径。

source:输入图像或视频的路径或URL,或者使用数字0指代摄像头,默认为YOLOv5自带的测试图像文件夹。

data:数据集文件的路径,默认为COCO128数据集的配置文件路径。

imgsz:输入图像的大小,默认为640x640。

conf_thres:置信度阈值,默认为0.25。

iou_thres:非极大值抑制的IoU阈值,默认为0.45。

max_det:每张图像的最大检测框数,默认为1000。

device:使用的设备类型,默认为空,表示自动选择最合适的设备。

view_img:是否在屏幕上显示检测结果,默认为False。

save_txt:是否将检测结果保存为文本文件,默认为False。

save_conf:是否在保存的文本文件中包含置信度信息,默认为False。

save_crop:是否将检测出的目标区域保存为图像文件,默认为False。

nosave:是否不保存检测结果的图像或视频,默认为False。

classes:指定要检测的目标类别,默认为None,表示检测所有类别。

agnostic_nms:是否使用类别无关的非极大值抑制,默认为False。

augment:是否使用数据增强的方式进行检测,默认为False。

visualize:是否可视化模型中的特征图,默认为False。

update:是否自动更新模型权重文件,默认为False。

project:结果保存的项目文件夹路径,默认为“runs/detect”。

name:结果保存的文件名,默认为“exp”。

exist_ok:如果结果保存的文件夹已存在,是否覆盖,默认为False,即不覆盖。

line_thickness:检测框的线条宽度,默认为3。

hide_labels:是否隐藏标签信息,默认为False,即显示标签信息。

hide_conf:是否隐藏置信度信息,默认为False,即显示置信度信息。

half:是否使用FP16的半精度推理模式,默认为False。

dnn:是否使用OpenCV DNN作为ONNX推理的后端,默认为False。

💞​​​​​​​💞​​​​​​​💞3.4.2 初始化配置

 "=================初始化配置================="
    # 输入的路径变成字符串
    source = str(source)
    # 是否保存图片和txt文件,如果nosave(传入的参数)为False且source的结尾不是txt则保存图片。
    save_img = not nosave and not source.endswith('.txt')  # save inference images
    # 判断source是不是视频/图像文件路径
    # Path()提取文件名。suffix:最后一个组件的文件扩展名。
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    # 判断source是否是链接
    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
    # 判断source是否是摄像头
    webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
    if is_url and is_file:
        # 返回文件。如果source是一个指向图片/视频的链接,则下载输入数据。
        source = check_file(source)  # download

这段代码主要用于处理输入来源,定义了一些布尔值区分输入是图片、视频、网络流还是摄像头。📚

首先将 source 转换为字符串类型,然后判断是否需要保存输出结果。如果 nosavesource 的后缀不是.txt,则会保存输出结果。

接着根据 source 的类型,确定输入数据的类型:

  • 若source的后缀是图像或视频格式之一,那么将is_file设置为True;
  • 若source以rtsp://、rtmp://、http://、https://开头,那么将is_url设置为True;
  • 若source是数字或以.txt结尾或是一个URL,那么将webcam设置为True;
  • 若source既是文件又是URL,那么会调用check_file函数下载文件。

💞​​​​​​​💞​​​​​​​💞3.4.3 保存结果

"=================保存结果================="
    # Directories
    # save_dir是保存运行结果的文件夹名,是通过递增的方式来命名的。
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
    # 根据前面生成的路径创建文件夹
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

这段代码主要是用于创建保存输出结果的目录。创建一个新的文件夹exp(在runs文件夹下)来保存运行的结果。 📚

save_dir是保存运行结果的文件夹名,是通过递增的方式来命名的。第一次运行时路径是“runs\detect\exp”,第二次运行时路径是“runs\detect\exp1” 。

首先代码中的 project 指 run 函数中的 project,对应的是 runs/detect 的目录,name 对应 run 函数中的“name=exp”,然后进行拼接操作。使用increment_path函数来确保目录不存在,如果存在,则在名称后面添加递增的数字。

然后判断 save_txt 是否为 true,save_txt 在 run 函数以及 parse_opt() 函数中都有相应操作,如果传入save_txt,新建 “labels” 文件夹存储结果。如果目录已经存在,而exist_ok为False,那么会抛出一个异常,指示目录已存在。如果exist_ok为True,则不会抛出异常,而是直接使用已经存在的目录。

💞​​​​​​​💞​​​​​​​💞3.4.4 加载模型

"=================加载模型================="
    # Load model 加载模型
    # 获取设备的CPU/GPU
    device = select_device(device)
    # DetectMultiBackend定义在models.commom模块中,是我们要加载的网络。其中的weights参数就是输入时指定的权重文件(比如yolov5s.pt)。
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    stride, names, pt = model.stride, model.names, model.pt
    '''
               stride:推理时所用到的步长,默认为32, 大步长适合于大目标,小步长适合于小目标
               names:保存推理结果名的列表,比如默认模型的值是['person', 'bicycle', 'car', ...] 
               pt: 加载的是否是pytorch模型(也就是pt格式的文件)
               jit:当某段代码即将第一次被执行时进行编译,因而叫“即时编译”
               onnx:利用Pytorch我们可以将model.pt转化为model.onnx格式的权重,在这里onnx充当一个后缀名称,
                     model.onnx就代表ONNX格式的权重文件,这个权重文件不仅包含了权重值,也包含了神经网络的网络流动信息以及每一层网络的输入输出信息和一些其他的辅助信息。
           '''
    # 确保输入图片的尺寸imgsz能整除stride=32 如果不能则调整为能被整除并返回
    imgsz = check_img_size(imgsz, s=stride)  # check image size

这段代码主要是用于选择设备、初始化模型和检查图像大小。📚

首先,调用select_device函数选择设备,如果device为空,则使用默认设备。

然后,使用DetectMultiBackend类来初始化模型,接着从模型中获取stride、names和pt等参数。

最后,调用check_img_size函数检查图像大小是否符合要求,如果不符合则进行调整。

💞​​​​​​​💞​​​​​​​💞3.4.5 加载数据

 "=================加载数据================="
    # Dataloader
    # Dataloader
    # 通过不同的输入源来设置不同的数据加载方式
    if webcam: # 使用摄像头作为输入
        view_img = check_imshow() # 检测cv2.imshow()方法是否可以执行,不能执行则抛出异常
        cudnn.benchmark = True  # set True to speed up constant image size inference 该设置可以加速预测
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt) # 加载输入数据流
        '''
                source:输入数据源;image_size 图片识别前被放缩的大小;stride:识别时的步长, 
                auto的作用可以看utils.augmentations.letterbox方法,它决定了是否需要将图片填充为正方形,如果auto=True则不需要
               '''
        bs = len(dataset)  # batch_size
    else: # 直接从source文件下读取图片
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
        bs = 1  # batch_size
    # 保存视频的路径
    vid_path, vid_writer = [None] * bs, [None] * bs # 前者是视频路径,后者是一个cv2.VideoWriter对象

这段代码通过输入的 source 参数来判断数据输入源(是摄像头还是从source文件下读取的)。 📚

若输入源是摄像头:使用 LoadStreams 加载视频流,并设置 cudnn.benchmark = True 以加速常量图像大小的推理。

若输入源是source文件下读取的(图片/视频):则使用 LoadImages 加载图像。bs:batch_size(批量大小),这里表示 1 或视频流中的帧数,vid_path 和 vid_writer 分别是视频路径和视频编写器,初始化为长度为 batch_size 的空列表。

💞​​​​​​​💞​​​​​​​💞3.4.6 推理部分

推理部分是整个算法的核心部分,通过 for 循环对加载的数据进行遍历,如果是视频流则一帧一帧地推理,然后进行NMS,最后画框,预测类别。

(1)热身部分

# Run inference
    if pt and device.type != 'cpu':
        # 使用空白图片(零矩阵)预先用GPU跑一遍预测流程,可以加速预测
        model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters())))  # warmup
    dt, seen = [0.0, 0.0, 0.0], 0
    '''
     dt: 存储每一步骤的耗时
     seen: 计数功能,已经处理完了多少帧图片
    '''
    # 去遍历图片,进行计数,
    for path, im, im0s, vid_cap, s in dataset:
        '''
         在dataset中,每次迭代的返回值是self.sources, img, img0, None, ''
          path:文件路径(即source)
          im: resize后的图片(经过了放缩操作)
          im0s: 原始图片
          vid_cap=none
          s: 图片的基本信息,比如路径,大小
        '''
        # ===以下部分是做预处理===#
        t1 = time_sync() # 获取当前时间
        im = torch.from_numpy(im).to(device) # 将图片放到指定设备(如GPU)上识别。#torch.size=[3,640,480]
        im = im.half() if half else im.float()  # uint8 to fp16/32 # 把输入从整型转化为半精度/全精度浮点数。
        im /= 255  # 0 - 255 to 0.0 - 1.0 归一化,所有像素点除以255
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim 添加一个第0维。缺少batch这个尺寸,所以将它扩充一下,变成[1,3,640,480]
        t2 = time_sync() # 获取当前时间
        dt[0] += t2 - t1 # 记录该阶段耗时

这段代码让模型进行了一个预热,然后定义 dt、seen 两个变量,遍历 dataset ,整理图片信息。📚

热身操作,即对模型进行一些预处理以加速后续的推理过程。

作用:♨️♨️♨️

深度学习模型训练热身的作用是为了使初始权重更好地适应数据分布,提高最终模型的收敛速度和泛化能力。通过热身训练,可以有效减少梯度下降的震荡,加速收敛速度,并降低局部极小值的影响。

说简单点就是在模型训练初期给他一个较大的学习率,因为较大的学习率就不那么容易会使模型学偏,然后在训练的后期再减小学习率,使其收敛。

在这个阶段,还定义了一些变量,包括seen、windows 和 dt,分别表示已处理的图片数量、窗口列表和时间消耗列表。遍历dataset,整理图片信息。

接着是对数据集的图片进行预处理:

  • 将图片转化为tensor格式,放到device上,并转换为FP16/32。
  • 将像素值0 ~ 255归一化,变为0 ~ 1,并为批处理增加一维度(batch)。
  • 记录时间消耗并更新dt
     

(2)对每张图片/视频进行前向推理

# Inference
        # 可视化文件路径。如果为True则保留推理过程中的特征图,保存在runs文件夹中
        visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
        # 推理结果,pred保存的是所有的bound_box的信息,
        pred = model(im, augment=augment, visualize=visualize) #模型预测出来的所有检测框,torch.size=[1,18900,85]
        t3 = time_sync()
        dt[1] += t3 - t2

这里对每张图片进行前向推理。📚

第二行代码,使用model对图像进行预测,augment和visualize参数是用于指示是否在预测时使用数据增强和可视化。后面的代码记录了当前时间,并计算从上一个时间点到这个时间点的时间差,然后将这个时间差加到一个名为dt的时间差列表中的第二个元素上。

(3)NMS后处理除去多余的框

 # NMS
        # 执行非极大值抑制,返回值为过滤后的预测框
        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
        '''
         pred: 网络的输出结果
         conf_thres: 置信度阈值
         iou_thres: iou阈值
         classes: 是否只保留特定的类别 默认为None
         agnostic_nms: 进行nms是否也去除不同类别之间的框
         max_det: 检测框结果的最大数量 默认1000
        '''
        # 预测+NMS的时间
        dt[2] += time_sync() - t3

这段是YOLO的经典代码:非极大值抑制(NMS),用于筛选预测结果。再次更新计时器,记录NMS所耗费的时间。📚

(4)预测过程

 # Process predictions
        # 把所有的检测框画到原图中
        for i, det in enumerate(pred):  # per image 每次迭代处理一张图片
            '''
            i:每个batch的信息
            det:表示5个检测框的信息
            '''
            seen += 1 #seen是一个计数的功能
            if webcam:  # batch_size >= 1
                # 如果输入源是webcam则batch_size>=1 取出dataset中的一张图片
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f'{i}: ' # s后面拼接一个字符串i
            else:
                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
            '''
                大部分我们一般都是从LoadImages流读取本都文件中的照片或者视频 所以batch_size=1
                   p: 当前图片/视频的绝对路径 如 F:\yolo_v5\yolov5-U\data\images\bus.jpg
                   s: 输出信息 初始为 ''
                   im0: 原始图片 letterbox + pad 之前的图片
                   frame: 视频流,此次取的是第几张图片
            '''

对筛选后的结果进行for循环遍历,这一段主要是判断是否采用网络摄像头。📚

如果使用的是网络摄像头,则代码会遍历每个图像并复制一份备份到变量im0中,同时将当前图像的路径和计数器记录到变量p和frame中。最后,将当前处理的物体索引和相关信息记录到字符串变量s中。如果没有使用网络摄像头,则会直接使用im0变量中的图像,将图像路径和计数器记录到变量p和frame中。同时,还会检查数据集中是否有"frame"属性,如果有,则将其值记录到变量frame中。

 p = Path(p)  # to Path
            # 图片/视频的保存路径save_path 如 runs\\detect\\exp8\\fire.jpg
            save_path = str(save_dir / p.name)  # im.jpg
            # 设置保存框坐标的txt文件路径,每张图片对应一个框坐标信息
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt
            # 设置输出图片信息。图片shape (w, h)
            s += '%gx%g ' % im.shape[2:]  # print string
            # 得到原图的宽和高
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            # 保存截图。如果save_crop的值为true,则将检测到的bounding_box单独保存成一张图片。
            imc = im0.copy() if save_crop else im0  # for save_crop
            # 得到一个绘图的类,类中预先存储了原图、线条宽度、类名
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))

这一部分主要是路径转换,save_crop来选择是否把检测框裁剪下来保存成一张图片。📚

最后创建了一个annotator对象,以便于在图像上绘制检测结果。

# 判断有没有框
            if len(det):
                # Rescale boxes from img_size to im0 size
                # 将预测信息映射到原图
                # 将标注的bounding_box大小调整为和原图一致(因为训练时原图经过了放缩)此时坐标格式为xyxy
                det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() #scale_coords:坐标映射功能
 
                # Print results
                # 打印检测到的类别数量
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

这一部分是判断有没有框,如果有物体,则会执行操作。📚

首先,scale_coords会将检测结果中的物体坐标从缩放的图片大小变回去。然后遍历det的内容,前面说了det就是一张图片的信息,其实det里面包含了每一个物体的信息,将其类别和数量添加到s字符串中。方便后面打印。

(5)打印目标检测结果

 # Write results
                # 保存预测结果:txt/图片画框/crop-image
                for *xyxy, conf, cls in reversed(det):
                    # 将每个图片的预测信息分别存入save_dir/labels下的xxx.txt中 每行: class_id + score + xywh
                    if save_txt:  # Write to file 保存txt文件
                        # 将xyxy(左上角+右下角)格式转为xywh(中心点+宽长)格式,并归一化,转化为列表再保存
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        # line的形式是: ”类别 x y w h“,若save_conf为true,则line的形式是:”类别 x y w h 置信度“
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                        with open(txt_path + '.txt', 'a') as f:
                            # 写入对应的文件夹里,路径默认为“runs\detect\exp*\labels”
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')
 
                    # 在原图上画框+将预测到的目标剪切出来保存成图片,保存在save_dir/crops下,在原图像画图或者保存结果
                    if save_img or save_crop or view_img:  # Add bbox to image
                        c = int(cls)  # integer class # 类别标号
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') # 类别名
                        annotator.box_label(xyxy, label, color=colors(c, True))  #绘制边框
                        # 在原图上画框+将预测到的目标剪切出来保存成图片,保存在save_dir/crops下(单独保存)
                        if save_crop:
                            save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)

如果存在目标检测结果,就会执行下一步操作。

(6)在窗口中实时查看检测结果

 # Print time (inference-only)
            # 打印耗时
            LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
 
            # Stream results
            # 如果设置展示,则show图片 / 视频
            im0 = annotator.result() # im0是绘制好的图片
            # 显示图片
            if view_img:
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 暂停 1 millisecond

这段代码是在输出窗口中实时查看检测结果。📚

(7)设置保存结果

# Save results (image with detections)
            # 设置保存图片/视频
            if save_img: # 如果save_img为true,则保存绘制完的图片
                if dataset.mode == 'image': # 如果是图片,则保存
                    cv2.imwrite(save_path, im0)
                else:  # 'video' or 'stream'  如果是视频或者"流"
                    if vid_path[i] != save_path:  # new video  vid_path[i] != save_path,说明这张图片属于一段新的视频,需要重新创建视频文件
                        vid_path[i] = save_path
                        # 以下的部分是保存视频文件
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)  # 视频帧速率 FPS
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 获取视频帧宽度
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 获取视频帧高度
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                            save_path += '.mp4'
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    vid_writer[i].write(im0)

这段代码是设置保存图片或者视频结果。📚

首先“save_img”判断是否是图片,如果是则保存路径和图片;如果是视频或者流,需要重新创建视频文件。

💞​​​​​​​💞​​​​​​​💞3.4.7 打印结果

'''================在终端里打印出运行的结果============================'''
    # Print results
    t = tuple(x / seen * 1E3 for x in dt)  # speeds per image 平均每张图片所耗费时间
    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' # 标签保存的路径
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights)  # update model (to fix SourceChangeWarning)

这段代码用于打印结果,记录了一些总共的耗时,以及信息保存。📚

输出的结果包括每张图片的预处理、推理和NMS时间,以及结果保存的路径。

如果update为True,则将模型更新,以修复SourceChangeWarning。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/905413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【三维重建】Factor Fields: 超越神经场的统一框架

论文:Factor Fields: A Unified Framework for Neural Fields and Beyond 文章:https://arxiv.org/abs/2302.01226 项目:https://apchenstu.github.io/FactorFields/ 文章目录 摘要一、前言二、Factor Fields2.1.字典场(DiF&#…

二级MySQL(三)——数据库的增删改查

创建一个新的数据库: CREATE DATABASE db_school DEFAULT CHARACTER SET GB2312 DEFAULT COLLATE GB2312_chinese_ci; 查找创建的数据库(全部): 其他的是SQL自带的数据库 查阅我们自己创建的字符集以及对应的字符集…

Qt 实现 360 安全卫士

作者: 一去、二三里 个人微信号: iwaleon 微信公众号: 高效程序员 回想起来,这也算是一个有故事的代码。虽然时间比较久远,但还是记忆犹新。 那就简单说说吧,也不枉费当年的一片心血! 说说我的…

java可变字符串

一、常用方法 以StringBuilder为例 1、append(String str) 添加 StringBuilder str new StringBuilder("hello"); System.out.println(str);//在源字符串后添加world StringBuilder add str.append("world"); System.out.println(add);//结果helloworl…

Platypus:Quick,Cheap,and Powerful Refinement of LLMs

Platypus:Quick,Cheap,and Powerful Refinement of LLMs IntroductionMethod2.1 Curating Open- PlatypusRemoving similar&duplicate questionsContamination CheckFine-tuning & mergingResult参考Introduction 现在大模型已经取得很不错的结果,如何把大模型的能…

PL 侧驱动和fpga 重加载的方法

可以解决很多的问题 时钟稳定后加载特定fpga ip (要不内核崩的一塌糊涂)fpga 稳定复位软件决定fpga ip 加载的时序 dluash load /usr/local/scripts/si5512_setup.lua usleep 30 mkdir -p /lib/firmware cp -rf /usr/local/firmare/{*.bit.bin,*.dtbo} …

css 实现svg动态图标效果

效果演示&#xff1a; 实现思路&#xff1a;主要是通过css的stroke相关属性来设置实现的。 html代码: <svgt"1692441666814"class"icon"viewBox"0 0 1024 1024"version"1.1"xmlns"http://www.w3.org/2000/svg"p-id"…

jps(JVM Process Status Tool):虚拟机进程状况工具

jps&#xff08;JVM Process Status Tool&#xff09;&#xff1a;虚拟机进程状况工具 列出正在运行的虚拟机进程&#xff0c;并显示虚拟机执行主类名称&#xff08;Main Class&#xff0c;main()函数所在的类&#xff09;以及这些进程的本地虚拟机唯一ID&#xff08;LVMID&am…

VMware上搭建的虚拟机突然本地无法连接服务器

长时间没有使用VMware 虚拟机了&#xff0c;今天突然登录上去&#xff0c;启动虚拟服务器后发现本地等不了了&#xff0c; 经过排查发现是开启了&#xff1a;VirtualBox Host-Only Network 关闭之后就本机就可以直连服务器了

java能实现热替换而属性不丢失的原因

1.替换的是klass&#xff0c;数据在oop里面 2.这个没想通说明对java面向对象底层实现不了解。

Midjourney API 申请及使用

在人工智能绘图领域&#xff0c;想必大家听说过 Midjourney 的大名吧&#xff01; Midjourney 以其出色的绘图能力在业界独树一帜。无需过多复杂的操作&#xff0c;只要简单输入绘图指令&#xff0c;这个神奇的工具就能在瞬间为我们呈现出对应的图像。无论是任何物体还是任何风…

linux中shell脚本——shell数组、正则表达式及文件三剑客之AWK

目录 一.shell数组 1.1.数组分类 1.2.定义数组方法 二.正则表达式 2.1.元字符 2.2.表示次数 2.3.位置锚定 2.4.分组 2.5.扩展正则表达式 三.文本三剑客之AWK 3.1.awk介绍及使用格式 3.2.处理动作 3.3.awk选项 3.4.awk处理模式 2.5.awk常见的内置变量 2.6.if条…

5.5.webrtc的线程管理

今天呢&#xff0c;我们来介绍一下线程的管理与绑定&#xff0c;首先我们来看一下web rtc中的线程管理类&#xff0c;也就是thread manager。对于这个类来说呢&#xff0c;其实实现非常简单&#xff0c;对吧&#xff1f; 包括了几个重要的成员&#xff0c;第一个成员呢就是ins…

2021年12月 C/C++(三级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题:我家的门牌号 我家住在一条短胡同里,这条胡同的门牌号从1开始顺序编号。 若所有的门牌号之和减去我家门牌号的两倍,恰好等于n,求我家的门牌号及总共有多少家。 数据保证有唯一解。 时间限制:1000 内存限制:65536 输入 一个正整数n。n < 100000。 输出 一行,包含…

DTC 19服务学习2

紧跟上篇 0x04 reportDTCSnapshotRecordByDTCNumber 通过DTC和快照序列来获取DTC快照记录。 适用以下假设&#xff1a; — 服务器支持存储给定 DTC 的两个 DTCSnapshot 记录的能力。 — 此示例假定是上一个示例的延续。 — 假设服务器请求服务器存储的 DTC 编号 123456 的两个…

【学会动态规划】环绕字符串中唯一的子字符串(25)

目录 动态规划怎么学&#xff1f; 1. 题目解析 2. 算法原理 1. 状态表示 2. 状态转移方程 3. 初始化 4. 填表顺序 5. 返回值 3. 代码编写 写在最后&#xff1a; 动态规划怎么学&#xff1f; 学习一个算法没有捷径&#xff0c;更何况是学习动态规划&#xff0c; 跟我…

使用 Amazon Redshift Serverless 和 Toucan 构建数据故事应用程序

这是由 Toucan 的解决方案工程师 Django Bouchez与亚马逊云科技共同撰写的特约文章。 带有控制面板、报告和分析的商业智能&#xff08;BI&#xff0c;Business Intelligence&#xff09;仍是最受欢迎的数据和分析使用场景之一。它为业务分析师和经理提供企业的过去状态和当前状…

尝试自主打造一个有限状态机(一)

前言 我们都知道Unity有自带的有限状态机Animator&#xff0c;它的功能非常强大&#xff0c;为了探索它背后的原理&#xff0c;我开启了这个系列的文章&#xff0c;尝试通过自主打造一个有限状态机来理解Animator的工作原理&#xff0c;同时我会将这个状态机应用于实际&#xf…

unity 之 Input.GetMouseButtonDown 的使用

文章目录 Input.GetMouseButtonDown Input.GetMouseButtonDown 当涉及到处理鼠标输入的时候&#xff0c;Input.GetMouseButtonDown 是一个常用的函数。它可以用来检测鼠标按键是否在特定帧被按下。下面我会详细介绍这个函数&#xff0c;并举两个例子说明如何使用它。 函数签名…

美国陆军希望大数据技术能够帮助保护其云安全

随着陆军采用更大型的云服务&#xff0c;一位高级官员警告说&#xff0c;一些在私营部门有效的快速软件开发技巧和简单解决方案&#xff08;例如开放代码库&#xff09;如果没有额外的安全性&#xff0c;将无法为军队工作。 我们知道现代软件开发确实依赖于第三方库&#xff…