XGBOOST、LightGBM、CATBoost

news2024/11/21 20:25:28

本文介绍几种不同的 GBDT 优化算法:

  • XGBoost
    XGBoost 对损失函数展开二阶导,使得提升树能逼近真是损失,增加正则项防止过拟合,XGBoost 公式:
    L( y i y_i yi, y ^ i \hat{y}_i y^i): 损失函数
    Ω ( f k ) \Omega(f_k) Ω(fk): 正则项
    在这里插入图片描述
    分类点增加了二阶导:
    G:一阶导数
    H:二阶导数
    在这里插入图片描述
# 安装依赖
pip install xgboost

import numpy as np
from cart import TreeNode, BinaryDecisionTree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from utils import cat_label_convert

### 准备数据
from sklearn import datasets
# 导入鸢尾花数据集
data = datasets.load_iris()
# 获取输入输出
X, y = data.data, data.target
# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43)  

import xgboost as xgb
from xgboost import plot_importance
from matplotlib import pyplot as plt

# 设置模型参数
params = {
    'booster': 'gbtree',
    'objective': 'multi:softmax',   
    'num_class': 3,     
    'gamma': 0.1,
    'max_depth': 2,
    'lambda': 2,
    'subsample': 0.7,
    'colsample_bytree': 0.7,
    'min_child_weight': 3,
    'eta': 0.001,
    'seed': 1000,
    'nthread': 4,
}


dtrain = xgb.DMatrix(X_train, y_train)
num_rounds = 200
model = xgb.train(params, dtrain, num_rounds)
# 对测试集进行预测
dtest = xgb.DMatrix(X_test)
y_pred = model.predict(dtest)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print ("Accuracy:", accuracy)
# 绘制特征重要性
plot_importance(model)
plt.show();

在这里插入图片描述

  • LightGBM
    XGBoost 需找最优分裂点的计算复杂度可以估计为:特征数 x 分裂点数量 x 样本量,LightGBM 对 XGBoost 算法通过这三方面进行优化。
  1. 直方图优化(Histogram-Based):按桶计算特征值的分裂点而不是去尝试每一个分裂点,每个桶中包含多个特征值。
  2. 互斥特征合并(Exclusive Feature Bundling):把多个互斥的特征进行合并,可以有效的减少特征数量。
  3. 叶子策略(Leaf-Wise):叶子生长策略相对于按层生长的策略,优势在于只保留有效降低损失值的节点,缺点是如果正则值设置的不合适,有可能产生过拟合。
# 安装依赖
pip install lightgbm
# 导入相关模块
import lightgbm as lgb
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# 导入iris数据集
iris = load_iris()
data = iris.data
target = iris.target
# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=43)
# 创建lightgbm分类模型
gbm = lgb.LGBMClassifier(objective='multiclass',
                         num_class=3,
                         num_leaves=31,
                         learning_rate=0.05,
                         n_estimators=20)
# 模型训练
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)])
# 预测测试集
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
# 模型评估
print('Accuracy of lightgbm:', accuracy_score(y_test, y_pred))
lgb.plot_importance(gbm)
plt.show();

在这里插入图片描述

  • CatBoost
    CatBoost 算法是使用类别特征的 Boost 框架,使用目标变量统计算法而不是 OneHot 编码,通过排序提升让后面的角色树只能前面的数据,而不能看到后面决策树所能看到的数据库,这个可以大大提升训练效果。
#安装依赖
pip install catboost
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn.metrics import f1_score

# 读取数据
data = pd.read_csv('./adult.data', header=None)
# 变量重命名
data.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 
                'marital-status', 'occupation', 'relationship', 'race', 'sex', 
                'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income']
# 标签转换
data['income'] = data['income'].astype("category").cat.codes
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(data.drop(['income'], axis=1), data['income'],
                                                    random_state=10, test_size=0.3)
# 配置训练参数
clf = cb.CatBoostClassifier(eval_metric="AUC", depth=4, iterations=500, l2_leaf_reg=1,
                            learning_rate=0.1)
# 类别特征索引
cat_features_index = [1, 3, 5, 6, 7, 8, 9, 13]
# 训练
clf.fit(X_train, y_train, cat_features=cat_features_index)
# 预测
y_pred = clf.predict(X_test)
# 测试集f1得分
print(f1_score(y_test, y_pred))

在这里插入图片描述

总结

本文介绍了三种 GBDT 的优化算法,可以根据实际情况进行选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244866.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文阅读 SimpleNet: A Simple Network for Image Anomaly Detection and Localization

SimpleNet: A Simple Network for Image Anomaly Detection and Localization 摘要: 该论文提出了一个简单且应用友好的网络(称为 SimpleNet)来检测和定位异常。SimpleNet 由四个组件组成:(1)一个预先训练的…

多线程4:线程池、并发、并行、综合案例-抢红包游戏

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

Java数据库连接(Java Database Connectivity,JDBC)

1.JDBC介绍 Java数据库连接(Java Database Connectivity,JDBC)是SUN公司为了简化、统一对数据库的操作,定义的一套Java操作数据库的规范(接口)。这套接口由数据库厂商去实现,这样,开…

高亮变色显示文本中的关键字

效果 第一步:按如下所示代码创建一个用来高亮显示文本的工具类: public class KeywordUtil {/*** 单个关键字高亮变色* param color 变化的色值* param text 文字* param keyword 文字中的关键字* return*/public static SpannableString highLigh…

2024强化学习的结构化剪枝模型RL-Pruner原理及实践

[2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration 目录 [2024] RL-Pruner: Structured Pruning Using Reinforcement Learning for CNN Compression and Acceleration一、论文说明二、原理三、实验与分析1、环境配置在…

电脑超频是什么意思?超频的好处和坏处

嗨,亲爱的小伙伴!你是否曾经听说过电脑超频?在电脑爱好者的圈子里,这个词似乎非常熟悉,但对很多普通用户来说,它可能还是一个神秘而陌生的存在。 今天,我将带你揭开超频的神秘面纱,…

uniapp: vite配置rollup-plugin-visualizer进行小程序依赖可视化分析减少vender.js大小

一、前言 在之前文章《uniapp: 微信小程序包体积超过2M的优化方法(主包从2.7M优化到1.5M以内)》中,提到了6种优化小程序包体积的方法,但并没有涉及如何分析common/vender.js这个文件的优化,而这个文件的大小通常情况下…

SQL Server Management Studio 的JDBC驱动程序和IDEA 连接

一、数据库准备 (一)启用 TCP/IP 协议 操作入口 首先,我们要找到 SQL Server 配置管理器,操作路径为:通过 “此电脑” 右键选择 “管理”,在弹出的 “计算机管理” 窗口中,找到 “服务和应用程…

STM32F103系统时钟配置

时钟是单片机运行的基础,时钟信号推动单片机内各个部分执行相应的指令。时钟系统就是CPU的脉搏,决定CPU速率,像人的心跳一样 只有有了心跳,人才能做其他的事情,而单片机有了时钟,才能够运行执行指令&#x…

鸿蒙进阶篇-Math、Date

“在科技的浪潮中,鸿蒙操作系统宛如一颗璀璨的新星,引领着创新的方向。作为鸿蒙开天组,今天我们将一同踏上鸿蒙基础的探索之旅,为您揭开这一神奇系统的神秘面纱。” 各位小伙伴们我们又见面了,我就是鸿蒙开天组,下面让我们进入今…

RAID存储技术 详解

RAID(Redundant Array of Independent Disks,独立磁盘冗余阵列)是一种将多个物理硬盘组合为一个逻辑存储单元的技术。它通过分布数据、冗余校验和容错能力,提高存储系统的性能、可靠性和容量利用率。 以下从底层原理和源代码层面…

MTK主板定制_联发科主板_MTK8766/MTK8768/MTK8788安卓主板方案

主流市场上的MTK主板通常采用联发科的多种芯片平台,如MT8766、MT6765、MT6762、MT8768和MT8788等。这些芯片基于64位Cortex-A73/A53架构,提供四核或八核配置,主频可达2.1GHz,赋予设备卓越的计算与处理能力。芯片采用12纳米制程工艺…

免费微调自己的大模型(llama-factory微调llama3.1-8b)

目录 1. 名词/工具解释2. 微调过程3. 总结 本文主要介绍通过llama-factory框架,使用Lora微调方法,微调meta开源的llama3.1-8b模型,平台使用的是趋动云GPU算力资源。 微调已经经过预训练的大模型目的是,通过调整模型参数和不断优化…

MySQL 中 InnoDB 支持的四种事务隔离级别名称,以及逐级之间的区别?

MySQL中的InnoDB存储引擎支持四种事务隔离级别,这些级别定义了事务在并发环境中的行为和相互之间的可见性。以下是这四种隔离级别的名称以及它们之间的区别: 读未提交(Read Uncommitted) 特点:这是最低的隔离级别&…

【YOLOv10改进[注意力]】引入并行分块注意力PPA(2024.3.16) + 适于微小目标

本文将进行在YOLOv10中引入并行分块注意力PPA魔改v10 的实践,文中含全部代码、详细修改方式。助您轻松理解改进的方法。 一 HCF 论文题目:Hierarchica

共建智能软件开发联合实验室,怿星科技助力东风柳汽加速智能化技术创新

11月14日,以“奋进70载,智创新纪元”为主题的2024东风柳汽第二届科技周在柳州盛大开幕,吸引了来自全国的汽车行业嘉宾、技术专家齐聚一堂,共襄盛举,一同探寻如何凭借 “新技术、新实力” 这一关键契机,为新…

在ubuntu下,使用Python画图,无法显示中文怎么解决

1.首先需要下载中文字体,推荐simsun,即宋体,地址如下 https://www.freefonts.io/download/simsun/ 2.下载完要把字体文件放进字体目录,具体方法如下; a.创建字体目录:sudo mkdir -p /usr/share/fonts/truet…

鸿蒙实战:使用显式Want启动Ability

文章目录 1. 实战概述2. 实现步骤2.1 创建鸿蒙应用项目2.2 修改Index.ets代码2.3 创建SecondAbility2.4 创建Second.ets 3. 测试效果4. 实战总结5. 拓展练习 - 启动文件管理器5.1 创建鸿蒙应用项目5.2 修改Index.ets代码5.3 测试应用运行效果 1. 实战概述 本实战详细阐述了在 …

《Python浪漫的烟花表白特效》

一、背景介绍 烟花象征着浪漫与激情,将它与表白结合在一起,会创造出别具一格的惊喜效果。使用Python的turtle模块,我们可以轻松绘制出动态的烟花特效,再配合文字表白,打造一段专属的浪漫体验。 接下来,让…

springboot中设计基于Redisson的分布式锁注解

如何使用AOP设计一个分布式锁注解&#xff1f; 1、在pom.xml中配置依赖 <dependency><groupId>org.springframework</groupId><artifactId>spring-aspects</artifactId><version>5.3.26</version></dependency><dependenc…