Boosting三巨头:XGBoost、LightGBM和CatBoost(发展、原理、区别和联系,附代码和案例)

news2024/9/24 13:21:45

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

GBDT

(封面图由ERNIE-ViLG AI 作画大模型生成)

Boosting三巨头:XGBoost、LightGBM和CatBoost(发展、原理、区别和联系,附代码和案例)

机器学习中,提高模型精度是研究的重点之一,而模型融合技术中,Boosting算法是一种常用的方法。在Boosting算法中,XGBoost、LightGBM和CatBoost是三个最为流行的框架。它们在实际使用中有各自的优势和适用场景,下面将会介绍它们的区别与联系。

1. 算法原理

1.1 XGBoost

XGBoost是由陈天奇等人提出的一个优化的Gradient Boosting算法,以其出色的表现和可扩展性而受到广泛关注。XGBoost使用了C++实现,可以运行在多个平台上,并支持多种编程语言,如Python、R、Java等。其原理可以概括为将弱学习器依次加入到一个全局的加权模型中,每一轮迭代都在损失函数的梯度方向上优化模型。它在原有GBDT的基础上,添加了正则化项和缺失值处理,使得模型更加稳定和准确。其原理如下:

首先,假设有n个训练样本 ( x i , y i ) (x_{i}, y_{i}) (xi,yi),其中 x i x_{i} xi为输入特征, y i y_{i} yi为输出值。那么,目标就是找到一个函数f(x),使得 f ( x ) f(x) f(x)可以预测 y y y的值。

其次,定义损失函数 L ( y , f ( x ) ) L(y, f(x)) L(y,f(x)),用来度量 f ( x ) f(x) f(x)的预测值与实际值之间的误差。

再次,我们使用Boosting算法来不断迭代提高模型精度。假设现在已经有了一个弱分类器 f m − 1 ( x ) f_{m-1}(x) fm1(x),那么我们希望找到一个新的弱分类器 f m ( x ) f_{m}(x) fm(x)来减少 L ( y , f ( x ) ) L(y, f(x)) L(y,f(x))的值。于是我们在已有的弱分类器 f m − 1 ( x ) f_{m-1}(x) fm1(x)基础上,加上一个新的弱分类器 f m ( x ) f_{m}(x) fm(x),最终得到新的分类器 f m ( x ) = f m − 1 ( x ) + γ h m ( x ) f_{m}(x)=f_{m-1}(x)+\gamma h_{m}(x) fm(x)=fm1(x)+γhm(x),其中 γ \gamma γ为学习率, h m ( x ) h_{m}(x) hm(x)为新的弱分类器。

最后,由于XGBoost使用了正则化项来控制模型的复杂度,并采用了特殊的梯度下降方法进行训练,使得其在处理高维稀疏数据时,具有较好的效果。

1.2 LightGBM

LightGBM是由微软提出的一种基于Histogram算法的Gradient Boosting框架。它通过对样本特征值进行离散化,将连续特征离散化为有限个整数,从而将高维稀疏数据转化为低维稠密数据,从而加速了训练速度。相比XGBoost,LightGBM的最大优势在于其快速的训练速度和较小的内存占用,这主要得益于其采用了基于直方图的决策树算法和局部优化等技术。LightGBM的核心思想是在构造决策树时,将连续特征离散化为若干个桶,然后将每个桶作为一个离散特征对待,从而加速树的构建和训练过程。其原理如下:

首先,对于每个特征,我们需要将其离散化为一些桶,每个桶中包含一些连续的特征值。在训练时,我们只需要计算每个桶中的样本的统计信息(如平均值和方差),而不需要计算每个样本的特征值,从而减少了计算量。

其次,对于每个样本,我们根据离散化后的特征值,将其归入对应的桶中,然后计算桶中样本的统计信息。接着,我们通过梯度单边采样(GOSS)算法,选择一部分样本进行训练,这些样本中包含了大部分的梯度信息,从而保证了训练的准确性和效率。

最后,LightGBM还使用了基于直方图的决策树算法,使得在处理高维稀疏数据时,具有较好的效果。

1.3 CatBoost

CatBoost是由Yandex提出的一种基于梯度提升算法的开源机器学习框架。它在处理分类问题时,可以自动处理类别特征,无需手动进行特征编码。CatBoost的原理与XGBoost和LightGBM类似,同样是通过将多个弱学习器组合成一个强学习器。不同之处在于,CatBoost使用了一种新的损失函数,即加权交叉熵损失函数,可以有效地处理类别不平衡问题。其原理如下:

首先,CatBoost使用了一种称为Ordered Boosting的算法来提高模型精度。Ordered Boosting可以看做是一种特殊的特征选择方法,它将训练样本按照特征值大小排序,然后使用分段线性模型拟合每一段特征值的梯度,从而提高了模型的拟合能力。

其次,CatBoost在处理分类问题时,可以自动处理类别特征。它使用了一种称为Target Encoding的方法,将类别特征转化为一组实数值,从而避免了手动进行特征编码的麻烦。

最后,CatBoost还使用了基于对称树的决策树算法,使得在处理高维稀疏数据时,具有较好的效果。

2. 发展前景和应用场景

XGBoost、LightGBM和CatBoost作为目前最先进的梯度提升算法,在许多数据科学竞赛和实际应用中都取得了很好的效果。随着大数据时代的到来,这三种算法的应用场景也越来越广泛。

其中,XGBoost在传统机器学习领域仍然是最常用的算法之一,特别是在结构化数据的分类、回归和排序任务中表现突出。LightGBM在大规模数据集和高维度数据上表现更佳,适用于处理文本分类、图像分类、推荐系统等领域的数据。CatBoost在处理类别特征和缺失值方面表现出色,适用于电商推荐、医疗预测、金融风控等领域的数据。

总的来说,XGBoost、LightGBM和CatBoost作为梯度提升算法的代表,都具有自身的优势和适用场景,随着数据和计算能力的不断提升,它们的应用前景也会越来越广阔。

3. 使用案例

3.1 XGBoost

XGBoost可以应用于多种场景,如回归、分类、排序等。下面以Kaggle竞赛中的房价预测问题为例,展示如何使用XGBoost进行模型训练和预测。

首先,我们使用Pandas读取数据集,并将其划分为训练集和测试集。

import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_csv('train.csv')
X = df.drop('SalePrice', axis=1)
y = df['SalePrice']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

接着,我们使用XGBoost的Python接口进行模型训练和预测。

import xgboost as xgb
from sklearn.metrics import mean_squared_error

dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)

params = {'max_depth': 3, 'learning_rate': 0.1, 'objective': 'reg:squarederror'}
num_rounds = 100

model = xgb.train(params, dtrain, num_rounds)
y_pred = model.predict(dtest)

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error:", mse)

上述代码中,我们首先使用xgb.DMatrix将训练数据转化为DMatrix格式。接着,我们定义了模型参数,并设置了迭代次数为100。然后,我们使用xgb.train函数进行模型训练,并使用model.predict函数进行模型预测。最后,我们使用sklearn.metrics.mean_squared_error函数计算了模型的均方误差。

3.2 LightGBM

LightGBM可以应用于多种场景,如回归、分类、排序等。下面以Kaggle竞赛中的鸢尾花分类问题为例,展示如何使用LightGBM进行模型训练和预测。

首先,我们使用Pandas读取数据集,并将其划分为训练集和测试集。

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

data = load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target

X = df.drop('target', axis=1)
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

接着,我们使用LightGBM的Python接口进行模型训练和预测。

import lightgbm as lgb
from sklearn.metrics import accuracy_score

dtrain = lgb.Dataset(X_train, label=y_train)
dtest = lgb.Dataset(X_test, label=y_test)

params = {'objective': 'multiclass', 'num_class': 3, 'metric': 'multi_logloss'}
num_rounds = 100

model = lgb.train(params, dtrain, num_rounds)
y_pred = model.predict(X_test)

y_pred = [list(x).index(max(x)) for x in y_pred]
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

上述代码中,我们首先使用lgb.Dataset将训练数据转化为Dataset格式。接着,我们定义了模型参数,并设置了迭代次数为100。然后,我们使用lgb.train函数进行模型训练,并使用model.predict函数进行模型预测。最后,我们使用sklearn.metrics.accuracy_score函数计算了模型的准确率。

3.3 CatBoost

import catboost as cb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)

train_data = cb.Pool(X_train, label=y_train)
test_data = cb.Pool(X_test, label=y_test)

params = {'loss_function': 'MultiClass', 'num_class': 3, 'eval_metric': 'Accuracy'}

num_rounds = 20
bst = cb.train(params, train_data, num_rounds)

preds = bst.predict(X_test)
y_pred = [np.argmax(pred) for pred in preds]

acc = accuracy_score(y_test, y_pred)
print("Accuracy:", acc)

综合展示

为了更好地展示XGBoost、LightGBM和CatBoost的应用场景和效果,我们以波士顿房价预测数据集为例进行实验。

首先,我们使用sklearn库中的load_boston函数加载数据集,并对数据进行划分,80%用于训练,20%用于测试。

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

data = load_boston()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)

然后,我们依次使用XGBoost、LightGBM和CatBoost训练模型,并对模型进行评估。

import xgboost as xgb
from lightgbm import LGBMRegressor
from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error

# XGBoost
xgb_model = xgb.XGBRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42)
xgb_model.fit(X_train, y_train)
xgb_pred = xgb_model.predict(X_test)
xgb_rmse = mean_squared_error(y_test, xgb_pred, squared=False)

# LightGBM
lgb_model = LGBMRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42)
lgb_model.fit(X_train, y_train)
lgb_pred = lgb_model.predict(X_test)
lgb_rmse = mean_squared_error(y_test, lgb_pred, squared=False)

# CatBoost
cat_model = CatBoostRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_seed=42, silent=True)
cat_model.fit(X_train, y_train)
cat_pred = cat_model.predict(X_test)
cat_rmse = mean_squared_error(y_test, cat_pred, squared=False)

print("XGBoost RMSE: {:.2f}".format(xgb_rmse))
print("LightGBM RMSE: {:.2f}".format(lgb_rmse))
print("CatBoost RMSE: {:.2f}".format(cat_rmse))

参考文献

[1] Chen, T., & Guestrin, C. (2016). Xgboost: A scalable tree boosting system. In Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and data mining (pp. 785-794).
[2] Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., … & Liu,
G. (2017). Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information Processing Systems (pp. 3146-3154).
[3] Prokhorenkova, L., Gusev, G., Vorobev, A., Dorogush, A. V., & Gulin, A. (2018). CatBoost: unbiased boosting with categorical features. In Advances in neural information processing systems (pp. 6638-6648).
[4] XGBoost官方文档:https://xgboost.readthedocs.io/en/latest/
[5] LightGBM官方文档:https://lightgbm.readthedocs.io/en/latest/
[6] CatBoost官方文档:https://catboost.ai/docs/
[7] 《Python机器学习基础教程》(吴斌):介绍了XGBoost、LightGBM和CatBoost的使用方法和实例。
[8] 《Applied Machine Learning》(Kelleher, John D.):介绍了各种机器学习算法,其中也包括了梯度提升算法和其变种。
[9] 《Hands-On Gradient Boosting with XGBoost and scikit-learn》(Villalba, Benjamin):详细介绍了XGBoost和scikit-learn库的梯度提升实现。
[10] 《Gradient Boosting》(Friedman, Jerome H.):介绍了梯度提升算法的基本思想和实现原理。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/400253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Binder ——binder的jni注册和binder驱动

环境:Android 11源码Android 11 内核源码源码阅读器 sublime textbinder的jni方法注册zygote启动1-1、启动zygote进程zygote是由init进程通过解析init.zygote.rc文件而创建的,zygote所对应的可执行程序是app_process,所对应的源文件是app_mai…

因果推断12--dragonnet论文和代码学习

目录 论文 dragonnet 1介绍 2 Dragonnet 3定向正则化 4相关工作 5实验 6讨论 NN-Based的模型 dragonnet 如何更新参数 dragonnet的损失函数 CausalML Dragonnet类 论文代码 论文 dragonnet Adapting Neural Networks for the Estimation of Treatment Effects 应…

二叉搜索树的实现

什么是二叉搜索树1.若它的左子树不为空,那么左子树上所有节点都小于根节点2.若它的右子树不为空,那么右子树上所有节点都小于根节点3.它的左右子树也分别是二叉搜索树4.使用中序遍历结果是从小到大定义节点,使用静态内部类static class TreeN…

http组成及状态及参数传递

http组成及状态及参数传递 早期的网页都是通过后端渲染来完成的:服务器端渲染(SSR,server side render): 客户端发出请求 -> 服务端接收请求并返回相应HTML文档 -> 页面刷新,客户端加载新的HTML文档&…

7综合项目 旅游网 【7.精选分类】

精选旅游人气旅游→收藏次数最高最新旅游→日期最新主题旅游→主题关键字相同在首页将精选的内容动态展示的实现分析首页中的精选包含“人气旅游”、“最新旅游”、“主题旅游”三个部分index.html//页面加载完成,发送ajax请求根据点击不同分类展示不同内容人气旅游→收藏次数最…

分享17个提升开发效率的工具“轮子”

本文是向大家介绍平时在开发中经常用到的小工具,它能够极大得提升我们的开发效率,能够解决平时开发中遇到的问题。前言在java的庞大体系中,其实有很多不错的小工具,也就是我们平常说的“轮子“。如果在我们的日常工作当中&#xf…

数据结构课程设计:高铁信息管理系统(C++实现)

目录 简介实验输出实验要求代码运行环境结语简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:海轰 标签:程序猿|C++选手|学生 简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金,有幸在竞赛中拿过一些国奖…

嵌入式Linux驱动开发(一)chrdevbase虚拟字符设备

Linux下三大驱动:字符设备,块设备,网络设备。一个硬件可以从属于不同的设备分类。 0. Linux应用程序对驱动程序的调用流程 驱动加载成功后会在/dev目录下生成一个文件,对该文件的操作就是对设备的操作。当我们在用户态调用一个函…

Element-UI实现复杂table表格结构

Element-UI组件el-table用于展示多条结构类似的数据,可对数据进行排序、筛选、对比或其他自定义操作。将使用到以下两项,来完成今天demo演示:多级表头:数据结构比较复杂的时候,可使用多级表头来展现数据的层次关系。合…

Web3中文|Web3CN加速器第二期「Web3项目征集」火热报名

Web3CN加速器第二期「Web3项目征集」火热征集中,本次征集活动是由Web3CN加速器联合专业web3媒体Web3CN、VC机构Tiger VC DAO核心发起,数百家加密VC机构、加密社区等联合发起的,为早期Web3创新创业项目提供加速服务。如果你正在进行web3相关的…

VC常见问题(.obj : error LNK2019、fatal error C1083、编译64位Detours)

VC常用问题VC常见问题*.obj : error LNK2019: 无法解析的外部符号 __imp_FindWindow ,该符号在函数 YAWindows环境下用nmake编译常见问题fatal error C1083: 无法打开包括文件:“excpt.h”vs2012编译64位Detours(其他vs版本同理)vs项目设置选项编译使用了…

Java基础面试题(三)

Java基础面试题 一、JavaWeb专题 1.HTTP响应码有哪些 1、1xx(临时响应) 2、2xx(成功) 3、3xx(重定向):表示要完成请求需要进一步操作 4、4xx(错误):表示请…

Nuxt实战教程基础-Day01

Nuxt实战教程基础-Day01Nuxt是什么?Nuxt.js框架是如何运作的?Nuxt特性流程图服务端渲染(通过 SSR)单页应用程序 (SPA)静态化 (预渲染)Nuxt优缺点优点缺点安装运行项目总结前言:本教程基于Nuxt2,作为教程的第一天,我们先…

BUUCTF-[RoarCTF2019]polyre

题目下载:下载 这道题目是一个关于控制流平坦化和虚假流程。 首先了解一下控制流平坦化:利用符号执行去除控制流平坦化 - 博客 - 腾讯安全应急响应中心https://www.cnblogs.com/zhwer/p/14081454.htmlbuuctf RoarCTF2019 polyre writeup - 『脱壳破解区…

单点登录的几种实现方式探讨

单点登录(Single Sign On),简称为 SSO,是解决企业内部的一系列产品登录问题的方案。SSO 的定义是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统,用于减少用户重复的登录操作&#…

PyTorch的自动微分(autograd)

PyTorch的自动微分(autograd) 计算图 计算图是用来描述运算的有向无环图 计算图有两个主要元素:结点(Node)和边(Edge) 结点表示数据,如向量、矩阵、张量 边表示运算,如加减乘除卷积等 用计算…

共话开源 | 开放原子开源基金会专题调研openKylin社区!

3月8日,开放原子开源基金会秘书长冯冠霖、运营部部长李博、业务发展部部长朱其罡、研发部副部长周济一行莅临openKylin社区调研交流,麒麟软件高级副总经理韩乃平、副总裁董军平、终端研发部副总经理陆展、产品规划部经理常亚武、市场与政府事务部高级经理…

力扣sql简单篇练习(二十五)

力扣sql简单篇练习(二十五) 1 无效的推文 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 # Write your MySQL query statement below SELECT tweet_id FROM Tweets WHERE CHAR_LENGTH(content)>151.3 运行截图 2 求关注者的数量 2.1 基本题目内…

【Linux实战篇】二、在Linux上部署各类软件

一、实战章节:在Linux上部署各类软件 二、MySQL数据库管理系统安装部署【简单】 简介 MySQL数据库管理系统(后续简称MySQL),是一款知名的数据库系统,其特点是:轻量、简单、功能丰富。 MySQL数据库可谓是…

在矩池云运行 Stable Diffusion web UI,使用v1.5模型和 ControlNet 插件

今天给大家介绍下如何在矩池云使用 Stable Diffusion web UI v1.5 模型和 Stable Diffusion ControlNet 插件。 租用机器 租用机器需要选择内存大于8G的机器,比如 A2000,不然 Stable Diffusion web UI 启动加载模型会失败。(Killed 内存不足…