【机器学习】分类算法 - 模型选择与调优GridSearchCV(网格搜索)

news2024/12/22 18:28:56

「作者主页」:士别三日wyx
「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者
「推荐专栏」:零基础快速入门人工智能《机器学习入门到精通》

模型选择与调优

  • 1、交叉验证
  • 2、网格搜索
  • 3、模型选择与调优API
  • 4、案例演示
    • 4.1、特征集获取划分
    • 4.2、特征标准化
    • 4.3、KNN算法处理
    • 4.4、参数调优

K-近邻算法的K是指邻居的个数,「K值」不同,算法的「准确率」也不同,我们需要不断调整K值,以提高算法的准确率。在「调整」过程中,我们需要用到「交叉验证」

1、交叉验证

交叉验证(Cross-Validation)是在机器学习建立模型和验证模型「参数」时常用的方法,用于「评估」机器模型的性能指标,从而进行「模型选择」

交叉验证的「基本思想」是,把原始数据分组,一部分当做训练集,另一部分作为验证集,先用训练集对算法模型进行训练,再用验证集测试训练得到的算法模型。

比如,把数据分成四份,先用第一份数据当验证集,把后面三份的训练结果与第一份做验证;再用第二份数据当验证集,把其他三份数据的训练结果和第二份做验证;以此类推。。。

在这里插入图片描述

交叉验证常配合网格搜索一同使用。

2、网格搜索

网格搜索也叫超「参数搜索」,比如K-近邻算法的K值需要手动指定参数,这种参数就叫超参数。网格搜索通过预设几组超参数组合,每组超参数都用交叉验证进行评估,从而选出「最优」的参数组合来建立模型。

sklearn 模块 GridSearchCV 很好的实现了网格搜索,它可以自动调参,只要把参数输进去,就能给出最优的结果和参数。


3、模型选择与调优API

sklearn.model_selection.GridSearchCV( estimator,param_grid,cv)

  • estimator:需要使用的分类器
  • param_grid:需要优化的参数,字典或列表格式{ "n_neighbors": [1, 3, 5] , }
  • cv:交叉验证次数

返回值属性

  • best_params_:(dict)最佳参数
  • best_score_ :(float)最佳结果
  • best_estimator_:(estimator)最佳分类器
  • cv_results_:(dict)交叉验证结果
  • best_index_:(int)最佳参数的索引
  • n_splits_:(int)交叉验证的次数

4、案例演示

接下来,我们使用 GridSearchCV 来选择 K-近邻算法的「最佳K值」

4.1、特征集获取划分

使用 sklearn 自带的的鸢尾花「数据集」,数据集划分为60%训练,40%测试。

from sklearn import datasets
from sklearn import model_selection

# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
print('训练集特征:', len(x_train))
print('测试集特征:', len(x_test))
print('训练集目标:', len(y_train))
print('测试集特征:', len(y_test))

输出:

训练集特征: 112
测试集特征: 38
训练集目标: 112
测试集特征: 38

从输出结果可以看到,训练集和测试集的比例符合预期


4.2、特征标准化

接下来,对训练集特征和测试集特征进行「标准化」处理

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing

# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
# 3、特征标准化
ss = preprocessing.StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.fit_transform(x_test)
print(x_train)

输出:

[[-0.18295405 -0.192639    0.25280554 -0.00578113]
 [-1.02176094  0.51091214 -1.32647368 -1.30075363]
 [-0.90193138  0.97994624 -1.32647368 -1.17125638]

从输出结果可以看到,特征已经标准化。


4.3、KNN算法处理

将训练特征集和测试特征集传给KNN,并查看「准确率」

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing
from sklearn import neighbors

# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
# 3、特征标准化
ss = preprocessing.StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.fit_transform(x_test)
# 4、KNN算法处理
knn = neighbors.KNeighborsClassifier(n_neighbors=2)
knn.fit(x_train, y_train)
print(knn.score(x_test, y_test))

输出:

0.8947368421052632

从输出结果可以看到,准确率是89%,一般般。


4.4、参数调优

将不同的K值封装成字典,传给 GridSearchCV,计算「最优」的参数。

from sklearn import datasets
from sklearn import model_selection
from sklearn import preprocessing
from sklearn import neighbors

# 1、获取数据集
iris = datasets.load_iris()
# 2、划分数据集
# x_train:训练集特征,x_test:测试集特征,y_train:训练集目标,y_test:测试集目标
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target, random_state=6)
# 3、特征标准化
ss = preprocessing.StandardScaler()
x_train = ss.fit_transform(x_train)
x_test = ss.fit_transform(x_test)
# 4、KNN算法处理
knn = neighbors.KNeighborsClassifier(n_neighbors=2)
# 5、参数调优
params = {"n_neighbors": [1, 3, 5, 7]}
knn = model_selection.GridSearchCV(knn, param_grid=params, cv=10)
knn.fit(x_train, y_train)
print('最优参数:', knn.best_params_)
print('最优准确率:', knn.best_score_)
print('最优分类器:', knn.best_estimator_)

输出:

最优参数: {'n_neighbors': 5}
最优准确率: 0.9727272727272729
最优分类器: KNeighborsClassifier()

从输出结果可以看到,最优的K值参数是5,准确率达到了97%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/801226.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IntelliJ IDEA 2023.2 最新变化

主要更新 AI Assistant 限定访问 Ultimate 在此版本中,我们为 IntelliJ IDEA 引入了一项重要补充 – AI Assistant。 AI Assistant 当前具备一组由 AI 提供支持的初始功能,提供集成式 AI 聊天,可以完成一些任务,例如自动编写文档…

在win10上安装spinal hdl完全教程(一篇文章就够了)

一 参考文章 SpinalHDL 开发环境搭建一步到位(图文版) - 极术社区 - 连接开发者与智能计算生态 (aijishu.com)https://aijishu.com/a/1060000000255643SpinalHDL(一)——环境搭建 - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/146529005

Android安卓实战项目(4)---提供给阿尔兹海默症患者的APP(源码在文末)

Android安卓实战项目(4)—提供给阿尔兹海默症患者的APP(源码在文末) 一.项目运行介绍 1.大致浏览 (1)开机界面 (2)主界面 (3)Read Instructions界面 &…

运维:Windows11操作系统安装VMware16.1.1图文教程(附下载)

目录 一、VMware 16.1.1 介绍 二、安装教程 三、下载地址 一、VMware 16.1.1 介绍 VMware 16.1.1 是一款功能非常强大虚拟化软件,它允许用户在一台计算机上创建和运行多个虚拟机(Virtual Machine)相当于拥有多台服务器。这些虚拟机可以模拟…

JUC高并发编程(二)——Synchronized关键字

文章目录 前言为什么要用Synchronized关键字 并发编程中的三个问题可见性原子性有序性 Synchronized保证三大特性使用synchronized保证可见性使用synchronized保证原子性用synchronized保证有序性 Synchronized的特征可重入特征不可中断特征 前言 synchronized 关键字&#xff…

Python爬虫时遇到SSL证书验证错误解决办法汇总

在进行Python爬虫任务时,遇到SSL证书验证错误是常见的问题之一。SSL证书验证是为了确保与服务器建立的连接是安全和可信的,但有时候可能会由于证书过期、不匹配或未受信任等原因导致验证失败。为了解决这个问题,本文将提供一些实用的解决办法…

提高业务效率:利用手机号在网状态 API 进行智能筛选

引言 随着科技的不断发展,手机已成为现代人生活中不可或缺的工具。人们通过手机完成通信、娱乐、购物等各种活动,使得手机号成为了一个重要的个人标识。对于企业而言,了解手机号的在网状态对于业务发展和客户管理至关重要。为了提高业务效率…

https和http有什么区别

https和http有什么区别 简要 区别如下: ​ https的端口是443.而http的端口是80,且二者连接方式不同;http传输时明文,而https是用ssl进行加密的,https的安全性更高;https是需要申请证书的,而h…

Linux常用命令——dpkg-statoverride命令

在线Linux命令查询工具 dpkg-statoverride Debian Linux中覆盖文件的所有权和模式 补充说明 dpkg-statoverride命令用于Debian Linux中覆盖文件的所有权和模式,让dpkg于包安装时使得文件所有权与模式失效。 语法 dpkg-statoverride(选项)选项 -add&#xff1…

深度:解密数据库的诗与远方!

‍数据智能产业创新服务媒体 ——聚焦数智 改变商业 不同于历史上的黄金和石油,数据成为了我们新的宝藏,一个驱动社会进步、催生创新的无尽源泉。然而,这些形式各异、复杂纷繁的数据需要一个管理者,一个保险库,一个解…

【动态规划part09】| 198.打家劫舍、213.打家劫舍II、337.打家劫舍III

🎈LeetCode198.打家劫舍 链接:198.打家劫舍 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷…

【数据结构】实验三:链表

实验三链表 一、实验目的与要求 1)熟悉链表的类型定义; 2)熟悉链表的基本操作; 3)灵活应用链表解决具体应用问题。 二、实验内容 1)请设计一个单链表的存储结构,并实现单链表中基本运算算…

基于ssm+mysql+jsp高校疫情防控出入信息管理系统

基于ssmmysqljsp高校疫情防控出入信息管理系统 一、系统介绍二、功能展示1.登陆2.教师管理3.学生管理4.打卡记录管理5.学生申请通行证6.通行证管理7.留言信息管理8.公告类型管理9.公告管理 四、获取源码 一、系统介绍 学生 : 个人中心、打卡记录管理、学生申请通行证、通行证管…

Java 8 Stream流:代码简洁之道

文章目录 前言一、filter二、map三、mapToInt、mapToLong、mapToDouble四、flatMap五、flatMapToInt、flatMapToLong、flatMapToDouble六、distinct七、sorted八、peek九、limit十、forEach十一、forEachOrdered十二、toArray十三、reduce十四、collect十五、min、max十六、cou…

mysql(二) 索引-基础知识

继续整理复习、我以我的理解和认知来整理 "索引" 会通过 文 和 图 来展示。 文: 基本概念知识(mysql 的索引分类、实现原理) 图: 画B树等 MySQL官方对索引的定义是:索引(Index)是帮…

记录--虚拟 DOM 和实际 DOM 有何不同?

这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前言 本文我们会先聊聊 DOM 的一些缺陷,然后在此基础上介绍虚拟 DOM 是如何解决这些缺陷的,最后再站在双缓存和 MVC 的视角来聊聊虚拟 DOM。理解了这些会让你对目前的前端框架有…

第四章 HL7 架构和可用工具 - 查看数据结构

文章目录 第四章 HL7 架构和可用工具 - 查看数据结构查看数据结构查看代码表使用自定义架构编辑器 第四章 HL7 架构和可用工具 - 查看数据结构 查看数据结构 当单击“数据结构”列中的名称时,InterSystems 会显示该数据结构中的所有字段。这是 HL7 数据结构页面。…

影视行业案例 | 燕千云助力大地影院集团搭建智能一体化IT服务管理平台

影视行业过去三年受新冠肺炎疫情影响,经历了一定程度的冲击和调整,但也展现出了强大的韧性和潜力。2023年中国影视产业规模可能达到2600亿元左右,同比增长11%左右。影视行业的发展趋势主要表现在内容创新、模式创新和产业融合三个方面&#x…

第八章:将自下而上、自上而下和平滑性线索结合起来进行弱监督图像分割

0.摘要 本文解决了弱监督语义图像分割的问题。我们的目标是在仅给出与训练图像关联的图像级别对象标签的情况下,为新图像中的每个像素标记类别。我们的问题陈述与常见的语义分割有所不同,常规的语义分割假设在训练中可用像素级注释。我们提出了一种新颖的…

PSP - MMseqs2 编译最新版本源码 (14-7e284) 支持 MPI 功能 MSA 快速搜索

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/131966061 MPI (Message Passing Interface) 是用于并行计算的标准化和可移植的消息传递接口,可以在分布式内存的多台计算机上运行并行…