【超详细】基于YOLOv8训练无人机视角Visdrone2019数据集

news2024/9/23 14:03:55

主要内容如下:

1、Visdrone2019数据集介绍
2、下载、制作YOLO格式训练集
3、模型训练及预测
4、Onnxruntime推理

运行环境:Python=3.8(要求>=3.8),torch1.12.0+cu113(要求>=1.8),onnxruntime-gpu==1.12.0
原始数据集百度AI stduio下载链接:https://aistudio.baidu.com/datasetdetail/115729
Visdrone-YOLO格式数据集下载链接:https://aistudio.baidu.com/datasetdetail/295374
训练资源占用:bacth=16,workers=8,yolov8s显存需16G,bacth=8的话8G够用,RTX4080大约1min一个epoch。

往期内容:

【超详细】跑通YOLOv8之深度学习环境配置1-Anaconda安装
【超详细】跑通YOLOv8之深度学习环境配置2-CUDA安装
【超详细】跑通YOLOv8之深度学习环境配置3-YOLOv8安装
【超详细】基于YOLOv8的PCB缺陷检测
【超详细】基于YOLOv8改进1-Drone-YOLO复现

1 数据集介绍

1.1 简介

VisDrone数据集是由天津大学等团队开源的一个大型无人机视角的数据集,官方提供的数据中训练集是6471、验证集是548、测试集1610张。数据集共提供了以下12个类,分别是:‘忽略区域’, ‘pedestrian’, ‘people’, ‘bicycle’, ‘car’, ‘van’,‘truck’, ‘tricycle’, ‘awning-tricycle’, ‘bus’, ‘motor’, ‘others’,其中忽略区域、others是非有效目标区域,本项目中予以忽略;

1.2 示例

在这里插入图片描述

1.3 标签格式

在这里插入图片描述

**标签含义:**
1. 边界框左上角的x坐标
2. 边界框左上角的y坐标
3. 边界框的宽度
4. 边界框的高度
5. GROUNDTRUTH文件中的分数设置为101表示在计算中考虑边界框,而0表示将忽略边界框。
6.  类别:忽略区域(0)、行人(1)、人(2)、自行车(3)、汽车(4)、面包车(5)、卡车(6)、三轮车(7)、雨篷三轮车(8)、公共汽车(9)、摩托车(10),其他(11)。
7. GROUNDTRUTH文件中的得分表示对象部分出现在帧外的程度(即,无截断=0(截断比率0%),部分截断=1(截断比率1%°´50%))。
8. GROUNDTRUTH文件中的分数表示被遮挡的对象的分数(即,无遮挡=0(遮挡比率0%),部分遮挡=1(遮挡比率1%°´50%),重度遮挡=2(遮挡率50%~100%))。

2 下载和制作YOLO格式数据集

2.1 下载原始数据集

百度AI stduio下载链接:https://aistudio.baidu.com/datasetdetail/115729
注意:可直接下载已完成转换的YOLO格式数据进行训练,可跳过该阶段,直接训练!链接为:https://aistudio.baidu.com/datasetdetail/295374
在这里插入图片描述
下载解压
在这里插入图片描述
注意:由于格式不是YOLO直接可以训练的格式,所以需进行转换!!!

2.2 制作YOLO格式数据集

(1)新建visdrone2yolo.py脚本,脚本内容如下:
(2)修改路径参数–dir_path的值,即自己下载路径;
(2)结果会在原始每个文件夹下生成一个label文件夹,即YOLO格式标签;

import os
from pathlib import Path
import argparse
 
def visdrone2yolo(dir):
    from PIL import Image
    from tqdm import tqdm
 
    def convert_box(size, box):
        # Convert VisDrone box to YOLO xywh box
        dw = 1. / size[0]
        dh = 1. / size[1]
        return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
 
    (dir / 'labels').mkdir(parents=True, exist_ok=True)  # make labels directory
    pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
    for f in pbar:
        img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
        lines = []
        with open(f, 'r') as file:  # read annotation.txt
            for row in [x.split(',') for x in file.read().strip().splitlines()]:
                if row[4] == '0':  # VisDrone 'ignored regions' class 0
                    continue
                cls = int(row[5]) - 1  # 类别号-1
                box = convert_box(img_size, tuple(map(int, row[:4])))
                lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
                with open(str(f).replace(os.sep + 'annotations' + os.sep, os.sep + 'labels' + os.sep), 'w') as fl:
                    fl.writelines(lines)  # write label.txt


if __name__ == '__main__':
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_path', type=str, default=r'E:\datasets\visdrone2019', help='visdrone数据集路径')
    args = parser.parse_args()

    dir = Path(args.dir_path)
    # Convert
    for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
        visdrone2yolo(dir / d)  # convert VisDrone annotations to YOLO labels
 

3 模型训练及预测

3.1 模型训练

3.1.1 修改数据集配置文件

文件路径:ultralytics-main\ultralytics\cfg\datasets\VisDrone.yaml
在这里插入图片描述

3.1.2 创建模型训练脚本

(1)训练方式1-脚本训练
在ultralytics-main目录新建一个train.py脚本,内容如下:
注意:如爆显存,降低batch大小!!!
【如下配置显存需16G,bacth=8的话8G够用,RTX4080大约1min一个epoch】

from ultralytics import YOLO

if __name__ == '__main__':
    # Load a model
    # model = YOLO("yolov8n.yaml")  # build a new model from scratch
    model = YOLO("yolov8s.pt")  # load a pretrained model (recommended for training)

    # Use the model
    model.train(data="VisDrone.yaml", imgsz=640, batch=16, workers=8, cache=True, epochs=100)  # train the model
    metrics = model.val()  # evaluate model performance on the validation set
    # results = model("ultralytics\\assets\\bus.jpg")  # predict on an image
    path = model.export(format="onnx", opset=13)  # export the model to ONNX format

(2)训练方式2-终端命令行

cd ../ultralytics-main
yolo task=detect mode=train model=yolov8s.pt data=ultralytics/cfg/datasets/VisDrone.yaml batch=16 epochs=100 imgsz=640 workers=8 cache=True device=0
3.1.3 数据分布情况可视化

特点:类别不均衡、小目标较多(640*640输入精度不会太高,可提高输入分辨率,如1280、1536等)。
在这里插入图片描述

3.1.4 训练结果可视化

训练100epoch结果如下,增加epoch还能提升。
略

3.2 模型预测

在ultralytics-main目录新建一个predict.py脚本,内容如下:

from ultralytics import YOLO

if __name__ == '__main__':
    # Load a model
    model = YOLO(r"E:\Code\ultralytics-main\runs\detect\train\weights\best.pt")  # load model
    model.predict(source=r"E:\datasets\visdrone2019\VisDrone2019-DET-test-dev\images\0000006_01111_d_0000003.jpg", save=True, save_conf=True, save_txt=True, name='output')

结果如下:
在这里插入图片描述

4 Onnxruntime推理

在ultralytics-main目录新建一个onnx_infer.py脚本,内容如下:
注意:如导出动态onnx,model.export(format=“onnx”, opset=13, dynamic=True)

import argparse
import time 
import cv2
import numpy as np

import onnxruntime as ort  # 使用onnxruntime推理用上,pip install onnxruntime-gpu==1.12.0 -i  https://pypi.tuna.tsinghua.edu.cn/simple,默认安装CPU
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

class YOLOv8:
    """YOLOv8 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, imgsz=(640, 640)):
        """
        Initialization.

        Args:
            onnx_model (str): Path to the ONNX model.
        """
        
        # 构建onnxruntime推理引擎
        self.ort_session = ort.InferenceSession(onnx_model,
                                            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
                                            if ort.get_device() == 'GPU' else ['CPUExecutionProvider'])
        print(ort.get_device())
        # Numpy dtype: support both FP32 and FP16 onnx model
        self.ndtype = np.half if self.ort_session.get_inputs()[0].type == 'tensor(float16)' else np.single
       
        self.model_height, self.model_width = imgsz[0], imgsz[1]  # 图像resize大小
     

    def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45):
        """
        The whole pipeline: pre-process -> inference -> post-process.

        Args:
            im0 (Numpy.ndarray): original input image.
            conf_threshold (float): confidence threshold for filtering predictions.
            iou_threshold (float): iou threshold for NMS.

        Returns:
            boxes (List): list of bounding boxes.
        """
        # 前处理Pre-process
        t1 = time.time()
        im, ratio, (pad_w, pad_h) = self.preprocess(im0)
        pre_time = round(time.time() - t1, 3)
        # print('det预处理时间:{:.3f}s'.format(time.time() - t1))
        
        # 推理 inference
        t2 = time.time()
        preds = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: im})[0]
        # print('det推理时间:{:.2f}s'.format(time.time() - t2))
        det_time = round(time.time() - t2, 3)
        
        # 后处理Post-process
        t3 = time.time()
        boxes = self.postprocess(preds,
                                im0=im0,
                                ratio=ratio,
                                pad_w=pad_w,
                                pad_h=pad_h,
                                conf_threshold=conf_threshold,
                                iou_threshold=iou_threshold,
                                )
        # print('det后处理时间:{:.3f}s'.format(time.time() - t3))
        post_time = round(time.time() - t3, 3)

        return boxes, (pre_time, det_time, post_time)
        
    # 前处理,包括:resize, pad, HWC to CHW,BGR to RGB,归一化,增加维度CHW -> BCHW
    def preprocess(self, img):
        """
        Pre-processes the input image.

        Args:
            img (Numpy.ndarray): image about to be processed.

        Returns:
            img_process (Numpy.ndarray): image preprocessed for inference.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
        """
        # Resize and pad input image using letterbox() (Borrowed from Ultralytics)
        shape = img.shape[:2]  # original image shape
        new_shape = (self.model_height, self.model_width)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        ratio = r, r
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2  # wh padding
        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
            
        top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
        left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))  # 填充

        # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
        img = np.ascontiguousarray(np.einsum('HWC->CHW', img)[::-1], dtype=self.ndtype) / 255.0
        img_process = img[None] if len(img.shape) == 3 else img
        return img_process, ratio, (pad_w, pad_h)
    
    # 后处理,包括:阈值过滤与NMS
    def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold):
        """
        Post-process the prediction.

        Args:
            preds (Numpy.ndarray): predictions come from ort.session.run().
            im0 (Numpy.ndarray): [h, w, c] original input image.
            ratio (tuple): width, height ratios in letterbox.
            pad_w (float): width padding in letterbox.
            pad_h (float): height padding in letterbox.
            conf_threshold (float): conf threshold.
            iou_threshold (float): iou threshold.

        Returns:
            boxes (List): list of bounding boxes.
        """
        x = preds  # outputs: predictions (1, 84, 8400)
        # Transpose the first output: (Batch_size, xywh_conf_cls, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls)
        x = np.einsum('bcn->bnc', x)  # (1, 8400, 84)
   
        # Predictions filtering by conf-threshold
        x = x[np.amax(x[..., 4:], axis=-1) > conf_threshold]

        # Create a new matrix which merge these(box, score, cls) into one
        # For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html
        x = np.c_[x[..., :4], np.amax(x[..., 4:], axis=-1), np.argmax(x[..., 4:], axis=-1)]

        # NMS filtering
        # 经过NMS后的值, np.array([[x, y, w, h, conf, cls], ...]), shape=(-1, 4 + 1 + 1)
        x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]
       
        # 重新缩放边界框,为画图做准备
        if len(x) > 0:
            # Bounding boxes format change: cxcywh -> xyxy
            x[..., [0, 1]] -= x[..., [2, 3]] / 2
            x[..., [2, 3]] += x[..., [0, 1]]

            # Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image
            x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
            x[..., :4] /= min(ratio)

            # Bounding boxes boundary clamp
            x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
            x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])

            return x[..., :6]  # boxes
        else:
            return []


if __name__ == '__main__':
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--det_model', type=str, default=r"E:\Code\ultralytics-main\runs\detect\train\weights\best.onnx", help='Path to ONNX model')
    parser.add_argument('--source', type=str, default=str(r'E:\datasets\visdrone2019\VisDrone2019-DET-test-dev\images'), help='Path to input image')
    parser.add_argument('--out_path', type=str, default=str(r'E:\Code\ultralytics-main\runs/res'), help='结果保存文件夹')
    parser.add_argument('--imgsz_det', type=tuple, default=(640, 640), help='Image input size')
    parser.add_argument('--classes', type=list, default=['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 
                                                         'awning-tricycle', 'bus', 'motor'], help='类别')

    parser.add_argument('--conf', type=float, default=0.25, help='Confidence threshold')
    parser.add_argument('--iou', type=float, default=0.6, help='NMS IoU threshold')
    args = parser.parse_args()

    if not os.path.exists(args.out_path):
        os.mkdir(args.out_path)
    print('开始运行:')
    # Build model
    det_model = YOLOv8(args.det_model, args.imgsz_det)
    color_palette = np.random.uniform(0, 255, size=(len(args.classes), 3))  # 为每个类别生成调色板
    
    for i, img_name in enumerate(os.listdir(args.source)):
        try:
            t1 = time.time()
            # Read image by OpenCV
            img = cv2.imread(os.path.join(args.source, img_name))

            # 检测Inference
            boxes, (pre_time, det_time, post_time) = det_model(img, conf_threshold=args.conf, iou_threshold=args.iou)
            print('{}/{} ==>总耗时间: {:.3f}s, 其中, 预处理: {:.3f}s, 推理: {:.3f}s, 后处理: {:.3f}s, 识别{}个目标'.format(i+1, len(os.listdir(args.source)), time.time() - t1, pre_time, det_time, post_time, len(boxes)))
            
            for (*box, conf, cls_) in boxes:
                cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
                                color_palette[int(cls_)], 2, cv2.LINE_AA)
                cv2.putText(img, f'{args.classes[int(cls_)]}: {conf:.3f}', (int(box[0]), int(box[1] - 9)),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
            cv2.imwrite(os.path.join(args.out_path, img_name), img)
        
        except Exception as e:
            print(e)      

资源占用:显存不到2G,RTX4080推理耗时20几毫秒。
在这里插入图片描述
结果可视化如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2157710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网站建设中,sitemap是什么,有什么作用

在网站建设中,Sitemap(站点地图)是一种文件,通常采用txt或XML格式,它列出了网站中的网页、视频或其他文件的相关信息。Sitemap的主要作用是帮助搜索引擎更高效地抓取和索引网站内容。 以下是Sitemap的具体作用&#x…

ABAP 学习t-code DWDM

ABAP 学习t-code DWDM ,里面有很多例子展示,且能看到源代码

【第十四章:Sentosa_DSML社区版-机器学习之时间序列】

目录 【第十四章:Sentosa_DSML社区版-机器学习时间序列】 14.1 ARIMAX 14.2 ARIMA 14.3 HoltWinters 14.4 一次指数平滑预测 14.5 二次指数平滑预测 【第十四章:Sentosa_DSML社区版-机器学习时间序列】 14.1 ARIMAX 1.算子介绍 考虑其他序列对一…

云计算第四阶段---CLOUD Day7---Day8

CLOUD 07 一、Dockerfile详细解析 指令说明FROM指定基础镜像(唯一)RUN在容器内执行命令,可以写多条ADD把文件拷贝到容器内,如果文件是 tar.xx 格式,会自动解压COPY把文件拷贝到容器内,不会自动解压ENV设置…

双十一快来了!什么值得买?分享五款高品质好物~

双十一大促再次拉开帷幕,面对众多优惠是否感到选择困难?为此,我们精心筛选了一系列适合数字生活的好物,旨在帮助每一位朋友都能轻松找到心仪之选。这份推荐清单,不仅实用而且性价比高,是您双十一购物的不二…

C++入门基础知识82(实例)——实例7【 判断一个数是奇数还是偶数】

成长路上不孤单😊😊😊😊😊😊 【14后😊///C爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于C 实例 【判断一个数是奇数还是偶数】相…

【JavaEE初阶】文件IO(上)

欢迎关注个人主页:逸狼 创造不易,可以点点赞吗~ 如有错误,欢迎指出~ 目录 路径 绝对路径 相对路径 文件类型 文件的操作 File类 文件系统操作 创建文件,获取路径 删除文件 列出所有路径 路径修改 创建目录 mkdir和mkdirs 服务器领域,机械…

win系统接入google_auth实现动态密码,加强保护

开源代码地址:windows动态密码: 针对win服务器进行的动态密码管控,需要配合谷歌的身份认证APP使用 (gitee.com) 为什么要搞个动态密码呢? 首先云服务器启用了远程访问,虽然更换了端口以及初始用户名,不过还是是不是被…

go的结构体、方法、接口

结构体: 结构体:不同类型数据集合 结构体成员是由一系列的成员变量构成,这些成员变量也被称为“字段” 先声明一下我们的结构体: type Person struct {name stringage intsex string } 定义结构体法1: var p1 P…

老程序员的数字游戏开发笔记(三) —— Godot出你的第一个2D游戏(一篇文章完整演绎Godot制作2D游戏的全部细节)

忽略代码,忽略素材,忽略逻辑! 游戏的精髓是人性与思想,我一篇一篇地制作,不想动手的小伙伴看一看就可以,感受一下也不错,我们是有目的性的,这一切都是为今后的AI融合打基础&#xf…

详解CORDIC算法以及Verilog实现并且调用Xilinx CORDIC IP核进行验证

系列文章目录 文章目录 系列文章目录一、什么是CORDIC算法?二、CORDIC算法原理推导三、CORDIC模式3.1 旋转模式3.2 向量模式 四、Verilog实现CORDIC4.1 判断象限4.2 定义角度表4.3 迭代公式 五、仿真验证5.1 matlab打印各角度的正余弦值5.2 Verilog仿真结果观察 六、…

大模型学习方向不知道的,看完这篇学习思路好清晰!!

入门大模型并没有想象中复杂,尤其对于普通程序员,建议采用从外到内的学习路径。下面我们通过几个步骤来探索如何系统学习大模型: 1️⃣初步理解应用场景与人才需求 大模型的核心应用涵盖了智能体(AI Agent)、微调&…

NodeFormer:一种用于节点分类的可扩展图结构学习 Transformer

人工智能咨询培训老师叶梓 转载标明出处 现有的神经网络(GNNs)在处理大规模图数据时面临着一些挑战,如过度平滑、异质性、长距离依赖处理、边缘不完整性等问题,尤其是当输入图完全缺失时。为了解决这些问题,上海交通大…

RK3588NPU驱动版本升级至0.9.6教程

RK3588NPU驱动版本升级至0.9.6教程 1、下载RK3588NPU驱动2、修改NPU驱动源码2.0 修改MONITOR_TPYE_DEV写错问题2.1 解决缺少函数rockchip_uninit_opp_table问题2.2 解决缺少函数vm_flags_set、vm_flag_clear的问题2.3 内核编译成功2.4 重新构建系统 3、注意事项4、其他问题处理…

故障诊断 | 基于双路神经网络的滚动轴承故障诊断

故障诊断 | 基于双路神经网络的滚动轴承故障诊断 目录 故障诊断 | 基于双路神经网络的滚动轴承故障诊断效果一览基本介绍程序设计参考资料效果一览 基本介绍 基于双路神经网络的滚动轴承故障诊断 融合了原始振动信号 和 二维信号时频图像的多输入(多通道)故障诊断方法 单路和双…

【原创】java+springboot+mysql党员教育网系统设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

【Linux】常用指令【更详细,带实操】

Linux全套讲解系列,参考视频-B站韩顺平,本文的讲解更为详细 目录 一、文件目录指令 1、cd【change directory】指令 ​ 2、mkdir【make dir..】指令​ 3、cp【copy】指令 ​ 4、rm【remove】指令 5、mv【move】指令 6、cat指令和more指令 7、less和…

【爬虫工具】小红书评论高级采集软件

用python开发的爬虫采集工具【爬小红书搜索评论软件】,支持根据关键词采集评论。 思路:笔记关键词->笔记链接->评论 软件界面: 完整文章、详细了解: https://mp.weixin.qq.com/s/C_TuChFwh8Vw76hTGX679Q 好用的软件一起分…

去除字符串或字符串数组中字符串左侧的空格或指定字符numpy.char.lstrip()

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 去除字符串或字符串数组中 字符串左侧的空格或指定字符 numpy.char.lstrip() [太阳]选择题 请问关于以下代码表述错误的选项是? import numpy as np print("【执行】np.cha…

2024/9/22

系列文章目录 文章目录 系列文章目录前言一、两条腿走路二、编程语言能力提升1.廖雪峰的python课2.Leetcode(数据结构题) 三、机器学习能力提升1.统计学习方法 李航2.kaggle竞赛 四、神经网络能力提升1.神经网络与深度学习 邱锡鹏2.一套自己的万金油模板…