如何正确拆分数据集?常见方法最全汇总

news2024/11/25 12:41:59

将数据集划分为训练集(Training)和测试集(Testing)是机器学习和统计建模中的重要步骤:

训练集(Training):一般来说 Train 训练集会进一步再分为 Train 训练集与 Validation 验证集两部分,以评价不同参数组合的效果,以确定最终的模型

测试集(Testing):Test 测试集自始至终没有参与到模型的训练过程;它的目的只有一个:在确定一个最终模型后,评价其泛化能力

所以综上,数据集划分有以下原因和好处:

  • 评估模型的泛化能力: 训练集用于构建模型,测试集用于评估模型在未见数据上的表现,通过测试集的性能来估计模型的泛化能力,即模型在新数据上的预测准确性;
  • 防止过拟合:模型在训练集上表现很好,但在测试集上表现不佳,通常是因为模型过于复杂,捕捉了训练数据中的噪声,使用测试集可以帮助检测过拟合现象,确保模型不仅仅是在记忆训练数据;
  • 模型选择和参数调优:通过在测试集上的表现来比较不同模型的优劣,测试集提供了一个客观的标准来选择最佳的超参数组合;

一、简单拆分

1、简单随机拆分

简单随机拆分法是将数据集随机分成训练集和测试集的一种方法,该方法直观,易于实现,不需要复杂的算法或技术;随机拆分可以减少样本选择偏差,确保训练集和测试集的代表性。

缺点:在划分数据集时,不能确保每个类别的样本按照其在总体中的比例被选入训练集和测试集(如下示例中,划分为测试集的两个类别数据有可能出现类别都为 0 或者类别都为 1)

我们一般将数据集按一定比例(如 80:20)随机拆分,可以使用 sklearn 库中的 train_test_split 函数来实现,示例如下

import pandas as pd
from sklearn.model_selection import train_test_split

# 创建一个示例数据集
data = {
    'feat1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'feat2': [10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
    'label': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
}
df = pd.DataFrame(data)

# 特征和标签
X = df[['feat1', 'feat2']]
y = df['label']

# 使用简单随机拆分法将数据集分为训练集和测试集
# 使用 stratify 确保每个类别的样本在训练集和测试集中都有代表
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 输出结果
print("训练集特征:")
print(X_train)
print("\n测试集特征:")
print(X_test)
print("\n训练集标签:")
print(y_train)
print("\n测试集标签:")
print(y_test)

数据集划分的结果如下

# 训练集特征:
#    feat1  feat2
# 4      5      6
# 3      4      7
# 7      8      3
# 5      6      5
# 1      2      9
# 2      3      8
# 9     10      1
# 6      7      4

# 测试集特征:
#    feat1  feat2
# 8      9      2
# 0      1     10

# 训练集标签:
# 4    0
# 3    1
# 7    1
# 5    1
# 1    1
# 2    0
# 9    1
# 6    0
# Name: label, dtype: int64

# 测试集标签:
# 8    0
# 0    0
# Name: label, dtype: int64
2、分层抽样

在划分数据集时,确保每个类别的样本按照其在总体中的比例被选入训练集和测试集,通过对每个类别进行单独抽样,确保各类别样本的代表性,其适用于类别不均衡的数据集,优点在于:

  • 保证每个类别在训练集和测试集中都有足够的样本,有助于提高模型的泛化能力。
  • 减少了由于类别不平衡导致的偏差

我们可以使用 sklearn 库中的 train_test_split 函数,并设置 stratify 参数来实现,示例如下

import pandas as pd
from sklearn.model_selection import train_test_split

# 创建一个示例数据集
data = {
    'feat1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'feat2': [10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
    'label': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
}
df = pd.DataFrame(data)

# 特征和标签
X = df[['feat1', 'feat2']]
y = df['label']

# 使用简单随机拆分法将数据集分为训练集和测试集
# 使用 stratify 确保每个类别的样本在训练集和测试集中都有代表
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

# 输出结果
print("训练集特征:")
print(X_train)
print("\n测试集特征:")
print(X_test)
print("\n训练集标签:")
print(y_train)
print("\n测试集标签:")
print(y_test)

数据集划分的结果如下

# 训练集特征:
#    feat1  feat2
# 3      4      7
# 9     10      1
# 4      5      6
# 1      2      9
# 5      6      5
# 0      1     10
# 8      9      2
# 6      7      4

# 测试集特征:
#    feat1  feat2
# 7      8      3
# 2      3      8

# 训练集标签:
# 3    1
# 9    1
# 4    0
# 1    1
# 5    1
# 0    0
# 8    0
# 6    0
# Name: label, dtype: int64

# 测试集标签:
# 7    1
# 2    0
# Name: label, dtype: int64

二、交叉验证

1、K折交叉验验

将数据集随机分成 K 个子集(折),每次用 K-1 个子集进行训练,剩余的 1 个子集用于测试,重复 K 次,每次选择不同的子集作为测试集。通过计算所有测试结果的平均值作为模型的最终性能评估指标

优点

  • 充分利用数据:每个数据点都可以作为训练集和测试集的一部分,增加了数据的利用率
  • 更稳定的评估:提供了对模型性能的更稳定和可靠的估计,因为它考虑了多次划分的结果
  • 提高准确性:通过多次训练和测试,减少了因数据划分不同而导致的波动,提高模型的准确性
  • 灵活性:可以根据需要调整 K 的大小,以适应不同的数据集和计算资源

缺点

  • 计算成本高:需要训练 K 次模型,计算时间和资源消耗较大
  • 选择合适的 K 值困难:不同的 K 值可能导致不同的结果,选择不当可能影响模型评估的准确性
  • 数据分布问题:如果数据集不平衡,可能需要使用分层 K 折交叉验证来确保各类别的代表性

在 Python 中,可以使用 scikit-learn 库来实现 K 折交叉验证。以下是一个简单的实现示例:

from sklearn.model_selection import KFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import numpy as np

# 加载数据集
data = load_iris()
X, y = data.data, data.target

# 定义模型
model = RandomForestClassifier()

# 定义 K 折交叉验证
k = 10
kf = KFold(n_splits=k, shuffle=True, random_state=42)

# 进行交叉验证
scores = cross_val_score(model, X, y, cv=kf)

# 输出结果
print(f"K折交叉验证的平均准确率: {scores.mean():.2f}")
print(f"每折的准确率: {np.round(scores, 2)}")

# -------- 输出 --------
# K折交叉验证的平均准确率: 0.96
# 每折的准确率: [1. 1. 1. 0.93 1. 0.87 0.87 1. 1. 0.93]
2、分层K折交叉验证

其它与 K 折交叉验证相同,但在划分子集时,确保每个子集中各类别的比例与整个数据集中的比例相同

优点

  • 充分利用数据:每个数据点都被用作训练和验证,增加了数据的利用率
  • 更稳定的评估:由于每个折的类别分布与整体数据集一致,能提供更稳定的性能评估
  • 适用于不平衡数据:确保每个折中包含足够的少数类样本,有助于在不平衡数据集上进行更准确的评估

缺点

  • 计算成本高:需要训练模型 K 次,对于大型数据集或复杂模型,计算成本较高
  • 不适用于某些回归任务分层策略主要针对分类任务,对于回归任务,分层的概念不太适用

在 Python 中,可以使用 scikit-learn 库的 StratifiedKFold 来实现分层 K 折交叉验证。以下是一个简单的实现示例,主要和  K 折交叉验证的区别是采用 StratifiedKFold 方法替换了 KFold 方法

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import numpy as np

# 加载数据集
data = load_iris()
X, y = data.data, data.target

# 定义模型
model = RandomForestClassifier()

# 定义 K 折交叉验证
k = 10
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

# 进行交叉验证
scores = cross_val_score(model, X, y, cv=kf)

# 输出结果
print(f"K折交叉验证的平均准确率: {scores.mean():.2f}")
print(f"每折的准确率: {np.round(scores, 2)}")

# -------- 输出 --------
# K折交叉验证的平均准确率: 0.96
# 每折的准确率: [1. 1. 1. 0.93 1. 0.87 0.87 1. 1. 0.93]
3、留一交叉验证

留一交叉验证法(Leave-One-Out Cross-Validation, LOOCV)是一种特殊的交叉验证方法,其中每次迭代只用一个样本作为验证集,其余样本作为训练集。这种方法适用于小型数据集,因为它计算量较大

优点

  • 最大化数据使用:每次训练使用了几乎所有的数据,确保模型训练的充分性
  • 无偏估计:由于每个样本都被用作验证集,评估结果通常较为无偏
  • 适合小数据集:在样本数较少的情况下,能够充分利用数据进行模型评估

缺点

  • 计算成本高:需要训练模型 N 次(N 为样本数),对于大型数据集,计算代价非常高
  • 结果方差大:每次只用一个样本进行验证,导致评估结果的方差较大,可能不如 K 折交叉验证稳定
  • 不适合大型数据集:由于计算成本和时间要求,通常不适用于大型数据集
  • 过拟合风险:在某些情况下,可能导致模型过拟合,因为每次训练几乎使用了所有数据

在 Python 中,可以使用 scikit-learn 库的 LeaveOneOut 方法来实现 留一交叉验证。以下是一个简单的实现示例:

from sklearn.model_selection import LeaveOneOut, cross_val_score
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier

# 加载数据集
data = load_iris()
X, y = data.data, data.target

# 定义模型
model = RandomForestClassifier()

# 定义留一交叉验证
loo = LeaveOneOut()

# 进行交叉验证
scores = cross_val_score(model, X, y, cv=loo)

# 输出结果
print(f"留一交叉验证的平均准确率: {scores.mean():.2f}")
print(f"样本总数: {len(scores)}")

# -------- 输出 --------
# 留一交叉验证的平均准确率: 0.95
# 样本总数: 150
4、Bootstrap 自助法

Bootstrap 自助法是有放回的重复采样:在含有 m 个样本的数据集中,每次随机挑选一个样本, 将其作为训练样本,再将此样本放回到数据集中,这样有放回地抽样 m 次,生成一个与原数据集大小相同的数据集,这个新数据集就是训练集。这样有些样本可能在训练集中出现多次,有些则可能从未出现。原数据集中大概有 36.8% 的样本不会出现在新数据集中。因此,我们把这些未出现在新数据集中的样本作为验证集。

  • 优点:训练集的样本总数和原数据集一样都是 m个,并且仍有约 1/3 的数据不出现在训练集中,而可以作为验证集。
  • 缺点:这样产生的训练集的数据分布和原数据集的不一样了,会引入估计偏差。

用途:自助法在数据集较小,难以有效划分训练集/验证集时很有用;此外,自助法能从初始数据集中产生多个不同的训练集,这对集成学习等方法有很大的好处。

注意:由于其训练集有重复数据,这会改变数据的分布,因而导致训练结果有估计偏差,因此这种方法不是很常用,除非数据量真的很少。

以下是使用 Python 实现 Bootstrap 自助法的一个简单的实现示例:

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# 加载数据集
data = load_iris()
X = data.data
y = data.target

# 设置参数
num_iterations = 100
n_samples = len(X)
accuracies = []

# Bootstrap 自助法
for _ in range(num_iterations):
    # 从原始数据集中有放回地抽样
    indices = np.random.choice(n_samples, size=n_samples, replace=True)
    X_bootstrap = X[indices]
    y_bootstrap = y[indices]

    # 训练模型
    model = LogisticRegression(max_iter=200)
    model.fit(X_bootstrap, y_bootstrap)

    # 预测
    y_pred = model.predict(X)

    # 计算准确率
    accuracy = accuracy_score(y, y_pred)
    accuracies.append(accuracy)

# 结果汇总
mean_accuracy = np.mean(accuracies)
std_accuracy = np.std(accuracies)

print(f"平均准确率: {mean_accuracy:.2f}")
print(f"准确率标准差: {std_accuracy:.2f}")

# -------- 输出 --------
# 平均准确率: 0.97
# 准确率标准差: 0.01
5、Subsampling 随机子集交叉验证

随机子集交叉验证(Subsampling)是一种模型评估方法,通过多次随机拆分数据集来验证模型的性能。与标准的 k 折交叉验证不同,随机子集交叉验证不一定将数据划分为固定的 k 个子集,而是每次随机选择训练集和测试集。以下是详细的步骤和代码示例

  • 优点:能够在不同的训练集和测试集上进行多次评估,提供对模型性能的更加稳健的估计。
  • 缺点:计算成本较高,需要多次重复采样和训练模型。
  • 使用场景:用于需要更稳健的性能评估的情况,或者对模型性能进行一致性验证。

以下是使用 Python 实现 Subsampling 随机子集交叉验证 的一个简单的实现示例:

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

# 加载数据集
data = load_iris()
X = data.data
y = data.target

# 设置参数
num_iterations = 100
test_size = 0.3
accuracies = []

# 随机子集交叉验证
for _ in range(num_iterations):
    # 随机拆分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
    
    # 训练模型
    model = LogisticRegression(max_iter=200)
    model.fit(X_train, y_train)
    
    # 预测
    y_pred = model.predict(X_test)
    
    # 计算准确率
    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)

# 结果汇总
mean_accuracy = np.mean(accuracies)
std_accuracy = np.std(accuracies)

print(f"平均准确率: {mean_accuracy:.2f}")
print(f"准确率标准差: {std_accuracy:.2f}")

# -------- 输出 --------
# 平均准确率: 0.96
# 准确率标准差: 0.03


如果你喜欢本文,欢迎点赞,并且关注我们的微信公众号:Python数据挖掘分析,我们会持续更新数据挖掘分析领域的好文章,让大家在数据挖掘分析领域持续精进提升,成为更好的自己!

添加本人微信(coder_0101),或者通过扫描下面二维码添加二维码,拉你进入行业技术交流群,进行技术交流~

扫描以下二维码,加入 Python数据挖掘分析 群,在群内与众多业界大牛互动,了解行业发展前沿~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187876.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElasticSearch备考 -- Update by query

一、题目 有个索引task,里面的文档长这样 现在需要添加一个字段all,这个字段的值是以下 a、b、c、d字段的值连在一起 二、思考 需要把四个字段拼接到一起,组成一个新的字段,这个就需要脚本, 这里有两种方案&#xff…

geodatatool(地图资源下载工具)3.8更新

geodatatool(地图资源下载工具)3.8(新)修复更新,修复更新包括: 1.高德POI数据按行政区划下载功能完善。 2.修正高德POI数据类型重复问题。 3.对高德KEY数据访问量超过最大限制时,提示错误并终止…

RK3568平台(显示篇)车机图像显示偏白问题分析

一.显示偏白图片对比 正常图像: 偏白图像: 二.分析过程

手把手教你使用Tensorflow2.7完成人脸识别系统,web人脸识别(Flask框架)+pyqt界面,保姆级教程

目录 前言一、系统总流程设计二、环境安装1. 创建虚拟环境2.安装其他库 三、模型搭建1.采集数据集2. 数据预处理3.构建模型和训练 五、摄像头测试六、web界面搭建与pyqt界面搭建报错了并解决的方法总结 前言 随着人工智能的不断发展,机器学习和深度学习这门技术也越…

YOLO11改进|注意力机制篇|引入注意力与卷积混合的ACmix

目录 一、ACmix注意力机制1.1ACmix注意力介绍1.2ACmix核心代码 二、添加ACmix注意力机制2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、ACmix注意力机制 1.1ACmix注意力介绍 ACmix设计为一个结合了卷积和自注意力机制优势的混合模块&am…

Redis: 集群测试和集群原理

集群测试 1 ) SET/GET 命令 测试 set 和 get 因为其他命令也基本相似,我们在 101 节点上尝试连接 103 $ /usr/local/redis/bin/redis-cli -c -a 123456 -h 192.168.10.103 -p 6376我们在插入或读取一个 key的时候,会对这个key做一个hash运算&#xff0c…

判断有向图是否为单连通图的算法

判断有向图是否为单连通图的算法 算法描述伪代码C语言实现解释在图论中,单连通图(singly connected graph)是指对于图中的任意两个顶点 m 和 v,如果存在从 m 到 v 的路径,则该路径是唯一的。为了判断一个有向图是否为单连通图,我们需要确保从任意顶点出发,到任意其他顶点…

开发能够抵御ICS对抗性攻击的边缘弹性机器学习集成

论文标题:《Development of an Edge Resilient ML Ensemble to Tolerate ICS Adversarial Attacks》 作者信息: Likai Yao, NSF Center for Cloud and Autonomic Computing, University of Arizona, Tucson, AZ 85721 USAQinxuan Shi, School of Elect…

【数据库差异研究】别名与表字段冲突,不同数据库在where中的处理行为

目录 ⚛️总结 ☪️1 问题描述 ☪️2 测试用例 ♋2.1 测试单层查询 ♏2.1.1 SQLITE数据库 ♐2.1.2 ORACLE数据库 ♑2.1.3 PG数据库 ♋2.2 测试嵌套查询 ♉2.2.1 SQLITE数据库 ♈2.2.2 ORACLE数据库 🔯2.2.3 PG数据库 ⚛️总结 单层查询 数据库类型别名…

字节终面问 Transformer,太难了。。。

最近已有不少大厂都在秋招宣讲了,也有一些在 Offer 发放阶段。 节前,我们邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对新手如何入门算法岗、该如何准备面试攻略、面试常考点、大模型技术趋势、算法项目落地经验分享等热门话题进行了…

厦门网站设计的用户体验优化策略

厦门网站设计的用户体验优化策略 在信息化快速发展的今天,网站作为企业与用户沟通的重要桥梁,用户体验(UX)的优化显得尤为重要。尤其是在交通便利、旅游资源丰富的厦门,吸引了大量企业进驻。在这样竞争激烈的环境中&am…

后端向页面传数据(内容管理系统)

一、登录 首先&#xff0c;做一个登录页面。 在这里&#xff0c;注意 内容框里的提示信息用placeholder <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthd…

基于J2EE技术的高校社团综合服务系统

目录 毕设制作流程功能和技术介绍系统实现截图开发核心技术介绍&#xff1a;使用说明开发步骤编译运行代码执行流程核心代码部分展示可行性分析软件测试详细视频演示源码获取 毕设制作流程 &#xff08;1&#xff09;与指导老师确定系统主要功能&#xff1b; &#xff08;2&am…

Visual Studio AI插件推荐

声明&#xff1a;个人喜好&#xff0c;仅供参考。 1、AI插件 Fitten Code&#xff08;免费&#xff09; Fitten Code 是由非十大模型驱动的AI编程助手&#xff0c;支持多种编程语言&#xff0c;支持主流几乎所有的IDE开发工具。包括VS Code、Visual Studio、JetBrains系列I…

Visual Studio 小技巧记录

1、将行距设置成1.15跟舒服一些。 2、括号进行颜色对比。 效果&#xff1a; 3、显示参数内联提示。 效果&#xff1a; 4、保存时规范化代码。 配置文件&#xff1a; 5、将滚动条修改为缩略图 效果&#xff1a;

MongoDB 数据库服务搭建(单机)

下载地址 下载测试数据 作者&#xff1a;程序那点事儿 日期&#xff1a;2023/02/15 02:16 进入下载页&#xff0c;选择版本后&#xff0c;右键Download复制连接地址 下载安装包 ​ wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel70-5.0.14.tgz​ …

开放式耳机哪个品牌好?好用且高性价比的开放式蓝牙耳机推荐

相信很多经常运动的朋友都不是很喜欢佩戴入耳式耳机&#xff0c;因为入耳式耳机真的有很多缺点。 安全方面&#xff1a;在安全上就很容易存在隐患&#xff0c;戴上后难以听到周围环境声音&#xff0c;像汽车鸣笛、行人呼喊等&#xff0c;容易在运动中发生意外。 健康方面&…

智慧管控平台技术解决方案

1. 智慧管控平台概述 智慧管控平台采用先进的AI技术&#xff0c;围绕一个中心和四大应用构建&#xff0c;旨在打造一个智能、共享、高效的智慧运营管理环境&#xff0c;实现绿色节能和业务创新。 2. 平台架构设计 系统整体架构设计包括统一门户管理、IOT平台、大数据、视频云…

螺蛳壳里做道场:老破机搭建的私人数据中心---Centos下docker学习02(yum源切换及docker安装配置)

2 前期工作 2.1 切换yum源并更新 删除/etc/yum.repos.d/原有repo文件&#xff0c;将Centos-7.repo库文件拷贝到该目录下。 然后清楚原有缓存yum clean all 生成新的缓存yum makecache 更新yum update –y 然后再确认/etc/yum.repos.d/不会有其他库文件&#xff0c;只留下…

第十四章 I/O系统

一、I/O系统的分类 1.输入流&#xff1a;程序从输入流读取数据 输出流&#xff1a;程序向输出流写入数据 2.字节流&#xff1a;数据流中的最小的数据单元是字节 字符流&#xff1a;数据流中的最小单元是字符 3.节点流、处理流 二、I/O系统的四个抽象类 1.Java中提供的流类…