人工智能——梯度提升决策树算法

news2025/4/18 17:02:28

目录

摘要

14 梯度提升决策树

14.1 本章工作任务

14.2 本章技能目标

14.3 本章简介

14.4 编程实战

14.5 本章总结

14.6 本章作业

本章已完结!


摘要

本章实现的工作是:首先采用Python语言读取含有英语成绩、数学成绩以及学生所属类型的样本数据。然后将样本数据划分为训练集和测试集,接着采用GBDT算法,对训练集数据进行拟合,最后在输入更多学生的数学成绩和英语成绩后,使用已求解的最优模型去预测其分类结果。

本章掌握的技能是:1、使用NumPy包读取连续的样本数据。2、使用sklearn库model_selection模块中的model_selection函数实现训练集和测试集的划分。3、使用Matplotlib实现数据的可视化,绘制树状图。

14 梯度提升决策树

14.1 本章工作任务

 采用梯度提升决策树(GBDT)算法编写程序,根据700名学生的数学成绩和英语成绩对其进行分类,将其划分为文科生、理科生和综合生。1、算法的输入是:700名学生的数学和英语成绩以及相应的的学生类型。2、算法模型需要求解的是:N颗残差树(每颗残差树需要求解所有分支,每个分支节点需要求解该分支的属性及分支的阈值)。3、算法的结果是:待测样本中学生的分类。

14.2 本章技能目标

掌握GBDT原理。

使用Python读取样本数据,并划分为训练集和测试集。

使用Python实现GBDT的建模与求解。

掌握GBDT模型实现预测。

使用Python实现GBDT分类结果进行可视化展示。

14.3 本章简介

梯度提升决策树(GBDT)是指:一种集成算法,由多个子分类器(弱分类器)的分类结果进行累加,从而得到总分类器(强分类器)的分类结果。GBDT子分类器的特点是后一个子分类器是对前一个子分类器得到的分类结果与目标值之间的差值进行的拟合,即后一个子分类器是对前一个子分类器得到的残差值的矫正。

梯度提升决策树(GBDT)算法可以解决的实际应用问题是:已知N个样本数据,样本特征是学生的数学成绩和英语成绩,样本标签是学生的类型。通过建立GBDT模型对样本数据进行训练,找到样本特征和样本标签之间的关系,从而预测出T个新的样本(学生)的类型。

本章的重点是:梯度提升决策树方法的理解和使用。


14.4 编程实战

  步骤1 引入NumPy库,命名为np;引入pandas库,命名为pd;引入os.path.abspath()模块,命名为plt,用于绘制图像;引入os模块,用于处理文件和目录。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

步骤2 将当前文件所在目录的路径设置为Python的当前工作目录。os.path.abspath()用于将相对路径转化为绝对路径;os.chdir()用于改变当前工作目录;os.getcwd()用于获取当前工作目录(不同设备的绝对路径不同)。

thisFilePath = os.path.abspath('.')  # 获取当前文件的绝对路径
os.chdir(thisFilePath)  # 改变当前工作目录
os.getcwd()   # 获取当前工作目录



输出结果:
'D:\\MyPythonFiles'

步骤3 导入并读取数据。pd.read_csv()函数用于读取数据,usecols参数用于读取文件中指定的数据列。type()函数用于返回对象的类型。head()函数用于查看前几行数据,tail()函数用于查看后几行数据。

myData = pd.read_csv('DataForClassify(1).csv', usecols=['YingYu', 'ShuXue', 'Label'])
type(myData)
myData.head()

输出结果:

myData.tail()

输出结果:

步骤4 划分训练集和测试集。从sklearn包的model_selection模块中引入train_test_split函数,函数中的第一个参数表示所要划分的样本特征集;第二个参数表示所要划分的样本标签;第三个参数表示测试集占样本数据的比例,若为整数,则表示绝对数量;第四个参数表示随机数种子。

from sklearn.model_selection import train_test_split
trainSet_x, testSet_x, trainSet_y, testSet_y = train_test_split(
    myData.iloc[:, 0:2],
    myData.iloc[:, 2],
    test_size = 0.2,
    random_state=220
)
trainSet_x.shape




输出结果:
(560,2)

步骤5 拟合并验证梯度提升决策树(GBDT)模型。从sklearn包中引入ensemble,并构建模型。

从sklearn包的model_selection模块中引入cross_val_score函数,运用 K 折交叉验证对模型的稳定性进行验证。其中,第一个参数表示模型的名称,第二个参数表示样本特征集,第三个参数表示样本标签,第四个参数表示进行几个交叉验证。K 折交叉验证指把初始训练样本分成 K份,其中 K-1份被用作训练集,剩下一份被用作评估集,进行K次训练后得到K个训练结果,通过结果对比来验证模型的稳定性。

from sklearn import ensemble
import datetime
start_time = datetime.datetime.now()   # 获取函数的开始时间
GBDT_model = ensemble.GradientBoostingClassifier(n_estimators=10, min_samples_split=50)
GBDT_model = GBDT_model.fit(trainSet_x, trainSet_y)  # 导入训练集的数据进行模型拟合
from sklearn.model_selection import cross_val_score
print("%s Score: %0.2f" % ("GBDT", GBDT_model.score(testSet_x, testSet_y)))
scores = cross_val_score(GBDT_model, testSet_x, testSet_y, cv=5)
print("%s Cross Avg. Score: %0.2f(+/-%0.2f)" % ("GBDT", scores.mean(), scores.std()*2))
end_time = datetime.datetime.now()  # 求得函数结束的时间
time_spend = end_time - start_time  # 求得函数运行的时间
print("%s Time:%0.2f" % ("GBDT", time_spend.total_seconds()))  # 输出模型的训练时间





输出结果:

GBDT Score: 1.00
GBDT Cross Avg. Score: 1.00(+/-0.00)
GBDT Time:333.33

步骤6 对模型的输出结果进行可视化。

from sklearn import tree

estimator = GBDT_model.estimators_[0,1]

dot_data = tree.export_graphviz(
    estimator,
    out_file = None,
    feature_names = trainSet_x.columns.values,
    filled = True,
    impurity = False,
    rounded = True
)
import pydotplus
graph = pydotplus.graph_from_dot_data(dot_data)
graph.get_nodes()[1].set_fillcolor("#FFF2DD")
from IPython.display import Image
Image(graph.create_png())

输出结果:

步骤7 对样本数据的分类结果进行可视化。

import matplotlib as mpl

def plot(train_x, train_y, test_x, test_y, model):
    x1_min, x1_max = train_x.iloc[:, 0].min(), train_x.iloc[:, 0].max()  # YingYu的最低分和最高分
    x2_min, x2_max = train_x.iloc[:, 1].min(), train_x.iloc[:, 1].max()  # ShuXue的最低分和最高分
    x1, x2 = np.mgrid[x1_min:x1_max:80j, x2_min:x2_max:80j]  # 生成网格采样点
    grid_test = np.stack((x1.flat, x2.flat), axis = 1)
    grid_test_df = pd.DataFrame(grid_test, columns=train_x.columns)
    grid_hat = model.predict(grid_test_df)
    grid_hat = grid_hat.reshape(x1.shape)
    color = ['g', 'r', 'b']
    color_dark = ['darkgreen', 'darkred', 'darkblue']
    markers = ["o", "D", "^"]
    plt.pcolormesh(x1, x2, grid_hat)
    train_x_arr = np.array(train_x)
    test_x_arr  = np.array(test_x)
    for i, marker in enumerate(markers):
        plt.scatter(train_x_arr[train_y==i+1][:,0],train_x_arr[train_y==i+1][:,1],c=color[i],edgecolors='black',s=20,marker=marker)
        plt.scatter(test_x_arr[test_y==i+1][:,0],test_x_arr[test_y==i+1][:,1],c=color_dark[i],edgecolors='black',s=40,marker=marker)
        
    plt.xlabel('English', fontsize=13)
    plt.ylabel('Math', fontsize=13)
    plt.xlim(x1_min, x1_max)
    plt.ylim(x2_min, x2_max)
    plt.title('Students Grade', fontsize=15)
    plt.show()
plot(trainSet_x, trainSet_y, testSet_x, testSet_y, GBDT_model)

 输出结果:

14.5 本章总结

本章实现的工作是:首先采用Python语言读取含有英语成绩、数学成绩以及学生所属类型的样本数据。然后将样本数据划分为训练集和测试集,接着采用GBDT算法,对训练集数据进行拟合,最后在输入更多学生的数学成绩和英语成绩后,使用已求解的最优模型去预测其分类结果。

本章掌握的技能是:1、使用NumPy包读取连续的样本数据。2、使用sklearn库model_selection模块中的model_selection函数实现训练集和测试集的划分。3、使用Matplotlib实现数据的可视化,绘制树状图。

14.6 本章作业

1、实现本章的案例,即生成样本数据,实现梯度提升决策树模型的建模、参数调整、预测和数据可视化。

2、利用 Iris(鸢尾花)原始数据集,运用GBDT算法,实现根据鸢尾花的任意两个特征对其进行分类。

本章已完结!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2336663.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【leetcode hot 100 136】只出现一次的数字

解法一:(异或XOR)相同的数字出现两次则归零 class Solution {public int singleNumber(int[] nums) {int result 0;for(int num:nums){result ^ num;}return result;} }注意: 其他方法:HashList记录次数再查找数组&a…

QEMU学习之路(8)— ARM32通过u-boot 启动Linux

QEMU学习之路(8)— ARM32通过u-boot 启动Linux 一、前言 参考文章: Linux内核学习——内核的编译和启动 Linux 内核的编译和模拟执行 Linux内核运行——根文件系统 Linux 内核学习——使用 uboot 加载内核 二、构建Linux内核 1、获取Linu…

AgentOps - 帮助开发者构建、评估和监控 AI Agent

文章目录 一、关于 AgentOps二、关键集成 🔌三、快速开始 ⌨️2行代码中的Session replays 首类开发者体验 四、集成 🦾OpenAI Agents SDK 🖇️CrewAI 🛶AG2 🤖Camel AI 🐪Langchain 🦜&#x1…

leetcode 122. Best Time to Buy and Sell Stock II

题目描述 这道题可以用贪心思想解决。 本文介绍用动态规划解决。本题分析方法与第121题一样,详见leetcode 121. Best Time to Buy and Sell Stock 只有一点区别。第121题全程只能买入1次,因此如果第i天买入股票,买之前的金额肯定是初始金额…

【ROS】代价地图

【ROS】代价地图 前言代价地图(Costmap)概述代价地图的参数costmap_common_params.yaml 参数说明costmap_common_params.yaml 示例说明global_costmap.yaml 参数说明global_costmap.yaml 示例说明local_costmap.yaml 参数说明local_costmap.yaml 示例说明…

《Against The Achilles’ Heel: A Survey on Red Teaming for Generative Models》全文阅读

《Against The Achilles’ Heel: A Survey on Red Teaming for Generative Models》 突破阿基里斯之踵:生成模型红队对抗综述 摘要 生成模型正迅速流行并被整合到日常应用中,随着各种漏洞暴露,其安全使用引发担忧。鉴于此,红队…

datagrip连接mysql问题5.7.26

1.Case sensitivity: plainmixed, delimitedexac Remote host terminated the handshake. 区分大小写:plain混合,分隔exac 远程主机终止了握手。 原因:usessl 参数用于指定是否使用 SSL(Secure Sockets Layer)加密来保护数据传…

探索亮数据Web Unlocker API:让谷歌学术网页科研数据 “触手可及”

本文目录 一、引言二、Web Unlocker API 功能亮点三、Web Unlocker API 实战1.配置网页解锁器2.定位相关数据3.编写代码 四、Web Scraper API技术亮点 五、SERP API技术亮点 六、总结 一、引言 网页数据宛如一座蕴藏着无限价值的宝库,无论是企业洞察市场动态、制定…

【本地MinIO图床远程访问】Cpolar TCP隧道+PicGo插件,让MinIO图床一键触达

写在前面:本博客仅作记录学习之用,部分图片来自网络,如需引用请注明出处,同时如有侵犯您的权益,请联系删除! 文章目录 前言MinIO本地安装与配置cpolar 内网穿透PicGo 安装MinIO远程访问总结互动致谢参考目录…

Policy Gradient思想、REINFORCE算法,以及贪吃蛇小游戏(一)

文章目录 Policy Gradient思想论文REINFORCE算法论文Policy Gradient思想和REINFORCE算法的关系用一句人话解释什么是REINFORCE算法策略这个东西实在是太抽象了,它可以是一个什么我们能实际感受到的东西?你说的这个我理解了,但这个东西,我怎么优化?在一堆函数中,找到最优…

Profibus DP主站转modbusTCP网关与dp从站通讯案例

Profibus DP主站转modbusTCP网关与dp从站通讯案例 在当前工业自动化的浪潮中,不同协议之间的通讯转换成为了提升生产效率和实现设备互联的关键。Profibus DP作为一种广泛应用的现场总线技术,与Modbus TCP的结合,为工业自动化系统的集成带来了…

快速部署大模型 Openwebui + Ollama + deepSeek-R1模型

背景 本文主要快速部署一个带有web可交互界面的大模型的应用,主要用于开发测试节点,其中涉及到的三个组件为 open-webui Ollama deepSeek开放平台 首先 Ollama 是一个开源的本地化大模型部署工具,提供与OpenAI兼容的Api接口,可以快速的运…

H.265硬件视频编码器xk265代码阅读 - 帧内预测

源代码地址: https://github.com/openasic-org/xk265 帧内预测具体逻辑包含在代码xk265\rtl\rec\rec_intra\intra_pred.v 文件中。 module intra_pred() 看起来是每次计算某个4x4块的预测像素值。 以下代码用来算每个pred_angle的具体数值,每个mode_i对应…

Arcgis经纬线标注设置(英文、刻度显示)

在arcgis软件中绘制地图边框,添加经纬度度时常常面临经纬度出现中文,如下图所示: 解决方法,设置一下Arcgis的语言 点击高级--确认 这样Arcgis就转为英文版了,此时在来看经纬线刻度的标注,自动变成英文

Windows安装Ollama并指定安装路径(默认C盘)

手打不易,如果转摘,请注明出处! 注明原文:http://blog.csdn.net/q258523454/article/details/147289192 一、下载Ollama 访问Ollama官网 打开浏览器,访问Ollama的官方网站:https://ollama.ai/。 在官网首页…

Python自动化处理奖金分摊:基于连续空值的智能分配算法升级

Python自动化处理奖金分摊:基于连续空值的智能分配算法升级 原创 IT小本本 IT小本本 2025年04月04日 02:00 北京 引言 在企业薪酬管理中,团队奖金分配常涉及复杂的分摊规则。传统手工分摊不仅效率低下,还容易因人为疏漏导致分配不公。 本文…

AI工具箱源码+成品网站源码+springboot+vue

大家好,今天给大家分享一个靠AI广告赚钱的项目:AI工具箱成品网站源码,源码支持二开,但不允许转售!! 本人专门为小型企业和个人提供的解决方案。 不懂技术的也可以直接部署工具箱网站,成为站长&…

如何下载免费地图数据?

按照以下步骤下载免费地图数据。 1、安装GIS地图下载器 从GeoSaaS(.COM)官网下载“GIS地图下载器”软件:,安装完成后桌面上出现”GIS地图下载器“图标。 双击桌面图标打开”GIS地图下载器“ 2、下载地图数据 点击主界面底部的“…

IO 口作为外部中断输入

外部中断 1. NVIC2. EXTI 1. NVIC NVIC即嵌套向量中断控制器,它是内核的器件,M3/M4/M7 内核都是支持 256 个中断,其中包含了 16 个系统中断和 240 个外部中断,并且具有 256 级的可编程中断设置。然而芯片厂商一般不会把内核的这些…

《MySQL基础:了解MySQL周边概念》

1.登录选项的认识 -h:指明登录部署了mysql服务的主机,默认为127.0.0.1-P:指明要访问的端口号,默认为3306-u:指明登录用户-p:指明登录密码 2.什么是数据库 2.1认识数据库 第一点理解。 mysql是数据库的客户…