机器学习:训练集与测试集分割train_test_split

news2025/1/17 0:22:37

1 引言

在使用机器学习训练模型算法的过程中,为提高模型的泛化能力、防止过拟合等目的,需要将整体数据划分为训练集和测试集两部分,训练集用于模型训练,测试集用于模型的验证。此时,使用train_test_split函数可便捷高效的实现数据训练集与测试集的划分。

2 train_test_split介绍

train_test_split函数来自scikit-learn库(也称为sklearn),安装命令:

pip install sklearn

函数的导入:

from sklearn.model_selection import train_test_split

1.1 函数定义

def train_test_split(*arrays,test_size=None,train_size=None,random_state=None,
    shuffle=True,stratify=None,):

1.2 参数说明

  • *arrays: 单个数组或元组,表示需要划分的数据集。如果传入多个数组,则必须保证每个数组的第一维大小相同。
  • test_size: 测试集的大小(占总数据集的比例,值为0.0-1.0,表示测试集占总样本比例)。默认值为0.25,即将传入数据的25%作为测试集。
  • train_size: 训练集的大小(占总数据集的比例,值为0.0-1.0,表示训练集占总样本比例)。默认值为None,此时和test_size互补,即训练集的大小为(1-test_size)。
  • random_state: 随机数种子。可以设置一个整数,用于复现结果。默认为None。其实是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。(比如每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。)
  • shuffle: 是否随机打乱数据。默认为True。
  • stratify: 可选参数,用于进行分层抽样。传入标签数组,保证划分后的训练集和测试集中各类别样本比例与原始数据集相同。默认为None,即普通的随机划分。(此参数作用是保持测试集与整个数据集里的数据分类比例一致,比如有1000个数据,800个属于A类,200个属于B类。设置stratify = y_lable,test_size=0.25,split之后数据组成如下:training: 750个数据,其中600个属于A类,150个属于B类;testing: 250个数据,其中200个属于A类,50个属于B类)

1.3 返回值

该函数返回一个元组(X_train, X_test, y_train, y_test),其中X_train表示训练集的特征数据,X_test表示测试集的特征数据,y_train表示训练集的标签数据,y_test表示测试集的标签数据。

1.4 注意事项

  • test_sizetrain_size必须至少有一个设置为非None
  • 当传入多个数组时,请确保每个数组的第一维大小相同。
  • random_state要设置一个整数值,从而保证每次获取相同的训练集和测试集
  • 当使用分层抽样时,请确保传入的标签数组是正确的。

3 train_test_split使用

3.1 使用train_test_split分割Iris数据

from sklearn import datasets
from sklearn.model_selection import train_test_split

# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)
print(X_train)
print(X_test)

结果展示:

X_train=[[6.5 2.8 4.6 1.5]
 [6.7 2.5 5.8 1.8]
 [6.8 3.  5.5 2.1]
 [5.1 3.5 1.4 0.3]
 [6.  2.2 5.  1.5]
 ......此处数据省略
 [4.9 3.6 1.4 0.1]]
X_test=[[5.8 4.  1.2 0.2]
 [5.1 2.5 3.  1.1]
 [6.6 3.  4.4 1.4]
 [5.4 3.9 1.3 0.4]
 [7.9 3.8 6.4 2. ]
 ......此处数据省略
 [5.2 3.4 1.4 0.2]]

3.2 使用train_test_split分割水果识别数据

在/opt/dataset下存放着水果图片的分类数据文件夹(文件夹名称为标签),每个文件夹下存储着多张对应标签的水果图片,如下所示:

以apple文件夹为例,图片内容如下:

数据加载和分割数据集的代码如下:

from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split

# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                     mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5]
                                ), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset', transform=transform)

# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)

batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/822134.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go语言性能优化建议与pprof性能调优详解——结合博客项目实战

文章目录 性能优化建议Benchmark的使用slice优化预分配内存大内存未释放 map优化字符串处理优化结构体优化atomic包小结 pprof性能调优采集性能数据服务型应用go tool pprof命令项目调优分析修改main.go安装go-wrk命令行交互界面图形化火焰图 性能优化建议 简介: …

python用来做什么的,python用来干什么的

大家好,小编为大家解答python用来干什么的的问题。很多人还不知道python用来做什么的,现在让我们一起来看看吧! 编程语言python是用来干什么的? python的作用: 1、系统编程:提供API(ApplicationProgrammin…

【算法提高:动态规划】1.3 背包模型 TODO

文章目录 例题列表423. 采药(01背包)1024. 装箱问题(大小和价值相等的01背包)1022. 宠物小精灵之收服(二维费用的背包问题)补充:相关题目——8. 二维费用的背包问题 278. 数字组合(0…

阿里云负载均衡SLB网络型NLB负载均衡架构性能详解

阿里云网络型负载均衡NLB是阿里云推出的新一代四层负载均衡,支持超高性能和自动弹性能力,单实例可以达到1亿并发连接,帮您轻松应对高并发业务。网络型负载均衡NLB具有超强性能、自动弹性伸缩、高可用、TCPSSL卸载、多场景流量分发和丰富的高级…

【初阶C语言】数组

目录 一、一维数组 1.一维数组的创建和初始化 2.一维数组的使用 3.一维数组在内存中的存储 二、二维数组 1.二维数组的创建 2.二维数组的初始化 3.二维数组的使用 4.二维数组在内存中的存储 三、数组的越界问题 四、数组传参 前言: 数组在C语言中是一个…

express学习笔记6 - 用户模块

新建router/user.js const express require(express) const routerexpress.Router() router.get(/login, function(req, res, next) {console.log(/user/login, req.body)res.json({code: 0,msg: 登录成功})})module.exportsrouter 在router/user.js引入并使用 const us…

一起学算法(链表篇)

1.链表的概念 对于顺序存储的结构最大的缺点就是插入和排序的时候需要移动大量的元素&#xff0c;所以链表的出生由此而来 先上代码&#xff1a; // 链表 public class LinkedList<T extends Comparable> {// 结点类class Node {T ele; // 当前结点上的元素内容Node ne…

java学习路程之篇四、进阶知识、石头迷阵游戏、绘制界面、打乱石头方块、移动业务、游戏判定胜利、统计步数、重新游戏

文章目录 1、绘制界面2、打乱石头方块3、移动业务4、游戏判定胜利5、统计步数6、重新游戏7、完整代码 1、绘制界面 2、打乱石头方块 3、移动业务 4、游戏判定胜利 5、统计步数 6、重新游戏 7、完整代码 java之石头迷阵单击游戏、继承、接口、窗体、事件、组件、按钮、图片

【Spring】Spring 中事务的实现

目录 1.编程式事务&#xff08;手动编写代码&#xff09;2.声明式事务&#xff08;利用注解&#xff09;2.1 Transactional作用范围2.2 Transactional参数说明2.3 Transactional工作原理 3.Spring 中设置事务隔离级别3.1 事务四大特性ACID3.2 事务的隔离级别3.2 Spring中设置事…

(13) Qt事件系统(two)

目录 事件分发函数 无边框窗口拖动 自定义事件 发送事件的函数 自定义事件 系统定义的事件号 自定义事件号 自定义事件类 发送和处理事件 sendEvent与postEvent的区别 栈区对象 堆区对象 事件传播机制 事件传播的过程 事件传播到父组件 鼠标单击事件与按钮单击信…

【STM32零基础入门教程03】GPIO输入输出之GPIO框图分析

本章节主要讲解点亮LED的基本原理&#xff0c;以及GPIO框图的讲解。 如何点亮LED&#xff08;输出&#xff09; 首先我们查看原理图&#xff0c;观察电路图中LED的连接情况&#xff0c;如下图可以看出我们的板子中LED一端通过限流电阻连接的PB0另一端连接的是高电平VCC&#xf…

30. 利用linprog 解决 生产决策问题(matlab程序)

1.简述 线线规划的几个基本性质&#xff1a;【文献[1]第46页】 (1)线性规划问题的可行域如果非空&#xff0c;则是一个凸集-凸多面体&#xff1b; (2)如果线性规划问题有最优解&#xff0c;那么最优解可在可行域的顶点中确定&#xff1b; (3)如果可行域有界&#xff0c;且可行域…

【数据中台】DataX源码进行二开插件

参考官方 使用的离线数据同步工具/平台&#xff0c;实现不同数据库等各种异构数据源之间高效的数据同步功能 工具部署 https://github.com/alibaba/DataX/blob/master/userGuid.md 拉取下来的代码&#xff0c;pom.xml里面注释 <!--<module>tsdbreader</module&g…

大整数截取解决方法(java代码)

大整数截取解决方法&#xff08;java代码&#xff09; 描述输入描述输出描述输入示例输出示例前置知识&#xff1a;代码 解题思路来自这个博客&#xff1a;简单^不简单 https://blog.csdn.net/younger_china/article/details/126376374 描述 花花有一个很珍贵的数字串&#xf…

P4053 [JSOI2007] 建筑抢修(贪心)(内附封面)

[JSOI2007] 建筑抢修 题目描述 小刚在玩 JSOI 提供的一个称之为“建筑抢修”的电脑游戏&#xff1a;经过了一场激烈的战斗&#xff0c;T 部落消灭了所有 Z 部落的入侵者。但是 T 部落的基地里已经有 N N N 个建筑设施受到了严重的损伤&#xff0c;如果不尽快修复的话&#x…

python项目开发案例集锦,python开发程序流程

大家好&#xff0c;给大家分享一下python项目开发案例集锦 源码&#xff0c;很多人还不知道这一点。下面详细解释一下。现在让我们来看看&#xff01; 今天任务 1.创建Python项目为pythontest1以及test1.py文件 2.修改字号 3.输入九九乘法表程序&#xff0c;编译调试执行 4.配置…

Python selenium对应的浏览器chromedriver版本不一致

1、chrome和chromedriver版本不一致导致的&#xff0c;我们只需要升级下chromedriver的版本即可 浏览器版本查看 //打开google浏览器直接访问&#xff0c;查看浏览器版本 chrome://version/ 查看chromedriver的版本 //查看驱动版本 chromedriver chromedriver下载 可看到浏…

基于 Debian GNU/Linux 12 “书虫 “的Neptune 8.0 “Juna “来了

导读Neptune Linux 发行版背后的团队发布了 Neptune 8.0&#xff0c;作为这个基于 Debian 的 GNU/Linux 发行版的重大更新&#xff0c;它围绕最新的 KDE Plasma 桌面环境构建。 Neptune 8.0 被命名为 “Juna”&#xff0c;是在Neptune 7.5 发布 11 个月后发布的&#xff0c;也是…

2.1 密码学基础

数据参考&#xff1a;CISP官方 目录 密码学基本概念对称密码算法非对称密码算法哈希函数与数字签名公钥基础设施 一、密码学基本概念 1、密码学形成与发展 发展历程 古典密码学 (1949年之前) 主要特点&#xff1a;数据的安全基于算法的保密 近代密码学 (1949~1975年…

第4章 案例研究:JavaScript图片库

案例 html部分 <h1 id"title">图片1</h1> <ul><li><!-- onclick绑定点击事件&#xff0c;this为触发dom&#xff0c;return false阻止默认行为 --><a onclick"show_img(this); return false" title"图片1" h…