CatBoost模型Python代码——用CatBoost模型实现机器学习

news2024/11/14 2:55:29

一、CatBoost模型简介

1.1适用范围

CatBoost(Categorical Boosting)是一种基于梯度提升的机器学习算法,特别适用于处理具有类别特征的数据集。它可以用于分类、回归和排序任务,并且在处理具有大量类别特征的数据时表现优异。典型应用包括但不限于:

  • 电子商务中的推荐系统
  • 客户行为分析
  • 财务风险评估
  • 医疗数据分析
1.2原理

CatBoost使用梯度提升决策树(GBDT)作为其核心算法。其主要特点包括:

  1. 处理类别特征:CatBoost原生支持类别特征,并在内部使用目标编码(target encoding)来处理它们,从而减少了类别变量处理的复杂性。
  2. 顺序增强(Ordered Boosting):在构建每棵树时,CatBoost通过引入一种新的顺序提升方法来避免传统梯度提升中的预测偏差问题。
  3. 随机分片:为了进一步减少过拟合,CatBoost在每次树构建时随机分割数据集。
1.3优点
  • 高效处理类别特征:无需复杂的预处理步骤。
  • 减少过拟合:通过顺序增强和随机分片技术。
  • 易于使用:内置了许多默认的优化参数,适合初学者和快速原型开发。
  • 高性能:在许多实际应用中表现优于其他GBDT算法(如XGBoost和LightGBM)。
1.4缺点
  • 模型训练时间较长:尽管有许多优化,训练时间可能比其他简单模型更长。
  • 内存占用较高:在处理大规模数据时,内存需求较大。

二、实现CatBoost模型的Python代码

下面是一个使用CatBoost进行分类任务的完整Python代码示例,包含详细注释。

2.1导入必要的包和测试数据
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns

# 加载Titanic数据集
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
data = pd.read_csv(url)

# 查看数据集的列名
print("Columns in the dataset:", data.columns)
2.2简单的数据预处理
# 简单的数据预处理
# 填充缺失值
# data['Age'].fillna(data['Age'].median(), inplace=True)
# data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)

# 将Sex和Embarked转换为类别型特征
data['Sex'] = data['Sex'].astype('category')
# data['Pclass'] = data['Pclass'].astype('Pclass')

# 选择特征和目标
features = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']
target = 'Survived'

X = data[features]
y = data[target]
2.3构建CatBoost模型
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建CatBoost数据池
categorical_features = ['Sex', 'Pclass']
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)

# 初始化并训练CatBoost分类器
model = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.1,
    depth=6,
    loss_function='Logloss',  # 二分类任务使用'Logloss'
    verbose=100  # 每100次迭代打印一次信息
)

# 训练模型
model.fit(train_pool)

# 在测试集上进行预测
y_pred = model.predict(test_pool)
y_pred_proba = model.predict_proba(test_pool)[:, 1]
2.4模型评估
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred))

模型评估输出结果如下 :

0:	learn: 0.6538633	total: 159ms	remaining: 2m 39s
100:	learn: 0.2814504	total: 891ms	remaining: 7.93s
200:	learn: 0.2007734	total: 1.68s	remaining: 6.68s
300:	learn: 0.1536222	total: 2.45s	remaining: 5.69s
400:	learn: 0.1220845	total: 3.19s	remaining: 4.77s
500:	learn: 0.0961718	total: 3.95s	remaining: 3.93s
600:	learn: 0.0810769	total: 4.7s	remaining: 3.12s
700:	learn: 0.0694396	total: 5.45s	remaining: 2.33s
800:	learn: 0.0598153	total: 6.2s	remaining: 1.54s
900:	learn: 0.0527771	total: 6.93s	remaining: 761ms
999:	learn: 0.0474017	total: 7.67s	remaining: 0us
Accuracy: 0.8033707865168539
              precision    recall  f1-score   support

           0       0.84      0.85      0.84       111
           1       0.74      0.73      0.74        67

    accuracy                           0.80       178
   macro avg       0.79      0.79      0.79       178
weighted avg       0.80      0.80      0.80       178

Feature: Pclass, Importance: 16.480181005946406
Feature: Sex, Importance: 24.322199798316337
Feature: Age, Importance: 27.28642174968946
Feature: Siblings/Spouses Aboard, Importance: 5.125530737270014
Feature: Parents/Children Aboard, Importance: 3.006729091175773
Feature: Fare, Importance: 23.77893761760206
2.5可视化特征重要性(可选)
# 可视化特征重要性(可选)
plt.figure(figsize=(10, 6))
plt.barh(X.columns, feature_importances)
plt.xlabel('Feature Importance')
plt.title('CatBoost Feature Importances')
plt.show()

特征重要性输出结果如下:

 2.6绘制混淆矩阵
# 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

绘制混淆矩阵输出结果如下:

2.7绘制ROC曲线
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

绘制ROC曲线输出结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FPGA:3-8译码器的设计

1、什么是3-8译码器? 3-8译码器,顾名思义,三个输入,八个输出,构成3-8译码器。根据二进制特性,三位二进制数有八种可能,对应的真值表如下所示(该译码器输出低电平有效): 3-8译码器(…

kail-linux如何使用NAT连接修改静态IP

1、Contos修改静态IP vi /etc/sysconfig/network-scripts/ifcfg-ens33, 标记红色处可能序号会变动 参考linux配置网络不通解决方案_kylinv10sp2 网关不通-CSDN博客https://tanrt06.blog.csdn.net/article/details/132430485?spm1001.2014.3001.5502 Kail时候NAT连…

Linux - 进程的概念、状态、僵尸进程、孤儿进程及进程优先级

进程基本概念 课本概念:在编程或软件工程的上下文中,进程通常被视为正在执行的程序的实例。当你启动一个应用程序时,操作系统会为这个程序创建一个进程。每个进程都有自己的独立内存空间,可以运行自己的指令序列,并可能…

自然学习法和科学学习法

一、自然学习法 自然学习法:什么事自然学习法,特意让kimi来回答了一下。所谓的自然学习法说的俗一点就是野路子学习方法。这种学习方法的特点是“慢”“没有系统性”,学完之后感觉都会了,但是又感觉什么都不会。 二、科学学习法 …

FastAPI(六十七)实战开发《在线课程学习系统》接口开发--用户登陆接口开发

源码见:"fastapi_study_road-learning_system_online_courses: fastapi框架实战之--在线课程学习系统" 接上一篇文章FastAPI(六十六)实战开发《在线课程学习系统》接口开发--用户注册接口开发。这次我们分享实际开发--用户登陆接口…

中望CAD 专业 v2024 解锁版下载与安装教程 (CAD三维制图)

前言 中望CAD软件(ZWCAD)是一款源自国内的自主研发CAD制图软件,提供二三维CAD功能,专注于机械设计制图领域。其最新版本,中望CAD采用了国际领先的CAD核心技术,不断优化软件性能和用户体验,并加…

.netcore TSC打印机打印

此文章给出两种打印案例, 第一种是单列打印,第二种是双列打印 需要注意打印机名称的设置,程序中使用的打印机名称为999,电脑中安装打印机时名称也要为999。 以下是我在使用过程中总结的一些问题: 一 TSC打印机使用使…

谷粒商城实战笔记-跨域问题

一,When allowCredentials is true, allowedOrigins cannot contain the special value “*” since that cannot be set on the “Access-Control-Allow-Origin” response header. To allow credentials to a set of origins, list them explicitly or consider u…

PostgreSQL 中如何处理数据的唯一性约束?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会!📚领书:PostgreSQL 入门到精通.pdf 文章目录 PostgreSQL 中如何处理数据的唯一性约束?一、什么是唯一性约束二、为什么要设置唯一性约束…

基于A律压缩的PCM脉冲编码调制通信系统simulink建模与仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1A律压缩的原理 4.2 PCM编码过程 4.3 量化噪声与信噪比 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 matlab2022a 3.部分核心程序 &#…

Atom - hackmyvm

简介 靶机名称:Atom 难度:简单 靶场地址:https://hackmyvm.eu/machines/machine.php?vmAtom 本地环境 虚拟机:vitual box 靶场IP(Atom):192.168.56.101 跳板机IP(windows 11)&#xff1…

MySQL面试篇章——MySQL索引

文章目录 MySQL 索引索引分类索引创建和删除索引的执行过程explain 查看执行计划explain 结果字段分析 索引的底层实现原理B-树B树哈希索引 聚集和非聚集索引MyISAM(\*.MYD,*.MYI)主键索引辅助索引(二级索引) InnoDB&a…

线程的中互斥锁和条件变量的运用

第一题&#xff1a;使用互斥锁或者信号量&#xff0c;实现一个简单的生产者消费者模型 一个线程每秒生产3个苹果&#xff0c;另一个线程每秒消费8个苹果 #include <myhead.h>pthread_mutex_t m1,m2;int apple 0; void* usrapp(void* data) {while(1){pthread_mutex_lock…

旋转差分,以及曼哈顿距离转换切比雪夫距离

拿到这个问题我们要怎么去想呢&#xff0c;如果是暴力的修改的话&#xff0c;我们的复杂度为 m * 2r*r 的复杂度&#xff0c;这也太暴力了&#xff0c;我们要怎么办呢&#xff0c;我们能不能用差分数组来实现呢&#xff1f; 我们首先要看如何实现公式的转换 很显然我们可以利用…

<数据集>pcb板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;693张 标注数量(xml文件个数)&#xff1a;693 标注数量(txt文件个数)&#xff1a;693 标注类别数&#xff1a;6 标注类别名称&#xff1a;[missing_hole, mouse_bite, open_circuit, short, spurious_copper, spur…

物联网与区块链技术的跨界融合:智能城市的建设与管理

随着科技的迅猛发展&#xff0c;物联网&#xff08;IoT&#xff09;和区块链技术逐渐成为推动智能城市发展的重要技术支柱。本文将探讨物联网和区块链技术在智能城市建设与管理中的跨界融合&#xff0c;分析其应用场景和潜力。 什么是智能城市&#xff1f; 智能城市利用先进的…

(35)远程识别(又称无人机识别)(一)

文章目录 前言 1 更改 2 可用的设备 3 开放式无人机ID 4 ArduRemoteID 5 终端用户数据的设置和使用 6 测试 7 为OEMs添加远程ID到ArduPilot系统的视频教程 前言 在一些国家&#xff0c;远程 ID 正在成为一项法律要求。以下是与 ArduPilot 兼容的设备列表。这里(here)有…

深度刨析C语言中的动态内存管理

文章目录 1.为什么会存在动态内存分配2.动态内存函数介绍2.1 [malloc](https://legacy.cplusplus.com/reference/cstdlib/malloc/?kwmalloc)与[free](https://legacy.cplusplus.com/reference/cstdlib/free/?kwfree)2.2 [calloc](https://legacy.cplusplus.com/reference/cst…

Redis - SpringDataRedis - RedisTemplate

目录 概述 创建项目 引入依赖 配置文件 测试代码 测试结果 数据序列化器 自定义RedisTemplate的序列化方式 测试报错 添加依赖后测试 存入一个 String 类型的数据 测试存入一个对象 优化 -- 手动序列化 测试存入一个Hash 总结&#xff1a; 概述 SpringData 是 S…

浏览器【WebKit内核】渲染原理【QUESTION-1】

浏览器【WebKit内核】渲染原理【QUESTION】 1.浏览器输入一个网址&#xff08;域名之后&#xff09;,浏览器会呈现一个新的页面&#xff0c;中间的过程是怎么实现的&#xff1f; 输入一个网址之后&#xff0c;首先DNS服务器会解析这个域名&#xff0c;将这个域名解析成IP地址&…