交叉验证、网格搜索、模型选择与调优、鸢尾花案例增加K值调优与Facebook人造世界签到位置train.csv数据预测代码实现

news2025/1/18 16:58:29

一、交叉验证

交叉验证(cross validation):将拿到的训练数据分为训练和验证集,以下图为例,将数据分成4份,其中一份作为验证集,经过4次(组)的测试,每次都更换不同的验证集,即得到4组模型的结果,取平均值作为最终结果,又称4折交叉验证

交叉验证目的:为了让被评估的模型更加准确可信,交叉验证不能提高模型的准确率

为了让从训练得到模型结果更加准确,分割数据为

  • 训练集:训练集+验证集
  • 测试集:测试集

二、网格搜索

网格搜索(Grid Search):通常情况下,有很多参数需要手动指定的(如k-近邻算法中的K值),这种参数叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型 ,即把超参数的值通过字典形式传入,然后选择最优值

 交叉验证,网格搜索(模型选择与调优)API:

  • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None):对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
    • 结果分析:
      • bestscore__:在交叉验证中验证的最好结果
      • bestestimator:最好的参数模型
      • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

三、鸢尾花案例增加K值调优

使用GridSearchCV构建估计器

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier  # 导入模块

# 1.获取数据集
iris = load_iris()
# 2.数据基本处理
x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)   # 划分数据集
# 3.特征工程:标准化
transfer = StandardScaler()   # 实例化一个转换器类
x_train = transfer.fit_transform(x_train)   # 调用fit_transform
x_test = transfer.transform(x_test)
# 4.KNN预估器流程
estimator = KNeighborsClassifier(n_neighbors=1)   # 实例化预估器类

# 模型选择与调优——网格搜索和交叉验证
param_dict = {"n_neighbors": [1, 3, 5, 7]}   # 准备要调的超参数
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=3)  # estimator:选择了哪个训练模型,cv为几折交叉验证
estimator.fit(x_train, y_train)   # fit数据进行训练
# 5.评估模型效果
# 方法a:比对预测结果和真实值
y_predict = estimator.predict(x_test)
print('预测值是:\n', y_predict)
print("比对预测结果和真实值:\n", y_predict == y_test)
# 方法b:直接计算准确率
score = estimator.score(x_test, y_test)
print("直接计算准确率:\n", score)
# 评估查看最终选择的结果和交叉验证的结果
print("最好的参数模型:\n", estimator.best_estimator_)
print("在交叉验证中验证的最好结果:\n", estimator.best_score_)
print("每次交叉验证后的准确率结果:\n", estimator.cv_results_)
----------------------------------------------------------------
输出:
预测值是:
 [0 2 1 2 1 1 1 1 1 0 2 1 2 2 0 2 1 1 1 1 0 2 0 1 2 0 2 2 2 2 0 0 1 1 1 0 0
 0]
比对预测结果和真实值:
 [ True  True  True  True  True  True  True False  True  True  True  True
  True  True  True  True  True  True False  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True]
直接计算准确率:
 0.9473684210526315
最好的参数模型:
 KNeighborsClassifier()
在交叉验证中验证的最好结果:
 0.9732100521574205
每次交叉验证后的准确率结果:
 {'mean_fit_time': array([0.00033259, 0.00031996, 0.00066495, 0.00060987]), 'std_fit_time': array([0.00047036, 0.00045249, 0.00047019, 0.00043647]), 'mean_score_time': array([0.00166217, 0.00200717, 0.00171733, 0.00133022]), 'std_score_time': array([4.70134086e-04, 1.72589729e-05, 5.13878713e-04, 4.70639818e-04]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7],
             mask=[False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}], 'split0_test_score': array([0.97368421, 0.97368421, 0.97368421, 0.97368421]), 'split1_test_score': array([0.97297297, 0.97297297, 0.97297297, 0.97297297]), 'split2_test_score': array([0.94594595, 0.89189189, 0.97297297, 0.97297297]), 'mean_test_score': array([0.96420104, 0.94618303, 0.97321005, 0.97321005]), 'std_test_score': array([0.01291157, 0.03839073, 0.00033528, 0.00033528]), 'rank_test_score': array([3, 4, 1, 1])}

四、Facebook人造世界数据预测

train.csv, test.csv数据获取官网:grid_knn | Kaggle

train.csv, test.csv数据文件内容如下

  • row_ id:登记事件的id        
  • x,y:坐标
  • accuracy: 定位准确性
  • time: 时间戳
  • place. _id:业务的id,这是预测目标
  • 基本步骤
    • 数据处理
      • 缩小数据集范围 DataFrame.query()
      • 选取有用的时间特征
      • 将签到位置少于n个用户的删除
    • 分割数据集
    • 标准化处理
    • k-近邻预测

数据介绍:根据用户的位置,准确性和时间戳预测用户正在查看的业务

 代码如下

import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier  # 导入模块

# 1、获取数据集
facebook = pd.read_csv("../data/train.csv")
# 2.基本数据处理
facebook_data = facebook.query("x>2.0 & x<2.5 & y>2.0 & y<2.5")  # 缩小数据范围
time = pd.to_datetime(facebook_data["time"], unit="s")   # 选择时间特征
time = pd.DatetimeIndex(time)
facebook_data["day"] = time.day
facebook_data["hour"] = time.hour
facebook_data["weekday"] = time.weekday
# 去掉签到较少的地方
place_count = facebook_data.groupby("place_id").count()
place_count = place_count[place_count["row_id"] > 3]
facebook_data = facebook_data[facebook_data["place_id"].isin(place_count.index)]
x = facebook_data[["x", "y", "accuracy", "day", "hour", "weekday"]]   # 确定特征值和目标值
y = facebook_data["place_id"]
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=22)   # 分割数据集

# 3.特征工程--特征预处理(标准化)
transfer = StandardScaler()  # 实例化一个转换器
x_train = transfer.fit_transform(x_train)   # 调用fit_transform
x_test = transfer.fit_transform(x_test)

# 4.机器学习--knn+cv
estimator = KNeighborsClassifier()   # 实例化一个估计器
param_grid = {"n_neighbors": [1, 3, 5, 7, 9]}  # 准备要调的超参数
estimator = GridSearchCV(estimator, param_grid=param_grid, cv=10)   # 调用gridsearchCV,10折交叉验证
estimator.fit(x_train, y_train)   # 模型训练
# 5.模型评估
# 基本评估方式
score = estimator.score(x_test, y_test)
print("最后预测的准确率为:\n", score)

y_predict = estimator.predict(x_test)
print("最后的预测值为:\n", y_predict)
print("预测值和真实值的对比情况:\n", y_predict == y_test)

# 使用交叉验证后的评估方式
print("在交叉验证中验证的最好结果:\n", estimator.best_score_)
print("最好的参数模型:\n", estimator.best_estimator_)
print("每次交叉验证后的验证集准确率结果和训练集准确率结果:\n", estimator.cv_results_)
----------------------------------------------------------
输出:
最后预测的准确率为:
 0.36804111804111805
最后的预测值为:
 [9983648790 5806536504 9674001925 ... 1247398579 3455925971 5100539171]
预测值和真实值的对比情况:
 24703810     True
19445902    False
18490063     True
7762709     False
6505956     False
            ...  
27632888    False
23367671    False
6692268     False
25834435    False
13319005    False
Name: place_id, Length: 17316, dtype: bool
在交叉验证中验证的最好结果:
 0.36276655562074106
最好的参数模型:
 KNeighborsClassifier()
每次交叉验证后的验证集准确率结果和训练集准确率结果:
 {'mean_fit_time': array([0.2586242 , 0.29684422, 0.27456429, 0.30488372, 0.31562543]), 'std_fit_time': array([0.08488032, 0.07275422, 0.10609217, 0.10391282, 0.0656314 ]), 'mean_score_time': array([0.51792595, 0.61323991, 0.72256181, 0.88158529, 0.97560654]), 'std_score_time': array([0.10981396, 0.11335467, 0.24410747, 0.13826019, 0.14417211]), 'param_n_neighbors': masked_array(data=[1, 3, 5, 7, 9],
             mask=[False, False, False, False, False],
       fill_value='?',
            dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}, {'n_neighbors': 7}, {'n_neighbors': 9}], 'split0_test_score': array([0.36650626, 0.35534167, 0.3653513 , 0.36496631, 0.35611165]), 'split1_test_score': array([0.36323388, 0.35630414, 0.35765159, 0.36034649, 0.3574591 ]), 'split2_test_score': array([0.3680462 , 0.35437921, 0.36284889, 0.36111646, 0.36419634]), 'split3_test_score': array([0.36169394, 0.35091434, 0.36438884, 0.36458133, 0.36188643]), 'split4_test_score': array([0.35842156, 0.34898941, 0.35380173, 0.35784408, 0.3520693 ]), 'split5_test_score': array([0.35399423, 0.34879692, 0.36111646, 0.35861405, 0.35668912]), 'split6_test_score': array([0.3599615 , 0.35303176, 0.36785371, 0.36458133, 0.35899904]), 'split7_test_score': array([0.3574591 , 0.35688162, 0.37208855, 0.37170356, 0.36053898]), 'split8_test_score': array([0.35444744, 0.34674625, 0.36176357, 0.35425491, 0.34867154]), 'split9_test_score': array([0.3600308 , 0.35252214, 0.36080092, 0.35829804, 0.35656527]), 'mean_test_score': array([0.36037949, 0.35239075, 0.36276656, 0.36163066, 0.35731868]), 'std_test_score': array([0.00441373, 0.00327258, 0.00486026, 0.0046996 , 0.00431428]), 'rank_test_score': array([3, 5, 1, 2, 4])}

学习导航:http://xqnav.top/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/69642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

One-shot就能做事件抽取?ChatGPT在信息抽取上的强大应用

One-shot就能做事件抽取&#xff1f;ChatGPT在信息抽取上的强大应用0. 前言1. 灵感2. 实验3. 结论0. 前言 近期&#xff0c;OpenAI发布的chat GPT可谓是各种刷屏&#xff0c;很多人都在关注这种模式是否可以应用于搜索引擎&#xff0c;这给做搜索的朋友们带来了很大的危机感。…

强大的VS插件DevExpress CodeRush v22.1 - 让代码编程更智能

DevExpress CodeRush是一个强大的Visual Studio .NET 插件&#xff0c;它利用整合技术&#xff0c;通过促进开发者和团队效率来提升开发者体验。为Visual Studio IDE增压、消除重复的代码并提高代码质量&#xff0c;可以快速思考、自动化测试、可视化调试和重构。 CodeRush v2…

vue学习笔记(一)-vue基础语法

视频教程&#xff1a;尚硅谷Vue2.0Vue3.0全套教程丨vuejs从入门到精通_哔哩哔哩_bilibili 相关文档&#xff1a;Vue核心 Vue简介 初识 (yuque.com) 兼容性 Vue 不支持 IE8 及以下版本&#xff0c;因为 Vue 使用了 IE8 无法模拟的 ECMAScript 5 特性。但它支持所有兼容 ECMAS…

RabbitMQ入门

1. 什么是MQ 消息队列(Message Queue&#xff0c;简称MQ)&#xff0c;从字面意思上看&#xff0c;本质是个队列&#xff0c;FIFO先入先出&#xff0c;只不过队列中存放的内容是message而已 作用&#xff1a;应用程序“对”应用程序的通信方法。 2. 应用场景 主要解决异步处理…

pixel 3xl 手机如何烧录自己编译的android 12代码

pixel 3xl 手机如何烧录自己编译的android 12代码 一.查看pixel 3xl手机支持的Android 12版本 通过浏览器访问android版本跟代号网页查看对应的pixel 3XL 手机支持的android 版本跟代号 可以看出&#xff0c;pixel 3XL手机支持Adnroid 12的有Android 12.0.0_r31, Android 12.…

华为机试 - 区间交叠问题

目录 题目描述 输入描述 输出描述 用例 题目解析 算法源码 题目描述 给定坐标轴上的一组线段&#xff0c;线段的起点和终点均为整数并且长度不小于1&#xff0c;请你从中找到最少数量的线段&#xff0c;这些线段可以覆盖柱所有线段。 输入描述 第一行输入为所有线段的数…

键盘输入保护器:KeyScrambler

创新技术屏蔽数字资产 KeyScrambler 开创性的击键加密技术可在 Windows 操作系统、所有浏览器和数百个关键应用程序中实时深入地保护用户键入的信息。 值得信赖的软件让用户安心 KeyScrambler 已经被世界各地的专家、博主和用户测试和使用了 16 年&#xff0c;并被证明对最阴险…

ANSYS_Dsigner仿真串扰

1、边沿RT的大小对串扰的影响 仿真电路如下图所示&#xff1a; V1为V_Pulse电压源&#xff0c;设置如图所示&#xff1a; A4为耦合微带线 这里一定要设置为9.6mil&#xff0c;因为介质厚度我设置的是4.8mil&#xff0c;如果没阻抗匹配会在串扰的基础上增加信号的反射&#xff…

【计算机视觉】完整版复习

计算机标定 齐次坐标 齐次坐标&#xff0c;将欧氏空间的无穷远点&#xff0c;与投影空间中有实际意义的消失点&#xff0c;建立起映射关系。 把齐次坐标转化为笛卡尔坐标的方法&#xff1a;是前面n-1个坐标分量分别除以最后一个分量即可 一些解释和性质&#xff1a; 比较好的…

idea远程debug

有时候我们需要进行远程的debug&#xff0c;本文研究如何进行远程debug&#xff0c;以及使用 IDEA 远程debug的过程中的细节。看完可以解决你的一些疑惑。 1.配置idea 如图&#xff0c;依次点击或者填写对应的ip和端口&#xff0c;需要debug的服务 2.修改启动命令 选择 jdk …

东郊到家、往约到家预约上门理疗按摩系统小程序模式讲解

东郊到家和往约到家都是做上门理疗按摩推拿等服务的线上预约平台&#xff0c;目前已经在全国很多一二线城市都开设了分站&#xff0c;今天我们就来对这两个程序进行讲解。 为什么这类上门服务平台能发展的这么迅速&#xff1f; 一是因为平台成本投入比较低&#xff0c;线上预…

微服务框架 SpringCloud微服务架构 22 DSL 查询语法 22.4 地理查询

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式&#xff0c;系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构22 DSL 查询语法22.4 地理查询22.4.1 地理查询22 DSL 查询语法 22.4 地理…

【强化学习论文】多智能体强化学习是一个序列建模问题

文献题目&#xff1a;Multi-Agent Reinforcement Learning is A Sequence Modeling Problem时间&#xff1a;2022代码&#xff1a;https://github.com/PKU-MARL/Multi-Agent-Transformer. 摘要 GPT 系列和 BERT 等大序列模型&#xff08;SM&#xff09;在自然语言处理、视觉和…

FL Studio免费升级21完整版新功能新插件介绍

万众期待的 FL Studio 21 版本正式发布上线&#xff0c;所有FL Studio的用户&#xff0c;都可以免费升级到21版&#xff01; 按照惯例&#xff0c;本次新版也会增加全新插件&#xff0c;来帮助大家更好地创作。今天先给大家分享一下&#xff0c;新增的4款插件简单介绍&#xf…

基于AT89S52单片机的蘑菇大棚环境监测系统论文(附录代码)

目 录 第1章 绪 论 1 1.1 研究背景和意义 1 1.2 国内外发展现状 2 1.3 设计内容和指标 4 第2章 系统设计方案 5 2.1 系统组成 5 2.1.1 总体结构 5 2.1.2 单片机的选型 5 2.1.3 温湿度传感器选型 6 2.1.4 二氧化碳传感器选型 6 2.1.5 PH值传感器选型 7 2.1.6 加热器选型 8 2.1.7…

HTTP协议分析 实验报告

实验名称&#xff1a; HTTP协议分析 一、实验预习 1、实验目的 利用抓包工具&#xff08;Wireshark/Windump/Sniffer&#xff09;抓取HTTP报文&#xff0c;以进一步熟悉和理解HTTP报文格式规范与HTTP协议的工作原理 2、实验内容&#xff08;…

《Linux-权限的理解、shell的理解和粘滞位》

目录 一、shell的理解 二、Linux权限 一、用户的引入 二、权限管理 一、什么是权限 二、Linux下的权限 三、视图展示 四、文件类型 五、为什么gcc编译器编译.txt后缀的文件有问题&#xff1f; 六、修改权限 一、chmod设置文件的访问权限 一、基本使用 二、八进制方案(访…

es的自动补全查询——DSL语句java代码实现

1、DSL语句 elasticsearch提供了Completion Suggester查询来实现自动补全功能。这个查询会匹配以用户输入内容开头的词条并返回。 为了提高补全查询的效率&#xff0c;对于文档中字段的类型有一些约束&#xff1a; 参与补全查询的字段必须是completion类型。 字段的内容一般…

SpringMVC的执行流程

文章目录1 初始化阶段2 匹配阶段3 执行阶段我们把整个流程分成三个阶段初始化阶段匹配阶段执行阶段 1 初始化阶段 在 Web 容器第一次用到 DispatcherServlet 的时候&#xff0c;会创建其对象并执行 init 方法 init 方法内会创建 Spring Web 容器&#xff0c;并调用容器 refre…

阿里十年技术沉淀|深度解析百PB级数据总线技术

云原生场景下数据总线需求场景及挑战 数据总线简介 数据总线作为大数据架构下的流量中枢&#xff0c;在不同的大数据组件之间承载着数据桥梁的作用。通过数据总线&#xff0c;可以实时接入来自服务器、K8s、APP、Web、IoT/移动端等产生的各类异构数据&#xff0c;进行统一数据…