机器学习实战:Python基于GBM梯度提升机进行预测(十四)

news2024/11/24 12:27:44

这篇干货很硬,喜欢的小伙伴点个赞/收藏,持续更新!

文章目录

    • 1.前言
      • 1.1 GBM的介绍
      • 1.2 GBM的应用
    • 2. scikit-learn实战演示
      • 2.1 分类问题
      • 2.2 回归问题
    • 3. GBM超参数
      • 3.1 决策树数量(n_estimators)
      • 3.2 样本数量(subsample)
      • 3.3 特征数量(max_features)
      • 3.4 学习率(learning_rate)
      • 3.5 决策树深度(max_depth)
    • 4.讨论

1.前言

1.1 GBM的介绍

梯度提升机(Gradient Boosting Machine,简称GBM)是一种强大的机器学习算法,它是集成学习的一种形式。GBM在解决分类和回归问题上表现优异,是数据科学领域中常用的算法之一。GBM通过组合多个弱学习器(通常是决策树)来构建一个强大的预测模型。训练过程采用梯度提升技术,逐步改进模型的预测能力。每一轮迭代中,新的弱学习器被训练来纠正前一轮模型的错误,以尽可能减少模型对数据的残差。最终,所有弱学习器的结果加权融合,得到最终的预测结果。

优点

  • 高预测准确性:GBM在许多数据集上表现出色,通常可以获得较高的预测准确性。

  • 处理非线性关系:GBM能够很好地捕捉复杂数据中的非线性关系。

  • 鲁棒性:GBM对于噪声和异常值相对较为鲁棒,它可以通过组合多个模型来减少单个模型的过拟合风险。

  • 特征重要性评估:GBM可以提供每个特征的重要性评估,帮助理解数据中哪些特征对预测影响最大。

  • 可扩展性:GBM可以适用于大规模数据集,并且在现代计算平台上可以高效实现。

缺点

  • 训练时间较长:相比于一些简单的线性模型,GBM的训练时间可能较长,特别是在复杂模型和大规模数据集上。

  • 超参数调整:GBM的性能对于一些超参数(如学习率、树的数量等)较为敏感,需要仔细的调整和优化。

  • 可能出现过拟合:如果不谨慎设置超参数或者训练集数据较小,GBM有可能出现过拟合,导致在新数据上表现不佳。

  • 不适合高维稀疏数据:对于高维稀疏数据,GBM的表现可能不如一些专门针对此类数据设计的算法。

1.2 GBM的应用

  1. 金融领域:在金融领域,GBM广泛用于信用评分模型。通过分析客户的历史信用信息和其他相关特征,GBM可以预测客户的信用风险,帮助金融机构做出更明智的信贷决策。此外,GBM还可以用于股票价格预测,利用历史股票价格数据和市场指标,来预测未来股票价格的走势。

  2. 保险业:在保险行业,GBM可以用于风险评估和理赔预测。通过分析客户的个人信息、保险历史和其他风险因素,GBM可以评估客户的风险水平,并帮助保险公司制定适当的保险政策和定价。此外,GBM还可以预测保险索赔的概率,帮助保险公司更好地管理理赔流程。

  3. 电子商务:在电子商务中,GBM常用于推荐系统。通过分析用户的购买历史、浏览行为和其他行为特征,GBM可以为用户推荐个性化的产品或服务,提高用户的购买转化率和满意度。此外,GBM还可以预测广告点击率,帮助广告商优化广告投放策略。

  4. 医疗和生物科学:在医疗领域,GBM可以应用于疾病预测和诊断支持。通过分析患者的临床数据、医学影像和基因信息,GBM可以预测患者是否患有某种疾病,并提供辅助诊断的信息。此外,GBM还可以用于药物发现,通过分析药物分子和生物活性数据,来预测潜在的新药物候选。

  5. 社交网络:在社交网络中,GBM可以用于用户行为建模和社交关系分析。通过分析用户在社交网络中的互动行为和内容分享,GBM可以预测用户的兴趣和行为,为社交网络平台提供更个性化的服务和推荐。

  6. 自然语言处理:在自然语言处理领域,GBM可以用于情感分析和文本分类。通过分析文本中的情感和语义信息,GBM可以判断文本的情感倾向或者将文本分类到不同的类别中,例如新闻分类、垃圾邮件识别等。

  7. 工业和制造:在工业领域,GBM可以应用于设备故障预测和质量控制。通过分析设备的传感器数据和工艺参数,GBM可以预测设备是否会出现故障,并提前采取维修措施。此外,GBM还可以用于质量控制,帮助检测生产过程中的缺陷和不良品。

  8. 能源和环境:在能源和环境领域,GBM可以应用于能源消耗预测和气候模式预测。通过分析能源消耗的历史数据和影响因素,GBM可以预测未来的能源需求,有助于能源规划和管理。此外,GBM还可以用于气候模式预测,帮助预测天气变化和自然灾害风险。

  9. 交通和物流:在交通和物流行业,GBM可用于交通流量预测和路况评估。通过分析历史交通数据和道路条件,GBM可以预测未来的交通流量和路况,帮助交通管理部门优化交通调度和规划。

2. scikit-learn实战演示

GradientBoostingRegressorGradientBoostingClassifier是Python中scikit-learn库中梯度提升机(Gradient Boosting MachineGBM)的两个主要类。

GradientBoostingRegressor是scikit-learn中用于回归问题的梯度提升机类。它用于解决连续型目标变量的预测问题。通过拟合一系列的回归树(决策树用于回归任务),GBM可以对数据中的非线性关系进行建模,并产生连续型的预测结果。例如,可以使用GradientBoostingRegressor来预测房屋价格、销售量等连续性目标变量。

GradientBoostingClassifier是scikit-learn中用于分类问题的梯度提升机类。它用于解决离散型目标变量的预测问题,将样本划分到不同的类别中。在训练过程中,GBM使用决策树作为弱学习器,来对类别进行预测。例如,可以使用GradientBoostingClassifier来进行垃圾邮件识别、肿瘤类型分类等分类任务。

2.1 分类问题

创建数据集,生成一个包含1000个样本和20个特征的合成分类数据集。其中有15个特征是有信息量的特征,5个特征是冗余的特征。

from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=7)
print(X.shape, y.shape)

# (1000, 20) (1000,)

构建模型,10折交叉,重复3次

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score, RepeatedStratifiedKFold
from sklearn.ensemble import GradientBoostingClassifier

X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=7)
model = GradientBoostingClassifier()
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
n_scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
print('Mean Accuracy: %.3f (%.3f)' % (np.mean(n_scores), np.std(n_scores)))

# Mean Accuracy: 0.899 (0.030)

二分类数据预测

from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier

X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=7)
model = GradientBoostingClassifier()
model.fit(X, y)
row = [0.2929949, -4.21223056, -1.288332, -2.17849815, -0.64527665, 2.58097719, 0.28422388, -7.1827928, -1.91211104, 2.73729512, 0.81395695, 3.96973717, -2.66939799, 3.34692332, 4.19791821, 0.99990998, -0.30201875, -4.43170633, -2.82646737, 0.44916808]
yhat = model.predict([row])

print('Predicted Class: %d' % yhat[0])

# Predicted Class: 1

2.2 回归问题

创建数据集

from sklearn.datasets import make_regression
X, y = make_regression(n_samples=1000, n_features=20, n_informative=15, noise=0.1, random_state=7)
print(X.shape, y.shape)

# (1000, 20) (1000,)

构建模型,10折交叉,重复3次

from numpy import mean, std
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_score, RepeatedKFold
from sklearn.ensemble import GradientBoostingRegressor

X, y = make_regression(n_samples=1000, n_features=20, n_informative=15, noise=0.1, random_state=7)
model = GradientBoostingRegressor()
cv = RepeatedKFold(n_splits=10, n_repeats=3, random_state=1)
n_scores = cross_val_score(model, X, y, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
print('MAE: %.3f (%.3f)' % (mean(n_scores), std(n_scores)))

# MAE: -62.440 (3.259)

模型预测

from sklearn.datasets import make_regression
from sklearn.ensemble import GradientBoostingRegressor

X, y = make_regression(n_samples=1000, n_features=20, n_informative=15, noise=0.1, random_state=7)
model = GradientBoostingRegressor()
model.fit(X, y)
row = [0.20543991, -0.97049844, -0.81403429, -0.23842689, -0.60704084, -0.48541492, 0.53113006, 2.01834338, -0.90745243, -1.85859731, -1.02334791, -0.6877744, 0.60984819, -0.70630121, -1.29161497, 1.32385441, 1.42150747, 1.26567231, 2.56569098, -0.11154792]
yhat = model.predict([row])
print('Prediction: %d' % yhat[0])

# Prediction: 37

3. GBM超参数

梯度提升(Gradient Boosting)是通过组合多个弱学习器(通常是决策树)来构建一个强大的预测模型。在梯度提升算法中,有一些关键的超参数需要设置,以影响模型的性能和训练过程。以下是梯度提升算法的一些重要超参数,这里只演示最常用的5个。

超参数描述
n_estimators集成中弱学习器(通常是决策树)的数量。更多的学习器可能导致过拟合,较少的学习器可能导致欠拟合。
learning_rate控制每个弱学习器对模型的贡献。较小的学习率使得每个学习器的权重调整更小,有助于提高模型的鲁棒性,但可能需要更多的迭代。较大的学习率可能导致过拟合。
max_depth决策树的最大深度,用于控制弱学习器(决策树)的复杂度。较大的最大深度允许决策树更多的分支,可能导致过拟合。较小的最大深度有助于防止过拟合。
min_samples_split决策树节点分裂所需的最小样本数。增加该值可降低模型复杂度,减少过拟合的可能性。
min_samples_leaf叶子节点所需的最小样本数。增加该值可降低模型复杂度,减少过拟合的可能性。
subsample控制每个弱学习器在训练时所使用的样本比例。取值小于1.0时,表示使用部分样本进行训练,有助于降低方差,增加模型的泛化能力。
max_features决定每个决策树节点在分裂时考虑的特征数量。

这些超参数在梯度提升机中非常重要,影响模型的复杂度、拟合性能以及泛化能力。在调参时,需要根据具体的数据集和问题选择合适的超参数取值,以获得最优的模型性能。

### 总模版,后续
from numpy import mean, std, arange
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score, RepeatedStratifiedKFold
from sklearn.ensemble import GradientBoostingClassifier
from matplotlib import pyplot

# 获取数据集
def get_dataset():
    X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=7)
    return X, y

# 获取待评估的模型列表
# ——————————————————
# 替换内容,直接复制每小段超参数建模示例进来运行即可
# ——————————————————

# 使用交叉验证评估给定模型
def evaluate_model(model, X, y):
    # 定义评估过程
    cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
    # 评估模型并收集结果
    scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
    return scores

# 获取数据集
X, y = get_dataset()
# 获取待评估的模型列表
models = get_models()
# 评估模型并存储结果
results, names = list(), list()
for name, model in models.items():
    # 评估模型
    scores = evaluate_model(model, X, y)
    # 存储结果
    results.append(scores)
    names.append(name)
    # 输出性能评估结果
    print('>%s %.3f (%.3f)' % (name, mean(scores), std(scores)))

# 绘制模型性能箱线图进行对比
pyplot.boxplot(results, labels=names, showmeans=True)
pyplot.xlabel('Sample Ratio')
pyplot.ylabel('Accuracy')
pyplot.title('Gradient Boosting Performance Comparison')
pyplot.show()

3.1 决策树数量(n_estimators)

指定构建的弱学习器的数量,也称为迭代次数。更多的迭代次数会增加模型的复杂度,但可能会导致过拟合。该参数默认100,这里演示10 到 5,000 的变化。

# 获取待评估的模型列表
def get_models():
    models = dict()
    # 定义不同树数量
    n_trees = [10, 50, 100, 500, 1000, 5000]
    for n in n_trees:
        models[str(n)] = GradientBoostingClassifier(n_estimators=n)
    return models

3.2 样本数量(subsample)

控制每个弱学习器在训练过程中所使用的样本比例。取值小于1.0时,表示使用部分样本进行训练,这样可以降低方差,增加模型的泛化能力。

# 获取待评估的模型列表
def get_models():
    models = dict()
    # 探索不同的样本比例,从10%到100%,步长为10%
    for i in arange(0.1, 1.1, 0.1):
        key = '%.1f' % i
        models[key] = GradientBoostingClassifier(subsample=i)
    return models

3.3 特征数量(max_features)

决定每个决策树节点在分裂时考虑的特征数量。它可以影响模型的随机性,较小的值可能有助于增加模型的多样性,较大的值可能增加拟合性能。

# 获取要评估的模型列表
def get_models():
    models = dict()
    # 探索特征数量从1到20
    for i in range(1, 21):
        models[str(i)] = GradientBoostingClassifier(max_features=i)
    return models

3.4 学习率(learning_rate)

每个弱学习器对模型的贡献,也称为学习率。较小的学习率会使得每个学习器的权重调整较小,有助于提高模型的鲁棒性,但可能需要更多的迭代次数。

# 获取要评估的模型列表
def get_models():
    models = dict()
    # 定义要探索的学习率
    learning_rates = [0.0001, 0.001, 0.01, 0.1, 1.0]
    for rate in learning_rates:
        key = '%.4f' % rate
        models[key] = GradientBoostingClassifier(learning_rate=rate)
    return models

3.5 决策树深度(max_depth)

决策树的最大深度,用于控制弱学习器(决策树)的复杂度,默认为3。较大的最大深度允许决策树更多的分支,可能导致过拟合。较小的最大深度有助于防止过拟合。

# 获取要评估的模型列表
def get_models():
    models = dict()
    # 定义要探索的树深度范围
    for depth in range(1, 11):
        models[str(depth)] = GradientBoostingClassifier(max_depth=depth)
    return models

4.讨论

这里将n_estimatorssubsamplelearning_ratemax_depth四个参数组合做了个热图,也能大致了解哪个指标对结果的影响更大。

梯度提升是指一类可用于分类或回归预测建模问题的集成机器学习算法。集成是根据决策树模型构建的。一次将一棵树添加到集合中,并进行拟合以纠正先前模型产生的预测错误。这是一种称为 boosting 的集成机器学习模型。使用任意可微损失函数和梯度下降优化算法来拟合模型。这使得该技术被称为“梯度增强”,因为随着模型的拟合,损失梯度被最小化,就像神经网络一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/786505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络设备中的配置文件管理

建立强大网络的第一步是为灾难和网络中断做好准备,许多企业在中断期间遭受损失,因为他们缺乏备份计划并且配置管理不达标,使用配置文件管理工具进行适当的配置文件管理不仅有助于处理网络中断,还有助于优化网络性能。 使用配置文…

Cilium

Cilium是一个开源的、面向Kubernetes和容器环境的网络插件,用于提供高级的网络和安全功能。它是一个用于容器网络和网络层四、七层安全的项目,旨在简化网络和安全层的管理,并提供高性能和低延迟的数据包处理。Cilium通过BPF(Berke…

Windows 使用批处理脚本 kill 进程

使用 jenkins 构建 SpringBoot 项目时,需要增加 kill 进程的功能,否则再次启动时会失败,提示端口被占用。 Windows 平台的脚本 命令 for /f "tokens5" %%p in (netstat /anop tcp ^| findstr /i 8007 ^| findstr /i listening) do…

2023 7.24~7.30 周报 (VelocityGAN)

目录哟 0 上周回顾1 本周计划1.1 论文背景 2 完成情况2.1 背景简述2.2 网络结构: 生成器 (Generator)2.3 网络结构: 判别器 (Discriminator)2.4 损失函数: Loss2.5 OpenFWI中的VelocityGAN与它的核心代码2.5.1 判别器loss2.5.2 生成器loss2.5.3 训练的配置顺序 2.6 复现结果 3 …

【lesson3】Linux基本指令2

文章目录 echo重定向输出重定向>(输出重定向)>>(追加重定向) <(输入重定向) 生成10000行内容到file.txt命令行moremoremore -n(行数) lesslessless / 字符串 ctrl cheadheadhead -n tailtailtial -n |(管道)wcdatedatedate 其它命令 calcalcal 年份cal -1cal …

Python爬虫基础知识点有哪些

目录 Python爬虫基础知识点 Requests库 Beautiful Soup库 正则表达式 数据存储 防止被反爬虫策略 爬虫调度和任务管理 认识robots.txt文件 反爬虫法律与道德 示例代码 Requests库 Beautiful Soup库 正则表达式 数据存储 防止被反爬虫策略 结语 网络世界中信息的…

Ant Design Vue Modal 模态框位置调整

问题描述 有一个功能已经实现的需求&#xff0c;是点击了一个按钮&#xff0c;弹出了如下模态框&#xff1a; 这里看到的就是点击按钮之后用户看到的效果&#xff0c;了保持模态框在用户视野范围内&#xff0c;我钻研如何调整显示的位置。 实现步骤 Ant Design of Vue的官方…

音视频——封装格式原理

视频解码基础 一、封裝格式 ​ 我们播放的视频文件一般都是用一种封装格式封装起来的&#xff0c;封装格式的作用是什么呢&#xff1f;一般视频文件里不光有视频&#xff0c;还有音频&#xff0c;封装格式的作用就是把视频和音频打包起来。 所以我们先要解封装格式&#xff0…

【Android Framework系列】第8章 事件分发你真了解吗?

1 事件分发基本认知 1.1 事件分发的”事件“是指什么 1.2 事件处理中涉及到的点 1.3 Android 事件处理的三个流程 在Android中&#xff0c;Touch事件的分发分服务端和应用端。在服务端由WindowManagerService&#xff08;借助InputManagerService&#xff09;负责采集和分发的…

高校vr元宇宙虚拟禁毒体验推动社会戒毒工作的深入开展

元宇宙是指一个虚拟的、全球性的、可交互的虚拟世界&#xff0c;深度融合了VR虚拟现实、AR增强现实和ai等技术。将元宇宙应用于戒毒安全教育平台&#xff0c;具有以下现实意义&#xff1a; 创造安全的学习环境 戒毒安全教育需要让人们了解毒品的危害和如何预防&#xff0c;但直…

水环境综合治理监测系统:筑牢城市水生态安全屏障

水是生命之源&#xff0c;是人类赖以生存的基础。然而&#xff0c;随着工业化、城市化的快速发展&#xff0c;水污染问题日益凸显&#xff0c;给居民的环境卫生以及用水安全带来了巨大的威胁。因此&#xff0c;加强水环境综合治理&#xff0c;保护水资源和维护生态平衡&#xf…

vue之ReadIdcard(身份证读取组件)

组件功能 读取二代身份证信息组件,包含无效身份证验证,过期身份证验证,是否满16周岁验证 #界面 #界面输入项 序号输入项输入形式是否必输是否可配置备注1

CentOS 安装Oracle11g

一、方式一&#xff08;亲测&#xff09; https://blog.csdn.net/zw521cx/article/details/108550215 遇到问题解决&#xff1a; 1.执行 dbca -silent -responseFile /home/oracle/response/dbca.rsp 报错 解决办法&#xff1a; a.全局查找 [rootVM-0-8-centos ~]# locate S…

Vue全局事件总线

main.js //引入Vue import Vue from vue //引入App import App from "./App"; //关闭Vue的生产提示 Vue.config.productionTip false // const Demo Vue.extend({}) // const d new Demo() // Vue.prototype.x d//创建vm new Vue({el:#app,render:h>h(App),b…

MySQL 修改时区的方法

文章目录 什么是MySQL时区&#xff1f;通过MySQL命令模式下修改首先查看MySQL当前的时间进行修改 不方便重启MySQL&#xff0c;临时解决时区问题通过修改配置文件mysql.cnf(my.ini)来进行修改总结 环境&#xff1a;Windows10系统&#xff0c;MySQL5.7版本 mysql修改时区的方法&…

音视频——压缩原理

H264视频压缩算法现在无疑是所有视频压缩技术中使用最广泛&#xff0c; 最流行的。随着 x264/openh264以及ffmpeg等开源库的推出&#xff0c;大多数使用者无需再对H264的细节做过多的研究&#xff0c;这大降低了人们使用H264的成本。 但为了用好H264&#xff0c;我们还是要对…

X - Transformer

回顾 Transformer 的发展 Transformer 最初是作为机器翻译的序列到序列模型提出的&#xff0c;而后来的研究表明&#xff0c;基于 Transformer 的预训练模型&#xff08;PTM&#xff09; 在各项任务中都有最优的表现。因此&#xff0c;Transformer 已成为 NLP 领域的首选架构&…

面试题——当实体类中的属性名和表中的字段名不一样,如何将查询的结果封装到指定 pojo?

在使用MyBatis的时候&#xff0c;应该注意实体类的属性名尽量和表的字段名尽量相同&#xff0c;如果不同将会导致MyBatis无法完成数据的封装&#xff0c;但是在软件开发过程中&#xff0c;数据库的创建和软件环境的搭建不可能是同一个人&#xff0c;实体类属性名和数据库的字段…

真正理解红黑树,真正的(Linux内核里大量用到的数据结构

作为一种数据结构&#xff0c;红黑树可谓不算朴素&#xff0c;因为各种宣传让它过于神秘&#xff0c;网上搜罗了一大堆的关于红黑树的文章&#xff0c;不外乎千篇一律&#xff0c;介绍概念&#xff0c;分析性能&#xff0c;贴上代码&#xff0c;然后给上罪恶的一句话&#xff0…

芯片制造详解.从沙子到晶圆.学习笔记(一)

刚入行半导体行业&#xff0c;很多知识需要系统的学习&#xff0c;想从入门通俗易懂的知识开始学起&#xff0c;于是在导师的帮助下&#xff0c;找到了这门课程&#xff0c;那就从这门课程开始打开我的半导体之旅吧。 我只是对视频内容的提炼&#xff0c;和自己的学习心得&…