机器学习:基于心脏病数据集的XGBoost分类预测

news2025/1/11 18:34:31

目录

一、简介

原理:

二、实战演练

1.数据准备

2.数据读取/载入

3.数据预处理

4.可视化处理

 5.对离散变量进行编码

6.模型训练与预测

 7.特征选择

8.通过调整参数获得更好的效果

核心参数调优

 网格调参法


一、简介

XGBoost(eXtreme Gradient Boosting)是一种梯度提升决策树(Gradient Boosting Decision Tree,GBDT)的实现,是目前最流行的机器学习算法之一,被广泛应用于各种任务,如分类、回归和排序等。它由陈天奇在2016年开发,是Boosting算法家族的成员之一,可以通过增量的方式训练模型,逐步提高模型的准确性。

与传统的决策树不同,XGBoost采用的是一种优化算法,即梯度提升算法(Gradient Boosting)。梯度提升算法是一种串行的集成方法,通过逐步训练多个弱分类器(即决策树),使它们逐渐变得更强大。在每一轮迭代中,它会计算损失函数的负梯度,作为新的训练目标,再训练一个弱分类器来拟合这个目标。最终,将所有弱分类器组合起来,形成一个强分类器。

XGBoost的优势在于它的高效性和准确性。它能够处理大规模的数据集和高维度的特征空间,且在处理稀疏数据时也表现良好。此外,XGBoost的模型训练速度快,可以处理大规模的数据集,在比赛中多次获得第一名。

总之,XGBoost是一个强大且高效的机器学习算法,广泛应用于各种领域,特别是在竞赛中和实际业务中都有着重要的应用。

原理:

XGBoost底层实现了GBDT算法,并对GBDT算法做了一系列优化:

  1. 对目标函数进行了泰勒展示的二阶展开,可以更加高效拟合误差。
  2. 提出了一种估计分裂点的算法加速CART树的构建过程,同时可以处理稀疏数据。
  3. 提出了一种树的并行策略加速迭代。
  4. 为模型的分布式算法进行了底层优化。

XGBoost是基于CART树的集成模型,它的思想是串联多个决策树模型共同进行决策。

那么如何串联呢?XGBoost采用迭代预测误差的方法串联。举个通俗的例子,我们现在需要预测一辆车价值3000元。我们构建决策树1训练后预测为2600元,我们发现有400元的误差,那么决策树2的训练目标为400元,但决策树2的预测结果为350元,还存在50元的误差就交给第三棵树……以此类推,每一颗树用来估计之前所有树的误差,最后所有树预测结果的求和就是最终预测结果!

XGBoost的基模型是CART回归树,它有两个特点:(1)CART树,是一颗二叉树。(2)回归树,最后拟合结果是连续值。

具体来说,XGBoost使用决策树作为基分类器,每个决策树都是通过梯度提升算法来训练的。在训练过程中,XGBoost会计算损失函数的负梯度,并用这个负梯度来训练一个新的决策树,通过不断地迭代,最终得到一个具有很强泛化能力的强分类器。

为了防止过拟合,XGBoost引入了正则化技术,包括L1正则化和L2正则化。L1正则化可以使模型更加稀疏,而L2正则化可以防止模型权重过大,从而避免过拟合。

除此之外,XGBoost还采用了一些优化技术,如缓存访问技术、数据压缩技术、多线程并行计算等,使得XGBoost在训练和预测速度上具有很高的效率。

二、实战演练

1.数据准备

下载阿里云提供的一个天气数据集,在pycharm之类的跑以下代码下载保存(原文是基于天气预测,举一反三学习就用心脏病这个数据集)

import requests

url = 'https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv'
response = requests.get(url)
with open('train.csv', 'wb') as f:
    f.write(response.content)

 最上面分别是:年龄、是否贫血、肌酸磷酸激酶、是否糖尿病、射血分数、是否高血压、血小板血清、creatine血清_钠、性别、是否吸烟、时间、是否死亡。

原文是预测是否明天下雨,这里就预测死亡了。

2.数据读取/载入

放同一目录下,直接读即可

##  基础函数库
import numpy as np 
import pandas as pd

## 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns

## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式

data = pd.read_csv('heart.csv')

 可以打印查看下

## 利用.info()查看数据的整体信息
data.info()

 基本上都是整形和浮点型。

3.数据预处理

心脏病数据没啥问题这里不再演示,以下是说明:

简单查看数据,如果有缺少的(NaN)就用-1填补上。

## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
data.head()

data = data.fillna(-1)
data.tail()

如果数据集中的负样本数量远大于正样本数量,这种常见的问题叫做“数据不平衡”问题,在某些情况下需要进行一些特殊处理。(像我这个负样本死亡为96没死亡为203就不用处理)

print(pd.Series(data['DEATH_EVENT']).value_counts())

## 对于特征进行一些统计描述
data.describe()

4.可视化处理

为了方便,先纪录数字特征与非数字特征:

numerical_features = [x for x in data.columns if data[x].dtype == np.float]
category_features = [x for x in data.columns if data[x].dtype != np.float and x != 'DEATH_EVENT']
## 选取三个特征与标签组合的散点可视化
sns.pairplot(data=data[['age',
'creatinine_phosphokinase',
'ejection_fraction'] + ['DEATH_EVENT']], diag_kind='hist', hue= 'DEATH_EVENT')
plt.show()

 从上图可以发现,在2D情况下不同的特征组合对于心脏病人是否死亡的散点分布,以及大概的区分能力。我认为ejection_fraction与其他特征的组合更具有区分能力(不太会看其实)

for col in data[numerical_features].columns:
    if col != 'DEATH_EVENT':
        sns.boxplot(x='DEATH_EVENT', y=col, saturation=0.5, palette='pastel', data=data)
        plt.title(col)
        plt.show()

 打印箱型图

可以得到不同类别在不同特征上的分布差异情况。

可以进行数据分析,比如分析吸烟与死亡的关系

tlog = {}
for i in category_features:
    tlog[i] = data[data['DEATH_EVENT'] == 1][i].dropna().value_counts()

flog = {}
for i in category_features:
    flog[i] = data[data['DEATH_EVENT'] == 0][i].dropna().value_counts()



plt.figure(figsize=(10,2))
plt.subplot(1,2,1)
plt.title('DEATH')
sns.barplot(x = pd.DataFrame(tlog['smoking'][:2]).sort_index()['smoking'], y = pd.DataFrame(tlog['smoking'][:2]).sort_index().index, color = "red")
plt.subplot(1,2,2)
plt.title('Not DEATH')
sns.barplot(x = pd.DataFrame(flog['smoking'][:2]).sort_index()['smoking'], y = pd.DataFrame(flog['smoking'][:2]).sort_index().index, color = "blue")
plt.show()

 5.对离散变量进行编码

由于XGBoost无法处理字符串类型的数据,我们需要一些方法讲字符串数据转化为数据。一种最简单的方法是把所有的相同类别的特征编码成同一个值,例如女=0,男=1,狗狗=2,所以最后编码的特征值是在[0,特征数量−1]之间的整数。除此之外,还有独热编码、求和编码、留一法编码等等方法可以获得更好的效果。

代码如下,但本文用的心脏病数据集都是整形和浮点型,因此不用处理。

## 把所有的相同类别的特征编码为同一个值
def get_mapfunction(x):
    mapp = dict(zip(x.unique().tolist(),
         range(len(x.unique().tolist()))))
    def mapfunction(y):
        if y in mapp:
            return mapp[y]
        else:
            return -1
    return mapfunction
for i in category_features:
    data[i] = data[i].apply(get_mapfunction(data[i]))

6.模型训练与预测

## 为了正确评估模型性能,将数据划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能。
from sklearn.model_selection import train_test_split

## 选择其类别为0和1的样本 (不包括类别为2的样本)
data_target_part = data['RainTomorrow']
data_features_part = data[[x for x in data.columns if x != 'RainTomorrow']]

## 测试集大小为20%, 80%/20%分
x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part, test_size = 0.2, random_state = 2020)
#查看标签数据
print(y_train[0:2],y_test[0:2])

# 打印修改后的结果
print(y_train[0:2],y_test[0:2])

导入XGBoost模型

## 导入XGBoost模型
from xgboost.sklearn import XGBClassifier
## 定义 XGBoost模型 
clf = XGBClassifier(use_label_encoder=False)
# 在训练集上训练XGBoost模型
clf.fit(x_train, y_train)

注意:控制台导入下载的时候要关掉梯子!

否则就有这种报错:WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'ProxyError('Cannot connect to proxy.', timeout('_ssl.c:1112: The handshake operation timed out'))': /pypi/web/simple/xgboost/

## 在训练集和测试集上分布利用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)
from sklearn import metrics

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

 7.特征选择

XGBoost的特征选择属于特征选择中的嵌入式方法,在XGboost中可以用属性feature_importances_去查看特征的重要度。

plt.figure(figsize=(8, 6))
sns.barplot(y=data_features_part.columns, x=clf.feature_importances_)
plt.show()

 从图中我们可以发现得病时间是决定是否死亡最重要的因素。

初次之外,我们还可以使用XGBoost中的下列重要属性来评估特征的重要性。

  • weight:是以特征用到的次数来评价
  • gain:当利用特征做划分的时候的评价基尼指数
  • cover:利用一个覆盖样本的指标二阶导数(具体原理不清楚有待探究)平均值来划分。
  • total_gain:总基尼指数
  • total_cover:总覆盖

 acc= 0.7833333333333333

 这些图同样可以帮助我们更好的了解其他重要特征。

8.通过调整参数获得更好的效果

以下是几个重要的参数

1. learning_rate: 有时也叫作eta,系统默认值为0.3。每一步迭代的步长,很重要。太大了运行准确率不高,太小了运行速度慢。
2. subsample:系统默认为1。这个参数控制对于每棵树,随机采样的比例。减小这个参数的值,算法会更加保守,避免过拟合, 取值范围零到一。
3. colsample_bytree:系统默认值为1。我们一般设置成0.8左右。用来控制每棵随机采样的列数的占比(每一列是一个特征)。
4. max_depth: 系统默认值为6,我们常用3-10之间的数字。这个值为树的最大深度。这个值是用来控制过拟合的。max_depth越大,模型学习的更加具体。

核心参数调优

1.eta[默认0.3]
通过为每一颗树增加权重,提高模型的鲁棒性。
典型值为0.01-0.2。

2.min_child_weight[默认1]
决定最小叶子节点样本权重和。
这个参数可以避免过拟合。当它的值较大时,可以避免模型学习到局部的特殊样本。
但是如果这个值过高,则会导致模型拟合不充分。

3.max_depth[默认6]
这个值也是用来避免过拟合的。max_depth越大,模型会学到更具体更局部的样本。
典型值:3-10

4.max_leaf_nodes
树上最大的节点或叶子的数量。
可以替代max_depth的作用。
这个参数的定义会导致忽略max_depth参数。

5.gamma[默认0]
在节点分裂时,只有分裂后损失函数的值下降了,才会分裂这个节点。Gamma指定了节点分裂所需的最小损失函数下降值。
这个参数的值越大,算法越保守。这个参数的值和损失函数息息相关。

6.max_delta_step[默认0]
这参数限制每棵树权重改变的最大步长。如果这个参数的值为0,那就意味着没有约束。如果它被赋予了某个正值,那么它会让这个算法更加保守。
但是当各类别的样本十分不平衡时,它对分类问题是很有帮助的。

7.subsample[默认1]
这个参数控制对于每棵树,随机采样的比例。
减小这个参数的值,算法会更加保守,避免过拟合。但是,如果这个值设置得过小,它可能会导致欠拟合。
典型值:0.5-1

8.colsample_bytree[默认1]
用来控制每棵随机采样的列数的占比(每一列是一个特征)。
典型值:0.5-1

9.colsample_bylevel[默认1]
用来控制树的每一级的每一次分裂,对列数的采样的占比。
subsample参数和colsample_bytree参数可以起到相同的作用,一般用不到。

10.lambda[默认1]
权重的L2正则化项。(和Ridge regression类似)。
这个参数是用来控制XGBoost的正则化部分的。虽然大部分数据科学家很少用到这个参数,但是这个参数在减少过拟合上还是可以挖掘出更多用处的。

11.alpha[默认1]
权重的L1正则化项。(和Lasso regression类似)。
可以应用在很高维度的情况下,使得算法的速度更快。

12.scale_pos_weight[默认1]
在各类别样本十分不平衡时,把这个参数设定为一个正值,可以使算法更快收敛。

 网格调参法

调节模型参数的方法有贪心算法、网格调参、贝叶斯调参等。这里我们采用网格调参,它的基本思想是穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果

## 从sklearn库中导入网格调参函数
from sklearn.model_selection import GridSearchCV

## 定义参数取值范围
learning_rate = [0.1, 0.3,]
subsample = [0.8]
colsample_bytree = [0.6, 0.8]
max_depth = [3,5]

parameters = { 'learning_rate': learning_rate,
              'subsample': subsample,
              'colsample_bytree':colsample_bytree,
              'max_depth': max_depth}
model = XGBClassifier(n_estimators = 20)

## 进行网格搜索
clf = GridSearchCV(model, parameters, cv=3, scoring='accuracy',verbose=1,n_jobs=-1)

clf = clf.fit(x_train, y_train)
## 在训练集和测试集上分布利用最好的模型参数进行预测

## 定义带参数的 XGBoost模型 
clf = XGBClassifier(colsample_bytree = 0.6, learning_rate = 0.3, max_depth= 8, subsample = 0.9)
# 在训练集上训练XGBoost模型
clf.fit(x_train, y_train)

train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

 更多调参技巧请参考:【机器学习笔记】【随机森林】【乳腺癌数据上的调参】_n_estimators_桜キャンドル淵的博客-CSDN博客


原文:A.机器学习入门算法(六)基于天气数据集的XGBoost分类预测_汀、人工智能的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/436401.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VOS3000 AXB模块工作原理

VOS AXB 模块适用于语音市场直连运营商或虚拟运营商 X 号平台的业务需求 与 VOS 系统无缝集成,无需独立服务器部署,节约硬件,网络成本 单机支持不低于 2,000 并发 AXB 呼叫,性能是市面常见 AXB 产品的 2-3 倍 支持设定在呼叫接…

Java阶段二Day04

Java阶段二Day04 文章目录 Java阶段二Day04截至此版本可实现的流程图为V9BirdBootApplicationClientHandlerDispatcherServletHttpServletResponseHttpServletRequest V10DispatcherServletHttpServletResponseMETA-INF / mime.types V11EmptyRequestExceptionClientHandlerHtt…

使用Vue脚手架【Vue】

3. 使用 Vue 脚手架 3.1 初始化脚手架 3.1.1 说明 Vue脚手架是Vue官方提供的标准化开发工具(开发平台)最新的版本是4.x文档:https://cli.vuejs.org/zh/ 3.1.2 具体步骤 第一步(仅第一次执行):全局安装…

Foresight对话:刘韧对谈王建硕、曾映龙、Joy Xue

Foresight 2023论坛现场 自2022年11月上线以来,OpenAI研发的ChatGPT一度风靡全球。面对这波AI浪潮,有些人拥抱了新趋势,有些人则担心会被取代,另一些人发掘其中的创业机遇和价值。创业是大浪淘沙的过程,目前以ChatGPT为…

Spring Boot概述(二)

1.SpringBoot整合Junit 1.搭建SpringBoot工程 2.引入starter-test起步依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope> </dependency>…

第二章IDEA快速上传项目到码云

文章目录 下载Git并配置邮箱上传到Github配置Git配置Github账号创建要上传码云的项目 上传到Gitee下载插件配置我们的Gitee账号 我们的IDEA功能很强大&#xff0c;所以肯定集成了快速上传项目到码云的功能 实际的开发中&#xff0c;代码都是采用IDE进行开发&#xff0c;所以我们…

day31—选择题

文章目录 1.在单处理器系统中&#xff0c;如果同时存在有12个进程&#xff0c;则处于就绪队列中的进程数量最多为&#xff08;D&#xff09;2.以下关于多线程的叙述中错误的是&#xff08;C&#xff09;3. 整数0x12345678&#xff0c;在采用bigendian中内存的排序序列是&#x…

AutoGPT是什么?超简单安装使用教程

1.AutoGPT 最近几天当红炸子鸡的是AutoGPT&#xff0c;不得不说AI发展真快啊&#xff0c;几天出来一个新东西&#xff0c;都跟不上时代的脚步了。 AutoGPT是一个开源的应用程序&#xff0c;展示了GPT-4语言模型的能力。这个程序由GPT-4驱动&#xff0c;自主地开发和管理业务。…

WIN10-22H2专业版_电脑维修人员专用装机系统镜像【03.27更新】

WIN10-22H2专业版是由站长亲自封装的电脑维修人员专用装机系统镜像&#xff0c;系统干净无广告&#xff0c;稳定长效不卡顿&#xff0c;适合电脑维修店用来维修电脑重装系统。此版本是WIN10系统里非常稳定的正式版本之一&#xff0c;适合在维修电脑时重装系统或者大批量装机使用…

OpenCV图像处理之傅里叶变换

文章目录 OpenCV图像处理之傅里叶变换图像处理之傅里叶变换流程图OpenCv图像处理之傅里叶变换OpenCv傅里叶变换之低通滤波OpenCv傅里叶变换之高通滤波 OpenCV图像处理之傅里叶变换 傅里叶变换&#xff1a;目的就是得到图像的低频和高频&#xff0c;然后针对低频和高频进行不同…

yolov5训练自己的目标检测模型

yolov5训练自己的目标检测模型 1.克隆项目并配置环境 1.1克隆项目 进入GitHub下载yolov5源码 点此进入 选择分支v5.0&#xff0c;并下载源码 anaconda激活相应环境 activate pytorch进入项目存放的地址 E: cd yolov5-master1.2 yolov5项目结构 ├── data&#xff1a;主…

信号生成和可视化——周期性/非周期性波形

信号生成和可视化 此示例说明如何使用 Signal Processing Toolbox™ 中提供的函数生成广泛使用的周期和非周期性波形、扫频正弦波和脉冲序列。尝试此示例Copy Command Copy Code 周期性波形 除了 MATLAB 中的 sin 和 cos 函数外&#xff0c;Signal Processing Toolbox™ 还…

客快物流大数据项目(一百一十五):熔断器 Spring Cloud Hystrix

文章目录 熔断器 Spring Cloud Hystrix 一、​​​​​​​Hystrix 简介 二、什么是雪崩效应

如何使用 Linux find 命令查找文件?

在Linux系统中&#xff0c;find命令是一个非常强大的工具&#xff0c;可以帮助用户查找文件或目录。这篇教程将向您展示如何使用Linux find命令来查找您需要的文件。 基本语法 在使用Linux find命令之前&#xff0c;您需要了解其基本语法。Linux find命令的基本语法如下&…

初识Java:数据类型与变量、运算符

哈喽大家好&#xff0c;这篇文章我将为大家分享关于Java的数据类型与变量和运算符。 文章目录 数据类型与变量数据类型整型类型byte类型short类型int类型long类型 浮点型字符类型布尔类型 变量浮点型变量布尔型变量类型转换隐式转化显式转化 运算符算术运算符增量运算符自增/自…

CSDN 周赛 47 期

CSDN 周赛 47 期 判断题单选题12 填空题编程题1、题目名称&#xff1a;最小差值&#xff08;30分&#xff09;2、题目名称&#xff1a;风险投资&#xff08;45分&#xff09; 小结 判断题 中国古代就发现并证明了勾股定理&#xff0c;并在《周髀算经》中出现了“勾三股四弦五”…

Linux 这4个进程相关的命令,太好用!

当您在Linux系统中管理进程时&#xff0c;了解一些进程监控命令是非常重要的。这些命令可以帮助您了解当前正在运行的进程以及它们的状态&#xff0c;从而更好地管理系统资源。下面是一些常用的Linux进程监控命令及其示例&#xff1a; 1、ps命令 ps命令可以列出当前正在运行的…

验证码登录开发----手机验证码登录

手机验证码登录 需求分析 为了方便用户登录&#xff0c;移动端通常都会提供通过手机验证码登录的功能 手机验证码登录的优点&#xff1a; 方便快捷、无需注册&#xff0c;直接登录使用短信验证码作为登录凭证&#xff0c;无需记忆密码安全 登录流程&#xff1a; 输入手机…

06-文章搜索页面

文章搜索页面 6-1&#xff1a;开篇 再上一章中&#xff0c;我们完成了 热搜首页 的开发&#xff0c;虽然经历了 ”千辛万苦“ &#xff0c;但是对大家来说&#xff0c;应该也是收获满满。 那么在这一章节&#xff0c;我们将会进入新的篇章&#xff0c;来到 文章搜索 页面的开…

【C++】深度剖析string类的底层结构及其模拟实现

文章目录 前言1. string的结构2. 构造、析构2.1 无参构造2.2 带参构造2.3 问题发现及修改c_stroperator []析构 2.4 合二为一 ——全缺省 3. 拷贝构造3.1 浅拷贝的默认拷贝构造3.2 深拷贝拷贝构造的实现 4. 赋值重载4.1 浅拷贝的默认赋值重载4.2 深拷贝赋值重载的实现 5. strin…