【人工智能与机器学习】决策树ID3及其python实现

news2024/9/25 1:24:29

文章目录

  • 1 决策树算法
    • 1.1 特征选择
    • 1.2 熵(entropy)
    • 1.3 信息增益
  • 2 ID3算法的python实现
  • 总结

1 决策树算法

决策树(Decision Tree)是一类常见的机器学习方法,是一种非常常用的分类方法,它是一种监督学习。常见的决策树算法有ID3,C4.5、C5.0和CART(classification and regression tree),CART的分类效果一般要优于其他决策树。

决策树是基于树状结构来进行决策的,一般地,一棵决策树包含一个根节点、若干个内部节点和若干个叶节点。

每个内部节点表示一个属性上的判断
每个分支代表一个判断结果的输出
每个叶节点代表一种分类结果。
根节点包含样本全集
决策树学习的目的是为了产生一棵泛化能力强,即处理未见示例能力强的决策树,其基本流程遵循简单且直观的“分而治之”(divide-and-conquer)策略。

本文主要介绍ID3算法,ID3算法的核心是根据信息增益来选择进行划分的特征,然后递归地构建决策树。

1.1 特征选择

特征选择也即选择最优划分属性,从当前数据的特征中选择一个特征作为当前节点的划分标准。 随着划分过程不断进行,希望决策树的分支节点所包含的样本尽可能属于同一类别,即节点的“纯度”越来越高。

1.2 熵(entropy)

熵表示事务不确定性的程度,也就是信息量的大小(一般说信息量大,就是指这个时候背后的不确定因素太多),熵的公式如下:

其中, p(xi)是分类 xi 出现的概率,n是分类的数目。可以看出,熵的大小只和变量的概率分布有关。
对于在X的条件下Y的条件熵,是指在X的信息之后,Y这个变量的信息量(不确定性)的大小,计算公式如下:

例如,当只有A类和B类的时候,p(A)=p(B)=0.5,熵的大小为:

当只有A类或只有B类时,

所以当Entropy最大为1的时候,是分类效果最差的状态,当它最小为0的时候,是完全分类的状态。因为熵等于零是理想状态,一般实际情况下,熵介于0和1之间 。

熵的不断最小化,实际上就是提高分类正确率的过程。

1.3 信息增益

信息增益:在划分数据集之前之后信息发生的变化,计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

定义属性A对数据集D的信息增益为infoGain(D|A),它等于D本身的熵,减去 给定A的条件下D的条件熵,即:

信息增益的意义:引入属性A后,原来数据集D的不确定性减少了多少。

计算每个属性引入后的信息增益,选择给D带来的信息增益最大的属性,即为最优划分属性。一般,信息增益越大,则意味着使用属性A来进行划分所得到的的“纯度提升”越大。

2 ID3算法的python实现

以西瓜数据集为例
·watermalon.csv·文件内容如下:

读取文件数据

import numpy as np
import pandas as pd
import math
data = pd.read_csv('work/watermalon.csv')
data

计算熵

def info(x,y):
    if x != y and x != 0:
        # 计算当前情况的熵
        return -(x/y)*math.log2(x/y) - ((y-x)/y)*math.log2((y-x)/y)
    if x == y or x == 0:
        # 纯度最大,熵值为0
        return 0
info_D = info(8,17)
info_D

结果为:
0.9975025463691153

计算信息增益

# 计算每种情况的熵
seze_black_entropy = -(4/6)*math.log2(4/6)-(2/6)*math.log2(2/6)
seze_green_entropy = -(3/6)*math.log2(3/6)*2
seze_white_entropy = -(1/5)*math.log2(1/5)-(4/5)*math.log2(4/5)

# 计算色泽特征色信息熵
seze_entropy = (6/17)*seze_black_entropy+(6/17)*seze_green_entropy+(5/17)*seze_white_entropy
print(seze_entropy)
# 计算信息增益
info_D - seze_entropy

结果为:
0.10812516526536531

查看每种根蒂中好坏瓜情况的分布情况

data.根蒂.value_counts()
# 查看每种根蒂中好坏瓜情况的分布情况
print(data[data.根蒂=='蜷缩'])
print(data[data.根蒂=='稍蜷'])
print(data[data.根蒂=='硬挺'])
gendi_entropy = (8/17)*info(5,8)+(7/17)*info(3,7)+(2/17)*info(0,2)
gain_col = info_D - gendi_entropy
gain_col

根蒂的信息增益为:0.142674959566793

查看每种敲声中好坏瓜情况的分布情况

data.敲声.value_counts()
# 查看每种敲声中好坏瓜情况的分布情况
print(data[data.敲声=='浊响'])
print(data[data.敲声=='沉闷'])
print(data[data.敲声=='清脆'])
qiaosheng_entropy = (10/17)*info(6,10)+(5/17)*info(2,5)+(2/17)*info(0,2)
info_gain = info_D - qiaosheng_entropy
info_gain

查看每种纹理中好坏瓜情况的分布情况

data.纹理.value_counts()
# 查看每种纹理中好坏瓜情况的分布情况
print(data[data.纹理=="清晰"])
print(data[data.纹理=="稍糊"])
print(data[data.纹理=="模糊"])
wenli_entropy = (9/17)*info(7,9)+(5/17)*info(1,5)+(3/17)*info(0,3)
info_gain = info_D - wenli_entropy
info_gain

同理查看其他列的分布情况,这里不做演示

绘制可视化树

import matplotlib.pylab as plt
import matplotlib

# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']

# 分叉节点,也就是决策节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")

# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")

# 箭头样式
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制一个节点
    :param nodeTxt: 描述该节点的文本信息
    :param centerPt: 文本的坐标
    :param parentPt: 点的坐标,这里也是指父节点的坐标
    :param nodeType: 节点类型,分为叶子节点和决策节点
    :return:
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    """
    获取叶节点的数目
    :param myTree:
    :return:
    """
    # 统计叶子节点的总数
    numLeafs = 0

    # 得到当前第一个key,也就是根节点
    firstStr = list(myTree.keys())[0]

    # 得到第一个key对应的内容
    secondDict = myTree[firstStr]

    # 递归遍历叶子节点
    for key in secondDict.keys():
        # 如果key对应的是一个字典,就递归调用
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 不是的话,说明此时是一个叶子节点
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    得到数的深度层数
    :param myTree:
    :return:
    """
    # 用来保存最大层数
    maxDepth = 0

    # 得到根节点
    firstStr = list(myTree.keys())[0]

    # 得到key对应的内容
    secondDic = myTree[firstStr]

    # 遍历所有子节点
    for key in secondDic.keys():
        # 如果该节点是字典,就递归调用
        if type(secondDic[key]).__name__ == 'dict':
            # 子节点的深度加1
            thisDepth = 1 + getTreeDepth(secondDic[key])

        # 说明此时是叶子节点
        else:
            thisDepth = 1

        # 替换最大层数
        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    """
    计算出父节点和子节点的中间位置,填充信息
    :param cntrPt: 子节点坐标
    :param parentPt: 父节点坐标
    :param txtString: 填充的文本信息
    :return:
    """
    # 计算x轴的中间位置
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    # 计算y轴的中间位置
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    # 进行绘制
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制出树的所有节点,递归绘制
    :param myTree: 树
    :param parentPt: 父节点的坐标
    :param nodeTxt: 节点的文本信息
    :return:
    """
    # 计算叶子节点数
    numLeafs = getNumLeafs(myTree=myTree)

    # 计算树的深度
    depth = getTreeDepth(myTree=myTree)

    # 得到根节点的信息内容
    firstStr = list(myTree.keys())[0]

    # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

    # 绘制该节点与父节点的联系
    plotMidText(cntrPt, parentPt, nodeTxt)

    # 绘制该节点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    # 得到当前根节点对应的子树
    secondDict = myTree[firstStr]

    # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

    # 循环遍历所有的key
    for key in secondDict.keys():
        # 如果当前的key是字典的话,代表还有子树,则递归遍历
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 打开注释可以观察叶子节点的坐标变化
            # print((plotTree.xOff, plotTree.yOff), secondDict[key])
            # 绘制叶子节点
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 绘制叶子节点和父节点的中间连线内容
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
    """
    需要绘制的决策树
    :param inTree: 决策树字典
    :return:
    """
    # 创建一个图像
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 计算出决策树的总宽度
    plotTree.totalW = float(getNumLeafs(inTree))
    # 计算出决策树的总深度
    plotTree.totalD = float(getTreeDepth(inTree))
    # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
    plotTree.xOff = -0.5/plotTree.totalW
    # 初始的y轴偏移量,每次向下或者向上移动1/D
    plotTree.yOff = 1.0
    # 调用函数进行绘制节点图像
    plotTree(inTree, (0.5, 1.0), '')
    # 绘制
    plt.show()


if __name__ == '__main__':
    createPlot(mytree)

总结

决策树ID3是一种经典的机器学习算法,用于解决分类问题。它通过在特征空间中构建树形结构来进行决策,并以信息增益作为划分标准。ID3算法的关键在于选择最佳的属性进行划分,以最大化信息增益。通过Python实现ID3算法,我们可以构建出一棵高效而准确的决策树模型,用于分类预测和决策分析。


参考
https://zhuanlan.zhihu.com/p/133846252
https://cuijiahua.com/blog/2017/11/ml_2_decision_tree_1.html
https://blog.csdn.net/tauvan/article/details/121028351

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/710238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ModaHub魔搭社区:向量数据库MIlvus服务端配置(一)

目录 服务端配置 配置概述 Milvus 文件结构 配置修改 编辑配置文件 运行时修改 server_config.yaml 参数说明 cluster 区域 general 区域 network 区域 服务端配置 配置概述 以下配置说明可同时应用于单机或者分布式场景。 Milvus 文件结构 成功启动 Milvus 服务后…

加速优化WooCommerce跨境电商网站的15种简单方法

Neil Patel和 Google所做的研究表明&#xff0c;如果加载时间超过三秒&#xff0c;将近一半的用户会离开网站。页面加载时间每增加一秒&#xff08;最多5秒&#xff09;&#xff0c;您的收入可能就会减少。在本教程中&#xff0c;我们将学习如何优化加速WooCommerce商店。 目录…

【20220605】文献翻译:高维数据动态可视化研究综述

A Review of the State-of-the-Art on Tours for Dynamic Visualization of High-dimensional Data Visualization of High-dimensional Data) Lee, Stuart, et al. “A Review of the State-of-the-Art on Tours for Dynamic Visualization of High-dimensional Data.” arXiv…

【书评】一本Android系统性能优化的新书

Android性能优化&#xff0c;是一个合格的Android程序员必备的技能&#xff0c;现如今几乎所有的Android面试内容都会或多或少涉及性能优化方面的话题。 学习Android性能优化可以让我们在简历上展示自己的专业技能和项目经验&#xff0c;证明自己具备高效开发和优化Android应用…

java jwt生成token并在网关设置全局过滤器进行token的校验

1、首先引入jjwt的依赖 <dependency><groupId>io.jsonwebtoken</groupId><artifactId>jjwt</artifactId><version>0.9.1</version> </dependency>2、编写生成token的工具类 package com.jjw.result.util;import com.jjw.res…

【UnityDOTS 三】Component的理解

Component的理解 文章目录 Component的理解前言一、托管Component与非托管Component1.非托管Component2.托管Component 二、各功能的Component三、在Editor中的Component的区分总结 前言 Component作为ECS中承载数据的结构&#xff0c;了解他相关内容是非常必要的&#xff0c;…

基于Jsp+Servlet+Mysql学生信息管理系统

基于JspServletMysql学生信息管理系统 一、系统介绍二、功能展示1. 系统的部署2.导入数据库3. 系统介绍 四、其它1.其他系统实现五.获取源码 一、系统介绍 项目类型&#xff1a;Java web项目/Java EE项目/ 项目名称&#xff1a;基于sevelet的学生信息管理系统 当前版本&…

用Python制作一个简单时间、日期显示工具

Python是一款强大的编程软件&#xff0c;可以轻松实现我们的多种开发需求。今天我们拿Python中自带的tkinter来开发一个时钟显示器。如下图所示&#xff1a; 时间显示器 一、编程要求 用tkinter写一个漂亮、五彩的时间显示器&#xff0c;要求显示时、分、秒&#xff0c;即时变…

【JAVA】十分钟带你了解java的前世今生

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【初始JAVA】 文章目录 前言JAVA介绍诞生&#x1f52c;名字与图标&#x1f916;发展&#x1f6e9;️未来&#x1fa84; 前言 玩过我的世界的朋友想必对JAVA以及它的图标都很熟悉&#xff0c;在游戏开始画面…

Java程序所在机器性能监控

Java程序所在机器性能监控 背景 问题单&#xff1a;程序故障&#xff08;OOM、网络不通、操作卡顿&#xff09;问题单&#xff1a;服务连接不上需求 1、监控本地机器性能 告警日志UI2、监控服务接口服务 告警日志UI方案 固定间隔获取机器网络CPU内存数据设置阈值&#xff0c;告…

自定义starter实现接口或方法限流功能

本文的思路是利用AOP技术自定义注解实现对特定的方法或接口进行限流。目前通过查阅相关资料&#xff0c;整理出三种类型限流方法&#xff0c;分别为基于guava限流实现、基于sentinel限流实现、基于Semaphore的实现。 一、限流常用的算法 1.1令牌桶算法 令牌桶算法是目前应用…

OpenCV(视频加载与摄像头使用)

目录 1、VideoCapture类 2、视频属性get() 3、视屏文件保存 1、VideoCapture类 2、视频属性get() 3、视屏文件保存 //视频的读取保存 int test3() {VideoCapture video;//video.open("F:/testMap/lolTFT.mp4");//读取视频video.open(0);//读取摄像头if (!video.i…

linux docker安装

一、Linux安装docker 1.1 前提 要求Linux内核&#xff08;kernel&#xff09; 版本大于等于3.8。&#xff08;kernel version >3.8&#xff09;。 查看当前系统内核版本 uname -a | awk {split($3,arr,"-");print arr[1]} 1.2 linux 安装docker Centos安装doc…

【数据结构与算法】7、队列(Queue)的实现【用栈实现队列】

目录 一、队列介绍二、使用 LinkedList 实现队列三、LeetCode&#xff1a;用【栈】实现队列(1) 老师讲之前我自己的实现&#xff08;Correct&#xff09;(2) 实现思路(3) 代码实现 四、jdk 的 Queue五、双端队列&#xff08;Deque&#xff09;六、循环队列(1) 分析(2) 入队(3) …

Linux--运行指令的本质

本质&#xff1a; ①找到它 which的作用就是找到它 ②运行它 示例&#xff1a; ①告诉系统要运行的指令&#xff0c;然后系统去查找它的路径并运行它 ②自己告诉系统自己要运行的路径&#xff0c;然后系统运行它 注意&#xff1a;a.out不能运行&#xff0c;而./a.out能运行…

MES是如何帮助企业提高生产效率的

大多数提高制造生产效率的系统都是从详细分析公司的制造流程和运营开始的。这样做的目的是是为了消除浪费的不增值的流程&#xff0c;将有价值的流程系统化&#xff0c;实现生产自动化并增强增值操作。 在自动化流程方面&#xff0c;实施制造执行系统&#xff08;MES&#xff…

HTML5 游戏开发实战 | 俄罗斯方块

俄罗斯方块是一款风靡全球的电视游戏机和掌上游戏机游戏&#xff0c;它曾经造成的轰动与造成的经济价值可以说是游戏史上的一件大事。这款游戏看似简单但却变化无穷&#xff0c;游戏过程仅需要玩家将不断下落的各种形状的方块移动、翻转&#xff0c;如果某一行被方块充满了&…

发送邮箱验证码【spring boot】

⭐前言⭐ ※※※大家好&#xff01;我是同学〖森〗&#xff0c;一名计算机爱好者&#xff0c;今天让我们进入学习模式。若有错误&#xff0c;请多多指教。更多有趣的代码请移步Gitee &#x1f44d; 点赞 ⭐ 收藏 &#x1f4dd;留言 都是我创作的最大的动力&#xff01; 1. 思维…

Redis6之穿透、击穿、雪崩

大量的高并发的请求打在Redis上&#xff0c;但是发现Redis中并没有请求的数据&#xff0c;redis的命令率降低&#xff0c;所以这些请求就只能直接打在DB&#xff08;数据库服务器&#xff09;上&#xff0c;在大量的高并发的请求下就会导致DB直接卡死、宕机。 缓存穿透 当客户端…

一例Phorpiex僵尸网络样本分析

本文主要分析Phorpiex僵尸网络的一个变种&#xff0c;该样本通常NSIS打包&#xff0c;能够检测虚拟机和沙箱。病毒本体伪装为一个文件夹&#xff0c;通过U盘来传播&#xff0c;会隐藏系统中各盘符根目录下的文件夹&#xff0c;创建同名的lnk文件&#xff0c;诱导用户点击。 病…