模型优化_XGBOOST学习曲线及改进,泛化误差

news2024/9/23 13:21:01

代码

from xgboost import XGBRegressor as XGBR
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.linear_model import LinearRegression as LR
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split,cross_val_score as CV,KFold
from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from time import time
import datetime

#加载数据
data=load_boston()
X=data.data
y=data.target

#划分数据集
Xtrain,Xtest,ytrain,ytest=train_test_split(X,y,test_size=0.3,random_state=420)

#定位模型,进行fit
reg=XGBR(n_estimators=100).fit(Xtrain,ytrain)

#进行预测
reg.predict(Xtest)
reg.score(Xtest,ytest)#返回的是R平方
MSE(ytest,reg.predict(Xtest))
reg.feature_importances_
#查看SKLEARN中所有的模型评估指标
import sklearn
sorted(sklearn.metrics.SCORERS.keys())



# ======================================
#交叉验证,与线性回归随机森林进行结果比对
reg=XGBR(n_estimators=100)

from sklearn.model_selection import train_test_split,cross_val_score
cross_val_score(reg,Xtrain,ytrain,cv=5).mean()#

#交叉验证既可以解决数据集的数据量不够大问题,也可以解决参数调优的问题。
#这块主要有三种方式:简单交叉验证(HoldOut检验)、k折交叉验证(k-fold交叉验证)
cross_val_score(reg,Xtrain,ytrain,cv=5,scoring="neg_mean_squared_error").mean()

#绘制学习曲线
def plot_learning_curve(estimator,title,X,y,
                       ax=None,#选择子图
                       ylim=None,#设置纵坐标的取值范围
                       cv=None,#交叉验证
                       n_jobs=None):
    from sklearn.model_selection import learning_curve
    train_sizes,train_scores,test_scores=learning_curve(estimator,X,y,shuffle=True,
                                                       cv=cv,random_state=420,n_jobs=n_jobs)
    if ax==None:
        ax=plt.gca()
    else:
        ax=plt.figure()
    ax.set_title(title)
    if ylim is not None:
        ax.set_ylim(*ylim)
    ax.set_xlabel("Traing example")
    ax.set_ylabel("Score")
    ax.grid()#绘制网格
    ax.plot(train_sizes,np.mean(train_scores,axis=1),"o-",color="r",label="traing score")
    ax.plot(train_sizes,np.mean(test_scores,axis=1),"o-",color="g",label="test.py score")
    ax.legend(loc="best")
    return ax

#学习曲线的绘制
cv=KFold(n_splits=5,shuffle=True,random_state=42)
plot_learning_curve(XGBR(n_estimators=100,random_state=420),"XGB",Xtrain,ytrain,ax=None,cv=cv)

                        

#绘制学习曲线,查看n_estimators对模型的影响 

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5).mean()
    rs.append(cv1)
print(axis[rs.index(max(rs))],max(rs))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

泛化误差:用来衡量模型在未知数据集上的准确率

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,50,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)
    rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率
    var.append(cv1.var())
    ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.legend()
plt.show()

 

#绘制学习曲线,查看n_estimators对模型的影响

#绘制学习曲线,查看n_estimators对模型的影响
axis=range(10,30,1)
rs=[]#偏差,衡量的是准确率
var=[]#方差,衡量的是稳定性
ge=[]#泛化误差的可控部门
for i in axis:
    reg=XGBR(n_estimators=i)
    cv1=cross_val_score(reg,Xtrain,ytrain,cv=5)
    rs.append(cv1.mean())#记录偏差,返回的R平方就是偏差部门,衡量的是准确率
    var.append(cv1.var())
    ge.append((1-cv1.mean())**2+cv1.var())
print(axis[rs.index(max(rs))],max(rs),var[rs.index(max(rs))])
print(axis[var.index(min(var))],rs[var.index(min(var))],min(var))
print(axis[ge.index(min(ge))],rs[ge.index(min(ge))],var[ge.index(min(ge))],min(ge))
#添加方差线条
rs=np.array(rs)
var=np.array(var)#源代码这里*0.01
plt.figure(figsize=(20,5))
plt.plot(axis,rs,c='red',label="XGB")
plt.plot(axis,rs+var,c="black",linestyle="-.")
plt.plot(axis,rs-var,c="black",linestyle="-.")
plt.legend()
plt.show()

#看看泛化误差的可控部分如何

plt.figure(figsize=(20,5))
plt.plot(axis,ge,c='red',label="XGB")
plt.legend()
plt.show()
 

从这个过程中观察n_estimators参数对模型的影响,我们可以得出以下结论:
首先,XGB中的树的数量决定了模型的学习能力,树的数量越多,模型的学习能力越强。只要XGB中树的数量足够
了,即便只有很少的数据, 模型也能够学到训练数据100%的信息,所以XGB也是天生过拟合的模型。但在这种情况
下,模型会变得非常不稳定。
第二,XGB中树的数量很少的时候,对模型的影响较大,当树的数量已经很多的时候,对模型的影响比较小,只能有
微弱的变化。当数据本身就处于过拟合的时候,再使用过多的树能达到的效果甚微,反而浪费计算资源。当唯一指标
或者准确率给出的n_estimators看起来不太可靠的时候,我们可以改造学习曲线来帮助我们。
第三,树的数量提升对模型的影响有极限,最开始,模型的表现会随着XGB的树的数量一起提升,但到达某个点之
后,树的数量越多,模型的效果会逐步下降,这也说明了暴力增加n_estimators不一定有效果。
这些都和随机森林中的参数n_estimators表现出一致的状态。在随机森林中我们总是先调整n_estimators,当
n_estimators的极限已达到,我们才考虑其他参数,但XGB中的状况明显更加复杂,当数据集不太寻常的时候会更加
复杂。这是我们要给出的第一个超参数,因此还是建议优先调整n_estimators,一般都不会建议一个太大的数目,
300以下为佳。

参考:

XGBOOST学习曲线及改进,泛化误差-CSDN博客

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1490341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何添加极狐GitLab Runner 信任域名证书

本文作者 徐晓伟 极狐Gitlab Runner 信任实例域名证书,用于注册注册极狐 GitLab Runner。 问题 参见 极狐gitlab-runner-host.md 说明 解决方案是使用颁发给域名 gitlab.test.helm.xuxiaowei.cn 的证书,可以使用自己的域名去各大云厂商免费申请&#…

重学SpringBoot3-yaml文件配置

重学SpringBoot3-yaml文件配置 引言YAML 基本语法YAML 数据类型YAML 对象YAML 数组复合结构标量引用 YAML 文件结构Spring Boot 中的 YAML 配置注意事项总结参考 引言 YAML(YAML Ain’t Markup Language)是一种常用于配置文件的数据序列化格式&#xff…

Unity3D

一、C# 输入输出 二、三维数学

CSS变量和@property

CSS变量 var() CSS 变量是由CSS作者定义的实体,其中包含要在整个文档中重复使用的特定值。使用自定义属性来设置变量名,并使用特定的 var() 来访问。(比如 color: var(--main-color);)。 基本用法 CSS变量定义的作用域只在定义该…

搞定国科金 必备王炸新技术!凌恩生物重磅推出微生物单细胞测序产品

单细胞异质性研究如火如荼,但原核生物研究却是个“坎”。 现有常规的原核生物研究,都集中于单菌群落或微生态大群体,只能从宏观角度研究群体状态,而经典的单细胞RNA测序技术无法应用于细菌。 单细胞技术应用于原核生物的几点障碍…

window10 安装配置docker

前言(重要):确认window10版本已经更新到最新版 随着时间推移,docker对window版本的支持也在变,截至2024年3月份,支持win10最低版本号:22H2,操作系统最低版本:19045.2965&#xff0c…

学编程怎么样才能更快入手,编程怎么简单易学

学编程怎么样才能更快入手,编程怎么简单易学 一、前言 对于初学编程建议先从简单入手,然后再学习其他复杂的编程语言。 今天给大家分享的中文编程开发语言工具 进度条构件的用法。 编程入门视频教程链接 https://edu.csdn.net/course/detail/39036 …

26、Qt调用.py文件中的函数

一、开发环境 Qt5.12.0 Python3.7.8 64bit 二、使用 新建一个Qt项目,右击项目名称,选择“添加库” 选择“外部库”,点击“下一步” 点击“浏览”,选择Python安装目录下的libs文件夹中的“python37.lib”文件,点击“下…

【Python如何输入工资,五险一金,专项扣除后得出个税和到手工资(2024年最新)】

最近综合所得年度汇算,正好心血来潮算一下到手工资对不对,有些朋友年综合收入也才几万块,结果年综报税时还要补一两万的个税,这主要是因为跳槽后,上家公司的年薪全平均移到了新的公司每个月中,系统的缺陷导…

第16课:如何出版人生第一本书

机会是留给有准备的人的,在网上多输出文章,就会有更好的曝光机会,有可能被潜在的机会捕捉到。 除了不断的写文章,还可以通过书籍封面的投稿信息进行文章投稿,投稿的文章一定要符合要求。 出书从来不是一件简单的事&am…

Spring MVC源码中设计模式——适配器模式

适配器模式介绍 适配器模式(Adapter Pattern)是作为两个不兼容的接口之间的桥梁。这种类型的设计模式属于结构型模式,它结合了两个独立接口的功能。 应用场景: 1、系统需要使用现有的类,而此类的接口不符合系统的需要…

EdgeX Foundry 边缘物联网中间件平台

文章目录 1.EdgeX Foundry2.平台架构3.平台服务3.1.设备服务3.2.核心服务3.3.支持服务3.4.应用服务3.5.安全服务3.6.管理服务 EdgeX Foundry # EdgeX Foundryhttps://iothub.org.cn/docs/edgex/ https://iothub.org.cn/docs/edgex/edgex-foundry/1.EdgeX Foundry EdgeX Found…

JavaScript快速入门+文档查询【详解】

目录 1. js简介 2.js引入方式 3. JS基础语法(ECMAScript) 4. js函数和事件【js的核心】 5.js对象 6.BOM对象 7.DOM对象 8.案例全选全消 1. js简介 1.什么是js JavaScript,简称js,是web开发中不可缺少的脚本语言,不需要编译就能…

Spring Test 常见错误

前面我们介绍了许多 Spring 常用知识点上的常见应用错误。当然或许这些所谓的常用,你仍然没有使用,例如对于 Spring Data 的使用,,有的项目确实用不到。那么这一讲,我们聊聊 Spring Test,相信你肯定绕不开对…

IDEA自动导入provided的依赖

最近在学习flink 流程序&#xff0c;在写demo程序的时候依赖flink依赖&#xff0c;依赖的包在flink集群里面是自己已经提供了的&#xff0c;在导入的时候配置为provided&#xff0c;像下面这样&#xff0c;以使打包的时候不用打到最终的程序包里面。 <dependency><gro…

STM32USART串口数据包

文章目录 前言一、介绍部分数据包两种包装方式&#xff08;分割数据&#xff09;HEX数据包文本数据包 数据包的收发流程数据包的发送数据包的接收固定包长的hex数据包接收可变包长的文本数据包接收 二、实例部分固定包长的hex数据包接收连接线路代码实现 可变包长的文本数据包接…

AWS的RDS数据库开启慢查询日志

#开启慢日志两个参数 slow_query_log 1 设置为1&#xff0c;来启用慢查询日志 long_query_time 5 &#xff08;单位秒&#xff09; sql执行多长时间被定义为慢日志1. 点击RDS然后点击参数组&#xff0c;选择slow_query_log&#xff0c;设置为1【表示开启慢日志】点击保存…

[cg] Games 202 - NPR 非真实感渲染

NPR特性&#xff08;基于真实感渲染&#xff09; 真实感--》翻译成非真实感的过程 NPR风格 需要转换为渲染中的操作 1.描边 B-->普通边界&#xff08;不是下面几种的&#xff09; C-->折痕 M-->材质边界 S-->需要在物体外面一圈上&#xff0c;并且是多个面共享…

使用GitHub API 查询开源项目信息

一、GitHub API介绍 GitHub API 是一组 RESTful API 接口&#xff0c;用于与 GitHub 平台进行交互。通过使用 GitHub API&#xff0c;开发人员可以访问和操作 GitHub 平台上的各种资源&#xff0c;如仓库、提交记录、问题等。 GitHub API 提供了多种功能和端点&#xff0c;以…

gin gorm学习笔记

代码仓库 https://gitee.com/zhupeng911/go-advanced.git https://gitee.com/zhupeng911/go-project.git 1. gin介绍 Gin 是使用纯 Golang 语言实现的 HTTP Web框架&#xff0c;Gin接口设计简洁&#xff0c;提供类似Martini的API&#xff0c;性能极高&#xff0c;现在被广泛使用…