梯度提升算法决策过程的逐步可视化

news2025/1/16 5:02:03

梯度提升算法是最常用的集成机器学习技术之一,该模型使用弱决策树序列来构建强学习器。这也是XGBoost和LightGBM模型的理论基础,所以在这篇文章中,我们将从头开始构建一个梯度增强模型并将其可视化。

梯度提升算法介绍

梯度提升算法(Gradient Boosting)是一种集成学习算法,它通过构建多个弱分类器,然后将它们组合成一个强分类器来提高模型的预测准确率。

梯度提升算法的原理可以分为以下几个步骤:

  1. 初始化模型:一般来说,我们可以使用一个简单的模型(比如说决策树)作为初始的分类器。
  2. 计算损失函数的负梯度:计算出每个样本点在当前模型下的损失函数的负梯度。这相当于是让新的分类器去拟合当前模型下的误差。
  3. 训练新的分类器:用这些负梯度作为目标变量,训练一个新的弱分类器。这个弱分类器可以是任意的分类器,比如说决策树、线性模型等。
  4. 更新模型:将新的分类器加入到原来的模型中,可以用加权平均或者其他方法将它们组合起来。
  5. 重复迭代:重复上述步骤,直到达到预设的迭代次数或者达到预设的准确率。

由于梯度提升算法是一种串行算法,所以它的训练速度可能会比较慢,我们以一个实际的例子来介绍:

假设我们有一个特征集Xi和值Yi,要计算y的最佳估计

我们从y的平均值开始

每一步我们都想让F_m(x)更接近y|x。

在每一步中,我们都想要F_m(x)一个更好的y给定x的近似。

首先,我们定义一个损失函数

然后,我们向损失函数相对于学习者Fm下降最快的方向前进:

因为我们不能为每个x计算y,所以不知道这个梯度的确切值,但是对于训练数据中的每一个x_i,梯度完全等于步骤m的残差:r_i!

所以我们可以用弱回归树h_m来近似梯度函数g_m,对残差进行训练:

然后,我们更新学习器

这就是梯度提升,我们不是使用损失函数相对于当前学习器的真实梯度g_m来更新当前学习器F_{m},而是使用弱回归树h_m来更新它。

也就是重复下面的步骤

1、计算残差:

2、将回归树h_m拟合到训练样本及其残差(x_i, r_i)上

3、用步长\alpha更新模型

看着很复杂对吧,下面我们可视化一下这个过程就会变得非常清晰了

决策过程可视化

这里我们使用sklearn的moons 数据集,因为这是一个经典的非线性分类数据

 import numpy as np
 import sklearn.datasets as ds
 import pandas as pd
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 
 from sklearn import tree
 from itertools import product,islice
 import seaborn as snsmoonDS = ds.make_moons(200, noise = 0.15, random_state=16)
 moon = moonDS[0]
 color = -1*(moonDS[1]*2-1)
 
 df =pd.DataFrame(moon, columns = ['x','y'])
 df['z'] = color
 df['f0'] =df.y.mean()
 df['r0'] = df['z'] - df['f0']
 df.head(10)

让我们可视化数据:

下图可以看到,该数据集是可以明显的区分出分类的边界的,但是因为他是非线性的,所以使用线性算法进行分类时会遇到很大的困难。

那么我们先编写一个简单的梯度增强模型:

 def makeiteration(i:int):
     """Takes the dataframe ith f_i and r_i and approximated r_i from the features, then computes f_i+1 and r_i+1"""
     clf = tree.DecisionTreeRegressor(max_depth=1)
     clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])
     df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)
     
     eta = 0.9
     df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']
     df[f'r{i}'] = df['z'] - df[f'f{i}']
     rmse = (df[f'r{i}']**2).sum()
     clfs.append(clf)
     rmses.append(rmse)

上面代码执行3个简单步骤:

将决策树与残差进行拟合:

 clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])
 df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)

然后,我们将这个近似的梯度与之前的学习器相加:

 df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']

最后重新计算残差:

 df[f'r{i}'] = df['z'] - df[f'f{i}']

步骤就是这样简单,下面我们来一步一步执行这个过程。

第1次决策

Tree Split for 0 and level 1.563690960407257

第2次决策

Tree Split for 1 and level 0.5143677890300751

第3次决策

Tree Split for 0 and level -0.6523728966712952

第4次决策

Tree Split for 0 and level 0.3370491564273834

第5次决策

Tree Split for 0 and level 0.3370491564273834

第6次决策

Tree Split for 1 and level 0.022058885544538498

第7次决策

Tree Split for 0 and level -0.3030575215816498

第8次决策

Tree Split for 0 and level 0.6119407713413239

第9次决策

可以看到通过9次的计算,基本上已经把上面的分类进行了区分

我们这里的学习器都是非常简单的决策树,只沿着一个特征分裂!但整体模型在每次决策后边的越来越复杂,并且整体误差逐渐减小。

 plt.plot(rmses)

这也就是上图中我们看到的能够正确区分出了大部分的分类

如果你感兴趣可以使用下面代码自行实验:

https://avoid.overfit.cn/post/533a0736b7554ef6b8464a5d8ba964ab

作者:Tanguy Renaudie

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/387581.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【VC 7/8】vCenter Server 基于文件的备份和还原Ⅰ——基于文件的备份和还原的注意事项和限制

目录1.1 协议1.2 还原后配置说明1.3 Storage DRS1.4 分布式电源管理1.5 分布式虚拟交换机1.6 内容库1.7 虚拟机生命周期操作1.8 vSphere High Availability1.9 基于存储策略的管理1.10 其它注意事项虚拟存储区域网络修补关联博文[图片来源]:https://www.vmignite.co…

ARM uboot 源码分析9 - uboot的硬件驱动部分

一、uboot 与 linux 驱动 1、uboot 本身是裸机程序 (1) 裸机本来是没有驱动的概念的(狭义的驱动的概念就是,操作系统中用来具体操控硬件的那部分代码叫驱动) (2) 裸机程序中是直接操控硬件的,操作系统中必须通过驱动来操控硬件…

Java 8 新特性之Stream流(二)关键

继续探索流API的高级功能之前,我们先从接口级别全面了解一下流API,这个对于我们来说是至关重要的。下面是一张流API关键知识点的UML图。 流API UML 流API定义的几个接口,都是在java.util.stream包中的.其中上图中的BaseStream接口是最基础的…

每日记录自己的Android项目(二)—Viewbinding,WebView,Navigation

今日想法今天是想把做一个跳转页面的时候调到H5页面去,但是这个页面我用app来承载,不要调到浏览器去。所以用到了下方三个东西。Viewbindingbuild.gradle配置首先在app模块的build.gradle里添加一下代码默认情况下,每一个布局xml文件都会生成…

【Linux学习】基础IO——理解缓冲区 | 理解文件系统

🐱作者:一只大喵咪1201 🐱专栏:《Linux学习》 🔥格言:你只管努力,剩下的交给时间! 基础IO☕理解缓冲区🧃缓冲区的共识🧃缓冲区的位置🧃缓冲区的刷…

Spring Boot+Vue前后端分离项目练习03之网盘项目文件夹创建及文件查询接口开发

1.集成Swagger 3接口文档 在前后端分离的项目中,接口文档的存在十分重要。swagger 是一个自动生成接口文档的工具,在需求变更十分频繁的情况下,手写接口文档是效率十分低下,这时swagger自动生生文档的的作用就体现出来了&#xf…

【uni-app教程】UniAPP 常用组件和 常用 API 简介# 知心姐姐聊天案例

五、UniAPP 常用组件简介 uni-app 为开发者提供了一系列基础组件,类似 HTML 里的基础标签元素,但 uni-app 的组件与 HTML 不同,而是与小程序相同,更适合手机端使用。 虽然不推荐使用 HTML 标签,但实际上如果开发者写了…

华为机试题:HJ105 记负均正II(python)

文章目录(1)题目描述(2)Python3实现(3)知识点详解1、input():获取控制台(任意形式)的输入。输出均为字符串类型。1.1、input() 与 list(input()) 的区别、及其相互转换方…

【Kubernetes】第十七篇 - ECS 服务停机和环境修复

一,前言 上一篇,介绍了 Secret 镜像的使用; 三台服务每天大概 15 块钱的支出,用一个月也是不少钱; 闲时可以停掉,这样每天只有 4 块钱支出,剩下一大笔; ECS 服务停机后公网 IP 会…

移除元素(每日一题)

目录 一、题目描述 二、题目分析 2.1 方法一 2.1.1 思路 2.1.2 代码 2.2 方法二 2.2.1 思路 2.2.2 代码 一、题目描述 题目链接:27. 移除元素 - 力扣(LeetCode) 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数…

【Maven】P1 Maven 基础知识

Maven 基础知识Maven基础仓库坐标快速坐标生成网站国内镜像仓库前言 本节:Maven第一节内容,记录maven是什么,解决了什么问题,进而推出他的作用;然后介绍maven中两个重要概念,仓库与坐标。 下一节&#xff1…

TIA博途中使用SCL语言实现选择排序算法并封装成FC全局库

TIA博途中使用SCL语言实现选择排序算法并封装成FC全局库 选择排序算法包括升序和降序2种: 升序排列: 第一轮从数据源中找到最小值排在第一位,第二轮从剩下的数据中寻找最小值排在第二位,依次类推,直到所有数据完成遍历;降序排列: 第一轮从数据源中找到最大值排在第一位,…

centOS 编译strongswan

安装编译环境 yum groupinstall "Development Tools" 编译strongswan Download strongSwan: wget https://download.strongswan.org/strongswan-x.x.x.tar.bz2 Unpack the tarball and navigate into the directory: tar xjf strongswan-x.x.x.tar.bz2; cd strong…

Editor.md 的使用方法及图片处理

目录1. 资源下载2. 生成页面2.1 编辑和预览页面2.2 文本渲染页面3. 图片上传3.1 前端配置3.2 后端接口4. 图片粘贴1. 资源下载 官网下载 gitee 下载 2. 生成页面 2.1 编辑和预览页面 将资源(精简后 Editor.md 资源1)导入项目: 按照官方教…

nvidia Jetson nano Linux内核编译

今天编译了nvidia 的jetson nano的内核。在网上找到的资料都比较老了。现在官网的最新版本是35.1.结合之前看到的博客的内容。关键是内核源码和交叉编译器的下载。找到官方文档后,编译成功!并且官方的文档是有一个编译脚本的。看之前的资料都是给出的命令,不知道这个nvbuild…

库函数qsort的使用以及模拟实现

首先&#xff0c;qsort函数是个库函数 那么就有头文件 #include<stdlib.h>这个函数的实现是利用快速排序的方法实现的 下面是该函数的参数//void qsort(void* base, //指向了待排序数组的第一个元素 // size_t num, //待排序的元素个数 // size_t …

FLoyd算法的入门与应用

目录 一、前言 二、FLoyd算法 1、最短路问题 2、Floyd算法 3、Floyd的特点 4、Floyd算法思想&#xff1a;动态规划 三、例题 1、蓝桥公园&#xff08;lanqiaoOJ题号1121&#xff09; 2、路径&#xff08;2021年初赛 lanqiaoOJ题号1460&#xff09; 一、前言 本文主要…

Cannot start compiler The output path is not specified for module mystatic(已解决)

1.背景&#xff1a;今天在idea上写了一些代码&#xff0c;右键run竟然跑不起来了&#xff0c;而且右下角的Event Log还报错。报错内容如下图&#xff1a;2.报错原因&#xff1a;项目代码和编译器的输出路径不在一块&#xff0c;导致idea无法找到模块的output path&#xff08;输…

Docker--(六)--Docker资源限制

前言系统压力测试Cpu资源限制Mem资源限制IO 资源限制【扩展】 1.前言 在使用 Docker 运行容器时&#xff0c;一台主机上可能会运行几百个容器&#xff0c;这些容器虽然互相隔离&#xff0c;但是底层却使用着相同的 CPU、内存和磁盘资源。如果不对容器使用的资源进行限制&#x…

VUE中给对象添加新属性时,界面不刷新怎么办

一、直接添加属性的问题 举例&#xff1a; 定义一个p标签&#xff0c;通过v-for指令进行遍历 然后给botton标签绑定点击事件&#xff0c;我们预期点击按钮时&#xff0c;数据新增一个属性&#xff0c;界面也 新增一行。 <p v-for"(value,key) in item" :key&qu…