【机器学习】梯度提升和随机森林的概念、两者在python中的实例以及梯度提升和随机森林的区别

news2024/9/21 21:22:35

引言

梯度提升(Gradient Boosting)是一种强大的机器学习技术,它通过迭代地训练决策树来最小化损失函数,以提高模型的预测性能
随机森林(Random Forest)是一种基于树的集成学习算法,它通过组合多个决策树来提高预测的准确性和稳定性

文章目录

  • 引言
  • 一、梯度提升
    • 1.1 基本原理
      • 1.1.1 初始化模型
      • 1.1.2 迭代优化
      • 1.1.3 梯度计算
      • 1.1.4模型更新
    • 1.2 关键步骤
    • 1.3 梯度提升树(GBDT)
    • 1.4 常用库
    • 1.5 总结
  • 二、梯度提升在python中的实例
    • 2.1 代码
    • 2.2 代码解释
  • 三、随机森林
    • 3.1 关键特点
      • 3.1.1 集成学习
      • 3.1.2 数据样本的随机性
      • 3.1.3 特征选择的随机性
      • 3.1.4 不需要大量参数调整
      • 3.1.5 抗过拟合能力
    • 3.2 实现步骤
  • 四、随机森林在python中的实例
    • 4.1 代码
    • 4.2 代码解释
  • 五、随机森林和梯度提升的区别
    • 5.1 训练过程
    • 5.2 树的权重和组合
    • 5.3 特征选择
    • 5.4 泛化能力和过拟合
    • 5.5 计算复杂度
    • 5.6 应用场景
    • 5.7 总结

一、梯度提升

在这里插入图片描述

1.1 基本原理

1.1.1 初始化模型

梯度提升算法从一个简单的模型开始,例如一个常数预测器

1.1.2 迭代优化

在每一轮迭代中,算法会训练一个新的模型来拟合残差(实际值与当前模型预测值之间的差异)。通过这种方式,新模型专注于纠正前一个模型的错误

1.1.3 梯度计算

在每一轮迭代中,算法计算损失函数的梯度,这表示损失函数在当前模型预测值处的斜率。梯度指向损失增加最快的方向

1.1.4模型更新

新训练的模型用于更新当前模型,使其在梯度方向上迈出一步,从而减少损失

1.2 关键步骤

  1. 损失函数:选择一个合适的损失函数,例如平方损失(用于回归问题)或对数损失(用于分类问题)
  2. 决策树:梯度提升通常使用决策树作为基学习器。决策树的深度通常较小,以防止过拟合
  3. 负梯度:计算当前模型的负梯度,这表示损失函数下降最快的方向
  4. 拟合残差:使用决策树拟合负梯度,得到一个新模型
  5. 学习率(Shrinkage):对新模型的贡献进行缩放,以防止过拟合。学习率是一个超参数,通常需要通过交叉验证来调整
  6. 模型更新:将新模型添加到当前模型中,以更新预测
  7. 迭代:重复上述步骤,直到达到预定的迭代次数或损失不再显著下降

1.3 梯度提升树(GBDT)

梯度提升树(Gradient Boosting Decision Tree,GBDT)是梯度提升的一种实现,它使用决策树作为基学习器。GBDT在许多机器学习任务中表现出色,尤其是在结构化数据上

1.4 常用库

在Python中,常用的梯度提升库有:

  • XGBoost
  • LightGBM
  • CatBoost
    这些库提供了高效的梯度提升算法实现,并且具有许多优化和特性,使得模型训练更加快速和准确。

1.5 总结

梯度提升是一种强大的机器学习技术,通过迭代地优化模型来提高预测性能。在实际应用中,合理调整超参数和使用先进的梯度提升库可以帮助我们构建高效、准确的模型

二、梯度提升在python中的实例

可以使用Python中的scikit-learn库来实现梯度提升(Gradient Boosting)。我们将使用梯度提升回归器(Gradient Boosting Regressor)来训练一个模型,并用它来预测一些数据

2.1 代码

以下是一个完整的例子,包括数据生成、模型训练和预测:

# 导入所需的库
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error
# 生成模拟数据
X, y = make_regression(n_samples=1000, n_features=20, noise=0.1, random_state=42)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化梯度提升回归器
gb_regressor = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
# 训练模型
gb_regressor.fit(X_train, y_train)
# 进行预测
y_pred = gb_regressor.predict(X_test)
# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"均方误差: {mse}")
# 打印特征重要性
feature_importances = gb_regressor.feature_importances_
print(f"特征重要性: {feature_importances}")

输出结果:
在这里插入图片描述

2.2 代码解释

  • 首先生成了一个包含1000个样本和20个特征的回归数据集
  • 然后将数据集划分为训练集和测试集,其中测试集占20%
  • 接着创建了一个GradientBoostingRegressor对象,并设置了树的数(n_estimators)、学习率(learning_rate)和树的最大深度(max_depth
  • 使用训练集数据训练模型
  • 使用训练好的模型对测试集进行预测
  • 最后,计算了模型的均方误差,并打印了特征的重要性

三、随机森林

在这里插入图片描述

随机森林能够用于分类和回归任务,并且在许多实际应用中表现出色

3.1 关键特点

3.1.1 集成学习

随机森林是由多个决策树组成的集合,每个树都对数据进行投票(分类任务)或取平均值(回归任务)以产生最终的预测

3.1.2 数据样本的随机性

在构建每棵树时,随机森林从原始数据集中随机抽取一个子集进行训练。这种抽样称为“装袋”(Bagging)

3.1.3 特征选择的随机性

在树的每个节点上,随机森林会从所有特征中随机选择一个子集来决定最佳分割点。这增加了树之间的多样性,有助于提高模型的泛化能力

3.1.4 不需要大量参数调整

随机森林通常不需要复杂的参数调整,这使得它成为一个易于使用且效果不错的算法

3.1.5 抗过拟合能力

由于随机森林结合了多个决策树,每个树都在不同的数据子集上训练,因此它通常能够避免过拟合

3.2 实现步骤

  1. 数据抽样:从原始数据集中进行有放回的随机抽样,得到多个训练子集
  2. 树构建:对于每个训练子集,构建一个决策树。在每个节点上,随机选择特征子集,并找到最佳分割点
  3. 树集成:将所有决策树的预测结果进行汇总。对于分类问题,通常采用多数投票;对于回归问题,通常取平均值

四、随机森林在python中的实例

4.1 代码

以下是一个使用scikit-learn库实现随机森林的简单例子

# 导入所需的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 初始化随机森林分类器
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
# 训练模型
rf_classifier.fit(X_train, y_train)
# 进行预测
y_pred = rf_classifier.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"准确率: {accuracy}")
# 打印特征重要性
feature_importances = rf_classifier.feature_importances_
print(f"特征重要性: {feature_importances}")

输出结果:
在这里插入图片描述

4.2 代码解释

  • 首先加载了Iris数据集
  • 然后将其划分为训练集和测试集
  • 接着,我们创建了一个RandomForestClassifier对象,并用训练集数据训练了模型
  • 最后,我们评估了模型的准确率并打印了特征的重要性

五、随机森林和梯度提升的区别

梯度提升(Gradient Boosting)和随机森林(Random Forest)都是基于决策树的集成学习算法,但它们在构建集成模型的方式和原理上有显著的不同

5.1 训练过程

  • 梯度提升
    • 采用串行训练方式,每一棵树都是为了纠正前一棵树的错误而训练的
    • 每棵树都是基于残差(实际值与当前模型预测值之间的差异)进行训练的
    • 通过梯度下降在损失函数上迭代优化,逐步构建模型
  • 随机森林
    • 采用并行训练方式,每棵树都是独立地从原始数据集中抽取的子集上进行训练
    • 每棵树的训练不依赖于其他树,它们之间是相互独立的
    • 通过随机选择特征和样本来增加模型的多样性,减少过拟合

5.2 树的权重和组合

  • 梯度提升
    • 每棵树都有不同的权重,这些权重是基于它们减少损失的能力来确定的
    • 最终的预测是所有树预测的加权和
  • 随机森林
    • 所有树在最终预测中的权重是相同的
    • 对于分类问题,通常采用多数投票来决定最终的类别;对于回归问题,通常取所有树预测的平均值

5.3 特征选择

  • 梯度提升
    • 在每个分割点考虑所有特征,选择最佳分割
  • 随机森林
    • 在每个分割点随机选择一个特征子集,并从中选择最佳分割

5.4 泛化能力和过拟合

  • 梯度提升
    • 由于梯度提升专注于减少残差,它可能会对训练数据过度拟合,特别是如果没有适当的正则化或早停机制
  • 随机森林
    • 由于其随机性和独立性,随机森林通常具有较好的泛化能力,对过拟合有一定的抵抗力

5.5 计算复杂度

  • 梯度提升
    • 通常计算成本较高,因为它需要连续地训练多棵树,并且每棵树都要与前一棵树的结果相配合
  • 随机森林
    • 计算成本相对较低,因为树是并行训练的,并且每棵树的训练可以并行化

5.6 应用场景

  • 梯度提升
    • 通常用于需要高预测精度的任务,如广告点击率预测、信用评分等
  • 随机森林
    • 适用于需要快速、稳定预测的场景,如分类问题、特征选择等

5.7 总结

梯度提升和随机森林都是强大的机器学习工具,但它们在模型构建、泛化能力、计算复杂度和适用场景上有所不同。选择哪个算法取决于具体问题的需求、数据特性和性能要求

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2107280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网关桥梁:modbus 转 profinet 网关中频加热机的智能融合之旅

一、项目序章:金属热处理的智慧曙光在金属锻造的辉煌舞台上,中频感应加热电源以其高效节能、精准控温的卓越才艺,成为了热处理、焊接与成型艺术中不可或缺的幕后英雄。然而,随着工业自动化的浪潮汹涌而至,如何让这位英…

ig运营事半功倍千万做到这“四不要”

在运营品牌Ins的时候,想要把账号做活跃,就不能做单一的内容,一定要多元化分配内容,下面这4个不要做,一定请记住! 1. 不要只是介绍您的产品。否则,你的内容就会变得单调、乏味。观众喜欢阅读故事…

Java中的TCP/IP与UDP协议Socket入门

Socket: 简单地说Socket就相当于是一家快递公司包括: 寄件人: 1.包裹放快递盒里(数据打包:DatagramSocket) 2.运输快递(发送数据) 3.付钱回家(释放资源)…

1.初识ChatGPT:AI聊天机器人的革命(1/10)

引言 在当今的数字化世界中,人工智能(AI)正以其独特的方式重塑我们的生活和工作。其中,AI聊天机器人作为人机交互的前沿技术,已经成为企业与客户沟通、提供个性化服务的重要工具。这些机器人通过模拟人类的对话方式&a…

Android 存储之 SharedPreferences 框架体系编码模板

一、SharedPreferences 框架体系 1、SharedPreferences 基本介绍 SharedPreferences 是 Android 的一个轻量级存储工具,它采用 key - value 的键值对方式进行存储 它允许保存和读取应用中的基本数据类型,例如,String、int、float、boolean …

RKNPU2项目实战【1】 ---- YOLOv5实时目标分类

目录 目标 一、python接口下实现yolov5模型在开发板上的部署 1.1 在rknntoolkit2环境下模拟实现yolov5模型在RK3588开发板上的推理测试 1.2 在rknntoolkit2环境下实现模型在RK3588开发板上的连板推理测试(模型运行在NPU上) 1.3 在rknntoolkitlite2环…

使用llamaindexLLM大模型构建一个可离线可在线可异步扩展信息的RAG智能问答系统

之前对一件事很好奇,为什么去年训练的大模型可以回答今天的新闻内容。答案是使用了知识扩展系统。基本原理是把参考答案和问题一同提给大模型,给他充分的参考信息做回复编辑。 本文教你完成离线版本的智能问答系统搭建。 最近在疯狂找下家,本人精通图形渲染和ai,求捞啊! …

没参加会议,还要 30000 字的会议材料写总结?用好 AI工具,30 分钟堵住领导的嘴

前段时间本来要参加总公司的重要会议,但由于临时出差错过了。 分公司老总,给了我 10 份会议材料内容,让我学习,并在节后梳理出要点。 结果,一过节就全都给忘记了,咋办?听说最近Kimi出了新玩法…

k8s 部署 jenkins【详细步骤】

文章目录 部署介绍部署步骤第 1 步:创建 namespace第 2 步:创建 ServiceAccount第 3 步:创建持久卷第 4 步:创建 Deployment第 5 步:创建 Service第 6 步:浏览器访问 Jenkins第 7 步:修改默认时区参考⭐ 本文目标:在 k8s 集群中部署一个 jenkins。 部署介绍 🚀 在 K…

内推|京东|后端开发|运维|算法...|北京 更多岗位扫内推码了解,直接投递,跟踪进度

热招岗位 更多岗位欢迎扫描末尾二维码,小程序直接提交简历等面试。实时帮你查询面试进程。 安全运营中心研发工程师 岗位要求 1、本科及以上学历,3年以上的安全相关工作经验; 2、熟悉c/c、go编程语言之一、熟悉linux网络编程和系统编程 3、…

coze 的插件输入飞书多维表格 app_token 后一直显示错误,如何解决?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

为什么企业数据资产入表实践少?

​在财政部会计司发布暂行规定之后,数据资产化或者数据要素市场进入加速期。 会计司在22年12月1号首先发布征求意见,经过半年时间迅速迭代后,正式发布了暂行规定。文件第8条规定了国家实行统一的会计制度,任何相关的企业组织都必须…

【网络安全】调试模式获取敏感数据

未经许可,不得转载。 文章目录 漏洞原因步骤PHPPythonASPNode.js漏洞原因 当开发者忘记在生产环境中禁用调试模式,应用在发生错误时,可能会输出详细的错误信息。这些错误信息(比如“error title”或堆栈跟踪)通常包含了应用程序的内部结构、配置甚至数据库连接信息等敏感…

Windows自动化程序开发指南

自动化程序的概念 “自动化程序”指的是通过电脑编程来代替人类手工操作的一类程序或软件。这类程序具有智能性高、应用范围广的优点,但是自动化程序的开发难度大、所用技术杂。 本文对自动化程序开发的各个方面进行讲解。 常见的处理对象 自动化程序要处理的对…

公认最好的跑步耳机分享,选购骨传导运动耳机需注意的五大陷阱!

跑步,不仅是一种锻炼身体的方式,更是一种生活态度的体现。它让我们在汗水中释放压力,在节奏中感受生命的律动。而音乐,作为跑步时的完美伴侣,能够激发我们的运动潜能,让我们的跑步之旅更加愉悦。因此&#…

软件测试自动化面试题(含答案)

1.如何把自动化测试在公司中实施并推广起来的? 选择长期的有稳定模块的项目 项目组调研选择自动化工具并开会演示demo案例,我们主要是演示selenium和robot framework两种。 搭建自动化测试框架,在项目中逐步开展自动化。 把该项目的自动化…

示波器基础知识汇总(2)

系列文章目录 1.元件基础 2.电路设计 3.PCB设计 4.元件焊接 5.板子调试 6.程序设计 7.算法学习 8.编写exe 9.检测标准 10.项目举例 11.职业规划 送给大学毕业后找不到奋斗方向的你(每周不定时更新) 中国计算机技术职业资格网 上海市工程系列计算机专…

SD-WAN解决企业远程服务难题

在当今数字化和全球化的商业环境中,企业不再受限于地理位置。远程工作和分布式团队已成为常态,但随之而来的是对网络连接的更高需求。本文将讨论企业远程服务中的挑战,并介绍一个解决这些挑战的有效方案——SD-WAN。 随着远程工作的增加&…

Buzzer:一款针对eBPF的安全检测与模糊测试工具

关于Buzzer Buzzer是一款功能强大的模糊测试工具链,该工具基于Go语言开发,可以帮助广大研究人员简单高效地开发针对eBPF的模糊测试策略。 功能介绍 下面给出的是当前版本的Buzzer整体架构: 元素解析: 1、ControlUnit&#xff1a…

查看元神操作系统的版本

1. 背景 本文通过元神操作系统的API调用来获取元神系统的版本,并显示在屏幕上。 2. 方法 (1)编写程序 本例先设置系统调用的参数:第一个参数设置为API_OS_VER,表示获取元神操作系统的版本号;第二个参数…