XGBoost模型的python实现

news2025/1/13 10:30:41

文章目录

  • 函数介绍
  • 实例
    • 二分类问题
    • 多分类问题

作者:李雪茸

函数介绍

实现 XGBoost 分类算法使用的是 xgboost 库的 XGBClassifier,具体参数如下:

  • 1、max_depth:给定树的深度,默认为3

  • 2、learning_rate:每一步迭代的步长,很重要。太大了运行准确率不高,太小了运行速度慢。我们一般使用比默认值小一点,0.1左右就好

  • 3、n_estimators:这是生成的最大树的数目,默认为100

  • 4、objective:给定损失函数,常用的有:
    – reg:linear– 线性回归
    – reg:logistic – 逻辑回归
    – binary:logistic – 二分类逻辑回归
    – binary:logitraw – 二分类逻辑回归
    – count:poisson – 计数问题的poisson回归

  • 5、booster:给定模型的求解方式,默认为:gbtree;可选参数:gbtree、gblinear,gbtree是采用树的结构来运行数据,而gblinear是基于线性模型

  • 6、gamma:指定了节点分裂所需的最小损失函数下降值。这个参数的值越大,算法越保守。范围: [0,∞]

  • 7、alpha:L1正则项的权重,推荐的候选值为:[0, 0.01~0.1, 1]

  • 8、lambda:L2正则项的权重,推荐的候选值为:[0, 0.1, 0.5, 1]

  • 9、num_class:用于设置多分类问题的类别个数

  • 10、min_child_weight:指定子节点中最小的样本权重和,如果一个叶子节点的样本权重和小于min_child_weight则拆分过程结束,默认值为1。

  • 11、subsample:默认值1,指定采样出 subsample * n_samples 个样本用于训练弱学习器。取值在(0, 1)之间,设置为1表示使用所有数据训练弱学习器。

  • 12、colsample_bytree:构建弱学习器时,对特征随机采样的比例,默认值为1

实例

from xgboost import XGBClassifier
from sklearn.datasets import load_iris
from sklearn.datasets import load_breast_cancer
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

二分类问题

# 举例(二分类)
cancer = load_breast_cancer()
x = cancer.data
y = cancer.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.333, random_state=0)  # 分训练集和验证集
model = XGBClassifier(max_depth=10,
                        learning_rate=0.01,
                        n_estimators=2000,
                        objective='binary:logistic',
                        nthread=-1,
                        gamma=0,
                        min_child_weight=1,
                        max_delta_step=0,
                        subsample=0.85,
                        colsample_bytree=0.7,
                        colsample_bylevel=1,
                        reg_alpha=0,
                        reg_lambda=1,
                        scale_pos_weight=1,
                        seed=1440)
model.fit(x_train, y_train,eval_metric='auc')# 'rmse’:用于回归任务 ;'mlogloss’,用于多分类任务;
                                             # 'error’,用于二分类任务; 'auc’,用于二分类任务
# 对测试集进行预测
y_pred = model.predict(x_test)
predictions = [round(value) for value in y_pred]
#计算准确率
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
print(f"\nXGBoost模型混淆矩阵为:\n{metrics.confusion_matrix(y_test,y_pred)}")

####绘制ROC曲线
fpr1,tpr1,threshold1 = roc_curve(y_test,y_pred)
roc_auc1 = auc(fpr1, tpr1)
lw = 2
plt.figure(figsize=(8, 5))
plt.plot(fpr1, tpr1, color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) 
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('XGBoost ROC')
plt.legend(loc="lower right")
plt.show()
print(f"\nXGBoost模型AUC值为:\n{roc_auc_score(y_test,y_pred)}")

在这里插入图片描述
在这里插入图片描述

多分类问题

###举例 (多分类)
# 加载样本数据集
iris = load_iris()
X,y = iris.data,iris.target
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=12343)
 
model = XGBClassifier(
    max_depth=3,
    learning_rate=0.1,
    n_estimators=100, # 使用多少个弱分类器
    num_class=3,
    booster='gbtree',
    gamma=0,
    min_child_weight=1,
    max_delta_step=0,
    subsample=1,
    colsample_bytree=1,
    reg_alpha=0,
    reg_lambda=1,
    seed=0 # 随机数种子
)
model.fit(X_train,y_train,eval_metric='mlogloss')# 'rmse’:用于回归任务 ;'mlogloss’,用于多分类任务;
                                                 # 'error’,用于二分类任务; 'auc’,用于二分类任务
 
# 对测试集进行预测
y_pred = model.predict(X_test)
predictions = [round(value) for value in y_pred]
#计算准确率
accuracy = accuracy_score(y_test, predictions)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
print(f"\nXGBoost模型混淆矩阵为:\n{metrics.confusion_matrix(y_test,y_pred)}")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/129336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot整合TKMyBatis实现增删改查

文章目录什么是TKMybatis?SpringBoot整合TKMybatis实体类注解TKMapper接口如何使用基本增删改操作批量查询和删除批量添加自定义查询条件ExampleExample 条件设置Example 使用什么是TKMybatis? TKMybatis 是基于Mybatis 框架开发的一个工具,…

[4]MQTT协议基础--下

1.QoS服务质量等级 MQTT服务质量(Quality of Service 缩写 QoS)正是用于告知物联网系统,哪些信息是重要信息需要准确无误的传输,而哪些信息不那么重要,即使丢失也没有问题。 MQTT协议有三种服务质量级别: QoS 0 – 最多发一次…

公司jmeter分享

一、数据库压测组件功能说明 1.JDBC Connection Configuration:jdbc连接配置(一个测试计划可以有多个 JDBC Connection) 2.Variable Name for created pool: 创建池的变量名 连接绑定的变量名,JMeter可以使用多个连接,每个连接绑定到不同的变量;通过引用不同的绑定变量…

安全防范语音通知实现方案

语音通知作为一种强提醒的信息通知方式,非常适合使用在安全防范语音通知场景中,可以有效避免用户错过重要信息。那安全防范语音通知怎么实现?这里互亿无线小编为大家做个详细介绍: 一、如何发送安全防范语音通知信息 互亿无线语…

本地事务、分布式事务、CAP 定理与 BASE 理论、分布式事务几种方案、Linux 安装 Seata、Seata的使用-56

一:本地事务 1.1 事务的基本性质 1.数据库事务的几个特性:原子性(Atomicity )、一致性( Consistency )、隔离性或独立性( Isolation)和持久性(Durabilily),简称就是 ACID; 原子性:一系列的操作整体不可拆分&#xf…

LVGL学习笔记8 - 字体

目录 1. 修改默认字体 2. 修改字体 3. 特殊字体 3.1 SUBPX字体 3.2 28像素压缩字体 3.3 16像素希伯来语/阿拉伯语/Perisan字母 3.4 16像素中文字体 3.5 8像素Ascii字体 3.6 16像素Ascii字体 3.7 内置图标 4. 超大字体 5. 编码方式 6. 添加字体 6.1 在线字体转换器 …

【微服务笔记01】微服务组件之Eureka注册中心的介绍及其基础环境的搭建

这篇文章,主要介绍微服务中的注册中心Eureka及其基础环境的搭建【源代码地址】。 目录 一、Eureka注册中心 1.1、什么是注册中心 1.2、注册中心原理 二、搭建Eureka注册中心环境 2.1、创建父工程,引入微服务依赖 2.2、创建Eureka服务端工程 &…

全球公开的DEM数据产品

1 简介 全球公开版地形数据包括:GTOPO30-DEM、ASTER-GDEM、SRTM90、ALOS-AW3D30等,其他的诸如World DEM及ALOS-AW3D (5m分辨率)等全球地形数据不能免费获得。 SRTM:由NASA 及国家地理空间情报局NGA采用2000年2月发射的“奋进号”…

word文件损坏打不开如何修复?文件丢失怎么办?

我们日常办公中,经常用到Word文档。但是有时会遇到word文件损坏、无法打开的情况。这时该怎么办?接着往下看,小编在这里就给大家带来Word文件修复的方法,以及Word文件丢失如何恢复的方法! 一、Word文件损坏怎么办 部分…

【vsan数据恢复】磁盘离线导致分布式存储瘫痪的数据恢复案例

vsan数据恢复环境: 一组4台服务器搭建vsan集群; 每台服务器配置有2组分别由6块硬盘组成的磁盘阵列,上层是虚拟机文件。 vsan故障: 在运行过程中,某一个节点的一块硬盘离线,vsan安全机制启动,开始…

梦想云图Node.JS服务 ( 最近更新时间:2022-12-30 10:04:50 )

说明 后台提供梦想Node.JS服务,方便调用控件后台功能,Windows服务程序所在目录:Bin\MxDrawServer\Windows,Linux服务程序所在目录:Bin\Linux\MxDrawServer 梦想云图Node.JS服务 ( 最近更新时间:2022-12-30 10:04:50 …

第三个脚本——时间加速and视频倍速

目录 本文主要内容 granr属性介绍 run-at属性 时间加速原理 视频倍速原理 完整示例 本文主要内容 介绍grant属性,run-at属性以及时间加速,视频倍速原理 granr属性介绍 相关函数四个: GM_setValue GM_getValue GM_listValues GM_del…

json基本使用与简介

一、简介 二、json两种构造结构 三、js解析JSON 1、JSON2解析JSON 2.用eval()方法把JSON字符串转化成JSON对象. 3.使用JSON2中的JSON对象的parser()方法解析JSON字符串 4. 使用JSON2中的JSON对象的stringify ()方法把JSON对象转换成字符串 5、案例 四、Java解…

【模型部署】人脸检测模型DBFace C++ ONNXRuntime推理部署(1)

系列文章目录 【模型部署】人脸检测模型DBFace C ONNXRuntime推理部署(0) 【模型部署】人脸检测模型DBFace C ONNXRuntime推理部署(1) 【模型部署】人脸检测模型DBFace C ONNXRuntime推理部署(2) 文章目录…

深度学习训练营之灵笼人物识别

深度学习训练营之灵笼人物识别原文链接环境介绍前置工作设置GPU导入数据数据查看数据预处理加载数据可视化数据检查数据配置数据集prefetch()功能详细介绍:归一化查看归一化后的数据构建VGG-19网络VGG优点VGG缺点利用官方给到的网络网络结构编译模型训练结果可视化预…

第03讲:GitHub的使用

一、创建远程仓库 访问GitHub官方网站,并创建账号,然后按照以下图示创建项目 复制仓库地址 二、远程仓库的操作 命令作用git remote -v查看当前所有远程地址别名git remote add 别名 远程地址起别名git push 别名 分支推送本地分支上的内容到远程仓库…

vscode+opencv+mingw+cmake配置vscode下的opencv环境

目录介绍安装VsCode安装mingw安装cmake安装opencv,以及其扩展库 opencv_contrib安装python利用cmake生成opencv的Makefile文件cmake命令进行编译,安装配置opencv环境变量配置VSCODE测试DEMO介绍 参考链接:https://www.cnblogs.com/czlhxm/p/…

教育行业回访话术

近些年来,随着知识经济的快速发展,教育市场呈现良好的增长态势。越来越多的人开始通过参加各种培训来提升自己,教育行业竞争十分激烈。 前言 近些年来,随着知识经济的快速发展,教育市场呈现良好的增长态势。而且由于受…

人力资源软件对中小企业的七点重要性

对于中小企业(SMB)来说,员工就意味着一切。你的员工几乎掌握着企业的整体增长和发展,他们可以成就企业,但也能破坏企业的发展。为了提高员工效率,中小型企业需要出色的人力资源管理。员工只有在受到重视和培…

被新手忽视的 自谐振频率点

在MHz的DCDC和RF LNA电路中,被新手忽视的 自谐振频率点Self-Resonant Frequency 计算公式为 下图显示了 1μF,封装为 1206 的陶瓷电容器的阻抗(MLCC有经典的V型阻抗-频率曲线。随着频率升高,寄生电感的影响开始凸显,阻…