【机器学习-08】参数调优宝典:网格搜索与贝叶斯搜索等攻略

news2024/12/25 1:08:59

超参数是估计器的参数中不能通过学习得到的参数。在scikit-learn中,他们作为参数传递给估计器不同类的构造函数。典型的例子有支持向量分类器的参数C,kernel和gamma,Lasso的参数alpha等。

​ 在超参数集中搜索以获得最佳cross validation交叉验证分数的方法是可实现并且推荐的。

​ 当构建一个估计器是,任意参数的选取通过这种方式可能会获得最佳参数。尤其是,要想知道参数名字和其对应的当前值,使用:

estimator.get_params() 

​ 搜索包括:

  • 一个估计器(回归或分类,如sklearn.svm.SVC());
  • 一个参数空间;
  • 一个搜索的方法或可选参数集;
  • 一个交叉验证的方案;
  • 一个评分函数。

在scikit-learn中,超参数优化(optimization)或调优(tuning)是为学习算法选择一组最优超参数(Hyper-parameter),找到全局最小值;的方法一般包括以下四种:
在这里插入图片描述

  • 传统或手动调参
  • 网格搜索
  • 随机搜索
  • 贝叶斯搜索

一、 传统或手动调参

机器学习模型调优算法中的“传统或手动调参”是一种重要的优化方法,它主要依赖于模型训练者的经验和判断,对模型的超参数进行调整以达到最佳性能。以下是关于传统或手动调参的详细介绍:

原理:

手动调参的原理在于通过调整模型的超参数,来权衡模型的偏差和方差,从而提高模型的预测性能。超参数是机器学习算法在训练之前需要设定的参数,它们对模型的训练过程和结果有重要影响。

步骤:

  • 理解数据和模型: 首先,模型训练者需要深入了解数据的特性以及所使用的- 机器学习模型的原理。这有助于确定哪些超参数可能对模型的性能有重要影响。
  • 选择初始参数: 可以使用算法的默认参数作为起点,或者根据经验选择一组初始参数进行训练。
  • 调整重要参数: 根据模型的训练效果和性能评估指标(如准确率、召回率、F1分数等),逐步调整对模型性能影响较大的超参数。
  • 评估和调整: 使用验证集或交叉验证来评估不同超参数组合下的模型性能,并根据评估结果调整超参数。
  • **迭代优化:**重复上述步骤,直到找到一组较优的超参数组合,使得模型在验证集上的性能达到最佳。

优点:

  • 灵活性: 手动调参可以根据具体的数据和模型特点进行灵活调整,没有固定的规则限制。
  • 可解释性: 由于调参过程是基于人的经验和判断,因此可以更容易地理解和解释模型的性能变化。

缺点:

  • 依赖经验: 手动调参的效果很大程度上取决于模型训练者的经验和技能水平,对于初学者来说可能难以掌握。
  • 耗时费力: 手动调参通常需要尝试多种不同的超参数组合,并进行多次训练和评估,因此过程可能比较耗时和繁琐。
  • 容易陷入局部最优: 由于手动调参通常是基于当前状态的局部调整,因此可能容易陷入局部最优解,而非全局最优解。

让我们看看如下代码:

#importing required libraries
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold , cross_val_score
from sklearn.datasets import load_wine

wine = load_wine()
X = wine.data
y = wine.target

#splitting the data into train and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.3,random_state = 14)

#declaring parameters grid
k_value = list(range(2,11))
algorithm = ['auto','ball_tree','kd_tree','brute']
scores = []
best_comb = []
kfold = KFold(n_splits=5)

#hyperparameter tunning
for algo in algorithm:
  for k in k_value:
    knn = KNeighborsClassifier(n_neighbors=k,algorithm=algo)
    results = cross_val_score(knn,X_train,y_train,cv = kfold)

    print(f'Score:{round(results.mean(),4)} with algo = {algo} , K = {k}')
    scores.append(results.mean())
    best_comb.append((k,algo))

best_param = best_comb[scores.index(max(scores))]
print(f'\nThe Best Score : {max(scores)}')
print(f"['algorithm': {best_param[1]} ,'n_neighbors': {best_param[0]}]")

二、网格搜索

机器学习模型调优算法中的“网格搜索”是一种自动化调参方法,它通过穷举搜索的方式,在指定的参数空间内寻找最优的超参数组合,从而优化模型的性能。以下是对网格搜索的详细介绍:

原理:

网格搜索的原理是基于暴力穷尽搜索,它会在指定的参数范围内,按照设定的步长,生成一个参数网格。然后,网格搜索会遍历这个参数网格中的每一组参数组合,对模型进行训练和评估,以找到在验证集上性能最佳的超参数组合。

步骤:

  • 确定参数范围和步长: 首先,需要确定要搜索的超参数的范围和步长。这通常基于经验、算法的要求或数据的特性来确定。
  • 生成参数网格: 根据确定的参数范围和步长,生成一个参数网格,其中包含了所有可能的参数组合。
  • 遍历参数网格: 对于参数网格中的每一组参数组合,使用这些参数来训练模型,并在验证集上评估模型的性能。
  • 记录最佳参数组合: 在遍历过程中,记录每一组参数对应的模型性能评估结果,并找到性能最佳的一组参数。
  • 使用最佳参数组合: 最后,使用找到的最佳参数组合来重新训练模型,并在测试集上评估其性能。

优点:

  • 简单直观: 网格搜索是一种简单直观的调参方法,易于理解和实现。
  • 全局搜索: 通过遍历整个参数网格,网格搜索能够在全局范围内寻找最优参数组合,避免陷入局部最优。

缺点:

  • 计算成本高: 当参数空间较大或参数范围较广时,网格搜索需要遍历的参数组合数量会急剧增加,导致计算成本非常高昂,可能需要大量的时间和计算资源。
  • 可能不是最优解: 由于网格搜索是基于固定步长的搜索,可能无法找到真正的全局最优解,特别是当最优解位于两个搜索点之间时。
  • 缺乏灵活性: 网格搜索无法根据模型性能的变化动态调整搜索范围和步长,因此对于一些复杂的问题可能不太适用。

让我们来了解一下 sklearn 的 GridSearchCV 是如何工作的,

from sklearn.model_selection import GridSearchCV

knn = KNeighborsClassifier()
grid_param = { 'n_neighbors' : list(range(2,11)) , 
              'algorithm' : ['auto','ball_tree','kd_tree','brute'] }
              
grid = GridSearchCV(knn,grid_param,cv = 5)
grid.fit(X_train,y_train)

#best parameter combination
grid.best_params_

#Score achieved with best parameter combination
grid.best_score_

#all combinations of hyperparameters
grid.cv_results_['params']

#average scores of cross-validation
grid.cv_results_['mean_test_score']

三、随机搜索

机器学习模型调优算法中的“随机搜索”是一种高效且灵活的参数调优方法。它通过随机采样超参数空间中的参数组合,来寻找最优的超参数配置,从而优化模型的性能。以下是关于随机搜索的详细介绍:

原理:

随机搜索的原理在于,通过随机采样的方式,从整个超参数空间中挑选出一定数量的参数组合进行评估。与网格搜索不同,随机搜索并不遍历所有可能的参数组合,而是根据指定的采样策略(如均匀分布、正态分布等)在参数空间中进行随机抽样。这种方式可以在有限的计算资源下,快速找到性能较好的超参数组合。

步骤:

  • 定义参数空间: 首先,需要确定要搜索的超参数及其取值范围。这些超参数可以是学习率、正则化系数、树的深度等,具体取决于所使用的机器学习算法。
  • 设置采样策略: 选择合适的采样策略,确定如何从参数空间中随机抽取参数组合。常见的采样策略包括均匀采样、正态分布采样等。
  • 随机采样与评估: 根据采样策略,在参数空间中随机抽取一定数量的参数组合,并使用这些参数组合来训练模型。然后,在验证集上评估模型的性能,并记录每个参数组合对应的性能指标。
  • 选择最佳参数组合: 从所有评估过的参数组合中,选择性能最佳的一组作为最优超参数组合。
  • 重新训练与评估: 使用最优超参数组合重新训练模型,并在测试集上评估其性能。

优点:

  • 计算效率高: 由于随机搜索只需要评估一部分参数组合,因此相对于网格搜索来说,计算成本更低,能够在有限的时间内找到较好的超参数组合。
  • 灵活性强: 随机搜索可以根据问题的特点和计算资源的情况,灵活调整采样策略和采样数量,以适应不同的应用场景。
  • 全局搜索能力: 通过随机采样,随机搜索能够在一定程度上避免陷入局部最优解,具有更好的全局搜索能力。

缺点:

  • 结果不稳定: 由于随机搜索是基于随机采样的,因此每次运行的结果可能会有所不同,存在一定的不确定性。
  • 可能错过最优解: 虽然随机搜索可以在全局范围内搜索超参数组合,但由于其随机性,有时可能会错过真正的最优解。

让我们了解一下 sklearn 的 RandomizedSearchCV 是如何工作的,

from sklearn.model_selection import RandomizedSearchCV

knn = KNeighborsClassifier()

grid_param = { 'n_neighbors' : list(range(2,11)) , 
              'algorithm' : ['auto','ball_tree','kd_tree','brute'] }

rand_ser = RandomizedSearchCV(knn,grid_param,n_iter=10)
rand_ser.fit(X_train,y_train)

#best parameter combination
rand_ser.best_params_

#score achieved with best parameter combination
rand_ser.best_score_

#all combinations of hyperparameters
rand_ser.cv_results_['params']

#average scores of cross-validation
rand_ser.cv_results_['mean_test_score']

四、贝叶斯搜索

之前在博客中也介绍过贝叶斯搜索和python实现,具体参考本人博客贝叶斯优化(Bayesian Optimization)介绍和python实现

机器学习模型调优算法中的“贝叶斯搜索”是一种基于贝叶斯定理的自动化调参方法,它通过不断地更新目标函数的后验分布来寻找最优的超参数组合。以下是对贝叶斯搜索的详细介绍:

原理:

贝叶斯搜索的主要思想是,在给定优化的目标函数(广义的函数,只需指定输入和输出即可,无需知道内部结构以及数学性质)的情况下,通过不断地添加样本点来更新目标函数的后验分布(通常是高斯过程),直到后验分布基本贴合于真实分布。每一次添加样本点都会考虑上一次参数的信息,以便更好地调整当前的参数。贝叶斯搜索的核心过程包括先验函数(Prior Function, PF)与采集函数(Acquisition Function, AC)。PF主要利用高斯过程回归(或其他PF函数)来建模目标函数的分布,而AC则用于确定下一个采样点,以最大化信息增益或最小化预期损失。

步骤:

  • 定义先验分布: 基于问题的特性和经验知识,为超参数定义先验分布。
  • 采集函数与样本选择: 使用采集函数(如EI、PI、UCB等)来评估不同超参数组合的潜在价值,并选择下一个要评估的样本点。
  • 模型评估与更新: 在选定的超参数组合下训练模型,并在验证集上评估其性能。然后,使用这些评估结果来更新目标函数的后验分布。
  • 迭代优化: 重复步骤2和3,直到满足停止条件(如达到最大迭代次数、性能提升不再显著等)。
  • 选择最优参数: 从更新后的后验分布中选择性能最优的超参数组合。

优点:

  • 高效性: 贝叶斯搜索能够平衡探索(在全局尚未探索的区域寻找更好的解)和利用(利用已知信息优化当前解),从而在较少的迭代次数内找到较好的超参数组合。
  • 灵活性: 贝叶斯搜索可以适应各种复杂的目标函数和参数空间,对于非线性、非凸等问题也能表现出良好的性能。
  • 自适应性: 随着迭代次数的增加,贝叶斯搜索会逐渐聚焦于性能更好的参数区域,从而提高搜索效率。

缺点:

  • 计算成本: 尽管贝叶斯搜索相对于网格搜索和随机搜索更为高效,但在处理高维参数空间或需要精确建模的情况下,计算成本仍然可能较高。
  • 实现复杂性: 贝叶斯搜索的实现相对复杂,需要深入理解贝叶斯定理和高斯过程等概念,以及如何选择和使用合适的采集函数。
  • 对初始化的敏感性: 贝叶斯搜索的性能可能受到初始化参数和先验分布的影响,不合理的初始化可能导致搜索过程陷入局部最优解。

让我们用 scikit-optimize 的BayesSearchCV来理解这一点
安装: pip install scikit-optimize
实现贝叶斯搜索的另一个类似的库是 bayesian-optimization

from skopt import BayesSearchCV

import warnings
warnings.filterwarnings("ignore")

# parameter ranges are specified by one of below
from skopt.space import Real, Categorical, Integer

knn = KNeighborsClassifier()
#defining hyper-parameter grid
grid_param = { 'n_neighbors' : list(range(2,11)) , 
              'algorithm' : ['auto','ball_tree','kd_tree','brute'] }

#initializing Bayesian Search
Bayes = BayesSearchCV(knn , grid_param , n_iter=30 , random_state=14)
Bayes.fit(X_train,y_train)

#best parameter combination
Bayes.best_params_

#score achieved with best parameter combination
Bayes.best_score_

#all combinations of hyperparameters
Bayes.cv_results_['params']

#average scores of cross-validation
Bayes.cv_results_['mean_test_score']

五、总结

在机器学习模型调优的过程中,找到参数的最佳组合与所需的计算时间之间始终存在一个权衡。当面对超参数空间庞大、维度众多时,选择适当的优化方式显得尤为重要。以下是针对网格搜索、随机搜索、手动调参和贝叶斯搜索这四种优化方式的总结:

  • 网格搜索提供了一种全面而系统的搜索方法,通过遍历所有可能的参数组合来找到最优解。然而,当参数空间较大时,网格搜索的计算成本会急剧增加,可能导致优化过程耗时过长。

  • 随机搜索则通过随机抽样来减少计算量,同时保持一定的全局搜索能力。它能够在有限的计算资源下快速找到性能较好的参数组合,尤其适用于超参数空间较大或计算资源有限的情况。

  • 手动调参依赖于模型训练者的经验和判断,虽然灵活性较高,但耗时费力且容易陷入局部最优。它通常作为其他自动化调参方法的补充,用于对特定参数进行微调。

  • 贝叶斯搜索则利用贝叶斯定理和高斯过程来不断更新目标函数的后验分布,以找到最优的超参数组合。它在平衡探索和利用方面表现出色,能够在较少的迭代次数内找到较好的解。然而,贝叶斯搜索的实现相对复杂,需要深入理解相关概念和技术。

在实际应用中,我们可以根据问题的特性和计算资源的情况选择合适的优化方式。对于超参数空间较大的情况,可以先使用随机搜索快速找到潜在的参数组合,然后对这些组合进行局部的网格搜索以选择最优特征。这样既能保证一定的全局搜索能力,又能减少计算成本。同时,结合手动调参对特定参数进行微调,可以进一步提高模型的性能。最终,在找到参数的最佳组合的保证和计算时间之间取得一个合理的权衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1551101.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java开发过程中如何进行进制换换

最近由于工作上的需要,遇到进制转换的问题。涉及到的进制主要是十进制、十六进制、二进制转换。 1、十进制转十六进制、二进制 调用java自带的api,测试十进制转16进制、2进制 package com.kangning.common.utils.reflect;/*** 十进制 转 十六进制* 十进制 转 二进…

黑群晖Docker安装aria2-pro

前言 最近买了星际蜗牛C款当Nas,来满足我的存储需求,在之前我写过一篇docker安装aria2-pro的文章,既然买了nas那当然也要安装一个aria2-pro做下载器 1.安装 Container Manager 套件 可以在套件中心搜索docker找到 2.下载aria2-pro镜像 打…

力扣热门算法题 89. 格雷编码,92. 反转链表 II,93. 复原 IP 地址

89. 格雷编码,92. 反转链表 II,93. 复原 IP 地址,每题做详细思路梳理,配套Python&Java双语代码, 2024.03.24 可通过leetcode所有测试用例。 目录 89. 格雷编码 解题思路 完整代码 Python Java 92. 反转链表…

利用Tensor在jetson orin 上加速YOLOv5

一、第一种方法,需要下载各种包: 要用到一个大佬的开源,GitHub地址如下: https://github.com/wang-xinyu/tensorrtx/tree/master/yolov51. 安装pycuda,在线安装pycuda pip3 install pycuda 2. Windows操作&#xf…

Ubuntu Desktop 更改默认应用程序 (Videos -> SMPlayer)

Ubuntu Desktop 更改默认应用程序 [Videos -> SMPlayer] References System Settings -> Details -> Default Applications 概况、默认应用程序、可移动介质、法律声明 默认应用程序,窗口右侧列出了网络、邮件、日历、音乐、视频、照片操作的默认应用程序…

2024全行业数字化转型企业建设解决方案PPT合集(附下载)

精品推荐,2024全行业数字化转型企业建设解决方案PPT合集,精品PPT源格式共21份。 点击直达星球下载地址(文末领取优惠券):2024全行业数字化转型企业建设解决方案PPT合集 1.制造业数字化转型解决方案及应用.pptx 2.医院…

Java代码基础算法练习-求一个三位数的各位数字之和-2024.03.27

任务描述&#xff1a; 输入一个正整数n&#xff08;取值范围&#xff1a;100<n<1000&#xff09;&#xff0c;然后输出每位数字之和 任务要求&#xff1a; 代码示例&#xff1a; package M0317_0331;import java.util.Scanner;public class m240327 {public static voi…

langchin-chatchat部分开发笔记(持续更新)

大模型相关目录 大模型&#xff0c;包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步&#xff0c;扬帆起航。 大模型应用向开发路径及一点个人思考大模型应用开发实用开源项目汇总大模型问答项目…

Matlab进阶绘图第47期—气泡分组蝴蝶图

气泡分组蝴蝶图是分组蝴蝶图与气泡图的组合——在分组蝴蝶图每组柱子上方添加大小不同的气泡&#xff0c;用于表示另外一个数据变量&#xff08;如每组柱子值的和&#xff09;的大小。 本文利用自己制作的BubbleButterfly工具&#xff0c;进行气泡分组蝴蝶图的绘制&#xff0c…

从接口发现到文件上传getshell

0x01 信息收集 通过fofa&#xff0c;子域名收集等相关工具搜索域名 定位到站点&#xff1a;htps://xx..edu.cn/x/xx/ 0x02 寻找接口 通过f12寻找相关的js&#xff0c;发现有其他的页面 0x03 拼接路径 https://xx.xx.edu.cn/xx/xx/repairResgister 之后未授权获取到注册用户的页…

【群晖】解决docker容器启动出现 database is locked 错误

【群晖】解决docker容器启动出现 database is locked 错误 问题描述 升级DSM 7.2 V3版本后docker中的大量容器出现虽然显示启动状态&#xff0c;但是webStation中服务是禁用中。 因此选择手动重启容器&#xff0c;但是发现容器无法启动&#xff0c;提示了以下错误&#xff1…

盏多多生物现已加入2024第七届燕窝天然滋补品展

参展企业介绍 广东省盏多多生物科技有限公司是一家从事食品销售,食品销售,食品进出口等业务的公司&#xff0c;成立于2018年12月07日&#xff0c;公司坐落在广东省&#xff0c;详细地址为&#xff1a;惠州市东江三路45号悦榕湾27层05号&#xff08;仅限办公&#xff09;;经国家…

课堂练习:环境体验——Linux 文件操作命令

任务描述 第二个任务就是了解Linxu的文件查看命令&#xff0c;文件编辑基本命令。 相关知识 为了完成本关任务&#xff0c;你需要掌握&#xff1a; 1.文件查看命令。 2.文件编辑基本命令。 文件查看命令 我们要查看一些文本文件的内容时&#xff0c;要使用文本编辑器来查看…

解决NRF52832正常添加OTA代码后无法进入app一直运行在bootloader的问题!

问题现象描述&#xff1a; SDK版本17.1.0 在 mergehex工具 合并以下文件setting.hex bootloader.hex app.hex sortdevice.hex 之后烧录固件第一次运行 程序一直运行在bootloader&#xff0c;蓝牙名称显示 DFUTARG &#xff0c;必须要进行一次OTA才进入APP 注意&#xff1a;如…

补单系统平台第三方接口,电商平台数据市场接口api提供

补单系统平台第三方接口&#xff0c;电商平台数据市场接口api提供 部分数据参数

电缆故障测试仪的原理和组成部件分别是什么?

电缆故障测试仪是专为检测电缆线路中的各种故障而设计制造的精密电子设备&#xff0c;广泛应用于电力、通信、石油化工、航空航天等领域。这类仪器的工作原理和组成相对复杂&#xff0c;下面将详细阐述。 电缆故障测试仪的工作原理 电缆故障测试仪的核心原理通常涉及电磁波反…

STM32最小核心板使用HAL库ADC读取MCU温度(使用DMA通道)

STM32自带CPU的温度数据&#xff0c;需要使用ADC去读取。因此在MX创建项目时如图配置&#xff1a; 模块初始化代码如下&#xff1a; void MX_ADC1_Init(void) {/* USER CODE BEGIN ADC1_Init 0 *//* USER CODE END ADC1_Init 0 */ADC_ChannelConfTypeDef sConfig {0};/* USER…

验证码demo(简单实现)

前言 我们注意到我们登录网站的时候经常会用到网络验证码,今天我们就简单实现一个验证码的前后端交互问题,做一个小demo 准备 我们这里并不需要依靠原生的java来实现,而是只需要引入一个maven依赖,使用现成的封装好的即可,这是我使用的是hutool工具包 网址:Hutool&#x1f36c;…

Linux系统-----------MySQL 数据类型

目录 MySQL 数据类型 一、数值类型 二、日期和时间类型 三、字符串类型 &#xff08;1&#xff09;CHAR类型 &#xff08;2&#xff09;VARCHAR类型 &#xff08;3&#xff09;CHAR和VARACHAR的比较及其应用场景 MySQL 数据类型 MySQL 中定义数据字段的类型对你数据库的…

Nginx超详细讲解+实操

前言 nginx作为当今火爆的、高性能的http及反向代理服务&#xff0c;不管前端还是后端&#xff0c;都需要全面去了解&#xff0c;学习&#xff0c;实操。 nginx 介绍 为了有一个全面的认知&#xff0c;接下来我们先来看看nginx的架构以及一些特点。 nginx 特点 处理响应请…