机器学习作业3____决策树(CART算法)

news2025/1/14 4:09:11

目录

一、简介

 二、具体步骤

样例:

三、代码

四、结果

五、问题与解决


一、简介

CART(Classification and Regression Trees)是一种常用的决策树算法,可用于分类和回归任务。这个算法由Breiman等人于1984年提出,它的主要思想是通过递归地将数据集划分为两个子集,然后在每个子集上继续划分,直到满足某个停止条件为止。

CART算法在分类和回归问题上表现良好,并且能够处理多种数据类型(包括离散型和连续型特征)。由于其简单、易于理解和实现,以及在一些应用中的良好性能,CART算法被广泛应用于实践中。

 二、具体步骤

  1. 计算整体数据集的基尼指数:

    • 首先,计算整个数据集的基尼指数。基尼指数表示从数据集中随机选择两个样本,它们属于不同类别的概率。对于每个节点,基尼指数可以通过以下公式计算: \text{Gini}(D) = 1 - \sum_{i=1}^k (p_i)^2 ( D ) 是当前节点的数据集,( k ) 是类别的数量,( p_i ) 是第 ( i ) 类样本在数据集 ( D ) 中的频率。
  2. 选择最佳特征和切分点:

    • 对于每个特征,遍历其所有可能的取值作为切分点。
    • 对于每个切分点,将数据集分为两个子集:左子集(特征值小于等于切分点)和右子集(特征值大于切分点)。
    • 计算基尼指数来衡量使用当前特征和切分点进行划分后的加权基尼指数。数学上,对于一个特征 ( A ) 的某个切分点 ( t ),其左子集和右子集的基尼指数可以计算为: \text{Gini}(A, t) = \frac{|D{\text{l}}|}{|D|} \times \text{Gini}(D_{\text{l}}) + \frac{|D_{\text{r}}|}{|D|} \times \text{Gini}(D_{\text{r}})
    • 选择使得基尼指数最小的特征和切分点作为当前节点的划分依据。
  3. 递归划分子集:

    • 根据选择的最佳特征和切分点,将当前节点的数据集划分为两个子集:左子集和右子集。
    • 对每个子集递归地重复步骤 1 和步骤 2,直到达到停止条件,例如达到最大深度、节点样本数量小于预设阈值等。
  4. 停止条件:

    • 决策树构建过程中,需要设定停止条件,以防止过度拟合或无限生长。常见的停止条件包括:达到最大深度、节点样本数量小于预设阈值、节点基尼指数低于阈值等。
  5. 剪枝:

    • 在决策树生长完成后,可以应用剪枝来降低树的复杂度和提高泛化能力。剪枝的目标是通过移除部分节点或子树来减小模型的复杂度,常见的剪枝方法有预剪枝和后剪枝。
    • 预剪枝:预剪枝是在决策树构建过程中,在决策树生长过程中进行判断并提前终止树的生长。在预剪枝中,可以设置一些停止生长的条件,例如限制树的最大深度、节点中最小样本数、基尼不纯度的阈值等。当达到任何一个预设条件时,就停止分裂节点并将该节点标记为叶子节点,不再继续向下生长,从而避免过拟合。
    • 后剪枝:后剪枝是在决策树构建完成之后,对已生成的决策树进行修剪来减少过拟合。后剪枝通过剪掉一些子树或者将子树替换为叶子节点来减少树的复杂度,从而提高泛化能力。后剪枝的过程通常是自底向上地遍历决策树,然后对每个内部节点尝试剪枝,判断剪枝后的决策树性能是否提升,如果提升则进行剪枝操作。

    • 预剪枝和后剪枝各有优劣势:预剪枝可以在构建树的过程中避免过拟合,但可能会导致欠拟合,因为它在生长时就限制了树的复杂度。后剪枝在构建完整个树后进行修剪,更容易实现,但可能会由于过拟合而导致剪枝效果不佳。

  6. 预测:

    • 使用生成的决策树对新样本进行分类预测。
    • 从根节点开始,根据特征值逐步向下遍历树的分支,直到到达叶子节点,然后将叶子节点的预测值作为样本的预测结果。

样例:

一些解释:

分裂阈值是决策树算法中用来划分数据集的一个值,它决定了将数据集分成两部分的标准。在每个节点上,决策树算法会选择一个特征和一个分裂阈值,将数据集分为两部分,使得分裂后的子集尽可能地纯净(即属于同一类别)。

假设有一个二维数据集,包含两个特征和一个类别:

X1X2Y
1.02.00
2.03.00
2.02.01
3.04.01
3.03.00

在构建过程中,需要选择一个特征和一个分裂阈值来将数据集划分为左右两个子集。假设现在可以选择特征X1,并将分裂阈值设为2.5。所有X1小于2.5的样本将被划分到左子集,而X1大于等于2.5的样本将被划分到右子集。

首先,我们根据选定的特征和分裂阈值将数据集划分成两个子集。

左子集(X1 < 2.5):

X1X2Y
1.02.00
2.03.00
2.02.01

右子集(X1 >= 2.5):

X1X2Y
3.04.01
3.03.00

对于左子集:

  • 类别0的频率:p_0 = \frac{2}{3} = 0.67
  • 类别1的频率:p_1 = \frac{1}{3} = 0.33

左子集的基尼指数为:

Gini_{left} = 1 - (0.67^2 + 0.33^2)= 0.4422

对于右子集:

  • 类别0的频率:p_0 = \frac{1}{2} = 0.5
  • 类别1的频率:p_1 = \frac{1}{2} = 0.5

右子集的基尼指数为:

Gini_{right} = 1 - (0.5^2 + 0.5^2) = 0.5

计算加权基尼指数

Weighted Gini Index = \frac{3}{5} \times Gini_{left} + \frac{2}{5} \times Gini_{right}= \frac{3}{5} \times 0.4422 + \frac{2}{5} \times 0.5 = 0.66532

在这个例子中,选定特征X1和分裂阈值2.5的加权基尼指数为约0.66532。

三、代码

import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score



class DecisionTreeClassifier:
    def __init__(self, max_depth=None, min_samples_split=2):
        """
        初始化决策树分类器
        
        参数:
        - max_depth: 决策树的最大深度,控制树的生长。默认为None,表示不限制深度。
        - min_samples_split: 内部节点再划分所需的最小样本数。默认为2。
        """
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
    
    def fit(self, X, y):
        """
        根据训练数据拟合模型
        
        参数:
        - X: 训练数据的特征数组。
        - y: 训练数据的标签数组。
        """
        self.n_classes = len(np.unique(y))
        self.n_features = X.shape[1]
        self.tree_ = self._grow_tree(X, y)
    
    def _grow_tree(self, X, y, depth=0):
        """
        递归地构建决策树
        
        参数:
        - X: 当前节点的特征数组。
        - y: 当前节点的标签数组。
        - depth: 当前节点的深度。
        """
        # 计算每个类别的样本数
        n_samples_per_class = [np.sum(y == i) for i in range(self.n_classes)]
        # 预测当前节点的类别为样本数最多的类别
        predicted_class = np.argmax(n_samples_per_class)
        
        if depth < self.max_depth and X.shape[0] >= self.min_samples_split:
            best_gini = float('inf')
            best_feature = None
            best_threshold = None
            
            # 遍历每个特征
            for feature in range(self.n_features):
                unique_values = np.unique(X[:, feature])
                # 遍历每个特征值作为分裂阈值
                for threshold in unique_values:
                    y_left = y[X[:, feature] < threshold]
                    y_right = y[X[:, feature] >= threshold]
                    
                    if len(y_left) == 0 or len(y_right) == 0:
                        continue
                    
                    # 计算基尼不纯度
                    gini = self._gini_impurity(y_left, y_right)
                    
                    # 选择最小基尼不纯度对应的特征和阈值
                    if gini < best_gini:
                        best_gini = gini
                        best_feature = feature
                        best_threshold = threshold
            
            # 如果存在可以降低基尼不纯度的分裂,则继续构建子树
            if best_gini < float('inf'):
                left_indices = X[:, best_feature] < best_threshold
                X_left, y_left = X[left_indices], y[left_indices]
                X_right, y_right = X[~left_indices], y[~left_indices]
                
                left_subtree = self._grow_tree(X_left, y_left, depth + 1)
                right_subtree = self._grow_tree(X_right, y_right, depth + 1)
                
                return {'feature': best_feature, 'threshold': best_threshold,
                        'left': left_subtree, 'right': right_subtree}
        
        # 当无法继续分裂时,返回当前节点的预测类别
        return {'predicted_class': predicted_class}
    
    def _gini_impurity(self, y_left, y_right):
        """
        计算基尼不纯度
        
        参数:
        - y_left: 左子节点的标签数组。
        - y_right: 右子节点的标签数组。
        """
        n_left, n_right = len(y_left), len(y_right)
        n_total = n_left + n_right
        
        p_left = np.array([np.sum(y_left == c) / n_left for c in range(self.n_classes)])
        p_right = np.array([np.sum(y_right == c) / n_right for c in range(self.n_classes)])
        
        # 计算左右节点的基尼不纯度
        gini_left = 1.0 - np.sum(p_left ** 2)
        gini_right = 1.0 - np.sum(p_right ** 2)
        
        # 计算加权基尼不纯度
        gini = (n_left / n_total) * gini_left + (n_right / n_total) * gini_right
        return gini
    
    def predict(self, X):
        """
        对输入数据进行预测
        
        参数:
        - X: 待预测数据的特征数组。
        
        返回:
        - 预测的标签数组。
        """
        return np.array([self._predict(inputs) for inputs in X])
    
    def _predict(self, inputs):
        """
        递归地预测单个样本的标签
        
        参数:
        - inputs: 单个样本的特征数组。
        
        返回:
        - 预测的标签。
        """
        node = self.tree_
        while 'predicted_class' not in node:
            feature_value = inputs[node['feature']]
            # 根据特征值和阈值判断进入左子树还是右子树
            if feature_value < node['threshold']:
                node = node['left']
            else:
                node = node['right']
        return node['predicted_class']

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化决策树分类器
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(X_train, y_train)
print("test:"+str(X_test)+"\n         pre:"+str(y_test))
print("Predictions:", clf.predict(X_test))

accuracy = accuracy_score(y_test, clf.predict(X_test))
print("准确率:"+str(accuracy*100)+'%')


四、结果

本次实验采用的是python自带的鸢尾花数据集,将数据集8:2分为训练集和测试集,将树的最大深度设置为2,得到的结果如下:

可以看到只有一个点出现了错误,预测的效果不错。

五、问题与解决

问题1.

  1. 过拟合:决策树容易在训练数据上过拟合,即模型过于复杂,过度拟合训练数据中的噪声或特定样本,导致在测试数据上表现不佳。欠拟合:与过拟合相反,如果决策树过于简单,可能无法捕捉数据中的复杂关系,导致在训练和测试数据上都表现不佳。

  2.  解决减缓过拟合:降低模型复杂度、增加训练数据量、使用正则化技术、特征选择等。减缓欠拟合:增加模型复杂度、添加更多特征、使用更复杂的模型等

        

 问题2.

  1. 内存消耗问题:如果数据集过大或者决策树过深,可能导致内存消耗过大,甚至导致程序崩溃或者运行缓慢。

  2. 解决:限制最大深度:通过设置决策树的最大深度来限制树的复杂度,从而减少内存消耗。限制叶子节点数量:设置叶子节点的最小样本数,以避免树过于深。使用剪枝:在树的训练过程中或者之后对树进行剪枝,去掉不必要的分支和节点。分批处理数据:将数据集划分为小批次,并逐批进行训练和预测,以减少一次性处理的内存需求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1623437.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Clion连接MySQL数据库:实现C/C++语言与MySQL交互

确保你的电脑里已经有了MySQL。 1、找到MySQL的目录 2、进入lib目录 3、复制libmysql.dll和libmysql.lib文件 4、将这俩文件粘贴到你的clion项目的cmake-build-debug目录下 如果不是在这个目录下&#xff0c;运行时会出以下错误报错&#xff1a; 进程已结束&#xff0c;退…

火绒安全的应用介绍

火绒安全软件是一款集成了杀毒、防御和管控功能的安全软件&#xff0c;旨在为用户提供全面的计算机安全保障。以下是火绒安全软件的一些详细介绍&#xff1a; 系统兼容性强&#xff1a;该软件支持多种操作系统&#xff0c;包括Windows 11、Windows 10、Windows 8、Windows 7、…

AI预测福彩3D第9套算法实战化测试第3弹2024年4月25日第3次测试

今天继续进行新算法的测试&#xff0c;今天是第3次测试。好了&#xff0c;废话不多说了&#xff0c;直接上图上结果。 2024年4月25日福彩3D预测结果 6码定位方案如下&#xff1a; 百位&#xff1a;6、4、3、7、2、8 十位&#xff1a;8、4、9、3、1、0 个位&#xff1a;7、6、9、…

Linux进程间通信 管道系列: 利用管道实现进程池(匿名和命名两个版本)

Linux进程间通信 管道系列: 利用管道实现进程池[匿名和命名两个版本] 一.匿名管道实现进程池1.池化技术2.搭架子3.代码编写1.创建子进程1.利用命令行参数传入创建几个子进程2.创建管道和子进程(封装Channel类)1.先描述2.在组织3.开始创建 2.封装MainProcess类3.控制子进程1.封装…

无限滚动分页加载与下拉刷新技术探析:原理深度解读与实战应用详述

滚动分页加载&#xff08;也称为无限滚动加载、滚动分页等&#xff09;是一种常见的Web和移动端应用界面设计模式&#xff0c;用于在用户滚动到底部时自动加载下一页内容&#xff0c;而无需点击传统的分页按钮。这种设计旨在提供更加流畅、连续的浏览体验&#xff0c;减少用户交…

人耳的七个效应

1、掩蔽效应 • 人们在安静环境中能够分辨出轻微的声音&#xff0c;即人耳对这个声音的听域很低&#xff0c;但在嘈杂的环境中轻微的声音就会被淹没掉&#xff0c;这时将轻微的声音增强才能听到。 • 这种在聆听时&#xff0c;一个声音的听阈因另一声音的出现而提高的现象&…

ThinkPad E14 Gen 4,R14 Gen 4,E15 Gen 4(21E3,21E4,21E5,21E6,21E7)原厂Win11系统恢复镜像下载

lenovo联想ThinkPad笔记本电脑原装出厂Windows11系统安装包&#xff0c;恢复出厂开箱状态一模一样 适用型号&#xff1a;ThinkPad E14 Gen 4,ThinkPad R14 Gen 4,ThinkPad E15 Gen 4 (21E3,21E4,21E5,21E6,21E7) 链接&#xff1a;https://pan.baidu.com/s/1QRHlg2yT_RFQ81Tg…

服务部署后出错怎么快速调试?试试JDWP协议

前言 原文链接&#xff1a;教你使用 JDWP 远程调试服务 在我们日常开发工作中&#xff0c;经常会遇到写好的代码线上出了问题&#xff0c;但是本地又无法复现&#xff0c;看着控制台输出的日志恨自己当初没有多打几条日志&#xff0c;然后追着日志一条一条查&#xff0c;不说…

安装 Nginx 的三种方式

通过 Nginx 源码安装需要提前准备的内容&#xff1a; GCC 编译器 Nginx 是使用 C 语言编写的程序&#xff0c;因此想要运行 Nginx 就需要安装一个编译工具 GCC 就是一个开源的编译器集合&#xff0c;用于处理各种各样的语言&#xff0c;其中就包含了 C 语言 使用命令 yum i…

python基础语法--列表

一、列表的概念 列表&#xff08;List&#xff09;是一种有序、可变、允许重复元素的数据结构。列表用于存储一组相关的元素&#xff0c;并且可以根据需要动态地进行增加、删除、修改和访问。以下是列表的主要特点和操作&#xff1a; 有序性&#xff1a; 列表中的元素是按照它…

工作与生活,如何找到平衡点,实现双赢?(2个简单工具答案一目了然)

前言 很多 35岁左右上有老下有小的程序员会陷入一个瓶颈期&#xff0c;在工作上想努力多赚钱&#xff0c;但是每天回到家 23 点&#xff0c;老婆孩子早已熟睡。好不容易周末有点休息时间&#xff0c;但是一个电话接一个&#xff0c;由于是生产问题还不得不接。 那么职场人应该如…

激活IDM下载器并配置百度网盘

前言&#xff1a; 最近想下载一些软件&#xff0c;奈何不充钱的百度网盘的速度实在太慢了&#xff0c;不到一个G的文件夹奈何下了一晚上&#xff0c;只能重新找一下idm的下载了。 但是idm的正版是需要收费的&#xff0c;所以有白嫖党的破解版就横空出世了。 正文&#xff1a…

【目标跟踪】ByteTrack详解与代码细节

文章目录 一、前言二、代码详解2.1、新起航迹2.2、预测2.3、匹配2.4、结果发布2.5、总结 三、流程图四、部署 一、前言 论文地址&#xff1a;https://arxiv.org/pdf/2110.06864.pdf git地址&#xff1a;https://github.com/ifzhang/ByteTrack ByteTrack 在是在 2021 年 10 月…

OpenAIGPT-4.5提前曝光?

OpenAI GPT-4.5的神秘面纱&#xff1a;科技界的震撼新篇章 在人工智能的世界里&#xff0c;每一次技术的飞跃都不仅仅是一次更新&#xff0c;而是对未来无限可能的探索。近日&#xff0c;科技巨头OpenAI似乎再次站在了这场革命的前沿&#xff0c;其潜在的新产品——GPT-4.5 Tur…

Https协议原理剖析【计算机网络】【三种加密方法 | CA证书 】

目录 一&#xff0c;fidler工具 前提知识 二&#xff0c;Https原理解析 1. 中间人攻击 2. 常见的加密方式 1&#xff09;. 对称加密 2&#xff09;. 非对称加密 对称加密 4&#xff09;. CA证书 1. 数据摘要 3. 数字签名 CA证书 理解数据签名 存在的安全疑问&am…

ubuntu ROS1 C++下使用免安装eigen库的方法

1、eigen库的定义及头文件介绍 Eigen是一个高层次的C 库&#xff0c;有效支持线性代数&#xff0c;矩阵和矢量运算&#xff0c;数值分析及其相关的算法。 2、获取eigen库安装包 下载地址&#xff1a;eigen库官网 &#xff0c;如下图所示&#xff1a; 下载最新版tar.bz2即可&…

嵌入式Linux driver开发实操(二十三):ASOC

ASoC的结构及嵌入到Linux音频框架 ALSA片上系统(ASoC)层的总体项目目标是为嵌入式片上系统处理器(如pxa2xx、au1x00、iMX等)和便携式音频编解码器提供更好的ALSA支持。在ASoC子系统之前,内核中对SoC音频有一些支持,但它有一些局限性: ->编解码器驱动程序通常与底层So…

甘特图是什么?如何利用其优化项目管理流程?

甘特图是项目管理软件中十分常见的功能&#xff0c;可以说每一个项目经理都要学会使用甘特图才能更好的交付项目。什么是甘特图&#xff1f;甘特图用来做什么&#xff1f;简单来说一种将项目任务与时间关系直观表示的图表&#xff0c;直观地展示了任务进度和持续时间。 一、甘特…

博睿数据亮相GOPS全球运维大会,Bonree ONE 2024春季正式版发布!

2024年4月25日&#xff0c;博睿数据 Bonree ONE 2024 春季正式版焕新发布。同时&#xff0c;博睿数据AIOps首席专家兼产品总监贺安辉携核心产品新一代一体化智能可观测平台 Bonree ONE 亮相第二十二届 GOPS 全球运维大会深圳站。 Bonree ONE 2024 春季版产品重点升级数据采集、…

7-30 字符串的冒泡排序

题目链接&#xff1a;7-30 字符串的冒泡排序 一. 题目 1. 题目 2. 输入输出样例 3. 限制 二、代码 1. 代码实现 #include <stdio.h> #include <string.h> #include <stdlib.h>// 获取输入的字符串 char **arrayGet(int len) {char **array;array malloc…