机器学习之决策树原理详解、公式推导(手推)、面试问题、简单实例(python实现,sklearn调包)

news2025/1/16 17:03:09

在这里插入图片描述

目录

  • 1. 决策树原理
    • 1.1. 特性
    • 1.2. 思路
    • 1.3. 概念
      • 决策树概念
      • 信息论
  • 2. 公式推导
    • 2.1. 构造决策树
      • 2.1.1. ID3
        • 理论
        • 示例
        • 缺点
      • 2.1.2. C4.5
        • 理论
        • 示例
        • 缺点
      • 2.1.3. CART
        • 示例
      • 对比分析
    • 2.2. 剪枝
  • 3. 实例
    • 3.1. 数据集
    • 3.2. ID3
    • 3.3. C4.5
    • 3.4. CART
    • 3.5. sklearn实现
  • 4. 几个注意点(面试问题)
  • 5. 运行(可直接食用)

1. 决策树原理

1.1. 特性

决策树属于非参数监督学习方法,在学习数据特征推断出简单的决策规划创建一个预测目标变量值的模型。

优点

  1. 可解释型强,而且决策树很容易可视化。此外,决策树使用“白箱模型”,比起黑箱模型,解释性也更强。
  2. 能处理多输出问题。
  3. 即使其假设与生成数据的真实模型有一定程度的违背,也能很好地执行。
  4. 使用树(即预测数据)的成本在用于训练树的数据点的数量上是对数的。

缺点

  1. 很容易就过拟合了。所以才需要剪枝等技巧或者设置叶节点所需最小样本树或者设置树最大深度。
  2. 不稳定,因为时递归出来的,很微小的变化可能导致整个树结构大变。
  3. 不擅长泛化和外推。因为决策树的预测不是平滑的也不是连续的,所以中间的空挡很难处理,所以在外推方面表现差。
  4. 决策树基于的启发式算法可能会导致产生的决策树并非最优决策树,比如在使用贪心算法的时候。因此才会有人提出特征和样本需要被随机抽样替换。
  5. 在一些问题上先天不足。如:XOR、奇偶校验或多路复用器问题。
  6. 不支持缺失值,这点在随机森林中也有所提现,所以采取本方法时注意缺失值处理或填补。

1.2. 思路

在计算机中,决策树更多的时用于分类。有点类似于用 if else 语句去穷举每一种可能,但是面对多个判断的情况时,每个判断出现的位置还不确定,我们需要通过某种方法确定各个条件出现顺序(谁在树根谁在下面),这个方法是决策树的一个难点。此外,面对大量数据时,建立的决策树模型可能会特别特别大,深度深,叶子节点多且节点中的样例少等可能性,优化这个问题需要的方法也是一个难题。

所有的决策树的步骤的最终目的都是希望通过学习产生一个泛化能力强的决策树,即处理未见示例能力强的决策树。

1.3. 概念

决策树概念

在这里插入图片描述

根据西瓜书的定义:⼀棵决策树包含⼀个根节点,若⼲个内部节点和若⼲个叶⼦节点。叶结点对应决策树决策结果,其他每个结点则对应于⼀个属性测试。

信息论

决策树这里用到了信息论中间很基础的知识,我们希望决策树的分支节点所包含的样本尽可能都属于一个类别,那么这么衡量这个?这需要用到纯度。直观上讲就是分类纯度越低,里面的标签就越杂 。

那么纯度怎么表示?

可以用信息熵表示。假定当前样本集合 D 中第 k 类样本所占的比例为 P k P_k Pk (k =1, 2,. . . , ∣ γ ∣ |\gamma| γ) ,则D的信息熵定义为:
E n t ( D ) = − ∑ k = 1 ∣ γ ∣ p k l o g 2 p k Ent(D)=-\sum^{|\gamma|}_{k=1}{p_klog_2p_k} Ent(D)=k=1γpklog2pk

Ent(D)的值越小,则D的纯度越⾼
Ent(D)的值越大,则D的不纯度越大,则D的不确定性越大

2. 公式推导

决策树模型构建一般分为三步,

  1. 特征选择
  2. 决策树生成
  3. 决策树修剪

特征选择更多的是取决于数据集,在生活中的问题里可能更加重要,一般来说没有什么特定的标准,需要注意的点比较琐碎,这边不细讲。

决策树生成和修剪的难度更大,这里做出总结。

2.1. 构造决策树

如何构造决策树其实就是如何选取每层节点的问题,西瓜书上介绍了三种方法,这边也按照这三种方法总结,其他的方法我将在面试题中提到。

2.1.1. ID3

理论

信息增益区分。

我们知道信息熵可以确定经过一个节点划分后每个分支节点的纯度,但是我们这里需要确定节点的顺序,所以单单知道划分后每个分支节点的纯度还不够,应该通过某种方法利用分支节点的纯度反映到划分节点的优劣。

ID3就是使用信息增益来判断每个节点的划分能力的。信息增益公式如下:

G a i n ( D , α ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D, \alpha)=Ent(D)-\sum^{V}_{v=1}{\frac{|D^v|}{|D|}Ent(D^v)} Gain(D,α)=Ent(D)v=1VDDvEnt(Dv)

∣ D v ∣ ∣ D ∣ \frac{|D^v|}{|D|} DDv就是每个分支节点的划分权重,所以样本数量越多划分权重越大。 G a i n ( D , α ) Gain(D, \alpha) Gain(D,α)就是使用 α \alpha α属性划分对整体纯度提升的能力。因此我们有 ID3 的逻辑顺序:

  • 从根节点开始,会节点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征。
  • 由该特征的不同取值建立子节点,在对子节点递归地调用上面方法,构建决策树,知道所有的特征信息增益都很小或者没有特征可以选为止。
  • 得到决策树模型。

示例

西瓜数据集

编号色泽根蒂敲声纹理脐部触感密度含糖率好瓜
10000000.6970.461
21010000.7740.3761
31000000.6340.2641
40010000.6080.3181
52000000.5560.2151
60100110.4030.2371
71101110.4810.1491
81100100.4370.2111
91111100.6660.0910
100220210.2430.2670
112222200.2450.0570
122002210.3430.0990
130101000.6390.1610
142111000.6570.1980
151100110.360.370
162002200.5930.0420
170011100.7190.1030

计算根蒂的信息增益:

在这里插入图片描述

缺点

对种类多的属性有偏好,比如编号属性。这个属性的信息增益会特别高,但是实际上它所包含的信息没多少。

此时就需要另一个算法 C4.5

2.1.2. C4.5

理论

增益率区分。

既然考虑到信息增益对种类多的属性的偏好,增益率应该削弱这类属性的权重。于是提出增益的权重,也称为属性的固有值。

对信息增益提出了惩罚,特征可取值数量较多时,惩罚参数大;特征可取值数量较少时,惩罚参数小。

I V ( α ) = − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ l o g 2 ∣ D v ∣ ∣ D ∣ IV(\alpha)=-\sum^{V}_{v=1}{\frac{|D^v|}{|D|}}log_2\frac{|D^v|}{|D|} IV(α)=v=1VDDvlog2DDv
G a i n _ r a t i o = G a i n ( D , α ) I V ( α ) Gain\_ratio=\frac{Gain(D,\alpha)}{IV(\alpha)} Gain_ratio=IV(α)Gain(D,α)

底层思路是讲连续特征离散化,所有也可以用来处理连续型数据。

示例

西瓜数据集

编号色泽根蒂敲声纹理脐部触感密度含糖率好瓜
10000000.6970.461
21010000.7740.3761
31000000.6340.2641
40010000.6080.3181
52000000.5560.2151
60100110.4030.2371
71101110.4810.1491
81100100.4370.2111
91111100.6660.0910
100220210.2430.2670
112222200.2450.0570
122002210.3430.0990
130101000.6390.1610
142111000.6570.1980
151100110.360.370
162002200.5930.0420
170011100.7190.1030

计算根蒂的信息增益率:

缺点

缺点也很明显,就是矫枉过正。面对可取数值比较少的属性时,增益率存在偏好。因此有了基指系数

2.1.3. CART

基尼指数划分。

我们可以理解为随机抽取两个样本,其类别不一样的概率的加权和。这个权重跟 C4.5 的权重划分方法相同。

这个本质上是一个回归树,对于计算机来说虽然很多时候结果跟信息熵算出来的差不多,但是计算会快很多,因为没有对数。

数据集的基尼指数:

G i n i ( D ) = ∑ k = 1 ∣ γ ∣ ∑ k ′ ≠ k p k p k ′ = 1 − ∑ k = 1 ∣ γ ∣ p k 2 = 1 − ∑ k = 1 ∣ γ ∣ ( ∣ D i ∣ ∣ D ∣ ) 2 Gini(D)=\sum^{|\gamma|}_{k=1}\sum_{k^{'} \ne k}p_kp_{k^{'}}=1-\sum^{|\gamma|}_{k=1}p_{k}^{2}=1-\sum^{|\gamma|}_{k=1}{(\frac{|D_i|}{|D|})^2} Gini(D)=k=1γk=kpkpk=1k=1γpk2=1k=1γ(DDi)2

数据集纯度越高 ∑ k = 1 ∣ γ ∣ p k 2 \sum^{|\gamma|}_{k=1}p_{k}^{2} k=1γpk2 越大,Gini 越小

属性的基尼指数:

G i n i _ i n d e x ( D , α ) = ∑ v = 1 V ∣ D v ∣ ∣ D ∣ G i n i ( D v ) Gini\_index(D, \alpha )=\sum^{V}_{v=1}{\frac{|D^v|}{|D|}Gini(D^{v})} Gini_index(D,α)=v=1VDDvGini(Dv)

最优划分属性就是 G i n i _ i n d e x ( D , α ) Gini\_index(D, \alpha ) Gini_index(D,α) 最小的属性。

示例

西瓜数据集

编号色泽根蒂敲声纹理脐部触感密度含糖率好瓜
10000000.6970.461
21010000.7740.3761
31000000.6340.2641
40010000.6080.3181
52000000.5560.2151
60100110.4030.2371
71101110.4810.1491
81100100.4370.2111
91111100.6660.0910
100220210.2430.2670
112222200.2450.0570
122002210.3430.0990
130101000.6390.1610
142111000.6570.1980
151100110.360.370
162002200.5930.0420
170011100.7190.1030

计算根蒂的基尼指数:

在这里插入图片描述

对比分析

在这里插入图片描述

2.2. 剪枝

剪枝 (pruning)是决策树学习算法对付"过拟合"的主要⼿段。在决策树学习中,为了尽可能正确分类训练样本,结点划分过程将不断重复,有时会造成决策树分支过多,这时就可能因训练样本学得"太好"了,以致于把训练集自身的⼀些特点当作所有数据都具有的⼀般性质⽽导致过拟合。

决策树剪枝的基本策略有“预剪枝” (prepruning)和“后剪枝 ”(post pruning) [Quinlan, 1993]. 预剪枝是指在决策树⽣成过程中,对每个结点划分前先进⾏估计,若当前结点的划分不能带来决策树泛化性能提升,则停⽌划分并将当前结点标记为叶结点;后剪枝则是先从训练集⽣成⼀棵完整的决策树, 然后自底向上地对非叶结点进⾏考察,若将该结点对应的⼦树替换为叶结点能带来决策树泛化性能提升,则将该⼦树替换为叶结点

3. 实例

3.1. 数据集

watermelon.csv

x1,x2,x3,x4,x5,x6,label
0,0,0,0,0,0,1
1,0,1,0,0,0,1
1,0,0,0,0,0,1
0,0,1,0,0,0,1
2,0,0,0,0,0,1
0,1,0,0,1,1,1
1,1,0,1,1,1,1
1,1,0,0,1,0,1
1,1,1,1,1,0,0
0,2,2,0,2,1,0
2,2,2,2,2,0,0
2,0,0,2,2,1,0
0,1,0,1,0,0,0
2,1,1,1,0,0,0
1,1,0,0,1,1,0
2,0,0,2,2,0,0
0,0,1,1,1,0,0

3.2. ID3

# 计算信息熵
def c_entropy(label):
    # 每类有几个
    count_class = Counter(label)
    ret = 0
    for classes in count_class.keys():
        ret += count_class[classes]/len(label)*np.log2(count_class[classes]/len(label))
    return -ret


# 计算单个属性信息增益
def c_entropy_gain(data, label):
    ent = c_entropy(label)
    # 几个值循环几次
    count_value = Counter(data)
    sum_entropy = 0
    for value_a in count_value.keys():
        sum_entropy += count_value[value_a] / len(label) * c_entropy(label[data == value_a])
    return ent - sum_entropy


# 计算计算每个属性的信息增益并求最大的
def find_entropy_index(data, label):
    label = label.flatten()
    feature_count = data.shape[1]
    max_list = []
    for i in range(feature_count):
        max_list.append(c_entropy_gain(data[:, i], label))
    max_value = max(max_list)  # 求列表最大值
    max_idx = max_list.index(max_value)  # 求最大值对应索引
    return max_idx

在这里插入图片描述

3.3. C4.5

# 计算信息熵
def c_entropy(label):
    # 每类有几个
    count_class = Counter(label)
    ret = 0
    for classes in count_class.keys():
        ret += count_class[classes]/len(label)*np.log2(count_class[classes]/len(label))
    return -ret


# 计算单个属性信息增益率
def c_entropy_gain_rate(data, label):
    ent = c_entropy(label)
    # 几个值循环几次
    count_value = Counter(data)
    sum_entropy = 0
    sum_iv = 0
    for value_a in count_value.keys():
        d = count_value[value_a] / len(label)
        sum_entropy += d * c_entropy(label[data == value_a])
        sum_iv -= d * np.log2(d)
    return (ent - sum_entropy) / sum_iv


# 计算计算每个属性的信息增益率并求最大的
def find_entropy_ratio(data, label):
    # 返回大于中位数的下标
    def get_mid_index(list):
        mid = int(len(list)/2)
        b = sorted(enumerate(list), key=lambda x: x[1])  # x[1]是因为在enumerate(a)中,a数值在第1位
        c = [x[0] for x in b]  # 获取排序好后b坐标,下标在第0位
        return c[mid:]


    label = label.flatten()
    feature_count = data.shape[1]
    max_list_e = []
    max_list = []
    for i in range(feature_count):
        max_list_e.append(c_entropy_gain(data[:, i], label))
        max_list.append(c_entropy_gain_rate(data[:, i], label))
    # 只保留信息增益大于平均水平的
    max_list_mid = [max_list[i] for i in get_mid_index(max_list_e)]
    max_value = max(max_list_mid)  # 求列表最大值
    max_idx = max_list_mid.index(max_value)  # 求最大值对应索引
    return max_idx

在这里插入图片描述

3.4. CART

# 计算单个属性基尼指数
def c_gini(data, label):
    # 几个值循环几次
    count_value = Counter(data)

    gini = 0
    for value_a in count_value.keys():
        d = count_value[value_a] / len(label)
        class_gini = 1
        # 取出值一样的data的value
        same_value_class = label[data == value_a]
        count_labels = Counter(same_value_class)
        for i in count_labels.values():
            class_gini -= (i / count_value[value_a])**2
        gini += d * class_gini
    return gini


# 计算计算每个属性的基尼指数并求最小的
def find_gini(data, label):
    label = label.flatten()
    feature_count = data.shape[1]
    min_list = []
    for i in range(feature_count):
        min_list.append(c_gini(data[:, i], label))
    min_value = min(min_list)  # 求列表最小值
    min_idx = min_list.index(min_value)  # 求最小值对应索引
    return min_idx

在这里插入图片描述

3.5. sklearn实现


def sk(data_train, label_train, data_test, label_test):
    clf = tree.DecisionTreeClassifier(criterion="entropy")
    clf = clf.fit(data_train, label_train)
    score = clf.score(data_test, label_test)  # 返回精确度
    print("准确率为:", score * 100, "%")

    # 画决策树
    feature_name = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
    import matplotlib as mpl
    mpl.rcParams['font.sans-serif'] = ['FangSong']  # 指定中文字体
    mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(4, 4), dpi=300)
    tree.plot_tree(clf,
                feature_names=feature_name,
                class_names=["好瓜", "坏瓜"],
                filled=True,  # 填充颜色,颜色越深,不纯度越低
                rounded=True  # 框的形状
                )
    plt.show()

结果如下

在这里插入图片描述

可视化一下:

在这里插入图片描述

4. 几个注意点(面试问题)

  决策树有几个注意点,可能会在面试中被提到,还是比较能体现被面试者对这个算法的理解的。当然我们也不一定就是为了面试,搞清楚这些问题对帮助我们理解这个算法还是很有好处的。其中部分答案是博主自己的理解,如果有问题麻烦路过的大佬评论区指正。

  1. 决策树常使用那些常用的启发函数?

答:

  • ID3 => 最大信息增益
  • C4.5 => 最大信息增益比
  • CART => 最大基尼指数

这仨还有一些特点:

  • ID3 泛化能力弱,C4.5一定程度上惩罚了取值较多的特征,避免ID3的过拟合,提升泛化能力。

  • ID3 只能处理离散数据,其他俩都可以处理连续型变量,C4.5靠排序后找分割线把连续型变量转化为布尔型,CART靠对特征二值划分。

  • ID3 C4.5 只能做分类任务,CART可以分类也可以回归。

  • ID3 对缺失值特别敏感。

  • ID3 C4.5 每个节点可以产生多分支,CART只会产生二分支。

  1. 决策树是如何进行剪枝的?

答:

老生常谈的问题了,前面也有讲。这里就重复讲下。

预剪枝

在树长到一定深度后停止生长,当树的节点样本数小于某个阈值时停止生长,每次分裂树带来的提升小于某个阈值时停止生长。

可能导致欠拟合。不过比较适合大规模的问题,在分类后期准确率会有显著增长。需要有经验的人确定相关参数(节点样本数、最大深度等)

后剪枝

后剪枝是在树建立后,从下往上计算是否需要剪掉一部分子树并用叶子代替。比较难的地方是几个后剪枝的方法(错误率降低后剪枝、悲观剪枝、代价复杂度剪枝、最小误差剪枝、CVP、OPP 等)。这个比较复杂,不详细介绍。

  1. 决策树递归时导致递归返回的三种条件?

答:

  • 当前节点包含的样本全部属于统同一类别,无需划分。
  • 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分。
  • 当前节点包含的样本集合为空,不能划分。
  1. 计算恋爱数据集。数据集如下:
样本颜值(A)身材(B)性格(C)收入(D)学历(E)交往(R)
1101111
2110000
3111111
4101010
5110100
6010110
7110000
8101111
9001110
10111111
11101000
12001110
13111011
14101111
15111111

答:
暂时就算一个颜值和身材哈~~

在这里插入图片描述
在这里插入图片描述

5. 运行(可直接食用)

import random
from collections import Counter
import numpy as np
import pandas as pd
import warnings
from matplotlib import pyplot as plt
warnings.filterwarnings("ignore")
from sklearn import tree


class Node(object):
    def __init__(self, son=None, data=None, label=None, node_label=None):
        self.son = son      # 节点的子节点
        self.data = data    # 节点包含的数据
        self.label = label  # 节点包含的标签
        self.node_label = node_label    # 按照哪个属性分类
        self.if_leaf = False        # 默认不是子节点


def build_tree(data, label, choice):
    mytree = Node()
    # 全部属于一个类别
    if np.unique(label.flatten()).shape[0] == 1:
        mytree.if_leaf = True
        mytree.label = label[0]
        return mytree
    # 当前属性都属于一个类别
    if np.unique(data, axis=0).shape[0] == 1:
        mytree.if_leaf = True
        # 听天由命,直接选第一个
        mytree.label = label[0]
        return mytree
    # 样本或属性集合是空的
    if label.shape[0] == 0:
        mytree.if_leaf = True
        # 从来没碰到这种情况,默认0
        mytree.label = 0
        return mytree
    # 判断
    if choice == 1:
        mytree = Node(data=data, label=label, node_label=find_entropy_index(data, label))
    if choice == 2:
        mytree = Node(data=data, label=label, node_label=find_entropy_ratio(data, label))
    if choice == 3:
        mytree = Node(data=data, label=label, node_label=find_gini(data, label))
    mytree.son_label = Counter(data[:, mytree.node_label]).keys()
    # 递归
    # for feature_class in Counter(data[:, mytree.node_label]).keys():
    #     print(data[data[:, mytree.node_label] == feature_class], "\n", label[data[:, mytree.node_label] == feature_class],"\n\n")

    mytree.son = {feature_class: build_tree(data[data[:, mytree.node_label] == feature_class], label[data[:, mytree.node_label] == feature_class], choice) for feature_class in Counter(data[:, mytree.node_label]).keys()}
    return mytree



# 计算信息熵
def c_entropy(label):
    # 每类有几个
    count_class = Counter(label)
    ret = 0
    for classes in count_class.keys():
        ret += count_class[classes]/len(label)*np.log2(count_class[classes]/len(label))
    return -ret


# 计算单个属性信息增益
def c_entropy_gain(data, label):
    ent = c_entropy(label)
    # 几个值循环几次
    count_value = Counter(data)
    sum_entropy = 0
    for value_a in count_value.keys():
        sum_entropy += count_value[value_a] / len(label) * c_entropy(label[data == value_a])
    return ent - sum_entropy



# 计算计算每个属性的信息增益并求最大的
def find_entropy_index(data, label):
    label = label.flatten()
    feature_count = data.shape[1]
    max_list = []
    for i in range(feature_count):
        max_list.append(c_entropy_gain(data[:, i], label))
    max_value = max(max_list)  # 求列表最大值
    max_idx = max_list.index(max_value)  # 求最大值对应索引
    return max_idx


# 计算单个属性信息增益率
def c_entropy_gain_rate(data, label):
    ent = c_entropy(label)
    # 几个值循环几次
    count_value = Counter(data)
    sum_entropy = 0
    sum_iv = 0
    for value_a in count_value.keys():
        d = count_value[value_a] / len(label)
        sum_entropy += d * c_entropy(label[data == value_a])
        sum_iv -= d * np.log2(d)
    return (ent - sum_entropy) / sum_iv


# 计算计算每个属性的信息增益率并求最大的
def find_entropy_ratio(data, label):
    # 返回大于中位数的下标
    def get_mid_index(list):
        mid = int(len(list)/2)
        b = sorted(enumerate(list), key=lambda x: x[1])  # x[1]是因为在enumerate(a)中,a数值在第1位
        c = [x[0] for x in b]  # 获取排序好后b坐标,下标在第0位
        return c[mid:]


    label = label.flatten()
    feature_count = data.shape[1]
    max_list_e = []
    max_list = []
    for i in range(feature_count):
        max_list_e.append(c_entropy_gain(data[:, i], label))
        max_list.append(c_entropy_gain_rate(data[:, i], label))
    # 只保留信息增益大于平均水平的
    max_list_mid = [max_list[i] for i in get_mid_index(max_list_e)]
    max_value = max(max_list_mid)  # 求列表最大值
    max_idx = max_list_mid.index(max_value)  # 求最大值对应索引
    return max_idx


# 计算单个属性基尼指数
def c_gini(data, label):
    # 几个值循环几次
    count_value = Counter(data)

    gini = 0
    for value_a in count_value.keys():
        d = count_value[value_a] / len(label)
        class_gini = 1
        # 取出值一样的data的value
        same_value_class = label[data == value_a]
        count_labels = Counter(same_value_class)
        for i in count_labels.values():
            class_gini -= (i / count_value[value_a])**2
        gini += d * class_gini
    return gini


# 计算计算每个属性的基尼指数并求最小的
def find_gini(data, label):
    label = label.flatten()
    feature_count = data.shape[1]
    min_list = []
    for i in range(feature_count):
        min_list.append(c_gini(data[:, i], label))
    min_value = min(min_list)  # 求列表最小值
    min_idx = min_list.index(min_value)  # 求最小值对应索引
    return min_idx


def sk(data_train, label_train, data_test, label_test):
    clf = tree.DecisionTreeClassifier(criterion="entropy")
    clf = clf.fit(data_train, label_train)
    score = clf.score(data_test, label_test)  # 返回精确度
    print("准确率为:", score * 100, "%")

    # 画决策树
    feature_name = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
    import matplotlib as mpl
    mpl.rcParams['font.sans-serif'] = ['FangSong']  # 指定中文字体
    mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号
    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(4, 4), dpi=300)
    tree.plot_tree(clf,
                feature_names=feature_name,
                class_names=["好瓜", "坏瓜"],
                filled=True,  # 填充颜色,颜色越深,不纯度越低
                rounded=True  # 框的形状
                )
    plt.show()


def pridect(data, mytree):

    # 找下个节点
    def find_result(mytree, row):
        # 如果是叶子节点
        if mytree.if_leaf:
            return mytree.label
        # 不是叶子节点就继续找
        row_the_value = row[mytree.node_label]
        return find_result(mytree.son[row_the_value], row)

    ret = np.array([find_result(mytree, row) for row in data])
    return ret


def eveluate(predict, result):
    correct = 0
    for i in range(predict.shape[0]):
        correct += (int(predict[i][0]+0.5) == result[0])
    print("准确率为:", correct / predict.shape[0] * 100, "%")

if __name__ == '__main__':
    random.seed(1129)
    data = pd.read_csv("watermelonData.csv").sample(frac=1, random_state=1129)
    # 因为西瓜数据集每列都是0到2的,所以这里就不进行标准化了

    labels = data["label"]
    data_shuffled = np.array(data[data.columns[:-1]])
    # 划分训练集测试集
    data_train = data_shuffled[:int(data_shuffled.shape[0]*0.6), :]
    label_train = np.array(labels[:data_train.shape[0]]).reshape(-1, 1)
    data_test = data_shuffled[data_train.shape[0]:, :]
    label_test = np.array(labels[data_train.shape[0]:]).reshape(-1, 1)

    choice = 0
    while choice != 5:
        print("1. ID3求解\n2. C45求解\n3. CART求解\n4. sklearn求解\n5. 退出")
        try:
            choice = int(input())
        except:
            break
        if choice == 1:
            print("ID3求解中...")
            mytree = build_tree(data_train, label_train, choice)
            eveluate(pridect(data_test, mytree), label_test)
            # find_entropy_index(data_train, label_train)
        elif choice == 2:
            print("C45求解中...")
            mytree = build_tree(data_train, label_train, choice)
            eveluate(pridect(data_test, mytree), label_test)
            # find_entropy_ratio(data_train, label_train)
        elif choice == 3:
            print("CART求解中...")
            mytree = build_tree(data_train, label_train, choice)
            eveluate(pridect(data_test, mytree), label_test)
            # find_gini(data_train, label_train)
        elif choice == 4:
            print('sklearn yyds')
            sk(data_train, label_train, data_test, label_test)
        else:
            print("退出成功")
            break

参考
吴恩达《机器学习》
sklearn官网
《百面机器学习》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/335467.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高性能MySQL -- 查询性能优化

一般来说一个好的程序:查询优化,索引优化,库表结构要同时进行优化。今天我们来讲一下查询优化。 我们需要对MySQL的架构有基本认知,所以这里贴一张图大家看看: 图片来自于《小林coding》 为什么从查询会慢&#xff1…

点云深度学习系列博客(四): 注意力机制原理概述

目录 1. 注意力机制由来 2. Nadaraya-Watson核回归 3. 多头注意力与自注意力 4. Transformer模型 Reference 随着Transformer模型在NLP,CV甚至CG领域的流行,注意力机制(Attention Mechanism)被越来越多的学者所注意,将…

九、Linux文件 - fopen函数和fclose函数讲解

目录 1.fopen函数 2.fclose函数 3.fopen函数和fclose实战 1.fopen函数 fopen fwrite fread fclose ...属于标准C库 include <stdio.h> standard io lib open close write read 属于Linux系统调用 可移植型&#xff1a;fopen > open&#xff08;open函数只在嵌入…

ES6的代理Proxy和反射Reflect的使用

一、Proxy使用 作用&#xff1a;Proxy是ES6为了操作对象而引入的API&#xff0c;不直接作用于对象&#xff0c;而是通过类似媒介的方式进行对象的操作使用 /*** target&#xff1a;需要proxy处理的对象* handler&#xff1a;对对象进行处理的方法 */ let proxy new Proxy(ta…

ARM uboot源码分析2-启动第一阶段

一、start.S 解析5 注释的中文含义&#xff1a; 当我们已经在 RAM 中运行时&#xff0c;我们不需要重新定位 U-Boot。实际上&#xff0c;在 U-Boot 在 RAM 中运行之前&#xff0c;必须配置内存控制器。 1、判断当前代码执行位置 (1) lowlevel_init.S 的 110-115 行。 (2) 这几…

5年经验之谈:月薪3000到30000,测试工程师的变“行”记!

自我介绍下&#xff0c;我是一名转IT测试人&#xff0c;我的专业是化学&#xff0c;去化工厂实习才发现这专业的坑人之处&#xff0c;化学试剂害人不浅&#xff0c;有毒&#xff0c;易燃易爆&#xff0c;实验室经常用丙酮&#xff0c;甲醇&#xff0c;四氯化碳&#xff0c;接触…

ESP32 Arduino(十二)lvgl移植使用

一、简介LVGL全程LittleVGL&#xff0c;是一个轻量化的&#xff0c;开源的&#xff0c;用于嵌入式GUI设计的图形库。并且配合LVGL模拟器&#xff0c;可以在电脑对界面进行编辑显示&#xff0c;测试通过后再移植进嵌入式设备中&#xff0c;实现高效的项目开发。SquareLine Studi…

RMI攻击中的ServerClient相互攻击反制

前言 前文中&#xff0c;我们分析了攻击Registry的两种方式&#xff0c;这里我们接着前面的内容&#xff0c;分析Server和Client的相互攻击方式。 Attacked Server Attacked By Client 首先我们搭建个示例&#xff0c;这里直接注册端和服务端放置在一起。 package pers.rm…

JS 实现抛物线动画案例

相信大家都有浏览过&#xff0c;很多购物网站购物车的添加商品动画&#xff0c;今天&#xff0c;我们就手写一个简单的抛物线动画&#xff0c;先上案例&#xff1a; 一、绘制页面 我们这里简单实现&#xff0c;一个按钮&#xff0c;一个购物车图标&#xff0c;样式这里直接跳过…

GNN图神经网络原理解析

一、GNN基本概念 1. 图的基本组成 图神经网络的核心就是进行图模型搭建&#xff0c;图是由点和边组成的。在计算机处理时&#xff0c;通常将数据以向量的形式进行存储。因此&#xff0c;在存储图时&#xff0c;就会有点的向量&#xff0c;点与点之间边的向量&#xff0c;全局…

Acwing---1235. 付账问题

付账问题1.题目2.基本思想3.代码实现1.题目 几个人一起出去吃饭是常有的事。 但在结帐的时候&#xff0c;常常会出现一些争执。 现在有 nnn个人出去吃饭&#xff0c;他们总共消费了 SSS元。 其中第 iii 个人带了 aiaiai元。 幸运的是&#xff0c;所有人带的钱的总数是足够…

vue解决跨域问题-反向代理

浏览器有同源策略&#xff0c;限制同协议、同域名、同端口&#xff0c;只要有一项不一致&#xff0c;就是跨域。&#xff08;不同源则跨域&#xff09; 解决方案&#xff1a; 后端 、cors 、 jsonp、 反向代理 同源下&#xff1a;浏览器向服务器请求数据&#xff0c;服务器响应…

jenkins +docker+python接口自动化之docker下安装jenkins(一)

jenkins dockerpython接口自动化之docker下安装jenkins&#xff08;一&#xff09; 目录&#xff1a;导读 1、下载jenkins 2、启动jenkins 3、访问jenkins 4.浏览器直接访问http://ip/:8080 5.然后粘贴到输入框中,之后新手入门中先安装默认的插件即可&#xff0c;完成后出…

俄罗斯方块游戏代码

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a;小刘主页 ♥️每天分享云计算网络运维课堂笔记&#xff0c;努力不一定有收获&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️夕阳下&#xff0c;是最美的&#xff0c;绽…

图表控件LightningChart .NET再破世界纪录,支持实时可视化 1 万亿个数据点

LightningChart.NET SDK 是一款高性能数据可视化插件工具&#xff0c;由数据可视化软件组件和工具类组成&#xff0c;可支持基于 Windows 的用户界面框架&#xff08;Windows Presentation Foundation&#xff09;、Windows 通用应用平台&#xff08;Universal Windows Platfor…

数据可视化大屏百度地图绘制行政区域标注实战案例解析(个性化地图、标注、视频、控件、定位、检索)

百度地图开发系列目录 数据可视化大屏应急管理综合指挥调度系统完整案例详解&#xff08;PHP-API、Echarts、百度地图&#xff09;数据可视化大屏百度地图API开发&#xff1a;停车场分布标注和检索静态版百度地图高级开发&#xff1a;map.getDistance计算多点之间的距离并输入…

元宵晚会节目预告没有岳云鹏,是不敢透露还是另有隐情

在刚刚结束的元宵节晚会上&#xff0c;德云社的岳云鹏&#xff0c;再一次参加并引起轰动&#xff0c;并获得了观众朋友们的一致好评。 不过有细心的网友发现&#xff0c;早前央视元宵晚会节目预告&#xff0c;并没有看到小岳岳&#xff0c;难道是不敢提前透露&#xff0c;怕公布…

TCP 三次握手和四次挥手

✏️作者&#xff1a;银河罐头 &#x1f4cb;系列专栏&#xff1a;JavaEE &#x1f332;“种一棵树最好的时间是十年前&#xff0c;其次是现在” 目录TCP 建立连接(三次握手)为啥不能是 4 次&#xff1f;为啥不能是 2 次&#xff1f;三次握手的意义&#xff1a;TCP 断开连接(四…

前端报表如何实现无预览打印解决方案或静默打印

在前端开发中&#xff0c;除了将数据呈现后&#xff0c;我们往往需要为用户提供&#xff0c;打印&#xff0c;导出等能力&#xff0c;导出是为了存档或是二次分析&#xff0c;而打印则因为很多单据需要打印出来作为主要的单据来进行下一环节的票据支撑&#xff0c; 而前端打印可…

Android Binder机制之一(简介)

目录 前言 一、Android 进程间通信方式 二、Binder架构图 三、Binder涉及角色 3.1 Binder驱动 3.2 Binder实体 3.3 Binder引用 3.4 远程服务 3.5 ServiceManager守护进程 四、涉及源码 前言 这是本人第N次看Binder 相关知识了&#xff0c;其实每次看都有新的收获&…