机器学习实践(2.2)LightGBM回归任务

news2025/2/23 6:56:16

前言

LightGBM也属于Boosting集成学习模型(还有前面文章的XGBoost),LightGBM和XGBoost同为机器学习的集大成者。相比越来越流行的深度神经网络,LightGBM和XGBoost能更好的处理表格数据,并具有更强的可解释性,还具有易于调参、输入数据不变性等优势。

机器学习实践(1.2)XGBoost回归任务

机器学习实践(2.1)LightGBM分类任务

❤️ 本文完整脚本点此链接百度网盘链接获取 ❤️

一.轻松实现回归任务

1.1导入第三方库、数据集

"""第三方库导入"""
from lightgbm import LGBMRegressor
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import r2_score, mean_squared_error
import lightgbm as lgb

"""波士顿房价数据集导入"""
data = datasets.load_boston()
# print(data)

"""训练集 验证集构建"""
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2,
                                                    random_state=42)

sklearn的波士顿房价数据集共506个数据样本,8:2切分后,训练集404个数据样本,验证集102个数据样本。数据集中包括 样本特征data(13个特征)、特征名称feature_names样本标签target(MEDV)、以及数据集位置filename(~~~\anaconda\lib\site-packages\sklearn\datasets\data\boston_house_prices.csv)

特征名称和标签解释如下:

- CRIM     per capita crime rate by town\n      # 按城镇划分的犯罪率  
- ZN       proportion of residential land zoned for lots over 25,000 sq.ft.\n  # 划分为25000平方英尺以上地块的住宅用地比例        
- INDUS    proportion of non-retail business acres per town\n     # 每每个城镇的非零售商业用地比例
- CHAS     Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)\n        # 靠近查尔斯河,则为1;否则为0
- NOX      nitric oxides concentration (parts per 10 million)\n      # 一氧化氮浓度(百万分之一)
- RM       average number of rooms per dwelling\n  # 每个住宅的平均房间数      
- AGE      proportion of owner-occupied units built prior to 1940\n     # 1940年之前建造的自住单位比例  
- DIS      weighted distances to five Boston employment centres\n     # 到波士顿五个就业中心的加权距离
- RAD      index of accessibility to radial highways\n    # 辐射状公路可达性指数    
- TAX      full-value property-tax rate per $10,000\n   # 每10000美元的全额财产税税率
- PTRATIO  pupil-teacher ratio by town\n    # 按城镇划分的师生比例    
- B        1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town\n  # 1000(Bk-0.63)^2其中Bk是按城镇划分的黑人比例       
- LSTAT    % lower status of the population\n    # 人口密度   
- MEDV     Median value of owner-occupied homes in $1000's\n # 住房屋的中值(单位:1000美元)

1.2模型训练

"""模型训练"""
model = LGBMRegressor()

# r2_score:0.886290371005902
# mse:8.338757275893952

# 给定参数
# model = LGBMRegressor(boosting='gbdt',  # gbdt \ dart
#                       n_estimators=300,  # 迭代次数
#                       learning_rate=0.1,  # 步长
#                       max_depth=10,  # 树的最大深度
#                       seed=42,  # 指定随机种子,为了复现结果
#                       )

# r2_score:0.9043019057586194
# mse:7.017903291953562

model.fit(X_train, y_train)

LGBMRegressor()是没有指定参数,模型使用默认参数如下。也可以指定参数例如指定boosting='dart'等。训练的模型参数如下:

parameters:
[boosting: gbdt]
[objective: regression]
[metric: l2]
[tree_learner: serial]
[device_type: cpu]
[data: ]
[valid: ]
[num_iterations: 100]
[learning_rate: 0.1]
[num_leaves: 31]
[num_threads: -1]
[deterministic: 0]
[force_col_wise: 0]
[force_row_wise: 0]
[histogram_pool_size: -1]
[max_depth: -1]
......

1.3模型验证

模型效果的验证,简单直接的可以通过验证集来实现。实际项目中通常将整个数据集按照7:3:1比例划分为训练集、验证集、测试集。本例使用验证集验证模型准确性。
回归任务的评估指标只要有 r2_score 和 mse,其中 r2_score 越趋近于1越好,mse 越小越好。
R2 = 1 - (SSE / TSS),其中,SSE(sum of squared errors )是模型预测值与实际观测值之间差异的平方和,TSS(total sum of squares)是所有观测值与其均值差异的平方和。

y_pred = model.predict(X_test)
# print(y_pred)

for m, n in zip(y_pred, y_test):
    if m / n - 1 > 0.2:
        print('预测值为{0}, 真是结果为{1}, 预测结果偏差大于20%'.format(m, n))


def metrics_sklearn(y_valid, y_pred_):
    """模型效果评估"""
    r2 = r2_score(y_valid, y_pred_)
    print('r2_score:{0}'.format(r2))

    mse = mean_squared_error(y_valid, y_pred_)
    print('mse:{0}'.format(mse))


"""模型效果评估"""
metrics_sklearn(y_test, y_pred)

结果中仅打印了预测误差在20%以上的预测数据。
在这里插入图片描述

二.模型调参

def adj_params():
    """模型调参"""
    params = {
        'n_estimators': [100, 200, 300, 400],
        # 'learning_rate': [0.01, 0.03, 0.05, 0.1],
        'max_depth': [5, 8, 10, 12]
    }

    other_params = {'learning_rate': 0.1, 'seed': 42}
    model_adj = LGBMRegressor(**other_params)

    # sklearn提供的调参工具,训练集k折交叉验证(消除数据切分产生数据分布不均匀的影响)
    optimized_param = GridSearchCV(estimator=model_adj, param_grid=params, scoring='r2', cv=5, verbose=1)
    # 模型训练
    optimized_param.fit(X_train, y_train)

    # 对应参数的k折交叉验证平均得分
    means = optimized_param.cv_results_['mean_test_score']
    params = optimized_param.cv_results_['params']
    for mean, param in zip(means, params):
        print("mean_score: %f,  params: %r" % (mean, param))
    # 最佳模型参数
    print('参数的最佳取值:{0}'.format(optimized_param.best_params_))
    # 最佳参数模型得分
    print('最佳模型得分:{0}'.format(optimized_param.best_score_))


adj_params()

2.1网格搜索调参

params = {
        'n_estimators': [100, 200, 300, 400],
        # 'learning_rate': [0.01, 0.03, 0.05, 0.1],
        'max_depth': [5, 8, 10, 12]
    }

other_params = {'learning_rate': 0.1, 'seed': 42}

调参内容不是很多时,例如本次调参训练80次,两个参数值可以一起调整,结果如下图:
在这里插入图片描述

调参是个无穷无尽的过程,适可而止,切误沉溺其中本末倒置,真正决定模型效果上限的还是数据质量

2.2调参结果入模

model = LGBMRegressor(boosting='gbdt',  # gbdt \ dart
                      n_estimators=200,  # 迭代次数
                      learning_rate=0.1,  # 步长
                      max_depth=5,  # 树的最大深度
                      seed=42,  # 指定随机种子,为了复现结果
                      )

model.fit(X_train, y_train)

基础模型boosting='gbdt',最大深度max_depth=5, 迭代次数n_estimators=200 参数入模,fit()训练带参的模型,模型的参数和评估见下方(三.模型保存、加载、调用预测)
在这里插入图片描述

三.模型保存、加载、调用预测

3.1模型保存、加载、调用预测

"""模型保存"""
model.booster_.save_model('lgb_regressor_boston.txt')

"""模型加载"""
rgs = lgb.Booster(model_file='lgb_regressor_boston.txt')

"""模型参数打印"""
print('模型参数值-开始'.center(20, '='))
# lightgbm模型参数直接打开模型文件查看更为方便
model_params = rgs.dump_model()
print(model_params)
print('模型参数值-结束'.center(20, '='))

"""预测验证数据"""
y_pred = rgs.predict(X_test)

"""模型效果评估"""
metrics_sklearn(y_test, y_pred)

模型参数打印和预测评估结果如图,不再赘述。
在这里插入图片描述

3.2模型参数

经过上面脚本’lgb_regressor_boston.txt’是已经保存到本地的模型文件,可以打开文件查看参数,其中
最开始是 树tree 的信息;

tree
version=v3
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=12
objective=regression
feature_names=Column_0 Column_1 Column_2 Column_3 Column_4 Column_5 Column_6 Column_7 Column_8 Column_9 Column_10 Column_11 Column_12
feature_infos=[

每个tree信息都会存在,trees之后是feature_importances特征重要性信息,几乎在文件末尾;

......
end of trees

feature_importances:
Column_12=289
Column_7=237
Column_5=230
Column_6=174
Column_0=155
Column_11=140
Column_4=138
Column_9=97
Column_10=68
Column_2=38
Column_3=34
Column_1=21
Column_8=16

结尾是 parameters 参数和 pandas_categorical pandas经过虚拟化的类别信息。

parameters:
[boosting: gbdt]
[objective: regression]
[metric: l2]
[tree_learner: serial]
[device_type: cpu]
[data: ]
[valid: ]
[num_iterations: 200]
[learning_rate: 0.1]
[num_leaves: 31]
[num_threads: -1]
[deterministic: 0]
[force_col_wise: 0]
[force_row_wise: 0]
[histogram_pool_size: -1]
[max_depth: 5]
[min_data_in_leaf: 20]
......

end of parameters

pandas_categorical:null

附加——深入学习XGBoost

附加1.模型调参、训练、保存、评估和预测

见《XGBoost模型调参、训练、评估、保存和预测》 ,包含模型脚本文件

附加2.算法原理

见《XGBoost算法原理及基础知识》 ,包括集成学习方法,XGBoost模型、目标函数、算法,公式推导等

附加3.分类任务的评估指标值详解

见《分类任务评估1——推导sklearn分类任务评估指标》,其中包含了详细的推理过程;
见《分类任务评估2——推导ROC曲线、P-R曲线和K-S曲线》,其中包含ROC曲线、P-R曲线和K-S曲线的推导与绘制;

附加4.模型中树的绘制和模型理解

见《Graphviz绘制模型树1——软件配置与XGBoost树的绘制》,包含Graphviz软件的安装和配置,以及to_graphviz()和plot_trees()两个画图函数的部分使用细节;
见《Graphviz绘制模型树2——XGBoost模型的可解释性》,从模型中的树着手解释XGBoost模型,并用EXCEL构建出模型。

附加5.XGBoost实践

见机器学习实践(1.1)XGBoost分类任务,包含二分类、多分类任务以及多分类的评估方法。
见机器学习实践(1.2)XGBoost回归任务,包含回归任务模型训练、评估(R2、MSE)
见机器学习实践(2.1)LightGBM分类任务,包含LightGBM二分类、多分类任务及评估方法。

❤️ 机器学习内容持续更新中… ❤️


声明:本文所载信息不保证准确性和完整性。文中所述内容和意见仅供参考,不构成实际商业建议,可收藏可转发但请勿转载,如有雷同纯属巧合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/755373.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

青岛大学_王卓老师【数据结构与算法】Week05_10_顺序栈的操作3_学习笔记

本文是个人学习笔记,素材来自青岛大学王卓老师的教学视频。 一方面用于学习记录与分享, 另一方面是想让更多的人看到这么好的《数据结构与算法》的学习视频。 如有侵权,请留言作删文处理。 课程视频链接: 数据结构与算法基础…

OpenTelemetry

OpenTelemetry(简称为Otel)是一个开源项目,旨在为分布式系统提供可观测性(observability)。它提供了一组标准化的API、工具和库,用于生成、收集和分析分布式应用程序的跟踪(tracing)…

Redis报错-CROSSSLOT keys in request don‘t hash in the same slot

背景 问题涉及:spring security、spring session、redis 问题描述 springbootspringsecurityspringsessionantd 登录功能的时候,在源码中使用到了redis的rename命令(如下图所示) 在这里就会报错 CROSSSLOT keys in request d…

微信小程序安装和使用 Vant Weapp 组件库

微信小程序安装和使用 Vant Weapp 组件库 1. Vant Weapp 介绍2. Vant Weapp 的 安装2.1. 通过npm安装2.2. 构建npm2.3. 修改 app.json2.4. 修改 project.congfig.json2.5. 测试一下,使用Vant Weapp提供的组件 1. Vant Weapp 介绍 Vant 是一个轻量、可靠的移动端组件…

字符函数和内存函数(二)

目录 一、strtok函数 二、strerror函数 三、memcpy函数 3.1memcpy函数的认识 3.2memcpy函数的模拟实现 四、memmove函数 4.1memmove函数的认识 4.2memmove函数的模拟实现 五、memcmp函数 5.1memcmp函数的认识 5.2memcmp函数的模拟实现 六、memset函数 七、字符分类函…

24 张图搞定 ICMP :最常用的网络命令 ping 和 tracert

ICMP IP 是尽力传输的网络协议,提供的数据传输服务是不可靠的、无连接的,不能保证数据包能成功到达目的地。那么问题来了:如何确定数据包成功到达目的地? 这需要一个网络层协议,提供错误检测功能和报告机制功能&#x…

Python爬虫——urllib_post请求百度翻译

post请求: post的请求参数,是不会拼接在url后面的,而是需要放在请求对象定制的参数中 post请求的参数需要进行两次编码,第一次urlencode:对字典参数进行Unicode编码转成字符串,第二次encode:将字…

GNN环境安装

参考: torch_geometric踩坑实战–安装与运行 亲测有效!! https://blog.csdn.net/m0_55245520/article/details/130424828pytorch 查看gpu cuda版本 https://blog.csdn.net/jacke121/article/details/93592487 x.1 安装 x.1.1 镜像信息补充…

【LeetCode】594. 最长和谐子序列

594. 最长和谐子序列(简单) 方法:哈希表计数 思路 题目规定的「和谐子序列」中的最值差值正好为 1,因而子序列排序后必然符合[a,a,.., a 1,a1]形式,即符合条件的和谐子序列长度为相邻两数(差值为 1)的出现次数之和。…

element中icon字体图标的使用

效果图 官方提供的图标 icon字体图标 安装 安装依赖 cnpm install element-plus/icons-vue 编写src/plugins/icons.js import * as components from "element-plus/icons-vue";export default {install: (app) > {for (const key in components) {const comp…

c++智能指针简单示例

代码 #include<iostream> using namespace std; #include<memory> // 头文件class TestClass { private:int Value;public:TestClass(int value) :Value(value) {cout << "构造函数调用" << endl;}~TestClass() {cout << "析构函…

如何有效阅读文献

作为研究生要保持看文献的能力&#xff0c;以《面向大规模图像定位的高效优先匹配&#xff08;Efficient & Effective Prioritized Matching for Large-Scale Image-Based Localization&#xff09;》文献为例&#xff0c;本文记录了自己在学习过程中如何阅读文献技巧。 文…

第八节 学生管理系统 (阶段案例)

学生管理系统 1.1 设计背景 管理系统&#xff0c;主要任务就是使用计算机对学生的各种信息进行日常管理&#xff0c;如&#xff1a; 添加删除修改查询退出系统 程序设计思路 1.2 需求设计分析 打印 “学生管理系统” 的功能菜单,提示用户选择功能序号; print_menu() 打印函…

WebSocket协议基础

文章目录 什么是websocketwebsocet 特点 一、websocket 建立连接流程二、websocket 握手流程客户端握手包2.服务端握手包 三、websocket数据总结参考 什么是websocket WebSOcket 是基于TCP的应用层协议。该协议和http或https 相似&#xff0c;但是却区别于http的一种新的协议。…

AD Class 、设计参数、规则的创建

设计 生产 线宽 间距 过孔 根据生产的要求进行桥接 Class 电源走线 和 信号走线 设计—》类里有 将所有的电源都添加进电源类里 新建的类别可以在Panls的PCB中看到 并且可以在这里面改变线的颜色 区分电源 对于走线的宽度,电源主要是用来载流的&#xff0c;信号主要是用来做信…

彻底理解Handler的设计之传送带模型

作者&#xff1a;彭泰强 0 这篇文章的目的 有时候在Handler相关的文章中可以看到&#xff0c;会把Handler机制的几个角色类比成一个传送带场景来理解。 例如&#xff0c;这篇文章中写到&#xff1a; 我们可以把传送带上的货物看做是一个个的Message&#xff0c;而承载这些货物…

6.2.8 网络基本服务----万维网(www)

6.2.8 网络基本服务----万维网&#xff08;www&#xff09; 万维网即www&#xff08;World Wide Web&#xff09;是开源的信息空间&#xff0c;使用URL也就是统一资源标识符标识文档和Web资源&#xff0c;使用超文本链接互相连接资源&#xff0c;万维网并非某种特殊的计算机网…

力扣 198.打家劫舍【中等】

198.打家劫舍 1 题目2 思路3 代码4 结果 1 题目 题目来源&#xff1a;力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/house-robber 题目&#xff1a;你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃…

【Vue3】初始化和Composition API(组合式)

Vue3 创建Vue3.0工程查看自己的vue/cli版本&#xff0c;使用Vue/cli创建使用vite创建 查看Vue3.0工程vue.config.js中&#xff0c;关闭语法检查&#xff0c;main.js讲解app.vue讲解 常用的Composition API&#xff08;组合式&#xff09;1.拉开序幕的setup返回对象返回渲染函数…

降级npm后,出现xxx 不是内部或外部命令解决方法

比如我安装了anyproxy npm install anyproxy -g 之后在cmd中输入anyproxy 发现 anyproxy 不是内部或外部命令解决方法. 一般出现这样的问题原因是npm安装出现了问题&#xff0c;全局模块目录没有被添加到系统环境变量。 Windows用户检查下npm的目录是否加入了系统变量P…