机器学习探索计划——数据集划分

news2024/11/18 13:43:39

文章目录

  • 导包
  • 手写数据划分函数
  • 使用sklearn内置的划分数据函数
    • stratify=y理解举例

导包

import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs

手写数据划分函数

x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 666,
    return_centers = False
)

make_blobs:scikit-learn(sklearn)库中的一个函数,用于生成聚类任务中的合成数据集。它可以生成具有指定特征数和聚类中心数的随机数据集。

n_samples:生成的样本总数,本例中为 300。
n_features:生成的每个样本的特征数,本例中为 2。
centers:生成的簇的数量,本例中为 3。
cluster_std:每个簇中样本的标准差,本例中为 1。
center_box:每个簇中心的边界框(bounding box)范围,本例中为 (-10, 10)。
random_state:随机种子,用于控制数据的随机性,本例中为 666。
return_centers:是否返回生成的簇中心点,默认为 False,在本例中不返回。

plt.scatter(x[:, 0], x[:, 1], c = y, s = 15)
plt.show()

在这里插入图片描述

x[:, 0]:表示取 x 数据集中所有样本的第一个特征值。
x[:, 1]:表示取 x 数据集中所有样本的第二个特征值。
c=y:表示使用标签 y 对样本点进行颜色编码,即不同的标签值将使用不同的颜色进行展示。
s=15:表示散点的大小为 15,即每个样本点的显示大小。

index = np.arange(20)
np.random.shuffle(index)
index

output: array([12, 15, 7, 11, 14, 16, 6, 5, 0, 1, 2, 19, 13, 4, 18, 9, 8,
10, 3, 17])

np.random.permutation(20)

output: array([ 6, 4, 11, 13, 18, 1, 8, 3, 10, 9, 7, 0, 15, 17, 19, 16, 5,
2, 14, 12])

np.random.seed(666)
shuffle = np.random.permutation(len(x))
shuffle

output:
array([235, 169, 17, 92, 234, 15, 0, 152, 176, 243, 98, 260, 96,
123, 266, 220, 109, 286, 185, 177, 160, 11, 50, 246, 258, 254,
34, 229, 154, 66, 285, 214, 237, 95, 7, 205, 262, 281, 110,
64, 111, 87, 263, 38, 153, 129, 273, 255, 208, 56, 162, 106,
277, 224, 178, 265, 108, 104, 101, 158, 248, 29, 181, 62, 14,
75, 118, 201, 41, 150, 131, 183, 288, 291, 76, 293, 267, 1,
165, 12, 278, 53, 209, 114, 71, 135, 184, 206, 244, 61, 211,
213, 128, 3, 143, 296, 227, 242, 94, 251, 284, 253, 89, 49,
159, 35, 268, 249, 197, 55, 167, 146, 23, 283, 187, 173, 124,
68, 250, 189, 186, 5, 221, 65, 40, 119, 74, 22, 19, 59,
188, 231, 44, 137, 31, 256, 43, 85, 149, 134, 218, 120, 81,
67, 239, 195, 207, 240, 182, 179, 90, 216, 180, 47, 299, 30,
163, 193, 48, 245, 138, 28, 257, 125, 170, 157, 259, 290, 200,
203, 215, 238, 194, 121, 298, 73, 97, 8, 130, 105, 190, 6,
36, 27, 32, 144, 4, 117, 115, 171, 136, 84, 10, 113, 233,
247, 72, 292, 198, 252, 82, 228, 37, 39, 33, 280, 272, 79,
116, 172, 202, 226, 271, 145, 13, 78, 196, 274, 26, 297, 191,
232, 52, 20, 230, 18, 58, 294, 140, 132, 287, 217, 25, 133,
83, 99, 93, 21, 241, 168, 147, 275, 212, 127, 54, 199, 282,
107, 151, 289, 88, 100, 264, 45, 77, 295, 9, 166, 57, 80,
155, 279, 86, 219, 2, 269, 126, 102, 142, 192, 161, 103, 42,
261, 16, 175, 122, 174, 164, 112, 148, 24, 139, 276, 141, 204,
210, 69, 46, 63, 225, 270, 156, 223, 60, 51, 222, 91, 70,
236])

np.random.seed(666)使得随机数结果可复现

shuffle.shape

output: (300,)

train_size = 0.7
train_index = shuffle[:int(len(x) * train_size)]
test_index = shuffle[int(len(x) * train_size):]
train_index.shape, test_index.shape

output: ((210,), (90,))

x[train_index].shape, y[train_index].shape, x[test_index].shape, y[test_index].shape

output: ((210, 2), (210,), (90, 2), (90,))

def my_train_test_split(x, y, train_size = 0.7, random_state = None):
    if random_state:
        np.random.seed(random_state)
    shuffle = np.random.permutation(len(x))
    train_index = shuffle[:int(len(x) * train_size)]
    test_index = shuffle[int(len(x) * train_size):]
    return x[train_index], x[test_index], y[train_index], y[test_index]
x_train, x_test, y_train, y_test = my_train_test_split(x, y, train_size=0.7, random_state=233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape

output: ((210, 2), (90, 2), (210,), (90,))

plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, s=15)  # y_train一样的,颜色相同
plt.show()

在这里插入图片描述

plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, s=15)
plt.show()

在这里插入图片描述

使用sklearn内置的划分数据函数

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape

output: ((210, 2), (90, 2), (210,), (90,))

from collections import Counter
Counter(y_test)

output: Counter({2: 34, 0: 29, 1: 27})

x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=666, stratify=y)

stratify=y: 使用标签 y 进行分层采样,确保训练集和测试集中的类别分布相对一致。
这样做的好处是,在训练过程中,模型可以接触到各个类别的样本,从而更好地学习每个类别的特征和模式,提高模型的泛化能力。

Counter(y_test)

output: Counter({1: 30, 0: 30, 2: 30})

stratify=y理解举例

x = np.random.randn(1000, 2)  # 1000个样本,2个特征
y = np.concatenate([np.zeros(800), np.ones(200)])  # 800个负样本,200个正样本

# 使用 stratify 进行分层采样
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.7, random_state=42, stratify=y)

# 打印训练集中正负样本的比例。通过使用 np.mean,我们可以方便地计算出比例或平均值,以了解数据集的分布情况或对模型性能进行评估。
print("训练集中正样本比例:", np.mean(y_train == 1))
print("训练集中负样本比例:", np.mean(y_train == 0))

# 打印测试集中正负样本的比例
print("测试集中正样本比例:", np.mean(y_test == 1))
print("测试集中负样本比例:", np.mean(y_test == 0))

output:
训练集中正样本比例: 0.2
训练集中负样本比例: 0.8
测试集中正样本比例: 0.2
测试集中负样本比例: 0.8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1249991.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux中vim的编译链接和gcc

gcc,g,gdb的安装 命令行写gcc,g,gdb根据提示安装:sudo apt install gcc/g/gdb gcc分布编译链接 (1)预编译: gcc -E main.c -o main.i (2)编译: gcc -S main.i -o main.s (3)汇编: gcc -c main.s -o main.o (4)链接 gcc main.o -o main 执行: ./main 或者:全路径/main 编译链…

讲述 什么是鸿蒙 为什么需要鸿蒙 为什么要学习鸿蒙

首先 我们为什么要学习鸿蒙开发? 因为 鸿蒙发展前景巨大 鸿蒙自发布依赖 一直受社会各界关注 强两百的 App厂商 大部分接受了与鸿蒙的合作 硬件也有非常多与鸿蒙合作的厂商 鸿蒙的合作企业基本已经覆盖整个互联网客户的主流需求 所以鸿蒙的崛起不过是早晚的问题 …

NX二次开发UF_CURVE_ask_line_arc_data 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_line_arc_data Defined in: uf_curve.h int UF_CURVE_ask_line_arc_data(tag_t line_arc_feat_id, UF_CURVE_line_arc_t * line_arc_data ) overview 概述 Populates…

远程网络安全访问JumpServer:使用cpolar内网穿透搭建固定公网地址

文章目录 前言1. 安装Jump server2. 本地访问jump server3. 安装 cpolar内网穿透软件4. 配置Jump server公网访问地址5. 公网远程访问Jump server6. 固定Jump server公网地址 前言 JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。JumpS…

小程序中的大道理之三--对称性和耦合问题

再继续扒 继续 前一篇 的话题, 在那里, 提到了抽象, 耦合及 MVC, 现在继续探讨这些, 不过在此之前先说下第一篇里提到的对称性. 注: 以下讨论建立在前面的基础之上, 为控制篇幅起见, 这里将不再重复前面说到的部分, 如果您还没看过前两篇章, 阅读起来可能会有些困难. 这是第一…

电源控制系统架构(PCSA)之系统控制处理器组件

目录 6.4 系统控制处理器 6.4.1 SCP组件 SCP处理器Core SCP处理器Core选择 SCP处理器核内存 系统计数器和通用计时器 看门狗 电压调节器控制 时钟控制 系统控制 信息接口 电源策略单元 传感器控制 外设访问 系统访问 6.4 系统控制处理器 系统控制处理器(SCP)是…

【LeetCode】每日一题 2023_11_24 统计和小于目标的下标对数目(暴力/双指针)

文章目录 刷题前唠嗑题目:统计和小于目标的下标对数目题目描述代码与解题思路 结语 刷题前唠嗑 LeetCode?启动!!! 题目:统计和小于目标的下标对数目 题目链接:2824. 统计和小于目标的下标对数…

Kafka 集群如何实现数据同步

Kafka 介绍 Kafka 是一个高吞吐的分布式消息系统,不但像传统消息队列(RaabitMQ、RocketMQ等)那样能够【异步处理、流量消峰、服务解耦】 还能够把消息持久化到磁盘上,用于批量消费。除此之外由于 Kafka 被设计成分布式系统&…

Qt 软件调试(二)使用dump捕获崩溃信息

Qt应用程序异常崩溃该怎么办&#xff0c;生成dump文件再回溯分析&#xff0c;可以快速且准确的帮助我们定位到崩溃的点。那么&#xff0c;本章我们分享下如何在Qt中生成dump文件。 一、使用minudump捕获崩溃信息 #include <QCoreApplication> #include <QDir> #i…

判断序列Series中的值是否都不一样 PandasSeries中的方法:is_unique()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 判断序列Series中的值是否都不一样 PandasSeries中的方法&#xff1a; is_unique() 选择题 请问下列程序运行的的结果是&#xff1a; import pandas as pd s1 pd.Series([1,2,3]) print("…

[PyTorch][chapter 64][强化学习-DQN]

前言&#xff1a; DQN 就是结合了深度学习和强化学习的一种算法&#xff0c;最初是 DeepMind 在 NIPS 2013年提出&#xff0c;它的核心利润包括马尔科夫决策链以及贝尔曼公式。 Q-learning的核心在于Q表格&#xff0c;通过建立Q表格来为行动提供指引&#xff0c;但这适用于状态…

现货黄金区间交易的两个要点

在现货黄金市场中&#xff0c;我们常碰到横盘区间行情。有区间&#xff0c;就终究会出现突破&#xff0c;因为金价不可能缺乏方向而一直在区间内运行。那既然要突破&#xff0c;我们又应当如何应对和交易呢&#xff1f;下面我们就来讨论一下。 切忌在突破发生时马上跟随突破方向…

EI期刊完整程序:MEA-BP思维进化法优化BP神经网络的回归预测算法,可作为对比预测模型,丰富内容,直接运行,免费

适用平台&#xff1a;Matlab 2020及以上 本程序参考中文EI期刊《基于MEA⁃BP神经网络的建筑能耗预测模型》&#xff0c;程序注释清晰&#xff0c;干货满满&#xff0c;下面对文章和程序做简要介绍。 适用领域&#xff1a;风速预测、光伏功率预测、发电功率预测、碳价预测等多…

leetcode刷题详解三

2. 两数相加 思路&#xff1a;直接加&#xff0c;注意进位条件不要用if&#xff0c;核心代码在于sum l1->val l2->val carry; ListNode* addTwoNumbers(ListNode* l1, ListNode* l2) {ListNode* dummy new ListNode();ListNode* dummy_head dummy;int carry 0;int …

OSG编程指南<十>:OSG几何体的绘制

1、场景基本绘图类 在 OSG 中创建几何体的方法比较简单&#xff0c;通常有 3 种处理几何体的手段&#xff1a; 使用松散封装的OpenGL 绘图基元&#xff1b;使用 OSG 中的基本几何体&#xff1b;从文件中导入场景模型。 使用松散封装的OpenGL 绘图基元绘制几何体具有很强的灵活…

完美解决:vue.js:6 TypeError: Cannot read properties of undefined (reading ‘0‘)

Vue项目出现以下报错&#xff1a; 原因&#xff1a; 在渲染的时候&#xff0c;不满足某个条件而报错&#xff0c;或者某个属性丢失或后台没传过来导致 我这里出现的原因是后台给我传递过来的数组中&#xff0c;其中有一条少传了一个我在渲染时需要用的属性&#xff0c;没让后台…

毅速:3D打印随形透气钢为解决模具困气提供了新助力

在模具行业中&#xff0c;困气是一个较常见的问题。解决困气问题的方法有很多&#xff0c;透气钢就是其一。传统的制造的透气钢往往存在一些不足&#xff0c;如加工难度大、无法满足复杂形状的需求等。随着3D打印技术的发展&#xff0c;一种新型的随形透气钢技术逐渐崭露头角&a…

HBase数据模型杂谈

1.概述 HBase是一个稀疏、多维度、排序的映射表&#xff0c;这张表的索引是行键、列族、列限定符和时间戳。 每个值是一个未经解释的字符串&#xff0c;没有数据类型。用户在表中存储数据&#xff0c;每一行都有一个可排序的行键和任意多的列。表在水平方向由一个或者多个列族…

【Jenkins】jenkins发送邮件报错:Not sent to the following valid addresses:

jenkins报错未能发送邮件到指定邮箱 注意&#xff1a;这是在系统配置中修改 在系统配置》邮件通知中添加配置信息 注意&#xff1a;这个是在项目的配置下修改 配置完成后&#xff0c;重新执行发送邮件成功&#xff01;&#xff01;&#xff01;

AI动画制作 StableDiffusion

1.brew -v 2.安装爬虫项目包所必需的python和git等系列系统支持部件 brew install cmake protobuf rust python3.10 git wget pod --version brew link --overwrite cocoapods 3.从github网站克隆stable-diffusion-webui爬虫项目包至本地 ssh-add /Users/haijunyan/.ssh/id_rs…