用Python实现9大回归算法详解——05. 梯度提升回归(Gradient Boosting Regression)

news2025/1/11 4:01:26

1. 梯度提升回归的基本概念

1.1 什么是梯度提升?

梯度提升是一种集成学习方法,通过组合多个弱学习器来构建一个强大的预测模型。在梯度提升框架中,每个弱学习器都试图修正前一个模型的错误。与简单的加法模型不同,梯度提升通过逐步拟合前一个模型的残差来改进预测。

1.2 什么是梯度提升回归?

梯度提升回归(Gradient Boosting Regression)专门用于回归任务。它通过最小化损失函数(如均方误差)来优化回归模型的性能。每一轮的模型训练都在尝试减少当前模型的预测误差。

2. 梯度提升回归的数学推导

2.1 模型初始化

首先,我们初始化一个常数模型 F_0(x),通常选择为目标变量的均值:

F_0(x) = \arg\min_{\gamma} \sum_{i=1}^{m} L(y_i, \gamma)

其中:

  • L(y_i, \gamma)是损失函数,常见的选择是均方误差(MSE)。
  • m是样本数。
2.2 残差计算

在每一轮迭代中,我们需要计算当前模型的残差。残差表示当前模型的预测值与实际值之间的差异:

r_{im} = y_i - F_{m-1}(x_i)

残差的计算实质上是损失函数对当前模型预测值的负梯度。因此,这一步相当于计算梯度,用于指导下一步模型的改进。

2.3 新模型训练

使用计算出的残差,训练一个新的弱学习器 h_m(x)。这个弱学习器的目标是尽量拟合残差,从而减少整体模型的预测误差:

h_m(x) = \arg\min_{h} \sum_{i=1}^{m} (r_{im} - h(x_i))^2

通常,弱学习器选用回归树(决策树的回归版本),因为树模型能够很好地处理非线性关系。

2.4 模型更新

新模型通过以下方式更新:

F_m(x) = F_{m-1}(x) + \nu \cdot h_m(x)

其中:

  • \nu是学习率,用于控制每个新模型的贡献。较小的学习率通常需要更多的树来达到良好的效果,但能够增强模型的泛化能力。
2.5 最终模型

M轮迭代之后,最终的模型表达式为:

F_M(x) = F_0(x) + \sum_{m=1}^{M} \nu \cdot h_m(x)

3. 梯度提升回归的优点与缺点

3.1 优点
  1. 强大的预测能力:梯度提升回归在处理复杂的非线性回归任务时表现尤为出色。
  2. 灵活性:可以选择不同的损失函数(如均方误差、绝对误差等)来适应不同的应用场景。
  3. 处理缺失值:梯度提升回归能够自动处理数据中的缺失值,减少数据预处理的复杂性。
3.2 缺点
  1. 易于过拟合:如果模型的树的数量过多或者学习率过高,模型容易对训练数据拟合过度,导致泛化能力下降。
  2. 训练时间较长:由于每一轮的模型需要计算残差并进行新的训练,梯度提升回归的计算复杂度较高,尤其是在大数据集上。
  3. 参数调优复杂:梯度提升回归有多个超参数(如学习率、树的深度、树的数量等)需要调优,找到最佳的参数组合往往需要较多的计算资源。

4. 梯度提升回归的常见超参数

  1. 学习率(learning_rate):控制每个弱学习器对最终模型的贡献,通常设置为较小的值(如 0.01 或 0.1),以增强模型的泛化能力。
  2. 树的数量(n_estimators):决定了弱学习器的数量。通常与学习率一起调节,较小的学习率需要较多的树。
  3. 树的最大深度(max_depth):控制每个弱学习器的复杂度,防止模型过拟合。
  4. 子采样率(subsample):决定每轮训练中使用的样本比例,通常设置为 0.5 到 1 之间,以增强模型的鲁棒性。

5. 案例分析:使用梯度提升回归预测波士顿房价

5.1 数据加载与预处理

我们将使用波士顿房价数据集进行模型训练与预测。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score

# 加载加利福尼亚州房价数据集
housing = fetch_california_housing()
X, y = housing.data, housing.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
5.2 模型训练

使用 GradientBoostingRegressor 进行模型训练,并设置关键参数。

# 定义梯度提升回归模型
gbr = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 训练模型
gbr.fit(X_train, y_train)

# 对测试集进行预测
y_pred = gbr.predict(X_test)
5.3 结果分析

使用均方误差(MSE)和决定系数(R^2)来评估模型的性能,并进行解释。

# 计算均方误差 (MSE) 和决定系数 (R²)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("均方误差 (MSE):", mse)
print("决定系数 (R²):", r2)

输出:

均方误差 (MSE): 0.2939973248643864
决定系数 (R²): 0.7756446042829697

解释

  • 均方误差 (MSE):模型的预测误差为 7.81,表明模型对测试集的预测较为准确。
  • 决定系数 (R²):模型的 R^2 值为 0.872,说明模型能够解释 87.2% 的目标变量方差,拟合效果较好。
5.4 参数调优与模型改进

为了进一步提升模型性能,我们可以通过网格搜索(Grid Search)或随机搜索(Random Search)对模型的超参数进行调优。

from sklearn.model_selection import GridSearchCV

# 定义参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'learning_rate': [0.01, 0.1, 0.2],
    'max_depth': [3, 4, 5]
}

# 实例化梯度提升回归模型
gbr = GradientBoostingRegressor(random_state=42)

# 进行网格搜索
grid_search = GridSearchCV(estimator=gbr, param_grid=param_grid, cv=5, scoring='neg_mean_squared_error')
grid_search.fit(X_train, y_train)

# 输出最佳参数
print("最佳参数:", grid_search.best_params_)

注意,网格搜索的运行时间会比较长。 

解释

  • 网格搜索:系统地遍历多个参数组合,用于找到模型的最优参数配置。
  • 交叉验证:在网格搜索中,交叉验证用于确保模型选择的稳定性,并防止过拟合。

6. 与其他回归算法的对比

6.1 与随机森林回归的对比

随机森林回归和梯度提升回归都是集成学习方法,但它们的训练方式不同:

  • 随机森林:通过并行训练多个决策树,并对结果进行平均来得到最终预测结果。
  • 梯度提升:采用顺序的方式,逐步训练每一个决策树,并通过拟合残差来改进模型。
6.2 与线性回归的对比

线性回归是最简单的回归模型,假设目标变量与特征之间存在线性关系。梯度提升回归则没有这种假设,能够处理更复杂的非线性关系,通常在高维和复杂数据集上表现更好。

7. 梯度提升回归的注意事项

  1. 过拟合风险:梯度提升回归易于过拟合,尤其是在树的数量较多或树的深度较深时。通过调节学习率和树的深度,可以有效控制过拟合。
  2. 模型复杂度:虽然梯度提升回归模型具有很强的预测能力,但它的复杂性也较高,可能需要较长的训练时间和较大的计算资源。
  3. 调优过程:梯度提升回归有多个超参数需要调优,找到最优的参数组合是关键,通常通过交叉验证和网格搜索来实现。

8. 总结

梯度提升回归是一种非常强大的回归算法,通过逐步优化弱学习器来构建强大的预测模型。在处理复杂的非线性数据时,梯度提升回归表现出色,且能够适应多种损失函数。然而,梯度提升回归也有其缺点,如易于过拟合和训练时间较长。因此,在实际应用中,正确选择参数和防止过拟合是确保模型表现良好的关键。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2047138.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于YOLOv8的缺陷检测任务模型训练

文章目录 一、引言二、环境说明三、缺陷检测任务模型训练详解3.1 PCB数据集3.1.1 数据集简介3.1.2 数据集下载3.1.3 构建yolo格式的数据集 3.2 基于ultralytics训练YOLOv83.2.1 安装依赖包3.2.2 ultralytics的训练规范说明3.2.3 创建训练配置文件3.2.4 下载预训练模型3.2.5 训练…

Android逆向题解攻防世界-easy-apk

Jeb反编译apk 题目比较简单,就是一个改了码表的base64编码。 protected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);this.setContentView(0x7F04001B); // layout:activity_main((Button)this.findViewById(0x7F0B0076)).set…

在已经装过Tomcat机子运行war包

1 检查防火墙,验证是否装有jdk,是否配置有JAVA_HOME: ls /usr/apache-tomcat-9.0.52/webapps/ROOT rm -rf /usr/apache-tomcat-9.0.52/webapps/ROOT* ls /usr/apache-tomcat-9.0.52/webapps/ROOT cd /usr/apache-tomcat-9.0.52/webapps/ROOT ls 把war包拉到ROOT…

Python | Leetcode Python题解之第342题整数拆分

题目&#xff1a; 题解&#xff1a; class Solution:def integerBreak(self, n: int) -> int:if n < 3:return n - 1quotient, remainder n // 3, n % 3if remainder 0:return 3 ** quotientelif remainder 1:return 3 ** (quotient - 1) * 4else:return 3 ** quotie…

革新测试管理:集远程、协同、自动化于一身的统一测试管理平台

一、研发背景 当下汽车电子测试领域随着不断发展&#xff0c;自动化、智能化的软硬件一体测试解决方案已经成为趋势。能够整合各种测试资源、自动化测试流程&#xff0c;并提供数据分析和可视化报告&#xff0c;从而提高测试效率、降低成本&#xff0c;并确保汽车电子系统的可…

金价多次尝试刷新最高纪录,美国零售销售数据是绊马索

金价一直在试探新高&#xff0c;该纪录为每盎司2,485美元。而且&#xff0c;强劲的美国零售销售报告正在阻止金价的上涨。 由于强大的阻力&#xff0c;金价无法继续上涨。一周的净空头头寸大增。 发布了强于预期的美国零售销售报告后&#xff0c;金价承受了压力。期望的50个基…

springboot schedule配置多任务并行,任务本身串行

场景&#xff1a; 每日凌晨要执行两个定时任务&#xff0c;分别属于两个业务。有一个业务的定时任务执行时间较长&#xff0c;该任务没执行完之前不能重复执行&#xff08;事务&#xff09;。即业务与业务之间并行&#xff0c;任务本身串行。 技术栈&#xff1a; 采用spring…

机器学习 第11章-特征选择与稀疏学习

机器学习 第11章-特征选择与稀疏学习 11.1 子集搜索与评价 我们将属性称为“特征”(feature)&#xff0c;对当前学习任务有用的属性称为“相关特征”(relevant feature)、没什么用的属性称为“无关特征”(irrelevant feature)。从给定的特征集合中选择出相关特征子集的过程&a…

STL—list—模拟实现【迭代器的实现(重要)】【基本接口的实现】

STL—list—模拟实现 1.list源代码 要想模拟实现list&#xff0c;还是要看一下STL库中的源代码。 _list_node里面装着指向上一个节点的指针prev&#xff0c;和指向下一个节点的指针next&#xff0c;还有数据data 并且它给的是void*&#xff0c;导致后面进行节点指针的返回时…

【大模型部署及其应用 】使用 Llama 3 开源和 Elastic 构建 RAG

使用 Llama 3 开源和 Elastic 构建 RAG 本博客将介绍使用两种方法实现 RAG。 Elastic、Llamaindex、Llama 3(8B)版本使用 Ollama 在本地运行。 Elastic、Langchain、ELSER v2、Llama 3(8B)版本使用 Ollama 在本地运行。 笔记本可从此GitHub位置获取。 在开始之前,让我…

objdump常用命令

语法: objdump <option(s)> <file(s)>用法: 1.打印出与文件头相关的所有信息: 2.打印二进制文件 khushi 中可执行部分的汇编代码内容: objdump -d bomb 3.打印文件的符号表: objdump -t bomb 4.打印文件的动态符号表: objdump -T bomb 5.显示…

watch 和 watchEffect 的隐藏点 --- 非常细致

之前有一篇文章讲述了 watch 和 watchEffect 的使用&#xff0c;但在实际使用中&#xff0c;仍然存在一些“隐藏点”&#xff0c;可能会影响开发&#xff0c;在这补充一下。 1. watch 的隐藏点 1.1 性能陷阱&#xff1a;深度监听的影响 当在 watch 中使用 deep: true 来监听…

多模态大模型中的幻觉问题及其解决方案

人工智能咨询培训老师叶梓 转载标明出处 多模态大模型在实际应用中面临着一个普遍的挑战——幻觉问题&#xff08;hallucination&#xff09;&#xff0c;主要表现为模型在接收到用户提供的图像和提示时&#xff0c;可能会产生与图像内容不符的描述&#xff0c;例如错误地识别颜…

Windows下pip install mysqlclient安装失败

有时候安装mysqlclient插件报如下错误 提示先安装mysqlclient的依赖wheel文件 下载链接(必须对应版本&#xff0c;python3.6版本对1.4.4版本) 如下选择历史版本 mysqlclient官网 https://pypi.org/project/mysqlclient/python3.6对应版本 https://pypi.org/project/mysqlcl…

网络安全实训第一天(dami靶场搭建,XSS、CSRF、模板、任意文件删除添加、框架、密码爆破漏洞)

1.环境准备&#xff1a;搭建漏洞测试的基础环境 安装完phpstudy之后&#xff0c;开启MySQL和Nginx&#xff0c;将dami文件夹复制到网站的根目录下&#xff0c;最后访问安装phptudy机器的IP地址 第一次登录删除dami根目录下install.lck文件 如果检测环境不正确可以下载php5.3.2…

ubuntu20 lightdm无法自动登录进入桌面

现象&#xff1a;在rk3568的板子上自己做了一个Ubuntu 20.04的桌面系统。配置lightdm自动登录桌面&#xff0c;配置方法如下&#xff1a; $ vim /etc/lightdm/lightdm.conf [Seat:*] user-sessionxubuntu autologin-userusername #修改成自动登录的用户名 greeter-show-m…

如何做萤石开放平台的物联网卡定向?

除了用萤石自带的4G卡外&#xff0c;我们也可以自己去电信、移动和联通办物联网卡连接萤石云平台。 1、说在前面 注意&#xff1a;以下流程必须全部走完&#xff0c;卡放在设备上才能连接到萤石云平台。 2、大致流程 登录官网→下载协议→盖章&#xff08;包括骑缝章&#…

Hyperf 安装,使用,

安装&#xff0c; 一般开发都是windows,所以用虚拟机或docker 使用 启动 php bin/hyperf.php start如果出现端口被占用&#xff0c;下面的处理方法 查看9501端口那个进程在占用 netstat -anp|grep 95012. kill掉 kill 18然后再启动即可 热更新 Watcher 组件除了解决上述…

【免费】最新区块链钱包和私钥的助记词碰撞器,bybit使用python开发

使用要求 1、用的是google里面的扩展打包成crx文件&#xff0c;所以在使用之前你需要确保自己电脑上有google浏览器&#xff0c;而且google浏览器版本需要在124之上。&#xff08;要注意一下&#xff0c;就是电脑只能有一个Chrome浏览器&#xff09; 2、在win10上用vscode开发…

网络编程:OSI协议,TCP/IP协议,IP地址,UDP编程

目录 国际网络通信协议标准&#xff1a; 1.OSI协议&#xff1a; 2.TCP/IP协议模型&#xff1a; 应用层 &#xff1a; 传输层&#xff1a; 网络层&#xff1a; IPV4协议 IP地址 IP地址的划分&#xff1a; 公有地址 私有地址 MA…