深度神经网络——决策树的实现与剪枝

news2024/12/23 13:00:07

概述

决策树 是一种有用的机器学习算法,用于回归和分类任务。 “决策树”这个名字来源于这样一个事实:算法不断地将数据集划分为越来越小的部分,直到数据被划分为单个实例,然后对实例进行分类。如果您要可视化算法的结果,类别的划分方式将类似于一棵树和许多叶子。

这是决策树的快速定义,但让我们深入了解决策树的工作原理。 更好地了解决策树的运作方式及其用例,将帮助您了解何时在机器学习项目中使用它们。

决策树的结构

决策树的结构类似于流程图,从一个起点或根节点开始,根据过滤条件的判断结果,逐级分支,直至达到树的末端,即叶子节点。每个内部节点代表一个特征的测试条件,而叶子节点则代表数据点的分类标签。
在这里插入图片描述
决策树是一种层次化的决策模型,它通过一系列的问题将数据分类。以下是决策树结构的关键组成部分和特性:

  1. 根节点(Root Node)

    • 决策树的起点,代表整个数据集。
  2. 内部节点(Internal Nodes)

    • 表示决策问题或属性测试。每个内部节点对应一个特征(或属性)的分割点。
  3. 分支(Branches)

    • 从每个内部节点延伸出来,代表测试的不同结果。分支的数量取决于该节点特征的可能值。
  4. 叶子节点(Leaf Nodes)

    • 树的末端,代表最终决策或分类结果。在分类问题中,叶子节点通常包含类别标签;在回归问题中,它们包含预测值。
  5. 路径(Path)

    • 从根节点到任一叶子节点的连接序列,代表一系列决策规则。
  6. 分割(Split)

    • 在内部节点处,根据特征值将数据集分割成子集的过程。
  7. 特征(Feature)

    • 用于分割数据的特征或属性。
  8. 阈值(Threshold)

    • 用于确定数据点是否沿着特定分支的值。
  9. 纯度(Purity)

    • 衡量节点中数据点是否属于同一类别的指标。高纯度意味着节点中的数据点属于同一类别。
  10. 深度(Depth)

    • 从根节点到树中任意节点的最长路径长度。
  11. 宽度(Width)

    • 树中叶子节点的最大数量。
  12. 树高(Tree Height)

    • 从根节点到最远叶子节点的边数。
  13. 基尼指数(Gini Index)

    • 用于分类树的内部节点评估,衡量节点不纯度的指标。
  14. 熵(Entropy)

    • 另一种衡量节点不纯度的指标,常用于构建分类树。
  15. 信息增益(Information Gain)

    • 通过分割获得的信息量,用于选择最佳分割点。
  16. 决策规则(Decision Rules)

    • 从根到叶的路径上的一系列决策,用于对数据点进行分类。

决策树的结构使得模型不仅能够进行预测,还能够解释预测背后的逻辑。这种可解释性使得决策树在需要模型透明度的应用中非常有用。然而,决策树也容易过拟合,特别是当树变得非常深和复杂时。因此,剪枝技术通常用于简化决策树,提高其泛化能力。

决策树算法

决策树的构建过程采用递归二元分割算法,该算法通过评估不同特征对数据集进行分割的效果,选择最佳分割点。分割的目的是使得每个子集尽可能地“纯”,即包含的数据点属于同一类别或具有相似的响应值。

分割成本的确定

决策树是一种常用用于分类和回归任务。在回归问题中,决策树的目标是预测一个连续的输出值。如果你使用决策树进行回归预测,并希望计算预测误差,你可以使用均方误差(Mean Squared Error, MSE)作为评估指标。MSE 衡量的是模型预测值与实际值之间差异的平方的平均值。

对于决策树来说,计算 MSE 的过程如下:

  1. 使用决策树模型进行预测:给定一个训练好的决策树模型,对于每个数据点,使用模型进行预测,得到预测值 prediction_i

  2. 计算误差:对于每个数据点,计算其实际值 y_i 与预测值 prediction_i 之间的差异,然后计算这个差异的平方。

  3. 求和:将所有数据点的误差平方求和。

  4. 平均:将求和结果除以数据点的总数 n,得到 MSE。

数学公式表示为:

M S E = 1 n ∑ i = 1 n ( y i − prediction i ) 2 {MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \text{prediction}_i)^2 MSE=n1i=1n(yipredictioni)2

其中:

  • n n n 是数据集中的样本数量。
  • y i y_i yi是第i` 个样本的实际值。
  • p r e d i c t i o n i {prediction}_i predictioni 是模型对第 i 个样本的预测值。

在 Python 中,如果使用 scikit-learn 库,可以很容易地计算决策树模型的 MSE。以下是一个简单的例子:

from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np

# 假设 X 是特征数据,y 是目标变量
X = ...  # 特征数据
y = ...  # 目标变量

# 创建决策树回归模型
tree_reg = DecisionTreeRegressor()

# 训练模型
tree_reg.fit(X, y)

# 进行预测
y_pred = tree_reg.predict(X)

# 计算 MSE
mse = mean_squared_error(y, y_pred)
print(f"Mean Squared Error: {mse}")

MSE 仅适用于回归问题。如果你在处理分类问题,可能需要考虑其他指标,如准确率、召回率、F1 分数等。此外,MSE 对异常值敏感,因此在某些情况下,你可能还想使用其他指标,如平均绝对误差(Mean Absolute Error, MAE)来评估模型性能。

决策树的剪枝

决策树的剪枝是防止模型过拟合的重要技术。过拟合的决策树可能会在训练数据上表现良好,但在未见过的数据上泛化能力差。剪枝通过移除树中的一些分支来简化模型,从而提高其在新数据上的预测性能。以下是几种常见的决策树剪枝方法:

  1. 预剪枝(Pre-pruning)

    • 在构建决策树的过程中,预剪枝会在树生长的每个阶段评估是否应该停止分裂。如果某个节点的分裂不能显著提高模型的性能,那么这个节点将被标记为叶子节点,不再进一步分裂。
  2. 后剪枝(Post-pruning)

    • 后剪枝是在决策树完全生长完成后进行的。它从树的叶子节点开始,评估移除节点对模型性能的影响。如果移除某个节点后的模型性能没有显著下降,那么这个节点将被删除。
  3. 错误率降低剪枝(Reduced-Error Pruning)

    • 这种方法是在后剪枝的基础上,通过比较剪枝前后的错误率来决定是否剪枝。如果剪枝后的模型在交叉验证集上的错误率没有增加,或者增加的幅度在可接受范围内,那么剪枝是成功的。
  4. 代价复杂性剪枝(Cost-Complexity Pruning)

    • 代价复杂性剪枝是一种后剪枝技术,它通过引入一个参数来平衡模型的复杂度和预测误差。这种方法允许模型在剪枝过程中保持一定程度的复杂性,同时减少过拟合的风险。
  5. 最小描述长度剪枝(Minimum Description Length Pruning)

    • 这种方法基于信息论原理,试图找到能够最小化描述模型和数据所需的信息量(即描述长度)的树。它考虑了模型的复杂性和预测误差,以找到最佳的剪枝点。
  6. 基于规则的剪枝

    • 在某些情况下,可以使用领域知识来定义规则,以指导剪枝过程。例如,如果某个特征在数据集中的分布非常不均匀,可以考虑剪枝掉依赖于该特征的分支。

使用决策树的注意事项

决策树在需要快速分类且计算时间受限的场景下非常有用。它们能够清晰地展示数据集中哪些特征最具预测力,并且与许多其他机器学习算法相比,决策树的规则更易于解释。此外,决策树能够处理分类变量和连续变量,减少了预处理的需求。

然而,决策树在预测连续属性值时可能表现不佳,且在类别众多而训练样本较少的情况下,分类准确性可能降低。

通过深入理解决策树的工作原理和特性,我们可以更好地判断在机器学习项目中何时使用它们,以及如何优化它们的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1871549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL注入和防御方法

SQL注入是一种攻击手段,通过在SQL查询中插入恶意SQL代码片段,欺骗数据库服务器执行非授权的数据库操作。这种攻击可能导致数据泄露、篡改或丢失。为了防范SQL注入,可以采取以下几种策略: 1.使用预编译语句(Prepared St…

OBD诊断(ISO15031) 01服务

文章目录 功能简介PID 的功能请求和响应1、read-supported PIDs1.1、请求1.2、肯定响应 2、read PID value1.1、请求1.2、肯定响应 3、同时请求多个PID3、同时读取多个PID数据 Parameter definition报文示例1、单个PID请求和读取2、多个PID请求和读取 功能简介 01服务&#xf…

Linux双网卡默认路由的metric设置不正确,导致SSH连接失败问题定位

测试环境 VMware虚拟机 RockyLinux 9 x86_64 双网卡:eth0(访问外网): 10.206.216.92/24; eth1(访问内网) 192.168.1.4/24 问题描述 虚拟机重启后,SSH连接失败,提示"Connection time out",重启之前SSH连接还是正常的…

2.Android逆向协议-了解常用的逆向工具

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于:微尘网校 上一个内容:1.Android逆向协议-环境搭建 常用的工具:AndroidKiller、jadx、JEB、IDA AndroidKiller…

华为云安全防护,九河云综合分解优劣势分析

随着全球化的发展,越来越多的企业开始寻求在国际市场上扩展业务,这一趋势被称为企业出海。然而,企业在海外扩张面临诸多隐患与安全挑战,其中因为地域的不同,在安全性方面与国内相比会变得薄弱,从而导致被黑…

Redis集群(Clustering in Redis)工作机制详解

Redis集群工作机制详解 Redis 集群是用于提高 Redis 可扩展性和高可用性的解决方案。 维基百科:Scalability is the property of a system to handle a growing amount of work by adding resources to the system. 可扩展性是系统的一种允许通过增加系统资源来处…

《征服数据结构》字典树(Trie树)

摘要: 1,字典树的介绍 2,字典树的插入 3,字典树的查询 4,字典树排序 5,字典树的删除 6,字典树的用途 1,字典树的介绍 字典树又称 Trie 树 ,单词查找树,前缀树…

花卉寄售系统

摘 要 随着互联网的快速发展和普及,电子商务已经成为人们日常生活中不可或缺的一部分。在电子商务领域,花卉行业也逐渐崭露头角,成为一个具有巨大潜力的市场。传统的花卉销售模式通常是通过实体店面进行销售,这种模式存在着许多问…

Python | Leetcode Python题解之第202题快乐数

题目: 题解: def isHappy(self, n: int) -> bool:cycle_members {4, 16, 37, 58, 89, 145, 42, 20}def get_next(number):total_sum 0while number > 0:number, digit divmod(number, 10)total_sum digit ** 2return total_sumwhile n ! 1 an…

RISC-V知识总结 —— 向量(扩展)指令集

资源1:晏明 - RISC-V向量扩展指令架构及LLVM自动向量化支持 - 202112118 - 第13届开源开发工具大会(OSDTConf2021)_哔哩哔哩_bilibili资源2:张先轶 - 基于RISC-V向量指令集优化基础计算软件生态【第12届开源开发工具大会(OSDT2020&#xff09…

AI加持,商业智能与分析软件市场释放更大潜能

根据IDC最新发布的《中国商业智能和分析软件市场跟踪报告,2023H2》显示,2023下半年,中国商业智能与分析软件市场规模为5.2亿美元,同比增长为3.7%。其中,本地部署收入占比为89.3%,同比增长1.7%;公…

密码学及其应用 —— 对称加密技术

1. 对称加密、流加密和块加密 1.1 对称加密 对称加密(也称为密钥加密)是一种加密方式,其中加密和解密使用相同的密钥。这种加密方法基于二进制层面的操作,如XOR(异或)、SHIFT(位移)…

Linux 搭建 kafka 流程

优质博文:IT-BLOG-CN 一、安装环境 【1】CenOS7虚拟机三台 【2】已经搭建好的zookeeper集群。 【3】软件版本:kafka_2.11-1.0.0 二、创建目录并下载安装软件 【1】创建目录 cd /opt mkdir kafka #创建项目目录 cd kafka mkdir kafkalogs #创建kafk…

Transformers 安装及 google-t5/t5-small 机器翻译示例

文章目录 Github文档推荐文章简介安装官方示例google-t5/t5-small使用脚本进行训练Pytorch 机器翻译数据集下载数据集格式转换 Github https://github.com/huggingface/transformers 文档 https://huggingface.co/docs/transformers/indexhttps://github.com/huggingface/tr…

《昇思25天学习打卡营第1天|基本介绍》

文章目录 前言:今日所学: 前言: 今天非常荣幸的收到了昇思25天学习打卡营的邀请。昇思MindSpore作为华为昇腾AI全栈的重要一员,他支持端、边、云独立的和协同的统一训练和推理框架,有着易于开发、执行效率高、全场景框…

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTraceAnalysis分析性能瓶颈 1.参考链接:2.性能对比3.相关依赖或命令4.测试代码5.HolisticTraceAnalysis代码6.可视化A.优化前B.优化后 以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTra…

深入剖析 Android 网络开源库 Retrofit 的源码详解

文章目录 概述一、Retrofit 简介Android主流网络请求库 二、Retrofit 源码剖析1. Retrofit 网络请求过程2. Retrofit 实例构建2.1 Retrofit.java2.2 Retrofit.Builder()2.2.1 Platform.get()2.2.2 Android 平台 2.3 Retrofit.Builder().baseUrl()2.4 Retrofit.Builder.client()…

Windows的内核对象

内核对象句柄特定于进程。 也就是说,进程必须创建 对象或打开现有对象以获取内核对象句柄。 内核句柄上的每个进程限制为 2^24。 但是,句柄存储在分页池中,因此可以创建的实际句柄数取决于可用内存。 可以在 32 位 Windows 上创建的句柄数明显低于 2^24。 任何进程都可以为…

Golang | Leetcode Golang题解之第201题数字范围按位与

题目&#xff1a; 题解&#xff1a; func rangeBitwiseAnd(m int, n int) int {for m < n {n & (n - 1)}return n }