GBDT精讲

news2024/12/28 20:25:40

GBDT算法的流程

首先GBDT是通过采用加法模型(即基函数的线性组合),以及不断减小训练过程产生的残差来达到将数据分类或回归的算法。

GBDT通过多轮迭代,每轮迭代产生一个弱分类器,每个分类器在上一轮分类器的梯度(如果损失函数是平方损失函数,则梯度就是残差值)基础上进行训练。对弱分类器的要求一般是足够简单,并且是低方差和高偏差的。因为训练的过程是通过降低偏差来不断提高最终分类器的精度。

弱分类器一般会选择为CART TREE(也就是分类回归树)。由于上述高偏差和简单的要求 每个分类回归树的深度不会很深。最终的总分类器 是将每轮训练得到的弱分类器加权求和得到的(也就是加法模型)。

模型最终可以描述为:

模型一共训练M轮,每轮产生一个弱分类器 �(�,��) 。弱分类器的损失函数

��−1(�) 为当前的模型,gbdt通过经验风险极小化来确定下一个弱分类器的参数。具体到损失函数本身的选择也就是L的选择,有平方损失函数、0-1损失函数、对数损失函数等等。如果我们选择平方损失函数,那么这个差值其实就是我们平常所说的残差。

  • 但是其实我们真正关注的,1.是希望损失函数能够不断的减小,2.是希望损失函数能够尽可能快的减小。所以如何尽可能快的减小呢?
  • 让损失函数沿着梯度方法的下降。这个就是gbdt的gb的核心了。利用损失函数的负梯度在当前模型的值作为回归问题提升树算法中的残差的近似值去拟合一个回归树。gbdt每轮迭代的时候,都去拟合损失函数在当前模型下的负梯度。
  • 这样每轮训练的时候都能够让损失函数尽可能快的减小,尽快的收敛达到局部最优解后者全局最优解。

GBDT算法原理好用资料记录

GBDT特征选择

GBDT中的弱分类器选择的是CART回归树。GBDT中特征的选择就是CART树的生成过程中特征属性的选择。而CART回归树的生成算法可以参照之前的文章。

回归树Regression Decision Tree

GBDT并不是很多分类树。决策树分为两大类,回归树和分类树。前者用于预测实数值,如明天的温度、用户的年龄、网页的相关程度;后者用于分类标签值,如晴天/阴天/、用户性别、网页是否是垃圾页面。这里要强调的是,前者的结果加减是有意义的,如10岁+5岁-3岁=12岁,后者则无意义,如男+男+女=到底是男是女?GBDT的核心在于累加所有树的结果作为最终结果。就像前面对年龄的累加,而分类树的结果显然没办法累加的。所以GBDT中的树都是回归树,不是分类树,这点对理解GBDT相当重要(尽管GBDT调整后也可用于分类但不代表GBDT的树是分类树)。那么回归树是如何工作的呢?

作为对比,先说分类树,我们知道C4.5分类树在每次分支时,是穷举每一个feature的每一个阈值,找到使得按照feature<=阈值,和feature>阈值分成的两个分支的熵最大的feature和阈值,按照该标准分支得到两个新节点,用同样方法继续分支知道所有人都被分入性别唯一的叶子节点,或达到预设的终止条件,若最终叶子节点中的性别不唯一,则以多数人的性别作为该叶子节点的性别。

回归树总体流程也是类似,不过在每个节点(不一定是叶子节点)都会得到一个预测值,以年龄为例,该预测值等于属于这个节点的所有人年龄的平均值。分支时穷举每一个feature的每个阈值找最好的分割点,但衡量最好的标准不再是最大熵,而是最小化均方差--即(每个人的年龄-预测年龄)^2的总和/N。这很好理解,被预测出错的人数越多,错的越离谱,均方差就越大,通过最小化均方差能够找到最靠谱的分支依据。分支直到每个叶子节点上人的年龄都唯一(这太难了)或者达到预设的终止条件(如叶子个数上限),若最终叶子节点上人的年龄不唯一,则以该节点上所有人的平均年龄作为该叶子节点的预测年龄。

GB:梯度迭代 Gradient Boosting

让损失函数沿着梯度方向的下降。这个就是gbdt的gb的核心。gbdt每轮迭代的时候,都去拟合损失函数在当前模型下的负梯度。(如果损失函数使用的时平方误差损失函数,则这个损失函数的负梯度就可以用残差来代替,一下所说的残差拟合,便是使用了平方误差损失函数)

Boosting,迭代,即通过迭代多棵树来共同决策。GBDT的核心就在于,每一棵树学的是之前所有树结论和的残差,这个残差就是一个加预测值后能得到真实值的累加量。比如A的真是年龄是18岁,但第一个棵树的预测年龄是12岁,差了6岁,即残差为6岁。那么在第二棵树里我们把A的年龄设为6岁去学习,如果第二棵树真的能把A分到6岁的叶子节点,那累加两棵树的结论就是A的真实年龄;如果第二棵树的结论是5岁,则A仍然存在1岁的残差,第三棵树里A的年龄就变成1岁,继续学。这就是Gradient Boosting在GBDT中的意义。

GBDT回归问题(预测)的例子

还是年龄预测,简单起见训练集只有4个人,A、B、C、D,他们的年龄分别是14,16,24,26.其中A、B分别是高一和高三学生;C、D分别是应届毕业生和工作两年的员工。如果用一颗传统的回归决策树来训练,会得到如图1所示的结果。

图1 传统回归决策树

现在我们使用GBD他来做这件事,由于数据太少,我们限定叶子节点最多有两个,即每棵树都只有一个分支,并且限定只学两棵树。我们得到如图2所示的结果。

图2 GBDT模型

第一棵树的第一个分支和图2一样,由于A、B年龄较为相近,C、D年龄较为相近,它们被分为两拨,每拨用平均年龄作为预测值。此时计算残差(残差的意思就是:A的预测值+ A的残差=A的实际值),所以A的残差就是16-15=1(注意,A的预测值是指前面所有树累加的和,这里前面只有一棵树所以直接是15,如果还有树则需要都累加起来作为A的预测值)。进而得到A、B、C、D的残差分别为-1,1,-1,1。然后我们拿残差替代A、B、C、D的原值,到第二棵树去学习,如果我们的预测值和它们的残差相等,则只需把第二棵树的结论累加到第一棵树上就能得到真实年龄了。根据这里的数据继续往下走,第二棵树只有两个值1和-1,直接分成两个节点。此时所有人的残差都是0,即每个人都得到了真实的预测值。

换句话说,现在A、B、C、D的预测值和真实年龄一致了。

A:14岁高一学生,购物较少,经常问学长问题;预测年龄A=15-1=14

B:16岁高三学生,购物较少,经常被学习问问题,预测年龄B=15+1=16

C:24岁应届毕业生,购物较多,经常问师兄问题,预测年龄C=25-1=24

D:26岁工作两年的员工,购物较多,经常被师弟问问题,预测年龄是D=25+1=26

讲到这里我们已经把GBDT最核心的概念、运算过程讲完了!讲到这里很容易发现两个问题:

1)既然图1和图2最终效果相同,为何还需要GBDT呢?

答案是过拟合。过拟合是指为了让训练集精度更高,学到了很多"仅在训练集上成立的规律",导致换一个数据集当前规律就不适用了。其实只要允许一棵树的叶子节点足够多,训练集总是能训练到100%的准确率的(大不了最后一个叶子上只有一个instance)。再训练精度和实际精度(或测试精度)之间,后者才是我们想要真正得到的。

我们发现图1为了达到100%精度使用了3个feature(上网时长、时段、网购金额),其中分支"上网时长>1.1h"很显然已经过拟合了,这个数据集上A、B也许恰好A每天上网1.09h,B上网1.05小时,但用上网时长是不是>1.1小时来判断所有人的年龄很显然是很有悖常识对的;

想对来说图2的boosting虽然用了两棵树,但其实只用了2个feature就搞定了,后一个feature是问答比例,显然图2的依据更靠谱。(当然这是造的数据,所以才会如此完全靠谱,实际中靠不靠谱总是相对的)。Boosting的最大好处在于,每一步的残差计算其实变相的增大了分错样本的权重,而已经分对的样本则都趋于0。这样后面的树就能越来越专注那些前面被分错的样本。就像我们做互联网,总是先解决60%用户的需求凑合着,再解决35%用户的需求,最周才关注那5%人的需求,这样就能逐渐把产品做好,因为不同类型用户需求可能完全不同,需要分别独立分析。如果反过来做,或者刚上来就一定要做到尽善尽美,往往竹篮打水一场空。

2)这不是boosting把?Adaboost可不是这么定义的。

这是boosting,但不是adaboost.GBDT 不是adaboost desicion tree.adaboost是另一种boosting方法,只能用于二分类,它按分类对错,分配不同的weight,计算cost function时使用这些weight,从而让"错分的样本权重越来越大,使它们更被重视"。

GBDT分类及分类问题例子

首先明确一点,gbdt无论用于分类还是回归一直都是使用的CART回归树。不会因为我们所选择的任务是分类任务就选用分类树,这里面的核心是因为GBDT每轮的训练是在上一轮的训练的残差基础之上进行训练的。这里的残差就是当前模型的负梯度值。这个要求每轮迭代的时候,弱分类器的输出的结果相减是有意义的。残差相减是有意义的。

如果选用的弱分类器是分类树,类别相减是没有意义的。上一轮输出的是样本属于A类,本轮训练输出的是样本x属于B类。A和B很多时候甚至都没有比较的意义,A类-B类是没有意义的。

我们具体到分类这个任务上面来,我们假设样本X总共有K类。来了一个样本x,我们需要使用gbdt来判断x属于样本的哪一类。

图3 gbdt多分类算法流程

第一步 我们在训练的时候,是针对样本X每个可能的类都训练一个分类回归树。举例说明,目前样本有三类,也就是K=3。样本x属于第二类。那么针对该样本x的分类结果,其实我们可以用一个三维向量[0,1,0]来表示,0表示样本不属于该类,1表示样本属于该类。由于样本已经属于第二类了,所以第二类对应的向量维度为1,其他位置为0.

针对样本有三类的情况,我们实质上是在每轮的训练的时候同时训练三棵树。第一棵树针对样本x的第一类,输入为(x,0)。第二棵树输入针对样本x的第二类,输入为(x,1)。第三棵树针对样本x的第三类,输入为(x,0)。

在这里每棵树的训练过程其实就是我们之前已经提到过的CART TREE的生成过程。在此处我们参照之前的生成树的程序既可以解出三棵树,以及三棵树对样本x类别的预测值f1(x), f2(x), f3(x)。那么在此类训练中,我们仿照多分类的逻辑回归,使用softmax来产生概率,则属于类别1的概率

并且我们可以针对类别1求出残差 �11(�)=0−�1(�) ;类别2求出残差 �22(�)=1−�2(�) ;类别3求出残差 �33(�)=0−�3(�) 。然后开始第二轮训练 针对第一类输入为 (�,�11(�)) ,针对第二类输入为 (�,�22(�)) ,针对第三类输入为 (�,�33(�)) 。继续训练出三棵树。一直迭代M轮。每轮构建3棵树。

所以当K =3.我们其实应该有三个式子:

当训练完毕以后,新来一个样本x1,我们需要预测该样本的类别的时候,便可以由这三个式子产生三个值,f1(x)、f2(x)、f3(x)。样本属于某个类别c的概率为

接下来我们用iris数据集中的六个数据作为例子,展示gbdt多分类的过程。

这是一个6个样本的三分类问题。我们需要根据这个花的花萼长度、花萼宽度、花瓣长度、花瓣宽度来判断这个花属于山鸢尾、杂色鸢尾、维吉尼亚鸢尾。具体应用到gbdt多分类算法上面。我们用一个三维向量来标志样本的label。[1,0,0]表示样本属于山鸢尾,[0,1,0]表示样本属于杂色鸢尾,[0,0,1]表示样本属于维吉尼亚鸢尾。

gbdt的多分类是针对每个类都独立训练一个cart tree。所以这里,我们将针对山鸢尾类别训练一个cart tree1。杂色鸢尾训练一个cart tree2。维吉尼亚鸢尾训练一个cart tree3,这三个树相互独立。

我们以样本1为例。针对cart tree1的训练样本是[5.1,3.5,1.4,0.2],label是1,最终输入到模型当中的为[5.1,3.5,1.4,0.2,1]。针对cart tree2的训练样本也是[5.1,3.5,1.4,0.2],但是label为0,最终输入到模型当中的为[5.1,3.5,1.4,0.2,0]。针对cart tree3的训练样本也是[5.1,3.5,1.4,0.2],label也为0,最终输入到模型当中的为[5.1,3.5,1.4,0.2,0]。

下面我们来看cart tree1是如何生成的,其他树cart tree2和cart tree3的生成方式是一样的。cart tree的成成过程是从四个特征中找一个特征作为cart tree1的节点,以及该特征的什么特征值作为切分点。即,生成的过程其实非常简单,问题1.是哪个特征最合适?2.是这个特征的什么特征值作为切分点?在这里我们的方式是遍历所有的可能性,找到一个最好的特征和它对应的最优特征值可以让当前式子(记为式子(1))的值最小。

我们以第一个特征的第一个特征值为例。R1为所有样本中花萼长度小于5.1cm的样本集合,R2为所有样本当中花萼长度大于等于5.1cm的样本集合。所以R1={2},R2={1,3,4,5,6},如下图所示

节点分裂示意图

y1为R1所有样本的label的均值1/1=1.y2为R2所有样本的label的均值(1+0+0+0+0)/5=0.2。

接下来我们计算这种划分方式下的损失值(其实就是每个样本的真实值与其预测值的均方差),(1-0.2)^2+(1-1)^2+(0-0.2)^2+(0-0.2)^2+(0-0.2)^2=0.8。因此按照第一个特征第一个特征值分裂损失值为0.8.依次类推,我们计算出每个特征的每个特征值作为切分点的损失值,一共有24种情况,4个特征*每个特征有6个特征值。在这里让式子(1)最小的切分点是特征为花萼长度,特征值为5.1cm,对应的损失值为0.8.

于是我们的预测函数此时可以得到:

此处 R1 = {2},R2 = {1,3,4,5,6},y1 = 1,y2 = 0.2。训练完以后的最终式子(记为式子(2))为:

这样我们就得到了第一轮的cart tree1,假设我们只训练一轮,那对于新样本,我们根据训练好的cart tree1,cart tree2, cart tree3计算该新样本在3个tree下的预测值(类似于式子(2)),分别记为f1(x),f2(x),f3(x)。那样本属于类别1的概率为

�1=���(�1(�))/∑�=13���(��(�))

附录

树模型优缺点

因为树要一个特征一个特征去计算对应的指标来选择合适的特征来作为切分点,所以特征数量很多的情况下 计算量太多,因此不适合处理高维稀疏数据。

泰勒公式

梯度下降法

牛顿法

GBDT的论文为:GREEDY FUNCTION APPROXIMATION: A GRADIENT BOOSTING MACHINE

实际上GBDT泛指所有梯度提升树算法,包括XGBoost,它也是GBDT的一种变种,这里为了区分它们,GBDT特指“Greedy Function Approximation:A Gradient Boosting Machine”里提出的算法,它只用了一阶导数信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/704766.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Quiz 4: Functions | Python for Everybody 配套练习_解题记录

文章目录 课程简介Quiz 4: Functions 单选题&#xff08;1-9&#xff09;编程题Exercise 4.6 课程简介 Python for Everybody 零基础程序设计&#xff08;Python 入门&#xff09; This course aims to teach everyone the basics of programming computers using Python. 本课…

JAVA2

文章目录 前言 前言 创建&#xff0c;编译java&#xff08;每4修改一次就要重新编译&#xff01;&#xff09; 第一个程序&#xff1a; 解决中文乱码问题&#xff1a; 效果&#xff1a; 总结&#xff1a;

管理类联考——英语——趣味篇——词根词汇——按频次分类——高频词汇——List1

优化原书记忆方法&#xff0c;轻松搞定考研单词 摒弃了传统的以字母顺序排序的方法&#xff0c;结合近20年考研真题&#xff0c;通过电脑搜索等方法对核心词进行科学统计&#xff0c;将核心词有机地分为高频词汇、常考词汇、中频词汇、低频词汇等4大部分&#xff0c;同时还补充…

一个三极管和稳压管构成的简易稳压电源

一个三极管和稳压管构成的简易稳压电源 先看电路 原理分析&#xff1a; 实际使用中可以加入合适的滤波电容。 上面的电路原理看着比较简单&#xff0c;但还是有不少要注意的地方。 来看看仿真电路的结果&#xff1a; 可以看到&#xff0c;输出的电压并不是我们想要的结果&am…

高压线路距离保护程序逻辑原理(五)

六、系统振荡的判断与振荡闭锁程序逻辑框图 &#xff08;一&#xff09;系统振荡概述 电力系统的振荡大致可以分为两种情况&#xff1a;一种是静稳破坏引起系统振荡&#xff0c;另一种是由于系统内故障切除时间过长&#xff0c;导致系统的两侧电源之间的不同步而引起的系统振…

【单片机】MSP430单片机,1.3寸 IIC OLED ,显示驱动

文章目录 main.coled.holedfont.h main.c #include <msp430.h> #include "OLED.h"int main( void ) {WDTCTL WDTPW WDTHOLD; /* Stop WDT */if ( CALBC1_8MHZ 0xFF ) /* If calibration constant erased */{while ( 1 ); /* do n…

C++ DAY4

1.思维导图 2.运算符重载 #include <iostream> using namespace std;class Person { private:int age;int *p; public://1.无参构造Person():p(new int(89)){age 18;}//2.有参构造Person(int age,int num){this->age age;this->pnew int(num);}//3.拷贝构造函数…

数据库中的日期函数DM和mysql都通用,计算年月日时分秒,获取日期之间相差的值

select MINUTE(date) from t_test; year month day hour minute second --对应年月日时分秒 select date from t_test select MINUTE(createtime),to_char(sysdate(),yyyy-MM-dd) select TIMESTAMPDIFF(minute,date,now()),date from t_test DateUtil.between(new Date(),ne…

小程序底层技术机制解读:版本更新与底层运行原理

&#x1f482; 个人网站:【海拥】【游戏大全】【神级源码资源网】&#x1f91f; 前端学习课程&#xff1a;&#x1f449;【28个案例趣学前端】【400个JS面试题】&#x1f485; 寻找学习交流、摸鱼划水的小伙伴&#xff0c;请点击【摸鱼学习交流群】 目录 前言小程序版本更新机制…

Day.3 LeetCode刷题练习(反转链表)

题目&#xff1a; 例子&#xff1a; 分析题目&#xff1a; 分析题目&#xff0c;因为是一个单链表所以不能找到尾后往前改变&#xff0c;所以不妨换个思路从前往后进行修改链表链接关系 用到三个指针 指针cur指向所要改变的节点链接关系、指针prev指向所要改变节点的前一个节点…

白盒测试入门概念

白盒测试的度量 根据待测产品的内部实现细节来设计测试用例白盒测试的执行手段是可以涵盖单元测试、集成测试使用代码覆盖率作为白盒测试的主要度量指标 代码覆盖率常见概念 语句覆盖&#xff1a;每行代码都要覆盖至少一次判定覆盖&#xff1a;判定表达式的真假至少覆盖一次…

CSS文字阴影渐变动画效果

最近项目中需要一些简单的动画效果&#xff0c;就写了一个文字渐变的动画效果,纯CSS动画.文字加了一点阴影效果,看起来有发光哒感觉~ 效果如下图 大家可以自己拷贝代码亲自试一试 代码看下面 html <div class"son"> Particia </div>css .son {color…

SpringBoot2+Vue2实战(七)springboot集成jwt

一、集成jwt JWT依赖 <!-- JWT --><dependency><groupId>com.auth0</groupId><artifactId>java-jwt</artifactId><version>3.10.3</version></dependency> UserDto import cn.hutool.core.annotation.Alias; import …

【数据库】mysql 管理员密码丢失解决方案

本次操作环境是mysql5.7.24版本 问题&#xff1a;由于各种原因&#xff0c;数据库管理员密码丢失&#xff0c;无法登陆数据库 解决方法&#xff1a; 1、进入my.cnf文件进行修改配置 在[mysqld]下添加 skip-grant-tables 2、重启mysql服务 service mysql stop service mysql …

Fiddler如何延迟接口响应时间

需求描述&#xff1a;通过延迟接口响应时间来mock响应超时的测试场景 解决方法&#xff1a; 使用fiddler模拟接口延时请求&#xff0c;fiddler设置参数如下&#xff1a;

淘宝再夺顶级技术比赛CVPRNTIRE冠军,背后是这些提升用户体验的内容技术

不知不觉间&#xff0c;内容电商似乎已经成为人们生活中不可或缺的存在&#xff1a;在闲暇时间&#xff0c;我们已经习惯于拿出手机&#xff0c;从电商平台的直播间随手下单自己心仪的商品。 尽管优质的货品、实惠的价格、精致的场景布置、有趣的内容输出都是非常关键的影响因…

【JavaScript】文档注释详解

文章目录 什么是文档注释为什么要写文档注释不使用文档注释存在的隐患使用函数成员时的书写问题调用函数时功能使用问题 文档注释官方标签函数参数标签 param参数类型 {}参数注释对象属性属性注释使用带有对象属性注释的参数 返回值标签 returns注释 作者标签 author许可证标签…

spring cloud 之 eureka

Eureka概述 Spring Cloud封装了Netflix 公司开发的Eureka模块来实现服务治理&#xff0c;SpringCloud将它集成在其子项目spring-cloud-netflix中 在服务注册与发现中&#xff0c;有一个注册中心。当服务器启动的时候&#xff0c;会把当前自己服务器的信息比如服务地址通讯地址…

ubuntu 显卡驱动/cuda/cudnn

显卡驱动 https://www.bilibili.com/video/BV1Zc41137tU/?spm_id_from333.999.0.0&vd_sourced75fca5b05d8be06d13cfffd2f4f7ab5 使用recommended的驱动&#xff0c;open和无open的区别在于无open更适合发挥NVIDIA显卡的全部功能和性能&#xff0c;特别是GPU加速计算等任…

K8S集群安装与部署(Linux系统)

一、环境说明&#xff1a;CentOS7、三台主机&#xff08;Master&#xff1a;10.0.0.132、Node1&#xff1a;10.0.0.133、Node2&#xff1a;10.0.0.134&#xff09; 二、准备环境&#xff1a; 映射 关闭防火墙 三、etcd集群配置 安装etcd&#xff08;Master&#xff09; [ro…