【进阶篇】YOLOv8实现K折交叉验证——解决数据集样本稀少和类别不平衡的难题,让你的模型评估更加稳健

news2024/12/24 9:03:29

在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


K折交叉验证

K折交叉验证(K-Fold Cross-Validation)是一种常用的机器学习模型评估方法,可以帮助我们评估模型的性能,特别适用于数据集相对较小的情况。

在K折交叉验证中,将原始数据集分成K个子集,然后依次将其中一个子集作为验证集,其余K-1个子集作为训练集进行模型训练和评估。这样可以得到K个模型,每个模型都在不同的验证集上进行评估,最后将K个模型的评估结果,求平均或取最优结果作为最终评估。

优点

  • 充分利用数据集: 在K折交叉验证中,整个数据集被划分为K个互斥的折叠(Folds)。每次训练模型时,都有K-1个折叠用于训练,而剩下的一个用于验证。这样,每个样本都有机会作为验证集的一部分,从而充分利用了数据集中的所有样本。
  • 减轻过拟合风险: 由于每个样本都会在训练集和验证集中出现,模型在验证集上的性能评估更具有代表性。这有助于减轻由于数据稀疏或类别不平衡导致的过拟合问题。
  • 更稳健的模型评估: K折交叉验证计算多次模型性能的平均值,提供了更稳健的性能评估。这对于对抗数据稀疏和类别不平衡等挑战的模型评估尤为重要,因为它减少了单次评估可能引入的随机性。
  • 处理类别不平衡: 如果数据集中某些类别的样本数量较少,K折交叉验证可以确保每个折叠中都包含这些少见类别的样本。这有助于确保模型在少见类别上的性能得到充分评估,并提高对类别不平衡的鲁棒性。

YOLOv8

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。

YOLOv8的K折交叉验证实现步骤

  1. 首先,将数据集划分成K个子集,可以使用现有的数据集划分函数或手动划分。
  2. 然后,使用一个循环迭代K次,每次将其中一个子集作为验证集,其余K-1个子集作为训练集。
  3. 在每次迭代中,使用训练集进行模型训练,并使用验证集进行模型评估。
  4. 可以根据需要调整模型的超参数或其他设置。例如,可以尝试不同的学习率、迭代次数等。
  5. 最后,将K次迭代中的评估结果进行平均或取最优结果作为最终的模型评估结果。

官方项目下载

git clone https://github.com/ultralytics/ultralytics.git

🚀 YOLOv8-K折交叉验证代码实现

在上述下载好的项目文件夹下新建k-fold-train.py文件,并添加下述代码:

import argparse
import datetime
from itertools import chain
import os
from pathlib import Path
import shutil
import yaml
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from ultralytics import YOLO

NUM_THREADS = min(8, max(1, os.cpu_count() - 1))

def parse_opt():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default=r'./data')  # 数据集路径
    parser.add_argument('--ksplit', default=5, type=int)  # K-Fold交叉验证拆分数据集
    parser.add_argument('--im_suffixes', default=['jpg', 'png', 'jpeg'], help='images suffix')  # 图片后缀名

    return parser.parse_args()

def run(func, this_iter, desc="Processing"):
    with ThreadPoolExecutor(max_workers=NUM_THREADS, thread_name_prefix='MyThread') as executor:
        results = list(
            tqdm(executor.map(func, this_iter), total=len(this_iter), desc=desc)
        )
    return results

def main(opt):
    dataset_path, ksplit, im_suffixes = Path(opt.data), opt.ksplit, opt.im_suffixes

    save_path = Path(dataset_path / f'{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-Valid')
    save_path.mkdir(parents=True, exist_ok=True)

    # 获取所有图像和标签文件的列表
    images = sorted(list(chain(*[(dataset_path / "images").rglob(f'*.{ext}') for ext in im_suffixes])))
    # images = sorted(image_files)
    labels = sorted((dataset_path / "labels").rglob("*.txt"))

    root_directory = Path.cwd()
    print("当前文件运行根目录:", root_directory)
    if len(images) != len(labels):
        print('*' * 20)
        print('当前数据集和标签数量不一致!!!')
        print('*' * 20)

    # 从YAML文件加载类名
    classes_file = sorted(dataset_path.rglob('classes.yaml'))[0]
    assert classes_file.exists(), "请创建classes.yaml类别文件"
    if classes_file.suffix == ".txt":
        pass
    elif classes_file.suffix == ".yaml":
        with open(classes_file, 'r', encoding="utf8") as f:
            classes = yaml.safe_load(f)['names']
    cls_idx = sorted(classes.keys())

    # 创建DataFrame来存储每张图像的标签计数
    indx = [l.stem for l in labels]  # 使用基本文件名作为ID(无扩展名)
    labels_df = pd.DataFrame([], columns=cls_idx, index=indx)

    # 计算每张图像的标签计数
    for label in labels:
        lbl_counter = Counter()
        with open(label, 'r') as lf:
            lines = lf.readlines()
        for l in lines:
            # YOLO标签使用每行的第一个位置的整数作为类别
            lbl_counter[int(l.split(' ')[0])] += 1
        labels_df.loc[label.stem] = lbl_counter

    # 用0.0替换NaN值
    labels_df = labels_df.fillna(0.0)

    kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # 设置random_state以获得可重复的结果
    kfolds = list(kf.split(labels_df))
    folds = [f'split_{n}' for n in range(1, ksplit + 1)]
    folds_df = pd.DataFrame(index=indx, columns=folds)

    # 为每个折叠分配图像到训练集或验证集
    for idx, (train, val) in enumerate(kfolds, start=1):
        folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'
        folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'

    # 计算每个折叠的标签分布比例
    fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)
    for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
        train_totals = labels_df.iloc[train_indices].sum()
        val_totals = labels_df.iloc[val_indices].sum()

        # 为避免分母为零,向分母添加一个小值(1E-7)
        ratio = val_totals / (train_totals + 1E-7)
        fold_lbl_distrb.loc[f'split_{n}'] = ratio

    ds_yamls = []

    for split in folds_df.columns:
        split_dir = save_path / split
        split_dir.mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)

        dataset_yaml = split_dir / f'{split}_dataset.yaml'
        ds_yamls.append(dataset_yaml.as_posix())
        split_dir = (root_directory / split_dir).as_posix()

        with open(dataset_yaml, 'w') as ds_y:
            yaml.safe_dump({
                'train': split_dir + '/train/images',
                'val': split_dir + '/val/images',
                'names': classes
            }, ds_y)
    # print(ds_yamls)
    with open(dataset_path / 'yaml_paths.txt', 'w') as f:
        for path in ds_yamls:
            f.write(path + '\n')

    args_list = [(image, save_path, folds_df) for image in images]

    run(split_images_labels, args_list, desc=f"Creating dataset")

def split_images_labels(args):
    image, save_path, folds_df = args
    label = image.parents[1] / 'labels' / f'{image.stem}.txt'
    if label.exists():
        for split, k_split in folds_df.loc[image.stem].items():
            # 目标目录
            img_to_path = save_path / split / k_split / 'images'
            lbl_to_path = save_path / split / k_split / 'labels'
            shutil.copy(image, img_to_path / image.name)
            shutil.copy(label, lbl_to_path / label.name)


if __name__ == "__main__":
    opt = parse_opt()
    main(opt)

    model = YOLO('yolov8n.pt', task='train')
    # 从文本文件中加载内容并存储到一个列表中
    ds_yamls = []
    with open(Path(opt.data) / 'yaml_paths.txt', 'r') as f:
        for line in f:
            # 去除每行末尾的换行符
            line = line.strip()
            ds_yamls.append(line)

    # 打印加载的文件路径列表
    print(ds_yamls)

    for k in range(opt.ksplit):
        dataset_yaml = ds_yamls[k]
        name = Path(dataset_yaml).stem
        model.train(
            data=dataset_yaml,
            batch=16,
            epochs=100,
            imgsz=640,
            device=0,
            workers=8,
            project="runs/train",
            name=name,)

    print("*"*40)
    print("K-Fold Cross Validation Completed.")
    print("*"*40)

命令行运行

python k-fold-train.py --data ./data --ksplit 5
  • data:数据集所在目录,数据集的分布如下:
    - data
    	- images
    	- labels
    
    images存放图片文件,labels存放txt标注文件。
  • ksplit:数据集拆分为ksplit个子集,该值通常取5或者10,某些情况下可取其他数值。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1326139.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JMeter接口测试高阶——精通JMeter接口测试之BeanShell及调用java和python脚本

文章目录 一、BeanShell组件二、BeanShell自带的语法(BeanShell常用变量和语法)1.log打印2.vars用来操作JMeter的局部变量(只能在一个线程组里面使用的变量)3.props用来操作JMeter的全局变量(能够跨线程组取值的变量&a…

2023 英特尔On技术创新大会直播 | 边云协同加速 AI 解决方案商业化落地

目录 前言边云协同时代背景边缘人工智能边缘挑战英特尔边云协同的创新成果最后 前言 最近观看了英特尔On技术创新大会直播,学到了挺多知识,其中对英特尔高级首席 AI 工程张宇博士讲解的边云协同加速 AI 解决方案商业化落地特别感兴趣。张宇博士讲解了英…

python实现一个图片查看器——可拖动、缩放和颜色画笔

目录 0 前言1 准备工作2 窗口布局3 图片显示功能3 图片拖拽功能4 图片缩放功能(难度大)5 画笔功能6 颜色选择功能后记源码 0 前言 在现如今的数字时代,我们对于图片的需求越来越大。无论是在工作中,还是在日常生活中,…

SLAM算法与工程实践——SLAM基本库的安装与使用(6):g2o优化库(1)g2o库的安装

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接,本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址: SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

PostGIS轨迹分析——横跨180°经线

问题描述 在处理AIS数据中,经常会遇到轨迹线横穿180经线的情况,这种数据绘制到地图上显示的非常乱,如下图所示: 数据模拟 在geojson.io上模拟一条轨迹线,可以看到轨迹显示的非常好,红框里面的经纬度超过…

使用Alpha Vantage API和Python进行金融数据分析

Alpha Vantage通过一套强大且开发者友好的数据API和电子表格,提供实时和历史的金融市场数据。从传统资产类别(例如股票、ETF、共同基金)到经济指标,从外汇汇率到大宗商品,从基本数据到技术指标,Alpha Vanta…

HashSet使用-力扣349做题总结

349. 两个数组的交集 分析代码HashSet出错的知识点1、HashSet新建2、HashSet添加add3、是否包含某元素4、集合->数组5、增强for循环 分析 没做出来的原因代码随想录的视频文字学习 为什么没做出来,因为没有理解好题意。根据示例1可知是去重的。且题目明确说“不考…

机器学习算法(12) — 集成技术(Boosting — Xgboost 分类)

一、说明 时间这是集成技术下的第 4 篇文章,如果您想了解有关集成技术的更多信息,您可以参考我的第 1 篇集成技术文章。 机器学习算法(9) - 集成技术(装袋 - 随机森林分类器和...... 在这篇文章中,我将解释…

先进制造身份治理现状洞察:从手动运维迈向自动化身份治理时代

在新一轮科技革命和产业变革的推动下,制造业正面临绿色化、智能化、服务化和定制化发展趋势。为顺应新技术革命及工业发展模式变化趋势,传统工业化理论需要进行修正和创新。其中,对工业化水平的判断标准从以三次产业比重标准为主回归到工业技…

服务器数据恢复-昆腾存储StorNext文件系统下raid5数据恢复案例

服务器数据恢复环境: 昆腾某型号存储,StorNext文件存储系统。 共有9个分别配置了24块磁盘的磁盘柜,其中8个磁盘柜存放普通数据,1个磁盘柜存放元数据。 存放元数据的磁盘柜中的24块磁盘组建了8组RAID1阵列和1组4盘RAID10阵列&#…

Ubuntu 常用命令之 history 命令用法介绍

📑Linux/Ubuntu 常用命令归类整理 history命令在Ubuntu系统中用于显示用户执行过的命令列表。这个命令在bash shell中非常有用,特别是当你需要记住你之前执行过的命令时。 history命令的参数如下 -c:清除历史记录。-d offset:删…

全功能知识付费小程序系统源码是什么?有什么好处?

全功能知识付费小程序系统源码,是一个集课程管理、用户管理、支付管理、数据分析等于一体的综合性解决方案。它支持多种形式的课程内容,如视频、音频、图文等,满足不同用户的学习需求。同时,系统具备完善的支付功能,保…

怎么开通独立站支付?独立站客户退款谁支付运费?——站斧浏览器

怎么开通独立站支付? 选择支付服务提供商:开通独立站支付首先需要选择一个可靠的支付服务提供商。目前市场上有许多知名的支付服务提供商,如支付宝、微信支付、PayPal等。根据自己的业务需求和目标市场选择合适的支付服务提供商。 注册账号…

目前电视盒子哪个最好?工程师盘点超值电视盒子推荐

因工作原因每天都会跟各种各样类型的电视盒子打交道,拆机、维修,身边朋友在挑选电视盒子的时候会问我目前电视盒子哪个最好,哪些电视盒子最值得入手,我整理了五款超值电视盒子推荐给大家,在挑选电视盒子时可以把这几款…

两套高质量可视化模板套件,需要进!

小编整理了两套高质量可视化模板套件,均来自于山海鲸可视化,需要源文件可私。 一、「星曜蓝」主题可视化模板 可以自由调用模板库中的所有内容,轻松搭建风格统一的地图、工厂、城市多种数字孪生项目。真免费、0代码数字孪生设计搭建&#xf…

可狱可囚的爬虫系列课程 07:BeautifulSoup4(bs4)库的使用

前面一直在讲 Requests 模块如何使用,那都是在请求阶段要做的事情,相信很多网友都在等一个能够开始爬网站信息的教程,今天它来了,今天我要给大家讲一个很简单易懂的库:BeautifulSoup4。 一、概述&安装 Beautiful…

BWS2000倾角传感器c++测试代码【2】

问题一:串口频率的初始化 由于本次项目之中使用的线长为40米的倾角传感器,需要对于其频率输出存在要求,如何测试其频率如下所示: 如上所示相应的软件,软件中存在一句如果设置后不保存,则存在传感器断电后设…

众和策略:大盘涨手中的股票却大跌,到底怎么回事?

大盘涨手中的股票却大跌,究竟怎么回事: 1、大盘上涨是权重股所造成的 大盘上涨可能是受一些权重比较大的工作所影响,比如证券工作、钢铁工作、银行工作等等,这些工作的大涨,可以拉升大盘的上涨,可是其它工…

C++20形式的utf-8字符串转宽字符串,不依赖编译器编码形式

默认的char[]编码都是要看编译器编译选项的,你选了ANSI那它就是ANSI,你选了UTF8那它就是UTF8. 【注意:经典DevC只支持ANSI编码(痛苦);上图是小熊猫DevC,则有这个选项】 这一点对我的代码造成了…

20231220将NanoPC-T4(RK3399)开发板的Android10的SDK按照Rockchip官方挖掘机开发板编译打包刷机之后启动跑飞

20231220将NanoPC-T4(RK3399)开发板的Android10的SDK按照Rockchip官方挖掘机开发板编译打包刷机之后启动跑飞 2023/12/20 17:19 简略步骤:rootrootrootroot-X99-Turbo:~/3TB$ tar --use-compress-programpigz -xvpf rk3399-android-10.git-20210201.tgz rootrootro…