5 CatBoost模型

news2025/1/23 0:55:26

目录

1 背景

2 原理

2.1 类别特征处理

2.1.1 传统目标编码: TS

2.1.2 Greedy TS

2.1.3 ordered TS编码

2.1.4 CatBoost处理Categorical features总结

2.2.预测偏移处理

2.2.1 梯度无偏估计

2.3 树的构建​​​​​​​

3 优缺点

优点

4 代码


1 背景

        终于到了CatBoost,这个模型我在打比赛的时候相对于lightGBM用的少一些,但是我一般都会进行尝试,尤其是类别型特征很多的时候。曾依赖这个模型单模干到了top8,那时候还不懂的继承的优雅。我们开始这个模型介绍吧。

        CatBoost是一种基于对称决策树(oblivious trees)为基学习器实现的参数较少、支持类别型变量和高准确性的GBDT框架,主要解决的痛点是高效合理地处理类别型特征,这一点从它的名字中可以看出来,CatBoost是由Categorical和Boosting组成。此外,CatBoost还解决了梯度偏差(Gradient Bias)以及预测偏移(Prediction shift)的问题,从而减少过拟合的发生,进而提高算法的准确性和泛化能力。

        上面做了一个概述,那么黑体的名字如何理解呢?

        另外,与其提升算法不同,CatBoost使用对称全二叉树(这种树的特点是每一层使用相同的分割特征)。这样一来,树是更简单的结构,我们也就避免了过度拟合的危险。此外,由于我们的基础模型结构简单,我们有更快的预测器。

                           

2 原理

2.1 类别特征处理

        CatBoost算法的设计初衷是为了更好的处理GBDT特征中的categorical features(比如性别【男,女】)。在处理 GBDT特征中的categorical features的时候,最简单的方法是用 categorical feature 对应的标签的平均值来替换(target encoding, 这个在比赛中我也是常用,但是存在问题)。在决策树中,标签平均值将作为节点分裂的标准。这种方法被称为 Greedy Target-based Statistics , 简称 Greedy TS;

2.1.1 传统目标编码: TS

用公式来表达就是:

x_{i,k} = \frac{ \sum_{j=1}^{n} [x_{j,k}=x_{i,k}] Y_j}{\sum_{j=1}^{n} [x_{j,k}=x_{i,k}]},  =>  groupby('cat')[label].mean()

        TS编码有一个缺点,极端情况下当训练集中某类取值只有一个样本、或者没有样本时,计算的编码值就失真了,也就是容易受噪声数据影响。

        如果强行用标签的平均值来表示特征的话,当训练数据集和测试数据集数据结构和分布不一样的时候会出条件偏移问题。

2.1.2 Greedy TS

        一个标准的改进 TS的方式是添加先验分布项,这样可以减少噪声和低频率类别型数据对于数据分布的影响

直接上公式:

x_{i,k} = \frac{ \sum_{j=1}^{n} [x_{j,k}=x_{i,k}] Y_j + ap}{\sum_{j=1}^{n} [x_{j,k}=x_{i,k}]+a}

        其中p是添加的先验项,a通常是大于0的权重系数。添加先验项是一个普遍做法,针对类别数较少的特征,它可以减少噪声数据。对于回归问题,一般情况下,先验项可取数据集label的均值。对于二分类,先验项是正例的先验概率。

        Greedy TS编码也存在一个问题,即目标泄露。也需要训练预测集合数据分布一致;

2.1.3 ordered TS编码

        它是catboost的主要思想,依赖于排序,受online learning algorithms的启发得到,对于某一个样本,TS的值依赖于观测历史,为了在离线的数据上应用该思想,我们将数据随机排序,对于每一个样本,利用该样本之前数据计算该样本类别值的TS值。如果仅仅使用一个随机序列,那么计算得到值会有较大的方差,因此我们使用不同的随机序列来计算。

        在某种排序状态 𝜎 下,样本 x_i 在分类特征 𝑘 下的值为 x_{i}^{k}x_{i}^{k}的ordered TS编码值是基于排在其前面的样本D_{\sigma }计算的,在D_{\sigma }中计算分类特征 𝑘 下取值与x_{i}^{k}相同的样本的Greedy TS编码值,该值即为的ordered TS编码值。举例说明

在上图中,经过样本排序后,样本的排序情况为{4,3,7,2,6,1,5},计算样本4的ordered TS编码值时,由于没有样本排在其前面,因此其ordered TS编码值计算方式为

T(y=1|x=D) = \frac{0+ap}{0+a} = p

计算样本6的ordered TS编码值时,排在其前面的样本为{4,3,7,2},在这4个样本中,特征取值为D的只有样本4,因此其ordered TS编码值计算方式为:

T(y=1|x=D) = \frac{0+ap}{0+a} = p

2.1.4 CatBoost处理Categorical features总结

  • 首先会计算一些数据的statistics。计算某个category出现的频率,加上超参数,生成新的numerical features。这一策略要求同一标签数据不能排列在一起(即先全是之后全是这种方式),训练之前需要打乱数据集。
  • 第二,使用数据的不同排列。在每一轮建立树之前,先扔一轮骰子,决定使用哪个排列来生成树。
  • 第三,考虑使用categorical features的不同组合。例如颜色和种类组合起来,可以构成类似于blue dog这样的特征。当需要组合的categorical features变多时,CatBoost只考虑一部分combinations。在选择第一个节点时,只考虑选择一个特征,例如A。在生成第二个节点时,考虑A和任意一个categorical feature的组合,选择其中最好的。就这样使用贪心算法生成combinations。
  • 第四,除非向gender这种维数很小的情况,不建议自己生成One-hot编码向量,最好交给算法来处理。

2.2.预测偏移处理

        在GBDT算法中,每一棵树都是为了拟合前一棵树上的梯度,构造树时所有的样本都参与了,一个样本参与了建树,然后又用这棵树去估计样本值,这样的估计就不是无偏估计,当测试集和训练集上的样本分布不一致时,模型就会因过拟合而性能不佳,即在测试集上产生了预测偏移。

2.2.1 梯度无偏估计

        对于样本 𝑥𝑖 ,如果用一个不包含它的模型去估计它的梯度,估计结果可以视为无偏估计。基于这种思路,CatBoost算法中采用了如下策略:在每一轮迭代时,将样本集排序,然后训练 𝑛 个模型 𝑀𝑖,𝑖=1,...,𝑛 ,𝑛为样本数量,其中 𝑀𝑖 是由前 𝑖 个样本训练得到(基于本轮样本的排序,包含样本𝑖),然后估计样本𝑖的梯度时,使用模型𝑀𝑖−1 来估计,因为𝑀𝑖−1是由不包含样本𝑖的样本训练得到,因此该估计结果是无偏估计。

                        

        这种方式尽管得到了无偏估计,但是对于排序靠前的样本,它的梯度估计结果可能并不太准确,具有较大的方差,因为估计它的模型是由较少的样本训练的,而且是基于本轮迭代中的样本排序计算的,为了减少预测方差,CatBoost在每轮迭代中都会对样本进行重新排序,然后按照相同的思路估计本轮中样本的梯度,这样多轮迭代的最终结果就可以获取一个较小的方差。

        当然,如果每轮迭代都要训练𝑛 个模型,那是一个比较大的工作量,CatBoost算法对这个过程做了些简化来提升训练速度,详情后面分析树的创建过程时再说明。

2.3 树的构建

        得到树结构,也是节点分裂的过程。在GBDT、XGBoost、LightGBM等算法中,节点分类时需要遍历所有候选特征及分裂阈值,CatBoost算法也采用了这种策略,但有两点不同:

  • 分类型特征处理;
  • 数值型特征的空值处理。在CatBoost算法中,将空值全部转换为最小值(默认),或者最大值;

        每一轮迭代都会创建 𝑛 个模型,也就是创建𝑛棵树(在CatBoost算法中,并不是𝑛棵树,而是 [𝑙𝑜𝑔2𝑛] ,这个只是为了减少计算量,因此在阐述CatBoost算法的原理时仍然使用𝑛),训练树是比较耗时的过程,在整个模型训练的时间中占很大的比重,如果每一轮都要重新训练𝑛棵树,那将会非常耗时,CatBoost算法在这点上做了调整,具体做是:

1 在第一轮迭代中,在选定样本排序状态 𝜎 下,分别用前 𝑗 个样本训练模型 𝑀𝜎,𝑗 ,也就是得到 𝑛 棵树。一个树的训练过程包含两部分:第一步,得到树结构;第二步,计算叶子节点的值。第一步中得到的树结构,也即是每一层选用什么分裂特征,分裂阈值是多少

2 得到模型𝑀𝜎,𝑗的树结构后,CatBoost会使将该树结构复用到后续的所有迭代过程中。例如在第一轮迭代中,由前100个样本训练得到了模型 𝑀𝜎,100 的树结构,后续的迭代过程中会直接使用该树结构,然后将对应排序状态 𝜎′ 下的前100个样本直接应用该树结构,将样本划分到对应的叶子节点上,得到完整的模型,而不用再重新遍历样本的特征来寻找最佳划分特征和划分阈值。

        ​​​​​​​        ​​​​​​​        

那么在每一层中,如何评判用哪种特征分裂最好呢?CatBoost算法采用这样的策略:

  1. 基本本轮样本排序状态,得到每一棵树上的样本,首先依据前面迭代轮次的结果计算每个样本的梯度,也就是得到根节点中每个样本的梯度向量 𝐺 。
  2. 遍历候选特征来分裂根节点,得到多种分裂结果,分别计算每个样本的叶子值增量 Δ(𝑖) ,其中 𝑖 指第 𝑖 个样本(这个增量值的含义原论文中也没有明确说明,只是说是一个增量值,这里也参照原文进行说明),Δ(𝑖)的计算方式为:在样本 𝑖 所属的叶子节点上,计算排在样本 𝑖 前面的样本的梯度的平均值,该平均值即为Δ(𝑖)。(这里要说明一下,上面说的叶子节点,是相对于上一层的节点来说的,并不是整棵树的叶子节点),这样就得到每个样本的增量值Δ(𝑖),这些增量值可以组成一个向量 Δ;

  3. 采用余弦相似度方法,计算各种候选特特征下节点分裂的损失 𝑙𝑜𝑠𝑠(𝐺,Δ) ,选择损失最小的方式来分裂树。余弦相似度的计算方式为:

  4.                               

  5. 其中 𝜔𝑖 标识样本𝑖的权重,这个权重是CatBoost算法中为每个样本随机赋予的,起到样本抽样的效果,目的是为了减少过拟合; 𝑎𝑖 标识样本在树上的输出值,这里也就是Δ(𝑖)值; 𝑔𝑖 即第一步中计算的样本的梯度值。

3 优缺点

优点

  • 性能卓越: 在性能方面可以匹敌任何先进的机器学习算法;
  • 鲁棒性/强健性: 无需调参即可获得较高的模型质量,采用默认参数就可以获得非常好的结果,减少在调参上面花的时间,减少了对很多超参数调优的需求
  • 易于使用: 提供与scikit集成的Python接口,以及R和命令行界面;
  • 实用: 可以处理类别型、数值型特征,支持类别型变量,无需对非数值型特征进行预处理
  • 可扩展: 支持自定义损失函数;
  • 快速、可扩展的GPU版本,可以用基于GPU的梯度提升算法实现来训练你的模型,支持多卡并行提高准确性,
  • 快速预测:即便应对延时非常苛刻的任务也能够快速高效部署模型;

缺点:

  • 对于类别型特征的处理需要大量的内存和时间;
  • 不同随机数的设定对于模型预测结果有一定的影响;

4 代码

import re
import os
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve,roc_auc_score
import matplotlib.pyplot as plt
import gc
from bayes_opt import BayesianOptimization
from catboost import Pool, cv

n_fold = 5
folds = KFold(n_splits=n_fold, shuffle=True, random_state=1314)
oof = np.zeros(len(train_df))
prediction = np.zeros(len(test_df))
for fold_n, (train_index, valid_index) in enumerate(folds.split(train_df)):
    X_train, X_valid = train_df[features].iloc[train_index], train_df[features].iloc[valid_index]
    y_train, y_valid = train_df[label].iloc[train_index], train_df[label].iloc[valid_index]
    cate_features=[]
#     +['corss_卧室_床的数量', 'corss_床的类型_床的数量',
#              'corss_房产类型_卧室数量', 'corss_房产类型_洗手间数量']
    train_pool = Pool(X_train, y_train, cat_features=cate_features)
    eval_pool = Pool(X_valid, y_valid, cat_features=cate_features)
    cbt_model = catboost.CatBoostRegressor(iterations=600, # 注:baseline 提到的分数是用 iterations=60000 得到的,但运行时间有点久
                           learning_rate=0.1, # 注:事实上好几个 property 在 lr=0.1 时收敛巨慢。后面可以考虑调大
                           eval_metric='mse',
                         # n_estimators=3000,
                           # reg_lambda=5,
                           use_best_model=True,
                           random_seed=42,
                           logging_level='Verbose',
                           #task_type='GPU',
                           devices='0',
                           gpu_ram_part=0.5)
    
    cbt_model.fit(train_pool,
              eval_set=eval_pool,
              verbose=1000)
    
    y_pred_valid = cbt_model.predict(X_valid)
    print("valid RMSE")
    print(print(np.sqrt(np.mean(np.square(y_pred_valid - train_df.loc[valid_index,label])))))
    y_pred = cbt_model.predict(test_df[features])
    oof[valid_index] = y_pred_valid.reshape(-1, )
    prediction += y_pred
prediction /= n_fold

from sklearn.metrics import mean_squared_error
score = mean_squared_error(oof, train_df[label].values)
print(score)


#test['价格'] = prediction
#test[['数据ID', '价格']].to_csv('./{}_sub_cat.csv'.format(np.sqrt(score)), index=None)

ref:

 安全验证 - 知乎

安全验证 - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关系图卷积神经网络

异质图和知识图谱 同质图与异质图 同质图指的是图中的节点类型和关系类型都仅有一种 异质图是指图中的节点类型或关系类型多于一种 知识图谱 知识图谱包含实体和实体之间的关系&#xff0c;并以三元组的形式存储&#xff08;<头实体, 关系, 尾实体>&#xff0c;即异…

Three.js——聚光灯、环境光、点光源、平行光、半球光

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

阿里云服务器怎么更换暴露的IP

很多客户阿里云服务器被攻击IP暴露&#xff0c;又不想迁移数据换服务器&#xff0c;其实阿里云服务器可以更换IP&#xff0c;今天就来和大家说说流程&#xff0c;云服务器创建成功后6小时内可以免费更换公网IP地址三次&#xff0c;超过6小时候就只能通过换绑弹性公网IP的方式来…

探索人工智能绘图的奇妙世界

探索人工智能绘图的奇妙世界 人工智能绘图的基本原理机器之美&#xff1a;AI绘图作品AI绘图对艺术创作的影响未来展望与挑战图书推荐&#x1f449;AI绘画教程&#xff1a;Midjourney使用方法与技巧从入门到精通内容简介获取方式&#x1f449;搜索之道&#xff1a;信息素养与终身…

访问云平台中linux系统图形化界面,登录就出现黑屏的问题解决(ubuntu图形界面)

目录 一、问题-图形化界面访问黑屏 二、系统环境 &#xff08;一&#xff09;网络结构示意图 &#xff08;二&#xff09;内部机器版本 三、分析 四、解决过程 &#xff08;一&#xff09;通过MobaXterm远程访问图形化界面(未成功) 1、连接方法 2、连接结果 &#xf…

【新版】系统架构设计师 - 知识点 - 结构化开发方法

个人总结&#xff0c;仅供参考&#xff0c;欢迎加好友一起讨论 文章目录 架构 - 知识点 - 结构化开发方法结构化开发方法结构化分析结构化设计 数据流图和数据字典模块内聚类型与耦合类型 架构 - 知识点 - 结构化开发方法 结构化开发方法 分析阶段 工具&#xff1a;数据流图、…

VUE项目使用.env配置多种环境以及如何加载环境

第一步&#xff0c;创建多个环境配置文件 Vue CLI 项目默认使用 .env 文件来定义环境变量。你可以通过创建不同的 .env 文件来为不同环境设置不同的环境变量&#xff0c;例如&#xff1a; .env —— 所有模式共用.env.local —— 所有模式共用&#xff0c;但不会被 git 提交&…

算法模板-线段树+懒标记

视频连接&#xff1a;C02【模板】线段树懒标记 Luogu P3372 线段树 1_哔哩哔哩_bilibili 题目链接&#xff1a;P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) P3374 【模板】树状数组 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 算法思路 递…

四大战略合作重磅签署,九章云极DataCanvas公司为全球智算生态注能

4月18日&#xff0c;备受瞩目的“2024九章云极DataCanvas智算操作系统新品发布会”上&#xff0c;九章云极DataCanvas公司携手新华出版社、曙光信息产业股份有限公司&#xff08;简称“中科曙光”&#xff09;、黄山旅游发展股份有限公司&#xff08;简称“黄山旅游”&#xff…

51单片机串口输出问题(第一个字符重复,自动循环输出第一个字符)

遇到的问题描述 51单片机使用串口发送数据时出现只循环发送字符串的第一个字符的情况。就算发送的是第一个字符也有时候一直发送。 串口函数代码 参考串口发送注意 #include <reg52.h> //此文件中定义了单片机的一些特殊功能寄存器void UsartInit() {SCON0X50; /…

在Spring Boot实战中碰到的拦截器与过滤器是什么?

在Spring Boot实战中&#xff0c;拦截器&#xff08;Interceptors&#xff09;和过滤器&#xff08;Filters&#xff09;是两个常用的概念&#xff0c;它们用于在应用程序中实现一些通用的逻辑&#xff0c;如日志记录、权限验证、请求参数处理等。虽然它们都可以用于对请求进行…

Navicat连接postgresql时出现‘datlastsysoid does not exist‘报错的问题

连接报错 解决方案 解决方法1&#xff1a;升级navicat 解决方法2&#xff1a;降级pgsql 解决方法3&#xff1a;修改dll 使用3解决 实操演示 1、 打开 Navicat 安装目录&#xff0c;找到libcc.dll文件 2、备份libcc.dll文件&#xff0c;将其复制并粘贴或者修改副本为任何其他名…

L2-045 堆宝塔

L2-045 堆宝塔 分数 25 全屏浏览 切换布局 作者 陈越 单位 浙江大学 堆宝塔游戏是让小朋友根据抓到的彩虹圈的直径大小&#xff0c;按照从大到小的顺序堆起宝塔。但彩虹圈不一定是按照直径的大小顺序抓到的。聪明宝宝采取的策略如下&#xff1a; 首先准备两根柱子&#xff…

Linux:进程调度

Linux&#xff1a;进程调度 进程优先级查看优先级调整优先级 Linux 2.6 内核进程调度队列 进程优先级 查看优先级 在Linux中&#xff0c;进程是有优先级的&#xff0c;我们可以通过指令ps -la来查看&#xff1a; 其中PRI表示priority优先级&#xff0c;在Linux中&#xff0c;…

[openGL] 高级光照-Gamma矫正与衰减

目录 一 衰减 二 衰减公式 三 使用场景 四 代码实现 4.1 部分代码 4.2 未校验的效果 4.3 Gamma校验后的效果 4.4 总结 本章节源码 点击此处 一 衰减 在之前平行光和投光物的部分中&#xff0c;了解了光源的衰减&#xff0c;对于平行光来说是不需要衰减的&#xff0c…

中霖教育:二建考试中六个专业分别有什么特点?

建筑实务 《建筑实务》技术部分多以选择题为主&#xff0c;主要是对各种数据的考查;管理部分以案例题为主&#xff0c;旨在考查大家的综合能力&#xff0c;也是分值占比比较多的部分。进度控制的网络图和流水施工每年必考其一;质量管理主要结合技术部分命题;安全管理和合同管理…

正式发布的Spring AI,能让Java喝上AI赛道的汤吗

作者:鱼仔 博客首页: https://codeease.top 公众号:Java鱼仔 前言 最近几年AI发展实在太快了&#xff0c;仿佛只要半年没关注&#xff0c;一个新的大模型所产生的效果就能超越你的想象。Java在AI这条路上一直没什么好的发展&#xff0c;不过Spring最近出来了一个新的模块叫做S…

高可用集群——keepalived

目录 1 高可用的概念 2 心跳监测与漂移 IP 地址 3 Keepalived服务介绍 4 Keepalived故障切换转移原理介绍 5 Keepalived 实现 Nginx 的高可用集群 5.1 项目背景 5.2 项目环境 5.3 项目部署 5.3.1 web01\web02配置&#xff1a; 5.3.2nginx负载均衡配置 5.3.3 主调度服…

全开源小狐狸Ai系统 小狐狸ai付费创作系统 ChatGPT智能机器人2.7.6免授权版

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 测试环境&#xff1a;Linux系统CentOS7.6、宝塔、PHP7.4、MySQL5.6&#xff0c;根目录public&#xff0c;伪静态thinkPHP&#xff0c;开启ssl证书 具有文章改写、广告营销文案、编程…

商务品牌解决方案企业网站模板 Bootstrap5

目录 一.前言 二.展示 三.下载链接 一.前言 这个网站包含以下内容&#xff1a; 导航栏&#xff1a;主页&#xff08;Home&#xff09;、关于&#xff08;About&#xff09;、服务&#xff08;Services&#xff09;、博客&#xff08;Blog&#xff09;等页面链接。主页部分…