【机器学习】CatBoost 模型实践:回归与分类的全流程解析

news2025/2/25 8:25:19

一. 引言

本篇博客首发于掘金 https://juejin.cn/post/7441027173430018067。
PS:转载自己的文章也算原创吧。

在机器学习领域,CatBoost 是一款强大的梯度提升框架,特别适合处理带有类别特征的数据。本篇博客以脱敏后的保险数据集为例,展示如何利用 CatBoost 完成分类和回归任务,并以可视化的方式解析特征重要性与结果。

我们将完成以下任务:

  1. 回归任务:预测保险索赔金额。
  2. 分类任务:判断保险案件是否需要调查。
  3. 可视化分析:利用散点图与分割线展示结果。

二. CatBoost 模型简介

CatBoost 是由俄罗斯搜索巨头 Yandex 于 2017 年开源的机器学习库,其名称来源于 “Category” 和 “Boosting” 的组合,旨在高效处理类别特征的梯度提升算法。与其他模型(如 XGBoost 和 LightGBM)相比,CatBoost 具有以下优势:

  • 支持类别特征:无需对类别特征进行独热编码,直接处理类别数据,避免数据膨胀。
  • 对缺失值的鲁棒性:无需特殊预处理即可直接处理缺失值。
  • 防止过拟合:内置多种正则化手段,减少梯度偏差和预测偏移,提高模型的准确性和泛化能力。
  • 对称树结构:采用对称决策树(Oblivious Trees),在每个层级使用相同的特征和分割点,提升训练和预测效率。

三. 实战项目环境与数据准备

本项目使用了脱敏后的保险数据集,包含以下特征:

  • 类别特征:险种代码、出险原因、医疗责任类别等。
  • 数值特征:基本保额、索赔金额等。
  • 标签:是否需要调查(分类任务)。

所有数据均已脱敏,支持迁移至其他表格数据集。

因为不好分享,所以后续第七节补充了一个基于sklearn "California Housing"数据集的流程代码与说明。


四. 回归任务:预测保险索赔金额

数据预处理

在回归任务中,我们根据特征预测索赔金额。以下是数据清洗与预处理的关键步骤:

  1. 过滤无效数据:移除缺失或非法值的记录。
  2. 特征转换:将类别特征转为字符串类型。
  3. 分割数据集:按 80% 和 20% 的比例划分训练集与测试集。

4.1 模型训练与评估

我们使用 CatBoost 进行回归建模,模型参数包括:

  • 学习率:0.02
  • 深度:8
  • 迭代次数:10,000(支持提前停止)

以下是模型的关键代码:

from catboost import CatBoostRegressor

# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(
    iterations=10000,
    learning_rate=0.02,
    depth=8,
    eval_metric='RMSE',
    early_stopping_rounds=1500,
    random_seed=42
)

# 训练模型
cat_regressor.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=100
)

4.2 特征重要性分析

特征重要性是衡量特征对模型预测贡献程度的指标,可以帮助我们更好地理解模型。

# 获取特征重要性
feature_importances = cat_regressor.get_feature_importance()
feature_names = X_train.columns

# 可视化特征重要性
import matplotlib.pyplot as plt
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importances
}).sort_values(by='Importance', ascending=True)

plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='salmon')
plt.xlabel('特征重要性')
plt.ylabel('特征名称')
plt.title('CatBoost 特征重要性分析')
plt.show()

结果展示
在这里插入图片描述


4.3 模型评估

我们可以均方误差 (MSE) 以及 平均绝对误差 (MAE) 来评估模型在测试集上的回归性能,同时展示模型的学习曲线:

# 获取训练和测试集的 RMSE
evals_result = cat_regressor.get_evals_result()
train_rmse = evals_result['learn']['RMSE']
test_rmse = evals_result['validation']['RMSE']

# 绘制 RMSE 曲线
plt.figure(figsize=(10, 6))
plt.plot(train_rmse, label='训练集 RMSE')
plt.plot(test_rmse, label='测试集 RMSE')
plt.title('训练与测试集的 RMSE 学习曲线')
plt.xlabel('迭代次数')
plt.ylabel('RMSE')
plt.legend()
plt.show()

五. 分类任务:判别是否调查

5.1 数据标注与模型选择

分类任务以 是否调查 作为标签(1 表示需要调查,0 表示无需调查),特征包括所有数值和类别字段。

为了完成分类任务,我们选用 CatBoostClassifier。模型参数类似于回归模型,分类评估指标包括准确率、混淆矩阵和分类报告。


5.2 训练结果与模型评估

训练结果显示,分类准确率达 94.0%。以下是模型的分类报告:

分类报告 (训练集):
               precision    recall  f1-score   support

           0       0.96      0.98      0.97     13087
           1       0.74      0.57      0.64      1354

    accuracy                           0.94     14441
   macro avg       0.85      0.77      0.80     14441
weighted avg       0.94      0.94      0.94     14441
5.3 代码示例
from catboost import CatBoostClassifier

# 初始化分类器
cat_classifier = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.02,
    depth=8,
    eval_metric='Accuracy',
    early_stopping_rounds=150,
    random_seed=42
)

# 模型训练
cat_classifier.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=100
)

六. 可视化分析

为更直观地理解模型,我们利用散点图和分割线对预测结果进行展示:

  • 散点图:展示实际金额与预测金额的分布。
  • 分割线:通过 KMeans 聚类划分四个金额档次。

以下代码生成散点图与分割线:

# 使用 KMeans 聚类生成分割线
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=4, random_state=42)
df['cluster'] = kmeans.fit_predict(df[['预测金额']])

# 绘制散点图
plt.figure(figsize=(12, 8))
plt.scatter(df['预测金额'], df['是否调查'], c=df['cluster'], cmap='tab10')
plt.title("预测金额与是否调查的散点图")
plt.xlabel("预测金额")
plt.ylabel("是否调查")
plt.colorbar(label='Cluster')
plt.show()

散点图展示

在这里插入图片描述


七. 补充学习

7.1 基础数据集

California Housing 数据集包含加利福尼亚州 20,640 个街区的人口、住房和收入信息。目标是预测每个街区的房价中位数 MedHouseVal

数据特征

  1. MedInc:街区的收入中位数。
  2. HouseAge:街区住房的平均年龄。
  3. AveRooms:每个街区的平均房间数。
  4. AveBedrms:每个街区的平均卧室数。
  5. Population:街区的总人口。
  6. AveOccup:每户的平均人数。
  7. Latitude:街区的纬度。
  8. Longitude:街区的经度。

7.2 实践步骤

7.2.1 导入数据与预处理

我们使用 Scikit-learn 加载数据并进行预处理。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd

# 加载 California Housing 数据集
data = fetch_california_housing(as_frame=True)
df = data.frame

# 特征和目标变量
X = df.drop(columns="MedHouseVal")
y = df["MedHouseVal"]

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")

训练集大小: (16512, 8), 测试集大小: (4128, 8)


7.2.2 训练 CatBoost 回归模型

使用 CatBoost 对房价进行预测。

from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error

# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(
    iterations=1000,
    learning_rate=0.1,
    depth=6,
    eval_metric="RMSE",
    random_seed=42,
    verbose=100
)

# 模型训练
cat_regressor.fit(X_train, y_train, eval_set=(X_test, y_test), verbose=100, early_stopping_rounds=50)

# 模型预测
y_pred_train = cat_regressor.predict(X_train)
y_pred_test = cat_regressor.predict(X_test)

# 模型评估
mse_train = mean_squared_error(y_train, y_pred_train)
mse_test = mean_squared_error(y_test, y_pred_test)
mae_test = mean_absolute_error(y_test, y_pred_test)

print(f"训练集均方误差 (MSE): {mse_train}")
print(f"测试集均方误差 (MSE): {mse_test}")
print(f"测试集平均绝对误差 (MAE): {mae_test}")

输出如下:

0:	learn: 1.0934740	test: 1.0841841	best: 1.0841841 (0)	total: 1.24s	remaining: 20m 38s
100:	learn: 0.4867395	test: 0.5154868	best: 0.5154868 (100)	total: 1.54s	remaining: 13.7s
200:	learn: 0.4320149	test: 0.4798269	best: 0.4798269 (200)	total: 1.8s	remaining: 7.18s
300:	learn: 0.4020581	test: 0.4657293	best: 0.4657293 (300)	total: 2.07s	remaining: 4.8s
400:	learn: 0.3803801	test: 0.4582868	best: 0.4582868 (400)	total: 2.35s	remaining: 3.5s
500:	learn: 0.3633580	test: 0.4534430	best: 0.4534430 (500)	total: 2.61s	remaining: 2.6s
600:	learn: 0.3488402	test: 0.4491723	best: 0.4491723 (600)	total: 2.89s	remaining: 1.92s
700:	learn: 0.3358611	test: 0.4461323	best: 0.4461323 (700)	total: 3.17s	remaining: 1.35s
800:	learn: 0.3234759	test: 0.4431320	best: 0.4431320 (800)	total: 3.44s	remaining: 854ms
900:	learn: 0.3126821	test: 0.4403978	best: 0.4403978 (900)	total: 3.71s	remaining: 407ms
999:	learn: 0.3025414	test: 0.4386906	best: 0.4386902 (998)	total: 3.97s	remaining: 0us

bestTest = 0.438690174
bestIteration = 998

Shrink model to first 999 iterations.
训练集均方误差 (MSE): 0.09158491090576551
测试集均方误差 (MSE): 0.19244906768098075
测试集平均绝对误差 (MAE): 0.28701415230111493

7.2.3 可视化预测结果

展示预测值与实际值的对比,以及模型的特征重要性。

实际值与预测值对比
import matplotlib.pyplot as plt

# 对比测试集的预测值和实际值
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_test)), y_test, color="blue", label="真实值", alpha=0.6)
plt.scatter(range(len(y_pred_test)), y_pred_test, color="red", label="预测值", alpha=0.6)
plt.title("真实房价与预测房价对比")
plt.xlabel("样本索引")
plt.ylabel("房价中位数")
plt.legend()
plt.show()

特征重要性分析
# 特征重要性可视化
feature_importances = cat_regressor.get_feature_importance()
feature_names = data.feature_names

plt.figure(figsize=(10, 6))
plt.barh(feature_names, feature_importances, color="skyblue")
plt.title("CatBoost 特征重要性")
plt.xlabel("重要性得分")
plt.ylabel("特征名称")
plt.show()

在这里插入图片描述


7.3 数据结果

  • 模型评估结果:
    • 训练集均方误差 (MSE): 0.09158491090576551
    • 测试集均方误差 (MSE): 0.19244906768098075
    • 测试集平均绝对误差 (MAE): 0.28701415230111493
  • 特征重要性解读:
    根据特征重要性分析,MedInc(收入中位数)对预测房价的影响最大,而经纬度特征(Latitude 和 Longitude)也提供了显著的信息。

八. 总结

通过本项目,我们完成了基于 CatBoost 的回归与分类建模,并展示了预测结果的可视化。CatBoost 的强大功能和易用性使其在处理类别特征和缺失值的数据中表现优异。

希望本篇博客能为大家带来启发,助力实际项目的落地实现。如果对您有所帮助,也欢迎点赞与分享😊。

源码已上传到:https://github.com/YYForReal/ML-DL-RL-Learning/blob/main/ML-Learning/Catboost/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2253334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

游戏引擎学习第25天

Git: https://gitee.com/mrxiao_com/2d_game 今天的计划 总结和复述: 这段时间的工作已经接近尾声,虽然每次编程的时间只有一个小时,但每一天的进展都带来不少收获。尽管看起来似乎花费了很多时间,实际上这些日积月累的时间并未…

AI开发:生成式对抗网络入门 模型训练和图像生成 -Python 机器学习

阶段1:GAN是个啥? 生成式对抗网络(Generative Adversarial Networks, GAN),名字听着就有点“对抗”的意思,没错!它其实是两个神经网络互相斗智斗勇的游戏: 生成器(Gene…

040集——CAD中放烟花(CAD—C#二次开发入门)

效果如下: 单一颜色的烟花: 渐变色的火花: namespace AcTools {public class HH{public static TransientManager tm TransientManager.CurrentTransientManager;public static Random rand new Random();public static Vector3D G new V…

JavaScript实现tab栏切换

JavaScript实现tab栏切换 代码功能概述 这段代码实现了一个简单的选项卡(Tab)切换功能。它通过操作 HTML 元素的类名(class)来控制哪些选项卡(Tab)和对应的内容板块显示,哪些隐藏。基本思路是先…

【天地图】HTML页面实现车辆轨迹、起始点标记和轨迹打点的完整功能

目录 一、功能演示 二、完整代码 三、参考文档 一、功能演示 运行以后完整的效果如下: 点击开始,小车会沿着轨迹进行移动,点击轨迹点会显示经纬度和时间: 二、完整代码 废话不多说,直接给完整代码,替换…

HCIA笔记6--路由基础与静态路由:浮动路由、缺省路由、迭代查找

文章目录 0. 概念1.路由器工作原理2. 跨网访问流程3. 静态路由配置4. 静态路由的应用场景4.1 路由备份4.2 浮动路由4.3 缺省路由 5. 迭代路由6 问题6.1 为什么路由表中有的下一跳的地址有接口?6.2 个人电脑的网关本质是什么? 0. 概念 自治系统&#xff…

20241129解决在Ubuntu20.04下编译中科创达的CM6125的Android10出现找不到库文件libncurses.so.5的问题

20241129解决在Ubuntu20.04下编译中科创达的CM6125的Android10出现找不到库文件libncurses.so.5的问题 2024/11/29 21:11 缘起:中科创达的高通CM6125开发板的Android10的编译环境需要。 vendor/qcom/proprietary/commonsys/securemsm/seccamera/service/jni/jni_if.…

Matlab搜索路径添加不上

发现无论是右键文件夹添加到路径,还是在“设置路径”中专门添加,我的路径始终添加不上,导致代码运行始终报错,后来将路径中的“”加号去掉后,就添加成功了,经过测试,路径中含有中文也可以添加成…

自由学习记录(28)

C# 中的流(Stream) 流(Stream)是用于读取和写入数据的抽象基类。 流表示从数据源读取或向数据源写入数据的矢量过程。 C# 中的流类是从 System.IO.Stream 基类派生的,提供了多种具体实现,每种实现都针对…

Redis3——线程模型与数据结构

Redis3——线程模型与数据结构 本文讲述了redis的单线程模型和IO多线程工作原理,以及几个主要数据结构的实现。 1. Redis的单线程模型 redis6.0之前,一个redis进程只有一个io线程,通过reactor模式可以连接大量客户端;redis6.0为了…

使用playwright自动化测试时,npx playwright test --ui打开图形化界面时报错

使用playwright自动化测试时,npx playwright test --ui打开图形化界面时报错 1、错误描述:2、解决办法3、注意符号的转义 1、错误描述: 在运行playwright的自动化测试项目时,使用npm run test无头模式运行正常,但使用…

深度学习模型:门控循环单元(GRU)详解

本文深入探讨了门控循环单元(GRU),它是一种简化版的长短期记忆网络(LSTM),在处理序列数据方面表现出色。文章详细介绍了 GRU 的基本原理、与 LSTM 的对比、在不同领域的应用以及相关的代码实现,…

用html+jq实现元素的拖动效果——js基础积累

用htmljq实现元素的拖动效果 效果图如下&#xff1a; 将【item10】拖动到【item1】前面 直接上代码&#xff1a; html部分 <ul id"sortableList"><li id"item1" class"w1" draggable"true">Item 1</li><li …

单片机学习笔记 12. 定时/计数器_定时

更多单片机学习笔记&#xff1a;单片机学习笔记 1. 点亮一个LED灯单片机学习笔记 2. LED灯闪烁单片机学习笔记 3. LED灯流水灯单片机学习笔记 4. 蜂鸣器滴~滴~滴~单片机学习笔记 5. 数码管静态显示单片机学习笔记 6. 数码管动态显示单片机学习笔记 7. 独立键盘单片机学习笔记 8…

【乐企文件生成工程】搭建docker环境,使用docker部署工程

1、自行下载docker 2、自行下载docker-compose 3、编写Dockerfile文件 # 使用官方的 OpenJDK 8 镜像 FROM openjdk:8-jdk-alpine# 设置工作目录 WORKDIR ./app# 复制 JAR 文件到容器 COPY ../lq-invoice/target/lq-invoice.jar app.jar # 暴露应用程序监听的端口 EXPOSE 1001…

React基础知识三 router路由全指南

现在最新版本是Router6和Router5有比较大的变化&#xff0c;Router5和Router4变化不大&#xff0c;本文以Router6的写法为主&#xff0c;也会对比和Router5的不同。比较全面。 安装路由 npm i react-router-dom基本使用 有两种Router&#xff0c;BrowserRouter和HashRouter&…

【C#】书籍信息的添加、修改、查询、删除

文章目录 一、简介二、程序功能2.1 Book类属性&#xff1a;方法&#xff1a; 2.2 Program 类 三、方法&#xff1a;四、用户界面流程&#xff1a;五、程序代码六、运行效果 一、简介 简单的C#控制台应用程序&#xff0c;用于管理书籍信息。这个程序将允许用户添加、编辑、查看…

打造去中心化交易平台:公链交易所开发全解析

公链交易所&#xff08;Public Blockchain Exchange&#xff09;是指基于公有链&#xff08;如以太坊、波场、币安智能链等&#xff09;建立的去中心化交易平台。与传统的中心化交易所&#xff08;CEX&#xff09;不同&#xff0c;公链交易所基于区块链技术实现资产交换的去中心…

CLIP模型也能处理点云信息

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

关于NXP开源的MCU_boot的项目心得

MCU的启动流程细查 注意MCU上电第一个函数运行的就是Reset_Handler函数&#xff0c;下图是表示了这个函数做了啥事情&#xff0c;注意加强一下对RAM空间的段的印象&#xff0c;从上到下是栈&#xff0c;堆&#xff0c;.bss段&#xff0c;.data段。 bootloader的难点 固件完…