Spark 3.0 - 12.ML GBDT 梯度提升树理论与实战

news2025/1/21 21:53:56

目录

一.引言

二.GBDT 理论

1.集成学习

2.分类 & 回归问题

3.梯度提升

4.GBDT 生成

三.GBDT 实战

1.数据准备

2.构建 GBDT Pipeline

3.预测与评估

四.总结


一.引言

关于决策树前面已经介绍了常规决策树与随机森林两种类型的知识,本文主要介绍梯度提升树 Gradient Boosting Decision Tree 即常说的 GBDT,其实一种使用决策树集成的流行分类和回归方法。梯度提升算法的思想类似于随机梯度下降。该算法中模型由若干个 F(x) 即基学习器构成,每个 F(x) 都拥有一个权重 Weight,初始化时各个权重相同,之后不断地将模型计算结果与真实结果进行比较,如果出错则增加错误样本的权重并基于新权重样本,让模型朝着损失减少最快的负梯度方法进行优化。其整体可以看做是 Bossting 方法,主要思想是每一次建立模型都是在之前建立模型损失函数的梯度下降方向,即"每次沿着当前位置最陡峭,损失下降最快的方向移动"。

二.GBDT 理论

决策树相对来说很直观形象,同学们也很好理解,但是到了梯度提升树,负梯度、最小化残差等概念的出现容易找不到方向,其次为什么0-1分类问题也有梯度等等疑问也随之而来,在使用 Spark 3.0 ML 介绍梯度提升树的使用之前,我们先熟悉一些 GBDT 的基础数学概念,做到理论实践相结合。

1.集成学习

上一文随机森林就是一种集成学习的方法,其一般都有一个基学习器 Tk,对于随机森林、梯度提升树而言,基学习器 Tk 就是我们常规的 DT 决策树。针对常见的分类与回归问题,我们的问题都可以转化为下述数学语言,构造一个函数 y = f(x),训练模型使得 f(x) 尽量与真实值相同:

y = f(x)

而实际运行中,我们的模型很难做到百发百中的精准预测,往往预测值与真实值之前存在一定偏差:

y = f(x) + ResidualError

这个 residual error 就是我们常说的残差,即 y - f(x) 的差值即真实值与预测值之前的差异。实践场景下我们一般对残差进行如下度量:

A.偏差 - 与真实值分布的偏差大小,体现模型的预测能力,越小模型预测越准

B.方差 - 与真实值分布的偏差均值方差,体现模型的预测稳定性,越小模型越稳定

集成学习中,基学习器一般为简单的算法模型、例如 LR、DT,因此其单一学习器的预测能力有限,从而通过集成学习将多个基学期的组合,针对残差进行拟合,进而降低模型的偏差与方差,提高模型整体的预测能力与稳定性,达到 "N个臭皮匠 Tk,能顶一个诸葛亮" 的思想。虽然 RF 和 GBDT 都是基于树的集成学习,但是二者亦有不同,前者 Tk 可以并行生成,是典型的 bagging,而 GBDT 的 Tk 是串行生成,类似于 Adaboost。

2.分类 & 回归问题

回归问题我们在 LR 中就遇到过,下面简单复述下,给定数据集 X:

D={(x_1,y_1), (x_2,y_2),\cdots ,(x_n,y_n)}

其中 x 为 K 维特征 (K >=1):

x_n=(x_{n1},x_{n2}, \cdots,x_{nk})

其中 y 为真实输出值,分类任务对应 0-1,回归任务对应预测值,我们的目的就是构建一个模型:

F(x_n)

去尽可能的逼近每一个真实值 y。

A.分类问题损失函数:(常见的指数损失函数)

loss(y,F(x))=-y_iF(x)+log(1+exp(F(x)))

B.回归问题损失函数:(常见的 MSE 均方误差损失函数)

loss(y,F(x))=E(y - F(x))^2

3.梯度提升

Gradient Boosting Decision Tree,简单分词可以得到两个主体,分别为 Gradient Boosting 与 Decision Tree,所以我们把这两个东西搞差不多,GBDT 我们也就搞差不多了,DT 可以参考 决策树原理与实战。此时基学习器 Tk 为 DT 决策树,假设当前:

F_{k}(x)=\sum_{i=1}^{k}T_i{x}

前 K 个基学习器的预测值为 Fk(x),可以看到 GBDT 是一种加法模型,它把所有基础模型的预测值累加起来作为最终的预测值。由于 GBDT 采用串行的生成方式生成新的基学习器,所以我们将上面的公式修改为递推形式:

F_k(x)=F_{k-1}(x) + T_k(x)

在训练第 K 个 T(x) 时,我们需要最小化如下目标函数:

J=\sum_{n=1}^{N}L(y_n,F_k(x_n)) = \sum_{n=1}^{N}L(y_n, F_{k-1}(x_n)+T_k(x))

此处我们需要使用梯度下降的方法,让目标函数的取值朝着最快的下降方向前进。以 MSE 交叉熵损失函数为例:

J=\sum_{n=1}^{N}L(y_n,F_k(x_n))=\sum_{n=1}^{N} \frac{1}{2}(y_n - F_k(x_n))^2

对 F(x) 求导可得:

\sum _{n=1}^{N}\frac{\partial L(y,F(x))}{\partial F(x)}=\sum _{n=1}^{N}\frac{\partial (\frac{1}{2}y_i-F_k(x_i))^2}{\partial F_k(x_i)}=\sum_{n=1}^{N}F_k(x_i) - y_i

后面得到的结果就是我们集成学习部分提到的负残差。由随机梯度下降更新公式可知,这里可以参考简易的 牛顿法参数更新,其中 α 为学习率:

F_k(x) = F_{k-1}(x) - \alpha \cdot \frac{\partial J}{\partial F_{k-1}(x)}

后面的求导结果为负残差,所以移项可得 (此处忽略 α):

F_{k}(x) - F_{k-1}(x) = T_k(x) = -1 \cdot \frac{\partial J}{\partial F} = \sum y_n - F_{k-1}(x_n)

所以可以看到每次新增的基学习器 Tk 都用于拟合之前所有 Ti 与当前真实值之间的残差时,导数梯度下降最快,从而模型拟合效果更好。

4.GBDT 生成

GBDT 串行生成,假设我们的第一个基学习器是:

T_1(x)

此时对应残差为,第二个学习器 T2 负责拟合 T1 与 y 之间的残差:

ResidualError_1 = T_2(x) = y - T_1(x)

根据残差拟合 T2 并串行增加到 T1 后面,得到最新的 GBDT 模型:

\hat{y} = F(x) = T_1(x) + T_2(x)

依次类推,不断在新函数的基础上求得残差,并通过残差拟合新的 Tk,直到达到我们预定的精度要求或者树要求,即代表 GBDT 模型生成完毕:

\hat{y} = F(x) = \sum_{i=1}^{K}T_k(x)

实际操作中,有时还会根据上一轮的误差修改新一轮样本 X 的权重,从而使得新增的 Tk 对于之前集成学习器分类错误的样本能够拥有更好的分类结果,从而提升整体集成学习器的预测能力。

三.GBDT 实战

1.数据准备

spark.read.format 对 LIBSVM 数据进行读取加载

LabelIndexer 对预测值进行重新编码映射

featureIndexer 对特征进行离散与连续的区分

randomSplit 将数据按比例分为训练、测试数据

    val spark = SparkSession
      .builder//创建spark会话
      .master("local")
      .appName("GradientBoostedTreeClassifierExample")//设置名称
      .getOrCreate() //创建会话变量

    // 读取文件,装载数据到spark dataframe 格式中
    val data = spark.read.format("libsvm").load("/Users/xudong11/sparkV3/src/main/scala/org/example/RandomForest/sample_libsvm_data.txt")

    // 搜索标签,添加元数据到标签列
    // 对整个数据集包括索引的全部标签都要适应拟合
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(data)
    // 自动识别分类特征,并对其进行索引
    // 设置MaxCategories以便大于4个不同值的特性被视为连续的。
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data)

    // 按照7:3的比例进行拆分数据,70%作为训练集,30%作为测试集
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

Tips:

样本为 libsvm 格式,特征维度 692,标签为二分类

2.构建 GBDT Pipeline

gbt 构造 GBDT 分类器

labelConverter 将上述标签转换的标签再映射回来

pipeline 将上述 Stage 拼接得到最终的 Estimator

pipeline.fit 训练模型,获取预测的 transformer

    // 建立一个决策树分类器,并设置MaxIter最大迭代次数为10
    val gbt = new GBTClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setMaxIter(10)
      .setFeatureSubsetStrategy("auto")

    // 将索引标签转换回原始标签
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labelsArray(0))

    // 把索引和决策树链接(组合)到一个管道(工作流)之中
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))

    // 载入训练集数据正式训练模型
    val model = pipeline.fit(trainingData)

Tips:

这里 FeatureSubsetStrategy 是属性在每个节点中计算的数目,即用作在每个树节点进行分割的候选特征数量,该数字被指定为总特征数量的分数或函数。减少这个数字会加快训练速度,但是太低的话会影响性能,这里建议使用 auto 参数让 ML 内核自动决定每个节点的属性数。

3.预测与评估

model.transform 用上一步得到的 transformer 对测试集数据预测

evaluator 计算预测样本的 Accuracy

toDebugString 获取本次训练的 GBDT 树的简介

    // 使用测试集作预测
    val predictions = model.transform(testData)

    // 选择一些样例进行显示
    predictions.select("predictedLabel", "label", "features").show(5)

    // 计算测试误差
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println(s"Test Error = ${1.0 - accuracy}")

    val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
    println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}")

    spark.stop()

Tips:

由于篇幅长度,这里我们只展示前 4 棵树,我们的模型共拥有 10 棵树,对应问题为 2 分类问题,全部特征为 692 维度。将全部预测值与权重加权求和,再经过 sigmoid 函数即可得到对应标签类型,如果是回归问题,则不需要 sigmoid 函数。这里 GBDT 处理二分类也借鉴了 LR,通过 sigmoid 函数将分类的非线性问题转化到 y = wx + b 的线性函数。

 

四.总结

GBDT 增加学习器意在让模型的损失函数持续下降,其中最好的方式就是让损失函数在梯度方向下降,此时优化速度最快。Boosting 算法是一种继承学习方法,每一轮训练样本都是固定的,改变的是每个样本的权重,根据错误率调整样本权重,错误率越大的样本权重越大。各个预测函数只能顺序生成,因为后一个模型需要用到上一个模型的结果。通过加法模型不断减小训练产生的残差,实现数据的分类与回归。在 Gradient Boosting 中,每个新基学习器的建立都是为了使之前的模型残差往梯度方向减少。啰嗦了这么多,下面我们简单总结一下:

A.训练阶段,GBDT 的基学习器只能串行生成,但是预测阶段可以通过并行计算提高效率

B.GBDT 分类问题支持 LogLoss,回归问题支持 MSE、MAE,SPARK ML 默认为 L2 MSE。

C.Iter 参数为迭代次数,每增加1都会新增一棵树,预测的准确性也随之增加

D.适量的增加树可以提高模型准确能力,但也会带来过拟合风险,可以添加 Reg 正则化参数

E.GBDT 可以自动筛选重要特征,可以与其他模型配合使用,例如最常见的 GBDT + LR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/92663.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小游戏赛道如何加速流量增长?

小游戏是指设计极简的轻量级游戏。它构造简单,但却给人带来了娱乐性和重复参与的欲望。 近年来,小游戏在抖音、微信小游戏等平台拥有着疯狂裂变的可能性,出现了例如“羊了个羊”“跳一跳”、“合成大西瓜”等风靡一时的小游戏。 这些爆火的小…

「微服务系列」统一网关Gateway

为什么需要网关 网关功能: 身份认证和权限校验服务路由、负载均衡请求限流在SpringCloud中网关的实现包括两种: Zuul:基于Servlet的实现,属于阻塞式编程。SpringCloudGateway:是基于Spring5中提供的WebFlux&#xf…

关注渐冻症|菌群助力探索其发病机理及相关干预措施

最杰出的物理学家之一的斯蒂芬威廉霍金想必大家都知道,以及曾经风靡全网的“冰桶挑战”,它们都与一种罕见疾病有关,那就是渐冻症。 媒体的宣传让渐冻症成为了较为“知名”罕见病之一;2000年丹麦举行的国际病友大会上正式确定6月21…

【Redis】数据类型操作二 (Set/Hash/Zset)

文章目录3、Redis集合(Set)4、 Redis哈希(Hash)5、Redis有序集合Zset(sorted set)实操3、Redis集合(Set)4、 Redis哈希(Hash)5、Redis有序集合Zset(sorted set)3、Redis集合(Set) Redis Set 是String类型的无序集合。一个key集合可以对应多个value元素。Redis Set 可以自动排重…

[附源码]Python计算机毕业设计高校篮球训练管理系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等…

Python如何做自动化测试?

众做周知,自动测试的优势是显而易见的,它可以大大节省我们的时间,提高我们的工作效率。那么Python如何做自动化测试呢?本文将用Python编写一个简单的测试用例,并指导大家写做自动化测试的代码。如果大家对这个内容感兴…

基于java+springboot+mybatis+vue+mysql的会员制医疗预约服务管理信息系统

项目介绍 会员制医疗预约服务管理信息系统是针对会员制医疗预约服务管理方面必不可少的一个部分。在会员制医疗预约服务管理的整个过程中,会员制医疗预约服务管理系统担负着最重要的角色。为满足如今日益复杂的管理需求,各类的管理系统也在不断改进。系…

[计算机网络微课]第三章 数据链路层

数据链路层 概述 数据链路层在网络体系结构中的地位 主机 H1 给主机 H2 发送数据,中间要经过 3 个路由器和电话网、局域网以及广域网等多种网络。从五层协议原理体系结构角度来看 为了专注数据链路层内容,这里我们只考虑数据链路层,而不考…

体外诊断丨艾美捷游离维多珠单抗ADA水平检测试剂盒

introduction: Crohns disease in patients with moderate to severe active ulcerative colitis, routine treatment or tumor necrosis factor α (TNF α) Antagonists can also be treated with vidolizumab. Vedolizumab is a humanized monoclona…

并查集引入

目的 主要是处理一些不相交集合的合并问题,比如:求连通子图,求最小生成树的克鲁斯卡尔算法以及最近公共祖先(LCA)等 简单应用就是连通图,将元素进行合并,如果要优化路径的话可以利用数据压缩 …

大学生简单抗击疫情静态HTML网页设计作品 DIV布局疫情感动人物介绍网页模板代码 DW学生抗疫逆行者网站制作成品下载

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

vTESTstudio入门到精通 - vTESTstudio工具栏介绍_Home

继上篇介绍File功能模块之后,今天我们来介绍vTESTstudio工程使用过程的种的另外一个重要的工具栏Home,这块将是我们使用vTESTstudio编程中使用最多的一个功能模块。话不多说,下面我们就来一一介绍该功能栏能在我们编程的时候做哪些事情。 2、…

网关服务限流熔断降级【Gateway+Sentinel】

目录 第一步:启动sentinel-dashboard控制台 第二步:在网关服务中引入sentinel依赖 第三步:在网关服务application.yml中配置sentinel 第四步:通过网关进入服务 再进入sentinel控制台查看链路情况 第一步:启动sen…

一个简单的dw网页制作作业,学生个人html静态网页制作成品代码——怪盗基德动漫主题网页成品(15页)

HTML实例网页代码, 本实例适合于初学HTML的同学。该实例里面有设置了css的样式设置,有div的样式格局,这个实例比较全面,有助于同学的学习,本文将介绍如何通过从头开始设计个人网站并将其转换为代码的过程来实践设计。 ⚽精彩专栏推荐&#x1…

【mmdetection系列】mmdetection之evaluate评测

1.configs 还是以yolox为例,配置有一项evaluation。用于配置评估是用什么评价指标评估。 https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py#L151 max_epochs 300 num_last_epochs 15 interval 10evaluation di…

LVS 负载均衡

LVS 负载均衡 本篇主要介绍一下 lvs 是什么 以及它的 nat 模式的搭建 配合nginx来演示 1.概述 LVS 是 Linux Virtual Server 的简写 (Linux 虚拟服务器 ), 是由章文嵩博士主导, 它虚拟出一个服务器集群,然后进行负载均衡的项目, 目前LVS 已经被集成到Linux内核模块中了, 外部请…

直播弹幕系统(三)- 直播在线人数统计

直播弹幕系统(三)- 直播在线人数统计前言一. 在线人数统计功能实现1.1 Redis整合1.2 在线人数更新1.3 演示前言 上一篇文章整合RabbitMQ进行消息广播和异步处理 写完了消息的广播、削峰、异步处理业务逻辑等操作。完成了实时共享功能。 不过写到后面发…

Netcat介绍及安装使用

目录 介绍 Linux 安装 Windows安装 1.下载安装包 2.解压安装包 3.安装路径加入系统变量 Netcat命令参数 使用Netcat互相通信 1.创建一个服务端 2.创建一个客户端(连接服务端) 介绍 Netcat 是一款简单的Unix工具,使用UDP和TCP协议。…

七、Docker 安装Tomcat(流程、注意点、实操)

1、从中央仓库搜索tomcat 命令:docker search tomcat 也可以从官网查找,地址:Docker Hub 2、从中央仓库拉取tomcat 命令:docker pull tomcat:8.0 这里我们选择8.0 版本tomcat 3、查看镜像 命令:docker images 4、运行镜像 命令:docker run -d

如何从内存卡恢复丢失的数据?简单内存卡(SD卡)数据恢复方法分享

SD卡,也就是内存卡,在日常使用中有着体积小、存储量大的优点,被我们用来存储一些重要的数据。相机是使用SD卡的场景之一。目前大多数相机都使用SD卡来存储相关数据,这不仅是因为SD容量的优势,而且其运行速度也比较快&a…