决策树的生成与剪枝

news2025/1/23 13:05:43

决策树的生成与剪枝

  • 决策树的生成
    • 生成决策树的过程
    • 决策树的生成算法
  • 决策树的剪枝
    • 决策树的损失函数
    • 决策树的剪枝算法
  • 代码

决策树的生成

生成决策树的过程

为了方便分析描述,我们对上节课中的训练样本进行编号,每个样本加一个ID值,如图所示:
在这里插入图片描述
从根结点开始生成决策树,先将上述训练样本(1-9)全部放在根节点中。

然后选择信息增益或信息增益比最大的特征向下分裂,按照所选特征取值的不同将训练样本分发到不同的子结点中。

例如所选特征有3个取值,则分裂出3个子结点,然后在子结点中重复上述过程,直到所有特征的信息增益(比)很小或者没有特征可选为止,完成决策树的构建,如下图所示:
在这里插入图片描述
图中的决策树共有2个内部结点和3个叶子结点,每个结点旁边的编号代表训练样本的 ID 值。内部结点代表样本的特征,叶子结点代表样本的预测类别,我们将叶子节点中训练样本占比最大的类作为决策树的预测标记。

在使用构建好的决策树在测试数据上分类时,只需要从根节点开始依次测试内部结点代表的特征即可得到测试样本的预测分类。

决策树的生成算法

下面我们先来总结一下 ID3分类决策树的生成算法:

输入:训练数据集 D 、特征集合 A 的信息增益阈值 ;输出:决策树 T

  1. 若 D 中的训练样本属于同一类,则 T 为单结点树,返回 D 中任意样本的类别。

  2. 若 A 中的特征为空,则 T 为单结点树,返回 D 中数量最多的类别。

  3. 使用信息增益在 A 中进行特征选择,若所选特征 A_i 的信息增益小于设定的阈值,则 T 为单结点树,返回 D 中数量最多的类别。

  4. 否则根据 A_i 的每一个取值,将 D 分成若干子集 D_i,将 D_i 中数量最多的类作为标记值,构建子结点,返回 T。

  5. 以 D_i 为训练集,{A - A_i} 为特征集,递归地调用上述步骤,得到子树 T_i,返回 T。

使用 C4.5 算法进行决策树的生成只需要将信息增益改成信息增益比即可。

决策树的剪枝

决策树的损失函数

决策树的叶子节点越多,模型越复杂。决策树的损失函数考虑了模型复杂度,我们可以通过优化其损失函数来对决策树进行剪枝。决策树的损失函数计算过程如下:

  1. 计算叶子结点 t 的样本类别经验熵
    在这里插入图片描述
    对于叶子结点 t 来说,其样本类别的经验熵越小, t 中训练样本的分类误差就越小。当叶子结点 t 中的训练样本为同一类别时,经验熵为零,分类误差为零。

  2. 计算决策树 T 在所有训练样本上的损失之和 C(T)
    在这里插入图片描述
    对于叶子结点 t 中的每一个训练样本,其类别标记都是随机变量 Y 的一个取值,这个取值的不确定性用信息熵来衡量,且可以用经验熵来估计。由上文可知,经验熵在一定程度上可以反映决策树在该样本上的预测损失,累加所有叶子结点上的训练样本损失即上图中的计算公式。

  3. 计算考虑模型复杂度的的决策树损失函数

在这里插入图片描述
决策树的叶子结点个数表示模型的复杂度,通过最小化上面的损失函数,一方面可以减少模型在训练样本上的预测误差,另一方面可以控制模型的复杂度,保证模型的泛化能力。

决策树的剪枝算法

  1. 计算决策树中每个结点的样本类别经验熵:
    在这里插入图片描述
    如上图所示,对于本课示例中的决策树,需要计算 5 个结点的经验熵。

  2. 遍历非叶子结点,剪枝相当于去除其子结点,自身变为叶子结点:

在这里插入图片描述
对于图中的非叶子结点(有工作?),剪枝后变为叶子结点,并通过多数表决的方法确定其类别标记。

以上就是这节课的所有内容了,实际上还有一种决策树算法:分类与回归树(classification and regression tree,简称 CART),它既可以用于分类也可以用于回归,同样包含了特征选择、决策树的生成与剪枝算法。

关于 CART 算法的内容,我们将在最后一章 XGBoost 中进行学习,下面请你来做一道关于信息增益比的题目,顺便回顾一下前面所学的知识

代码

## 1. 创建数据集

import pandas as pd
data = [['yes', 'no', '青年', '同意贷款'],
        ['yes', 'no', '青年', '同意贷款'],
        ['yes', 'yes', '中年', '同意贷款'],
        ['no', 'no', '中年', '不同意贷款'],
        ['no', 'no', '青年', '不同意贷款'],
        ['no', 'no', '青年', '不同意贷款'],
        ['no', 'no', '中年', '不同意贷款'],
        ['no', 'no', '中年', '不同意贷款'],
        ['no', 'yes', '中年', '同意贷款']]

# 转为 dataframe 格式
df = pd.DataFrame(data)
# 设置列名
df.columns = ['有房?', '有工作?', '年龄', '类别']

## 2. 经验熵的实现

from math import log2
from collections import Counter
def H(y):
    '''
    y: 随机变量 y 的一组观测值,例如:[1,1,0,0,0]
    '''
    # 随机变量 y 取值的概率估计值
    probs = [n/len(y) for n in Counter(y).values()]
    # 经验熵:根据概率值计算信息量的数学期望
    return sum([-p*log2(p) for p in probs])
    
## 3. 经验条件熵的实现

def cond_H(a):
    '''
    a: 根据某个特征的取值分组后的 y 的观测值,例如:
       [[1,1,1,0],
        [0,0,1,1]]
       每一行表示特征 A=a_i 对应的样本子集
    '''
    # 计算样本总数
    sample_num = sum([len(y) for y in a])
    # 返回条件概率分布的熵对特征的数学期望
    return sum([(len(y)/sample_num)*H(y) for y in a])

## 4. 特征选择函数
def feature_select(df,feats,label):
    '''
    df:训练集数据,dataframe 类型
    feats:候选特征集
    label:df 中的样本标记名,字符串类型
    '''

    # 最佳的特征与对应的信息增益比
    best_feat,gainR_max = None,-1
    # 遍历每个特征
    for feat in feats:
        # 按照特征的取值对样本进行分组,并取分组后的样本标记数组
        group = df.groupby(feat)[label].apply(lambda x:x.tolist()).tolist()
        # 计算该特征的信息增益:经验熵-经验条件熵
        gain = H(df[label].values) - cond_H(group)
        # 计算该特征的信息增益比
        gainR = gain / H(df[feat].values)
       
        # 更新最大信息增益比和对应的特征
        if gainR > gainR_max:
            best_feat,gainR_max = feat,gainR
        
    return best_feat,gainR_max 

## 5. 决策树的生成函数
import pickle
def creat_tree(df,feats,label):
    '''
    df:训练集数据,dataframe 类型
    feats:候选特征集,字符串列表
    label:df 中的样本标记名,字符串类型
    '''
    # 当前候选的特征列表
    feat_list = feats.copy()
    
    # 若当前训练数据的样本标记值只有一种
    if df[label].nunique()==1:
        # 将数据中的任意样本标记返回,这里取第一个样本的标记值
        return df[label].values[0]
    # 若候选的特征列表为空时
    if len(feat_list)==0:
        # 返回当前数据样本标记中的众数,各类样本标记持平时取第一个
        return df[label].mode()[0]
    # 在候选特征集中进行特征选择
    feat,gain = feature_select(df,feat_list,label)
    # 若选择的特征信息增益太小,小于阈值 0.1
    if gain<0.1:
        # 返回当前数据样本标记中的众数
        return df[label].mode()[0]
    
    # 根据所选特征构建决策树,使用字典存储
    tree = {feat:{}}
    # 根据特征取值对训练样本进行分组
    g = df.groupby(feat)
    # 用过的特征要移除
    feat_list.remove(feat)
    # 遍历特征的每个取值 i
    for i in g.groups:
        # 获取分组数据,使用剩下的候选特征集创建子树
        tree[feat][i] = creat_tree(g.get_group(i),feat_list,label)
    
    # 存储决策树
    pickle.dump(tree,open('tree.model','wb'))
        
    return tree
    
# 6. 决策树的预测函数
def predict(tree,feats,x):
    '''
    tree:决策树,字典结构
    feats:特征集合,字符串列表
    x:测试样本特征向量,与 feats 对应
    '''
    # 获取决策树的根结点:对应样本特征
    root = next(iter(tree))
    # 获取该特征在测试样本 x 中的索引
    i = feats.index(root)
    # 遍历根结点分裂出的每条边:对应特征取值
    for edge in tree[root]:
        # 若测试样本的特征取值=当前边代表的特征取值
        if x[i]==edge:
            # 获取当前边指向的子结点
            child = tree[root][edge]
            # 若子结点是字典结构,说明是一颗子树
            if type(child)==dict:
                # 将测试样本划分到子树中,继续预测
                return predict(child,feats,x)
            # 否则子结点就是叶子节点
            else:
                # 返回叶子节点代表的样本预测值
                return child

## 7. 在样例数据上测试

# 获取特征名列表
feats = list(df.columns[:-1])
# 获取标记名
label = df.columns[-1]
# 创建决策树(此处使用信息增益比进行特征选择)
T = creat_tree(df,feats,label)
# 计算训练集上的预测结果
preds = [predict(T,feats,x) for x in df[feats].values]
# 计算准确率
acc = sum([int(i) for i in (df[label].values==preds)])/len(preds)
# 输出决策树和准确率
print(T,acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2262457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

51c嵌入式~单片机~合集2

我自己的原文哦~ https://blog.51cto.com/whaosoft/12362395 一、不同的电平信号的MCU怎么通信&#xff1f; 下面这个“电平转换”电路&#xff0c;理解后令人心情愉快。电路设计其实也可以很有趣。 先说一说这个电路的用途&#xff1a;当两个MCU在不同的工作电压下工作&a…

Kerberos实验

kdc&#xff1a;192.168.72.163 客户端&#xff08;机器账户win10&#xff09;&#xff1a;192.168.72.159 用户&#xff1a;administrator 抓包&#xff1a;开机登录win10&#xff0c;使用administrator域用户凭据登录。 生成 Kerberos 解密文件 抓取 krbtgt 用户和 win1…

AI一键分析小红书对标账号‼️

宝子们&#xff0c;AI小助手近期发现了一款宝藏AI工具&#xff0c;拥有对标账号AI分析功能&#xff0c;只需10秒就能全面掌握对标账号的运营情况&#xff0c;并且可以根据分析结果提供创作方向和灵感&#xff0c;轻松助力1:1复刻起号&#xff01; 功能亮点&#xff1a; &…

大腾智能CAD:国产云原生三维设计新选择

在快速发展的工业设计领域&#xff0c;CAD软件已成为不可或缺的核心工具。它通过强大的建模、分析、优化等功能&#xff0c;不仅显著提升了设计效率与精度&#xff0c;还促进了设计思维的创新与拓展&#xff0c;为产品从概念构想到实体制造的全过程提供了强有力的技术支持。然而…

VMware虚拟机 Ubuntu没有共享文件夹的问题

在虚拟机的Ubuntu系统中&#xff0c;共享文件目录存放在 mnt/hgfs 下面&#xff0c;但是我安装完系统并添加共享文件后发现&#xff0c;在mnt下连/hgfs目录都没有。 注意&#xff1a;使用共享文件目录需要已安装VMtools工具。 添加共享文件目录 一&#xff1a;在超级用户下 可…

OpenGL ES 01 渲染一个四边形

项目架构 着色器封装 vertex #version 300 es // 接收顶点数据 layout (location 0) in vec3 aPos; // 位置变量的属性位置值为0 layout (location 1) in vec4 aColors; // 位置变量的属性位置值为1 out vec4 vertexColor; // 为片段着色器指定一个颜色输出void main() {gl…

leetcode二叉搜索树部分笔记

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 二叉搜索树 1. 二叉搜索树的最小绝对差2. 二叉搜索树中第 K 小的元素3. 验证二叉搜索树 1. 二叉搜索树的最小绝对差 给你一个二叉搜索树的根节点 root &#xff0c;返回 树中…

推送本地仓库到远程git仓库

目录 推送本地仓库到远程git仓库1.1修改本地仓库用户名1.2 push 命令1.3远程分支查看 推送本地仓库到远程git仓库 删除之前的仓库中的所有内容&#xff0c;从新建库&#xff0c;同时创建一个 A.txt 文件 清空原有的远程仓库内容&#xff0c;重新创建一个新的仓库&#xff0c;…

暂停一下,给Next.js项目配置一下ESLint(Next+tailwind项目)

前提 之前开自己的GitHub项目&#xff0c;想着不是团队项目&#xff0c;偷懒没有配置eslint&#xff0c;后面发现还是不行。eslint的存在可以帮助我们规范代码格式&#xff0c;同时 ctrl s保存立即调整代码格式是真的很爽。 除此之外&#xff0c;团队使用eslint也是好处颇多…

基于微信小程序的小区疫情防控ssm+论文源码调试讲解

第2章 程序开发技术 2.1 Mysql数据库 为了更容易理解Mysql数据库&#xff0c;接下来就对其具备的主要特征进行描述。 &#xff08;1&#xff09;首选Mysql数据库也是为了节省开发资金&#xff0c;因为网络上对Mysql的源码都已进行了公开展示&#xff0c;开发者根据程序开发需…

Win11安装安卓子系统WSA

文章目录 简介一、启用Hyper-V二、安装WSA三、安装APKAPK商店参考文献 简介 WSA&#xff1a;Windows Subsystem For Android 一、启用Hyper-V 控制面板 → 程序和功能 → 启用或关闭 Windows 功能 → 勾选 Hyper-V 二、安装WSA 进入 Microsoft Store&#xff0c;下拉框改为 …

[面试题]--索引用了什么数据结构?有什么特点?

答&#xff1a;使用了B树&#xff1a; 时间复杂度&#xff1a;O(logN),可以有效控制树高 B树特点&#xff1a; 1.叶子节点之间有相互链接的作用&#xff0c;会指向下一个相近的兄弟节点。 MySQL在组织叶子节点使用的是双向链表 2.非叶子节点的值都保存在叶子节点当中 MySQL非叶…

Element plus 下拉框组件选中一个选项后显示的是 value 而不是 label

最近刚进行 Vue3 Element plus 项目实践&#xff0c;在进行表单二次封装的时候&#xff0c;表单元素 select 下拉框组件选中一个选项后显示的是 value 而不是 label&#xff0c;下面上代码&#xff1a; 原来的写法&#xff1a; <el-selectv-if"v.type select"…

bean创建源码

去字节面试&#xff0c;直接让人出门左拐&#xff1a;Bean 生命周期都不知道&#xff01; spring启动创建bean流程 下面就接上了 bean生命周期 doGetBean Object sharedInstance this.getSingleton(beanName); sharedInstance this.getSingleton(beanName, new ObjectF…

【C++】- 掌握STL List类:带你探索双向链表的魅力

文章目录 前言&#xff1a;一.list的介绍及使用1. list的介绍2. list的使用2.1 list的构造2.2 list iterator的使用2.3 list capacity2.4 list element access2.5 list modifiers2.6 list的迭代器失效 二.list的模拟实现1. list的节点2. list的成员变量3.list迭代器相关问题3.1…

泷羽sec学习打卡-brupsuite8伪造IP和爬虫审计

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都 与本人无关,切莫逾越法律红线,否则后果自负 关于brupsuite的那些事儿-Brup-FaskIP 伪造IP配置环境brupsuite导入配置1、扩展中先配置python环境2、安…

挑战一个月基本掌握C++(第五天)了解运算符,循环,判断

一 运算符 运算符是一种告诉编译器执行特定的数学或逻辑操作的符号。C 内置了丰富的运算符&#xff0c;并提供了以下类型的运算符&#xff1a; 算术运算符关系运算符逻辑运算符位运算符赋值运算符杂项运算符 1.1 算术运算符 假设变量 A 的值为 10&#xff0c;变量 B 的值为…

JAVA没有搞头了吗?

前言 今年的Java程序员群体似乎承受着前所未有的焦虑。投递简历无人问津&#xff0c;难得的面试机会也难以把握&#xff0c;即便成功入职&#xff0c;也往往难以长久。于是&#xff0c;不少程序员感叹&#xff1a;互联网的寒冬似乎又一次卷土重来&#xff0c;环境如此恶劣&…

Linux -- 线程控制相关的函数

目录 pthread_create -- 创建线程 参数 返回值 代码 -- 不传 args&#xff1a; 编译时带 -lpthread 运行结果 为什么输出混杂&#xff1f; 如何证明两个线程属于同一个进程&#xff1f; 如何证明是两个执行流&#xff1f; 什么是LWP&#xff1f; 代码 -- 传 args&a…