16- 梯度提升分类树GBDT (梯度下降优化) (算法)

news2024/11/26 5:11:59
  • 梯度提升算法

from sklearn.ensemble import GradientBoostingClassifier
clf = GradientBoostingClassifier(subsample=0.8,learning_rate = 0.005)
clf.fit(X_train,y_train)


1、交叉熵

1.1、信息熵

  • 构建好一颗树,数据变的有顺序了(构建前,一堆数据,杂乱无章;构建一颗,整整齐齐,顺序),用什么度量衡表示,数据是否有顺序:信息熵
  • 物理学,热力学第二定律(熵),描述的是封闭系统的混乱程度
  • 信息熵,和物理学中熵类似的

        \bg_white \small H(x) = -\sum\limits_{i = 1}^n p(x)log_2p(x)

        H(x) = \sum\limits_{i = 1}^n p(x)log_2\frac{1}{p(x)}

1.2、交叉熵

信息熵可以引出交叉熵

小明在学校玩王者荣耀被发现了,爸爸被叫去开家长会,心里悲屈的很,就想法子惩罚小明。到家后,爸爸跟小明说:既然你犯错了,就要接受惩罚,但惩罚的程度就看你聪不聪明了。这样吧,我们俩玩猜球游戏,我拿一个球,你猜球的颜色,我可以回答你任何问题,你每猜一次,不管对错,你就一个星期不能玩王者荣耀,当然,猜对,游戏停止,否则继续猜。当然,当答案只剩下两种选择时,此次猜测结束后,无论猜对猜错都能100%确定答案,无需再猜一次,此时游戏停止

1.2.1、题目一

爸爸拿来一个箱子,跟小明说:里面有橙、紫、蓝及青四种颜色的小球任意个,各颜色小球的占比不清楚,现在我从中拿出一个小球,你猜我手中的小球是什么颜色?

为了使被罚时间最短,小明发挥出最强王者的智商,瞬间就想到了以最小的代价猜出答案,简称策略1,小明的想法是这样的。

1.2.2、题目二

爸爸还是拿来一个箱子,跟小明说:箱子里面有小球任意个,但其中1/2是橙色球,1/4是紫色球,1/8是蓝色球及1/8是青色球。我从中拿出一个球,你猜我手中的球是什么颜色的?

小明毕竟是最强王者,仍然很快得想到了答案,简称策略2,他的答案是这样的。

这就需要引入交叉熵,其用来衡量在给定的真实分布下,使用非真实分布所指定的策略消除系统的不确定性所需要付出的努力的大小。

1.3、sigmoid

f(x) = \frac{1}{1 + e^{-x}}

f'(x) = \frac{e^{-x}}{(1 + e^{-x})^2} =f(x) * \frac{1 + e^{-x} - 1}{1 + e^{-x}} = f(x) * (1 - f(x))

后面算法推导过程中都会使用到上面的基本方程,因此先对以上概念公式,有基本了解!

2、GBDT分类树

2.1、梯度提升分类树概述

GBDT分类树 sigmoid + 决策回归树 一一> 概率问题!

  • 损失函数是交叉熵

  • 概率计算使用sigmoid

  • 使用 mse 作为分裂标准(同梯度提升回归树)

2.2、梯度提升分类树应用

1、加载数据

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier

X,y = datasets.load_iris(return_X_y = True)
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 1124)

2、普通决策树表现

model = DecisionTreeClassifier()
model.fit(X_train,y_train)
model.score(X_test,y_test)      # 输出:0.8421052631578947

3、梯度提升分类树表现

from sklearn.ensemble import GradientBoostingClassifier
clf = GradientBoostingClassifier(subsample=0.8,learning_rate = 0.005)
clf.fit(X_train,y_train)
clf.score(X_test,y_test)     # 输出:0.9473684210526315

3、GBDT分类树算例演示

3.1、算法公式

  • 概率计算(sigmoid函数)

    p = \frac{1}{1 + exp(-F(x))}

  • 函数初始值(这个函数即是sigmoid分母中的F(x),用于计算概率)

    逻辑回归中的函数是线性函数,GBDT中的函数不是线性函数,但是作用类似!

    F_0(x) = log\frac{\sum\limits_{i=1}^Ny_i}{\sum\limits_{i=1}^N(1 -y_i)}

  • 计算残差公式

    residual = \widetilde{y}= y - \frac{1}{1+exp(-F(x))}

  • 均方误差(根据均方误差,筛选最佳裂分条件)

    mse = ((residual - residual.mean())^2).mean()

  • 决策树叶节点预测值(相当于负梯度)

    ​​​​​\gamma_{mj} = \frac{\sum\limits_{x_i \in R_{mj}}\widetilde{y}_i}{\sum\limits_{x_i \in R_{mj}}(y_i - \widetilde{y}_i)(1 - y_i + \widetilde{y}_i)}

  • 梯度提升

    F_1 = F_0 + \gamma * learning\_rate

  • 根据以上公式,即可进行代码演算了~

3.2、算例演示

3.2.1、创建数据

import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import tree
import graphviz
X = np.arange(1,11).reshape(-1,1)
y = np.array([0,0,0,1,1]*2)
display(X,y)

3.2.2、构造GBDT训练预测

# 默认情况下,损失函数就是Log-loss == 交叉熵!
clf = GradientBoostingClassifier(n_estimators=3,learning_rate=0.1,max_depth=1)
clf.fit(X,y)
y_ = clf.predict(X)
print('真实的类别:',y)
print('算法的预测:',y_)
proba_ = clf.predict_proba(X)
print('预测概率是:\n',proba_)

3.2.3、GBDT可视化

第一棵树

dot_data = tree.export_graphviz(clf[0,0],filled = True)
graph = graphviz.Source(dot_data)
graph

第二棵树

dot_data = tree.export_graphviz(clf[1,0],filled = True)
graph = graphviz.Source(dot_data)
graph

第三棵树

dot_data = tree.export_graphviz(clf[2,0],filled = True)
graph = graphviz.Source(dot_data)
graph

每棵树,根据属性进行了划分,每棵树的叶节点都有预测值,这些具体都是如何计算的呢?且看,下面详细的计算工程~

3.2.4、计算步骤

首先,计算初始值 :

F_0(x)= log\frac{\sum\limits_{i=1}^Ny_i}{\sum\limits_{i=1}^N(1 -y_i)}

F0 = np.log(y.sum()/(1-y).sum())
F0       # 输出结果:-0.40546510810816444
# 此时未裂分,所有的数据都是F0 
F0 = np.array([F0]*10)
# 然后,计算残差
# 残差,F0带入sigmoid计算的即是初始概率
residual0 = y - 1/(1 + np.exp(-F0))
residual0
# 输出:array([-0.4, -0.4, -0.4,  0.6,  0.6, -0.4, -0.4, -0.4,  0.6,  0.6])

3.2.5、拟合第一棵树

根据残差的mse,计算最佳分裂条件

mse = ((residual - residual.mean())^2).mean()

lower_mse = ((residual0 - residual0.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):
    if i == 9:
        mse = ((residual0 - residual0.mean())**2).mean()
    else:
        left_mse = ((residual0[:i+1] - residual0[:i+1].mean())**2).mean()
        right_mse = ((residual0[i+1:] - residual0[i+1:].mean())**2).mean()
        mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10
    if lower_mse > mse:
        lower_mse = mse
        best_split.clear()
        best_split['X[0] <= '] = X[i:i + 2].mean() 
    print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)

现在我们知道了,分裂条件是:X[0] <= 8.5!然后计算决策树叶节点预测值(相当于负梯度),其中的 ​​ 就是​​残差residual0

3.2.6、拟合第二棵树

第一棵树的负梯度(预测值)

# 第一棵预测的结果,负梯度
gamma = np.array([gamma1]*8 + [gamma2]*2)
gamma    '''输出:array([-0.625, -0.625, -0.625, -0.625, -0.625, -0.625,
 -0.625, -0.625, 2.5  ,  2.5  ])'''

梯度提升

F_1 = F_0 + \gamma * learning\_rate

# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F1 = F0 + gamma*learning_rate
F1    ''' 输出  array([-0.46796511, -0.46796511, -0.46796511, -0.46796511, 
-0.46796511, -0.46796511, -0.46796511, -0.46796511, -0.15546511, -0.15546511])'''

根据 F1 计算残差

residual = \widetilde{y}= y - \frac{1}{1+exp(-F(x))}

residual1 = y - 1/(1 + np.exp(-F1))
residual1    '''array([-0.38509799, -0.38509799, -0.38509799,  0.61490201,  
0.61490201, -0.38509799, -0.38509799, -0.38509799,  0.53878818,  0.53878818])'''

根据新的残差residual1的mse,计算最佳分裂条件

lower_mse = ((residual1 - residual1.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):
    if i == 9:
        mse = ((residual1 - residual1.mean())**2).mean()
    else:
        left_mse = ((residual1[:i+1] - residual1[:i+1].mean())**2).mean()
        right_mse = ((residual1[i+1:] - residual1[i+1:].mean())**2).mean()
        mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10
    if lower_mse > mse:
        lower_mse = mse
        best_split.clear()
        best_split['X[0] <= '] = X[i:i + 2].mean() 
    print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)

现在我们知道了,第二棵树分裂条件是:X[0] <= 8.5 !然后计算决策树叶节点预测值(相当于负梯度),其中的 ​​​ 就是​​残差residual1

3.2.7、拟合第三棵树

第二棵树的负梯度

# 第二棵树预测值
gamma = np.array([gamma1]*8 + [gamma2]*2)
gamma

梯度提升

# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F2 = F1 + gamma*learning_rate 
F2

根据 F2 计算残差

residual2 = y - 1/(1 + np.exp(-F2))
residual2

根据新的残差residual2的 mse,计算最佳分裂条件

lower_mse = ((residual2 - residual2.mean())**2).mean()
best_split = {}
# 分裂标准 mse
for i in range(0,10):
    if i == 9:
        mse = ((residual2 - residual2.mean())**2).mean()
    else:
        left_mse = ((residual2[:i+1] - residual2[:i+1].mean())**2).mean()
        right_mse = ((residual2[i+1:] - residual2[i+1:].mean())**2).mean()
        mse = left_mse*(i+1)/10 + right_mse*(10-i-1)/10
    if lower_mse > mse:
        lower_mse = mse
        best_split.clear()
        best_split['X[0] <= '] = X[i:i + 2].mean() 
    print('从第%d个进行分裂'%(i + 1),np.round(mse,4))
# 从第八个样本这里进行分类,最优的选择,和算法第一颗画图的结果一致
print('最小的mse是:',lower_mse)
print('最佳裂分条件是:',best_split)

现在我们知道了,第三棵树分裂条件是:X[0] <= 3.5!然后计算决策树叶节点预测值(相当于负梯度),其中的 ​​​​ 就是​​残差residual2

#  计算第三颗树的预测值
# 前三个是一类
# 后七个是一类
# 左边分支
gamma1 = residual2[:3].sum()/((y[:3] - residual2[:3])*(1 - y[:3] + 
residual2[:3])).sum()
print('第三棵树左边决策树分支,预测值:',gamma1)

# 右边分支
gamma2 =residual2[3:].sum()/((y[3:] - residual2[3:])*(1 - y[3:] + 
residual2[3:])).sum()
print('第三棵树右边决策树分支,预测值:',gamma2)

3.2.8、预测概率计算

计算第三棵树的F3(x)

# 第三棵树预测值
gamma = np.array([gamma1]*3 + [gamma2]*7)
# F(x) 随着梯度提升树,提升,发生变化
learning_rate = 0.1
F3 = F2 + gamma*learning_rate

概率公式如下:

proba = 1/(1 + np.exp(-F3))
# 类别:0,1,如果这个概率大于等于0.5类别1,小于0.5类别0
display(proba)
# 进行转换,类别0,1的概率都展示
np.column_stack([1- proba,proba])
# 算法预测概率 
clf.predict_proba(X)

结论:

  • 手动计算的概率和算法预测的概率完全一样!

  • GBDT分类树,计算过程原理如上

4、GBDT分类树原理推导

4.1、损失函数:

  • 定义交叉熵为函数​  \psi(y,F(x))

        \psi(y,F(x)) = -yln(p) - (1-y)ln(1-p)

        其中​​ p = \frac{1}{1 + exp(-F(x))},即sigmoid函数

  • F(x) 表示决策回归树 DecisionTreeRegressor F(x) 表示每一轮决策树的value,即负梯度

4.2、损失函数化简

  • 损失函数化简:

  • ​​\psi(y,F(x)) = -yF(x) + ln(1 + exp(F(x)))

  • 化简过程

4.3、损失函数求导

将F(x)看成整体变量,进行求导

一阶导数:

\begin{aligned}\psi'(y,F(x)) & = -y + \frac{exp(F(x))}{1 + exp(F(x))}\\\\&=-y + \frac{1}{1 + exp(-F(x))}\\\\&= -y + \sigma(F(x))\end{aligned}

4.4、初始值 F_0(x)计算

F_0(x) = log\frac{\sum\limits_{i=1}^Ny_i}{\sum\limits_{i=1}^N(1 -y_i)}

4.4.1、初始值方程构建

之前的GBDT回归树,初始值是多少:平均值

现在的GBDT分类树 ,计算初始值 F_0(x)​​ ,令​​ \rho = F_0(x)

​​

5、GBDT二分类步骤总结

Step - 1:

        F_0(x) = log\frac{\sum\limits_{i=1}^Ny_i}{\sum\limits_{i=1}^N(1 -y_i)}​​

Step - 2:for i in range(M):

a. ​​\widetilde{y}_i = -\left[\frac{\partial \psi(y_i,F(x_i))}{\partial F(x_i)}\right]_{F(x_i) = F_{m-1}(x_i)} = y_i - \frac{1}{1 + exp(-F_{m-1}(x_i))}

b. 根据残差 ​​​\widetilde{y} ,寻找最小 mse 裂分条件

c. \gamma_{mj} = \frac{\sum\limits_{x_i \in R_{mj}}\widetilde{y_i}}{\sum\limits_{x_i \in R_{mj}}(y_i - \widetilde{y_i})(1 - y_i + \widetilde{y_i})}​​

d. ​​F_m(x) = F_{m-1}(x) + \gamma * learning\_rate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/345743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jvm对象创建与内存解析

1.类加载检查虚拟机遇到一条new指令时&#xff0c;首先将去检查这个指令的参数是否能在常量池中定位到一个类的符号引用&#xff0c;并且检查这个符号引用代表的类是否已被加载、解析和初始化过。如果没有&#xff0c;那必须先执行相应的类加载过程。new指令对应到语言层面上讲…

【node.js】node.js的安装和配置

文章目录前言下载和安装Path环境变量测试推荐插件总结前言 Node.js是一个在服务器端可以解析和执行JavaScript代码的运行环境&#xff0c;也可以说是一个运行时平台&#xff0c;仍然使用JavaScript作为开发语言&#xff0c;但是提供了一些功能性的API。 下载和安装 Node.js的官…

linux篇【14】:网络https协议

目录 一.HTTPS介绍 1.HTTPS 定义 2.HTTP与HTTPS &#xff08;1&#xff09;端口不同&#xff0c;是两套服务 &#xff08;2&#xff09;HTTP效率更高&#xff0c;HTTPS更安全 3.加密&#xff0c;解密&#xff0c;密钥 概念 4.为什么要加密&#xff1f; 5.常见的加密方式…

裸辞5个月,面试了37家公司,终于找到理想工作了

上半年裁员&#xff0c;下半年裸辞&#xff0c;有不少人高呼裸辞后躺平真的好快乐&#xff01;但也有很多人&#xff0c;裸辞后的生活五味杂陈。 面试37次终于找到心仪工作 因为工作压力大、领导PUA等各种原因&#xff0c;今年2月下旬我从一家互联网小厂裸辞&#xff0c;没想…

linux高级命令之用户相关操作

用户相关操作学习目标能够知道创建用户的命令1. 创建用户命令说明useradd创建(添加)用户useradd命令选项:选项说明-m自动创建用户主目录,主目录的名字就是用户名-g指定用户所属的用户组&#xff0c;默认不指定会自动创建一个同名的用户组创建用户效果图:查看所有用户信息的文件…

nginx-host绕过实例复现

绕过Nginx Host限制第一种处理方法Nginx在处理Host的时候&#xff0c;会将Host用冒号分割成hostname和port&#xff0c;port部分被丢弃。所以&#xff0c;我们可以设置Host的值为2023.mhz.pw:xxx"example.com&#xff0c;这样就能访问到目标Server块&#xff1a;第二种处理…

SpringBoot的定时任务实现--SpringTask

SpringTask是Spring自带的功能。实现起来比较简单。 使用SpringTask实现定时任务有两种方式&#xff1a; 1.注解方式 基于注解 Scheduled Scheduled(cron "*/1 * * * * ?")public void up(){System.out.println("定时任务开启&#xff1a;"System.cu…

想做好项目经理,一定要知道这10句话

早上好&#xff0c;我是老原。有句话说过&#xff1a;“你是怎么过好一天的&#xff0c;就是怎么过好一生的。”这句话&#xff0c;我刚毕业那会没什么感觉&#xff0c;但工作越久&#xff0c;体会越深。你会发现优秀的人有些特质和习惯千篇一律&#xff0c;而普通人&#xff0…

深圳80后男子朋友圈晒情人节,一天收三个不同女子巧克力红包

每年情人节到来的时候&#xff0c;对于广大男同胞来说&#xff0c;都是倍受煎熬的日子&#xff0c;因为不论你怎么去做&#xff0c;都不会落到好处。如果你还没有对象&#xff0c;这个情人节就尴尬了&#xff0c;眼看着别人出入成双成对&#xff0c;自己却落得个孤家寡人。 如果…

微信Android架构历史——模块化架构重构实践

微信Android诞生之初&#xff0c;用的是常见的分层结构设计。这种架构简单、清晰并一直沿袭至今。这是微信架构的v1.x时代。 图1-架构演进 到了微信架构的v2.x时代&#xff0c;随着业务的快速发展&#xff0c;消息通知不及时和Android 2.3版本之前webview内存泄露问题开始突显…

java基于springboot+vue微信小程序的医疗监督反馈小程序

医疗监督反馈行业是一个传统的行业。根据当前发展现状,网络信息时代的全面普及,医疗监督反馈行业也在发生着变化,单就下单这一方面,利用手机下单正在逐步进入人们的生活。 传统的下单方式,不仅会耗费大量的人力、时间,有时候还会出错。小程序系统伴随智能手机为我们提供了新的方…

【贝叶斯方法】无论您是数据统计分析初学者,还是有一定基础

包括回归及结构方程模型概述及数据探索&#xff1b;R和Rstudio简介及入门和作图基础&#xff1b;R语言数据清洗-tidyverse包&#xff1b;贝叶斯回归与混合效应模型&#xff1b;贝叶斯空间自相关、时间自相关及系统发育相关数据分析&#xff1b;贝叶斯非线性数据分析;贝叶斯结构…

API数据是什么?举例说明,它是电商平台发展的领航者

API接口&#xff1a; API接口是什么&#xff1f; API全称是&#xff1a;Application Programming Interface&#xff0c;即&#xff1a;应用程序接口。开发人员可以使用这些API接口进行编程开发&#xff0c;而又无需访问源码&#xff0c;或理解内部工作机制的细节。 比较常见…

外包公司“混”了2年,我只认真做了5件事,如今顺利拿到腾讯Offer。

前言 是的&#xff0c;我一家外包公司工作了整整两年时间&#xff0c;在入职这家公司前&#xff0c;也就是两年前&#xff0c;我就开始规划了我自己的人生&#xff0c;所以在两年时间里&#xff0c;我并未懈怠。 现如今&#xff0c;我已经跳槽到了腾讯&#xff0c;顺利拿下 o…

项目(今日指数之登录功能)

今日目标1. 完善基于前后端分用户验证码登录功能; 2. 理解验证码生成流程,并使用postman测试; 3. 理解并实现国内大盘数据展示功能; 4. 理解并实现国内板块数据展示功能; 5. 理解后端接口调试和前后端联调的概念;1.验证码登录功能1.1 验证码功能分析1&#xff09;前后端分离架构…

【JAVA】jdk8 Stream 排序精通

背景 jdk8的stream流能方便的排序&#xff0c;但是每次都要查资料&#xff0c;非常不方便&#xff0c;不确定&#xff0c;所以这次直接弄懂&#xff0c;不再迷茫。 转载请注明来源&#xff0c;创作不易&#xff0c;请多多支持。 基础排序 stream流 大家应该都比较熟悉了&…

react-01-jsx语法与react实例三大属性与react生命周期

英文官网: https://reactjs.org/ 中文官网:https://react.docschina.org/ 基本知识 1、jsx语法 标签中使用js表达式用{} jsx中样式叫className 内联样式使用style{{key:value}}去写 只有一个根标签 标签必须闭合 标签首字母 (1).若小写字母开头&#xff0c;则将该标签…

网络安全领域中CISP证书八大类都有什么

CISP​注册信息安全专业人员 注册信息安全专业人员&#xff08;Certified Information Security Professional&#xff09;&#xff0c;是经中国信息安全产品测评认证中心实施的国家认证&#xff0c;对信息安全人员执业资质的认可。该证书是面向信息安全企业、信息安全咨询服务…

P1217 [USACO1.5]回文质数 Prime Palindromes

[USACO1.5]回文质数 Prime Palindromes 题目描述 因为 151151151 既是一个质数又是一个回文数&#xff08;从左到右和从右到左是看一样的&#xff09;&#xff0c;所以 151151151 是回文质数。 写一个程序来找出范围 [a,b](5≤a<b≤100,000,000)[a,b] (5 \le a < b \l…

idea使用本地代码远程调试线上运行代码---windows环境

场景&#xff1a; 今天在书上看了一个代码远程调试的方法&#xff0c;自己本地验证了一下感觉十分不错&#xff01;&#xff01; windows环境&#xff1a; 启动测试jar包&#xff1a;platform-multiappcenter-base-app-1.0.0-SNAPSHOT.jar 测试工具&#xff1a;postman,idea 应…