YOLO自定义数据集实现K折交叉验证——K-Fold Cross Validation

news2025/2/13 8:08:55

K-Fold

实现K折交叉验证(K-Fold Cross Validation)对于YOLO(You Only Look Once)自定义数据集的目标检测任务可以显著提升模型的可靠性和泛化能力。

1. 数据集准备

首先,你需要确保你的数据集符合YOLO的格式,具体来说,每个图像都有相应的标注文件,格式如下:

  • 每行包含:class_id center_x center_y width height
  • class_id 是类别的编号,center_xcenter_y 是物体中心的归一化坐标,widthheight 是物体框的归一化宽度和高度。

假设你已经准备好了数据集(例如自定义的水果检测数据集),其中图像和标注文件分别存储在 imageslabels 目录下。

2. 必要的Python包

你需要安装一些必要的Python库:

pip install -U ultralytics scikit-learn pandas pyyaml

3. 数据集标注和类定义

假设你有一个 data.yaml 文件,它定义了数据集的路径和类别。一个示例 data.yaml 文件可能如下:

train: ./Fruit-detection/images/train
val: ./Fruit-detection/images/val
names:
  0: Apple
  1: Grapes
  2: Pineapple
  3: Orange
  4: Banana
  5: Watermelon

确保你的数据集标注文件(例如 trainval 目录中的标注文件)符合此结构。

4. 数据准备和生成特征向量

你需要先生成一个表示数据集的特征向量(每个图像包含每个类的数量)。以下是生成特征向量的代码:

import pandas as pd
from pathlib import Path
from collections import Counter
import yaml

# 设置数据集路径
dataset_path = Path("./Fruit-detection")
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # 读取所有标注文件

# 读取data.yaml文件,提取类标签
yaml_file = "path/to/data.yaml"
with open(yaml_file, "r", encoding="utf8") as y:
    classes = yaml.safe_load(y)["names"]

# 初始化一个空的DataFrame
cls_idx = sorted(classes.keys())
index = [label.stem for label in labels]  # 使用文件名作为索引
labels_df = pd.DataFrame([], columns=cls_idx, index=index)

# 统计每个类的实例数量
for label in labels:
    lbl_counter = Counter()
    with open(label, "r") as lf:
        lines = lf.readlines()
    for line in lines:
        lbl_counter[int(line.split(" ")[0])] += 1
    labels_df.loc[label.stem] = lbl_counter

labels_df = labels_df.fillna(0.0)  # 填充缺失值为0

5. 使用K折交叉验证进行数据拆分

使用 sklearn.model_selection.KFold 来拆分数据集。这里我们使用5折交叉验证(k=5),你可以根据需要调整 k 的值。

from sklearn.model_selection import KFold

ksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # 设置随机种子以便结果可复现

# 获取数据集的索引拆分
kfolds = list(kf.split(labels_df))

# 显示每个fold的训练和验证集
folds_df = pd.DataFrame(index=index, columns=[f"split_{n}" for n in range(1, ksplit + 1)])
for i, (train, val) in enumerate(kfolds, start=1):
    folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"
    folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"

6. 计算每个fold的标签分布

为了确保每个fold的类别分布平衡,可以计算每个fold中每个类的数量比例。

fold_lbl_distrb = pd.DataFrame(index=[f"split_{n}" for n in range(1, ksplit + 1)], columns=cls_idx)
for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
    train_totals = labels_df.iloc[train_indices].sum()
    val_totals = labels_df.iloc[val_indices].sum()

    # 计算验证集与训练集的标签比例
    ratio = val_totals / (train_totals + 1e-7)  # 避免除0错误
    fold_lbl_distrb.loc[f"split_{n}"] = ratio

7. 创建K折数据集文件夹和YAML文件

为每个fold创建训练和验证数据集的文件夹,并生成相应的 dataset.yaml 配置文件。

import shutil
import datetime

save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)

# 创建目录和YAML文件
ds_yamls = []
for split in folds_df.columns:
    split_dir = save_path / split
    split_dir.mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)

    dataset_yaml = split_dir / f"{split}_dataset.yaml"
    ds_yamls.append(dataset_yaml)

    with open(dataset_yaml, "w") as ds_y:
        yaml.safe_dump({
            "path": split_dir.as_posix(),
            "train": "train",
            "val": "val",
            "names": classes,
        }, ds_y)

# 复制图像和标签文件到对应的目录
images = sorted((dataset_path / "images").rglob("*"))
for image, label in zip(images, labels):
    for split, k_split in folds_df.loc[image.stem].items():
        img_to_path = save_path / split / k_split / "images"
        lbl_to_path = save_path / split / k_split / "labels"
        shutil.copy(image, img_to_path / image.name)
        shutil.copy(label, lbl_to_path / label.name)

8. 训练YOLO模型

创建一个YOLO模型并使用每个fold的数据进行训练。训练完成后,你可以保存模型并记录性能指标。

from ultralytics import YOLO

weights_path = "path/to/weights.pt"  # YOLO预训练权重文件路径
model = YOLO(weights_path, task="detect")

# 训练每个fold的数据
results = {}
batch = 16
epochs = 100
project = "kfold_demo"

for k in range(ksplit):
    dataset_yaml = ds_yamls[k]
    model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)
    results[k] = model.metrics  # 保存训练结果

9. 结果分析

你可以从 results 中提取每个fold的训练指标进行进一步分析。例如,可以计算每个fold的mAP(mean Average Precision)并进行比较,确保模型的稳定性和泛化能力。


结论

通过上述步骤,你可以在YOLO自定义数据集上实现K折交叉验证。K折交叉验证的优点是能够减少模型过拟合的风险,确保模型在不同数据划分上的泛化能力,提升其性能可靠性。

这些步骤是通用的,可以根据自己的数据集进行修改和优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2297249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go语言简单快速的按顺序遍历kv结构(map)

文章目录 需求描述用map实现按照map的key排序用二维切片实现用结构体实现 需求描述 在go语言中,如果需要对map遍历,每次输出的顺序是不固定的,可以考虑存储为二维切片或结构体。 假设现在需要在页面的下拉菜单中展示一些基础的选项&#xff…

【竞技宝】LOL-LPL:EDG3-0零封LNG

北京时间2月12日,英雄联盟LPL2025正在如火如荼的进行之中,昨日迎来LNG对阵EDG,以下是本场比赛的详细战报。 第一局: EDG:杰斯、赵信、维克托、女枪、芮尔 LNG:猴子、猪妹、飞机、韦鲁斯、布隆 首局比赛,EDG在蓝色方,LNG在红色方。阵容方面,EDG点出了杰斯、赵信、维克托、女枪…

在fedora41中安装钉钉dingtalk_7.6.25.4122001_amd64

在Fedora-Workstation-Live-x86_64-41-1.4中安装钉钉dingtalk_7.6.25.4122001_amd64.deb 到官网下载钉钉Linux客户端com.alibabainc.dingtalk_7.6.25.4122001_amd64.deb https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload#/ 一、直接使用dpkg命令安装deb包报错…

看期货用的指标,可以提示买卖点和K线转折变颜色的主图指标源码下载

A:MA(CLOSE,17)ABS(MA(CLOSE,17)-REF(MA(CLOSE,17),1)); B:MA(CLOSE,17)MA(CLOSE,17)-REF(MA(CLOSE,17),1); 分界线:IF(MA(CLOSE,17)<B,B,MA(CLOSE,17)),COLORFF00FF,LINETHICK2; 操作线:分界线-(EMA(C,3)-分界线),COLOR00FFFF,LINETHICK2; GUP:MA(C,5),COLORWHITE,LINE…

【PS 2022】Adobe Genuine Service Alert 弹出

电脑总是弹出Adobe Genuine Service Alert弹窗 1. 不关掉弹窗并打开任务管理器&#xff0c;找到Adobe Genuine Service Alert&#xff0c;并右键进入文件所在位置 2 在任务管理器中结束进程并将文件夹中的 .exe 文件都使用空文档替换掉 3. 打开PS不弹出弹窗&#xff0c;解决&a…

30天开发操作系统 第 20 天 -- API

前言 大家早上好&#xff0c;今天我们继续努力哦。 昨天我们已经实现了应用程序的运行, 今天我们来实现由应用程序对操作系统功能的调用(即API, 也叫系统调用)。 为什么这样的功能称为“系统调用”(system call)呢&#xff1f;因为它是由应用程序来调用(操作)系统中的功能来完…

蓝桥杯(B组)-每日一题(求最大公约数最小公倍数)

题目&#xff1a; 代码展现&#xff1a; #include<iostream> using namespace std; int main() {int m,n,x,y;cin>>m>>n;//输入两个整数int b;bm%n;//取余数xm;//赋值yn;while(b)//当余数不为0的时候{xy;//辗转相除求最小公约数yb;bx%y;}cout<<y<&…

arduino扩展:Arduino Mega 控制 32 个舵机(参考表情机器人)

参考&#xff1a;表情机器人中使用22个舵机的案例 引言 在电子制作与自动化控制领域&#xff0c;Arduino 凭借其易用性和强大的扩展性备受青睐。Arduino Mega 作为其中功能较为强大的一款开发板&#xff0c;具备丰富的引脚资源&#xff0c;能够实现复杂的控制任务。舵机作为常…

基于51单片机的门禁刷卡器proteus仿真

地址&#xff1a;https://pan.baidu.com/s/1j0KAmH5pVGWZWRpT6p5hBg 提取码&#xff1a;1234 仿真图&#xff1a; 芯片/模块的特点&#xff1a; AT89C52/AT89C51简介&#xff1a; AT89C52/AT89C51是一款经典的8位单片机&#xff0c;是意法半导体&#xff08;STMicroelectron…

mapbox进阶,添加绘图扩展插件,裁剪线

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图扩…

19.4.6 读写数据库中的二进制数据

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 需要北风数据库的请留言自己的信箱。 北风数据库中&#xff0c;类别表的图片字段在【数据表视图】中显示为Bitmap Image&#xff1…

MapReduce到底是个啥?

在聊 MapReduce 之前不妨先看个例子&#xff1a;假设某短视频平台日活用户大约在7000万左右&#xff0c;若平均每一个用户产生3条行为日志&#xff1a;点赞、转发、收藏&#xff1b;这样就是两亿条行为日志&#xff0c;再假设每条日志大小为100个字节&#xff0c;那么一天就会产…

Winform自定义控件与案例 - 构建炫酷的自定义环形进度条控件

文章目录 1、控件效果2、案例实现1、代码实现2、代码解释3、使用示例 4、总结 1、控件效果 2、案例实现 1、代码实现 代码如下&#xff08;示例&#xff09;&#xff1a; using System; using System.ComponentModel; using System.Drawing; using System.Drawing.Drawing2D; …

【SpringBoot苍穹外卖】debugDay03.5

1、AOP面向切面编程 1. Target(ElementType.METHOD) 作用&#xff1a;指定自定义注解可以应用的目标范围。 参数&#xff1a;ElementType 是一个枚举类&#xff0c;定义了注解可以应用的目标类型。 ElementType.METHOD 表示该注解只能用于方法上。 其他常见的 ElementType 值…

flink实时集成利器 - apache seatunnel - 核心架构详解

SeaTunnel&#xff08;原名 Waterdrop&#xff09;是一个分布式、高性能、易扩展的数据集成平台&#xff0c;专注于大数据领域的数据同步、数据迁移和数据转换。它支持多种数据源和数据目标&#xff0c;并可以与 Apache Flink、Spark 等计算引擎集成。以下是 SeaTunnel 的核心架…

视频理解新篇章:Mamba模型的探索与应用

人工智能咨询培训老师叶梓 转载标明出处 想要掌握如何将大模型的力量发挥到极致吗&#xff1f;叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具&#xff08;限时免费&#xff09;。 1小时实战课程&#xff0c;您将学习到如何轻松上手并有效利用 Llama Facto…

分形几何表明数学一直存在有首、末的无穷序列

分形几何表明数学一直存在有首、末的无穷序列。一有穷长直线段S可变为锯齿状图形G而由无穷多无穷短直线段连接而成。G和S一样有左、右两个端点。

Python接口自动化测试—接口数据依赖

一般在做自动化测试时&#xff0c;经常会对一整套业务流程进行一组接口上的测试&#xff0c;这时候接口之间经常会有数据依赖&#xff0c;那又该如何继续呢&#xff1f; 那么有如下思路&#xff1a; 抽取之前接口的返回值存储到全局变量字典中。初始化接口请求时&#xff0c;…

C++ 实践扩展(Qt Creator 联动 Visual Studio 2022)

​ 这里我们将在 VS 上实现 QT 编程&#xff0c;实现如下&#xff1a; 一、Vs 2022 配置&#xff08;若已安装&#xff0c;可直接跳过&#xff09; 点击链接&#xff1a;​​​​​Visual Studio 2022 我们先去 Vs 官网下载&#xff0c;如下&#xff1a; 等待程序安装完成之…

Java中性能瓶颈的定位与调优方法

Java中性能瓶颈的定位与调优方法 Java作为一种高效、跨平台的编程语言&#xff0c;广泛应用于企业级应用、服务器端开发、分布式系统等领域。然而&#xff0c;在面对大量并发、高负载的生产环境时&#xff0c;Java应用的性能瓶颈问题往往会暴露出来。如何定位并优化这些性能瓶…