sklearn.model_selection模块介绍

news2024/12/26 12:37:23

数据集划分方法

train_test_split

train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)

参数包括:

  • test_size:可选参数,表示测试集的大小。可以是一个表示比例的浮点数(例如0.2表示20%的数据作为测试集),或者是一个表示样本数量的整数。默认为None。
  • train_size:可选参数,表示训练集的大小。可以是一个表示比例的浮点数(例如0.8表示80%的数据作为训练集),或者是一个表示样本数量的整数。默认为None,表示训练集的大小由测试集大小决定。
  • random_state:可选参数,表示随机数生成器的种子,用于随机划分数据集。设置一个整数值可以保证每次划分的结果一致。
  • shuffle:可选参数,表示是否在划分数据集之前对数据进行随机打乱。默认为True,即进行随机打乱。
  • stratify:可选参数,表示根据指定的标签数组进行分层划分。标签数组的长度必须与输入数据集的第一个维度相同。适用于分类问题中的类别不平衡情况。
from sklearn.model_selection import train_test_split

X, y = load_data()  # 加载特征数据 X 和标签数据 y
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

交叉验证方法

K折交叉验证

K折交叉验证将数据集划分为K个互不重叠的子集,称为折(Fold)。模型会进行K次训练和验证,每次使用K-1个折作为训练集,剩下的1个折作为验证集。K次训练和验证的结果会进行平均,得到最终的性能评估。K折交叉验证可以通过KFold类实现,具体用法如下

from sklearn.model_selection import KFold

X = np.arange(10)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    print(X_train, X_test)
    print("*"*20)

执行结果
在这里插入图片描述

留一交叉验证LeaveOneOut

留一交叉验证是一种特殊的K折交叉验证,其中K等于数据集的样本数量。每个样本都作为单独的验证集,而其余样本作为训练集。这种方法适用于数据集较小的情况。留一交叉验证可以通过LeaveOneOut类实现,具体用法如下

from sklearn.model_selection import LeaveOneOut

loo = LeaveOneOut()
X = np.arange(10)
for train_index, test_index in loo.split(X):
    X_train, X_test = X[train_index], X[test_index]
    print(X_train, X_test)
    print("*"*20)
    # 在训练集上训练模型,使用测试集进行评估

在这里插入图片描述

分组交叉验证GroupKFold

分组交叉验证是一种考虑数据集中样本之间分组关系的交叉验证方法。在某些任务中,样本可能彼此相关或存在依赖关系,例如在自然语言处理中的句子分类任务中,同一篇文章中的句子可能相互影响。为了确保模型在训练集和验证集中都包含相同分组的样本,可以使用GroupKFold类进行分组交叉验证。具体用法如下

from sklearn.model_selection import GroupKFold

gkf = GroupKFold(n_splits=3)
for train_index, test_index in gkf.split(X, y, groups):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    # 在训练集上训练模型,使用测试集进行评估

groups参数是一个表示样本分组的数组,长度与数据集的样本数相同。

随机重复K折交叉验证RepeatedKFold

随机重复K折交叉验证是K折交叉验证的扩展,通过多次重复执行K折交叉验证来更稳定地评估模型性能。可以使用RepeatedKFold类进行随机重复K折交叉验证。具体用法如下

from sklearn.model_selection import RepeatedKFold

rkf = RepeatedKFold(n_splits=5, n_repeats=10, random_state=42)

X = np.arange(10)
i  = 0
for train_index, test_index in rkf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    print(X_train, X_test)
    print("*"*20)
    i +=1
print(i)

层次化交叉验证cross_val_score

层次化交叉验证是一种嵌套的交叉验证方法,用于在模型选择和性能评估中进行双重交叉验证。外层交叉验证用于评估不同的模型或模型参数,内层交叉验证用于在每个外层验证折上进行模型训练和验证。可以通过嵌套使用cross_val_score函数来实现层次化交叉验证。具体用法如下

from sklearn.model_selection import cross_val_score

scores = cross_val_score(estimator, X, y, cv=5)

分层K折交叉验证StratifiedKFold

分层K折交叉验证是K折交叉验证的一种变体,它在划分数据集时保持了每个类别的样本比例。这对于类别不平衡的分类问题非常重要。分层K折交叉验证可以通过StratifiedKFold类实现,具体用法与K折交叉验证类似。
StratifiedKFold的作用是确保每个折中的样本比例与原始数据集中的样本比例相同。这对于处理类别不平衡的分类问题非常重要,因为如果样本比例不平衡,模型在某些折上可能无法学习到少数类别的有效模式

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    # 在训练集上训练模型,使用测试集进行评估

参数搜索和模型选择方法

网格搜索

网格搜索通过遍历指定的参数组合来寻找最佳的模型参数配置。它通过穷举搜索所有参数组合,并在交叉验证中评估每个组合的性能。GridSearchCV类实现了网格搜索的功能。我们需要指定要搜索的参数和其取值范围,并指定评估指标和交叉验证的折数。示例代码如下

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 定义模型和参数网格
model = SVC()
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}

# 执行网格搜索
grid_search = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X, y)

# 输出最佳参数配置和得分
print("Best parameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)

执行结果
在这里插入图片描述

随机搜索

随机搜索通过随机抽样一组参数组合来寻找最佳的模型参数配置。与网格搜索不同,随机搜索不遍历所有参数组合,而是在指定的参数空间中进行随机抽样,并在交叉验证中评估每个参数组合的性能。RandomizedSearchCV类实现了随机搜索的功能。示例代码如下:

from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 定义模型和参数分布
model = RandomForestClassifier()
param_dist = {'n_estimators': [10, 50, 100], 'max_depth': [None, 5, 10]}

# 执行随机搜索
random_search = RandomizedSearchCV(model, param_distributions=param_dist, cv=5)
random_search.fit(X, y)

# 输出最佳参数配置和得分
print("Best parameters: ", random_search.best_params_)
print("Best score: ", random_search.best_score_)

执行结果
在这里插入图片描述

交叉验证(Cross-Validation)

将数据集分成多个折(Fold),每次使用其中一部分作为验证集,剩余部分作为训练集进行模型训练和评估。使用cross_val_score函数进行交叉验证,并指定模型和评估指标。示例代码:

from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier

# 定义模型和数据集
model = DecisionTreeClassifier()
X, y = load_iris(return_X_y=True)

# 执行交叉验证
scores = cross_val_score(model, X, y, cv=5)

# 输出每折的得分和平均得分
print("Cross-validation scores: ", scores)
print("Average score: ", scores.mean())

执行结果
在这里插入图片描述

学习曲线

通过绘制不同训练集大小下模型的训练和验证得分曲线,评估模型的拟合能力和泛化能力。使用learning_curve函数生成学习曲线数据,并绘制曲线图。示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve
from sklearn.linear_model import LogisticRegression

# 加载数据集
X, y = load_digits(return_X_y=True)

# 定义模型
model = LogisticRegression()

# 生成学习曲线数据
train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=5)

# 绘制学习曲线图
plt.plot(train_sizes, np.mean(train_scores, axis=1), label='Training score')
plt.plot(train_sizes, np.mean(test_scores, axis=1), label='Validation score')
plt.xlabel('Training Set Size')
plt.ylabel('Score')
plt.title('Learning Curve')
plt.legend(loc='best')
plt.show()

执行结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/707788.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android:ViewPager2

简介 ViewPager2内部使用RecyclerView实现,并提供了增强功能 特性 支持水平、垂直方向布局 android:orientation “vertical” 支持从右到左 android:layoutDirection “rtl” 禁止滑动 setUserInputEnabled() 可修改Fragment集合 对可修改的Fragment集合进行分…

深入探究Bean生命周期的扩展点:Bean Post Processor

概要 在Spring框架中,Bean生命周期的管理是非常重要的一部分。在Bean的创建、初始化和销毁过程中,Spring提供了一系列的扩展点,使开发者能够在不破坏原有功能的基础上,对Bean的生命周期进行定制化操作。其中,Bean Post…

LLM记录202304-202306

RLHF RAFT RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment code RRHF RRHF: Rank Responses to Align Language Models with Human Feedback without tears code p i = ∑ t lo

English Learning - L3 作业打卡 Lesson7 Day53 2023.6.28 周三

English Learning - L3 作业打卡 Lesson7 Day53 2023.6.28 周三 引言🍉句1: It was this moment that I asked myself that life-defining question:成分划分同化连读爆破语调 🍉句2: If my life were a book and I were the author, how would I want t…

基于Web的小学学科数字教学资源管理系统

摘要 小学学科数字教学资源管理是一个典型的学习项目,从教学资源、教材信息的统计和分析,在过程中会产生大量的、各种各样的数据。本文以小学学科数字教学资源管理系统为目标,采用B/S模式,以Springboot为开发框架,java…

计算机网络面经之TCP三次握手和四次挥手的详解

常见问题 1.详细描述三次握手和四次挥手的过程。 2.三次握手可以变成两次握手吗? 3.简述 TCP 连接和关闭的状态转移。 4.简述TCP 四次挥手的 TIME_WAIT状态,以及为什么需要有这个状态 重要的字段定义与作用 (1)序号(sequence nu…

循环双链表

目录 双向循环链表结构体初始化函数添加数据头插删除数据显示函数示例程序一(简易版本):运行结果:示例程序二输出结果: 双向循环链表 结构图示: 结构体 typedef struct node {int data;struct node* pre; //指向前驱struct …

C++迭代器

目录 1.iterator 2.数组 1.iterator 迭代器就是个内置指针&#xff0c;可以 -- &#xff0c;可以解引用。 迭代器分两种类型 iterator 和const_iterator&#xff08;只读&#xff0c;不能修改&#xff09; 迭代器要用作用域限定类型 vector<int>::iterator it; 如果不限制…

Yarn的实现原理详解

概要 Yarn作为分布式集群的资源调度框架&#xff0c;它的出现伴随着Hadoop的发展&#xff0c;使Hadoop从一个单一的大数据计算引擎&#xff0c;成为一个集存储、计算、资源管理为一体的完整大数据平台&#xff0c;进而发展出自己的生态体系&#xff0c;成为大数据的代名词。 Ya…

C++11新特性 智能指针

智能指针 nuique_ptr特点不允许拷贝构造和赋值运算符重载-> () *unique_ptr 删除器仿写删除文件删除普通对象 shared_ptr特点示意图仿写shared_ptr删除器部分特化拷贝构造 移动构造 && 左值赋值 和移动赋值完整实现 weak_ptr特点weak_ptr 实现解决循环引用弱指针一个…

java: 警告: 源发行版 11 需要目标发行版 11解决方案

出现这样的问题首先检查一下自己的项目结构是否使用的对应的jdk 如果这里是正确的&#xff0c;之后查看一下自己的pom文件中是否指定了正确的jdk 这里的时候你改完运行就会发现还会报错&#xff0c;一定要记得刷新一下maven 再重新启动项目&#xff0c;即解决

剑指 Offer 63: 股票的最大利润

最标准答案 不可以有前一项的影响&#xff0c;只能用来对比并不叠加 这里max设置0就会导致先行进入大于max的判断语句&#xff01; 无语了&#xff0c;自己把问题想的太复杂了&#xff01; class Solution {public int maxProfit(int[] prices) {if(prices.length<2) retur…

十二个常用化学文献检索网站

一、Royal Society of Chemistry英国皇家化学学会 英国皇家化学学会&#xff08;Royal Society of Chemistry&#xff0c;简称RSC&#xff09;&#xff0c;是一个国际权威的学术机构&#xff0c;是化学信息的一个主要传播机构和出版商&#xff0c;其出版的期刊及资料库一向是化…

886. 可能的二分法

链接&#xff1a;886. 可能的二分法 题解&#xff1a; class Solution { public:bool possibleBipartition(int n, vector<vector<int>>& dislikes) {// -1&#xff0c;代表这个点没有访问过&#xff0c; 0&#xff0c;1代表两个染色的组std::vector<int&…

python机器学习——聚类评估方法 K-Means聚类 神经网络模型基础

目录 聚类模型的评价方法&#xff08;1&#xff09;轮廓系数&#xff1a;&#xff08;2&#xff09;评价分类模型 【聚类】K-Means聚类模型&#xff08;1&#xff09;聚类步骤&#xff1a;&#xff08;2&#xff09;sklearn参数解析&#xff08;3&#xff09;k-means算法特点 神…

GPT模型训练实践(3)-参数训练和代码实践

一、参数训练 GPT模型参数的训练过程宏观上有两个大环节&#xff0c;先从上往下进行推理&#xff0c;再从下往上进行训练&#xff0c;具体过程为&#xff1a; 1、模型初始化参数随机取得&#xff1b; 2、计算模型输出与真实数据的差距&#xff08;损失值和梯度&#xff09; …

VS2019的安装和简单使用

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题&#xff0c;有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

【数据结构与算法】学校运动会管理系统(C/C++)

这是一个完整的项目&#xff0c;若有需要整个项目的压缩包&#xff08;源代码、文档、md文件等&#xff09;可私聊发送"学校运动会管理系统"。 问题描述 在“学校运动会管理系统”中&#xff0c;设有n个单位参加运动会&#xff08;单位可是学院、系、年级等&#xf…

Java 实现支付宝支付、退款、订单查询

最在开发一款APP&#xff0c;需要实现支付宝支付&#xff0c;记录一下实现过程 流程整体交互图如下所示 一、引入pom依赖 <dependency><groupId>com.aliyun</groupId><artifactId>aliyun-java-sdk-core</artifactId><version>4.0.3<…

【Java可执行命令】(八)JWS应用程序启动工具 javaws:深入解析Java Web Start应用程序的启动工具javaws ~

Java可执行命令之javaws 1️⃣ 概念&#x1f50d;JNLP (Java Network Launch Protocol) &#xff1f; 2️⃣ 优势3️⃣ 使用3.1 语法3.1.1 运行选项&#xff1a;-Xnosplash3.1.2 运行选项&#xff1a;-wait3.1.3 控制选项&#xff1a;-import [导入选项] < jnlp-file> 4️…