StratifiedKFold解释和代码实现

news2025/1/14 18:36:21

StratifiedKFold解释和代码实现

文章目录

  • 一、StratifiedKFold是什么?
  • 二、 实验数据设置
    • 2.1 实验数据生成代码
    • 2.2 代码结果
  • 三、实验代码
    • 3.1 实验代码
    • 3.2 实验结果
    • 3.3 结果解释
    • 3.4 数据打乱对这种交叉验证的影响。
  • 四、总结


一、StratifiedKFold是什么?

0,1,2,3:每一行表示测试集和训练集的划分的一种方式。
class:表示类别的个数(下图显示的是3类),有些交叉验证根据类别的比例划分测试集和训练集(例三)。
group:表示从不同的组采集到的样本,颜色的个数表示组的个数(有些时候我们关注在一组特定组上训练的模型是否能很好地泛化到看不见的组)。举个例子(解释“组”的意思):我们有10个人,我们想要希望训练集上所用的数据来自(1,2,3,4,5,6,7,8),测试集上的数据来自(9,10),也就是说我们不希望测试集上的数据和训练集上的数据来自同一个人(如果来自同一个人的话,训练集上的信息泄漏到测试集上了,模型的泛化性能会降低,测试结果会偏好)。
在这里插入图片描述

二、 实验数据设置

2.1 实验数据生成代码

X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("数据:", end=" ")
for l in X:
    print(l, end=' ')
print("")
print("标签:", y)

2.2 代码结果

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] 
       [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] 
       [54 55] [56 57] [58 59] 
标签: [0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]

数据个数、标签个数:30个
类别个数:3个(分别是0,1,2,比例是0.1:0.3:0.6和class每类对应),StratifiedKFold
组别(group):由于StratifiedKFold交叉验证结果和group无关,所以这里不再设置。

三、实验代码

3.1 实验代码

代码如下:

from sklearn.model_selection import StratifiedKFold
import numpy as np
# X, y = np.ones((30, 1)), np.hstack(([0] * 20, [1] * 10))
# print(np.arange(0,30).reshape((30,1)))
X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("数据:", end=" ")
for l in X:
    print(l, end=' ')
print("")
print("标签:", y)
skf = StratifiedKFold(n_splits=3)
for i,(train, test) in enumerate(skf.split(X, y)):
    print("=================StratifiedKFold 第%d折叠 ===================="% (i+1))
    print('train -  {}'.format(np.bincount(y[train])))
    print("  训练集索引:%s" % train)
    print("  训练集标签:", y[train])
    print("  训练集数据:", end=" ")
    for l in X[train]:
        print(l, end=' ')
    print("")
    # print("  训练集数据:", X[train])
    print("test  -  {}".format(np.bincount(y[test])))
    print("  测试集索引:%s" % test)
    print("  测试集标签:", y[test])
    print("  测试集数据:", end=" ")
    for l in X[test]:
        print(l, end=' ')
    print("")
    # print("  测试集数据:", X[test])
    print("=============================================================")

3.2 实验结果

结果如下:

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
标签: [0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
=================StratifiedKFold 第1折叠 ====================
train -  [ 2  6 12]
  训练集索引:[ 1  2  6  7  8  9 10 11 18 19 20 21 22 23 24 25 26 27 28 29]
  训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
  训练集数据: [2 3] [4 5] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]
  测试集索引:[ 0  3  4  5 12 13 14 15 16 17]
  测试集标签: [0 1 1 1 2 2 2 2 2 2]
  测试集数据: [0 1] [6 7] [8 9] [10 11] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] 
=============================================================
=================StratifiedKFold 第2折叠 ====================
train -  [ 2  6 12]
  训练集索引:[ 0  2  3  4  5  9 10 11 12 13 14 15 16 17 24 25 26 27 28 29]
  训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
  训练集数据: [0 1] [4 5] [6 7] [8 9] [10 11] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]
  测试集索引:[ 1  6  7  8 18 19 20 21 22 23]
  测试集标签: [0 1 1 1 2 2 2 2 2 2]
  测试集数据: [2 3] [12 13] [14 15] [16 17] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] 
=============================================================
=================StratifiedKFold 第3折叠 ====================
train -  [ 2  6 12]
  训练集索引:[ 0  1  3  4  5  6  7  8 12 13 14 15 16 17 18 19 20 21 22 23]
  训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
  训练集数据: [0 1] [2 3] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] 
test  -  [1 3 6]
  测试集索引:[ 2  9 10 11 24 25 26 27 28 29]
  测试集标签: [0 1 1 1 2 2 2 2 2 2]
  测试集数据: [4 5] [18 19] [20 21] [22 23] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
=============================================================

进程已结束,退出代码 0

3.3 结果解释

可以看到测试集和训练集划分是根据折叠数和标签的比例。例如:这里的折叠数是3,标签的比例是1:3:6,所以在第一折叠处测试集标签0的个数是1/3(折叠数)*0.1(标签比例)*30(样本数)=1个。剩余的分析同理。

=================StratifiedKFold 第1折叠 ====================
train -  [ 2  6 12]
  训练集索引:[ 1  2  6  7  8  9 10 11 18 19 20 21 22 23 24 25 26 27 28 29]
  训练集标签: [0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
  训练集数据: [2 3] [4 5] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
test  -  [1 3 6]
  测试集索引:[ 0  3  4  5 12 13 14 15 16 17]
  测试集标签: [0 1 1 1 2 2 2 2 2 2]
  测试集数据: [0 1] [6 7] [8 9] [10 11] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] 
=============================================================

3.4 数据打乱对这种交叉验证的影响。

X, y = np.arange(0,60).reshape((30,2)), np.hstack(([0] * 3, [1] * 9, [2] * 18))

改为下面的代码

arr = np.hstack(([0] * 3, [1] * 9, [2] * 18))
print("原始标签:", arr)
# 使用np.random.shuffle函数将数组打乱
np.random.shuffle(arr)
X, y = np.arange(0,60).reshape((30,2)), arr

可以看出划分和标签的先后顺序有一定的关系。

四、总结

StratifiedKFold:考虑了标签(class),但没考虑组(group)的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1349133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Serverless Framework:开发无服务器应用的最佳工具 | 开源日报 No.133

serverless/serverless Stars: 45.6k License: MIT 该项目是 Serverless Framework,它是一个命令行工具,使用简单易懂的 YAML 语法部署代码和云基础设施以满足各种无服务器应用程序需求。支持 Node.js、Typescript、Python、Go 等多种编程语言&#xff…

Mysql体系结构一次讲清

Mysql进阶 Mysql体系结构 大体来说,MySQL 可以分为 Server 层和存储引擎层两部分。 Server层 主要包括连接器、查询缓存、分析器、优化器、执行器等,涵盖 MySQL 的大多数核心服务功能,以及所有的内 置函数(如日期、时间、数…

我的512天创作者纪念日总结:高效、高现

文章目录 512天创作者纪念日:2023年的12月31日CSDN的512天消息提醒第一篇文章,最后一篇文章总计847篇文章,每月发文分布512天,各专栏文章统计512天,互动总成绩 512天创作者纪念日:2023年的12月31日 2023年…

概念解析 | Shapley值及其在深度学习中的应用

注1:本文系“概念解析”系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:Shapley值及其在深度学习中的应用。 1 背景介绍 在机器学习和数据分析中,理解模型的预测是非常重要的。尤其是在深度学习黑盒模型中,我们往往难以直观地理解模型的预测行为。为…

Filezilla使用

服务端 点击安装包 点击我接受 点击下一步 点击下一步 点击下一步 点击安装即可 配置用户组,点击编辑,出现组点击 点击添加,点击确定即可 配置用户,点击编辑点击用户 点击添加,设置用户名&#xff…

deepfacelive实时换脸教程(2024最新版)

deepfacelive其实操作用法很简单,难的是模型的制作。本帖主要讲deepfacelive(下文简称dflive)软件本身的操作,以及模型怎么从dfl转格式过来,至于模型如何训练才能效果好,请移步教程区,看deepfac…

uniapp 新建组件

1. 新建文件夹 components 文件夹名称必须是 components &#xff0c;否则组件无法自动导入 2. 新建组件 3. 编辑组件 components/logo/logo.vue <template><img src"https://img.alicdn.com/imgextra/i1/O1CN01EI93PS1xWbnJ87dXX_!!6000000006451-2-tps-150-15…

全局异常和自定义异常处理

全局异常GlobalException.java&#xff0c;basePackages&#xff1a;controller层所在的包全路径 import com.guet.score_management_system.common.domian.AjaxResult; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bi…

软件测试/测试开发丨Python 封装 学习笔记

封装的概念 封装&#xff08;Encapsulation&#xff09; 隐藏&#xff1a;属性和实现细节&#xff0c;不允许外部直接访问暴露&#xff1a;公开方法&#xff0c;实现对内部信息的操作和访问 封装的作用 限制安全的访问和操作&#xff0c;提高数据安全性可进行数据检查&#x…

进程互斥的软件实现方法-第二十六天

目录 前言 单标志法&#xff08;谦让&#xff09; 双标志先检查法&#xff08;表达意愿&#xff09; 双标志后检查法&#xff08;表达意愿&#xff09; Peterson算法&#xff08;集大成者&#xff09; 本节思维导图 前言 单标志法、双标志的先后检查法中&#xff0c;如果…

服务器的TCP连接限制:如何优化并提高服务器的并发连接数?

&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308; 欢迎关注公众号&#xff08;通过文章导读关注&#xff09;&#xff0c;发送【资料】可领取 深入理解 Redis 系列文章结合电商场景讲解 Redis 使用场景、中间件系列…

Unity坦克大战开发全流程——1)需求分析

实践项目&#xff1a;需求分析 该游戏共有三个主要部分&#xff1a;UI、数据储存、核心游戏逻辑&#xff0c;下面我们将从开始场景、游戏场景、结束场景三个角度切入进行分析。

010、切片

除了引用&#xff0c;Rust还有另外一种不持有所有权的数据类型&#xff1a;切片&#xff08;slice&#xff09;。切片允许我们引用集合中某一段连续的元素序列&#xff0c;而不是整个集合。 考虑这样一个小问题&#xff1a;编写一个搜索函数&#xff0c;它接收字符串作为参数&a…

【代码解析】代码解析之生成token(1)

本篇文章主要解析上一篇&#xff1a;代码解析之登录&#xff08;1&#xff09;里的第8行代码调用 TokenUtils 类里的genToken 方法 https://blog.csdn.net/m0_67930426/article/details/135327553?spm1001.2014.3001.5501 genToken方法代码如下&#xff1a; public static S…

有了向量数据库,我们还需 SQL 数据库吗?

“除了向量数据库外&#xff0c;我是否还需要一个普通的 SQL 数据库&#xff1f;” 这是我们经常被问到的一个问题。如果除了向量数据以外&#xff0c;用户还有其他标量数据信息&#xff0c;那么其业务可能需要在进行语义相似性搜索前先根据某种条件过滤数据&#xff0c;例如&a…

spring创建与使用

spring创建与使用 创建 Spring 项⽬创建⼀个 Maven 项⽬添加 Spring 框架⽀持添加启动类 存储 Bean 对象创建 Bean将 Bean 注册到容器 获取并使⽤ Bean 对象创建 Spring 上下⽂获取指定的 Bean 对象获取bean对象的方法 使⽤ Bean 总结 创建 Spring 项⽬ 接下来使⽤ Maven ⽅式…

Linux:apache优化(6)—— apache的ab压力测试

要对所购买的物理机(二手)进行烧机在产品上线之前&#xff0c;对应用的一个压力测试对产品本身的压力测试 作用&#xff1a;Apache 附带了压力测试工具 ab&#xff0c;非常容易使用&#xff0c;并且完全可以模拟各种条件对 Web 服务器发起测试请求。在进行性能调整优化过程中&a…

国家开放大学形成性考核 统一考试 资料参考

试卷代号&#xff1a;11141 工程经济与管理 参考试题 一、单项选择题&#xff08;每题2分&#xff0c;共20分&#xff09; 1.资金的时间价值&#xff08; &#xff09;。 A.现在拥有的资金在将来投资时所能获得的利益 B.现在拥有的资金在将来消费时所付出的福利损失 C.…

C练习——银行存款

题目&#xff1a;设银行定期存款的年利率 rate为2.25%,已知存款期为n年&#xff0c;存款本金为capital 元,试编程计算并输出n年后本利之和deposit。 解析&#xff1a;利息本金*利率&#xff0c;下一年的本金又是是今年的本利之和 逻辑&#xff1a;注意浮点数&#xff0c;导入…

计算机操作系统(OS)——P5设备管理

1、I/O设备的概念和分类 什么是I/O设备 I/O就是输入/输出&#xff08;Input/Output&#xff09;。 I/O设备就是可以将数据输入到计算机&#xff0c;或者可以接收计算机输出数据的外部设备&#xff0c;属于计算机中的硬件部件。 UNIX系统将外部设备抽象为一种特殊的文件&#x…