[Machine Learning] decision tree 决策树

news2025/1/20 16:28:56

(为了节约时间,后面关于机器学习和有关内容哦就是用中文进行书写了,如果有需要的话,我在目前手头项目交工以后,用英文重写一遍)

(祝,本文同时用于比赛学习笔记和机器学习基础课程)

俺前两天参加了一个ai类的比赛,其中用到了一种名为baseline的模型来进行一些数据的识别。而这个识别的底层原理就是决策树。正好原本的学习进度刚刚完成这部分,所以集成一个笔记了,本文中所有的截图绝大多数来自吴恩达老师的公开课程,为了方便理解,把相关的图片搬过来了)

决策树是什么

决策树是一种机器学习算法,在一个类似二叉树的结构上实现的分支判断算法。每个节点都视为一个“判断语句”,将一批数据划分成不同的部分。节点上(除了叶子)都要判断“是”/“否”。

 一个具体化以后的模型差不多长这样子:给出一堆宠物的数据,根据不同的特征(耳朵,脸型什么的),我们判断输入案例是狗还是猫猫。

如果还是不好理解,那么想象一下我们平时在写代码时候大量if else嵌套,展开以后也是一模一样的结构。去别在于可能if构成的判断树的后代可能多于决策树,决策树只能是二叉树,输出“是”“不是”这种问题,当面对多个离散的特征值的时候,我们还有别的技术可以使用.

简而言之,决策树是一种区别于神经网络的另一种判断算法,在一些数据的处理上可能比神经网络更快更有效,由于其结构类似二叉树,所以称之为决策树(decision tree).决策树的生成是要根据已经给出的数据案例创建的,数据有多少特征用于区分,就会有多少个节点进行分裂(split).

具体的训练过程和训练中遇到的问题会在下面解释

在训练之前要接触的一些名词

纯净(purity)/杂质(impurity):纯度和不纯是根据某个节点来说的,例如我们输入一堆宠物的数据(包括耳朵形状,毛发长度,脸型这些特征),在判断某个属性的节点上,我们会根据"符合"/"不符合"把已有的数据划分为两拨.比如这样子

 原型的部分中,有四个是猫猫,三个是狗子.对于这个节点来说,我们可以认为这个节点的纯度是(4/7)

同理,另一个节点的纯度视为(1/3)

(纯度是一个相对的概念,如果你判断的是狗子,那么纯度就要变了)

:这个熵不是化学中的概念,而是代表混乱程度,当纯度和为0.5的时候,代表两种东西对半开,也就是最混乱的情况.根据纯度,我们有相关的公式可以计算出纯度对应熵的大小(假设纯度为p)

H(p)=-p\log _{2}(p)-(1-p)\log _{2}(1-p)

整个函数的图像大概就是这样子

信息增益:信息增益也是根据某一个点来说的,这个数值是训练时候的重要依据,信息增益越大,代表整个节点进行的划分越有效,信息增益的计算方式为

Information\: gain=H(0.5)-W_ {left}H(p_{left})-W_ {right}H(p_{right})

0.5对应的熵,减去左侧的熵和右侧的熵的加权平均和即可.比如上面的图,我们可以计算为

H(0.5)-(\frac{7}{10}p_{left}+\frac{3}{10}p_{right})

决策树如何进行训练

决策树底层的训练原理其实很简单,首先我们需要给定一个数据集合,这个数据集合中的每个事物都有一些共同的特征,类似这样,通常我们可以把有效的特征组合起来形成一个表格.

 前面的特征为输入,而cat一列作为输出,决定这个宠物到底是不是猫,由此构成一系列符合监督学习要求的训练数据集合.

然后会从这些信息中,选择分裂时产生更小熵的特征,算法会基于某种标准(例如信息增益、基尼不纯度等)来评估每个可能的划分,并选择最优的划分特征。这些标准用于衡量数据的不纯度和分割后的纯度。这里我们使用上面讲到的信息增益来判断这个划分成都

 由此可见,以耳朵形状作为划分所产生的分裂节点,信息增益更大,纯度也更好.

接下来再根据其他的特征进行划分即可,当遇到以下几种情况的时候,我们可以认为这个节点不用再继续分裂了

  • 树的高度达到某些限制
  • 纯度已经是100%
  • 数据全部低于阈值
  • ........

 两个特殊情况

(1)分裂时候的数据不是二元的离散数值,而是一个连续的情况

这个很简单,设置一个阈值,比如0.5,0,7,....反正到最后还是二元的

(2)分裂的时候,可能数据是多元的离散数值,比如毛发可能是长发,短发,卷发这三种.我们总不能搞出三叉树来,所以这里我们把"是什么"转变为"是不是"的问题.比如这样一个特征,我们可以划分为"是不是长发,是不是短发,是不是卷毛"三个二元的特征

随机森林算法

给定一个数据集合,我们可以计算出一个决策树来进行一些判断,给定一个动物,决策树最红会给出我们这个是不是猫猫的答案.但是这有两个问题,节点不一定是纯净的(虽然大多数情况下,只要不超过我们的限定高度,是可以把一个决策树修炼到高度纯净的),造成判断结果不一定准确.

另一个问题就是,一些数据发生扰动以后,可能会影响决策树这个依托信息增益产生的精密系统.

最简单粗暴的方法就是,训练多个树,形成一个森林.但是一个数据集合练出来的树是一样的,没啥必要,所以我们产生了随机森林算法.

sampling with replacement(放回抽样)这东西我们在高中就学过,所以这里不加简述了.我们要做的就是确定一个规模,比如10,每次从原始数据集中抽取10个案例,然后用来训练一棵树.

如此循环多次,我们就能得到多个决策树,组成一个森林,这其中难免会有一些决策树是一样的,我们忽视掉它

这样我们计算结果的时候,要考虑到整个森林所有树木的输出效果,然后综合考虑我们怎样确定输出效果 

XGBoost算法和使用

在众多随机森林算法中,XGBoost是一种使用很广泛的随机森林算法,并且XGBoost也是一个开源库(不是放在tf或者pytorch的库中的).XGBoost非常像我们之前聊过的增强算法(啥,哦博客还没写出来,8好意思,尽快补上)

XGBoost算法和普通决策树的区别在于放回抽样的不疯魔,传统的决策树是平等地抽取,xgb算法则是会根据上一次,估计错了哪些数值,在本次抽取中优先提取上一次参与训练并且估计失败的数值案例.

比如

 构建某一次决策树的时候,2,6,8号数据估计错误,则下一次会优先提取出这些作为训练案例之一.

当然这些主要是底层实现了(注意对应的函数从xgboost包中导入,这个包需要提前下载)

下面来看一下具体的使用案例.

pip3 install xgboost
#xgboost算法 这里没有使用训练集合什么de
# 定义特征矩阵和标签
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y = np.array([0, 0, 1, 1])

# 创建并训练模型
model = XGBClassifier()
model.fit(X, y)

# 预测一个数据
data_to_predict = np.array([[2, 3]])
prediction = model.predict(data_to_predict)

print(f"预测结果: {prediction}")
 
#xgboost算法 这里没有使用训练集合什么de
# 定义特征矩阵和标签
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y = np.array([0, 0, 1, 1])

# 创建并训练模型
model = XGBClassifier()
model.fit(X, y)

# 预测一个数据
data_to_predict = np.array([[2, 3]])
prediction = model.predict(data_to_predict)

print(f"预测结果: {prediction}")

和神经网络有什么区别捏?

相比于神经网络来说,决策树和随机森林算法更适合一些有固定相似数据结构的数据集合.换句话说,更容易处理那种可以形成表格的数据.

而神经网络则用来处理一些非相似结构的数据,这一点就是他们的主要区别

决策树同样是一种很重要的监督学习算法.

关于baseline(未完待续)

baseline是一种基于决策树的大模型,适用于多重二元分析等操作,在竞赛和论文中应用很广泛.

(至少与我们之前用到tensorflow要广泛.....tf都快开摆了)

不过这个模型我现在也不是很熟悉,仅仅是停留在"用过"这个层面上,后面有机会我会继续在这里补充这个模型的使用和优缺点,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/895664.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【学习FreeRTOS】第12章——FreeRTOS时间管理

1.FreeRTOS系统时钟节拍 FreeRTOS的系统时钟节拍计数器是全局变量xTickCount,一般来源于系统的SysTick。在STM32F1中,SysTick的时钟源是72MHz/89MHz,如下代码,RELOAD 9MHz/1000-1 8999,所以时钟节拍是1ms。 portNV…

事物有哪些特性 ?MySQL 如何保证事物的四大特性 ?

目录 1.事物有哪些特性 2. MySQL 如何保证事物的四大特性 3. 事物的隔离级别 1.事物有哪些特性 1.1 何为事物 ? 事物就是把一件事情的多个步骤,多个操作,打包成一个步骤,一个操作。其中任意一个步骤执行失败,都会进…

隧道广播平面波扬声器的应用

隧道广播平面波扬声器是一款高清晰定向扬声器,采用稀土永磁磁性材料与声波相控阵技术,有效的解决了声音定向问题。是远距离定向声波发射装置是一种革命性的技术,它具有大功、率高清晰、远距离传声特点,可以将声音信息清晰地传输到…

【数据结构】 链表简介与单链表的实现

文章目录 ArrayList的缺陷链表链表的概念及结构链表的分类单向或者双向带头或者不带头循环或者非循环 单链表的实现创建单链表遍历链表得到单链表的长度查找是否包含关键字头插法尾插法任意位置插入删除第一次出现关键字为key的节点删除所有值为key的节点回收链表 总结 ArrayLi…

Aurora 8B/10B

目录 1. Overview2. Feature List2. Block Diagram3. Ports Description3.1. User InterfaceFraming InterfaceStreaming InterfaceUser Flow Control(UFC)Native Flow Control(NFC) 3.2. Status and Control Ports3.3. Transceiv…

基于python+django+mysql的校园影院售票系统(可做计算机毕设)

开发柚子校园影院,不仅可以改善用户查看信息难的局面,还可以提高管理效率,同时也可以增强系统的竞争力。利用柚子校园影院的可以有效地提高系统的人事的效率和信息化水平,快速了解信息更新及服务的进度。这既可以确保系统服务的品…

RuoYi 云服务器部署系统

一.为什么要部署 关于RuoYi-Vue是一个前后端分离的Web后台管理系统。部署在云服务器上让所有人都可以访问这是Web网站很正常的一个需求,只要我们将前端静态文件暴露在公网中,自然就部署好了。当然,要求是前端的静态资源可以访问到后端的接口,网站才会正常运行。 二.云服务器…

.netcore windows app启动webserver

创建controller: using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Logging; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Text.Json.Serialization; using System.Threading.Tasks;namespace MyWorker.…

【AI】《动手学-深度学习-PyTorch版》笔记(十九):卷积神经网络模型(GoogLeNet、ResNet、DenseNet)

AI学习目录汇总 1、GoogLeNet 1.1 介绍 发布时间:2014年 GoogLeNet的贡献是如何选择合适大小的卷积核,并将不同大小的卷积核组合使用。 之前介绍的网络结构都是串行的,GoogLeNet使用并行的网络块,称为“Inception块” “Inception块”前后进化了四次,论文链接: [1]ht…

x.view(a,b)及x = x.view(x.size(0), -1) 的理解说明

x.view()就是对tensor进行reshape: 我们在创建一个网络的时候,会在Foward函数内看到view的使用。 首先这里是一个简单的网络,有卷积和全连接组成。它的foward函数如下: class NET(nn.Module):def __init__(self,batch_size):sup…

大数据之几分钟处理完30亿个数据

写在前面 假定现在我们有一个10G的文件,存储的是17~70岁的年龄,每个年龄使用,分割,现需要找出出现次数最多的年龄,以及其出现的次数。 源码 。 1:数据准备 我们首先来准备一个10G大小的存储年龄信息的数据文件&#…

(五)、深度学习框架源码编译

1、源码构建与预构建: 源码构建: 源码构建是通过获取软件的源代码,然后在本地编译生成可执行程序或库文件的过程。这种方法允许根据特定需求进行配置和优化,但可能需要较长的时间和较大的资源来编译源代码。 预构建: 预…

JSP-学习笔记

文章目录 1.JSP介绍2 JSP快速入门3 JSP 脚本3.1 JSP脚本案例3.2 JSP缺点 4 EL表达式4.1 快速入门案例 5. JSTL标签6. MVC模式和三层架构6.1 MVC6.2 三层架构 7. 案例-基于MVC和三层架构实现商品表的增删改查 1.JSP介绍 概念 JSP(JavaServer Pages)是一种…

JVM——引言+JVM内存结构

引言 什么是JVM 定义: Java VirtualMachine -java 程序的运行环境 (ava 二进制字节码的运行环境) 好处: 一次编写,到处运行自动内存管理,垃圾回收功能数组下标越界检查,多态 比较: jvm jre jdk 学习jvm的作用 面试理解底层实现原理中…

idea设置忽略大小写

1.点击file 2.点击settings 3.点击Editor选项 4.点击general选项 5.点击code completion 6.点击左上角match case

根据宿主机PID获取容器运行实例

当宿主机的容器化方式部署更多的时候按照之前linux查看进程的命令基本很难获取到想要的信息,只能看到ps后的结果,长时间后我都不知道哪里出现这么多nginx的进程,能确定是容器部署的,但是不知道那些容器出现了这么多进程 1.根据相…

Window下部署使用Stable Diffusion AI开源项目绘图

Window下部署使用Stable Diffusion AI开源项目绘图 前言前提条件相关介绍Stable Diffusion AI绘图下载项目环境要求环境下载运行项目打开网址,即可体验文字生成图像(txt2img)庐山瀑布 参考 本文里面的风景图,均由Stable Diffusion…

用户新增预测——baseline学习笔记

一、赛题理解 1. 赛题名称 用户新增预测挑战赛 2. 赛题数据集 赛题数据由约62万条训练集、20万条测试集数据组成,共包含13个字段。其中uuid为样本唯一标识,eid为访问行为ID,udmap为行为属性,其中的key1到key9表示不同的行为属性…

项目管理敏捷管理流程,高效敏捷项目管理解决方案

Leangoo领歌是一款永久免费的专业敏捷研发管理工具,提供敏捷研发解决方案,解决研发痛点,打造成功产品。帮助团队实现需求、迭代、缺陷、任务、测试、发布等全方位研发管理。 敏捷产品路线图管理: 产品路线图是一个高层次的战略计…

服务器数据库中了360后缀勒索病毒怎么办?360后缀勒索病毒的加密形式

随着信息技术的发展,企业的计算机服务器数据库变得越来越重要。然而,在数字时代,网络上的威胁也日益增多。近期,我们收到很多企业的求助,企业的计算机服务器遭到了360后缀勒索病毒的攻击,导致服务器内的所有…