Python机器学习17——Xgboost和Lightgbm结合分位数回归(机器学习与传统统计学结合)

news2024/12/28 5:36:40

最近XGboost支持分位数回归了,我看了一下,就做了个小的代码案例。毕竟学术市场上做这种新颖的机器学习和传统统计学结合的方法还是不多,算的上创新,找个好数据集可以发论文。


代码实现

导入包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error,r2_score
import xgboost as xgb
import lightgbm as lgb
import statsmodels.api as sm
from statsmodels.regression.quantile_regression import QuantReg

xgboost和lightgbm都需要安装的,他们和sklearn库的机器学习方法不是一个库的。怎么安装看我《实用的机器学习》这个栏目的xgb那篇文章。


模拟数据进行分位数回归

先制作一个模拟数据集

def f(x: np.ndarray) -> np.ndarray:
    return x * np.sin(x)

rng = np.random.RandomState(2023)
X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T
expected_y = f(X).ravel()
sigma = 0.5 + X.ravel() / 10.0
noise = rng.lognormal(sigma=sigma) - np.exp(sigma**2.0 / 2.0)
y = expected_y + noise

print(X.shape,y.shape)

然后画图看看:

plt.figure(figsize=(6,2),dpi=100)
plt.scatter(X,y,s=1)
plt.show()

#划分训练集和测试集


X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
print(f"Training data shape: {X_train.shape}, Testing data shape: {X_test.shape}")

这里采用三种模型进行拟合预测对比,分别是线性分位数回归,XGB结合分位数,LightGBM结合分位数:

alphas = np.arange(5, 100, 5) / 100.0
print(alphas)
mse_qr, mse_xgb, mse_lgb = [], [], []
r2_qr, r2_xgb, r2_lgb = [], [], []
qr_pred,xgb_pred,lgb_pred={},{},{}

# Train and evaluate
for alpha in alphas:
    # Quantile Regression
    model_qr = QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
    model_pred=model_qr.predict(sm.add_constant(X_test))
    mse_qr.append(mean_squared_error(y_test,model_pred ))
    r2_qr.append(r2_score(y_test,model_pred))
    
    # XGBoost
    model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                          xgb.QuantileDMatrix(X_train, y_train), num_boost_round=100)
    model_pred=model_xgb.predict(xgb.DMatrix(X_test))
    mse_xgb.append(mean_squared_error(y_test,model_pred ))
    r2_xgb.append(r2_score(y_test,model_pred))
    
    # LightGBM
    model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True,}, 
                          lgb.Dataset(X_train, y_train), num_boost_round=100)
    
    model_pred=model_lgb.predict(X_test)
    mse_lgb.append(mean_squared_error(y_test,model_pred))
    r2_lgb.append(r2_score(y_test,model_pred))
    
    if alpha in [0.1,0.5,0.9]:
        qr_pred[alpha]=model_qr.predict(sm.add_constant(X_test))
        xgb_pred[alpha]=model_xgb.predict(xgb.DMatrix(X_test))
        lgb_pred[alpha]=model_lgb.predict(X_test)

分位点为0.1,0.5,0.9时记录一下,方便画图查看。

然后画出三种模型在不同分位点下的误差和拟合优度对比:

plt.figure(figsize=(7, 5),dpi=128)
plt.subplot(211)
plt.plot(alphas, mse_qr, label='Quantile Regression')
plt.plot(alphas, mse_xgb, label='XGBoost')
plt.plot(alphas, mse_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('MSE')
plt.title('MSE across different quantiles')

plt.subplot(212)
plt.plot(alphas, r2_qr, label='Quantile Regression')
plt.plot(alphas, r2_xgb, label='XGBoost')
plt.plot(alphas, r2_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('$R^2$')
plt.title('$R^2$ across different quantiles')
plt.tight_layout()
plt.show()

可以看到在分位点为0.5附件,模型的误差都比较小。因为这个数据集没有很多的异常值。然后模型表现上,LGBM>XGB>线性QR。线性模型对于一个非线性的函数关系拟合在这里当然不行。

画出拟合图:
 

name=['QR','XGB-QR','LGB-QR']
plt.figure(figsize=(7, 6),dpi=128)
for k,model in enumerate([qr_pred,xgb_pred,lgb_pred]):
    n=int(str('31')+str(k+1))
    plt.subplot(n)
    plt.scatter(X_test,y_test,c='k',s=2)
    for i,alpha in enumerate([0.1,0.5,0.9]):
        sort_order = np.argsort(X_test, axis=0).ravel()
        X_test_sorted = np.array(X_test)[sort_order]
        #print(np.array(model[alpha]))
        predictions_sorted = np.array(model[alpha])[sort_order]
        plt.plot(X_test_sorted,predictions_sorted,label=fr"$\tau$={alpha}",lw=0.8)
    plt.legend()
    plt.title(f'{name[k]}')
plt.tight_layout()
plt.show()

可以看到分位数回归的明显的区间特点。

还有非参数非线性方法的优势,明显XGB和LGBM拟合得更好。


波士顿数据集

上面是人工数据,下面采用真实的数据集进行对比,就用回归最常用的波士顿房价数据集吧:

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]
column_names = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO',  'B','LSTAT', 'MEDV']
boston=pd.DataFrame(np.hstack([data,target.reshape(-1,1)]),columns= column_names)

取出X和y,划分测试集和训练集

X = boston.iloc[:,:-1]
y = boston.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

拟合预测,对比

alphas = np.arange(0.1, 1, 0.1)
mse_qr, mse_xgb, mse_lgb = [], [], []
r2_qr, r2_xgb, r2_lgb = [], [], []
qr_pred,xgb_pred,lgb_pred={},{},{}
# Train and evaluate
for alpha in alphas:
    # Quantile Regression
    model_qr = QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
    model_pred=model_qr.predict(sm.add_constant(X_test))
    mse_qr.append(mean_squared_error(y_test,model_pred ))
    r2_qr.append(r2_score(y_test,model_pred))
    
    # XGBoost
    model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                          xgb.QuantileDMatrix(X_train, y_train), num_boost_round=100)
    model_pred=model_xgb.predict(xgb.DMatrix(X_test))
    mse_xgb.append(mean_squared_error(y_test,model_pred ))
    r2_xgb.append(r2_score(y_test,model_pred))
    
    # LightGBM
    model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True,}, 
                          lgb.Dataset(X_train, y_train), num_boost_round=100)
    
    model_pred=model_lgb.predict(X_test)
    mse_lgb.append(mean_squared_error(y_test,model_pred))
    r2_lgb.append(r2_score(y_test,model_pred))
    
    if alpha in [0.1,0.5,0.9]:
        qr_pred[alpha]=model_qr.predict(sm.add_constant(X_test))
        xgb_pred[alpha]=model_xgb.predict(xgb.DMatrix(X_test))
        lgb_pred[alpha]=model_lgb.predict(X_test)

画图查看不同分位点的不同模型的误差和拟合优度:

plt.figure(figsize=(8, 5),dpi=128)
plt.subplot(211)
plt.plot(alphas, mse_qr, label='Quantile Regression')
plt.plot(alphas, mse_xgb, label='XGBoost')
plt.plot(alphas, mse_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('MSE')
plt.title('MSE across different quantiles')

plt.subplot(212)
plt.plot(alphas, r2_qr, label='Quantile Regression')
plt.plot(alphas, r2_xgb, label='XGBoost')
plt.plot(alphas, r2_lgb, label='LightGBM')
plt.legend()
plt.xlabel('Quantile')
plt.ylabel('$R^2$')
plt.title('$R^2$ across different quantiles')
plt.tight_layout()
plt.show()

可以看到在分位点为0.6附件三个模型表现效果都比较好,然后模型表现来看,XGB>LGBM>QR,还是两个机器学习模型更厉害。


分位数损失函数和平方和损失函数对比

上面我们得到在分位点为0.6的时候,模型效果表现好,那么分位数模型和普通的MSE损失函数的效果比起来怎么样呢?我们继续对比:

# 定义alpha值
alpha = 0.5

# 分位数回归模型
model_qr = sm.regression.quantile_regression.QuantReg(y_train, sm.add_constant(X_train)).fit(q=alpha)
qr_pred = model_qr.predict(sm.add_constant(X_test))

# XGBoost分位数回归
model_xgb = xgb.train({"objective": "reg:quantileerror", 'quantile_alpha': alpha}, 
                      xgb.DMatrix(X_train, label=y_train), num_boost_round=100)
xgb_q_pred = model_xgb.predict(xgb.DMatrix(X_test))

# LightGBM分位数回归
model_lgb = lgb.train({'objective': 'quantile', 'alpha': alpha,'force_col_wise': True}, 
                      lgb.Dataset(X_train, label=y_train), num_boost_round=100)
lgb_q_pred = model_lgb.predict(X_test)

# 普通的最小二乘法线性回归
model_lr = LinearRegression()
model_lr.fit(X_train, y_train)
lr_pred = model_lr.predict(X_test)

# 普通的XGBoost
model_xgb_reg = xgb.train({"objective": "reg:squarederror"}, xgb.DMatrix(X_train, label=y_train), num_boost_round=100)
xgb_pred = model_xgb_reg.predict(xgb.DMatrix(X_test))

# 普通的LightGBM
model_lgb_reg = lgb.train({'objective': 'regression', 'force_col_wise': True}, lgb.Dataset(X_train, label=y_train), num_boost_round=100)
lgb_pred = model_lgb_reg.predict(X_test)

上面是六个模型,非别是基于分位数回归的XGB,LGBM,线性分位数回归。还有三个基于最普通的MSE损失函数的普通XGB,LGBM和最小二乘线性回归。

# 计算6个模型的MSE和R^2 


models = ['QR', 'XGB Quantile', 'LightGBM Quantile', 'Linear Reg', 'XGBoost', 'LightGBM']
preds = [qr_pred, xgb_q_pred, lgb_q_pred, lr_pred, xgb_pred, lgb_pred]
mse_scores = [mean_squared_error(y_test, pred) for pred in preds]
r2_scores = [r2_score(y_test, pred) for pred in preds]

 画柱状图查看:

colors = sns.color_palette("muted", len(models))
fig, axs = plt.subplots(2, 1, figsize=(9,7))
axs[0].bar(models, mse_scores, color=colors)
axs[0].set_title('MSE Comparison')
axs[0].set_ylabel('MSE')
axs[1].bar(models, r2_scores, color=colors)
axs[1].set_title(r'$R^{2}$ Comparison')
axs[1].set_ylabel(r'$R^{2}$')
plt.tight_layout()
plt.show()

可以看到模型效果来看,XGboost由于Lightgbm优于线性模型。但是分位数回归效果没有MSE损失好,说明在这个数据集表现上,就采用最经典的MSE损失的普通的模型效果会更好。。。

确实是这样的,很多学术创新和改进都不一定比最经典和最常见的方法的效果好。

如果是那种异常值很多的数据,具有异方差的数据 ,可能损失函数改用分位数的会更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1128868.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【单例模式】饿汉式,懒汉式?JAVA如何实现单例?线程安全吗?

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ 单例设计模式 Java单例设计模式 Java单例设计模…

微信消息弹窗升级优化了,在微信打开时也能收到新消息显示。

最近,微信又更新了。微信对消息弹窗进行了升级优化,在微信打开时也能收到新消息显示。 点击「我」-「设置」-「消息通知」,可以看到新增了「横幅显示内容」选项。 有3种内容显示形式,分别为:仅显示你收到1条消息&#…

『 基础算法题解 』之双指针(上)

双指针 文章目录 双指针移动零题目解析算法原理代码拓展 复写零题目解析算法原理代码 快乐数题目解析算法解析拓展 代码 盛最多水的容器题目解析算法解析代码 有效的三角形个数题目解析算法原理代码 移动零 题目解析 【题目链接】 算法原理 该种题目可以归为一类题数组分块\…

想要精通算法和SQL的成长之路 - 最小高度树

想要精通算法和SQL的成长之路 - 最小高度树 前言一. 最小高度树1.1 邻接表的构建1.2 入度为1的先入队1.3 BFS遍历 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 最小高度树 原题链接 从题目的含义中我们可以发现: 题目的树是一颗多叉树。叶子节点的度为1&a…

你的支付环境是否安全?

1、平台支付逻辑全流程分析分析 2、平台支付漏洞如何利用?买东西还送钱? 3、BURP抓包分析修改支付金额,伪造交易状态? 4、修改购物车参数实现底价购买商品 5、SRC、CTF、HW项目月入10W副业之路 6、如何构建最适合自己的网安学习路…

【项目经理】目标管理工具

目标管理工具 1. WBS 任务分解法👊原则方法标准 2. 6W2H法WhatwhyWhowhen⏲️WhereWhichHowHow much 3. SWOT分析法strengths-优势Weaknesses-劣势Opportunities-机会Threats-威胁 4. 二八原则法巴列特定律准则例子 5. SMART原则SpecificMeasurableAttainableReleva…

处于十字路口的CIO:继续进化还是走进死胡同

2023年初Forrester研究给出的一个坏消息表明,有很多CIO尚未准备好满足这些新的需求。大多数CIO(58%)仍处于Forrester所说的传统IT领导模式;有37%的CIO被认为是“现代的”,但只有6%的CIO是“适合未来的”,具…

YOLOv8优化:独家创新(SC_C_Detect)检测头结构创新,实现涨点 | 检测头新颖创新系列

💡💡💡本文独家改进:独家创新(SC_C_Detect)检测头结构创新,适合科研创新度十足,强烈推荐 SC_C_Detect | 亲测在多个数据集能够实现大幅涨点 💡💡💡Yolov8魔术师,独家首发创新(原创),适用于Yolov5、Yolov7、Yolov8等各个Yolo系列,专栏文章提供每一步步…

面试了上百位性能测试后,我发现了一个令人不安的事实

在企业中负责技术招聘的同学,肯定都有一个苦恼,那就是招一个合适的测试太难了!若要问起招哪种类型的测试最难时,相信很多人都会说出“性能测试”这个答案。 每当发布一个性能测试岗位,不一会就能收到上百份简历&#x…

开发者版 ONLYOFFICE 文档 7.5:API 和文档生成器更新

随着版本 7.5 中新功能的发布,我们更新了编辑器、文档生成器、插件和桌面应用程序的 API。阅读本文查看所有详细信息。 用于处理表单的 API 隐藏/显示提交表单按钮:使用 editorConfig.customization.submitForm 参数,可以定义 OFORM 文件的顶…

【CV】图像分割详解!

图像分割是计算机视觉研究中的一个经典难题,已经成为图像理解领域关注的一个热点,图像分割是图像分析的第一步,是计算机视觉的基础,是图像理解的重要组成部分,同时也是图像处理中最困难的问题之一。所谓图像分割是指根…

【量化交易笔记】12.海龟交易策略

引言 海龟交易法则是一种著名的趋势跟踪交易策略,适用于中长线投资者。 海龟交易策略(Turtle Trading)起源于美国,由著名的交易员理查德丹尼斯(Richard Dennis)创立。这种交易策略属于趋势跟踪策略&#…

Quirks(怪癖)模式是什么?它和 Standards(标准)模式有什么区别?

目录 前言: 用法: 代码: Quirks模式示例: Standards模式示例: 理解: Quirks模式: Standards模式: 高质量讨论: 前言: "Quirks模式"和"Standards模式"是与HTML文档渲染模式相关的两种模式。它们影响着浏览器如何解释和渲染HT…

华夏版-超功能记事本 Ⅲ 8.8易语言源码

华夏版-超功能记事本 Ⅲ 8.8易语言源码 下载地址:https://user.qzone.qq.com/512526231

VisualStudio[WPF/.NET]基于CommunityToolkit.Mvvm架构开发

一、创建 "WPF应用程序" 新项目 项目模板选择如下&#xff1a; 暂时随机填一个目标框架&#xff0c;待会改&#xff1a; 二、修改“目标框架” 双击“解决方案资源管理器”中<项目>CU-APP, 打开<项目工程文件>CU-APP.csproj, 修改目标框架TargetFramew…

windows开机自启动和忘记密码-备忘

windows开机自启动和忘记密码-备忘 文章目录 windows开机自启动和忘记密码-备忘1.自启动网址定时任务方式 2.忘记windows用户密码 1.自启动 网址 参考博文&#xff1a;https://blog.csdn.net/wwzmvp/article/details/113656544&#xff0c;感谢博主。 定时任务方式 如图&#…

uniapp如何跳转系统授权管理页?

如何跳转系统授权管理页&#xff1f; 跳转APP应用授权设置页面 文章目录 如何跳转系统授权管理页&#xff1f;效果图打开系统App的权限设置界面 效果图 例&#xff1a;Android 打开系统App的权限设置界面 App端&#xff1a;打开系统App的权限设置界面微信小程序&#xff1a;打开…

20231024后端研发面经整理

1.如何在单链表O(1)删除节点&#xff1f; 狸猫换太子 2.redis中的key如何找到对应的内存位置&#xff1f; 哈希碰撞的话用链表存 3.线性探测哈希法的插入&#xff0c;查找和删除 插入&#xff1a;一个个挨着后面找&#xff0c;知道有空位 查找&#xff1a;一个个挨着后面找…

express session

了解 Session 认证的局限性 Session 认证机制需要配合 cookie 才能实现。由于 Cookie 默认不支持跨域访问&#xff0c;所以&#xff0c;当涉及到前端跨域请求后端接口的时候&#xff0c;需要做很多额外的配置&#xff0c;才能实现跨域 Session 认证。 注意&#xff1a; 当前端…

Unity实现方圆多少米范围随机生成怪物

using System.Collections; using System.Collections.Generic; using UnityEngine;public class CreatMonster : MonoBehaviour {// S这个脚本间隔一点时间生成怪物/*1.程序逻辑* 1. 设计一个计时器* 2.间隔一段时间3s执行一下 * */float SaveTime 0f;public GameObject …