交叉验证之KFold和StratifiedKFold的使用(附案例实战)

news2024/12/25 9:26:57

🤵‍♂️ 个人主页:@艾派森的个人主页

✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+


 一、交叉验证简介

        交叉验证是在机器学习建立模型和验证模型参数时常用的办法。交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏。在此基础上可以得到多组不同的训练集和测试集,某次训练集中的某样本在下次可能成为测试集中的样本,即所谓“交叉”。

  那么什么时候才需要交叉验证呢?交叉验证用在数据不是很充足的时候。通常情况下,如果数据样本量小于一万条,我们就会采用交叉验证来训练优化选择模型。如果样本大于一万条的话,我们一般随机的把数据分成三份,一份为训练集(Training Set),一份为验证集(Validation Set),最后一份为测试集(Test Set)。用训练集来训练模型,用验证集来评估模型预测的好坏和选择模型及其对应的参数。把最终得到的模型再用于测试集,最终决定使用哪个模型以及对应参数。

        学习预测函数的参数,并在相同数据集上进行测试是一种错误的做法: 一个仅给出测试用例标签的模型将会获得极高的分数,但对于尚未出现过的数据它则无法预测出任何有用的信息。 这种情况称为 overfitting(过拟合).。为了避免这种情况,在进行机器学习实验时,通常取出部分可利用数据作为 test set(测试数据集) X_test, y_test。下面是模型训练中典型的交叉验证工作流流程图。通过网格搜索可以确定最佳参数。

         k-折交叉验证得出的性能指标是循环计算中每个值的平均值。 该方法虽然计算代价很高,但是它不会浪费太多的数据(如固定任意测试集的情况一样), 在处理样本数据集较少的问题(例如,逆向推理)时比较有优势。

k-折交叉验证步骤

  • 第一步,不重复抽样将原始数据随机分为 k 份。
  • 第二步,每一次挑选其中 1 份作为测试集,剩余 k-1 份作为训练集用于模型训练。
  • 第三步,重复第二步 k 次,这样每个子集都有一次机会作为测试集,其余机会作为训练集。
  • 在每个训练集上训练后得到一个模型,
  • 用这个模型在相应的测试集上测试,计算并保存模型的评估指标,
  • 第四步,计算 k 组测试结果的平均值作为模型精度的估计,并作为当前 k 折交叉验证下模型的性能指标。
     

例如:

十折交叉验证

  • 将训练集分成十份,轮流将其中9份作为训练数据,1份作为测试数据,进行试验。每次试验都会得出相应的正确率。
  • 10次的结果的正确率的平均值作为对算法精度的估计,一般还需要进行多次10折交叉验证(例如10次10折交叉验证),再求其均值,作为对算法准确性的估计
  • 模型训练过程的所有步骤,包括模型选择,特征选择等都是在单个折叠 fold 中独立执行的。
  • 此外:
    • 多次 k 折交叉验证再求均值,例如:10 次10 折交叉验证,以求更精确一点。
    • 数据量大时,k设置小一些 / 数据量小时,k设置大一些。
       

KFold和StratifiedKFold的使用

        StratifiedKFold用法类似Kfold,但是它是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同。这一区别在于当遇到非平衡数据时,StratifiedKFold() 各个类别的比例大致和完整数据集中相同,若数据集有4个类别,比例是2:3:3:2,则划分后的样本比例约是2:3:3:2;但是KFold可能存在一种情况:数据集有5类,抽取出来的也正好是按照类别划分的5类,也就是说第一折全是0类,第二折全是1类等等,这样的结果就会导致模型训练时没有学习到测试集中数据的特点,从而导致模型得分很低,甚至为0。

Parameters

  • n_splits : int, default=3   也就是K折中的k值,必须大于等于2
  • shuffle : boolean  True表示打乱顺序,False反之
  • random_state :int,default=None 随机种子,如果设置值了,shuffle必须为True
# KFold
from sklearn.model_selection import KFold
kfolds = KFold(n_splits=3)
for train_index, test_index in kfolds.split(X,y):
    print('X_train:%s ' % X[train_index])
    print('X_test: %s ' % X[test_index])

# StratifiedKFold
from sklearn.model_selection import StratifiedKFold
skfold = StratifiedKFold(n_splits=3)
for train_index, test_index in skfold.split(X,y):
    print('X_train:%s ' % X[train_index])
    print('X_test: %s ' % X[test_index])

KFold和StratifiedKFold实战案例

首先导入数据集,本数据集为员工离职数据,属于二分类任务

import pandas as pd
import warnings
warnings.filterwarnings('ignore')

data = pd.read_excel('data.xlsx')
data['薪资情况'].replace(to_replace={'低':0,'中':1,'高':2},inplace=True)
data.head()

 拆分数据集为训练集和测试集,测试集比例为0.2

from sklearn.model_selection import train_test_split
X = data.drop('是否离职',axis=1)
y = data['是否离职']
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)

初始化一个分类模型,这里用逻辑回归模型举例。方法1使用cross_val_score()可以直接得到k折训练的模型效果,比如下面使用3折进行训练,得分评估使用准确率,关于scoring这个参数我会在文末介绍。

# 初始化一个分类模型,比如逻辑回归
from sklearn.linear_model import LogisticRegression
lg = LogisticRegression()
# 方法1
from sklearn.model_selection import cross_val_score
scores = cross_val_score(lg,X_train,y_train,cv=3,scoring='accuracy')
print(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

 接下来分别使用KFold和StratifiedKFold,其实两者代码非常类似,只是前面的方法不同。

KFold

# 方法2-KFold和StratifiedKFold
import numpy as np
from sklearn.model_selection import KFold,StratifiedKFold
from sklearn.metrics import accuracy_score,recall_score,f1_score
# KFold
kfolds = KFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in kfolds.split(X_train,y_train):
    # 准备交叉验证的数据
    X_train_fold = X_train.iloc[train_index]
    y_train_fold = y_train.iloc[train_index]
    X_test_fold = X_train.iloc[test_index]
    y_test_fold = y_train.iloc[test_index]
    # 训练模型
    lg.fit(X_train_fold,y_train_fold)
    y_pred = lg.predict(X_test_fold)
    # 评估模型
    AccuracyScore = accuracy_score(y_test_fold,y_pred)
    RecallScore = recall_score(y_test_fold,y_pred)
    F1Score = f1_score(y_test_fold,y_pred)
    # 将评估指标存放对应的列表中
    accuracy_score_list.append(AccuracyScore)
    recall_score_list.append(RecallScore)
    f1_score_list.append(F1Score)
    # 打印每一次训练的正确率、召回率、F1值
    print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

StratifiedKFold

# StratifiedKFold
skfolds = StratifiedKFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in skfolds.split(X_train,y_train):
    # 准备交叉验证的数据
    X_train_fold = X_train.iloc[train_index]
    y_train_fold = y_train.iloc[train_index]
    X_test_fold = X_train.iloc[test_index]
    y_test_fold = y_train.iloc[test_index]
    # 训练模型
    lg.fit(X_train_fold,y_train_fold)
    y_pred = lg.predict(X_test_fold)
    # 评估模型
    AccuracyScore = accuracy_score(y_test_fold,y_pred)
    RecallScore = recall_score(y_test_fold,y_pred)
    F1Score = f1_score(y_test_fold,y_pred)
    # 将评估指标存放对应的列表中
    accuracy_score_list.append(AccuracyScore)
    recall_score_list.append(RecallScore)
    f1_score_list.append(F1Score)
    # 打印每一次训练的正确率、召回率、F1值
    print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

补充

scoring 参数: 定义模型评估规则

Model selection (模型选择)和 evaluation (评估)使用工具,例如 model_selection.GridSearchCV 和 model_selection.cross_val_score ,采用 scoring 参数来控制它们对 estimators evaluated (评估的估计量)应用的指标。

常见场景: 预定义值

        对于最常见的用例, 可以使用 scoring 参数指定一个 scorer object (记分对象); 下表显示了所有可能的值。 所有 scorer objects (记分对象)遵循惯例 higher return values are better than lower return values(较高的返回值优于较低的返回值)。因此,测量模型和数据之间距离的 metrics (度量),如 metrics.mean_squared_error 可用作返回 metric (指数)的 negated value (否定值)的 neg_mean_squared_error 。

Scoring(得分)Function(函数)Comment(注解)
Classification(分类)
‘accuracy’metrics.accuracy_score
‘average_precision’metrics.average_precision_score
‘f1’metrics.f1_scorefor binary targets(用于二进制目标)
‘f1_micro’metrics.f1_scoremicro-averaged(微平均)
‘f1_macro’metrics.f1_scoremacro-averaged(宏平均)
‘f1_weighted’metrics.f1_scoreweighted average(加权平均)
‘f1_samples’metrics.f1_scoreby multilabel sample(通过 multilabel 样本)
‘neg_log_loss’metrics.log_lossrequires predict_proba support(需要 predict_proba 支持)
‘precision’ etc.metrics.precision_scoresuffixes apply as with ‘f1’(后缀适用于 ‘f1’)
‘recall’ etc.metrics.recall_scoresuffixes apply as with ‘f1’(后缀适用于 ‘f1’)
‘roc_auc’metrics.roc_auc_score
Clustering(聚类)
‘adjusted_mutual_info_score’metrics.adjusted_mutual_info_score
‘adjusted_rand_score’metrics.adjusted_rand_score
‘completeness_score’metrics.completeness_score
‘fowlkes_mallows_score’metrics.fowlkes_mallows_score
‘homogeneity_score’metrics.homogeneity_score
‘mutual_info_score’metrics.mutual_info_score
‘normalized_mutual_info_score’metrics.normalized_mutual_info_score
‘v_measure_score’metrics.v_measure_score
Regression(回归)
‘explained_variance’metrics.explained_variance_score
‘neg_mean_absolute_error’metrics.mean_absolute_error
‘neg_mean_squared_error’metrics.mean_squared_error
‘neg_mean_squared_log_error’metrics.mean_squared_log_error
‘neg_median_absolute_error’metrics.median_absolute_error
‘r2’metrics.r2_score

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/458796.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣---LeetCode88. 合并两个有序数组

文章目录 前言88. 合并两个有序数组链接:方法一:三指针(后插)1.2 代码:1.2 流程图:方法二:开辟新空间2.1 代码:2.2 流程图:2.3 注意: 总结 前言 “或许你并不熠熠生辉甚至有点木讷但…

POSTGRESQL COPY 命令原理与加速数据 导入提高速度200%以上

开头还是介绍一下群,如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题,有需求都可以加群群内有各大数据库行业大咖,CTO,可以解决你的问题。加群请联系 liuaustin3 ,在新加的朋友会分到2群(共…

vue2+vue3——107+

vue2vue3——107 vue2 Vuex工作原理图【23:54】vue2 搭建Vuex环境【26:40】插入 话题npm i vue3 store / index.js修改 vue2 求和案例_vuex版【22:39】vue2 vuex开发者工具的使用【23:21】vue2 getters配置项【07:55】vue2 mapState与mapGetters【25:20】vue2 mapActions与mapM…

egg3.0连接egg-mongoose操作数据库,删除一条数据、批量删除数据

删除一条数据 定义service app\service\role.js async delItem() {const { ctx } this;let results;await ctx.model.Role.deleteOne({ name: test-S3 }).then(res > {console.log(results-del-success, res);results res?.deletedCount > 0;}).catch(err > {con…

系统分析师之软件工程(十二)

目录 一、 软件开发生命周期 1.1 开发阶段工作细分 二、软件开发模型 2.1 瀑布模型 2.2 原型模型 2.3 增量模型与螺旋模型 2.4 V模型 2.5 喷泉模型 2.6 快速应用开发模型RAD 2.7 构件主装模型 2.8 统一过程 2.9 敏捷方法 三、逆向工程 四、净室软件工程 一、 软件…

为何C语言的函数调用要用到堆栈,而汇编却不需要自定义栈

一 ≠ 汇编不需要堆栈 汇编中一般不初始化,也就是直接使用系统的堆栈而已,自己定义堆栈还是要初始化的。 之前看了很多关于uboot的分析,其中就有说要为C语言的运行,准备好堆栈。 而自己在Uboot的start.S汇编代码中&#xff0c…

crm-day04 分页查询市场活动,刷新市场活动列表

分页插件 分页这个组件前端要写也很麻烦&#xff0c;而且与业务逻辑代码无关&#xff0c;因此我们引入一个分页查询的插件。 进行jsp测试 三大步骤&#xff1a; 1、引入相关的包 2、创建容器来保存插件的运行结果 容器是<input typetext/>或者div。 3、容器加载完成后&a…

猫猫与主人

时间限制&#xff1a;C/C 1秒&#xff0c;其他语言2秒 空间限制&#xff1a;C/C 262144K&#xff0c;其他语言524288K 64bit IO Format: %lld 对猫猫按照友善值进行排序 对主人按照期望友善值进行排序 就可以找到能收养猫猫的主人 对主人的友善值取一个max最后跟猫猫的期望友…

用大佬开发的模板做了“智慧水务”,终于可以和老板谈加薪喽!

为什么各个行业要进行数字化转型&#xff1f; 其实很好理解&#xff0c;这其中很大一部分属于传统行业&#xff0c;以往运营方式较为粗放&#xff0c;信息标准化程度偏低&#xff0c;但同时也意味着数字化的历史包袱轻&#xff0c;此时跟上潮流进行数字化转型&#xff0c;有利于…

美颜SDK的性能测试和优化方案

美颜SDK作为美颜相机、短视频等应用的核心技术之一&#xff0c;对于提升用户体验和增加应用商业价值起到了至关重要的作用。然而&#xff0c;如何对美颜SDK进行性能测试和优化&#xff0c;成为了广大应用开发者们所面临的一大难题。很多开发者也曾经向小编提起过应该如何着手优…

nodejs+python+php+springboot+vue 校园安全车辆人员出入安全管理系统

拟开发的校园安全管理系统通过测试,确保在最大负载的情况下稳定运转,各个模块工作正常,具有较高的可用性。系统整体界面简洁美观,用户使用简单,满足用户需要。在因特网发展迅猛的当今社会,校园安全管理系统必然会成为在数字信息化建设的一个重要方面。 本文阐述了开发的校园安全…

马斯克要告微软 拒绝AI训练“白嫖”数据

“现在是诉讼时间。”4月20日&#xff0c;推特被微软踢出其数字营销平台后&#xff0c;新掌门人马斯克立马发推回击称&#xff0c;微软用推特的数据做“非法训练”。这一怼&#xff0c;直接揭开了AI大模型开发商与数据源的利益之争。 此前&#xff0c;在线社区论坛Reddit与程序…

黑马redis实战篇-商铺缓存

目录 五、实战篇-商户查询缓存 5.1 什么是缓存 5.2 添加Redis缓存 1、不添加redis时&#xff0c;数据查询的作用模型&#xff1a; 2、添加redis时&#xff0c;数据查询的作用模型&#xff1a; 3、业务流程图&#xff1a;​编辑 4、代码实现 5、练习题 5.3 缓存更新策略…

【Android FrameWork (三)】- SystemServer

文章目录 知识回顾启动第一个流程initZygote的流程 前言源码分析1.system_server2.SystemServer.main3,startBootstrapServices4,startService 拓展知识LoadApkcontext 对于Android context 大家是怎么理解的&#xff1f;LocalServices.java: addServece方法中 ArrayMap和HashM…

Matlab 绘制双纵轴三纵轴图

三纵轴图 三坐标的图在前文中有所介绍&#xff1b;这次主要讲绘制双轴。 matlab 绘制三坐标&#xff08;轴&#xff09;图 绘制双纵轴图: yyaxis 简单用法 在MATLAB中&#xff0c;yyaxis可以用于绘制具有两个不同y轴的图形。以下是yyaxis的简单用法&#xff1a; 1.首先&am…

UG NX二次开发(C#)-UIStyler-找不到指定的Dlx文件的错误解决方法

1、项目场景: 在UG NX二次开发过程中,我们为了更好的操作,采用UI Styler设计了软件界面,然后按照UI Styler的编程流程成功的生成了dll,但是在采用Ctrl+U或者用“文件“->“执行”->"NX Open"执行dll时,遇到如下图所示的错误页面,提示内容为:找不到指定…

成就更强大的自己

每一次低谷&#xff0c;都会酝酿向上的力量。 每一次痛苦过后&#xff0c;都会洗涤掉心理深处的灰尘。 人生的路上&#xff0c;坎坷前行&#xff0c;只有保持积极向上的态度&#xff0c;才能把坎坷化为坦途。 走过一段路后&#xff0c;才发现&#xff0c;当内心强大、修养、爱…

Android之 颜色选择器

一&#xff0c;简介 1.1 计算机的颜色通常有两种表示方式&#xff1a; 光源模式RGB(Red红, Green绿, Blue蓝)&#xff0c;数值0-255 印刷模式CMYK(Cyan青, Magenta品红, Yellow黄, Black黑)&#xff0c;数值1-100 任何颜色都是由RGB或CMYK混合出来的&#xff0c;再加上透明度…

2023年产业基金研究报告

第一章 行业概况 1.1 概述 产业基金&#xff0c;又称为产业投资基金&#xff0c;是一种由政府、企业、金融机构等出资设立的&#xff0c;专门用于支持和促进特定产业发展的投资基金。产业基金通常以股权投资和长期投资为主&#xff0c;旨在推动产业结构升级、促进科技创新、提…

算法刷题|139.单词拆分、多重背包

单词拆分 题目&#xff1a;给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 。 注意&#xff1a;不要求字典中出现的单词全部都使用&#xff0c;并且字典中的单词可以重复使用。 思路&#xff1a;字符串s就是我们的背包…