【目标检测】模型验证:K-Fold 交叉验证

news2025/2/8 2:55:24

K-Fold 交叉验证

  • 1、引言
    • 1.1 K 折交叉验证概述
  • 2、配置
    • 2.1 数据集
    • 2.2 安装包
  • 3、 实战
    • 3.1 生成物体检测数据集的特征向量
    • 3.2 K 折数据集拆分
    • 3.3 保存记录
    • 3.4 使用 K 折数据分割训练YOLO
  • 4、总结

1、引言

我们将利用YOLO 检测格式和关键的Python 库(如 sklearn、pandas 和 PyYaml),完成必要的设置、生成特征向量的过程以及 K-Fold 数据集拆分的执行。

1.1 K 折交叉验证概述

无论你的项目涉及水果检测数据集还是自定义数据源,都可以使用 K 折交叉验证,
以提高项目的可靠性和稳健性。

书说简短,闲言少叙,咱进入正题
在这里插入图片描述

2、配置

2.1 数据集

该数据集共包含 8479 幅图像。
它包括 6 个类别标签,每个标签的实例总数如下:

类别计数
苹果7049
葡萄7202
菠萝1613
橙色15549
香蕉3536
西瓜1976

2.2 安装包

必要的Python 软件包包括

  • ultralytics
  • sklearn
  • pandas
  • pyyaml

这次实例中,我们使用 k=5 折叠次数

3、 实战

3.1 生成物体检测数据集的特征向量

具体步骤如下:

  • 1、首先创建一个新的 demo.py Python 文件来执行下面的步骤。

  • 2、继续检索数据集的所有标签文件。

from pathlib import Path

dataset_path = Path("./Fruit-detection")  # replace with 'path/to/dataset' for your custom data
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # all data in 'labels'
  • 3、现在,读取数据集 YAML 文件的内容并提取类标签的索引。
yaml_file = "path/to/data.yaml"  # your data YAML with data directories and names dictionary
with open(yaml_file, "r", encoding="utf8") as y:
    classes = yaml.safe_load(y)["names"]
cls_idx = sorted(classes.keys())
  • 4、初始化一个空的 pandas DataFrame.
import pandas as pd

index = [label.stem for label in labels]  # uses base filename as ID (no extension)
labels_df = pd.DataFrame([], columns=cls_idx, index=index)
  • 5、计算注释文件中每个类别标签的实例数。
from collections import Counter

for label in labels:
    lbl_counter = Counter()

    with open(label, "r") as lf:
        lines = lf.readlines()

    for line in lines:
        # classes for YOLO label uses integer at first position of each line
        lbl_counter[int(line.split(" ")[0])] += 1

    labels_df.loc[label.stem] = lbl_counter

labels_df = labels_df.fillna(0.0)  # replace `nan` values with `0.0`
  • 6、以下是已填充 DataFrame 的示例视图:
                                                       0    1    2    3    4    5
'0000a16e4b057580_jpg.rf.00ab48988370f64f5ca8ea4...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.7e6dce029fb67f01eb19aa7...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.bc4d31cdcbe229dd022957a...'  0.0  0.0  0.0  0.0  0.0  7.0
'00020ebf74c4881c_jpg.rf.508192a0a97aa6c4a3b6882...'  0.0  0.0  0.0  1.0  0.0  0.0
'00020ebf74c4881c_jpg.rf.5af192a2254c8ecc4188a25...'  0.0  0.0  0.0  1.0  0.0  0.0
 ...                                                  ...  ...  ...  ...  ...  ...
'ff4cd45896de38be_jpg.rf.c4b5e967ca10c7ced3b9e97...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff4cd45896de38be_jpg.rf.ea4c1d37d2884b3e3cbce08...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff5fd9c3c624b7dc_jpg.rf.bb519feaa36fc4bf630a033...'  1.0  0.0  0.0  0.0  0.0  0.0
'ff5fd9c3c624b7dc_jpg.rf.f0751c9c3aa4519ea3c9d6a...'  1.0  0.0  0.0  0.0  0.0  0.0
'fffe28b31f2a70d4_jpg.rf.7ea16bd637ba0711c53b540...'  0.0  6.0  0.0  0.0  0.0  0.0

解析

  • 行是标签文件的索引,每个标签文件对应数据集中的一幅图像,列则对应类标签索引。
  • 每一行代表一个伪特征向量,其中包含数据集中每个类标签的计数。
  • 这种数据结构可以将 K 折交叉验证应用于对象检测数据集。

3.2 K 折数据集拆分

  • 1、使用 KFold 从 sklearn.model_selection 以产生 k 对数据集进行分割。

    • 敲黑板:
      • 设置 shuffle=True 确保了分班中班级的随机分布。
      • 通过设置 random_state=M 其中 M 是一个选定的整数,这样就可以得到可重复的结果。
from sklearn.model_selection import KFold

ksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # setting random_state for repeatable results

kfolds = list(kf.split(labels_df))
  • 2、数据集现已分为 k 折叠,每个折叠都有一个 train 和 val 指数。我们将构建一个 DataFrame 来更清晰地显示这些结果。
folds = [f"split_{n}" for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=index, columns=folds)

for i, (train, val) in enumerate(kfolds, start=1):
    folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"
    folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"
  • 3、将计算每个褶皱的类别标签分布,并将其作为褶皱中出现的类别的比率。
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)

for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
    train_totals = labels_df.iloc[train_indices].sum()
    val_totals = labels_df.iloc[val_indices].sum()

    # To avoid division by zero, we add a small value (1E-7) to the denominator
    ratio = val_totals / (train_totals + 1e-7)
    fold_lbl_distrb.loc[f"split_{n}"] = ratio
最理想的情况是,每次分割和不同类别的所有类别比率都相当相似。不过,这取决于数据集的具体情况。
  • 4、为每个分割创建目录和数据集 YAML 文件。
import datetime

supported_extensions = [".jpg", ".jpeg", ".png"]

# Initialize an empty list to store image file paths
images = []

# Loop through supported extensions and gather image files
for ext in supported_extensions:
    images.extend(sorted((dataset_path / "images").rglob(f"*{ext}")))

# Create the necessary directories and dataset YAML files (unchanged)
save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)
ds_yamls = []

for split in folds_df.columns:
    # Create directories
    split_dir = save_path / split
    split_dir.mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)

    # Create dataset YAML files
    dataset_yaml = split_dir / f"{split}_dataset.yaml"
    ds_yamls.append(dataset_yaml)

    with open(dataset_yaml, "w") as ds_y:
        yaml.safe_dump(
            {
                "path": split_dir.as_posix(),
                "train": "train",
                "val": "val",
                "names": classes,
            },
            ds_y,
        )
  • 5、最后,将图像和标签复制到每个分割的相应目录("train "或 “val”)中。
import shutil

for image, label in zip(images, labels):
    for split, k_split in folds_df.loc[image.stem].items():
        # Destination directory
        img_to_path = save_path / split / k_split / "images"
        lbl_to_path = save_path / split / k_split / "labels"

        # Copy image and label files to new directory (SamefileError if file already exists)
        shutil.copy(image, img_to_path / image.name)
        shutil.copy(label, lbl_to_path / label.name)

3.3 保存记录

将 K 折分割和标签分布数据框的记录保存为 CSV 文件。

folds_df.to_csv(save_path / "kfold_datasplit.csv")
fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")

3.4 使用 K 折数据分割训练YOLO

  • 首先,加载YOLO 模型。
from ultralytics import YOLO

weights_path = "path/to/weights.pt"
model = YOLO(weights_path, task="detect")
  • 其次,遍历数据集 YAML 文件以运行训练。结果将保存到由 project 和 name 参数。默认情况下,该目录为 “exp/runs#”,其中 # 为整数索引。
results = {}

# Define your additional arguments here
batch = 16
project = "kfold_demo"
epochs = 100

for k in range(ksplit):
    dataset_yaml = ds_yamls[k]
    model = YOLO(weights_path, task="detect")
    model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)  # include any train arguments
    results[k] = model.metrics  # save output metrics for further analysis

4、总结

这篇小鱼使用了 K 折交叉验证来训练YOLO 物体检测模型的过程。

还创建报告 DataFrames 的程序,以可视化数据拆分和标签在这些拆分中的分布,清楚地了解训练集和验证集的结构。

此外,还保存了记录,这在大型项目或排除模型性能故障时尤为有用。

最后,在一个循环中使用每个拆分来执行实际的模型训练,保存训练结果,以便进一步分析和比较。

这种 K 折交叉验证技术是充分利用可用数据的一种稳健方法,有助于确保模型在不同数据子集中的性能是可靠和一致的。这将产生一个更具通用性和可靠性的模型,从而减少对特定数据模式的过度拟合。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)评测一等奖获得者

关注小鱼,学习【人工智能&大模型】/【深度学习&机器学习】领域最新最全的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2294568.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity 2D实战小游戏开发跳跳鸟 - 计分逻辑开发

上文对障碍物的碰撞逻辑进行了开发,接下来就是进行跳跳鸟成功穿越过障碍物进行计分的逻辑开发,同时将对应的分数以UI的形式显示告诉玩家。 计分逻辑 在跳跳鸟通过障碍物的一瞬间就进行一次计分,计分后会同步更新分数的UI显示来告知玩家当前获得的分数。 首先我们创建一个用…

京准:NTP卫星时钟服务器对于DeepSeek安全的重要性

京准:NTP卫星时钟服务器对于DeepSeek安全的重要性 京准:NTP卫星时钟服务器对于DeepSeek安全的重要性 在网络安全领域,分布式拒绝服务(DDoS)攻击一直是企业和网络服务商面临的重大威胁之一。随着攻击技术的不断演化…

Android学习20 -- 手搓App2(Gradle)

1 前言 昨天写了一个完全手搓的:Android学习19 -- 手搓App-CSDN博客 后面谷歌说不要用aapt,d8这些来搞。其实不想弄Gradle的,不过想着既然开始了,就多看一些。之前写过一篇Gradle,不过是最简单的编译,不涉…

车型检测7种YOLOV8

车型检测7种YOLOV8,采用YOLOV8NANO训练,得到PT模型,转换成ONNX,然后OPENCV的DNN调用,支持C,python,android开发 车型检测7种YOLOV8

IDEA 中集成 Maven,配置环境、创建以及导入项目

目录 在 IntelliJ IDEA 中集成 Maven 并配置环境 1. 打开 IDEA 设置 2. 定位 Maven 配置选项 3. 配置 Maven 路径 4. 应用配置 创建 Maven 项目 1. 新建项目 2. 选择项目类型 3. 配置项目信息 4. 确认 Maven 设置 5. 完成项目创建 导入 Maven 项目 1. 打开导入窗口…

react关于手搓antd pro面包屑的经验(写的不好请见谅)

我们先上代码,代码里面都有注释,我是单独写了一个组件,方便使用,在其他页面引入就行了 还使用了官方的Breadcrumb组件 import React, { useEffect, useState } from react; import { Breadcrumb, Button } from antd; import { …

[含文档+PPT+源码等]精品大数据项目-Django基于大数据实现的心血管疾病分析系统

大数据项目-Django基于大数据实现的心血管疾病分析系统背景可以从以下几个方面进行阐述: 一、项目背景与意义 1. 心血管疾病现状 心血管疾病是当前全球面临的主要健康挑战之一,其高发病率、高致残率和高死亡率严重威胁着人类的生命健康。根据权威机构…

【Rust自学】19.5. 高级类型

喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 19.5.1.使用newtype模式实现类型安全和抽象 在 19.2. 高级trait 中(具体来说是…

113,【5】 功防世界 web unseping

进入靶场 代码审计 <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;方便开发者查看代码结构和内容 highlight_file(__FILE__);// 定义一个名为 ease 的类 class ease {// 私有属性 $method&#xff0c;用于存储要调用的方法名private $method;// 私有属性 $args&…

leetCode刷题-图、回溯相关

岛屿数量 class Solution { private:int mi;int mj; public:int numIslands(vector<vector<char>>& grid) {mi grid.size() - 1; // i的范围 0~mimj grid[0].size() - 1; // j的范围 0~mjint landnum 0;bool sea false;do {pair<int, int> res …

Windows编程:下载与安装 Visual Studio 2010

本节前言 在写作本节的时候&#xff0c;本来呢&#xff0c;我正在写的专栏&#xff0c;是 MFC 专栏。而 VS2010 和 VS2019&#xff0c;正是 MFC 学习与开发中&#xff0c;可以使用的两款软件。然而呢&#xff0c;如果你去学习 Windows API 知识的话&#xff0c;那么&#xff0…

OpenEuler学习笔记(十八):搭建企业云盘服务

要在 OpenEuler 上搭建企业云盘&#xff0c;可借助一些开源软件来实现&#xff0c;以下以 Nextcloud 为例详细介绍搭建步骤。Nextcloud 是一款功能丰富的开源云存储解决方案&#xff0c;支持文件共享、同步、协作等多种功能。 1. 系统环境准备 确保 OpenEuler 系统已更新到最…

什么是三层交换技术?与二层有什么区别?

什么是三层交换技术&#xff1f;让你的网络飞起来&#xff01; 一. 什么是三层交换技术&#xff1f;二. 工作原理三. 优点四. 应用场景五. 总结 前言 点个免费的赞和关注&#xff0c;有错误的地方请指出&#xff0c;看个人主页有惊喜。 作者&#xff1a;神的孩子都在歌唱 大家好…

Ollama+deepseek+Docker+Open WebUI实现与AI聊天

1、下载并安装Ollama 官方网址&#xff1a;Ollama 安装好后&#xff0c;在命令行输入&#xff0c; ollama --version 返回以下信息&#xff0c;则表明安装成功&#xff0c; 2、 下载AI大模型 这里以deepseek-r1:1.5b模型为例&#xff0c; 在命令行中&#xff0c;执行&…

Linux生成自签证书【Nginx】

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

网络安全 | 加密技术揭秘:保护数据隐私的核心

网络安全 | 加密技术揭秘&#xff1a;保护数据隐私的核心 一、前言二、对称加密技术2.1 原理2.2 优点2.3 缺点2.4 应用场景 三、非对称加密技术3.1 原理3.2 优点3.3 缺点3.4 应用场景 四、哈希函数4.1 原理4.2 优点4.3 缺点4.4 应用场景 五、数字签名5.1 原理5.2 优点5.3 缺点5…

使用服务器部署DeepSeek-R1模型【详细版】

文章目录 引言deepseek-r1IDE或者终端工具算力平台体验deepseek-r1模型总结 引言 在现代的机器学习和深度学习应用中&#xff0c;模型部署和服务化是每个开发者面临的重要任务。无论是用于智能推荐、自然语言处理还是图像识别&#xff0c;如何高效、稳定地将深度学习模型部署到…

DirectX11 With Windows SDK--02 顶点/像素着色器的创建、顶点缓冲区

Direct3D 11 总结 —— 4 绘制三角形_direct绘制三角形-CSDN博客 DirectX11 With Windows SDK--02 顶点/像素着色器的创建、顶点缓冲区 - X_Jun - 博客园 练习题 粗体字为自定义题目 尝试交换三角形第一个和第三个顶点的数据&#xff0c;屏幕将显示什么&#xff1f;为什么&…

第二次连接k8s平台注意事项

第二次重新打开集群平台 1.三台机子要在VMware打开 2.MobaBXterm连接Session 3.三个机子docker重启 systemctl restart docker4.主节点进行平台链接 docker pull kubeoperator/kubepi-server[rootnode1 home]# docker pull kubeoperator/kubepi-server [rootnode1 home]# # 运…

Mybatis篇

1&#xff0c;什么是Mybatis &#xff08; 1 &#xff09;Mybatis 是一个半 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;它内部封装了 JDBC&#xff0c;开发时只需要关注 SQL 语句本身&#xff0c;不需要花费精力去处理加载驱动、创建连接、创建 statement 等繁…