GBDT算法原理及实战

news2025/1/13 6:03:41

1.什么是GBDT算法

  GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器(树),并把多棵决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。

在这里插入图片描述

  GBDT主要由三个概念组成:Regression Decistion Tree(即DT),Gradient Boosting(即GB),Shrinkage (算法的一个重要演进分枝,目前大部分源码都按该版本实现)。

  • DT:GBDT中的树都是回归树,不是分类树;将所有树的结果累加起来作为最终的结果。
  • GB:沿着梯度方向,构造一系列的弱分类器函数,并以一定的权重组合起来,形成最终决策的强分类器。
  • Shrinkage:每次走一小步逐渐逼近结果,要比每次迈一大步很快逼近结果的方式更容易避免过拟合。

2.Boosting核心思想

  Boosting方法训练基分类器时采用串行的方式,各个基分类器之间有依赖。它的基本思路是将基分类器(过拟合)层层叠加,每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重。测试时,根据各层分类器的结果的加权得到最终结果。

在这里插入图片描述

  Bagging 与 Boosting 的串行训练方式不同,Bagging 方法在训练过程中,各基分类器(欠拟合)之间无强依赖,可以进行并行训练。

3.GBDT原理详解

  • 所有弱分类器的结果相加等于预测值。
  • 每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
  • GBDT的弱分类器使用的是树模型。

在这里插入图片描述

上图是一个非常简单帮助理解的示例,用 GBDT 去预测年龄:

  • 第一个弱分类器(第一棵树)预测一个年龄(如 20 20 20 岁),计算发现误差有 10 10 10 岁;
  • 第二棵树预测拟合残差,预测值 ,计算发现差距还有 4 4 4 岁;
  • 第三棵树继续预测拟合残差,预测值 3 3 3,发现差距只有 1 1 1 岁了;
  • 第四课树用 1 1 1 岁拟合剩下的残差,完成。

  最终,四棵树的结论加起来,得到 30 30 30 岁这个标注答案(实际工程实现里,GBDT 是计算负梯度,用负梯度近似残差)。

3.1 GBDT与负梯度近似残差

  回归任务下,GBDT在每一轮的迭代时对每个样本都会有一个预测值,此时的损失函数为均方差损失函数
l ( y i , y ^ i ) = 1 2 ( y i − y ^ i ) 2 l(y_i,\hat y_i) = \frac{1}{2}(y_i-\hat y_i)^2 l(yi,y^i)=21(yiy^i)2
损失函数的负梯度计算如下:
− [ ∂ l ( y i − y ^ i ) ∂ y ^ i ] = ( y i − y ^ i ) -[\frac{\partial l(y_i-\hat y_i)}{\partial\hat y_i}]=(y_i-\hat y_i) [y^il(yiy^i)]=(yiy^i)
在这里插入图片描述

可以看出,当损失函数选用「均方误差损失」时,每一次拟合的值就是(真实值-预测值),即残差。

3.2 GBDT训练过程

  假定训练集只有 4 4 4 个人 ( A , B , C , D ) (A,B,C,D) (A,B,C,D),他们的年龄分别是 ( 14 , 16 , 24 , 26 ) (14,16,24,26) (14,16,24,26)。其中 A 、 B A、B AB 分别是高一和高三学生;$ C、D$ 分别是应届毕业生和工作两年的员工。

先看看用回归树来训练,得到的结果如下图所示:

在这里插入图片描述

  接下来使用GBDT训练,由于样本数据少,我们限定叶子节点最多为 2 2 2 (即每棵树都只有一个分枝),并且限定树的棵树为 2 2 2。最终训练得到的结果如下图所示:

在这里插入图片描述

从上图可知: A 、 B A、B AB年龄较为相近, C 、 D C、D CD年龄较为相近,被分为左右两支,每支用平均年龄作为预测值。

  • 计算残差(即「实际值」-「预测值」),所以 A A A 的残差 14 − 15 = − 1 14-15=-1 1415=1
  • 这里 A A A 的「预测值」是指前面所有树预测结果累加的和,在当前情形下前序只有一棵树,所以直接是 15 15 15,其他多树的复杂场景下需要累加计算作为 A A A 的预测值

在这里插入图片描述

上图中的树就是残差学习的过程了:

  • A 、 B 、 C 、 D A、B、C、D ABCD的值换作残差 − 1 、 1 、 − 1 、 1 -1、1、-1、1 1111,再构建一棵树学习,这棵树只有两个值 1 1 1 − 1 -1 1,直接分成两个节点: A 、 C A、C AC 在左边, B 、 D B、D BD 在右边。
  • 这棵树学习残差,在我们当前这个简单的场景下,已经能保证预测值和实际值(上一轮残差)相等了。
  • 我们把这棵树的预测值累加到第一棵树上的预测结果上,就能得到真实年龄,这个简单例子中每个人都完美匹配,得到了真实的预测值。

在这里插入图片描述

最终的预测过程是这样的:

A A A: 14 14 14 岁高一学生,购物较少,经常问学长问题;预测年龄 A = 15 – 1 = 14 A = 15 – 1 = 14 A=15–1=14

B B B: 16 16 16 岁高三学生;购物较少,经常被学弟问问题;预测年龄 B = 15 + 1 = 16 B = 15 + 1 = 16 B=15+1=16

C C C: 24 24 24 岁应届毕业生;购物较多,经常问师兄问题;预测年龄 C = 25 – 1 = 24 C = 25 – 1 = 24 C=25–1=24

D D D: 26 26 26 岁工作两年员工;购物较多,经常被师弟问问题;预测年龄 D = 25 + 1 = 26 D = 25 + 1 = 26 D=25+1=26

综上,GBDT 需要将多棵树的得分累加得到最终的预测得分,且每轮迭代,都是在现有树的基础上,增加一棵新的树去拟合前面树的预测值与真实值之间的残差。

3.3 思考

Q:回归树与GBDT得到的结果是相同的,那么为啥还要使用GBDT?

答:防止模型过拟合

在训练精度和实际精度(或测试精度)之间,后者才是我们想要真正得到的。

在上述的实例中,回归树算法为了达到 100 % 100\% 100% 精度使用了 3 3 3 个 feature(上网时长、时段、网购金额),其中分枝“上网时长 > 1.1 h 1.1h 1.1h ” 很显然已经过拟合了,在这个数据集上 A , B A,B A,B 也许恰好 A A A 每天上网 1.09 h 1.09h 1.09h , B B B 上网 1.05 h 1.05h 1.05h,但用上网时间是不是 > 1.1 h 1.1h 1.1h 来判断所有人的年龄很显然是有悖常识的。而GBDT算法只用了 2 2 2 个feature就搞定了,后一个feature是问答比例,显然GBDT算法更靠谱。Boosting的最大好处在于,每一步的残差计算其实变相地增大了分错instance的权重,而已经分对的instance则都趋向于 0 0 0。这样后面的树就能越来越专注那些前面被分错的instance。

4.梯度提升 vs 梯度下降

  对比一下「梯度提升」与「梯度下降」。这两种迭代优化算法,都是在每1轮迭代中,利用损失函数负梯度方向的信息,更新当前模型,只不过:

  • 梯度下降中,模型是以参数化形式表示,从而模型的更新等价于参数的更新。

w t = w t − 1 − ρ t ∇ w L ∣ w = w t − 1 L = ∑ i l ( y i , f w ( w i ) ) w_t=w_{t-1}-\rho_t\nabla_wL|_{w=w_{t-1}} \\[2ex] L= \sum \limits_il(y_i,f_w(w_i)) wt=wt1ρtwLw=wt1L=il(yi,fw(wi))

  • 梯度提升中,模型并不需要进行参数化表示,而是直接定义在函数空间中,从而大大扩展了可以使用的模型种类。
    F t = F t − 1 − ρ t ∇ F L ∣ F = F t − 1 L = ∑ i l ( y i , F ( x i ) ) F_t=F_{t-1}-\rho_t\nabla_FL|_{F=F_{t-1}} \\[2ex] L= \sum \limits_il(y_i,F(x_i)) Ft=Ft1ρtFLF=Ft1L=il(yi,F(xi))

4.1 GBDT中的梯度

  基于Boosting的集成学习是通过一系列的弱学习器,进而通过不同的组合策略得到相应的强学习器。在GBDT的迭代中,假设前一轮的得到的强学习器为 f t − 1 ( x ) f_{t-1}(x) ft1(x),对应的损失函数则为 L ( y , f t − 1 ( x ) ) L(y, f_{t-1}(x)) L(y,ft1(x))。因此在新一轮迭代的目的就是找到了一个弱学习器 h t ( x ) h_t(x) ht(x),使得损失函数 L ( y , f t − 1 ( x ) + h t ( x ) ) L(y, f_{t-1}(x)+h_t(x)) L(y,ft1(x)+ht(x))达到最小。难点在于损失函数如何度量?

梯度提升算法:利用最速下降的近似方法,即利用损失函数的负梯度在当前模型的值,作为回归问题中提升树算法的残差的近似值,拟合一个回归树。第 t t t 轮的第 i i i 个样本的负梯度表示为: r t i = − [ ∂ L ( y i , f ( x i ) ) ∂ f ( x i ) ] f ( x ) = f t − 1 ( x ) r_{ti}=-\left[\frac{\partial L(y_i,f(x_i))}{\partial f(x_i)}\right]_{f(x)=f_{t-1}(x)} rti=[f(xi)L(yi,f(xi))]f(x)=ft1(x)

4.2 GBDT回归算法基本流程

输入: 训练数据集 T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . ( x N , y N ) } , x i ∈ X ⊆ R n , y i ∈ Y ⊆ R , i = 1 , 2 , . . , N T=\{(x_1,y_1),(x_2,y_2),...(x_N,y_N)\},x_i \in \mathcal X \subseteq R^n,y_i \in \mathcal Y \subseteq R,i=1,2,..,N T={(x1,y1),(x2,y2),...(xN,yN)},xiXRn,yiYR,i=1,2,..,N,最大迭代次数 M M M ,损失函数 L ( y i , f ( x i ) ) L(y_i,f(x_i)) L(yi,f(xi))

输出: 回归树 f ^ ( x ) \hat f\left(x\right) f^(x)

1.初始化 f 0 ( x ) = a r g min ⁡ c ∑ i = 1 N L ( y i , c ) f_0(x)= arg \min\limits_c\sum\limits_{i=1}\limits^{N}L(y_i,c) f0(x)=argcmini=1NL(yi,c)

2.For m = 1 m=1 m=1 to M M M:

​ 2.1 For i = 1 i=1 i=1 to N N N compute: r m i = − [ ∂ L ( y i , f ( x i ) ) ∂ f ( x i ) ] f ( x ) = f m − 1 ( x ) r_{mi}=-\left[\frac{\partial L(y_i,f(x_i))}{\partial f(x_i)}\right]_{f(x)=f_{m-1}(x)} rmi=[f(xi)L(yi,f(xi))]f(x)=fm1(x)

​ 2.2 对 r m i r_{mi} rmi 拟合回归树,得到第 m m m 棵树的叶子结点区域 R m j , j = 1 , 2 , . . . , J m R_{mj},j=1,2,...,J_m Rmj,j=1,2,...,Jm

​ 2.3 For j = 1 j=1 j=1 to J m J_m Jm compute: c m j = a r g min ⁡ c ∑ x i ∈ R m j L ( y i , f t − 1 ( x i ) + c ) c_{mj}=arg \min\limits_c\sum\limits_{x_i \in R_{mj}}L(y_i,f_{t-1}(x_i)+c) cmj=argcminxiRmjL(yi,ft1(xi)+c)

​ 2.4 更新 f m ( x ) = f m − 1 ( x ) + ∑ j = 1 J m c m j I ( x ∈ R m j ) f_m(x)=f_{m-1}(x)+\sum\limits_{j=1}\limits^{J_m}c_{mj}I(x\in R_{mj}) fm(x)=fm1(x)+j=1JmcmjI(xRmj)

3.输出回归树: f ^ ( x ) = f M ( x ) \hat f(x)=f_M(x) f^(x)=fM(x)

5.GBDT优缺点

5.1 优点

  • 预测阶段,因为每棵树的结构都已确定,可并行化计算,计算速度快。
  • 适用稠密数据,泛化能力和表达能力都不错,数据科学竞赛榜首常见模型。
  • 可解释性不错,鲁棒性亦可,能够自动发现特征间的高阶关系。

5.2 缺点

  • GBDT 在高维稀疏的数据集上,效率较差,且效果表现不如 SVM 或神经网络。
  • 适合数值型特征,在 NLP 或文本特征上表现弱。
  • 训练过程无法并行,工程加速只能体现在单棵树构建过程中。

6.随机森林 vs GBDT

6.1 相同点

  • 都是集成模型,由多棵树组构成,最终的结果都是由多棵树一起决定。
  • RFGBDT 在使用 CART 树时,可以是分类树或者回归树。

6.2 不同点

  • 训练过程中,随机森林的树可以并行生成,而 GBDT 只能串行生成。
  • 随机森林的结果是多数表决表决的,而 GBDT 则是多棵树累加之。
  • 随机森林对异常值不敏感,而 GBDT 对异常值比较敏感。
  • 随机森林降低模型的方差,而 GBDT 是降低模型的偏差。

7.GBDT实践

7.1 数据说明

  新能源汽车充电桩的故障检测问题,提供 85500 85500 85500 条训练数据(标签: 0 0 0 代表充电桩正常, 1 1 1 代表充电桩有故障),参赛者需对 36644 36644 36644 条测试数据进行预测。

7.2 数据

训练数据: data_train.csv

在这里插入图片描述

测试数据:data_test.csv
在这里插入图片描述

7.3 代码实现

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
import joblib

# 读取数据
data = pd.read_csv(r"./data/data_train.csv", sep=',' )
# ['K1K2驱动信号','电子锁驱动信号','急停信号','门禁信号','THDV-M','THDI-M']
x_columns = []
# 读取文件,去除id和label标签项,并把数据分成训练集和验证集
for x in data.columns:
    if x not in ['id', 'label']:
        x_columns.append(x)

X = data[x_columns]
y = data['label']
# 采用默认划分比例,即75%数据作为训练集,25%作为预测集。
x_train, x_test, y_train, y_test = train_test_split(X, y)

# 模型训练,使用GBDT算法
gbr = GradientBoostingClassifier(n_estimators=3000, max_depth=2, min_samples_split=2, learning_rate=0.1)
gbr.fit(x_train, y_train.ravel())
joblib.dump(gbr, './model/train_model_result4.m')  # 保存模型

y_gbr = gbr.predict(x_train)
y_gbr1 = gbr.predict(x_test)
acc_train = gbr.score(x_train, y_train)
acc_test = gbr.score(x_test, y_test)
print(acc_train)
print(acc_test)

模型预测

# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
import joblib

# 加载模型并预测
gbr = joblib.load('./model/train_model_result4.m')  # 加载模型
test_data = pd.read_csv(r"./data/data_test.csv")
testx_columns = []
for xx in test_data.columns:
    if xx not in ['id', 'label']:
        testx_columns.append(xx)
test_x = test_data[testx_columns]
test_y = gbr.predict(test_x)
test_y = np.reshape(test_y, (36644, 1))

# 保存预测结果
df = pd.DataFrame()
df['id'] = test_data['id']
df['label'] = test_y
df.to_csv("./result/data_predict.csv", header=None, index=False)

8.总结

  • GBDT是一种基于Boosting思想的迭代决策树算法,通过构造一组弱的学习器(树),并把多棵决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。
  • 集成学习Boosting(降低偏差)和Bagging(降低方差)的理解以及区别。
  • GBDT算法的思想:所有弱分类器的结果相加等于预测值;每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
  • GBDT的理解:残差(每棵树学习的都是前一棵树的残差-全局最优)和梯度(每一棵回归树通过梯度下降法学习之前输的梯度下降值–局部最优)

本文仅仅作为个人学习记录所用,不作为商业用途,谢谢理解。

参考:

1.https://www.jianshu.com/p/47e73a985ba1

2.https://www.showmeai.tech/tutorials/34?articleId=193

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/436021.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手把手教你实现控制数组某一个属性之和不能超过某一个数值变量

大家好啊,最近有个小任务,就是我表格多选后,某一项关于栏目数量之和不能超过其他变量 先看图: 代码就是: 这里有一个点就是我需要累加数量之和,其实遍历循环累加也可以 我这里用的是reduce方法 0代表设置…

机器学习实战:Python基于LDA线性判别模型进行分类预测(五)

文章目录 1 前言1.1 线性判别模型的介绍1.2 线性判别模型的应用 2 demo数据演示2.1 导入函数2.2 训练模型2.3 预测模型 3 LDA手写数字数据演示3.1 导入函数3.2 导入数据3.3 输出图像3.4 建立模型3.5 预测模型 4 讨论 1 前言 1.1 线性判别模型的介绍 线性判别模型(…

vue2使用sync修饰符父子组件的值双向绑定

1、使用场景 当我需要对一个 prop 进行“双向绑定的时候,通常用在封装弹窗组件的时候来进行使用,当然也会有其他的使用场景,只要涉及到父子组件之间需要传参的都可以使用,尽量不要使用watch监听来进行修改值,也不要尝试…

GCC编译器的使用

源文件需要经过编译才能生成可执行文件。GCC是一款强大的程序编译软件,能够在多个平台中使用。 1. GCC编译过程 主要分为四个过程:预处理、编译、汇编、链接。 1.1 预处理 主要处理源代码文件中以#开头的预编译指令。 处理规则有: &…

怎么使用midjourney?9个步骤教你学会AI创作

人工智能生成艺术作品的时代已经来临,互联网上到处都是试图创造完美提示的用户,以引导人工智能创造出正确的图像——有时甚至是错误的图像。听起来很有趣?Midjourney 是一种更常见的 AI 工具,人们用它只用几句话就能创造出梦幻般的…

【Linux系统编程】15.fcntl、lseek、truncate

目录 fcntl lseek 参数fd 参数offset 参数whence 返回值 应用场景 测试代码1 测试结果 测试代码2 测试结果 查看文件方式 truncate 参数path 参数length 测试代码3 测试结果 fcntl 获取文件属性、修改文件属性。 int flgsfcntl(fd,F_GETFL); //获取 flgs|…

微服务架构是什么?

一、微服务 1、什么是微服务? 微服务架构(通常简称为微服务)是指开发应用所用的一种架构形式。通过微服务,可将大型应用分解成多个独立的组件,其中每个组件都有各自的责任领域。在处理一个用户请求时,基于…

DOM事件流

DOM事件流 1. 常用事件绑定方式1.1 对象属性绑定1.2 addEventListener()绑定1.3 两种方式区别 2. 事件流2.1 概念2.2 事件顺序2.2.1 捕获阶段2.2.2 目标阶段2.2.3 冒泡阶段 3. 阻止事件冒泡3.1 event.stopPropagation()3.2 stopPropagation与stopImmediatePropagation区别 4. 事…

“科技助力财富增值 京华四季伴您一生”,北银理财深化线下线上客户交流互动

2023年4月12日,北银理财有限责任公司(以下简称“北银理财”)携手东方财富网启动北银理财财富号,首次采用线上直播及线下主题演讲相结合的方式,在上海举办以“科技助力财富增值,京华四季伴您一生”为主题的机…

6、springboot快速使用

文章目录 1、最佳实践1.1、引入场景依赖1.2、查看自动配置了哪些(选做)1.3、是否需要修改配置1、修改配置2、自定义加入或者替换组件3、自定义器 XXXXXCustomizer 2、开发小技巧2.1、Lombok1、引入坐标2、在IDEA中安装lombok插件(新版默认安装…

趣说数据结构 —— 前言

趣说数据结构 —— 前言 一次偶然的机会,翻到当初自己读大学的时候教材,看着自己当初的勾勾画画,一时感触良多。 很值得一提的是,我在封面后第一页,写着自己的专业和名字的地方下面,写着几行这样的字&…

leetcode刷题(6)

各位朋友们大家好,今天是我的leetcode刷题系列的第六篇。这篇文章将与队列方面的知识相关,因为这些知识用C语言实现较为复杂,所以我们就只使用Java来实现。 文章目录 设计循环队列题目要求用例输入提示做题思路代码实现 用栈实现队列题目要求…

Vue2-黑马(七)

目录: (1)router-路由嵌套 (2)router-路由跳转 (3)router-导航菜单 (1)router-路由嵌套 我们有这样的需求,我们已经显示了主页,但是主页里面有&…

SpringBoot数据库换源

文章目录 前言一. baomidou提供换源注解 DS二. 手动数据源切换三. AOP自动换源 前言 笔者知道有三种方式: baomidou提供的DS自定义AOP自动换源实现AbstractRoutingDataSource手动换源 一. baomidou提供换源注解 DS 注意 1.不能使用事务,否则数据源不会切换&…

云原生入门

云原生入门. 云原生是一种设计和构建应用程序的方法,它充分利用了云计算的优势,如弹性、可扩展性、自动化和敏捷性。云原生应用程序不仅可以在云中运行,而且是为云而生的,它们采用了一些新式的技术和架构模式,使得应用…

零基础入门python好学么

python对于零基础的小伙伴算是非常友好的了~ python以简单易学著称~ Python简洁,高效的特点,大大提升了程序员的编码速度,极大的提高了程序员的办公效率,比如用其他编程语言5、6行代码才能整明白的,用Python可能1-2行就…

不应使用Excel进行项目资源规划的 7 个原因

项目资源规划早期仅限于基本分配和调度。因此,企业使用自制工具或excel表来执行这一简单功能。然而,随着技术和业务流程的发展,资源规划变得复杂,并包括其他组成部分,如预测和容量规划,优化等。 由于传统…

1.BootstrapTable组件

1.先在页面声明一个表格对象 <table id"table" class"table table-striped"></table> 2.生成表格JS代码如下 var url /log/;var columns [{checkbox: true,visible: true //是否显示复选框},{field: id,title: 序号,width…

若依框架—基于AmazonS3实现OSS对象存储服务

若依框架—基于AmazonS3实现OSS对象存储&#xff0c;其他也适用 文章目录 若依框架—基于AmazonS3实现OSS对象存储&#xff0c;其他也适用上一篇[若依mybatis升级mybatis-plus&#xff0c;其他也适用](https://blog.csdn.net/omnipotent_wang/article/details/128635654?spm10…

MYSQL:查询数据

一、学习目标 了解基本查询语句掌握表单查询的方法掌握如何使用几何函数的查询掌握连接查询的方法掌握如何使用子查询熟悉合并查询结果熟悉如何为表和字段取别名掌握如何使用正则表达式查询掌握数据表的查询操作技巧和方法 二、实验内容 根据不同条件对表进行查询操作&#…