《机器学习公式推导与代码实现》chapter14-CatBoost

news2025/1/6 17:14:55

《机器学习公式推导与代码实现》学习笔记,记录一下自己的学习过程,详细的内容请大家购买作者的书籍查阅。

CatBoost

CatBoost是俄罗斯搜索引擎巨头Yandex于2017年开源的一款GBDT计算框架,因能够高效处理数据中的类别特征而取名为CatBoost(Categorical Boosting)。

1 机器学习中类别特征的处理方法

CatBoost通过对常规的目标变量统计方法添加先验项来改进它们。除此之外CatBoost还考虑使用类别特征的不同组合来增加数据集特征维度。

对于特征取值数目较多的类别特征,一种折中的方法就是将类别数目重新归类,使其降到较少数目再进行one-hot编码。另一种常用的方法是目标变量统计target statistics, TS),TS计算每个类别对于目标变量的期望值并将类别特征转换为新的数值特征。CatBoost在常规TS方法上做了改进。

2 CatBoost理论基础

CatBoost算法框架的自身理论特色,包括用于处理类别变量的目标变量统计特征组合排序提升算法

2.1 目标变量统计

CatBoost算法的设计初衷是为了更好的处理GBDT特征中的categorical features。在处理 GBDT特征中的categorical features的时候,最简单的方法是用 categorical feature 对应的标签的平均值来替换。在决策树中,标签平均值将作为节点分裂的标准。这种方法被称为 Greedy Target-based Statistics , 简称 Greedy TS,用公式来表达就是:
x ^ k i = ∑ j = 1 n [ x j , k = x i , k ] Y i ∑ j = 1 n [ x j , k = x i , k ] \hat{x}_{k}^{i} =\frac{\sum_{j=1}^{n}\left [ x_{j,k} =x_{i,k} \right ]Y_{i}}{\sum_{j=1}^{n} \left [ x_{j,k} =x_{i,k} \right ]} x^ki=j=1n[xj,k=xi,k]j=1n[xj,k=xi,k]Yi
这种方法有一个显而易见的缺陷,就是通常特征比标签包含更多的信息,如果强行用标签的平均值来表示特征的话,当训练数据集和测试数据集数据结构和分布不一样的时候会出条件偏移问题。

一个标准的改进 Greedy TS的方式是添加先验分布项,这样可以减少噪声和低频率类别型数据对于数据分布的影响:
x ^ k i = ∑ j = 1 p − 1 [ x σ j , k = x σ p , k ] Y σ j + α p ∑ j = 1 p − 1 [ x σ j , k = x σ p , k ] + α \hat{x}_{k}^{i} =\frac{\sum_{j=1}^{p-1}\left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]Y_{\sigma _{j}} + \alpha p}{\sum_{j=1}^{p-1} \left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]+\alpha } x^ki=j=1p1[xσj,k=xσp,k]+αj=1p1[xσj,k=xσp,k]Yσj+αp
其中p是添加的先验项,α通常是大于0的权重系数。添加先验项是一个普遍做法,针对类别数较少的特征,它可以减少噪声数据。对于回归问题,一般情况下,先验项可取数据集label的均值。对于二分类,先验项是正例的先验概率。利用多个数据集排列也是有效的,但是,如果直接计算可能导致过拟合。

CatBoost利用了一个比较新颖的计算叶子节点值的方法,这种方式(oblivious trees,对称树)可以避免多个数据集排列中直接计算会出现过拟合的问题。

2.2 特征组合

值得注意的是几个类别型特征的任意组合都可视为新的特征。例如,在音乐推荐应用中,我们有两个类别型特征:用户ID和音乐流派。如果有些用户更喜欢摇滚乐,将用户ID和音乐流派转换为数字特征时,根据上述这些信息就会丢失。

结合这两个特征就可以解决这个问题,并且可以得到一个新的强大的特征。然而,组合的数量会随着数据集中类别型特征的数量成指数增长,因此不可能在算法中考虑所有组合。

为当前树构造新的分割点时,CatBoost会采用贪婪的策略考虑组合。对于树的第一次分割,不考虑任何组合。对于下一个分割,CatBoost将当前树的所有组合、类别型特征与数据集中的所有类别型特征相结合,并将新的组合类别型特征动态地转换为数值型特征。

2.3 排序提升算法

对于学习预测偏移的内容,我提出了两个问题:

  • 什么是预测偏移?
  • 用什么办法解决预测偏移问题?

预测偏移(Prediction shift)是由梯度偏差造成的。在GDBT的每一步迭代中, 损失函数使用相同的数据集求得当前模型的梯度, 然后训练得到基学习器, 但这会导致梯度估计偏差, 进而导致模型产生过拟合的问题。

CatBoost通过采用排序提升 (Ordered boosting) 的方式替换传统算法中梯度估计方法,进而减轻梯度估计的偏差,提高模型的泛化能力。

CatBoost采用对称树作为基分类器,对称意味着在树的同一层,分裂标准相同。对称树具有平衡、不易过拟合、能够大大缩短测试时间的特点。

3 CatBoost算法实现

作为与XGBoost和LightGBM齐名的Boosting算法,CatBoost有足够优秀的性能指标,尤其是对类别特征的处理。

import pandas as pd
data = pd.read_csv('./adult.data', header=None)
data

在这里插入图片描述

data.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race',
                'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income'] # 变量重命名
data['income']
0         <=50K
1         <=50K
2         <=50K
3         <=50K
4         <=50K
          ...  
32556     <=50K
32557      >50K
32558     <=50K
32559     <=50K
32560      >50K
Name: income, Length: 32561, dtype: object
data['income'] = data['income'].astype('category').cat.codes
data['income'].unique()
array([0, 1], dtype=int8)
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn.metrics import accuracy_score
X_train, X_test, y_train, y_test = train_test_split(data.drop(['income'], axis=1), data['income'], random_state=10, test_size=0.3)
clf = cb.CatBoostClassifier(eval_metric='AUC', depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.1)
cat_features_index = [1, 3, 5, 6, 7, 8, 9, 13] # 设置分类特征的索引,以便 CatBoost 能够正确地识别这些特征
clf.fit(X_train, y_train, cat_features=cat_features_index)
y_pred = clf.predict(X_test)
print(accuracy_score(y_pred, y_test))
0:	total: 274ms	remaining: 2m 16s
1:	total: 337ms	remaining: 1m 23s
2:	total: 384ms	remaining: 1m 3s
3:	total: 434ms	remaining: 53.8s
4:	total: 485ms	remaining: 48s
5:	total: 558ms	remaining: 45.9s
6:	total: 596ms	remaining: 41.9s
7:	total: 642ms	remaining: 39.5s
8:	total: 676ms	remaining: 36.9s
9:	total: 712ms	remaining: 34.9s
10:	total: 748ms	remaining: 33.3s
11:	total: 782ms	remaining: 31.8s
12:	total: 816ms	remaining: 30.6s
13:	total: 854ms	remaining: 29.6s
14:	total: 896ms	remaining: 29s
15:	total: 941ms	remaining: 28.4s
16:	total: 981ms	remaining: 27.9s
17:	total: 1.02s	remaining: 27.3s
18:	total: 1.06s	remaining: 26.8s
19:	total: 1.1s	remaining: 26.4s
20:	total: 1.14s	remaining: 26s
21:	total: 1.18s	remaining: 25.6s
22:	total: 1.22s	remaining: 25.2s
23:	total: 1.25s	remaining: 24.8s
24:	total: 1.28s	remaining: 24.4s
...
497:	total: 18s	remaining: 72.4ms
498:	total: 18.1s	remaining: 36.2ms
499:	total: 18.1s	remaining: 0us
0.8721465861398301

笔记本_Github地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/674104.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一时兴起之matlab学习记录

是学习记录&#xff0c;会有错误的地方 安装的话看其他文章把 小操作 查看历史命令 在输入命令的地方&#xff0c;按下↑的方向键即可 变量 对大小写敏感若想以指定的类型存储就是 类型名(值),如 int16(4)&#xff0c;这个也可以强转变量名字有限制&#xff0c;键入namele…

10分钟搭建Stable Diffusion

前言 人工智能生成内容&#xff08;Artificial Intelligence Generated Content&#xff0c;简称 AIGC&#xff09;是当下最火的概念之一。AIGC 被认为是继专业生成内容&#xff08;Professional Generated Content, PGC&#xff09;和用户生成内容&#xff08;User Generated…

【MySQL】数据库基础 ③

上一章&#xff1a; 【MySQL】数据库基础 ② ✍临时表 说明&#xff1a; MySQL 临时表在我们需要保存一些临时数据时是非常有用的。临时表只在当前连接可见&#xff0c;…

局域网实验报告

计算机网络综合实训 实训报告一 所在院系 计算机与信息工程学院 学科专业名称 计算机科学与技术 导师及职称 柯宗武 教授 提交时间 2023.3.10 网络层实验报告 &#xff08;湖北师范大学计算机与信息工程学院 中国 黄石 435002&#xff09; 1 集线器与交换机的对比实验 1.1 背…

【Python】自动化构建项目结构样式

引言 在使用Python或者其它编程语言的项目时候&#xff0c;编写README.md 往往是不可或缺的&#xff1b; 而在README.md 中&#xff0c;关于项目结构的样式展示&#xff0c;这个是可选的。不展示也无伤大雅&#xff0c;但有展示的话&#xff0c;有以下优点&#xff1a; 提供…

第九章 总结及作业(4)【编译原理】

第九章 总结及作业&#xff08;4&#xff09;【编译原理】 前言推荐第九章 运行时存储空间组织9.1 目标程序运行时的活动9.1.1过程的活动9.1.2参数传递 9.2 运行时存储器的划分9.2.1运行时存储器的划分9.2.2 活动记录9.2.3 存储分配策咯 9.3 静态存储分配9.3.1数据区*9.3.2公用…

基于深度学习的人脸检测技术

用到环境 1、pycharm community edition 2022.3.2 2、Python 3.10 整篇内容都已上传至我的csdn资源中&#xff0c;想用的请移步。 流程 多任务级联卷积神经网络(Multi-task Cascaded Convolutional Networks, MTCNN)算法进行人脸检测 普通人脸检测 单人人脸检测 图1 单人人…

我最喜欢的编程语言是python,以及我的见解!!

这里写目录标题 我最喜欢的编程语言&#xff1a;1、我个人认为编程语言优劣的评选标准2、我对不同编程语言的优点与缺点的拙见**1. Java****2. Python****3. JavaScript****4. C语言&#xff1a;****5. C语言&#xff1a;** 3、对python编程语言未来发展的猜测和未来趋势 我最喜…

Vicuna-13B使用云服务器部署

Vicuna概述 Vicuna由一群主要来自加州大学伯克利分校的研究人员推出&#xff0c;仍然是熟悉的配方、熟悉的味道。Vicuna同样是基于Meta开源的LLaMA大模型微调而来&#xff0c;它的训练数据是来自ShareGPT上的7万多条数据&#xff08;ShareGPT一个分享ChatGPT对话的谷歌插件&am…

kerberos配置dolphinscheduler

kerberos配置dolphinscheduler 一、添加dolphin 用戶1.所有節點上執行如下命令&#xff1a; 二、DolphinScheduler集群模式部署1.集群规划2.前置准备工作3.解压DolphinScheduler安装包4. 创建元数据库及用户5. 配置一键部署脚本6 初始化数据库7.修改common配置文件8. 一键部署D…

华为、思科和瞻博网络三个厂商如何配置基本ACL和高级ACL?

今天给大家带来基本ACL和高级ACL的配置&#xff0c;主要会介绍三个厂商的配置&#xff1a; 其他厂商也可以参考&#xff0c;比如华三的可以参考华为的&#xff0c;锐捷的参考思科的。 1. 基本ACL配置 基本ACL&#xff08;Access Control List&#xff09;是一种简单的网络安全…

【Java高级语法】(十二)可变参数:Java中的“可变之美“,做好这些细节,你的程序强大又灵活~

Java高级语法详解之可变参数 &#x1f539; 前言1️⃣ 概念2️⃣ 优势和缺点3️⃣ 特征和应用场景3.1 特征3.2 应用场景 4️⃣ 使用和原理5️⃣ 使用技巧5.1 可变参数结合泛型5.2 使用元组或列表进行参数传递5.3 使用默认值5.4 缓存计算结果 6️⃣ 实战&#xff1a;构建动态日志…

【Vue3+Ts project】认识 Websocket 以及 socket.io 库

目录 Websocket socket.io Socket.iO 事件名总结&#xff1a; Socket.IO 方法总结 Websocket 作用&#xff1a; WebSocket 仍然提供实时的双向通信功能&#xff0c;使用Vue3 应用程序能够与服务器进行实时数据交换降低延迟和网络开销&#xff1a;相比传统的HTTP请求-响…

scratch lenet(7): C语言计算可学习参数数量和连接数量

scratch lenet(7): C语言计算可学习参数数量和连接数量 1. 目的 按照 LeNet-5 对应的原版论文 LeCun-98.pdf 的网络结构&#xff0c;算出符合原文数据的“网络每层可学习参数数量、连接数量”。 网络上很多人的 LeNet-5 实现仅仅是 “copy” 现有的别人的项目&#xff0c; 缺…

求2的N次幂(C++)解决高精度运算

​&#x1f47b;内容专栏&#xff1a;《C/C专栏》 &#x1f428;本文概括&#xff1a; 计算高精度的2的N次方数字。 &#x1f43c;本文作者&#xff1a;花 碟 &#x1f438;发布时间&#xff1a;2023.6.22 文章目录 ​前言求2的N次方&#xff0c;N ≤ 10000实现思路&#xff1a…

SpringBoot 如何使用 @PathVariable 进行数据校验

SpringBoot 如何使用 PathVariable 进行数据校验 在 SpringBoot 项目中&#xff0c;我们经常需要从 URL 中获取参数并进行相关的数据校验。而 PathVariable 注解就是一种非常方便的方式&#xff0c;可以让我们在方法参数中直接获取 URL 中的参数&#xff0c;并进行数据校验。本…

基于python开发实现数学中各种经典曲线的可视化

今天正好有点时间就想着把之前零星时间里面做的一点小东西整合一下梳理出来&#xff0c;本文的核心目的就是想要基于python来开发实现各种有趣的数学曲线的可视化展示。 笛卡尔心形线 笛卡尔心形线是一种二维平面曲线&#xff0c;由法国数学家笛卡尔在17世纪提出。它得名于其…

基于springboot+Redis的前后端分离项目(三)-【黑马点评】

&#x1f381;&#x1f381;资源文件分享 链接&#xff1a;https://pan.baidu.com/s/1189u6u4icQYHg_9_7ovWmA?pwdeh11 提取码&#xff1a;eh11 优惠券秒杀 优惠券秒杀1 -全局唯一ID2 -Redis实现全局唯一Id3 添加优惠卷4 实现秒杀下单5 库存超卖问题分析6 优惠券秒杀-一人一单…

Spring Boot 异常处理的主要特点

Spring Boot 异常处理的主要特点 在 Web 应用程序中&#xff0c;异常处理是非常重要的一部分。在 Spring Boot 中&#xff0c;异常处理是非常简单和灵活的。本文将介绍 Spring Boot 异常处理的主要特点&#xff0c;并提供一些示例代码来帮助您更好地理解。 异常处理的主要特点…

王道计算机网络学习笔记(1)——计算机网络基本知识

前言 文章中的内容来自B站王道考研计算机网络课程&#xff0c;想要完整学习的可以到B站官方看完整版。 一&#xff1a;计算机网络基本知识 1.1.1&#xff1a;认识计算机网络 计算机网络的功能 网络把许多计算机连接在一起&#xff0c;而互联网则将许多网络连接在一起&#x…