一个必会算法模型,XGBoost !!

news2024/9/27 21:19:29

大家好,今天咱们来聊聊XGBoost ~

XGBoost(Extreme Gradient Boosting)是一种集成学习算法,是梯度提升树的一种改进。它通过结合多个弱学习器(通常是决策树)来构建一个强大的集成模型。

XGBoost 的核心原理涉及到损失函数的优化和树模型的构建。

核心原理

1. 损失函数(Loss Function)

假设我们有一个由  个样本组成的训练数据集,,其中  是特征向量, 是对应的标签。

XGBoost 使用泰勒展开式对损失函数进行近似。对于一般的损失函数 ,泰勒展开式可以写作:

其中, 是当前模型的预测值, 和  分别是损失函数关于预测值的一阶导数(梯度)和二阶导数(海森矩阵)。这里的  表示第  个样本。

2. 正则化项(Regularization Term)

为了防止过拟合,XGBoost 引入了正则化项。正则化项包含了树模型的复杂度,可以写作:

其中, 是叶子节点的数量, 是叶子节点的分数, 和  是正则化参数。

3. 目标函数(Objective Function)

XGBoost 的目标函数是损失函数和正则化项的加权和。假设我们有  个树模型,每个树模型表示为 ,目标函数可以写作:

4. 模型更新(Model Update)

XGBoost 采用贪婪算法逐步构建树模型。每一步迭代,都会学习一个新的树模型,以减小目标函数。模型更新分为两个步骤:叶子节点分裂(Leaf Split)和叶子节点权重(Leaf Weight)的更新。

对于叶子节点 ,其分数  可以通过以下公式计算:

其中, 是叶子节点  上所有样本的一阶梯度之和, 是叶子节点  上所有样本的二阶梯度之和。而且,当叶子节点的分数确定后,可以使用优化算法(如近似贪婪算法)来选择最佳的分裂点。

5. 预测(Prediction)

最终模型的预测结果可以通过将所有树的输出值相加来获得:

这里的  表示第  个树对样本  的预测输出。

综上所述,XGBoost 通过优化目标函数,迭代地构建树模型,并通过贪婪算法对树的结构进行优化,从而得到强大的集成模型。

特点和适用场景

XGBoost作为一种高效的集成学习算法,这里给大家总结7个显著的特点:

1. 高效的并行化处理

  • XGBoost 能够有效地利用多核处理器进行并行计算,加速模型训练过程。

  • 它采用了一种分布式计算框架,使得在大规模数据集上的训练也能够快速完成。

2. 高度优化的损失函数

  • XGBoost 使用了泰勒展开式对损失函数进行近似,这样做能够更好地理解数据,从而更快地收敛到最优解。

  • 通过一阶和二阶导数信息,XGBoost 能够更加精确地估计每个样本的损失。

3. 正则化和剪枝

  • XGBoost 通过正则化项来控制模型的复杂度,防止过拟合。

  • 它采用了剪枝技术来减小树的规模,降低模型的复杂度,提高泛化能力。

4. 可扩展性和灵活性

  • XGBoost 可以与多种编程语言和数据处理框架(如Python、R、Spark)无缝集成。

  • 它支持自定义损失函数和评估指标,可以适应各种不同的任务和需求。

5. 特征重要性评估

  • XGBoost 提供了一种直观的方法来评估特征的重要性,可以帮助用户进行特征选择和模型解释。

6. 处理缺失值

  • XGBoost 能够自动处理缺失值,不需要对缺失值进行额外的处理或填充。

7. 支持多种目标函数

  • XGBoost 支持分类、回归、排序等多种类型的任务,可以灵活应对不同的问题。

XGBoost 最能解决的问题包括但不限于:

  • 分类问题:XGBoost 在处理分类问题时表现优异,能够有效地处理高维度特征和大规模数据集。

  • 回归问题:对于回归问题,XGBoost 能够提供精确的预测和较小的泛化误差。

  • 排序问题:在搜索引擎、推荐系统等需要排序的场景中,XGBoost 能够学习到有效的排序模型。

  • 异常检测:XGBoost 可以通过学习异常模式来进行异常检测,适用于金融欺诈检测、工业生产中的异常监测等场景。

  • 特征工程:XGBoost 能够自动处理缺失值和异常值,减少了特征工程的工作量。

  • 模型解释:XGBoost 提供了直观的特征重要性评估,可以帮助解释模型的预测结果。

完整案例

下面,是一个使用XGBoost算法进行二分类的完整案例,包括数据集、Python代码和结果可视化。我们使用鸢尾花数据集作为示例数据集,该数据集包含四个特征和三个类别。

案例流程
  1. 数据加载与预处理。

  2. 特征工程与数据分割。

  3. 使用XGBoost进行模型训练。

  4. 模型评估与可视化。

数据集

我们将使用鸢尾花数据集,该数据集包含150个样本,每个样本有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度,以及一个目标变量,代表鸢尾花的类别(Setosa、Versicolor和Virginica)。

代码
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D


if __name__ == '__main__':
    # 加载鸢尾花数据集
    iris = load_iris()
    X = iris.data
    y = iris.target

    # 数据分割
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # 设置XGBoost参数
    param = {'max_depth': 3, 'eta': 0.3, 'objective': 'multi:softmax', 'num_class': 3}
    num_round = 20

    # 训练模型
    dtrain = xgb.core.DMatrix(X_train, label=y_train)
    dtest = xgb.core.DMatrix(X_test, label=y_test)
    bst = xgb.train(param, dtrain, num_round)

    # 在测试集上进行预测
    y_pred = bst.predict(dtest)

    # 计算准确率
    accuracy = accuracy_score(y_test, y_pred)
    print("Accuracy:", accuracy)

    # 绘制混淆矩阵
    labels = ['Setosa', 'Versicolor', 'Virginica']
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='coolwarm', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    # 3D 可视化
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    colors = ['r', 'g', 'b']
    for i in range(3):
        ax.scatter(X_test[y_test == i, 0], X_test[y_test == i, 1], X_test[y_test == i, 2], c=colors[i], label=labels[i])

    ax.set_xlabel('Sepal Length')
    ax.set_ylabel('Sepal Width')
    ax.set_zlabel('Petal Length')
    ax.set_title('Iris Data 3D Visualization')

    plt.legend()
    plt.show()

在实际应用中,可以调整XGBoost的参数以获得更好的性能,例如使用交叉验证来选择最佳的参数组合。

其中,

scikit-learn==1.5.2
matplotlib==3.9.0
seaborn==0.13.2
xgboost==2.1.1

绘制混淆矩阵

图片

可视化

图片

最后

XGBoost是一种集成学习算法,基于决策树构建强大的预测模型。它通过迭代训练多个决策树模型,利用梯度提升技术不断优化模型性能。XGBoost在各种数据集上都表现出色,并且被广泛应用于分类和回归问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167659.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构练习题————(二叉树)——考前必备合集!

今天在牛客网和力扣上带来了数据结构中二叉树的进阶练习题 1.二叉搜索树与双向链表———二叉搜索树与双向链表_牛客题霸_牛客网 (nowcoder.com) 2.二叉树遍历————二叉树遍历_牛客题霸_牛客网 (nowcoder.com) 3.二叉树的层序遍历————102. 二叉树的层序遍历 - 力扣&am…

C# 一键清空控件值

_场景:_在任何一个Form表单的操作页面或者数据台账的查询页面,基本都会看到一个清除的按钮,其功能就是用来清除我们需要抛弃的已经写入到控件内的数据。如果一个个控件来处理的话,想必会非常麻烦,而且系统不单单只是一…

杭州算力小镇:AI泛化解锁新机遇,探寻AI Agent 迭代新路径

人工智能技术不断迭代,重点围绕着两个事情,一是数据,二是算力。 算法的迭代推动着AI朝向多模态的方向发展,使之能够灵活应对不同领域的不同任务,模型的任务执行能力大大提升,人工智能泛化能力被推上高潮。…

scala 2.12 error: value foreach is not a member of Object

如图所示:在scala 2.11的时候下面的不报错,但是在2.12下报错了 在scala2.12环境下错误如下: 经过查找资料得到:df 后面加上rdd 即可

kafka下载配置

下载安装 参开kafka社区 zookeeperkafka消息队列群集部署https://apache.csdn.net/66c958fb10164416336632c3.html 下载 kafka_2.12-3.2.0安装包快速下载地址分享 官网下载链接地址: 官网下载地址:https://kafka.apache.org/downloads 官网呢下载慢…

详细解读,F5服务器负载均衡的技术优势

在现代大规模、高流量的网络使用场景中,为应对高并发和海量数据的挑战,服务器负载均衡技术应运而生。但凡知道服务器负载均衡这一名词的,基本都对F5有所耳闻,因为负载均衡正是F5的代表作,换句通俗易懂的话来说&#xf…

需求: 通过后台生成的树形结构,返回给前台用于动态生成表格标题,并将对应标题下面的信息对应起来

1. 如图所以&#xff0c;完成以下内容对应 2. 代码示例如下&#xff0c; 动态生成树形结构列名称&#xff0c;并将表格中存在的值与其对应起来 /*** 查询资源计划列表** param resourcePlan 资源计划* return 资源计划*/Overridepublic Map<String, Object> selectResour…

程序编译的四个阶段

程序编译的四个阶段 #include <stdio.h>int main(){printf("Hello World~");return 0; } hello.c程序的生命周期从一个高级C语言程序开始&#xff0c;这种形式容易被人读懂。 但这无法直接被计算机读懂。为了在系统上运行hello.c程序&#xff0c;每条C语言都…

科研绘图系列:R语言组合多个图形

文章目录 介绍加载R包画图介绍 通过patchworkR包组合多个ggplot数据图形对象。 加载R包 library(ggplot2) library(patchwork)画图 画图theme_set(theme_bw() +theme(

2024年408真题计算机网络篇

1 https://zhuanlan.zhihu.com/p/721169467。最小割可以看作是切断水流的最薄弱环节——通过切断这些关键的“水管”&#xff0c;就可以完全阻止水从源点流到汇点。 在下列二进制数字调制方法中&#xff0c;需要2个不同频率载波 的是 A. ASK B. PSK C. FSK D. DPSK 解答…

Linux终端简介

Linux终端简介 导语基本终端交互终端读写标准/非标准模式重定向处理 终端对话 termios结构模式相关输入模式输出模式控制模式本地模式特殊控制字符终端速度其他函数 终端输出终端类型terminfo 击键动作检测虚拟控制台&伪终端总结参考文献 导语 本章基本是以一个简单的用户…

PlayerPerfs-不同平台的存储位置

一 .PlayerPrefs存储的数据存在哪里 不同平台存储位置不一样 Windows PlayerPrefs 存储在 HKCU\Software\[公司名称]\[产品名称] 项下的注册表中 其中公司和产品名称是 在“Project Settings”中设置的名称。 查看方法&#xff1a; 运行 regedit HKEY…

手游和应用出海资讯:三七新游首月收入突破700万元;领英尝试推出游戏功能以增加用户使用时长

NetMarvel帮助游戏和应用广告主洞察全球市场、获取行业信息&#xff0c;以下为9月第四周资讯&#xff1a; ● 《AFK Journey》收入突破 1.5 亿美元 ● 《黑神话&#xff1a;悟空》IGN年度游戏投票第一掉至第三 ● 三七发布新游首月收入突破700万元 ● 开罗游戏《哆啦A梦的铜锣烧…

Ubuntu22.04安装paddle

查看系统版本信息 使用命令lsb_release -a查看系统版本 rootLAIS01:~# lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 22.04.5 LTS Release: 22.04 Codename: jammy查看系统支持的cuda版本&#xff0c;使用命令nvidia-smi&#…

Llama 3.2来了,多模态且开源!AR眼镜黄仁勋首批体验,Quest 3S头显价格低到离谱

如果说 OpenAI 的 ChatGPT 拉开了「百模大战」的序幕&#xff0c;那 Meta 的 Ray-Ban Meta 智能眼镜无疑是触发「百镜大战」的导火索。自去年 9 月在 Meta Connect 2023 开发者大会上首次亮相&#xff0c;短短数月&#xff0c;Ray-Ban Meta 就突破百万销量&#xff0c;不仅让马…

HT6872 4.7W防削顶单声道D类音频功率放大器

■ 特点 防削顶失真功能(Anti-Clipping Function,ACF) 优异的全带宽EMI抑制性能 免滤波器数字调制&#xff0c;直接驱动扬声器 输出功率 1.40W(VDD3.6V,RL4Ω,THDN10%) 2.80W(VDD5.0V,RL4Ω,THDN10%) 4.70W(VDD6.5V,RL4Ω,THDN10%) 高信噪比SNR:95dB(VDD6.5V,Av24dB. THDN1%) 低…

监控IDS和IPS增强网络安全性

入侵检测系统&#xff08;IDS&#xff09;和入侵防御系统&#xff08;IPS&#xff09;是当今使用的最复杂的网络安全设备之一&#xff0c;它们检查网络数据包并阻止可疑数据包&#xff0c;并提醒管理员有关攻击企图的信息。 在当今威胁不断变化的网络环境中&#xff0c;防火墙…

学习threejs,添加环境光和点光源

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言二、&#x1f340;绘制任意字体模型…

人工智能之机器学习常见算法

摘要 之前一直对机器学习很感兴趣,一直没时间去研究,今天刚好是周末,有时间去各大技术论坛看看,刚好看到一篇关于机器学习不错的文章,在这里就分享给大家了. 机器学习无疑是当前数据分析领域的一个热点内容。很多人在平时的工作中都或多或少会用到机器学习的算法。这里IT经理网…

【Linux实践】实验六:LINUX系统管理

【Linux实践】实验六&#xff1a;LINUX系统管理 实验目的实验内容实验步骤及结果1. 包管理工具2. VMware Tools3. 修改主机名4. 网络配置① 临时修改② 永久修改 5. 查找文件6. 前后台执行7. 查看进程8. 结束进程 实验目的 4、掌握Linux下软件包管理&#xff0c;包括命令rpm、…