LightGBM——提升机器算法详细介绍(附代码)

news2024/11/28 19:24:26

LightGBM——提升机器算法

前言

LightGBM是个快速的,分布式的,高性能的基于决策树算法的梯度提升框架。可用于排序,分类,回归以及很多其他的机器学习任务中。

在竞赛题中,我们知道XGBoost算法非常热门,它是一种优秀的拉动框架,但是在使用过程中,其训练耗时很长,内存占用比较大。在2017年年1月微软在GitHub的上开源了一个新的升压工具--LightGBM。在不降低准确率的前提下,速度提升了10倍左右,占用内存下降了3倍左右。因为他是基于决策树算法的,它采用最优的叶明智策略分裂叶子节点,然而其它的提升算法分裂树一般采用的是深度方向或者水平明智而不是叶,明智的。因此,在LightGBM算法中,当增长到相同的叶子节点,叶明智算法比水平-wise算法减少更多的损失。因此导致更高的精度,而其他的任何已存在的提升算法都不能够达。与此同时,它的速度也让人感到震惊,这就是该算法名字    的原因。

  • 2014年3月,XGBOOST最早作为研究项目,由陈天奇提出    (XGBOOST的部分在我的另一篇博客里:

  • 2017年1月,微软发布首个稳定版LightGBM 在微软亚洲研究院AI头条分享中的「LightGBM简介」中,机器学习组的主管研究员王太峰提到:微软DMTK团队在github上开源了性能超越其它推动决策树工具LightGBM后,三天之内星了1000+次,叉了超过200次。知乎上有近千人关注“如何看待微软开源的LightGBM?”问题,被评价为“速度惊人”,“非常有启发”,“支持分布式” “代码清晰易懂”,“占用内存小”等。以下是微软官方提到的LightGBM的各种优点,以及该项目的开源地址。

 

一、"What We Do in LightGBM?"

下面这个表格给出了XGBoost和LightGBM之间更加细致的性能对比,包括了树的生长方式,LightGBM是直接去选择获得最大收益的结点来展开,而XGBoost是通过按层增长的方式来做,这样呢LightGBM能够在更小的计算代价上建立我们需要的决策树。当然在这样的算法中我们也需要控制树的深度和每个叶子结点的最小数据量,从而减少过拟合。

小小翻译一下,有问题还望指出

 **XGBoost****LightGBM** 树木生长算法 **按层生长的方式** 有利于工程优化,但对学习模型效率不高  直接**选择最大收益的节点**来展开,在更小的计算代价上去选择我们需要的决策树 控制树的深度和每个叶子节点的数据量,能减少过拟合

有利于工程优化,但对学习模型效率不高

控制树的深度和每个叶子节点的数据量,能减少过拟合

划分点搜索算 法对特征预排序的方法直方图算法:将特征值分成许多小筒,进而在筒上搜索分裂点,减少了计算代价和存储代价,得到更好的性能。另外数据结构的变化使得在细节处的变化理上效率会不同 内存开销8个字节1个字节 划分的计算增益数据特征容器特征 高速缓存优化无在Higgs数据集上加速40% 类别特征处理无在Expo数据集上速度快了8倍

 

二、在不同数据集上的对比

higgs和expo都是分类数据,yahoo ltr和msltr都是排序数据,在这些数据中,LightGBM都有更好的准确率和更强的内存使用量。

准确率

内存使用情况

计算速度的对比,完成相同的训练量XGBoost通常耗费的时间是LightGBM的数倍之上,在higgs数据集上,它们的差距更是达到了15倍以上。

三、LightGBM的细节技术

1、直方图优化

XGBoost中采用预排序的方法,计算过程当中是按照value的排序,逐个数据样本来计算划分收益,这样的算法能够精确的找到最佳划分值,但是代价比较大同时也没有较好的推广性。

在LightGBM中没有使用传统的预排序的思路,而是将这些精确的连续的每一个value划分到一系列离散的域中,也就是筒子里。以浮点型数据来举例,一个区间的值会被作为一个筒,然后以这些筒为精度单位的直方图来做。这样一来,数据的表达变得更加简化,减少了内存的使用,而且直方图带来了一定的正则化的效果,能够使我们做出来的模型避免过拟合且具有更好的推广性。

看下直方图优化的细节处理

可以看到,这是按照bin来索引“直方图”,所以不用按照每个“特征”来排序,也不用一一去对比不同“特征”的值,大大的减少了运算量。

2、存储记忆优化

当我们用数据的bin描述数据特征的时候带来的变化:首先是不需要像预排序算法那样去存储每一个排序后数据的序列,也就是下图灰色的表,在LightGBM中,这部分的计算代价是0;第二个,一般bin会控制在一个比较小的范围,所以我们可以用更小的内存来存储

3、深度限制的节点展开方法

LightGBM使用了带有深度限制的节点展开方法(Leaf-wise)来提高模型精度,这是比XGBoost中Level-wise更高效的方法。它可以降低训练误差得到更好的精度。但是单纯的使用Leaf-wise可能会生长出比较深的树,在小数据集上可能会造成过拟合,因此在Leaf-wise之上多加一个深度限制

4、直方图做差优化

直方图做差优化可以达到两倍的加速,可以观察到一个叶子节点上的直方图,可以由它的父亲节点直方图减去它兄弟节点的直方图来得到。根据这一点我们可以构造出来数据量比较小的叶子节点上的直方图,然后用直方图做差来得到数据量比较大的叶子节点上的直方图,从而达到加速的效果。

5、顺序访问梯度

预排序算法中有两个频繁的操作会导致cache-miss,也就是缓存消失(对速度的影响很大,特别是数据量很大的时候,顺序访问比随机访问的速度快4倍以上  )。

  • 对梯度的访问:在计算增益的时候需要利用梯度,对于不同的特征,访问梯度的顺序是不一样的,并且是随机的- 对于索引表的访问:预排序算法使用了行号和叶子节点号的索引表,防止数据切分的时候对所有的特征进行切分。同访问梯度一样,所有的特征都要通过访问这个索引表来索引。这两个操作都是随机的访问,会给系统性能带来非常大的下降。

LightGBM使用的直方图算法能很好的解决这类问题。首先。对梯度的访问,因为不用对特征进行排序,同时,所有的特征都用同样的方式来访问,所以只需要对梯度访问的顺序进行重新排序,所有的特征都能连续的访问梯度。并且直方图算法不需要把数据id到叶子节点号上(不需要这个索引表,没有这个缓存消失问题)

6、支持类别特征

传统的机器学习一般不能支持直接输入类别特征,需要先转化成多维的0-1特征,这样无论在空间上还是时间上效率都不高。LightGBM通过更改决策树算法的决策规则,直接原生支持类别特征,不需要转化,提高了近8倍的速度

7、支持并行学习

LightGBM原生支持并行学习,目前支持特征并行(Featrue Parallelization)数据并行(Data Parallelization)两种,还有一种是基于投票的数据并行(Voting Parallelization)

  • 特征并行的主要思想是在不同机器、在不同的特征集合上分别寻找最优的分割点,然后在机器间同步最优的分割点。- 数据并行则是让不同的机器先在本地构造直方图,然后进行全局的合并,最后在合并的直方图上面寻找最优分割点。LightGBM针对这两种并行方法都做了优化。

  • 特征并行算法中,通过在本地保存全部数据避免对数据切分结果的通信。- 数据并行中使用分散规约 (Reduce scatter) 把直方图合并的任务分摊到不同的机器,降低通信和计算,并利用直方图做差,进一步减少了一半的通信量。- **基于投票的数据并行(Voting Parallelization)**则进一步优化数据并行中的通信代价,使通信代价变成常数级别。在数据量很大的时候,使用投票并行可以得到非常好的加速效果。下图更好的说明了以上这三种并行学习的整体流程:

在直方图合并的时候,通信代价比较大,基于投票的数据并行能够很好的解决这一点。

四、MacOS安装LightGBM

#先安装cmake和gcc,安装过的直接跳过前两步
brew install cmake
brew install gcc

git clone --recursive https://github.com/Microsoft/LightGBM 
cd LightGBM

#在cmake之前有一步添加环境变量
export CXX=g++-7 CC=gcc-7
mkdir build ; cd build

cmake ..
make -j4
cd ../python-package
sudo python setup.py install

来测试一下:

大功告成!

值得注意的是:pip list里面没有lightgbm,以后使用lightgbm需要到特定的文件夹中运行。我的地址是:

/Users/ fengxianhe / LightGBM /python-package

 

五,用python实现LightGBM算法

为了演示LightGBM在蟒蛇中的用法,本代码以sklearn包中自带的鸢尾花数据集为例,用lightgbm算法实现鸢尾花种类的分类任务。

# coding: utf-8
# pylint: disable = invalid-name, C0111

# 函数的更多使用方法参见LightGBM官方文档:http://lightgbm.readthedocs.io/en/latest/Python-Intro.html

import json
import lightgbm as lgb
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.datasets import  make_classification

iris = load_iris()   # 载入鸢尾花数据集
data=iris.data
target = iris.target
X_train,X_test,y_train,y_test =train_test_split(data,target,test_size=0.2)


# 加载你的数据
# print('Load data...')
# df_train = pd.read_csv('../regression/regression.train', header=None, sep='\t')
# df_test = pd.read_csv('../regression/regression.test', header=None, sep='\t')
#
# y_train = df_train[0].values
# y_test = df_test[0].values
# X_train = df_train.drop(0, axis=1).values
# X_test = df_test.drop(0, axis=1).values

# 创建成lgb特征的数据集格式
lgb_train = lgb.Dataset(X_train, y_train) # 将数据保存到LightGBM二进制文件将使加载更快
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)  # 创建验证数据

# 将参数写成字典下形式
params = {
    'task': 'train',
    'boosting_type': 'gbdt',  # 设置提升类型
    'objective': 'regression', # 目标函数
    'metric': {'l2', 'auc'},  # 评估函数
    'num_leaves': 31,   # 叶子节点数
    'learning_rate': 0.05,  # 学习速率
    'feature_fraction': 0.9, # 建树的特征选择比例
    'bagging_fraction': 0.8, # 建树的样本采样比例
    'bagging_freq': 5,  # k 意味着每 k 次迭代执行bagging
    'verbose': 1 # <0 显示致命的, =0 显示错误 (警告), >0 显示信息
}

print('Start training...')
# 训练 cv and train
gbm = lgb.train(params,lgb_train,num_boost_round=20,valid_sets=lgb_eval,early_stopping_rounds=5) # 训练数据需要参数列表和数据集

print('Save model...') 

gbm.save_model('model.txt')   # 训练后保存模型到文件

print('Start predicting...')
# 预测数据集
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) #如果在训练期间启用了早期停止,可以通过best_iteration方式从最佳迭代中获得预测
# 评估模型
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5) # 计算真实值和预测值之间的均方根误差

 输出结果:

可以看到预测值和真实值之间的均方根误差为0.722972。

参考资料:

【1】LightGBM——提升机器算法(图解+理论+安装方法+python代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/416976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL:安装 MySQL、Navicat、使用 Navicat 连接 MySQL

文章目录Day 01:一、概念1. 数据库 DB2. 数据库管理系统 DBMS3. MySQL二、安装 MySQL三、安装 Navicat Premium 16四、使用 Navicat 连接 MySQL注意:Day 01: 一、概念 1. 数据库 DB 数据库:DB (Database) 数据仓库,…

NumPy 秘籍中文第二版:四、将 NumPy 与世界的其他地方连接

原文:NumPy Cookbook - Second Edition 协议:CC BY-NC-SA 4.0 译者:飞龙 在本章中,我们将介绍以下秘籍: 使用缓冲区协议使用数组接口与 MATLAB 和 Octave 交换数据安装 RPy2与 R 交互安装 JPype将 NumPy 数组发送到 J…

脑电信号分析

导读 EEG信号的分析过程是为了获得能够突出信号本身特定特性的值,从而对其进行表征。同时,也需要将所获得的值通过准确的绘图技术来进行正确地显示,以使这些值对用户有用且清晰易读。目前,已有许多不同的脑电信号分析和显示技术&…

MVCC

MVCC基本概念 当前读 当前读 : 读取的是记录的最新版本,读取时还要保证其他并发事务不能修改当前记录,会对读取的记录进行加锁. 对于我们日常的操作. 如 : select....lock in share mode(共享锁) , select * for update , update ,insert,delete(排他锁) 都是一种当前读. 快…

「Cpolar」使用Typecho搭建个人博客网站【内网穿透实现公网访问】

💂作者简介: THUNDER王,一名热爱财税和SAP ABAP编程以及热爱分享的博主。目前于江西师范大学本科在读,同时任汉硕云(广东)科技有限公司ABAP开发顾问。在学习工作中,我通常使用偏后端的开发语言A…

Spring学习小结

文章目录1 BeanFactory与ApplicationContext的关系2 Spring基础环境下,常用的三个ApplicationContext3 Spring开发中Bean的配置4 Bean的初始化和销毁方法配置5 Bean的实例化配置6 Bean的依赖注入之自动装配7 Spring 的 xml 标签(默认、自定义&#xff09…

硬件语言Verilog HDL牛客刷题 day09 哲K部分

1.VL59 根据RTL图编写Verilog程序 1.题目: 根据以下RTL图,使用 Verilog HDL语言编写代码,实现相同的功能,并编写testbench验证功能 2.解题思路 2.1 了解D触发器的知识 (在时钟是上升沿的时候, 输入是什么…

UE “体积”的简单介绍

目录 一、阻挡体积 二、摄像机阻挡体积 三、销毁Z体积 四、后期处理体积 一、阻挡体积 你可以在静态网格体上使用阻挡体积替代碰撞表面,比如建筑物墙壁。这可以增强场景的可预测性,因为物理对象不会与地面和墙壁上的凸起细节相互作用。它还能降低物理模…

visio的使用技巧

一、调节箭头方向 1.打开你要修改的Microsoft Visio文件 2.选中你要修改的箭头,在上方的开始工具栏中找到“线条”选项,鼠标左键单击打开; 3.在下面找到“箭头”这个选项,鼠标移到上面去,就会展开;带阴影的…

Linux网络编程 第七天

目录 网络编程阶段项目 项目目标 Web服务器开发准备 Html语言基础 Html简介 Html标签介绍 题目标签 文本标签 列表标签 图片标签 超链接标签 http请求消息 请求类型 http响应消息 http常见状态码 http常见文件类型分…

“万物智联·共数未来”2023年移远通信物联网生态大会圆满落幕

4月12日,以“万物智联共数未来”为主题的2023年移远通信物联网生态大会在深圳前海华侨城JW万豪酒店隆重举办。 大会邀请到来自运营商、主流芯片商、行业客户、产业协会、标准联盟、媒体等产业链合作伙伴的40多位行业大咖,共话物联网产业的现在和未来。参…

node开通阿里云短信验证服务,代码演示 超级详细

阿里云官网步骤:Node.js SDK (aliyun.com) 首先先搭建一个node项目:app.js const express require(express); // 引入 Express 框架const app express(); app.use(express.json()); // 解析请求中的 JSON 数据const PORT process.env.PORT || 3000; …

URL 以及 URLConnection 类的使用

1. 概述 java 提供了两个类,在这两个类里封装了大部分 Web 相关的各种操作。这两个类是 URL 类 和 URLConnection 类。2. URL 类 java.net.URL 类定义了一个统一的资源定位器,它是指向互联网“资源”的指针。可以定 位互联网上的资源。并且…

LInux一天10题 day1

su(switch user) 命令用于更改其他使用者身份, usermod -l 修改账号名称,使用格式:usermod -l new_name old_name 修改用户权限: 方法1 1、先切换到root权限的用户登录下,修改 /etc/sudoers 文件,找…

games103——作业1

实验一主要实现简单的刚体动画模拟(一只兔子),包括 impulse 的碰撞检测与响应,以及 Shape Matching方法。 完整项目已上传至github。 文章目录简单刚体模拟(不考虑碰撞)平移运动旋转运动粒子碰撞检测与响应碰撞检测碰撞响应Penalty MethodsQuadratic Pen…

如何安全的从硬盘恢复文件?

可以从硬盘中恢复永久删除的文件吗? “我删除了一些看起来不重要的文件夹,并清空了回收站。但在几天后,我意识到删除的文件夹里有些重要的数据。如何恢复永久删除的文件?谢谢!” 随着科技的进步与发展&#xff0c…

LeetCode 2404. 出现最频繁的偶数元素

原题链接:2404. 出现最频繁的偶数元素 给你一个整数数组 nums ,返回出现最频繁的偶数元素。 如果存在多个满足条件的元素,只需要返回 最小 的一个。如果不存在这样的元素,返回 -1 。 示例 1: 输入:nums …

OpenAI团队抢着用的编程语言?

作为一名合格的(准)程序员,必做的一件事是关注编程语言的热度,编程榜代表了编程语言的市场占比变化,它的变化更预示着未来的科技风向和机会! 快一起看看本月排行有何看点: 4月Tiobe排行榜前15…

如何学习智能交通?

AI的专业领域知识是指AI与具体应用领域相结合时所需要的该应用领域的知识。AI的应用领域非常广泛,例如计算机视觉、智能交通、智能制造、智慧金融、智慧教育、智慧农业、智慧能源、智能通信、智能芯片等。本文主要介绍智能交通的基本概念、发展历程、主要研究内容、…

DAMA数据治理认证,一定要先考CDGA才能考CDGP吗?

DAMA认证为数据管理专业人士提供职业目标晋升规划,彰显了职业发展里程碑及发展阶梯定义,帮助数据管理从业人士获得企业数字化转型战略下的必备职业能力,促进开展工作实践应用及实际问题解决,形成企业所需的新数字经济下的核心职业…