【Python决策树】ID3方法建立决策树为字典格式,并调用 treelib 显示

news2025/1/15 6:25:08

首先,我们使用 treelib 库来显示树结构 :
ps : 如果 treelib 输出一堆乱码, 可以点进Tree修改 tree.py 大概 930 行左右的部分(去掉encode就行了)

        if stdout:
            print(self._reader)   # print(self._reader.encode("utf-8"))
        else:
            return self._reader

将字典转换为treelib 中可显示的 map的python 函数编写如下:

import copy
from treelib import Tree

def dict2map(dic):
    if not isinstance(dic, dict):
        raise TypeError("input should be dict")
    map = {}
    _dict2map_cb(dic if len(dic) <= 1 else {'root':dic}, map, parent = [])
    return map

def _dict2map_cb(dic, map, parent=[]):
    """
    Create map object in treelib by dict
    :param dic: Python dict
    :param map: Use {} as map
    :param node_name: If None, use the first key of dic as root
                      but when multiple items in json, pass "root"
    :param parent: Parent node array
    :return:
    """
    for key, val in dic.items():
        node_name_new = '-'.join(parent)
        root_name = '-'.join(parent + [key])
        if isinstance(val, dict):
            map[root_name] = node_name_new if parent!=[] else None   # when root node,use None
            _dict2map_cb(val, map, parent=parent + [key])
        else:
            map[root_name + " : " + str(val)] = node_name_new if parent!=[] else None

if __name__ == "__main__":
    a = {"hello": {"word": 2}}
    b = {'decision 0': {'target 1': 256, 'decision 3': {'target 0': 128, 'target 1': 256}, 'decision 2': {'target 0': 256, 'target 1': 128}}}
    c = {'hi': {"w": 3}, 'this':{'e':4}}
    Tree.from_map(dict2map(a)).show()
    Tree.from_map(dict2map(b)).show(line_type="ascii-em")
    Tree.from_map(dict2map(c)).show(line_type="ascii-em")

决策树的参考文章是 《机器学习苏娜发原理与编程实践》郑捷著, 具体是分类如下的问题 :

计数年龄收入学生信誉是否购买
64不买
64不买
128
60
64
64不买
64
128不买
64
132
64
32
32
64不买

显然容易构建出如下的决策树, 但这个决策树不是最优的
在这里插入图片描述
决策树的原理在这里不进行讲解, 将上面的表格保存为seals_data.xlsx, 并将转换的脚本保存为 dict_to_map.py, 即可直接运行下面的ID3决策树代码:

import numpy
import numpy as np
import copy
import pandas as pd
from sklearn.preprocessing import LabelEncoder  # encoder labels
from treelib import Tree, Node
from sklearn.datasets._base import Bunch
from dict_to_map import dict2map

class ID3_Tree():
    """ ID3 decision Tree Algorithm """
    def __init__(self, counts = None, data = None, target = None, label_encoder = None):
        if (counts == None or data == None or target == None):
            self.load_data()
        else:
            self.train_data = Bunch(counts = counts, data = data, target = target)
            self.label_encoder = label_encoder()

        self.__Init_Tree(self.train_data.counts, self.train_data.data, self.train_data.target)
        self.build_Tree()

    def to_dict(self, *args, **kwargs):
        return self.tree

    def load_data (self):
        """ arrange the data into the correct shapes """
        data_raw = pd.DataFrame(pd.read_excel("seals_data.xlsx"))
        label_encoder = LabelEncoder()
        # eliminate all the white space
        data_raw = data_raw.map(func=lambda x:x.strip() if isinstance(x, str) else x)
        data_proceed = pd.DataFrame()
        data_proceed['计数'] = data_raw.iloc[:, 0]
        for column in data_raw.columns[1:]:
            data_proceed[str(column).strip()] = label_encoder.fit_transform(data_raw[column])

        """ split data into 3 part : counts, data and target """
        counts = np.array(copy.deepcopy(data_proceed.iloc[:,0]))
        data = np.array(copy.deepcopy(data_proceed.iloc[:,1:-1]))
        target = np.array(copy.deepcopy(data_proceed.iloc[:,-1]))
        labels = [str(column).strip() for column in data_raw.columns]
        self.data_raw = data_raw
        self.train_data = Bunch(counts = counts, data = data, target = target ,labels = labels) # target = data_raw.iloc[:, 0])
        self.label_encoder = label_encoder

    def __Init_Tree(self, counts:np.ndarray, data:np.ndarray, target:np.ndarray):
        self.__check_param(counts, data, target)
        self.nums   = counts.shape[0]              # number of the type of the samples
        self.target_num = len(np.unique(target))   # the number of classes (C_i) , i = 1... m
        self.decision_num = data.shape[1]          # number of decision attributes (D)
        self.total_num = counts.sum()              # total number of samples (N)
        self.tree = {}                             # init the tree node

        # calculate the infomation entropy of the entire dataset
        targets =  np.unique(target)
        cls_cnt =  np.array([counts[target == targets[i]].sum() for i in range(targets.size)])
        cls_prop=  cls_cnt/cls_cnt.sum()
        self.base_entropy = -np.sum(cls_prop * np.log2(np.where(cls_prop == 0,1e-10, cls_prop)))

    def build_Tree(self):
        if (self.target_num <= 1):
            # only 1 class, stop split and return the empty tree
            self.tree["root"] = self.train_data.target[0]
            return self.tree

        # initialize the node decision range and target range, we use whole data set to calculate the entropy of root at first
        dec_range = np.arange(self.decision_num)   # decision attributes
        tar_range = np.arange(self.nums)           # targets on data
    
        # recursive call the calc_entropy_mat function until the class is purely classified.
        self.tree = self.__build_tree_node(dec_range, tar_range)

    def __check_param(self, counts, data, target):
        if len(counts.shape)!=1 or len(data.shape)!=2 or len(target.shape)!=1:
            raise ValueError("The input data is not in the correct shape")
        elif counts.shape[0]!= data.shape[0] or counts.shape[0] != target.shape[0]:
            raise ValueError("The input data is not in the correct shape")

    def show(self):
        map = dict2map(self.tree)
        tree = Tree.from_map(map)
        tree.show(line_type="ascii-em", sorting=False)

    def __build_tree_node(self, dec_range, tar_range) -> dict:
        """
        recursive function,
        counts, data, target, node_dec_range, node_tar_range
        :param dec_range : decision range (in direction 1 or y)
        :param tar_range : target range (in direction 0 or x)
        :return: root (name of the root node is defined by decision)
        """

        counts = self.train_data.counts[tar_range]
        data =   self.train_data.data[tar_range][:, dec_range]
        target = self.train_data.target[tar_range]
        self.__check_param(counts, data, target)

        gain_list = [self.__get_node_info_gain(counts, data[:, i], target) for i in range(dec_range.size)]
        best_dec_idx = np.argmax(gain_list)  # best decision (note : relevant to dec_range, not self.train_data)

        best_decision = data[:, best_dec_idx]
        features = np.unique(best_decision)
        cnt_mat = self.__get_count_matrix(counts, best_decision, target)

        # use best decision as root node -> delete it from dec_range
        root_name = "decision" + str(dec_range[best_dec_idx])   # get the location of best decision
        root = {}

        d2 = numpy.delete(dec_range, best_dec_idx)  # create new dec_range object
        for i in range(len(features)):
            feature = features[i]
            cnt_arr = np.array(cnt_mat[:, i]).squeeze(1)   # change to array and squeeeze to 1 dim

            # record the classification result:
            left_tar_idx = np.nonzero(cnt_arr)[0]          # choice leave in this feature

            if (len(left_tar_idx) == 0):
                raise ValueError("left choices is not zero here!")

            if left_tar_idx.size == 1:
                # left number can be calculated by cnt_arr[left_choices], t2 = []
                root["target" + str(left_tar_idx[0]) ] = np.sum(cnt_arr[left_tar_idx])
            elif d2.size == 0 :  # no available decision left
                root["dummy"] = np.sum(cnt_arr[left_tar_idx]) # create dummy node
            else:
                t2 = tar_range[best_decision == feature]
                sub_tree = self.__build_tree_node(d2, t2);
                sub_rootname, sub_root = next(iter(sub_tree.items()))
                root[sub_rootname] = sub_root
        return {root_name : root}

    def __get_node_info_gain(self,counts_arr, decision_arr, target_arr):
        """
        :param counts_arr: the counts array of the current node
        :param decision_arr: the decision attributes of the current node
        :param target_arr: target array to be classified of the current node
        :return: gain : scalar, infomation gain of the current node
        """
        cnt_mat = self.__get_count_matrix(counts_arr, decision_arr, target_arr)
        # calculate the information gain of the current node by cnt_mat
        prob_mat = np.mat(cnt_mat / np.sum(cnt_mat, axis=0))  # calculate probability matrix
        prob_mat = np.where(prob_mat == 0, 1e-10, prob_mat)   # substitute 0 with 1e-10 to avoid log calculation error
        node_entropy = -np.sum(np.multiply(prob_mat, np.log2(prob_mat)), axis=0)
        node_wt = cnt_mat.sum(axis=0) / np.sum(cnt_mat)
        gain = self.base_entropy - np.multiply(node_wt, node_entropy).sum()
        return gain
    
    def __get_count_matrix(self, counts_arr, decision_arr, target_arr):
        """
        calculate the count matrix of the node
        :param counts_arr:   the counts array of the current node
        :param decision_arr: the decision attributes of the current node
        :param target_arr:   target array to be classified of the current node
        :return: cnt_mat : np.matrix
        """
        features = np.unique(decision_arr)
        targets = np.unique(target_arr)
        cnt_mat = np.array([
            [np.sum(counts_arr[(decision_arr == dec) & (target_arr == tar)]) for dec in features]
            for tar in targets
        ])
        return np.matrix(cnt_mat)


if __name__ == "__main__":
    tree = ID3_Tree()
    print(tree.to_dict())
    tree.show()

需要说明的是, 上述的代码没有按参考原文的代码, 算法的主要思路是一致的, 代码中将不同的选择称为 decision , 而每个选择的不同分支 称为 feature, 便于编程, 最终通过概率矩阵 prep_mat 计算出对应的节点熵。选取最适合用于分类的节点。

在这里插入图片描述
上面程序的运行结果如下 :

{'decision0': {'target1': 256, 'decision3': {'target0': 128, 'target1': 256}, 'decision2': {'target0': 256, 'target1': 128}}}
decision0
╠══ decision0-target1 : 256
╠══ decision0-decision3
║   ╠══ decision0-decision3-target0 : 128
║   ╚══ decision0-decision3-target1 : 256
╚══ decision0-decision2
    ╠══ decision0-decision2-target0 : 256
    ╚══ decision0-decision2-target1 : 128

上述树按照 feature 的先后顺序划分 (例如decision0(年龄)-feature0(中年)) 对应的是 decision0下的第一条, 以此类推。同时为了防止treelib自动排序,需要设置参数 tree.show(line_type="ascii-em", sorting=False)

需要说明的是, decision0-decision3 分别对应的是年龄 , 收入, 学生, 信誉, 具体可以拿这个具体去看:

printf(self.data_raw)
printf(self.train_data.counts)
printf(self.train_data.data)
printf(self.train_data.target)

即最终成功建立了如下的决策树:
请添加图片描述

另外, 利用 可以方便地采用 c4.5 算法很方便地建立这个决策树:
库下载: pip install c45-decision-tree

建立该树只需采用如下代码 :

from C45 import C45Classifier
import pandas as pd
import graphviz

data_raw = pd.DataFrame(pd.read_excel("seals_data.xlsx"))
data_raw = data_raw.map(func=lambda x:x.strip() if isinstance(x, str) else x)

counts = data_raw.iloc[:,0]
data   = data_raw.iloc[:,1:-1]
target = data_raw.iloc[:,-1]

data_new = []
target_new = []
for i in range(counts.size):
    for j in range(counts[i]):
        data_new.append(list(data.iloc[i]))
        target_new.append(target.iloc[i])

model = C45Classifier()
model.fit(data_new, target_new)
tree_diagram = model.generate_tree_diagram(graphviz, "tree_diagram")
graphviz.view(tree_diagram)

另外, 如果出现中文显示乱码问题, 可以跳转到 generate_tree_diagram 源码中, 添加 dot.attr(encoding='utf-8') # Ensure UTF-8 encoding 部分和 fontname="SimHei" 三个部分 :

    def generate_tree_diagram(self, graphviz, filename):
        # Menghasilkan diagram pohon keputusan menggunakan modul graphviz
        dot = graphviz.Digraph()
        def build_tree(node, parent_node=None, edge_label=None):
            if isinstance(node, _DecisionNode):
                current_node_label = str(node.attribute)

                dot.node(str(id(node)), label=current_node_label)
                if parent_node:
                    dot.edge(str(id(parent_node)), str(id(node)), label=edge_label, fontname="SimHei")

                for value, child_node in node.children.items():
                    build_tree(child_node, node, value)
            elif isinstance(node, _LeafNode):
                current_node_label = f"Class: {node.label}, Weight: {node.weight}"
                dot.node(str(id(node)), label=current_node_label, shape="box", fontname="SimHei")

                if parent_node:
                    dot.edge(str(id(parent_node)), str(id(node)), label=edge_label, fontname="SimHei")

        build_tree(self.tree)
        dot.format = 'png'
        dot.attr(encoding='utf-8')  # Ensure UTF-8 encoding
        return dot.render(filename, view=False)

绘制出的决策树如图所示:
请添加图片描述

附注 : 由于本人水平所限,上面的代码部分也可能有一些错误, 如果读者发现也希望能在评论区指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2133241.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Python的B站热门视频可视化分析与挖掘系统

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 随着互联网视频平台的迅猛发展&#xff0c;如何从海量的数据中提炼出有价值的信息成为了内容创作者们关注的重点之一。B站&#xff08;哔哩哔哩&#xff09;作为国内领先的年轻人文化社区&#xf…

应用层协议 —— https

目录 http的缺点 https 安全与加密 运营商挟持 常见的加密方式 对称加密 非对称加密 数据摘要&#xff08;数据指纹&#xff09; 不安全加密策略 1 只使用对称加密 2 只使用非对称加密 3 双方都是用非对称加密 4 对称加密和非对称加密 解决方案 CA证书 http的缺点 我们可…

基于鸿蒙API10的RTSP播放器(八:音量和亮度调节功能的整合)

一、前言&#xff1a; 笔者在前面第六、七节文章当中&#xff0c;分别指出了音量和屏幕亮度的前置知识&#xff0c;在本节当中&#xff0c;我们将一并实现这两个功能&#xff0c;从而接续第五节内容。本文的逻辑分三大部分&#xff0c;先说用到的变量&#xff0c;再说界面&…

智慧环保平台建设方案

智慧环保平台建设方案摘要 政策导向与建设背景 背景&#xff1a;全国生态环境保护大会提出坚决打好污染防治攻坚战&#xff0c;推动生态文明建设&#xff0c;目标是在2035年实现生态环境质量根本好转。构建生态文明体系&#xff0c;包括生态文化、生态经济、目标责任、生态文明…

表格标记<table>

一.表格标记、 1table&#xff1a;表格标记 2.caption:表单标题标记 3.tr:表格行标记 4.td:表格中数据单元格标记 5.th:标题单元格 table标记是表格中最外层标记&#xff0c;tr表示表格中的行标记&#xff0c;一对<tr>表示表格中的一行&#xff0c;在<tr>中可…

Excel数据转置|Excel数据旋转90°

Excel数据转置|Excel数据旋转90 将需要转置的数据复制在旁边空格处点击鼠标右键&#xff0c;选择图中转置按钮&#xff0c;即可完成数据的转置。&#xff01;&#xff01;&#xff01;&#xff01;非常有用啊啊啊&#xff01;&#xff01;&#xff01;

嵌入式Linux学习笔记(2)-C语言编译过程

c语言的编译分为4个过程&#xff0c;分别是预处理&#xff0c;编译&#xff0c;汇编&#xff0c;链接。 一、预处理 预处理是c语言编译的第一个阶段&#xff0c;该任务主要由预处理器完成。预处理器会根据预处理指令对源代码进行处理&#xff0c;将预处理指令替换为相应的内容…

游戏各个知识小点汇总

抗锯齿原理记录 SSAA&#xff1a;把成像的图片放大N倍&#xff0c;然后每N个点进行平均值计算。一般N为2的倍数。比如原始尺寸是1000x1000&#xff0c;长宽各放大2倍变成2000x2000。 举例&#xff1a; 原始尺寸&#xff1a; 放大2倍后 最后平均值计算成像&#xff1a; MSAA&…

基于SpringBoot+Vue的网上蛋糕销售系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

踩坑记:Poco库,MySql,解析大文本的bug

这两天在调试一个小功能&#xff0c;使用c,读取MySql。使用的是Poco库。按照官网的写法&#xff1a; std::cout << "read normal data by poco recordset "<<std::endl;Poco::Data::MySQL::Connector::registerConnector();Poco::Data::Session session(…

.NET 6.0 + WPF 使用 Prism 框架实现导航

合集 - .NET 基础知识(3) 1..NET 9 优化&#xff0c;抢先体验 C# 13 新特性08-202.《黑神话&#xff1a;悟空》神话再现&#xff0c;虚幻引擎与Unity/C#谁更强&#xff1f;08-21 3..NET 6.0 WPF 使用 Prism 框架实现导航09-11 收起 阅读目录 前言什么是Prism?安装 Prism使…

【卷起来】VUE3.0教程-09-整合Element-plus

最后一次课了&#xff0c;给个关注和赞呗 &#x1f332; 简介 Element Plus 是一个基于 Vue 3 的高质量 UI 组件库。它包含了丰富的组件和扩展功能&#xff0c;例如表格、表单、按钮、导航、通知等&#xff0c;让开发者能够快速构建高质量的 Web 应用。Element Plus 的设计理念…

洛谷 P4683 [IOI2008] Type Printer

原题点这里 题目来源于&#xff1a;洛谷 题目本质&#xff1a;深搜&#xff0c;字典树Trie 当时想法&#xff1a;当时看了题目标签&#xff0c;就有思路了&#xff08;见代码注释&#xff09;&#xff0c;但一直REWA最后只剩下RE。 正确思路&#xff1a; 我们使用字典树来完…

【机器学习】任务四:使用贝叶斯算法识别葡萄酒类别和使用三种不同的决策树方法(ID3,C4.5,CART)对鸢尾花数据进行分类

目录 1.基础知识 1.1 高斯贝叶斯&#xff08;Gaussian Naive Bayes&#xff09; 1.2 决策树&#xff08;Decision Tree&#xff09; 1.3 模型评价&#xff08;Model Evaluation&#xff09; 1.3.1 评价维度&#xff1a; 1.3.2 评价方法&#xff1a; 2.使用贝叶斯算法识别…

android 删除系统原有的debug.keystore,系统运行的时候,重新生成新的debug.keystore,来完成App的运行。

1、先上一个图&#xff1a;这个是keystore无效的原因 之前在安装这个旧版本android studio的时候呢&#xff0c;安装过一版最新的android studio&#xff0c;然后通过模拟器跑过测试的demo。 2、运行旧的项目到模拟器的时候&#xff0c;就报错了&#xff1a; Execution failed…

proteus+51单片机+AD/DA学习5

目录 1.DA转换原理 1.1基本概念 1.1.1DA的简介 1.1.2DA0832芯片 1.1.3PCF8591芯片 1.2代码 1.2.1DAC8053的代码 1.2.2PCF8951的代码 1.3仿真 1.3.1DAC0832的仿真 1.3.2PFC8951的仿真 2.AD转换原理 2.1AD的基本概念 2.1.1AD的简介 2.1.2ADC0809的介绍 2.1.3XPT2…

双指针算法专题(1)

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a; 优选算法专题 目录 双指针算法的介绍 283. 移动零 1089. 复写零 202. 快乐数 11.盛最多水的容器 双指针算法的介绍 在正式做题之前&a…

C++ 获取文件夹下的全部文件及指定文件(代码)

文章目录 1.&#xff08;C17&#xff09;获得指定目录下的所有文件&#xff08;不搜索子文件夹&#xff09;2.&#xff08;C11&#xff09;获得指定目录下的所有文件&#xff08;不搜索子文件夹&#xff09;3.&#xff08;C11&#xff09;获取目录下指定格式的所有文件&#xf…

工厂安灯系统在优化生产流程上的优势

工厂安灯系统通过可视化的方式&#xff0c;帮助工厂管理者和操作工人及时了解生产状态&#xff0c;快速响应问题&#xff0c;从而优化生产流程。 一、安灯系统实时监控与反馈 安灯系统的核心功能是实时监控生产线的状态。通过在生产现场设置灯光、显示屏等设备&#xff0c;工人…

OpenGL(四) 纹理贴图

几何模型&材质&纹理 渲染一个物体需要&#xff1a; 几何模型&#xff1a;决定了物体的形状材质&#xff1a;绝对了当灯光照到上面时的作用效果纹理&#xff1a;决定了物体的外观 纹理对象 纹理有2D的&#xff0c;有3D的。2D图像就是一张图片&#xff0c;3D图像是在…