一文速学-GBDT模型算法原理以及实现+Python项目实战

news2024/11/25 9:41:37

目录

前言

一、GBDT算法概述

1.决策树

2.Boosting

3.梯度提升

 使用梯度上升找到最佳参数

二、GBDT算法原理

1.计算原理

2.预测原理

三、实例算法实现

 1.模型训练阶段

1)初始化弱学习器

 2)对于建立M棵分类回归树​:

 四、Python实现

1.原始决策树累积

2.sklearn


前言

上篇文章内容已经将Adaboost模型算法原理以及实现详细讲述实践了一遍,但是只是将了Adaboost模型分类功能,还有回归模型没有展示,下一篇我将展示如何使用Adaboost模型进行回归算法训练。首先还是先回到梯度提升决策树GBDT算法模型上面来,GBDT模型衍生的模型在其他论文研究以及数学建模比赛中十分常见,例如XGBoost,LighGBM,catboost。其实将这些算法重要的点拿出来就更容易理解了,主要是五个方向的变动改进:

算法差异点GBDTXGBoostLightGBMCatBoost
弱学习器CART回归树

1.CART回归树

2.线性学习器

3.Dart树

Leaf-wise树对称树
寻找分裂点贪心算法近似算法直方图算法预排序算法
稀疏值处理稀疏感知算法EFB(互斥特征捆绑)
类别特征不直接支持,可自行编码后输入模型同GBDT直接支持,GS编码直接支持,Ordered TS编码
并行支持不可以可以可以可以

本篇主讲GBDT算法模型以及应用,先把大体框架熟悉,之后的算法只需要填补功能就好了。本篇并不会提及太多专业公式以及推论公式,数学基础薄弱的不用担心,大家可以放心学习,我会尽可能简单易懂的讲明白算法原理,主要是实战以及运用和相关代码的使用。


一、GBDT算法概述

在开篇Boosting算法中有过讲到,回顾下Adaboost,我们是利用前一轮迭代弱学习器的误差率来更新训练集的权重,这样一轮轮的迭代下去。GBDT也是迭代,使用了前向分布算法,但是弱学习器限定了只能使用CART回归树模型,同时迭代思路和Adaboost也有所不同。

GBDT的思想可以用一个通俗的例子解释,假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。

1.决策树

那么GBDT算法肯定有其对应的弱学习器,也就是CART回归树。

这里如果大家之前并没有了解过决策树的概念,可以去看我的这篇文章:

一文速学数模-分类模型(二)决策树(Decision Tree)算法详解及python实现

那么这个CART指的是(Classification and Regression Tree)的意思, 这里我大体讲述一下该决策树算法:,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。

2.Boosting

Boosting的思路则是串行的,每一次训练一个模型都是建立在前一个模型的学习基础上,不断去通过新模型去减少之前的错误。

这点思路在讲AdaBoosting算法模型时候已经讲的很明确了,一图就可了解:

3.梯度提升

在梯度提升中,每个弱学习器的训练都是基于前一个弱学习器的预测误差,通过梯度下降的方式来最小化误差。具体来说,对于回归问题,我们可以选择平方损失函数作为损失函数。

关于梯度提升算法我之前在Logistic原理详解和遗传算法里面也有详解讲过,此类最优算法最核心的一点就是对于残差的使用。而损失函数就是衡量调整每一次迭代模型算法的权重的参考功能。

损失函数(loss function):为了评估模型拟合的好坏,通常用损失函数来度量拟合的程度。损失函数极小化,意味着拟合程度最好,对应的模型参数即为最优参数。在线性回归中,损失函数通常为样本输出和假设函数的差取平方。比如对于样本(x_{i},y_{i}) (i=1,2,...n),采用线性回归,损失函数为:

对于分类问题,则可以选择交叉熵损失函数。在每次迭代中,我们都会训练一个新的弱学习器,使得它能够最大程度地减少当前模型的误差。然后将这个新的学习器加入到当前的模型中,从而不断提升整个模型的预测能力。

 使用梯度上升找到最佳参数

使用梯度上升找到最佳参数可以假设为爬山运动,我们总是往向着山顶的方向攀爬,当爬到一定角度以后也会驻足停留下观察自身角度是否是朝着山顶的角度上攀爬。并且我们需要总是指向攀爬速度最快的方向爬。

要找到某函数的最大值,最好的方法就是沿着该函数的梯度方向搜寻。我们假设步长为\alpha,用向量来表示的话,梯度上升算法的迭代公式如下:

w:=w+\alpha gard_{w}(f(w))。该公式停止的条件是迭代次数达到某个指定值或者算法达到某个允许的误差范围。

梯度提升的一个重要特点是它可以应用于各种类型的弱学习器,例如决策树、线性模型、神经网络等。然而,决策树是梯度提升中最常用的弱学习器之一,因为它们可以很好地处理非线性特征和交互作用,同时也可以通过剪枝等技术来避免过拟合。

二、GBDT算法原理

1.计算原理

GBDT算法的原理如下:

  1. 初始化。将所有样本的权重设置为相等的值,建立一个初始模型作为基准模型,可以设置为简单的平均值或者是中位数。例如建立一个弱分类器F0(x)=argmin_{c}\sum_{i=1}^{N}L(y_{i},c),c即为平均值。

  2. 迭代训练。在每一轮迭代中,GBDT算法会先根据当前模型的预测结果计算每个样本的残差。对于回归问题,残差就是实际输出值与模型预测值之间的差异,对于分类问题,残差就是样本的实际类别与模型预测类别之间的差异。然后,GBDT会训练一个新的决策树模型,来学习如何预测这些残差。对于建立M棵CART树m=1,2,...M:

    1. 对i=1,2,...,N, 计算第m棵树对应的响应值(损失函数的负梯度):

    2. 对于i = 1,2,...N,利用CART回归树拟合数据,得到第m棵回归树,其对应的叶子节点区域为R_{m,j},其中j=1,2,...,j_{m},且j_{m}为第m棵回归叶子节点的个数。

    3. 对于j_{m}个叶子节点区域j=1,2,...,j_{m},计算最佳拟合值

      c_{m,j}=arg min_{c}=\sum _{x_{i}\epsilon R_{m,j}}L(y_{i},F_(m-1)(x_{i})+c)

    4. 更新强学习器F_{m}(x):F_{m}(x)=F_{m-1}(x)+\alpha \sum_{j=1}^{j_{m}}c_{m,j}I(x\in R_{m,j})

       

  3. 添加新模型。新模型的预测结果会被加入到当前模型的输出中,使得模型的预测结果逐步趋近于真实值。可以将每个模型的输出进行加权求和,得到最终模型的输出。

  4. 终止条件。当模型的准确率达到一定阈值,或者迭代次数达到预设的最大值时,算法停止迭代。最后得到强学习器表达式:F_{M}(x)=F_{0}(x)+\alpha \sum_{m=1}^{M}\sum_{j=1}^{J_{M}}c_{m,j}I(x\in R_{m,j})

GBDT算法通过不断训练新的决策树模型,并将它们的预测结果累加到当前模型的输出中,来逐步提升整个模型的预测能力。与传统的决策树算法相比,GBDT算法可以减少过拟合的风险,并且具有较强的鲁棒性。

2.预测原理

上述模型生成原理的数学推论和公式是绕不开的,其他算法模型也是一样,在所有的机器学习以及其他算法模型中来说,没有不存在数学公式的模型。但是预测原理我们可以尽可能简化,这里参考GBDT的原理和应用的举例比较形象:

假设我们要预测一个人是否会喜欢电脑游戏,特征包括年龄,性别是否为男,是否每天使用电脑。标记(label)为是否喜欢电脑游戏,假设训练出如下模型:

 该模型又两棵树组成, tree1使用 age < 15 和 is male 作为内节点,叶子节点是输出的分数。 tree2使用是否每日使用电脑作为根节点。假设测试样本如下:

 最后对某样本累加它所在的叶子节点的输出值,例如:

 单独的使用GBDT模型,容易出现过拟合,在实际应用中往往使用 GBDT+LR的方式做模型训练。

三、实例算法实现

首先我们以一组数据作为训练集:

编号车辆速度道路等级拥堵状态
02015
13024
26032
37042

 测试数据如下表所示:

编号车辆速度道路等级拥堵状态
0503

 1.模型训练阶段

参数设置:

  • 学习率:learning_rate = 0.3

  • 迭代次数:n_trees = 6

  • 树的深度:max_depth = 3

1)初始化弱学习器

F_{0}(x)=arg_{c}min\sum_{i=1}^{N}L(y_{i},c)

 损失函数为平方损失,因为平方损失函数是一个凸函数,可以直接求导,令导数等于零,得到c:

 令导数等于0:

 所以初始化时,2a7f1bd700677edcb4059c9751690e30.png取值为所有训练样本标签值的均值。

 c=(5+4+2+2)/4=3.25,此时得到的初始化学习器为F_{0}(x)=c=3.25

 2)对于建立M棵分类回归树500b09879826c48c11e67aeb7191f6db.png

 由于我们设置了迭代次数:n_trees=6,这就是设置了M=6。

首先计算负梯度,根据上文损失函数为平方损失时,负梯度就是残差,也就是856cff497381e535f4bd4339fc8d0462.png与上一轮得到的学习器fb7a1955d5b91d2abfd5524ea28e384a.png的差值:

 现将残差的计算结果列表如下:

编号真实值F_{0}(x)残差
053.251.75
143.250.75
223.25-1.25
323.25-1.25

 此时将残差作为样本的真实值来训练弱学习器feb1c5e538dee84225d547fae9b42d41.png,即下表数据:

编号车辆速度道路等级拥堵状态
02011.75
13020.75
2603-0.25
3704-1.25

 遍历每个特征的每个可能取值。从车辆速度为20开始,到道路等级特征为4结束,分别计算分裂后两组数据的平方损失(Square Error), dbd5d089109540cf778dcc1ff7cfedef.png为左节点的平方损失,6b56733a95db16464f19520cb8f778e3.png 为右节点的平方损失,找到使平方损失和 7e81aabd7a6ee6fd33211e6f9a70ce18.png最小的那个划分节点,即为最佳划分节点。

 例如:以车辆速度为30划分节点,将小于30的样本划分为左节点,大于等于30的样本划分为右节点。

划分点小于划分点的样本大于等于划分点的样本SE_{l}SE_{r}SE_{SUM}
车辆速度20/0,1,2,305.255.25
车辆速度3001,2,302.18752.1875
车辆速度600,12,3...
车辆速度700,1,23
道路等级1/0,1,2,3
道路等级201,2,3
道路等级30,12,3
道路等级40,1,233.67562503.675625

 以上划分点的总平方损失最小有两个划分点:车辆速度30和道路等级3.所以随机选一个作为划分点,这里我们选车辆速度30:

我们设置的参数中树的深度max_depth=3,现在树的深度只有2,需要再进行一次划分,这次划分要对左右两个节点分别进行划分: 

 此时我们的树深度满足了设置,还需要做一件事情,给这每个叶子节点分别赋一个参数5dd65e83175d3a82381904a723f5df60.png,来拟合残差。

 这里其实和上面初始化弱学习器是一样的,对平方损失函数求导,令导数等于零,化简之后得到每个叶子节点的参数55835c31efcef3f5004d4a89b886558b.png,其实就是标签值的均值。这个地方的标签值不是原始的cd8555e2349602e2e6b6444d509090fb.png,而是本轮要拟合的标残差3698080b4296ac347839c0ea8614f1b3.png

此时可更新强学习器,需要用到参数学习率:learning_rate=0.1,用65424a6ab8dcfd3d1b74aa4629508fc2.png表示。更新公式为:

 为什么要用学习率呢?这是Shrinkage的思想,如果每次都全部加上拟合值 7313fa5ea2e64c990352475698b0f439.png,即学习率为1,很容易一步学到位导致GBDT过拟合。

重复此步骤,最后生成5棵树。

得到最后的强学习器:

 四、Python实现

1.原始决策树累积

如果安装我们上一步这样原生计算推论的话,那么代码应该这样写:

from sklearn.tree import DecisionTreeRegressor
import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
import pandas as pd
import pydotplus
from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz
import os
os.environ["Path"] += os.pathsep + 'D:/Graphviz/bin'

data_1=[[20,1,5],[30,2,4],[60,3,2],[70,4,2]]
data=pd.DataFrame(data_1,columns=['speed','kind','state'])

X=np.array(data.iloc[:,:-1]).reshape((-1,2))
y=np.array(data.iloc[:,-1]).reshape((-1,1))
tree_reg1 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg1.fit(X, y)
y2 = y - np.array([3.25]*4).reshape((-1,1))
tree_reg2 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg2.fit(X, y2)
y3 = y2 - 0.1*np.array(tree_reg2.predict(X)).reshape((-1,1))
tree_reg3 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg3.fit(X, y3)
y4 = y3 - 0.1*np.array(tree_reg3.predict(X)).reshape((-1,1))
tree_reg4 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg4.fit(X, y4)
y5 = y4 - 0.1*np.array(tree_reg4.predict(X)).reshape((-1,1))
tree_reg5 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg5.fit(X, y5)
y6 = y5 - 0.1*np.array(tree_reg5.predict(X)).reshape((-1,1))
tree_reg6 = DecisionTreeRegressor(max_depth=4,random_state=10)
tree_reg6.fit(X, y6)

 

2.sklearn

 使用sklearn的话:

estimator=GradientBoostingRegressor(random_state=10)
estimator.fit(data.iloc[:,:-1],data.iloc[:,-1])
dot_data = export_graphviz(estimator.estimators_[5,0], out_file=None, filled=True, rounded=True, special_characters=True, precision=4)
graph = pydotplus.graph_from_dot_data(dot_data)

 二者树不同是因为参数学习率以及树的深度,迭代次数不一致导致,无碍。

那么我们现在拿预测样本来使用:

predict_data=pd.DataFrame({'speed':50,'kind':3},index=[0])
estimator.predict(predict_data)

 至此模型建立完毕,那么让我们总结一下GBDT模型特性:

AdaBoost和GBDT都是重复选择一个表现一般的模型并且每次基于先前模型的表现进行调整。不同的是,AdaBoost是通过调整错分数据点的权重来改进模型,GBDT是通过计算负梯度来改进模型。因此,相比AdaBoost, GBDT可以使用更多种类的目标函数,而当目标函数是均方误差时,计算损失函数的负梯度值在当前模型的值即为残差。

GBDT的求解过程就是梯度下降在函数空间中的优化过程。在函数空间中优化,每次得到增量函数,这个函数就是GBDT中一个个决策树,负梯度会拟合这个函数。要得到最终的GBDT模型,只需要把初始值或者初始的函数加上每次的增量即可。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/393885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring_让Spring 依赖注入彻底废掉

在Spring之基于注解方式实例化BeanDefinition&#xff08;1&#xff09;_chen_yao_kerr的博客-CSDN博客中&#xff0c;我们在末尾处分享了一个甜点&#xff0c;就是关于实现了BeanDefinitionRegistryPostProcessor也可以实例化bean的操作&#xff0c;首先需要去了解一下那篇博客…

宝塔(二):升级JDK版本

目录 背景 一、下载JDK17 二、配置环境变量 三、配置新的JDK路径 背景 宝塔的软件商店只有JDK8&#xff0c;不满足我当前项目所需的JDK版本&#xff0c;因此想对JDK版本进行升级&#xff0c;升级为JDK17。 一、下载JDK17 先进入 /usr/lib/jvm 目录 点击终端&#xff0c;进…

OpenCV——line、circle、rectangle、ellipse、polylines函数的使用和绘制文本putText函数以及绘制中文的方法。

学习OpenCV的过程中&#xff0c;画图是不可避免的&#xff0c;本篇文章旨在介绍OpenCV中与画图相关的基础函数。 1、画线条——line()函数 介绍&#xff1a; cv2.line(image, start_point, end_point, color, thickness)参数&#xff1a; image: 图像start_point&#xff1a…

拉链表(小记)

拉链表创建外部表将编写的orders.txt上传到hdfs创建一个增减分区表将orders表的数据传入ods_orders_inc查看分区创建历史表插入数据操作创建外部表 create database lalian; use lalian;create external table orders(orderId int,createDate string,modifiedTime string,stat…

Redis集群方案应该怎么做?

今天我们来跟大家唠一唠JAVA核心技术-RedisRedis是一款流行的内存数据库&#xff0c;适用于高性能的数据缓存和实时数据处理。当需要处理大量数据时&#xff0c;可以使用Redis集群来提高性能和可用性。Redis在单节点模式下&#xff0c;虽然可以支持高并发、快速读写、丰富的数据…

sizeof与一维数组和二维数组

&#x1f355;博客主页&#xff1a;️自信不孤单 &#x1f36c;文章专栏&#xff1a;C语言 &#x1f35a;代码仓库&#xff1a;破浪晓梦 &#x1f36d;欢迎关注&#xff1a;欢迎大家点赞收藏关注 sizeof与一维数组和二维数组 文章目录sizeof与一维数组和二维数组前言1. sizeof与…

专业版即将支持自定义场景测试

物联网 MQTT 测试云服务 XMeter Cloud 专业版于 2022 年底上线后&#xff0c;已有不少用户试用&#xff0c;对数千甚至上万规模的 MQTT 并发连接和消息吞吐场景进行测试。同时我们也收到了希望支持更多物联网协议测试的需求反馈。 新年伊始&#xff0c;XMeter 团队全力聚焦于 …

搭建Gerrit环境Ubuntu

搭建Gerrit环境 1.安装apache sudo apt-get install apache2 注意:To run Gerrit behind an Apache server using mod_proxy, enable the necessary Apache2 modules: 执行:sudo a2enmod proxy_http 执行:sudo a2enmod ssl 使新的配置生效&#xff0c;需要执行如下命令:serv…

ctfshow【菜狗杯】wp

文章目录webweb签到web2 c0me_t0_s1gn我的眼里只有$抽老婆一言既出驷马难追TapTapTapWebshell化零为整无一幸免无一幸免_FIXED传说之下&#xff08;雾&#xff09;算力超群算力升级easyPytHon_P遍地飘零茶歇区小舔田&#xff1f;LSB探姬Is_Not_Obfuscateweb web签到 <?ph…

在社交媒体上行之有效的个人IP趋势

如果您认为无论是获得一份工作、建立一家企业还是推动个人职业发展&#xff0c;社交媒体都是帮助您实现目标的可靠工具&#xff0c;那么个人IP就是推动这一工具前进的燃料。个人IP反映了您是谁&#xff0c;您在所处领域的专业程度&#xff0c;以及您与他人的区别。社交媒体将有…

打破原来软件开发模式的无代码开发平台

前言传统的系统开发是需要大量的时间和成本的&#xff0c;如今无代码开发平台的出现就改变了这种状况。那么你知道什么是无代码开发平台?无代码开发对企业来说有什么特殊的优势么?什么是无代码平台无代码平台指的是&#xff1a;使用者无需懂代码或手写代码&#xff0c;只需通…

代码分享:gprMax钻孔地质雷达波场模拟

代码分享&#xff1a;gprMax钻孔地质雷达波场模拟 前言 gprMax模拟地面地质雷达被广泛使用&#xff0c;但是在钻孔内进行地质雷达的模拟较少。本博文尝试利用gprMax进行钻孔地质雷达的模拟&#xff0c;代码仅供大家借鉴。 文章目录代码分享&#xff1a;gprMax钻孔地质雷达波场…

【数据结构】链表练习题(1)

练习题1.移除链表元素(LeetCode203)2.链表的中间结点(LeetCode876)3.链表的倒数第k个结点(剑指offer)4.反转链表(LeetCode206)5.合并两个有序链表(LeetCode21)6.链表分割(牛客)7.链表的回文结构(牛客)1.移除链表元素(LeetCode203) 给你一个链表的头结点 head 和一个整数 val &…

第十四届蓝桥杯三月真题刷题训练——第 4 天

目录 题目 1 &#xff1a;九数算式_dfs回溯(全排列) 题目描述 运行限制 代码&#xff1a; 题目2&#xff1a;完全平方数 问题描述 输入格式 输出格式 样例输入 1 样例输出 1 样例输入 2 样例输出 2 评测用例规模与约定 运行限制 代码&#xff1a; 题目 1 &am…

数据结构刷题(十九):77组合、216组合总和III

1.组合题目链接过程图&#xff1a;先从集合中取一个数&#xff0c;再依次从剩余数中取k-1个数。思路&#xff1a;回溯算法。使用回溯三部曲进行解题&#xff1a;递归函数的返回值以及参数&#xff1a;n&#xff0c;k&#xff0c;startIndex(记录每次循环集合从哪里开始遍历的位…

场景式消费激发春日经济,这些电商品类迎来消费热潮

春日越临近&#xff0c;商机越浓郁。随着气温渐升&#xff0c;春日经济已经潜伏在大众身边。“春菜”、“春装”、“春游”、“春季养生”等春日场景式消费走热。 下面&#xff0c;鲸参谋为大家盘点几个与春日经济紧密相关的行业。 •春日仪式之春游踏青 ——户外装备全面开花…

查看 WiFi 密码的两种方法

查看 WiFi 密码的两种方法1. 概述2. 在控制面板中查看 WiFi 密码3. 使用 CMD 查看 WiFi 密码结束语1. 概述 突然忘记 WiFi 密码怎么办&#xff1f; 想连上某个使用过的 WiFi&#xff0c;但有不知道 WiFi 密码怎么办&#xff1f; 使用电脑如何查询 WiFi 密码&#xff1f; 以下是…

zabbix4.0 网络发现-自动添加主机-自动注册

zabbix的网络发现 网络发现的好处&#xff1a; 加快zabbix部署 简化管理 无需过多管理就能在快速变化的环境中使用zabbix zabbix网络发现给予以下信息 IP范围 可用的外部服务&#xff08;FTP&#xff0c;SSH&#xff0c;WEB&#xff0c;POP3&#xff0c;IMAP&#xff0c;TCP等&…

一篇深入解析BTF 实践指南

BPF 是 Linux 内核中基于寄存器的虚拟机&#xff0c;可安全、高效和事件驱动的方式执行加载至内核的字节码。与内核模块不同&#xff0c;BPF 程序经过验证以确保它们终止并且不包含任何可能锁定内核的循环。BPF 程序允许调用的内核函数也受到限制&#xff0c;以确保最大的安全性…

FPGA使用GTX实现SFP光纤收发SDI视频 全网首创略显高端 提供工程源码和技术支持

目录1、前言2、设计思路和框架3、vivado工程详解4、上板调试验证并演示5、福利&#xff1a;工程代码的获取1、前言 FPGA实现SDI视频编解码目前有两种方案&#xff1a; 一是使用专用编解码芯片&#xff0c;比如典型的接收器GS2971&#xff0c;发送器GS2972&#xff0c;优点是简…