Xgboost分类模型的完整示例

news2025/1/17 17:42:03

在这里插入图片描述

往期精彩推荐

  • 数据科学知识库
  • 机器学习算法应用场景与评价指标
  • 机器学习算法—分类
  • 机器学习算法—回归
  • PySpark大数据处理详细教程

定义问题

UCI的蘑菇数据集的主要目的是为了分类任务,特别是区分蘑菇是可食用还是有毒。这个数据集包含了蘑菇的各种特征,如帽形、颜色、气味等,以及一个标签表示蘑菇是否有毒。通过对这些特征的分析,可以构建分类模型来预测任何一个蘑菇样本是否有毒。这种类型的任务对于练习数据科学和机器学习技能,尤其是分类算法的应用和理解,非常有帮助。

导入相关库

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report
from xgboost import XGBClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import shap

加载数据集

# 数据集的URL
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"

# 定义列名
column_names = ["class", "cap-shape", "cap-surface", "cap-color", "bruises", "odor",
                "gill-attachment", "gill-spacing", "gill-size", "gill-color", "stalk-shape",
                "stalk-root", "stalk-surface-above-ring", "stalk-surface-below-ring",
                "stalk-color-above-ring", "stalk-color-below-ring", "veil-type", "veil-color",
                "ring-number", "ring-type", "spore-print-color", "population", "habitat"]

# 加载数据集
mushroom_data = pd.read_csv(url, names=column_names)

# 查看数据集的前几行
mushroom_data.head(5)

在这里插入图片描述

数据探索

统计数据可以帮助您快速了解数据的分布情况、中心趋势和离散程度。在数据分析和机器学习的前期阶段,这是一个常用的探索性数据分析(EDA)步骤。

mushroom_data.describe()

在这里插入图片描述

数据预处理

去除异常值

# 可以根据均值标准差来定义异常值
均值 ± 3倍标准差  之外的定义为异常值

缺失值填充

mushroom_data = mushroom_data.fillna(0)

数值类型转换

# 将分类数据转换为数值
label_encoder = LabelEncoder()
for column in mushroom_data.columns:
    mushroom_data[column] = label_encoder.fit_transform(mushroom_data[column])

划分训练集与测试集

# 划分训练集和测试集
X = mushroom_data.drop('class', axis=1)
y = mushroom_data['class']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

均衡采样与模型定义

# 计算正负样本比例,在XGBoost模型中设置scale_pos_weight
scale_pos_weight = sum(y_train == 0) / sum(y_train == 1)

# 定义模型
xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight=scale_pos_weight)

随机搜索选参

定义XGBoost的参数搜索范围,并使用RandomizedSearchCV进行随机搜索,以找到最佳的超参数。

param_distributions = {
    'n_estimators': [100, 300, 500, 800, 1000],  # 表示树的个数。增加树的数量可以提高模型的复杂度,但也可能导致过拟合。
    'learning_rate': [0.01, 0.05, 0.1, 0.15, 0.2],  # 学习率,用于控制每棵树对最终结果的影响。较低的学习率意味着模型需要更多的树来进行训练。
    'max_depth': [3, 4, 5, 6, 7, 8],  # 树的最大深度。更深的树会增加模型的复杂度,但也可能导致过拟合。
    'min_child_weight': [1, 2, 3, 4],  # 决定最小叶子节点样本权重和。较大的值可以防止模型过于复杂,从而避免过拟合。
    'subsample': [0.6, 0.7, 0.8, 0.9, 1.0],  # 用于控制每棵树随机采样的比例,减少这个参数的值可以使模型更加保守,防止过拟合。
    'colsample_bytree': [0.6, 0.7, 0.8, 0.9, 1.0],  # 用于每棵树的训练时,随机采样的特征的比例。减少这个参数的值同样可以防止模型过于复杂。
    'gamma': [0, 0.1, 0.2, 0.3, 0.4]  # 后剪枝时,作为节点分裂所需的最小损失函数下降值。该参数值越大,算法越保守。
}

random_search = RandomizedSearchCV(
    xgb_model, param_distributions, n_iter=50, cv=5, random_state=42
)
random_search.fit(X_train, y_train)

best_params = random_search.best_params_
xgb_model.set_params(**best_params)

模型评估

使用找到的最佳参数训练XGBoost模型,然后在测试集上进行评估,计算性能指标如准确率、精确率、召回率和F1分数。

xgb_model.fit(X_train, y_train)
y_pred = xgb_model.predict(X_test)
# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
tn, fp, fn, tp = cm.ravel()
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

performance_df = pd.DataFrame({
        'TN': [tn], 
        'FP': [fp], 
        'FN': [fn], 
        'TP': [tp],
        'Accuracy': [accuracy_score(y_test, y_pred)],
        'Precision': [precision_score(y_test, y_pred)],
        'Recall': [recall_score(y_test, y_pred)],
        'F1 Score': [f1_score(y_test, y_pred)]
    })
# 展示模型评估结果
performance_df.head()

标准数据集,结果太完美了,实际数据集就差强人意了

在这里插入图片描述

模型特征重要性展示

# 获取特征重要性
feature_importances = xgb_model.feature_importances_

# 可视化特征重要性
plt.style.use("ggplot")
plt.barh(range(len(feature_importances)), feature_importances)
plt.yticks(range(len(X.columns)), X.columns)
plt.xlabel('Feature Importance')
plt.title('Feature Importance in XGBoost Model')
plt.show()

在这里插入图片描述

SHAP负例分析

SHAP库提供了多种可视化工具,可以帮助您更深入地了解模型的行为。使用SHAP库计算XGBoost模型的SHAP值,分析被模型错误分类为负例的情况,并通过可视化来理解影响这些预测的关键特征在进行负例分析时,您可以专注于那些被模型错误分类的样本,并使用SHAP值来探究背后的原因。

# 计算SHAP值
explainer = shap.Explainer(xgb_model, X_train)
shap_values = explainer(X_test)

# 可视化:展示单个预测的SHAP值
shap.initjs()
# shap.force_plot(explainer.expected_value, shap_values[0,:], X_test[0,:])

# 可视化:展示所有测试数据的SHAP值
shap.summary_plot(shap_values, X_test)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1348467.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WorkQueue模型

WorkQueues,也被称为任务队列模型。当消息处理比较耗时的时候,可能生产消息的速度会远远大于消息的消费速度。长此以往,消息就会堆积越来越多,无法及时的处理。此时就可以使用work模型:让多个消费者绑定到一个队列&…

IDEA错误: 找不到或无法加载主类 com.atguigu.springcloud.EurekaServer7001_App

第一种方法&#xff1a; 可以手动点击maven中的compile编译一下&#xff0c;如下图&#xff1a; 第二种方法&#xff1a; 在pom.xml文件中加入编译插件&#xff1a; <build><plugins><!-- 编译插件 --><plugin><artifactId>maven-compiler-plu…

matlab概率论例子

高斯概率模型&#xff1a; [f,xi] ksdensity(x): returns a probability density estimate, f, for the sample in the vector x. The estimate is based on a normal kernel function, and is evaluated at 100 equally spaced points, xi, that cover the range of the da…

如何在Linux系统中安装Redis

原本Redis官网提供了Windows和Linux两个版本&#xff0c;但从 2011-12-29 以后不再更新Windows版本&#xff08;https://github.com/dmajkic/redis/downloads&#xff09;&#xff0c;加之企业生产环境通常使用Linux系统&#xff0c;所以这里在Linux系统中演示如何安装Redis。 …

typescript,eslint,prettier的引入

typescript 首先用npm安装typescript&#xff0c;cnpm i typescript 然后再tsc --init生成tsconfig.json配置文件&#xff0c;这个文件在package.json同级目录下 最后在tsconfig.json添加includes配置项&#xff0c;在该配置项中的目录下&#xff0c;所有的d.ts中的类型可以在…

11 HAL库的硬件I2C驱动SI7006和AP3216C

引言&#xff1a; 本片文章想给大家分享一下使用HAL库驱动SI7006和AP3216C&#xff0c; 这两款常见的芯片的手册会在文章的末尾提供给大家。 一、SI7006和AP3216C简介 SI7006 SI7006是一款数字湿度和温度传感器&#xff0c;由Silicon Labs&#xff08;全称Silicon Laboratories…

【AI视频领域展望】未来视频行业:人工智能、5G和VR技术将如何改变视频制作和观看方式?

5G技术 5G技术的商用将会进一步推动物联网和视频行业的融合。通过5G技术&#xff0c;可以实现高清视频的实时传输和播放&#xff0c;为用户提供更加流畅和快速的观看体验。 5G视频的优势主要体现在以下几个方面&#xff1a; 更低的延迟&#xff1a;5G网络的延迟时间相比4G降低…

Plantuml之EBNF语法介绍(二十七)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

App.vue中引入自定义组件

components目录中定义组件&#xff1a;Person.vue 目录截图&#xff1a; Person.vue文件中内容&#xff1a; <template><div class"person"><h2>姓名&#xff1a;{{name}}</h2><h2>年龄&#xff1a;{{age}}</h2><!--定义了…

OSCHINA Gitee 联合呈现,《2023 中国开源开发者报告》正式发布,总结分非常帮,可以免费看的报告!

《2023 中国开源开发者报告》 详细地址&#xff1a; https://talk.gitee.com/report/china-open-source-2023-annual-report.pdf 不需要收费下载&#xff01;&#xff01; 其中大模型的部分总结的非常棒 gietee 也支持 AI 模型托管了 如何在 Gitee 上托管 AI 模型 https://…

使用WAZUH检测LD_PRELAOD劫持、SQL注入、主动响应防御

目录 1、检查后门 使用工具检测后门 1.chkrootkit 2.rkhunter 手动检查文件 检查ld.so.preload文件 2、检测LD_PRELOAD ubuntu配置 wazuh配置 3、检测SQL注入 ubuntu配置 攻击模拟 4、主动响应 wauzh的安装以及设置代理可以参考本篇&#xff1a;WAZUH的安装、设置…

【Qt之Quick模块】6. QML语法详解_3 QML对象特性

概述 每一个QML对象类型都包含一组已定义的特性。当进行实例时都会包含一组特性&#xff0c;这些特性是在对象类型中定义的。 一个QML文档中的对象类型声明了一个新的类型&#xff0c;即实例出一个类型。 其中包含以下特性。 the id attribute &#xff1a; id特性property a…

《教育观察》是什么级别的期刊?是正规期刊吗?能评职称吗?

教育类&#xff5c;《教育观察》知网收录 《教育观察》始终秉持“ 立足教育实践&#xff0c;展望教育未来”&#xff0c;致力于在教育实践中以“观察”为方法&#xff0c;以“观察者”为主体&#xff0c;以“新观察”为旨趣&#xff0c;打造从教育实践中洞察教育未来的教育研究…

深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第四节 参数传递对堆栈的影响 2

深入浅出图解C#堆与栈 C# Heaping VS Stacking 第四节 参数传递对堆栈的影响 2 [深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第一节 理解堆与栈](https://mp.csdn.net/mdeditor/101021023)[深入浅出图解C#堆与栈 C# Heap(ing) VS Stack(ing) 第二节 栈基本工作原理](htt…

c语言:打印平行四边形|练习题

一、题目 输入平行四边形的边数&#xff0c;用星号打印平行四边形 如图&#xff1a; 二、思路分析 图形分为两部分 1、左边的空格 2、右边的星号 因此&#xff0c;把空格和星号合起来&#xff0c;就是要求的图形 三、代码图片【带注释】 四、源代码【带注释】 #include <s…

你逛过凌晨四点的校园吗?--大四毕业生的年终总结

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 又是一年的年终总结&#xff0c;我也迎来了自己的毕业季&#xff0c;没错&#xff0c;我马上要毕业啦&#xff01;不知道大家是什么时候认识我的呢&#xff0c;又或者是第一次发现我~这一年&#xff0c;迎接过朝阳、拍下过…

手摸手系列之SpringBoot+Vue2项目整合高德地图实现车辆实时定位功能

前言 最近在做一个物流内陆运输的项目&#xff0c;其中的一个关键功能是根据车辆的GPS数据在页面上实时显示车辆位置信息。由于我们已经获得了第三方提供的GPS数据&#xff0c;所以接下来的任务是将这些数据整合到我们的系统中&#xff0c;并利用高德地图API来展示车辆的实时位…

机器学习分类

1. 监督学习 监督学习指的是人们给机器一大堆标记好的数据&#xff0c;比如&#xff1a; 一大堆照片&#xff0c;标记出哪些是猫的照片&#xff0c;哪些是狗的照片 让机器自己学习归纳出算法或模型 使用该算法或模型判断出其他没有标记的照片是否是猫或狗 上述流程如下图所…

解决windows系统找不到msvcr100.dll问题,vcomp100.dll缺失的5个解决方法

在日常使用计算机的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中之一就是“找不到vcomp100.dll”的错误。那么&#xff0c;vcomp100.dll究竟是什么文件&#xff1f;为什么会出现丢失的情况&#xff1f;本文将为您详细解析vcomp100.dll的作用、丢失原因以及提…

C++的面向对象学习(9):文件操作

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、类的封装的多文件实现回顾二、文件操作1.对文件进行操作需要头文件<fstream>2.操作文件的三大类方法&#xff1a;读、写、读写 三、实现文本文件的读、写…