机器学习 sklearn 中的超参数搜索方法

news2024/11/24 13:20:07

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 超参数搜索
    • 默认参数
    • GridSearchCV
    • RandomizedSearchCV
    • HalvingGridSearchCV
    • HalvingRandomSearchCV


超参数搜索

在建模时模型的超参数往往会对精度造成一定影响,而设置和调整超参数的取值,往往称为调参

在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜索,找到最优的超参数组合。

在 sklearn 中,提供了四种超参数搜索方法:

  • GridSearchCV
  • RandomizedSearchCV
  • HalvingGridSearchCV
  • HalvingRandomSearchCV

默认参数

为了方便起见,我们先定义一个默认参数的模型,用于后续的超参数搜索。

# 导入相关库
import random
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import (
    train_test_split,
    GridSearchCV,
    RandomizedSearchCV,
    HalvingGridSearchCV,
    HalvingRandomSearchCV,
)
from sklearn.ensemble import RandomForestRegressor

# 设置随机种子
seed = 1
random.seed(seed)
np.random.seed(seed)

# 加载数据集
data = datasets.load_diabetes()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)

# 使用默认参数模型进行分类并评分
reg = RandomForestRegressor(random_state=seed)
reg.fit(X_train, y_train)
print(round(reg.score(X_test, y_test), 6))

最终默认参数的模型在测试集上的 R 2 R^2 R2 分数约为 0.269413

GridSearchCV

GridSearchCV 是一种网格搜索超参数的方法,它会遍历所有的超参数组合,然后评估模型的性能,最终选择性能最好的一组超参数。

# 设置超参数搜索范围
param_grid = {
    "max_depth": [2, 4, 5, 6, 7],
    "min_samples_leaf": [1, 2, 3],
    "min_weight_fraction_leaf": [0, 0.1],
    "min_impurity_decrease": [0, 0.1, 0.2]
}

# 使用 GridSearchCV 进行超参数搜索
reg = GridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

在这个超参数搜索空间中,一共有 5 * 3 * 2 * 3 = 90 种超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

RandomizedSearchCV

RandomizedSearchCV 是一种随机搜索超参数的方法,它的使用方法与 GridSearchCV 类似,但是它不会遍历所有的超参数组合,而是在超参数的取值范围内随机选择一组超参数进行训练,然后评估模型的性能,最终选择性能最好的一组超参数。

# 使用 RandomizedSearchCV 进行超参数搜索
reg = RandomizedSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    n_iter=20,  # 设置迭代次数
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

RandomizedSearchCV 一共进行了 20 次迭代,即尝试了 20 组超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

HalvingGridSearchCV

HalvingGridSearchCVGridSearchCV 类似,但在迭代的过程中采用减半超参数搜索空间的方法,以此来减少超参数搜索的时间。

在搜索的最开始,HalvingGridSearchCV 使用很少的数据样本来在完整的超参数搜索空间中进行搜索,筛选其中最优的超参数,之后再增加数据进行进一步筛选。

# 使用 HalvingGridSearchCV 进行超参数搜索
reg = HalvingGridSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{
    'max_depth': 4,
    'min_impurity_decrease': 0.2,
    'min_samples_leaf': 1,
    'min_weight_fraction_leaf': 0
}

可以看到,HalvingGridSearchCV 得到的最优超参数组合与 GridSearchCV 得到的最优超参数组合相同。

HalvingRandomSearchCV

HalvingRandomSearchCVHalvingGridSearchCV 类似,都是逐步增加样本数量,减少超参数组合,但是 HalvingRandomSearchCV 每次生成的超参数组合是随机的。

# 使用 HalvingRandomSearchCV 进行超参数搜索
reg = HalvingRandomSearchCV(
    RandomForestRegressor(random_state=seed),
    param_grid,
    cv=5,
    n_jobs=-1,
    random_state=seed,
)
reg.fit(X_train, y_train)

# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))

# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{
    'min_weight_fraction_leaf': 0,
    'min_samples_leaf': 1,
    'min_impurity_decrease': 0.1,
    'max_depth': 6
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1292674.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软著项目推荐 深度学习实现语义分割算法系统 - 机器视觉

文章目录 1 前言2 概念介绍2.1 什么是图像语义分割 3 条件随机场的深度学习模型3\. 1 多尺度特征融合 4 语义分割开发过程4.1 建立4.2 下载CamVid数据集4.3 加载CamVid图像4.4 加载CamVid像素标签图像 5 PyTorch 实现语义分割5.1 数据集准备5.2 训练基准模型5.3 损失函数5.4 归…

探索SpringBoot发展历程

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: 循序渐进学SpringBoot ✨特色专栏&…

《消息队列MyMQ》——参考RabbitMQ实现

目录 一、什么是消息队列? 二、需求分析 1)核心概念 2)核心API 3)交换机类型 4)持久化 5)网络通信 ​编辑 6)消息应答 三、 模块划分 四、创建核心类 1.ExChange 2.MSGQueue 3.Bind…

系列学习前端之第 4 章:一文精通 JavaScript

全套学习 HTMLCSSJavaScript 代码和笔记请下载网盘的资料&#xff1a; 链接: 百度网盘 请输入提取码 提取码: 6666 1、JavaScript 格式 一般放在 html 的 <head> 标签中。type&#xff1a;默认值text/javascript可以不写&#xff0c;不写也是这个值。 <script typ…

C++新经典模板与泛型编程:用成员函数重载实现std::is_class

用成员函数重载实现is_class std::is_class功能&#xff0c;是一个C11标准中用于判断某个类型是否为一个类类型&#xff08;但不是联合类型&#xff09;的类模板。当时在讲解的时候并没有涉及std::is_class的实现代码&#xff0c;在这里实现一下。简单地书写一个IsClass类模板…

【微服务】springboot整合quartz使用详解

目录 一、前言 二、quartz介绍 2.1 quartz概述 2.2 quartz优缺点 2.3 quartz核心概念 2.3.1 Scheduler 2.3.2 Trigger 2.3.3 Job 2.3.4 JobDetail 2.4 Quartz作业存储类型 2.5 适用场景 三、Cron表达式 3.1 Cron表达式语法 3.2 Cron表达式各元素说明 3.3 Cron表达…

从文字到使用,一文读懂Kafka服务使用

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

【改进YOLOv8】融合高效网络架构 CloAtt的焊缝识别系统

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义 近年来&#xff0c;随着人工智能技术的快速发展&#xff0c;计算机视觉领域取得了显著的进展。其中&#xff0c;目标检测是计算机视觉领域的一个重要研究方向&am…

“消费增值:改变你的购物方式,让每一笔消费都变得更有价值“

你是否厌倦了仅仅购买物品或享受服务后便一无所有的消费方式&#xff1f;现在&#xff0c;消费增值的概念将彻底改变你的消费观念&#xff01;通过参与消费增值&#xff0c;你的每一笔消费都将变得更有价值&#xff01; 消费增值是一种全新的消费理念&#xff0c;它让你在购物的…

[BJDCTF2020]EzPHP 许多的特性

这道题可以学到很多东西 静下心来慢慢通过本地知道是干嘛用的就可以学会了 BJDctf2020 Ezphp_[bjdctf2020]ezphp-CSDN博客 这里开始 一部分一部分看 $_SERVER[QUERY_SRING]的漏洞 if($_SERVER) { if (preg_match(/shana|debu|aqua|cute|arg|code|flag|system|exec|passwd|…

SSM项目实战-前端-添加分页控件-调正页面布局

1、Index.vue <template><div class"common-layout"><el-container><el-header><el-row><el-col :span"24"><el-button type"primary" plain click"toAdd">新增</el-button></el-…

三坐标测量机如何精确测量产品的高度差?

三坐标测量机通过测量物体的三维坐标来实现精确的尺寸测量&#xff0c;不仅直观且又方便&#xff0c;测量结果精度高&#xff0c;并且重复性好。 三坐标测量机基于三个坐标轴&#xff1a;X轴、Y轴和Z轴&#xff0c;通过控制测针在三个方向上的移动来实现测量。而在测量产品高度…

1.10 C语言之外部变量与作用域

1.10 C语言之外部变量与作用域 一、外部变量概述二、练习 一、外部变量概述 我们说&#xff0c;函数&#xff08;不管是main函数还是其他函数&#xff09;内部定义的变量&#xff0c;其作用范围都只在函数内部&#xff0c;我们把这些变量叫做自动变量或者局部变量。除了局部变…

ingress介绍和ingress通过LoadBalancer暴露服务配置

目录 一.ingress基本原理介绍 1.将原有用于暴露服务和负载均衡的服务的三四层负载均衡变为一个七层负载均衡 2.controller和ingress 3.通过下面这个图可能会有更直观的理解 二.为什么会出现ingress 1.NodePort存在缺点 2.LoadBalancer存在缺点 三.ingress三种暴露服务的…

docker-compose部署sonarqube 8.9 版本

官方部署文档注意需求版本 所以选择8.9版本 一、准备部署配置 1、持久化目录 rootlocalhost:/root# mkdir -p /data/sonar/postgres /data/sonar/sonarqube/data /data/sonar/sonarqube/logs /data/sonar/sonarqube/extensions rootlocalhost:/root# chmod 777 /data/sona…

数据结构 | 查漏补缺之求叶子结点,分离链接法、最小生成树、DFS、BFS

求叶子结点的个数 参考博文&#xff1a; 树中的叶子结点的个数 计算方法_求树的叶子节点个数-CSDN博客 分离链接法 参考博文 数据结构和算法——哈希查找冲突处理方法&#xff08;开放地址法-线性探测、平方探测、双散列探测、再散列&#xff0c;分离链接法&#xff09;_线性…

用户案例|Milvus 助力 Credal.AI 实现 GenAI 安全与可控

AIGC 时代&#xff0c;企业流程中是否整合人工智能&#xff08;AI&#xff09;对于的企业竞争力至关重要。然而&#xff0c;随着 AI 不断发展演进&#xff0c;企业也在此过程中面临数据安全管理、访问权限、数据隐私等方面的挑战。 为了更好地解决上述问题&#xff0c;Credal.A…

双目光波导AR眼镜_AR智能眼镜主板PCB定制开发

AR眼镜方案的未来发展潜力非常巨大。随着技术的进步&#xff0c;AR眼镜的光学模块将变得更小巧&#xff0c;像素密度也会增加&#xff0c;实现更高分辨率的画面&#xff0c;甚至能够达到1080P、2K和4K级别的清晰度&#xff0c;从而提升用户的视觉体验。 AR智能眼镜的硬件方面&a…

【Android Studio】【入门】helloworld和工程的各个文件的作用

这里写目录标题 可以开发的app类型注意点 搞一个helloworld玩玩各个文件的作用 可以开发的app类型 Phone and Tablet&#xff1a;开发手机和平板的app&#xff1b;Wear OS&#xff1a;穿戴系统&#xff1b;TV&#xff1a;电视app&#xff1b;Android Auto&#xff1a;汽车上的…

了解Linux网络配置

本章主要介绍网络配置的方法。 网络基础知识 查看网络信息 图形化界面修改 通过配置文件修改 命令行管理 11.1 网络基础知识 一台主机需要配置必要的网络信息&#xff0c;才可以连接到互联网。需要的配置网络信息包括IP、 子网掩码、网关和 DNS。 11.1.1 IP 地址 在计算机…