花了3周理解的xgboost算法原理

news2024/11/17 17:53:22

文章目录

  • 算法流程
  • CART树
  • 最佳节点值
  • 最佳树结构

算法流程

先学决策树,再学随机森林,最后才来到xgboost。本以为如此平滑地过渡过来,会容易一些,没想到还是这么艰难。零零散散花了3周多的时间,看了好多文章的解释和阐述,才勉强get到xgboost的算法原理。

这其中,有两个参考来源,我个人比较推荐:一个是文字,一个是视频。

此处直接给出xgboost的算法流程图,方便我们直观认识xgboost。针对一个训练集,xgboost首先使用CART树训练得到一个模型,这样针对每个样本都会产生一个偏差值;然后将样本偏差值作为新的训练集,继续使用CART树训练得到一个新模型;以此重复,直至达到某个退出条件为止。

最终的xgboost模型就是将上述所有模型进行加和。假设一共有M个模型,每个模型的输出被定义为 f i f_i fi,那么xgboost模型的最终输出 y ^ i \hat y_i y^i
y ^ i = ∑ i = 1 M f i ( x i ) \hat y_i=\sum_{i=1}^Mf_i(x_i) y^i=i=1Mfi(xi)

显然,xgboost和随机森林一样,也是多个模型的集成,但是它们之间还存在诸多不同。用专业术语来描述的话,那就是xgboost是Boosting的方式,核心特点在于降低偏差,逻辑是串行的;而随机森林是Bagging的方式,核心特点在于降低方差,逻辑是并行的。

弄清楚了算法流程,我们还需要关注流程中的细节。需要搞明白的事情有两个:(1)什么是CART树?(2)CART树模型如何得到?

CART树

简单理解,CART树首先是一个树;在此基础上,每个叶子节点 j j j都会被赋予一个节点值 w j w_j wj

假设第 i i i个样本 x i x_i xi和第 j j j个节点之间的映射关系为
j = q ( x i ) j=q(x_i) j=q(xi)
以下图为例: q ( x 1 ) = 2 , q ( x 2 ) = 1 , q ( x 3 ) = 2 q(x_1)=2,q(x_2)=1,q(x_3)=2 q(x1)=2,q(x2)=1,q(x3)=2

那么第 i i i个样本的节点值为
w q ( x i ) w_{q(x_i)} wq(xi)
反过来说的话,第 j j j个节点的样本集合 I j I_j Ij可以表达为
I j = { x i ∣ q ( x i ) = j } I_j= \{x_i|q(x_i)=j \} Ij={xiq(xi)=j}
还是以上图为例: I 1 = { 2 } , I 2 = { 1 , 3 } I_1=\{2\}, I_2=\{1,3\} I1={2},I2={1,3}

从CART树的概念上,我们很容易发现,确定一颗CART树需要两类数据:树结构和节点值。接下来,我们先研究清楚,树结构固定时的最佳节点值优化策略;在此基础上,再返回来尝试确定最佳的树结构。

最佳节点值

当树结构固定后,每个样本落入的节点随即确定。此时,定义如下的目标函数来衡量xgboost对整体样本的误差
obj = ∑ i = 1 n l ( y i , y ^ i ( M ) ) + ∑ m = 1 M Ω ( f m ) \text{obj}=\sum_{i=1}^nl(y_i,\hat y_i^{(M)})+\sum_{m=1}^M\Omega{(f_m)} obj=i=1nl(yi,y^i(M))+m=1MΩ(fm)
其中,第一项的含义是样本自身误差, n n n表示样本数量, l ( ⋅ ) l(·) l()是衡量样本误差的函数,如MSE等, y i y_i yi为第 i i i个样本的真值, y ^ i ( M ) \hat y_i^{(M)} y^i(M)为第 i i i个样本的预测值;第二项是正则项,正则项之前在线性模型优化:岭回归和Lasso回归中用到过,主要目的是降低过拟合风险,此处的目标同样是降低过拟合的风险。

第二项简单,先处理一下
∑ m = 1 M Ω ( f m ) = ∑ m = 1 M − 1 Ω ( f m ) + Ω ( f M ) \sum_{m=1}^M\Omega{(f_m)}=\sum_{m=1}^{M-1}\Omega{(f_m)}+\Omega{(f_M)} m=1MΩ(fm)=m=1M1Ω(fm)+Ω(fM)
从算法流程可知,CART树是一颗一颗建立起来的,当我们要优化第 M M M颗树的时候,前 M − 1 M-1 M1颗树已经完成了计算,即这些树的正则项值已经确定,所以我们在优化时可以不考虑;对于第 M M M颗树, Ω \Omega Ω被定义为
Ω ( f M ) = γ T + 1 2 λ ∑ j = 1 T w j 2 \Omega(f_M)=\gamma T+\frac{1}{2}\lambda\sum_{j=1}^Tw_j^2 Ω(fM)=γT+21λj=1Twj2
该定义可以如此理解:一方面, T T T值大,说明树的深度比较深,过拟合的概率就会变高,所以使用 γ \gamma γ进行惩罚;另一方面, w w w值大,说明该树在整个模型中会占据较大的比重,即预测结果主要依赖该树,此时过拟合风险也会变高,所以需要再使用 λ \lambda λ进行惩罚。

现在回到第一项。有了第二项的具体表达式后,如果再给定 l ( ⋅ ) l(·) l()的表达式,看起来好像可以直接通过梯度求导得到极值。但实际上该方式是不可行的。这主要是因为 y ^ \hat y y^是通过树模型得到的,所以该值并不连续,所以不可导。

既然如此,我们就需要其它的解决方案。好在,因为前 M − 1 M-1 M1颗CART树已经确定,所以只需要关注第 M M M颗树的节点值即可,所以第一项可以转化为
∑ i = 1 n l ( y i , y ^ i ( M ) ) = ∑ j = 1 T ∑ i ∈ I j l ( y i , y ^ i ( M ) ) \sum_{i=1}^nl(y_i,\hat y_i^{(M)})=\sum_{j=1}^T\sum_{i\in I_j}l(y_i,\hat y_i^{(M)}) i=1nl(yi,y^i(M))=j=1TiIjl(yi,y^i(M))
该转化的价值在于将误差的统计逻辑从样本的加和转化为节点的加和,这样就可以和 Ω \Omega Ω使用相同的变量,便于表达式的合并操作。

为了处理该项,我们将对其进行泰勒展开,并保留至二阶项。

先回顾泰勒公式
f ( x + δ x ) = f ( x ) + f ′ ( x ) δ x + 1 2 f ′ ′ ( x ) δ x 2 f(x+\delta x)=f(x)+f'(x)\delta x+\frac{1}{2}f''(x)\delta x^2 f(x+δx)=f(x)+f(x)δx+21f′′(x)δx2
此处把 y ^ i ( M − 1 ) \hat y_i^{(M-1)} y^i(M1)定义为 x x x,那么 δ x \delta x δx就是 w j w_j wj,我们照葫芦画瓢进行泰勒展开
l ( y i , y ^ i ( M ) ) = l ( y i , y ^ i ( M − 1 ) + w j ) = l ( y i , y ^ i ( M − 1 ) ) + l ′ ( y i , y ^ i ( M − 1 ) ) w j + 1 2 l ′ ′ ( y i , y ^ i ( M − 1 ) ) w j 2 l(y_i, \hat y_i^{(M)})=l(y_i, \hat y_i^{(M-1)}+w_j)=l(y_i, \hat y_i^{(M-1)})+l'(y_i, \hat y_i^{(M-1)})w_j+\frac{1}{2}l''(y_i, \hat y_i^{(M-1)})w_j^2 l(yi,y^i(M))=l(yi,y^i(M1)+wj)=l(yi,y^i(M1))+l(yi,y^i(M1))wj+21l′′(yi,y^i(M1))wj2
第一个为常量,也可以不考虑。令 g i = l ′ ( y i , y ^ i ( M − 1 ) ) g_i=l'(y_i, \hat y_i^{(M-1)}) gi=l(yi,y^i(M1)) h i = l ′ ′ ( y i , y ^ i ( M − 1 ) ) h_i=l''(y_i, \hat y_i^{(M-1)}) hi=l′′(yi,y^i(M1))

总的目标函数变为

obj = ∑ j = 1 T ∑ i ∈ I j [ g i w j + 1 2 h i w j 2 ] + γ T + 1 2 λ ∑ j = 1 T w j 2 \text{obj}=\sum_{j=1}^T\sum_{i\in I_j}[g_iw_j+\frac{1}{2}h_iw_j^2]+\gamma T+\frac{1}{2}\lambda\sum_{j=1}^Tw_j^2 obj=j=1TiIj[giwj+21hiwj2]+γT+21λj=1Twj2
由于 w j w_j wj i i i无关,所以可以调整为
obj = ∑ j = 1 T [ w j ∑ i ∈ I j g i + 1 2 w j 2 ∑ i ∈ I j h i ] + γ T + 1 2 λ ∑ j = 1 T w j 2 \text{obj}=\sum_{j=1}^T[w_j\sum_{i\in I_j}g_i+\frac{1}{2}w_j^2\sum_{i\in I_j}h_i]+\gamma T+\frac{1}{2}\lambda\sum_{j=1}^Tw_j^2 obj=j=1T[wjiIjgi+21wj2iIjhi]+γT+21λj=1Twj2
合并 w j 2 w_j^2 wj2项,得到
obj = ∑ j = 1 T [ w j ∑ i ∈ I j g i + 1 2 w j 2 ( λ + ∑ i ∈ I j h i ) ] + γ T \text{obj}=\sum_{j=1}^T[w_j\sum_{i\in I_j}g_i+\frac{1}{2}w_j^2(\lambda+\sum_{i\in I_j}h_i)]+\gamma T obj=j=1T[wjiIjgi+21wj2(λ+iIjhi)]+γT

G j = ∑ i ∈ I j g i G_j=\sum_{i\in I_j}g_i Gj=iIjgi H j = ∑ i ∈ I j h i H_j=\sum_{i\in I_j}h_i Hj=iIjhi,上式可以简化为
obj = ∑ j = 1 T [ w j G j + 1 2 w j 2 ( λ + H j ) ] + γ T \text{obj}=\sum_{j=1}^T[w_jG_j+\frac{1}{2}w_j^2(\lambda+H_j)]+\gamma T obj=j=1T[wjGj+21wj2(λ+Hj)]+γT
这是个二元一次表达式,最优解为
w j ∗ = − G j λ + H j w_j^*=-\frac{G_j}{\lambda+H_j} wj=λ+HjGj
对应的最优目标函数值为
obj ∗ = − 1 2 ∑ j = 1 T G j 2 λ + H j + γ T \text{obj}^*=-\frac{1}{2}\sum_{j=1}^T\frac{G_j^2}{\lambda+H_j}+\gamma T obj=21j=1Tλ+HjGj2+γT

此处还需要描述一下 g i g_i gi h i h_i hi具体是如何计算的。举个例子,假设我们正在优化第11棵CART树,也就是说前10棵 CART树已经确定了。这10棵树对样本( x i , y i = 1 x_i,y_i=1 xi,yi=1)的预测值是 y i = − 1 y_i=-1 yi=1,假设我们现在是做分类,我们的损失函数是
L ( θ ) = ∑ i y i l n ( 1 + e − y ^ i ) + ( 1 − y i ) l n ( 1 + e y ^ i ) L(\theta)=\sum_i y_iln(1+e^{-\hat y_i})+(1-y_i)ln(1+e^{\hat y_i}) L(θ)=iyiln(1+ey^i)+(1yi)ln(1+ey^i)
由于 y i = 1 y_i=1 yi=1,损失函数变为
L ( θ ) = l n ( 1 + e − y ^ i ) L(\theta)=ln(1+e^{-\hat y_i}) L(θ)=ln(1+ey^i)
求梯度,结果为
e − y ^ i 1 − e − y ^ i \frac{e^{-\hat y_i}}{1-e^{-\hat y_i}} 1ey^iey^i
y ^ i = − 1 \hat y_i=-1 y^i=1带入梯度表达式,便得到 g 11 = − 0.27 g_{11}=-0.27 g11=0.27

针对梯度表达式继续求导,得到二阶导数表达式后,再带入 y ^ i \hat y_i y^i的值,便可得到 h 11 h_{11} h11

再理解一下 w j ∗ w_j^* wj。假设节点处只有一个样本,此时
w j ∗ = ( 1 λ + h j ) ( − g j ) w_j^*=(\frac{1}{\lambda+h_j})(-g_j) wj=(λ+hj1)(gj)
− g j -g_j gj代表的是负梯度方向,即目标函数值下降最快的方向,这是符合我们认知的; 1 λ + h j \frac{1}{\lambda+h_j} λ+hj1可以看作是学习率,如果 h j h_j hj值大,表明梯度变化比较大,即微小的扰动都会带来目标函数的极大变化,此时应该让学习率小一些,所以也是符合我们的认知的。

综上,只要树结构确定了,那么就可以通过以上的方法得到每个节点的最优值 w i w_i wi,从而使得目标函数值最小。现在,我们只剩最佳树结构了。

最佳树结构

为了确定最佳的树结构,本节介绍一种常用的方法:贪心算法。

如下图所示。针对当前节点,有A、B和C三个样本,可以计算得到该节点处的最优目标函数值为
obj 0 = γ − 1 2 ( G A + G B + G C ) 2 H A + H B + H C + λ \text{obj}_0=\gamma-\frac{1}{2}\frac{(G_A+G_B+G_C)^2}{H_A+H_B+H_C+\lambda} obj0=γ21HA+HB+HC+λ(GA+GB+GC)2

尝试将该节点拆分,假设可能出现三种情况,[A, BC]、[C, AB]和[B, AC],对应的最优目标函数值分别为
obj 1 = 2 γ − 1 2 G A 2 H A + λ − 1 2 ( G B + G C ) 2 H B + H C + λ \text{obj}_1=2\gamma-\frac{1}{2}\frac{G_A^2}{H_A+\lambda}-\frac{1}{2}\frac{(G_B+G_C)^2}{H_B+H_C+\lambda} obj1=2γ21HA+λGA221HB+HC+λ(GB+GC)2
obj 2 = 2 γ − 1 2 ( G A + G B ) 2 H A + H B + λ − 1 2 ( G C ) 2 H C + λ \text{obj}_2=2\gamma-\frac{1}{2}\frac{(G_A+G_B)^2}{H_A+H_B+\lambda}-\frac{1}{2}\frac{(G_C)^2}{H_C+\lambda} obj2=2γ21HA+HB+λ(GA+GB)221HC+λ(GC)2
obj 3 = 2 γ − 1 2 ( G A + G C ) 2 H A + H C + λ − 1 2 ( G B ) 2 H B + λ \text{obj}_3=2\gamma-\frac{1}{2}\frac{(G_A+G_C)^2}{H_A+H_C+\lambda}-\frac{1}{2}\frac{(G_B)^2}{H_B+\lambda} obj3=2γ21HA+HC+λ(GA+GC)221HB+λ(GB)2

分别计算3种情况下的目标函数变化值,即: obj 1 − obj 0 \text{obj}_1-\text{obj}_0 obj1obj0 obj 2 − obj 0 \text{obj}_2-\text{obj}_0 obj2obj0 obj 3 − obj 0 \text{obj}_3-\text{obj}_0 obj3obj0。然后取最大的变化值所对应的拆分结果作为下次该节点的最佳拆分方式。

所以确定最佳树结构的流程是:
从树的深度为0开始:
(1)对每个叶节点枚举所有的可用特征;
(2)针对每个特征,把属于该节点的训练样本根据该特征值进行升序排列,通过以上的贪心逻辑来决定该特征的最佳拆分点,并记录该特征的拆分收益;
(3)选择收益最大的特征作为拆分特征,用该特征的最佳拆分点作为拆分位置,在该节点上拆分出左右两个新的叶节点,并为每个新节点关联对应的样本集;
(4)回到第1步,重复执行直到满足特定条件为止;

至此,总算是搞明白了xgboost的算法原理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/682231.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据从0到1的完美落地之Flume案例2

案例演示 实时采集(监听目录):Spool File HDFS Spool 是Source来源于目录,有文件进入目录就摄取,File Channel将它暂存到磁盘,最终目的地是HDFS 即只要某个目录不断有文件,HDFS上也会同步到所有数据。 配置方案 [ro…

机器学习:基于逻辑回归对航空公司乘客满意度的因素分析

机器学习:基于逻辑回归对航空公司乘客满意度的因素分析 作者:i阿极 作者简介:数据分析领域优质创作者、多项比赛获奖者:博主个人首页 😊😊😊如果觉得文章不错或能帮助到你学习,可以点…

前端Vue仿京东加入购物车弹框立即购买弹框shopDialog自定义弹框内容

前端Vue仿京东加入购物车弹框立即购买弹框shopDialog自定义弹框内容, 下载完整代码请访问uni-app插件市场地址:https://ext.dcloud.net.cn/plugin?id13183 效果图如下: # cc-shopDialog #### 使用方法 使用注意: 该插件需引用…

【软件基础】面向对象编程知识总结

文章目录 前言一、面向对象要解决的问题1、 软件重用性差2、软件可维护性差3、不能满足用户需求 二、面向对象的基本概念三、面向对象的特征四、面向对象的要素五、面向对象的开发方法六、面向对象的模型1、对象模型2、动态模型3、功能模型 总结1、鸭子抽象类2、鸭子类3、鸭子动…

【Tableau案例】神奇宝贝属性及实力强弱|数据可视化

提前声明:神奇宝贝的数据分析仅供参考,不涉及对于神奇宝贝的各种评价,另外我是初学tableau,涉及到使用的tableau操作可能很简单,复杂的还掌握不熟练,之后会拿时间系统学习tabelau。 数据预处理 该数据集有…

电容笔和触控笔哪个好用?推荐平价好用的电容笔

实际上,电容笔和触控笔这两款笔最大的区别,就在于它的应用范围,一个是适用电容型屏幕,一个是适用电阻型屏幕。如果你想要一个与IPAD相匹配的电容笔,苹果的Pencil将会是一个很好的选择。实际上,平替的电容笔…

Java-API简析_java.lang.Throwable类(基于 Latest JDK)(浅析源码)

【版权声明】未经博主同意,谢绝转载!(请尊重原创,博主保留追究权) https://blog.csdn.net/m0_69908381/article/details/131367906 出自【进步*于辰的博客】 其实我的【Java-API】专栏内的博文对大家来说意义是不大的。…

成都爱尔林江院长解析看懂验光单,掌握配镜“秘密”

想要更了解自己的“数据”, 想知道自己近视有没有增长, 该如何知道自己的度数呢? 到医院进行验光, 验光后得到的验光单,自己有仔细看过吗? 一串字母与数字,知道都代表着什么吗?…

警惕度量指标陷阱

本文首发于个人网站「BY林子」,转载请参考版权声明。 近日,某群有人发了领导制定的绩效考核指标: 对测试人员的工作成效进行考核,指标是发现的 Bug 的情况,甚至有参考指标细到每个小时要求发现多少 Bug,同时…

VUE L ClassStyle ⑦

目录 文章有误请指正,如果觉得对你有用,请点三连一波,蟹蟹支持✨ V u e j s Vuejs Vuejs C l a s s Class Class与 S t y l e Style Style绑定总结 文章有误请指正,如果觉得对你有用,请点三连一波,蟹蟹支持…

scratch绘制正方形 少儿编程 电子学会图形化编程scratch编程等级考试二级真题和答案解析2023年5月

目录 scratch绘制正方形 一、题目要求 1、准备工作 2、功能实现 二、案例分析</

动态规划之下降路径最小和

1. 题目分析 题目链接选自力扣 : 下降路径最小和 如果光看这个题目说明的话, 是有点抽象的. 我们结合实例 1 来看 : 总的来说就是, 起始点是第一行中的任意一点, 每个点只有三个方向可以走即向下, 左下, 右下. 当到达最后一行的任意一点即算作到达终点. 期间不同的路径上不同…

mysql单机安装

准备工作 检测项 检测命令 标配值 服务器内存 free -m 32G 硬盘 df -h 1T seLinux getenforce Disabled&#xff08;disabled指关闭&#xff0c;Enforcing指开启 文件描述符大小 ulimit -n 65535 其他优化 Other Other 清理环境 卸载服务器自带…

佩戴比较舒适的蓝牙耳机有哪些?长久佩戴舒适的蓝牙耳机推荐

​听歌、刷剧、游戏&#xff0c;运动、吃饭、睡觉等&#xff0c;要说现在年轻人除了离不开手机之外&#xff0c;还有就是蓝牙耳机了&#xff01;当然&#xff0c;随着蓝牙耳机的快速发展&#xff0c;各种各样的蓝牙耳机都有&#xff0c;导致很多人不知道耳机怎么选了&#xff0…

四大因素解析:常规阻抗控制为什么只能是10%?

随着高速信号传输&#xff0c;对高速PCB设计提出了更高的要求&#xff0c;阻抗控制是高速PCB设计常规设计&#xff0c;PCB加工十几道工序会存在加工误差&#xff0c;当前常规板厂阻抗控制都是在10%的误差。理论上&#xff0c;这个数值是越小越好&#xff0c;为什么是10%&#x…

Git进阶系列 | 7. Git中的Cherry-pick提交

Git是最流行的代码版本控制系统&#xff0c;这一系列文章介绍了一些Git的高阶使用方式&#xff0c;从而帮助我们可以更好的利用Git的能力。本系列一共8篇文章&#xff0c;这是第7篇。原文&#xff1a;Cherry-Picking Commits in Git[1] 在本系列的第5部分中&#xff0c;讨论了r…

Facebook如何与品牌合作,提升用户体验?

Facebook是全球最大的社交媒体平台之一&#xff0c;每天有数亿用户在上面发布内容、互动交流。对于品牌来说&#xff0c;与Facebook合作可以帮助它们扩大影响力、吸引更多潜在客户。 但是&#xff0c;与Facebook合作不仅仅是在平台上发布广告&#xff0c;还需要更深入的合作来…

Ramnit病毒分析

概述 Ramnit病毒是一个相对古老的病毒&#xff0c;使用会感染系统内的exe和html文件&#xff0c;通过文件分发和U盘传播。 样本的基本信息 Verified: Unsigned Link date: 19:02 2008/2/12 Company: SOFTWIN S.R.L. Description: BitDefender Management Console MachineTyp…

王道操作系统学习笔记(3)——内存管理

前言 本文介绍了操作系统中的内存管理&#xff0c;文章中的内容来自B站王道考研操作系统课程&#xff0c;想要完整学习的可以到B站官方看完整版。 3.1.1&#xff1a;内存基本知识&#xff08;指令工作原理、编译、链接、逻辑地址到物理地址的转换&#xff09; 内存可存放数据…

【yocto1】利用yocto工具构建嵌入式Linux系统

文章目录 1.获取Yocto软件源码2.初始化Yocto构建目录2.1 imx-setup-release.sh脚本运行2.2 imx-setup-release.sh脚本解析2.2.1 setup-environment脚本解析 3.构建嵌入式Linux系统3.1 BitBake构建系统3.2 BitBake构建系统过程简要解析3.2.1 解析Metadata基本配置Metadatarecipe…