【chatgpt】train_split_test的random_state

news2025/1/11 7:44:42

在使用train_test_split函数划分数据集时,random_state参数用于控制随机数生成器的种子,以确保划分结果的可重复性。这样,无论你运行多少次代码,只要使用相同的random_state值,得到的训练集和测试集划分就会是一样的。

使用 train_test_split 示例

以下是一个示例,展示如何使用train_test_split函数进行数据集划分,并设置random_state参数:
程序输出结果
Training set shape: (80, 10), (80,)
Test set shape: (20, 10), (20,)

import numpy as np
from sklearn.model_selection import train_test_split

# 假设我们有一些数据
X = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
y = np.random.randint(0, 2, 100)  # 100个样本的标签(0或1)

# 使用train_test_split进行数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 打印划分后的数据集形状
print(f'Training set shape: {X_train.shape}, {y_train.shape}')
print(f'Test set shape: {X_test.shape}, {y_test.shape}')

在这个示例中:

  • X 是特征矩阵,包含100个样本,每个样本有10个特征。
  • y 是标签数组,包含100个样本的标签。
  • test_size=0.2 表示将数据集的20%用作测试集,剩下的80%用作训练集。
  • random_state=42 用于确保划分的可重复性。

为什么使用 random_state

使用 random_state 可以确保在多次运行代码时,得到的训练集和测试集划分是一致的,这在以下情况下特别有用:

  1. 调试和开发: 在开发和调试模型时,使用相同的 random_state 可以确保数据划分的一致性,从而使得调试更加容易。
  2. 实验的可重复性: 在进行实验时,使用相同的 random_state 可以确保实验结果的可重复性,使得其他人可以验证你的结果。
  3. 比较模型性能: 在比较不同模型的性能时,使用相同的 random_state 可以确保每个模型都使用相同的训练集和测试集,从而使比较更加公平。
    在这里插入图片描述

示例:比较大数据集和小数据集的模型性能

假设我们有一个大数据集和一个小数据集,我们想要比较它们在同一模型上的性能。我们可以使用 train_test_split 进行数据集划分,并设置 random_state 以确保划分的可重复性。

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
# 创建大数据集和小数据集
X_large = np.random.rand(1000, 10)
y_large = np.random.rand(1000, 1)

X_small = np.random.rand(100, 10)
y_small = np.random.rand(100, 1)

# 使用train_test_split进行数据集划分
X_train_large, X_test_large, y_train_large, y_test_large = train_test_split(X_large, y_large, test_size=0.2, random_state=42)
X_train_small, X_test_small, y_train_small, y_test_small = train_test_split(X_small, y_small, test_size=0.2, random_state=42)

# 转换为张量
X_train_large = torch.tensor(X_train_large, dtype=torch.float32)
y_train_large = torch.tensor(y_train_large, dtype=torch.float32)
X_test_large = torch.tensor(X_test_large, dtype=torch.float32)
y_test_large = torch.tensor(y_test_large, dtype=torch.float32)

X_train_small = torch.tensor(X_train_small, dtype=torch.float32)
y_train_small = torch.tensor(y_train_small, dtype=torch.float32)
X_test_small = torch.tensor(X_test_small, dtype=torch.float32)
y_test_small = torch.tensor(y_test_small, dtype=torch.float32)

# 创建数据加载器
train_loader_large = DataLoader(TensorDataset(X_train_large, y_train_large), batch_size=32, shuffle=True)
test_loader_large = DataLoader(TensorDataset(X_test_large, y_test_large), batch_size=32, shuffle=False)

train_loader_small = DataLoader(TensorDataset(X_train_small, y_train_small), batch_size=32, shuffle=True)
test_loader_small = DataLoader(TensorDataset(X_test_small, y_test_small), batch_size=32, shuffle=False)

# 定义简单的线性模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.linear(x)

# 训练模型的通用函数
def train_model(train_loader, num_epochs=50, learning_rate=0.01):
    model = SimpleModel()
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    train_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0.0
        for batch_x, batch_y in train_loader:
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
        epoch_train_loss /= len(train_loader)
        train_losses.append(epoch_train_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}')

    return model, train_losses

# 训练大数据集的模型
print("Training on large dataset")
model_large, train_losses_large = train_model(train_loader_large)

# 训练小数据集的模型
print("\nTraining on small dataset")
model_small, train_losses_small = train_model(train_loader_small)

# 绘制训练损失曲线
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(train_losses_large) + 1), train_losses_large, label='Large Dataset Train Loss')
plt.plot(range(1, len(train_losses_small) + 1), train_losses_small, label='Small Dataset Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss Comparison')
plt.savefig("test")

# 在测试集上计算最终的评估指标(例如均方误差)
def evaluate_model(model, test_loader):
    model.eval()
    test_loss = 0.0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            test_loss += loss.item()
    test_loss /= len(test_loader)
    return test_loss

# 评估大数据集的模型
final_test_loss_large = evaluate_model(model_large, test_loader_large)

# 评估小数据集的模型
final_test_loss_small = evaluate_model(model_small, test_loader_small)

print(f'Final Test Loss on Large Dataset: {final_test_loss_large:.4f}')
print(f'Final Test Loss on Small Dataset: {final_test_loss_small:.4f}')

结果分析

通过上述代码,可以得到大数据集和小数据集在训练过程中的损失曲线以及最终的测试损失。根据这些信息,可以比较它们的收敛情况和性能。

  • 损失曲线: 通过观察损失曲线,判断模型在两个数据集上的收敛速度和稳定性。如果两者曲线形状相似,并且在同一水平上趋于平稳,可以认为它们收敛到了相似的程度。

  • 最终测试损失: 最终测试损失值可以用于直接比较两个模型的性能。如果两者最终测试损失值接近,则可以认为它们的模型性能相当。

通过使用相同的 random_state 值,确保数据集划分的一致性,从而使得比较结果更加公平和具有可重复性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1854719.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vision Pro的3D跟踪能力:B端应用的工作流、使用教程和经验总结

Vision Pro的最新3D跟踪能力为工业、文博、营销等多个B端领域带来了革命性的交互体验。本文将详细介绍这一功能的工作流、使用教程,并结合实际经验进行总结。 第一部分:工作流详解 一、对象扫描 使用Reality Composer iPhone应用程序对目标对象进行3D扫描,如吉他或雕塑,…

粉笔1000题——判断推理

目录 一、图形推理1. 位置规律平移旋转、翻转 二、定义判断三、类比推理四、逻辑判断 一、图形推理 1. 位置规律 平移 旋转、翻转 二、定义判断 三、类比推理 四、逻辑判断

红队内网攻防渗透:内网渗透之内网对抗:横向移动篇PTH哈希PTT票据PTK密匙Kerberoast攻击点TGTNTLM爆破

红队内网攻防渗透 1. 内网横向移动1.1 首要知识点1.2 PTH1.2.1 利用思路第1种:利用直接的Hash传递1.2.1.1、Mimikatz1.2.2 利用思路第2种:利用hash转成ptt传递1.2.3 利用思路第3种:利用hash进行暴力猜解明文1.2.4 利用思路第4种:修改注册表重启进行获取明文1.3 PTT1.3.1、漏…

养殖自动化温控系统:现代养殖场的智能守护神

现代农业养殖业中,养殖自动化温控系统已经成为提高生产效率和保障动物福利的关键技术之一。本篇文章将深入介绍养殖自动化温控系统的原理、组成、优势及其在不同类型养殖场中的应用实例,并展望该技术的未来发展。 一、养殖自动化温控系统概述 养殖自动…

LabVIEW编程控制ABB机械臂

使用LabVIEW编程控制ABB机械臂是一项复杂但十分有价值的任务。通过LabVIEW,可以实现对机械臂的精确控制和监控,提升自动化水平和操作效率。 1. 项目规划和硬件选型 1.1 确定系统需求 运动控制:确定机械臂需要执行的任务,如抓取、…

易优cms内核简洁文章资讯作文范文网站模板源码(带手机版)

易优cms内核简洁文章资讯作文范文网站模板源码 带手机版 适用于博客、文章、资讯类网站使用 界面预览 易优cms内核简洁文章资讯作文范文网站模板源码

Python题目

实例 3.1 兔子繁殖问题(斐波那契数列) 兔子从出生后的第三个月开始,每月都会生一对兔子,小兔子成长到第三个月后也会生一对独自。初始有一对兔子,假如兔子都不死,那么计算并输出1-n个月兔子的数量 n int…

element-plus 表单组件 之element-form

elment-plus的表单组件的标签有el-form,el-form-item。 单个el-form标签内包裹若干个el-form-item,el-form-item包裹具体的表单组件,如输入框组件,多选组件,日期组件等。 el-form组件的主要作用是:提供统一的布局给其他表单组件&…

FPGA学习网站推荐

FPGA学习网站推荐 本文首发于公众号:FPGA开源工坊 引言 FPGA的学习主要分为以下两部分 语法领域内知识 做FPGA开发肯定要首先去学习相应的编程语言,FPGA开发目前在国内采用最多的就是使用Verilog做开发,其次还有一些遗留下来的项目会采用…

长亭谛听教程部署和详细教程

PPT 图片先挂着 挺概念的 谛听的能力 hw的时候可能会问你用过的安全产品能力能加分挺重要 溯源反制 反制很重要感觉很厉害 取证分析 诱捕牵制 其实就是蜜罐 有模板直接爬取某些网页模板进行伪装 部署要求 挺低的 对linux内核版本有要求 需要root 还有系统配置也要修改 …

leetcode刷题日记

题目描述 解题思路 基本思想,将数组复制一份,按照位置取余,确实做出来了,但是这样时间和空间上的资源比较多。看到切片法,感觉到很新,思路很好,用来记录。 代码 python class Solution:def ro…

springboot + Vue前后端项目(第十八记)

项目实战第十八记 写在前面1. 前台页面搭建(Front.vue)2. 路由3.改动登录页面Login.vue4. 前台主页面搭建Home.vue总结写在最后 写在前面 本篇主要讲解系统前台搭建,通常较大的项目都会搭建前台 1. 普通用户登录成功后前台页面效果&#xf…

Flutter 如何发布安卓应用?

android:hardwareAccelerated“true” android:windowSoftInputMode“adjustResize”> <meta-data android:name“flutterEmbedding” android:value“2” /> Flutter生成的文件建议是大部分内容可以保留不动&#xff0c;但是可以根据需要进行修改。 具体可能要修…

一款有趣的Python库绘制风向图,小白容易上手

利用 Python 绘制风向图 绘制风向图通常使用 matplotlib 库的 Barbs 类来实现.这个类用于绘制风向和风速的矢量场,可以实现不同的风向图风格. 安装 ## 命令安装 matplotlib 库&#xff1a;pip install matplotlib用法 下面是一个简单的示例代码,绘制风向图&#xff1a; 使…

分布式,容错:10台电脑坏了2台

由10台电脑组成的分布式系统&#xff0c;随机、任意坏了2台&#xff0c;剩下的8台电脑仍然储存着全部信息&#xff0c;可以继续服务。这是怎么做到的&#xff1f; 设N台电脑&#xff0c;坏了H台&#xff0c;要保证上述性质&#xff0c;需要有冗余&#xff0c;总的存储量降低为…

路由

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 客户端&#xff08;例如浏览器&#xff09;把请求发送给 Web 服务器&#xff0c;Web 服务器再把请求发送给 Flask程序实例。程序实例需要知道对每个U…

图像超分辨率重建

一、什么是图像超分辨 图像超分辨是一种技术&#xff0c;旨在通过硬件或软件的方法提高原有图像的分辨率。这一过程涉及从一系列低分辨率的图像中获取一幅高分辨率的图像&#xff0c;实现了时间分辨率向空间分辨率的转换。超分辨率重建的核心思想是利用多帧图像序列的时间带宽来…

压力测试Monkey命令参数和报告分析

目录 常用参数 -p <测试的包名列表> -v 显示日志详细程度 -s 伪随机数生成器的种子值 --throttle < 毫秒> --ignore-crashes 忽略崩溃 --ignore-timeouts 忽略超时 --monitor-native-crashes 监视本地崩溃代码 --ignore-security-exceptions 忽略安全异常 …

【vue3|第13期】深入了解Vue3生命周期:管理组件的诞生、成长与消亡

日期&#xff1a;2024年6月22日 作者&#xff1a;Commas 签名&#xff1a;(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释&#xff1a;如果您觉得有所帮助&#xff0c;帮忙点个赞&#xff0c;也可以关注我&#xff0c;我们一起成长&#xff1b;如果有不对的地方&#xf…

夏季城市内涝防治:视频汇聚系统智能AI技术助力城市自然灾害应急管理

据新闻报道&#xff0c;6月19日至20日&#xff0c;受强降雨影响&#xff0c;广西桂林城区及周边等地出现今年入汛以来持续时间最长、累计降水量最大、影响范围最广、致灾风险最高的暴雨天气过程&#xff0c;导致桂林市区多处发生洪水内涝&#xff0c;房屋被淹、道路受阻、人员被…