【机器学习】集成学习—Boosting—GBM(Gradient Boosting Machine)解析

news2024/11/18 11:49:17

【机器学习】集成学习—Boosting—GBM(Gradient Boosting Machine)解析

文章目录

  • 【机器学习】集成学习—Boosting—GBM(Gradient Boosting Machine)解析
    • 1. 介绍
    • 2. Boosting
      • 2.1 1. 强 / 弱学习器
      • 2.1.2 AdaBoost
    • 3. GBM
      • 3.1 GBM 特例
      • 3.2 梯度下降 - 参数空间
      • 3.3 梯度下降 - 函数空间
      • 3.4 整体思路
      • 3.5 损失函数
      • 3.6 缩减(shrinkage)
    • 4. GBDT
    • 参考

1. 介绍

前面我们在 https://blog.csdn.net/qq_51392112/article/details/130507112 了解了集成学习的相关知识。这一节我们讲什么是GBM,也就是Gradient Boosting Machine(梯度提升机)。

GBM 是一种集成算法(Boosting类)。常见的集成学习算法包括 Boosting,Bagging(也叫 Bootstrap Aggregating),Stacking 等。

  • 其中 Boosting 包含经典的 AdaBoost 和 依靠梯度提升的 GBM。
    • GBDT 就属于 GBM 这类,它是基学习器为树模型的 GBM。
  • 同类算法还有在 GBM 之上进行了全面优化的 XGBoost,以及进行了速度优化的 LightGBM 等。
  • 如果用一张图来表示它们之间的关系,就是这样:
    在这里插入图片描述
  • GBDT 是 GBM 的一个特例,它的效果非常好,所以一提 GBM,最先会想到 GBDT。
    • GBDT 以 CART 为基学习器(此处也叫弱学习器),通过抑制基学习器的复杂度,缓解基学习器的过拟合风险,提高泛化能力。
    • 同时 GBDT 对多个基学习器串行训练,通过结果相加来对基学习器集成,提高拟合能力。

2. Boosting

2.1 1. 强 / 弱学习器

提升方法(Boosting)的主要思想是:把多个高偏差的弱学习器组合利用起来,降低整体偏差,形成一个强学习器。

  • 弱学习器就是比随机分类稍好一点,比如随机分类正确率为 50%,错误率为 50%,那么弱学习器正确率就是刚刚超过 50% 一点,比如 55%,

  • 而强学习器则是正确率很高很高,比如 90%,如下图:
    在这里插入图片描述

  • 为什么 Boosting 要用弱学习器呢?

    • 因为弱学习器容易得到,难度低,成本低。而且20 世纪的时候,还没有深度学习,弱学习器很多,强学习器较少
    • 学者幻想一种能够让“三个臭皮匠顶个诸葛亮”的算法,
  • 因此, Boosting 这种思想就出来了。

2.1.2 AdaBoost

Boosting 的思想是:n 个弱学习器 -> 强学习器。它没有限定算法特点,能把弱变强即可,但大多数 Boosting 算法都会螺旋迭代式地训练弱学习器(串行,实现的时候可以并行),然后将结果加起来作为最终结果。形象点来说就是这样:
在这里插入图片描述
典型的boosting方法就是 AdaBoost(Adaptive Boosting),很多教材或博客讲 Boosting 的时候都会以 AdaBoost 为例。

  • 这是因为 AdaBoost 是第一个实现了 Boosting 效果的算法。
  • 上图所说的其实就是 AdaBoost 用于分类任务时的基本思想:
    • 先用弱学习器对样本分类;
    • 错分的样本在下一轮学习中提高权重;
    • 这样重复多次,得到多个弱学习器,并且每个弱学习器都学到一个权重 ;
    • 将多个弱学习器加权求和,得到想要的强学习器。

3. GBM

梯度提升机(Gradient Boosting Machine,GBM)是 Boosting 的另一种实现方式。

  • 前面提到的 AdaBoost 是依靠调整预测错误数据样本的权重来训练新的学习器,进而降低偏差;

  • 而 GBM 则是让新分类器拟合负梯度来降低偏差。
    在这里插入图片描述

  • 梯度提升机这个名字可能有一点迷惑性。

    • 我们都听过梯度下降算法,所以当听到梯度提升,可能会误以为这是让梯度提升的算法。
    • 然而并不是这样,提升 (Boosting)指的是让弱学习器变成强学习器,跟梯度没有半点关系。
  • 所以梯度提升机应该这样理解:使用了梯度下降的提升机

3.1 GBM 特例

为了便于理解,从特例讲起,然后再泛化到一般情况。

  • GBM 最好理解的特例是无缩减(shrinkage)、损失函数取平方误差的回归情况。
  • 假设我们想拟合一段正弦曲线,为此采了一系列的点,构成一个数据集:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

3.2 梯度下降 - 参数空间

  • 首先回顾下梯度下降。我们通常都是对模型参数进行梯度下降优化,比如神经网络。这个叫做参数空间的梯度下降。
    • 具体步骤是:
      • 1)首先用损失函数对各个参数求导,得到负梯度;
      • 2)然后用原始参数加上负梯度,实现一次迭代优化。公式如下:
        在这里插入图片描述
      • 算法从初始位置顺着负梯度方向,走到局部最优点。
        在这里插入图片描述

3.3 梯度下降 - 函数空间

与参数空间的梯度下降(针对 θ \theta θ )类似,GBM 是在函数空间(也就是针对 F(x) 这个函数)进行梯度下降。

  • 先定义下对于 单个样本的平方误差 损失函数:
    在这里插入图片描述
    我们的目标是降低整个训练集的损失值,它的计算公式是:
    在这里插入图片描述
    在这里插入图片描述
    这样,我们就得到了残差和负梯度的关系(损失取平方误差时)。
    在这里插入图片描述

3.4 整体思路

  • GBM 借鉴参数空间的梯度下降,得到了函数空间的梯度下降
  • 当损失函数为平方误差时,负梯度恰好呈现残差形式
  • 用一个新学习器去拟合残差,把新学习器结果加到原模型上,并反复加加加,慢慢逼近真实值。

再从公式角度小结下:
在这里插入图片描述

3.5 损失函数

平方误差算起来很快,但是他有一个很大的问题:受不了异常点(下面会进行解释)。先说两个对异常点比较鲁棒的损失函数,分别是:

  • 绝对值损失函数:
    在这里插入图片描述
  • Huber 损失函数:
    在这里插入图片描述
  • 三种损失函数可视化一下:(平方损失、绝对值损失、huber损失)
    在这里插入图片描述
  • 下面通过一张表格对比下他们对于异常情况的响应,最右边一列是异常样本。
    在这里插入图片描述
    分析:对于异常点 y i = 5 y_i = 5 yi=5,预测值正常为 1.7。
    • 这个时候平方损失为 5.445,最大,是 Huber 的 3.5倍;
    • 绝对值损失居中。绝对值误差在当前情况下,对于异常点的损失没有 Huber 小,说明鲁棒性没 Huber 好。
      那绝对值误差有什么优点呢?
      • 可以观察到,当误差小的时候,他的损失值相对更大,也就是说他对于拟合较好地样本也有很高地关注度。

3.6 缩减(shrinkage)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

算法实现过程如伪代码所写:

  • 1)先计算伪响应;
  • 2)根据伪响应拟合新学习器,得到新学习器参数;
  • 3)根据最速下降法的思想,求解最优步长;
  • 4)更新集成模型,进入下一轮梯度下降过程。

4. GBDT

梯度提升决策树(Gradient Boosting Decision Tree,GBDT)是 GBM + CART。CART 作为 GBM 的基模型,GBM 做为 CART 的集成方法。虽然 GBM 可以跟任何回归器结合,但通常用的都是 GBM 与 CART 的组合,因为它的效果总体来说最好。在 sklearn 库里有现成的 GBDT 类,分别是:

  • 分类器 GradientBoostingClassifier
  • 回归器 GradientBoostingRegressor

两个类的名字里虽然没有 tree,但都默认是用 tree 作为基学习器,而且不能修改。GBM 和 GBDT 的公式差不多,但是 GBDT 有一个计算过程的优化,能快一些。

参考

【1】https://zhuanlan.zhihu.com/p/361036526
【2】Boosting - 维基百科 https://en.wikipedia.org/wiki/Boosting_(machine_learning)
【3】Greedy function approximation: a gradient boosting machine https://projecteuclid.org/journals/annals-of-statistics/volume-29/issue-5/Greedy-function-approximation-A-gradient-boosting-machine/10.1214/aos/1013203451.full
【4】A gentle introduction to gradient boosting https://www.ccs.neu.edu/home/vip/teach/MLcourse/4_boosting/slides/gradient_boosting.pdf
【5】《统计学习方法》李航

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/496351.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何利用ChatGPT进行论文润色-ChatGPT润色文章怎么样

ChatGPT润色文章怎么样? ChatGPT可以润色文章,使用其润色功能可以为用户提供更加整洁、清晰、文采动人的文本。但需要注意以下几点: 需要保持文本的一致性和完整性。当使用ChatGPT进行润色时,需要注意保持文本的一致性和完整性。…

单调栈的学习

文章目录 单调栈的学习什么是单调栈?单调栈模板暴力解法单调栈解法 单调栈的简单变形1.[496. 下一个更大元素 I](https://leetcode.cn/problems/next-greater-element-i/)2.[739. 每日温度](https://leetcode.cn/problems/daily-temperatures/)3.[503. 下一个更大元…

Kali Linux 使用远程桌面连接——xrdpxfce

[笔者系统版本] [Kali]: Kali Linux 2023.1 [Kernel]: kernel 6.1.0 [Desktop]: Xfce 4.18.1 1. 前言 在 Windows 中我们会经常使用到远程桌面这样便利的工具,让我们随时随地都可以使用自己想要使用的电脑,或者同时使用多台设备,那么本文就将…

open3d 源码阅读image_processing.py

目录 1. open3d.geometry.Image和numpy互转 2. 对open3d.geometry.Image进行高斯过滤 3. 高斯金字塔过滤 4. sobel过滤 5. 可视化o3d.geometry.Image 1. open3d.geometry.Image和numpy互转 import numpy as np import matplotlib.pyplot as plt import matplotlib.image a…

Midjourney从入门到精通

前言 什么是AI绘画 AI 绘画,顾名思义就是利用人工智能进行绘画,是人工智能生成内容(AIGC)的一个应用场景。其主要原理就是收集大量已有作品数据,通过算法对它们进行解析,最后再生成新作品,而算…

vue框架快速入门

vue 1、第一个Vue程序1.1、什么是Vue程序1.2、为什么要使用MVVM1.3、Vue1.4、第一个vue程序 2、基础语法2.1、v-bind2.2、v-if, v-else2.3、v-for2.4、v-on 3、Vue表单双绑、组件3.1、什么是双向数据绑定3.2、在表单中使用双向数据绑定3.3、什么是组件 4、Axios异步…

NixOS Legacy Boot(MBR) VmwareWorkstation安装向导

NixOS & Legacy Boot(MBR) VmwareWorkstation安装向导 目录 NixOS & Legacy Boot(MBR) VmwareWorkstation安装向导1. 下载镜像2. 创建空白虚拟机3. 使用命令行安装 NixOS3.1 Legacy Boot(MBR)3.2 格式化 4. configration.nix 配置文件5. 部署NixOS6. 部分教育站镜像源集…

Maven 3.9.1下载安装配置一条龙(无压力)亲测

这里写自定义目录标题 前言一、下载 Apache Maven 3.9.11.1、请先检查自己的IDEA是否有这个条件,是否兼容1.2、Maven下载 二、Windows安装配置Maven2.1、解压2.2、新建 repository 本地仓库2.3、配置环境变量MAVEN_HOME 软件路径M2_HOME 本地仓库路径配置Path2.3.1新…

关于maven

一、maven是什么 一个java项目构建工具 二、maven的作用 (1)依赖管理 不同框架整合,互相依赖jar包版本不同,版本不一样,程序跑起来就会报错。用maven管理jar包。 (2)跨平台构建项目 linux服…

数字信号处理3:A/D、D/A转换

信号这个东西,我们是实际应用中用的大多都是模拟信号,比如说语音、地震、雷达、声纳信号,这些都是模拟信号,但是,计算机想要通过数学方法处理模拟信号,就要先将模拟信号转换成具有有限精度的数字序列&#…

L4公司进军辅助驾驶,放话无图也能跑遍中国

作者 | Amy 编辑 | 德新 高阶智能驾驶走向规模量产,高精地图成为关键的门槛之一。今年,多家车企和智驾公司都喊出「不依赖高精地图,快速大规模落地」的口号。 华为、小鹏、元戎以及毫末等,可能是最快在国内量产 无高精图智…

TCP/IP网络编程(一)

TCP/IP网络编程读书笔记 第1章 理解网络编程和套接字1.1 理解网络编程和套接字1.1.1 构建打电话套接字1.1.2 编写 Hello World 套接字程序 1.2 基于Linux的文件操作1.2.1 底层访问和文件描述符1.2.2 打开文件1.2.3 关闭文件1.2.4 将数据写入文件1.2.5 读取文件中的数据1.2.6 文…

AI仿写软件-仿写文章生成器

AI仿写软件:高效出色的营销利器 作为互联网时代的营销人员,我们不仅需要品牌意识,还必须深谙营销技巧。万恶的时限压力使得我们不得不在有限的时间内输出更多的文本内容,以便吸引更多的关注。那么,如何解决这个问题呢…

C++网络基础知识面试题2

目录 1、使用TCP的常见协议有哪些?使用UDP的常见协议有哪些?简单说几个 2、如何判断访问目标地址的网络是通的?如何简单地查看到目标地址的网络是否有丢包和抖动? 3、如果知道目标服务器的服务端口有没有开启? 4、…

【NodeJs】使用Express框架快速搭建一个web网站

如果电脑有安装使用Nodejs,用得次数少的话,忘了怎么弄,可以看看这个文章,按照步骤,能快速搭建一个web网站服务器, 首先,你需要保证电脑系统有安装了Node.js,然后可以用VsCode开发工…

Java多线程基础概述

简述多线程: 是指从软件或者硬件上实现多个线程并发执行的技术。 具有多线程能力的计算机因有硬件支持而能够在同一时间执行多个线程,提升性能。 正式着手代码前,需要先理清4个概念:并发,并行,进程&#…

ChatGPT带你领略自动驾驶技术

一、自动驾驶技术现概述 自动驾驶技术是指利用计算机、传感器和其他设备,使车辆能够在不需要人类干预的情况下自主行驶的技术。目前,自动驾驶技术已经在一些汽车厂商和科技公司中得到广泛应用,但仍然存在一些技术和法律上的挑战,需…

c++类友元函数理解(图、文、代码)

序: 1、初学c,理解阶段,一下为个人理解和案例,陆续更新 一、友元函数和普通函数区别 类的友元函数是函数,但是他可以调用类的私有变量,以下代码,Fun2是报错的,因为这个函数跟A没任…

基于SSM框架流浪动物救助及领养管理系统(spring+springmvc+mybatis+jsp+jquery+layui)

一、项目简介 本项目是一套基于SSM框架流浪动物救助及领养管理系统,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的Java学习者。 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试&…

java错题总结(28-30页)

------------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------------- 不考虑类加载, --------------------------------------------…