机器学习实践(1.1)XGBoost分类任务

news2025/1/19 18:29:58

前言

XGBoost属于Boosting集成学习模型,由华盛顿大学陈天齐博士提出,因在机器学习挑战赛中大放异彩而被业界所熟知。相比越来越流行的深度神经网络,XGBoost能更好的处理表格数据,并具有更强的可解释性,还具有易于调参、输入数据不变性等优势。本文只做XGBoost分类任务的脚本实现,更多XGBoost内容请查看文末 附加——深入学习XGBoost

❤️ 本文完整脚本点此链接 百度网盘链接 获取 ❤️

结论先行
当使用 from xgboost import XGBClassifier 的模型进行训练时,使用的是sklearn中的XGBClassifier类,该方法中无需特意指定分类类别,方法自带类别数量n_classes_的计算,并根据数量指定了objective参数,简而言之:该方法会自动判别是多分类还是二分类任务,无需特殊说明。

之所以特殊强调from xgboost import XGBClassifier 是要区别于import xgboost as xgb 的调用,本人更建议使用前者。

下方代码取自 sklearn.py 的 class XGBClassifier(XGBModel, XGBClassifierBase)…

# 计算类别数量
import cupy as cp  # pylint: disable=E0401

self.classes_ = cp.unique(y.values)
self.n_classes_ = len(self.classes_)

......

# objective参数的选择
if callable(self.objective):
   obj: Optional[
       Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
   ] = _objective_decorator(self.objective)
   # Use default value. Is it really not used ?
   params["objective"] = "binary:logistic"
else:
   obj = None

if self.n_classes_ > 2:
   # Switch to using a multiclass objective in the underlying XGB instance
   if params.get("objective", None) != "multi:softmax":
       params["objective"] = "multi:softprob"
   params["num_class"] = self.n_classes_

一.轻松实现多分类

1.1导入第三方库、数据集

# 导入第三方库,包括分类模型、数据集、数据集分割方法、评估方法
from xgboost import XGBClassifier  # 分类模型
from sklearn import datasets  # 数据集
from sklearn.model_selection import train_test_split  # 数据集分割方法
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score, \
    classification_report  # 评估方法
import xgboost as xgb

# 导入sklearn的鸢尾花卉数据集,作为模型的训练和验证数据
data = datasets.load_iris()

# 数据划分,按照7 3分切割数据集为训练集和验证集,其中最终4个结果依次为训练数据、验证数据、训练数据的标签、验证数据的标签
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3,random_state=123)

sklearn的鸢尾花卉数据集共150个数据样本,73切分后,训练集105个数据样本,验证集45个数据样本。数据集中包括 样本特征data(4个特征)、样本标签target(3类标签)、标签名称target_names([‘setosa’, ‘versicolor’, ‘virginica’])、特征名称feature_names([‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’])、以及数据集位置filename(~~~\anaconda\lib\site-packages\sklearn\datasets\data\iris.csv)

数据集的 部分特征数据 如下:
在这里插入图片描述
数据集的 标签数据 及 标签名称 如下:在这里插入图片描述
数据集文件所在本地地址,其他波士顿房价等数据也在此文件夹下
在这里插入图片描述

1.2模型训练、验证

# 默认参数的模型
model = XGBClassifier()

# 调参见 附加1 的文章内容
# model = XGBClassifier(booster='gbtree',
#                       n_estimators=20,  # 迭代次数
#                       learning_rate=0.1,  # 步长
#                       max_depth=5,  # 树的最大深度
#                       min_child_weight=1,  # 决定最小叶子节点样本权重和
#                       subsample=0.8,  # 每个决策树所用的子样本占总样本的比例(作用于样本)
#                       colsample_bytree=0.8,  # 建立树时对特征随机采样的比例(作用于特征)典型值:0.5-1
#                       nthread=4,
#                       seed=27,  # 指定随机种子,为了复现结果
#                       # num_class=4,  # 标签类别数
#                       # objective='multi:softmax',  # 多分类
#                       )

# 模型训练
model.fit(X_train, y_train, verbose=True)

# 模型对验证数据做预测 y_pred 预测结果,y_proba 预测各类别概率,y_pred 是softmax(y_proba) 的结果 
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)

# 为了便于观察验证数据的预测结果,写个循环传个参
for m, n, p in zip(y_proba, y_pred, y_test):
    if n == p:
        q = '预测正确'
    else:
        q = '预测错误'
        print('预测概率为{0}, 预测概率为{1}, 真是结果为{2}, {3}'.format(m, n, p, q))

# 准确率
accuracy = accuracy_score(y_test, y_pred)
print('Accuracy:%.2f%%' % (accuracy * 100))

打印运行结果如下
在这里插入图片描述

二.实现二分类同样简单

2.1导入数据集、训练、验证

导入数据集、训练、验证 与多分类完全一样,唯一需要改变的是将数据标签由n_classes>2转为n_classes=2

# 数据划分
X_train, X_test, y_train, y_test = ......  
# 接

# 训练集和验证集的标签都改成0和1
y_train = [1 if y > 0 else 0 for y in y_train]

y_test = [1 if y > 0 else 0 for y in y_test]

# 计算训练数据类别
n_classes = len(set(y_train))

# 接
# 默认参数的模型
model = XGBClassifier()
......

2.2再做模型训练和验证

内容与1.2完全一致,不再赘述,直接看结果
二分类任务预测概率有两个,预测概率通过sigmoid做出预测结果。多分类预测概率有多个,预测概率通过softmax做出预测结果。
在这里插入图片描述

2.3重要结论映证

"""模型参数打印"""
bst = xgb.Booster(model_file='xgb_classifier_model.model')
# print(bst.attributes())

print('模型参数值-开始'.center(20, '='))
for attr_name, attr_value in bst.attributes().items():
    # scikit_learn 的参数逐一解析
    if attr_name == 'scikit_learn':
        import json
        dict_attr = json.loads(attr_value)
        # 打印 模型 scikit_learn 参数
        for sl_name, sl_value in dict_attr.items():
            if sl_value is not None:
                print(f"{sl_name}:{sl_value}")
    else:
        print(f"{attr_name}:{attr_value}")
print('模型参数值-结束'.center(20, '='))

下图展示model = XGBClassifier()未指定参数情况下,模型的参数情况打印,其中objectiveclasses_参数映证了最前面的结论。
在这里插入图片描述

三.模型评估

评估方法更详细解释,见附加3的文章内容

def metrics_sklearn(class_num, y_valid, y_pred_, y_prob):
    """模型效果评估"""
    # 准确率
    # 准确度 accuracy_score:分类正确率分数,函数返回一个分数,这个分数或是正确的比例,或是正确的个数,不考虑正例负例的问题,区别于 precision_score
    # 一般不直接使用准确率,主要是因为类别不平衡问题,如果大部分是negative的 而且大部分模型都很容易判别出来,那准确率都很高, 没有区分度,也没有实际意义(因为negative不是我们感兴趣的)
    accuracy = accuracy_score(y_valid, y_pred_)
    print('Accuracy:%.2f%%' % (accuracy * 100))

    # 精准率
    if class_num == 2:
        precision = precision_score(y_valid, y_pred_)
    else:
        precision = precision_score(y_valid, y_pred_, average='macro')
    print('Precision:%.2f%%' % (precision * 100))

    # 召回率
    # 召回率/查全率R recall_score:预测正确的正样本占预测正样本的比例, TPR 真正率
    # 在二分类任务中,召回率表示被分为正例的个数占所有正例个数的比例;如果是多分类的话,就是每一类的平均召回率
    if class_num == 2:
        recall = recall_score(y_valid, y_pred_)
    else:
        recall = recall_score(y_valid, y_pred_, average='macro')
    print('Recall:%.2f%%' % (recall * 100))

    # F1值
    if class_num == 2:
        f1 = f1_score(y_valid, y_pred_)
    else:
        f1 = f1_score(y_valid, y_pred_, average='macro')
    print('F1:%.2f%%' % (f1 * 100))

    # auc曲线下面积
    # 曲线下面积 roc_auc_score 计算AUC的值,即输出的AUC(二分类任务中 auc 和 roc_auc_score 数值相等)
    # 计算auc,auc就是曲线roc下面积,这个数值越高,则分类器越优秀。这个曲线roc所在坐标轴的横轴是FPR,纵轴是TPR。
    # 真正率:TPR = TP/P = TP/(TP+FN)、假正率:FPR = FP/N = FP/(FP+TN)
    # auc:不支持多分类任务 计算ROC曲线下的面积
    # 二分类问题直接用预测值与标签值计算:auc = roc_auc_score(Y_test, Y_pred)
    # 多分类问题概率分数 y_prob:auc = roc_auc_score(Y_test, Y_pred_prob, multi_class='ovo') 其中multi_class必选
    if class_num == 2:
        auc = roc_auc_score(y_valid, y_pred_)
    else:
        auc = roc_auc_score(y_valid, y_prob, multi_class='ovo')
        # auc = roc_auc_score(y_valid, y_prob, multi_class='ovr')
    print('AUC:%.2f%%' % (auc * 100))

    # 评估效果报告
    print(classification_report(y_test, y_pred, target_names=['0:setosa', '1:versicolor', '2:virginica']))


"""模型效果评估"""
n_classes = len(set(y_train))
metrics_sklearn(n_classes, y_test, y_pred, y_proba)

多分类模型效果评估结果入下图所示
在这里插入图片描述
本文完整脚本可通过百度网盘链接 获取

附加——深入学习XGBoost

附加1.模型调参、训练、保存、评估和预测

见《XGBoost模型调参、训练、评估、保存和预测》 ,包含模型脚本文件

附加2.算法原理

见《XGBoost算法原理及基础知识》 ,包括集成学习方法,XGBoost模型、目标函数、算法,公式推导等

附加3.分类任务的评估指标值详解

见《分类任务评估1——推导sklearn分类任务评估指标》,其中包含了详细的推理过程;
见《分类任务评估2——推导ROC曲线、P-R曲线和K-S曲线》,其中包含ROC曲线、P-R曲线和K-S曲线的推导与绘制;

附加4.模型中树的绘制和模型理解

见《Graphviz绘制模型树1——软件配置与XGBoost树的绘制》,包含Graphviz软件的安装和配置,以及to_graphviz()和plot_trees()两个画图函数的部分使用细节;
见《Graphviz绘制模型树2——XGBoost模型的可解释性》,从模型中的树着手解释XGBoost模型,并用EXCEL构建出模型。

❤️ 机器学习内容持续更新中… ❤️


声明:本文所载信息不保证准确性和完整性。文中所述内容和意见仅供参考,不构成实际商业建议,可收藏可转发但请勿转载,如有雷同纯属巧合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653654.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

若依微服务 + seata1.5.2版本分布式事务(安装配置nacos+部署)

若依官方使用的1.4.0版本seata,版本较低配置相对更麻烦一些 一、seata服务端下载,下载方式介绍两种入口,如下: 1、找到对应版本,下载 binary 即可。 下载包名为:seata-server-1.5.2.zip 2. github上下载 …

WWDC2023 Metal swift 头显ARKit支持c c++ 开发

1 今年WWDC,我们看见了苹果的空间计算设备,visionOS也支持了c c API. 这有什么好处呢,不是说能够吸引更多c c开发者加入苹果开发者阵营,而是我们过去的很多软件,可以轻松对接到苹果的头显设备,让我们的软件…

2023年协议转让研究报告

第一章 概述 协议转让作为中国企业破产法中的重要程序之一,已经在实践中得到广泛应用。在协议转让过程中,债务人与债权人或其他相关方通过协商达成一致,将特定的资产或权益进行转让,以实现债务清偿或债务人的破产清算。协议转让的…

RRC重建比率高问题分析和优化方法

PART01 1、重建概述 RRC重建(RRC connection re-establishment)是UE处于RRC_CONNECTED状态,因为一些移动性管理或底层链路故障,导致连接中断,UE发起的空口资源重新建立的过程,以继续空口的RRC连接。重建是…

[CSP-S 2021] 回文

[CSP-S 2021] 回文 题目描述: 给定正整数 n 和整数序列 a1​,a2​,…,a2n​,在这 2n 个数中,1,2,…,n 分别各出现恰好 2 次。现在进行 2n 次操作,目标是创建一个长度同样为 2n 的序列 b1​,b2​,…,b2n​,初始时 b 为空序列&…

【SpringCloud入门】-- SpringCloud优质组件介绍

目录 1. SpringCloud优质项目 2. 介绍SpringCloud优质项目 SpringCloudConfig(Spring) SpringCloudBus Eureka Hystrix Zuul Archaius Consul SpringCloudForCloudFoundry SpringCloudSleuth SpringCloudDataFlow SpringCloudSecurity SpringCloudZookeeper Spr…

【Redis】孔夫子旧书网爬虫接入芝麻代理IP:代理IP利用效率最大化

背景: 之前用过芝麻IP,写过这几篇文章 《【Python】芝麻HTTP代理系列保姆级全套攻略(对接教程自动领取每日IPIP最优算法)》 《【Python】记录抓包分析自动领取芝麻HTTP每日免费IP(成品教程)》 《爬虫增加代理池:使用稳…

ICC2:自定义快捷键和菜单

把一些常用的功能放在一个菜单里是什么体验?直接放在工具栏里是不是更方便?那设置成快捷键呢? gui_create_menu 自定义菜单可以把工具常用的功能放到一个菜单里,用户也可以把“执行脚本操作”加到菜单里。 举例来说: 1)把Editor Toolbox放到Favorite菜单里,floorplan 操…

行业报告 | AIGC发展研究

原创 | 文 BFT机器人 01 技术篇 深度学习进化史:知识变轨 风起云涌 已发生的关键步骤: 人工神经网络的诞生 反向传播算法的提出 GPU的使用 大数据的出现 预训练和迁移学习 生成对抗网络 (GAN) 的发明 强化学习的成功应用 自然语言处理的突破 即将发生的关键…

MinGW-w64安装和使用_亲测有效

MinGW-w64 是什么!? MinGW-w64 是一个在 Windows 系统上运行的 GNU 编译器套件,支持 C 和 C 语言的编译。它包括了 GCC 编译器、GNU Binutils 和一些其他的工具。在 MinGW-w64 中 各个版本的参数含义如下: x86_64:表…

1.ORB-SLAM3系统概述

1.内容简介 本系列文章主要基于ORB-SLAM3代码、论文以及相关博客,对算法原理进行总结和梳理。 ORB-SLAM系列整体架构是不变的,都包含Tracking、LocalMapping和LoopClosing三个核心线程,中间伴随着优化过程。在ORB-SLAM3算法中比较突出的改进…

腾讯安全董志强:四大关键步骤促进数据安全治理闭环,提升企业免疫力

高速发展的数字时代,数据已成为推动产业发展的最重要生产要素之一,真正成为了创造经济财富的数字能源,守护数据资产的安全成为企业高质量发展不可回避的重要命题。 6月13日,腾讯安全联合IDC发布“数字安全免疫力”模型框架&#…

我被一家无货源电商培训公司骗了怎么办?

我是卢松松,点点上面的头像,欢迎关注我哦! 最近,一位被无货源电商培训骗的人找到了卢松松,她说: 老师,你好,我是被无货源电商课程骗了的受害人,走投无路了,想…

5个超好用的开源工具库分享~

在实际项目开发中,从稳定性和效率的角度考虑,重复造轮子是不被提倡的。但是,自己在学习过程中造轮子绝对是对自己百利而无一害的,造轮子是一种特别能够提高自己系统编程能力的手段。 今天分享几个我常用的开源工具库:…

大佬们都是如何编写测试方案的?

目录 1、背景 2、编写的方式 2.1 第一阶段:在需求评审开始前 2.2 第二阶段:在需求评审开始后,技术方案设计中 2.3 第三阶段:技术方案设计后 2.4 第四阶段:测试方案评审前 2.5 第五阶段:测试方案评审…

Opencv-C++笔记 (7) : opencv-文件操作XML和YMAL文件

文章目录 一、概述二、文件操作三、打开文件四、写入五、读写个人源码 一、概述 除了图像数据之外,有时程序中的尺寸较小的Mat类矩阵、字符串、数组等 数据也需要进行保存,这些数据通常保存成XML文件或者YAML文件。本小节中将介绍如何利用OpenCV 4中的函…

前端实现消息推送、即时通信、http简介

信息推送 服务端主动向客户端推送消息,使客户端能够即时接收到信息。 场景 页面接收到点赞,消息提醒聊天功能弹幕功能实时更新数据功能 实现即时通讯方式 短轮询 浏览器(客户端)每隔一段时间向服务器发送http请求,…

Google为TensorFlow设计的专用集成电路TPU3.0图片

Widrow也是在Minsky的影响下进入AI领域的,后来加入斯坦福大学任教。他在1960年提出了自适应线性单元(Adaline),一种和感知器类似的单层神经网络,用求导数方法来调整权重,所以说有“三十年神经网络经验”并不…

CI/CD 流水线 (FREE)

流水线是持续集成、交付和部署的顶级组件。 流水线包括: 工作,定义做什么。例如,编译或测试代码的作业。阶段,定义何时运行作业。例如,在编译代码的阶段之后运行测试的阶段。 作业由 runners 执行。如果有足够多的并…

Qt编写视频监控系统79-四种界面导航栏的设计

一、前言 最初视频监控系统按照二级菜单的设计思路,顶部标题栏一级菜单,左侧对应二级菜单,最初采用图片在上面,文字在下面的按钮方式展示,随着功能的增加,二级菜单越来越多,如果都是这个图文上…