机器学习——KNN数据集划分

news2025/3/25 10:05:09

一、主要函数

sklearn.datasets.my_train_test_split()

该函数为Scikit-learn 中用于将数据集划分为训练集和测试集的函数,适用于机器学习模型的训练和验证。以下是详细解释:


​1、函数签名

train_test_split(
    *arrays,                  # 输入的数据集(可以是多个数组,如 X, y)
    test_size=None,           # 测试集的比例或样本数
    train_size=None,          # 训练集的比例或样本数
    random_state=None,        # 随机种子(确保结果可复现)
    shuffle=True,             # 是否打乱数据顺序
    stratify=None             # 是否分层抽样(保持类别比例)
)
返回值
  • 分割后的数据集,例如 X_train, X_test, y_train, y_test(顺序与输入一致)。

​2、参数详解

1. ​***arrays**​(必填)
  • 输入的数据集,可以是多个数组(如特征矩阵 X 和标签 y),支持同时拆分多个数据集。
  • 例如:X_train, X_test, y_train, y_test = train_test_split(X, y)
2. ​**test_size**​(默认 None
  • 浮点数:表示测试集占总数据的比例(如 test_size=0.2 表示 20% 作为测试集)。
  • 整数:表示测试集的绝对样本数(如 test_size=100)。
  • 如果 test_size 和 train_size 均为 None,默认 test_size=0.25
3. ​**train_size**​(默认 None
  • 类似 test_size,但指定训练集的比例或样本数。
  • 通常只需指定 test_size 或 train_size 中的一个。
4. ​**random_state**​(默认 None
  • 随机种子,保证每次分割结果一致。
  • 例如:random_state=42 使结果可复现。
5. ​**shuffle**​(默认 True
  • 是否在分割前打乱数据顺序。
  • 时间序列数据 需设置为 shuffle=False,避免破坏时间依赖性。
6. ​**stratify**​(默认 None
  • 指定分层抽样的参考标签,保持训练集和测试集的类别分布与原始数据一致。
  • 适用于分类任务中类别不平衡的数据。
  • 例如:stratify=y 会根据 y 的类别比例分割数据。

二、手动实现数据集划分

1.按比例计算

import numpy as np
import matplotlib.pyplot as plt #绘图模块
from sklearn.datasets import make_blobs  #聚类划分模块

#300样本,2个标签,3个聚类
x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 233,
    return_centers = False
)

#打印显示
plt.scatter(x[:,0], x[:,1], c = y,s = 15)
plt.show()

#设置随机种子,将x打乱,并返回索引值
np.random.seed(233)
shuffle = np.random.permutation(len(x))

#设置划分比例
train_size=0.7  
test_size =0.3

#得到对应比例的数据索引
train_index = shuffle[:int(len(x) * train_size)]
test_index = shuffle[int(len(x) * train_size):]

#得到数据集
x[train_index].shape, y[train_index].shape #结果:((210, 2), (210,))
x[test_index].shape, y[test_index].shape  #结果:((90, 2), (90,))

2、上述过程封装

import numpy as np
from matplotlib import pyplot as plt

#函数封装
def my_train_test_split(x, y, train_size = 0.7, random_state = None):
    if random_state:
        np.random.seed(random_state)
    shuffle = np.random.permutation(len(x))
    train_index = shuffle[:int(len(x) * train_size)]
    test_index = shuffle[int(len(x) * train_size):]
    return x[train_index], x[test_index], y[train_index], y[test_index]

#调用
x_train, x_test, y_train, y_test = my_train_test_split(x, y, train_size = 0.7, random_state = 233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape #结果:((210, 2), (90, 2), (210,), (90,))

#显示结果(训练集)
plt.scatter(x_train[:, 0], x_train[:, 1], c = y_train, s = 15)
plt.show()
#显示结果(测试集)
plt.scatter(x_test[:, 0], x_test[:, 1], c = y_test, s = 15)
plt.show()

三、KNN方法实现数据集快速划分

from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs

#制作数据集
x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 233,
    return_centers = False
)


#调用
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size = 0.7, random_state = 233)

#显示形状
x_train.shape, x_test.shape, y_train.shape, y_test.shape #结果:((210, 2), (90, 2), (210,), (90,))

#统计y_test标签
from collections import Counter
Counter(y_test) #结果:Counter({2: 34, 0: 25, 1: 31}),发现标签并不均匀

#加stratify = y限制,使其划分和标签类型一样
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size = 0.7, random_state = 233, stratify = y) 
print(Counter(y_test))  # 结果:Counter({2: 30, 0: 30, 1: 30})
print(Counter(y_train)) # 结果:Counter({0: 70, 2: 70, 1: 70})


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2320295.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring AI Alibaba ChatModel使用

一、对话模型(Chat Model)简介 1、对话模型(Chat Model) 对话模型(Chat Model)接收一系列消息(Message)作为输入,与模型 LLM 服务进行交互,并接收返回的聊天…

基于FPGA频率、幅度、相位可调的任意函数发生器(DDS)实现

基于FPGA实现频率、幅度、相位可调的DDS 1 摘要 直接数字合成器( DDS ) 是一种通过生成数字形式的时变信号并进行数模转换来产生模拟波形(通常为正弦波)的方法,它通过数字方式直接合成信号,而不是通过模拟信号生成技术。DDS主要被应用于信号生成、通信系统中的本振、函…

k8s高可用集群安装

一、安装负载均衡器 k8s负载均衡器 官方指南 1、准备三台机器 节点名称IPmaster-1192.168.1.11master-2192.168.1.12master-3192.168.1.13 2、在这三台机器分别安装haproxy和keepalived作为负载均衡器 # 安装haproxy sudo dnf install haproxy -y# 安装Keepalived sudo yum …

3DMAX曲线生成器插件CurveGenerator使用方法

1. 脚本功能简介 3DMAX曲线生成器插件CurveGenerator是一个用于 3ds Max 的样条线生成工具,用户可以通过简单的UI界面输入参数,快速生成多条样条线。每条样条线的高度值随机生成,且可以自定义以下参数: 顶点数量:每条…

六十天前端强化训练之第二十六天之Vue Router 动态路由参数大师级详解

欢迎来到编程星辰海的博客讲解 看完可以给一个免费的三连吗,谢谢大佬! 目录 一、知识讲解 1. Vue Router 核心概念 2. 动态路由参数原理 3. 参数传递方案对比 二、核心代码示例 1. 完整路由配置 2. 参数接收组件 3. 导航操作示例 三、实现效果示…

Model Context Protocol:下一代AI系统集成范式革命

在2023年全球AI工程化报告中,开发者面临的核心痛点排名前三的分别是:模型与业务系统集成复杂度(58%)、上下文管理碎片化(42%)、工具调用标准化缺失(37%)。传统API集成模式在对接大语言模型时暴露明显短板:RESTful接口无法承载动态上下文,GraphQL缺乏工具编排能力,gR…

Java多线程与高并发专题——Future 是什么?

引入 在上一篇Callable 和 Runnable 的不同?的最后,我们有提到和 Callable 配合的有一个 Future 类,通过 Future 可以了解任务执行情况,或者取消任务的执行,还可获取任务执行的结果,这些功能都是 Runnable…

DeepSeek本地搭建

1. 软件下载安装 Miniconda Miniconda下载地址 选择对应的版本下载,此处下载如下版本 Python 3.10 conda 25.1.1 安装完成后,配置环境变量,打开cmd命令窗口验证 Python Python的版本为 3.10 PyTorch PyTorch下载地址 后面通过命令下…

维普AIGC降重方法有哪些?

在学术写作和论文创作中,重复率过高是许多人面临的一大难题。随着科技的发展,维普 AIGC 为我们提供了一系列有效的降重方法。那么,维普AIGC降重方法有哪些呢?接下来就为大家详细介绍。 语义理解与改写 维普 AIGC 具备强大的语义理…

南京审计大学:《 面向工程审计行业的DeepSeek大模型应用指南》.pdf(免费下载)

大家好,我是吾鳴。 今天吾鳴要给大家分享的是由南京审计大学出品的《面向工程审计行业的DeepSeek大模型应用指南》,这份报告与《面向审计行业DeepSeek大模型操作指南》不同,这份报告更多的讲述DeepSeek怎么与工程审计行业结合,应该…

【前端】Canvas画布实现在线的唇膏换色功能

【前端】Canvas画布实现在线的唇膏换色功能 推荐超级课程: 本地离线DeepSeek AI方案部署实战教程【完全版】Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战目录 【前端】Canvas画布实现在线的唇膏换色功能背景概述以下是我们的实现方法!第一步 — 找…

arcgispro加载在线地图

World_Imagery (MapServer)https://services.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer添加arcgis server WMTS 服务 by xdcxdc.at xdc的个人站点。博客请转至 http://i.xdc.at/ http://xdc.at/map/wmts 添加WMTS服务器

华为网路设备学习-16 虚拟路由器冗余协议(VRRP)

VRRP是针对干线上三层网络设备(如:路由器、防火墙等)的网络虚拟化技术,提供冗余和状态监测等功能。确保在网络中的单点故障发生时,能够快速切换到备份设备,从而保证网络通信的连续性和可靠性。‌ VRRP通过…

封装一个分割线组件

最终样式 Vue2代码 <template><div class"sep-line"><div class"sep-label"><span class"sep-box-text"><slot>{{ title }}</slot> <!-- 默认插槽内容&#xff0c;如果没有传递内容则使用title -->&…

网络HTTPS协议

Https HTTPS&#xff08;Hypertext Transfer Protocol Secure&#xff09;是 HTTP 协议的加密版本&#xff0c;它使用 SSL/TLS 协议来加密客户端和服务器之间的通信。具体来说&#xff1a; • 加密通信&#xff1a;在用户请求访问一个 HTTPS 网站时&#xff0c;客户端&#x…

OSASIS(One-Shot Structure-Aware Stylized Image Synthesis)

文章目录 摘要abstract论文摘要方法损失函数实验结论 总结 摘要 本周阅读了一篇关于新型图像风格化的论文《One-Shot Structure-Aware Stylized Image Synthesis》&#xff0c;旨在解决现有GAN模型在风格化过程中难以保持输入图像结构的问题。通过分离图像的结构和语义信息&am…

C++学习之网盘项目单例模式

目录 1.知识点概述 2.单例介绍 3.单例饿汉模式 4.饿汉模式四个版本 5.单例类的使用 6.关于token的作用和存储 7.样式表使用方法 8.qss文件中选择器介绍 9.qss文件样式讲解和测试 10.qss美化登录界面补充 11.QHTTPMULTIPART类的使用 12.文件上传协议 13.文件上传协议…

Apache Flink技术原理深入解析:任务执行流程全景图

前言 本文隶属于专栏《大数据技术体系》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见大数据技术体系 思维导图 📌 引言 Apache Flink 作为一款高性能的分布式流处理引擎,其内部执行机制精妙而复杂。本文将…

RAG(Retrieval-Augmented Generation)基建之PDF解析的“魔法”与“陷阱”

嘿&#xff0c;亲爱的算法工程师们&#xff01;今天咱们聊一聊PDF解析的那些事儿&#xff0c;简直就像是在玩一场“信息捉迷藏”游戏&#xff01;PDF文档就像是个调皮的小精灵&#xff0c;表面上看起来规规矩矩&#xff0c;但当你想要从它那里提取信息时&#xff0c;它就开始跟…

C语言【文件操作】详解中(会使用fgetc,fputc,fgets,fputs,fscanf,fprintf,fread,fwrite函数)

引言 介绍和文件操作中文件的顺序读写相关的函数 看这篇博文前&#xff0c;希望您先仔细看一下这篇博文&#xff0c;理解一下文件指针和流的概念&#xff1a;C语言【文件操作】详解上-CSDN博客文章浏览阅读606次&#xff0c;点赞26次&#xff0c;收藏4次。先整体认识一下文件是…