排序算法经典模型: 梯度提升决策树(GBDT)的应用实战

news2024/12/23 20:06:49

目录

一、Boosting训练与预测

二、梯度增强的思想核心

三、如何构造弱学习器和加权平均的权重

四、损失函数

五、梯度增强决策树

六、GBDT生成新特征

主要思想

构造流程

七、梯度增强决策树以及在搜索的应用

7.1 GDBT模型调参

7.1.1 框架层面参数

n_estimators

subsample

7.1.2 分类/回归树层面参数

最大特征数max_features

决策树最大深度max_depth

部节点再划分所需最小样本数min_samples_split

叶子节点最少样本数min_samples_leaf

7.2 K折交叉验证找到最佳超参数

交叉验证的优点

交叉验证的缺点

基于k折交叉验证的网格搜索法

7.3  GBDT在推荐系统中的排序算法示例


一、Boosting训练与预测

Boosting训练与预测

Boosting训练过程为串型,基模型按次序一一进行训练,基模型的训练集按照某种策略每次都进行一定的更新。对所有基模型预测的结果进行线性综合产生最终的预测结果

GBDT是将梯度下降和 Boosting 方法结合的算法。它采用决策树模型,并定义一个损失函数,通过梯度下降来优化模型。

二、梯度增强的思想核心

梯度增强首先还是增强算法的一个扩展,也是希望能用一系列的弱学习器来达到一个强学习器的效果,从而逼近目标变量的值,也就是我们常说的标签值。而根据加性模型的假设,这种逼近效果是这些弱学习器的一个加权平均。也就是说,最终的预测效果,是所有单个弱学习器的一个平均效果,只不过这个平均不是简单的平均,而是一个加权的效果。

三、如何构造弱学习器和加权平均的权重

梯度增强采用了一个统计学或者说是优化理论的视角,使得构造这些部分变得更加直观。

梯度增强的作者们意识到,如果使用“梯度下降”(Gradient Descent)来优化一个目标函数,最后的预测式可以写成一个加和的形式。也就是,每一轮梯度的值和一个叫“学习速率”(Learning Rate)的参数共同叠加起来形成了最后的预测结果。这个观察非常重要,如果把这个观察和我们的目标,也就是构造弱学习器的加权平均联系起来看,我们就会发现,其实每个梯度的值就可以认为是一个弱学习器,而学习速率就可以看作是某种意义上的权重

首先,这是一个迭代算法。每一轮迭代,我们把当前所有学习器的加权平均结果当作这一轮的函数值,然后求得针对某一个损失函数对于当前所有学习器的参数的一个梯度。然后,我们利用某一个弱学习器算法,可以是线性回归模型(Linear Regression)、对数几率模型(Logistic Regression)等来拟合这个梯度。最后,我们利用“线查找”(Line Search)的方式找到权重。说得更直白一些,那就是我们尝试利用一些简单的模型来拟合不同迭代轮数的梯度。

四、损失函数

损失函数是用来量化模型预测值与实际值之间差异的函数。在训练模型时,损失函数的值被用来通过优化算法(如梯度下降)调整模型参数,目标是最小化这个损失值

常见的损失函数

对于GBDT来说,如果是用于回归问题,那么通常选择平Squared Error Loss(方误差损失);如果是用于分类问题,尤其是二分类问题,通常选择Logistic Regression Loss(逻辑回归损失)。请注意,GBDT用于多分类问题时,会使用对数损失的多分类版本。

五、梯度增强决策树

梯度增强决策树就是利用决策树,这种最基本的学习器来当作弱学习器,去拟合梯度增强过程中的梯度。然后融合到整个梯度增强的过程中,最终,梯度增强决策树其实就是每一轮迭代都拟合一个新的决策树用来表达当前的梯度,然后跟前面已经有的决策树进行叠加。在整个过程中,决策树的形状,比如有多少层、总共有多少节点等,都是可以调整的或者学习的超参数。而总共有多少棵决策树,也就是有多少轮迭代是重要的调节参数,也是防止整个学习过程过拟合的重要手段。

六、GBDT生成新特征

主要思想

GBDT每棵树的路径直接作为LR输入特征使用

构造流程

用已有特征训练GBDT模型,然后利用GBDT模型学习到的树来构造新特征。构造的新特征向量是取值0/1的,向量的每个元素对应于GBDT模型中树的叶子结点。当一个样本点通过某棵树最终落在这棵树的一个叶子结点上,那么在新特征向量中这个叶子结点对应的元素值为1,而这棵树的其他叶子结点对应的元素值为0。新特征向量的长度等于GBDT模型里所有树包含的叶子结点数之和。 

新特征向量反映了数据点在所有决策树中的路径信息,可以帮助线性模型(如逻辑回归)更好地捕捉数据点的复杂结构和模式,因为决策树能够捕捉非线性关系,而这种关系现在被编码到新特征向量中供线性模型使用。这样的技术常被用于提高模型在各种任务中的表现,尤其是在那些线性模型不足以捕捉数据复杂性的场景中。

七、梯度增强决策树以及在搜索的应用

7.1 GDBT模型调参

7.1.1 框架层面参数

n_estimators

弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说取值太小容易欠拟合;太大又容易过拟合,一般选择一个适中的数值。

subsample

即子采样,取值为(0,1]。注意这里的子采样和随机森林不一样,随机森林使用的是放回抽样,而这里是不放回抽样。如果取值为1,则全部样本都使用;如果取值小于1,则只有一部分样本会去做GBDT的决策树拟合。选择小于1的比例可以减少方差,即防止过拟合,但是会增加样本拟合的偏差,因此取值不能太低。推荐在[0.5, 0.8]之间,默认是1.0,即不使用子采样。

7.1.2 分类/回归树层面参数

最大特征数max_features

默认是“None”,即 考虑所有的特征数。如果是整数,代表考虑的特征绝对数。如果是浮点数,代表考虑特征百分比。一般来说,如果样本特征数不多,比如小于50,可以用默认的“None”,如果特征数非常多,需要进行网格搜索。

决策树最大深度max_depth

默认可以不输入,此时决策树在建立子树的时候不会限制子树的深度。一般来说,数据少或者特征少的时候可以不管这个值。如果模型样本量多,特征也多则需要限制最大深度,取值取决于数据的分布。常用的可以取值10-100之间。

部节点再划分所需最小样本数min_samples_split

如果某节点的样本数少于min_samples_split,则不会继续再进行划分。 默认是2.
如果样本量不大,不需要调节这个值。如果样本量数量级非常大,则推荐增大这个值。

叶子节点最少样本数min_samples_leaf

如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。 默认是1,可以输入最少的样本数的整数,或者最少样本数占样本总数的百分比。如果样本量不大,不需要调节这个值。如果样本量数量级非常大,则推荐增大这个值。

7.2 K折交叉验证找到最佳超参数

  • 选择K的值(比如10),将数据集分成K等份
  • 使用其中的K-1份数据作为训练数据进行模型的训练
  • 使用一种度量测度在另外一份数据(作为验证数据)衡量模型的预测性能

交叉验证的优点

  • 交叉验证通过降低模型在一次数据分割中性能表现上的方差来保证模型性能的稳定性
  • 交叉验证可以用于选择调节参数、比较模型性能差别、选择特征

交叉验证的缺点

交叉验证带来一定的计算代价,尤其是当数据集很大的时候,导致计算过程会变得很慢

基于k折交叉验证的网格搜索法

GridSearchCV,其作用是自动调参。将每个参数所有可能的取值输入后可以给出最优化的结果和参数。但是该方法适合于小数据集,对于大样本很难得出结果

此时可以使用基于贪心算法的坐标下降进行快速调优:先拿当前对模型影响最大的参数调优,直到最优化,再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,时间效率较高。

以下是一个使用 GridSearchCV 优化 GBDT 参数的简单例子

from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import GradientBoostingClassifier

# 定义 GBDT 模型
gbdt = GradientBoostingClassifier()

# 定义参数网格
param_grid = {
    'n_estimators': [100, 200, 300],
    'learning_rate': [0.01, 0.1, 0.2],
    'max_depth': [3, 4, 5]
}

# 创建 GridSearchCV 对象
grid_search = GridSearchCV(gbdt, param_grid, cv=5, scoring='accuracy')

# 拟合/训练模型
grid_search.fit(X_train, y_train)

# 获取最佳参数组合和模型
best_parameters = grid_search.best_params_
best_model = grid_search.best_estimator_

在这个例子中,GridSearchCV 用于找到 GBDT 模型的最佳超参数,如 n_estimatorslearning_ratemax_depth。通过这种方式,GridSearchCV 提供了一个方便的方法来提高 GBDT 模型的性能。

7.3  GBDT在推荐系统中的排序算法示例

import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import train_test_split

# 生成模拟的推荐系统数据集
# 假设有10个特征,这些特征描述了用户和物品的各种属性和交互
X = np.random.rand(1000, 10)  # 特征矩阵,每行代表用户-物品对的特征
y = np.random.rand(1000)  # 目标变量,代表用户对物品的评分或偏好

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化 GBDT 模型
gbdt = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 训练模型
gbdt.fit(X_train, y_train)

# 使用模型进行预测
predictions = gbdt.predict(X_test)

# 将预测结果和测试集的特征合并,模拟实际推荐场景
recommended_items = pd.DataFrame(X_test, columns=[f'feature_{i}' for i in range(X_test.shape[1])])
recommended_items['predicted_rating'] = predictions

# 对每个用户推荐评分最高的物品
# 在实际应用中,你需要一个用户ID来分组,这里我们简化为按行分组
recommended_items.sort_values(by='predicted_rating', ascending=False, inplace=True)

# 展示每个用户推荐的最高评分物品
recommended_items_per_user = recommended_items.groupby('feature_0', as_index=False).first()
print(recommended_items_per_user)

在上述代码中,X 表示特征矩阵,y 表示目标变量。我们首先使用 train_test_split 分割数据集,然后创建和训练一个GBDT模型。这个模型用来预测用户对物品的评分。最后,我们展示了如何使用这个模型来为用户推荐评分最高的物品。

请注意,这是一个高度简化的示例。在实际推荐系统中,你会有一个用户ID与物品ID,并且你会根据这些ID来构建特征,然后进行排序和推荐。特征可能包括用户的历史行为,物品的内容特征,用户和物品的交互历史等。另外,排序模型的评估可能会使用更复杂的指标,如平均准确率均值(Mean Average Precision)或归一化折扣累积增益(NDCG)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1409376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode:三数之和---双指针

问题: 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所有和为 0 且不重复的三元组。 注意:答案中不可以包含重复…

阿里Animate Anyone:任何静态图像都能动起来,让C罗、梅西、内马尔一起跳科目三!

目录 前言 相关链接 摘要 方法 效果展示 为各种角色制作动画 比较 更多应用 前言 2024年一开年,社交媒体和朋友圈就被一系列惊艳舞蹈视频占领了。钢铁侠跳起了科目三,马斯克也在摆着网红舞步,这些大约10秒的视频都是借助大模…

C# Socket通信从入门到精通(16)——单个同步UDP服务器监听多个客户端C#代码实现

前言: 我们在开发UDP通信程序时,有时候我们也需要开发UDP服务器程序,这个服务器只需要和一个客户端实现通信,比如这篇博文C# Socket通信从入门到精通(15)——单个同步UDP服务器监听一个客户端C#代码实现,但是在实际项目中有的时候需要和多个客户端进行通信,这时和一个…

Pandas实践指南:从基础到高级数据分析

Pandas实践指南:从基础到高级数据分析 引言Pandas基础1. 安装和基本配置2. DataFrame和Series的基础3. 基础数据操作 数据清洗与预处理1. 缺失值处理2. 数据转换3. 数据过滤 数据分析与操作1. 数据聚合和分组操作2. 时间序列数据处理3. 条件逻辑和数据分割 高级数据…

将AWS iot消息数据发送S3

观看此文章之前,请先学习AWS iot的数据收集: 使用Linux SDK客户端向AWS Iot发送数据-CSDN博客 上述的文章向大家展示了如何从客户端向AWS iot发送数据,那么数据收到之后,我们如何通过AWS的服务进行数据处理或者保存呢&#xff1…

Unity - gamma space下还原linear space效果

文章目录 环境目的环境问题实践结果处理要点处理细节【OnPostProcessTexture 实现 sRGB 2 Linear 编码】 - 预处理【封装个简单的 *.cginc】 - shader runtime【shader需要gamma space下还原记得 #define _RECOVERY_LINEAR_IN_GAMMA】【颜色参数应用前 和 颜色贴图采样后】【灯…

【C#】基础巩固

最近写代码的时候各种灵感勃发,有了灵感,就该实现了,可是,实现起来有些不流畅,总是有这样,那样的卡壳,总结下来发现了几个问题。 1、C#基础内容不是特别牢靠,理解的不到位&#xff…

vivo 海量基础数据计算架构应用实践

作者:来自 vivo 互联网大数据团队 本文根据刘开周老师在“2023 vivo开发者大会"现场演讲内容整理而成。公众号回复【2023 VDC】获取互联网技术分会场议题相关资料。 本文介绍了vivo在万亿级数据增长驱动下,基础数据架构建设的演进过程,…

如何创建以业务为中心的AI?

AI是企业的未来,这一趋势越来越明显。各种AI模型可以帮助企业节省时间、提高效率并增加收入。随着越来越多的企业采用AI,AI很快就不再是一种可有可无的能力,而是企业参与市场竞争的必备能力。 然而,作为一名业务决策者&#xff0c…

【jetson笔记】torchaudio报错

原因是因为pip安装的包与jetson不兼容导致 自己安装或者cmake编译也会报错 需要拉取官方配置好的docker镜像 拉取docker镜像 具体容器可以看官网,按照自己需求拉取即可 https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-ml 如果其他包不需要只需要torc…

【学习笔记】遥感影像分类相关精度指标

文章目录 0.混淆矩阵1. 精度名词解释2. Kappa系数3.举个栗子参考资料 0.混淆矩阵 混淆矩阵是分类精度的评定指标。是一个用于表示分为某一类别的像元个数与地面检验为该类别数的比较阵列。 对检核分类精度的样区内所有的像元,统计其分类图中的类别与实际类别之间的…

来自世坤!寻找Alpha 构建交易策略的量化方法

问:常常看到有人说Alpha seeking,这究竟是什么意思? 推荐这本《Finding Alphas: A Quantitative Approach to Building Trading Strategies》。我拿到的PDF是2019年的第二版。来自WorldQuant(世坤)的Igor Tulchinshky…

【数据结构与算法】栈(Stack)之 浅谈数组和链表实现栈各自的优缺点

文章目录 1.栈介绍2. 哪种结构实现栈会更优?3.栈代码实现(C语言) 往期相关文章: 线性表之顺序表线性表之链表 1.栈介绍 栈是一种特殊的线性表,只允许在栈顶(Top)进行插入和删除元素操作&#…

Toolbar

记录一下遇到的问题 Toolbal 使用过程中左右出现间隙 代码&#xff1a; <com.google.android.material.appbar.AppBarLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http://schemas.android.com/apk/res-auto"xmlns:t…

SAP 消息编号 KI235

在执行AFAB折旧运行的时候&#xff0c;折旧没有运行出来 通过AFBP查询&#xff0c;出现一下报错 原因是因为在ASCET当中没有配置科目分配对象&#xff0c;所以系统无法把折旧费和CO&#xff08;成本中心&#xff09;关联起来 “科目设置”必选勾选 重新运行AFAB &#xff0c;就…

【新书推荐】2.4节 数据宽度

本节内容&#xff1a;计算机受制于物理器件的制约&#xff0c;存储或读写数据的宽度是有长度限制的&#xff0c;通常我们使用数据位的位数来表示数据宽度&#xff0c;如8位、16位、32位、64位等。 ■计算机计数与数学计数的区别&#xff1a;数学中的数据可以是无穷大或无穷小&a…

01.领域驱动设计:微服务设计为什么要选择DDD学习总结

目录 1、前言 2、软件架构模式的演进 3、微服务设计和拆分的困境 4、为什么 DDD适合微服务 5、DDD与微服务的关系 6、总结 1、前言 我们知道&#xff0c;微服务设计过程中往往会面临边界如何划定的问题&#xff0c;不同的人会根据自己对微服务的理 解而拆分出不同的微服…

解读IP风险画像标签:深度洞察网络安全

在当今数字化的世界中&#xff0c;网络安全成为企业和个人关注的焦点。IP风险画像标签作为网络安全的利器&#xff0c;扮演着深度洞察网络风险的角色。本文将深入解读IP风险画像标签&#xff0c;揭示其在网络安全领域的重要性和功能。 1. IP风险画像标签是什么&#xff1f; I…

Kubernetes/k8s之安全机制:

k8s当中的安全机制 核心是分布式集群管理工具&#xff0c;容器编排&#xff0c;安全机制核心是:API SERVER作为整个集群内部通信的中介&#xff0c;也是外部控制的入口&#xff0c;所有的安全机制都是围绕api server开设计的。 请求api资源 1、认证 2、鉴权 3、准入机制 三…

Java设计模式-装饰器模式(10)

大家好,我是馆长!今天开始我们讲的是结构型模式中的装饰器模式。老规矩,讲解之前再次熟悉下结构型模式包含:代理模式、适配器模式、桥接模式、装饰器模式、外观模式、享元模式、组合模式,共7种设计模式。。 装饰器模式(Decorator Pattern) 定义 装饰(Decorator)模式…