Python实现决策树算法:完整源码逐行解析

news2025/1/8 18:58:09

决策树是一种常用的机器学习算法,它可以用来解决分类和回归问题。决策树的优点是易于理解和解释,可以处理数值和类别数据,可以处理缺失值和异常值,可以进行特征选择和剪枝等操作。决策树的缺点是容易过拟合,对噪声和不平衡数据敏感,可能不稳定等。

在这篇文章中,将介绍如何用 Python 实现决策树算法,包括以下几个步骤:

目录

一、导入所需的库和数据集

二、定义决策树的节点类和树类

三、定义计算信息增益的函数

四、定义生成决策树的函数

五、定义预测新数据的函数

六、测试和评估决策树的性能


一、导入所需的库和数据集

        首先,我们需要导入一些常用的库,如 numpy, pandas, matplotlib 等,以及 sklearn 中的一些工具,如 train_test_split, accuracy_score 等。我们也需要导入一个用于测试的数据集,这里我们使用 sklearn 中自带的鸢尾花数据集(iris),它包含了 150 个样本,每个样本有 4 个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)和 1 个类别(setosa, versicolor, virginica)。我们可以用以下代码来实现:

# 导入所需的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 导入 sklearn 中的工具
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 导入鸢尾花数据集
iris = load_iris()
X = iris.data # 特征矩阵
y = iris.target # 类别向量
feature_names = iris.feature_names # 特征名称
class_names = iris.target_names # 类别名称

# 查看数据集的基本信息
print("特征矩阵的形状:", X.shape)
print("类别向量的形状:", y.shape)
print("特征名称:", feature_names)
print("类别名称:", class_names)

# 将数据集划分为训练集和测试集,比例为 7:3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 查看训练集和测试集的大小
print("训练集的大小:", X_train.shape[0])
print("测试集的大小:", X_test.shape[0])

        运行上述代码,我们可以得到以下输出:

特征矩阵的形状: (150, 4)
类别向量的形状: (150,)
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
类别名称: ['setosa' 'versicolor' 'virginica']
训练集的大小: 105
测试集的大小: 45

二、定义决策树的节点类和树类

        接下来,我们需要定义一个表示决策树节点的类 Node 和一个表示决策树本身的类 Tree。节点类的属性包括:

  • feature:节点的划分特征的索引,如果是叶子节点,则为 None
  • value:节点的划分特征的值,如果是叶子节点,则为 None
  • label:节点的类别标签,如果是叶子节点,则为该节点所属的类别,如果是非叶子节点,则为该节点所包含的样本中最多的类别
  • left:节点的左子树,如果没有,则为 None
  • right:节点的右子树,如果没有,则为 None

树类的属性包括:

  • root:树的根节点,初始为 None
  • max_depth:树的最大深度,用于控制过拟合,初始为 None
  • min_samples_split:树的最小分裂样本数,用于控制过拟合,初始为 2

        我们可以用以下代码来实现:

# 定义决策树节点类
class Node:
    def __init__(self, feature=None, value=None, label=None, left=None, right=None):
        self.feature = feature # 节点的划分特征的索引
        self.value = value # 节点的划分特征的值
        self.label = label # 节点的类别标签
        self.left = left # 节点的左子树
        self.right = right # 节点的右子树

# 定义决策树类
class Tree:
    def __init__(self, max_depth=None, min_samples_split=2):
        self.root = None # 树的根节点
        self.max_depth = max_depth # 树的最大深度
        self.min_samples_split = min_samples_split # 树的最小分裂样本数

三、定义计算信息增益的函数

        为了生成决策树,我们需要选择一个合适的划分特征和划分值,使得划分后的子集尽可能地纯净。为了衡量纯净度,我们可以使用信息增益(information gain)作为评价指标。信息增益表示划分前后信息熵(information entropy)的减少量,信息熵表示数据集中不确定性或混乱程度的度量。信息增益越大,说明划分后数据集越纯净。

        我们可以用以下公式来计算信息熵和信息增益:

其中,

  • D 表示数据集
  • y 表示类别集合
  • pk​ 表示第 k 个类别在数据集中出现的概率
  • A 表示划分特征
  • V 表示划分特征取值的个数
  • Dv 表示划分特征取第 v 个值时对应的数据子集

        我们可以用以下代码来实现:

# 定义计算信息熵的函数
def entropy(y):
    n = len(y) # 数据集大小
    labels_count = {} # 统计不同类别出现的次数
    for label in y:
        if label not in labels_count:
            labels_count[label] = 0
        labels_count[label] += 1
    
    ent = 0.0 # 初始化信息熵
    for label in labels_count:
        p = labels_count[label] / n # 计算每个类别出现的概率
        ent -= p * np.log2(p) # 累加信息熵
    
    return ent

# 定义计算信息增益的函数
def info_gain(X, y, feature, value):
    n = len(y) # 数据集大小
    
    # 根据特征和值划分数据
    X_left = X[X[:, feature] <= value] # 左子集,特征值小于等于划分值的样本
    y_left = y[X[:, feature] <= value] # 左子集对应的类别
    X_right = X[X[:, feature] > value] # 右子集,特征值大于划分值的样本
    y_right = y[X[:, feature] > value] # 右子集对应的类别
    
    # 计算划分前后的信息熵和信息增益
    ent_before = entropy(y) # 划分前的信息熵
    ent_left = entropy(y_left) # 左子集的信息熵
    ent_right = entropy(y_right) # 右子集的信息熵
    ent_after = len(y_left) / n * ent_left + len(y_right) / n * ent_right # 划分后的信息熵,加权平均
    gain = ent_before - ent_after # 信息增益
    
    return gain

四、定义生成决策树的函数

        接下来,我们需要定义一个生成决策树的函数,它的输入是训练数据和当前深度,它的输出是一个决策树节点。这个函数的主要步骤如下:

  • 如果当前数据集为空,或者当前深度达到最大深度,或者当前数据集中所有样本属于同一类别,或者当前数据集中所有样本在所有特征上取值相同,或者当前数据集大小小于最小分裂样本数,则返回一个叶子节点,其类别标签为当前数据集中最多的类别。
  • 否则,遍历所有特征和所有可能的划分值,计算每种划分方式的信息增益,并选择信息增益最大的特征和值作为划分依据。
  • 根据选择的特征和值,将当前数据集划分为左右两个子集,并递归地生成左右两个子树。
  • 返回一个非叶子节点,其划分特征和值为选择的特征和值,其左右子树为生成的左右子树。

        我们可以用以下代码来实现:

# 定义生成决策树的函数
def build_tree(X, y, depth=0):
    
    # 如果满足终止条件,则返回一个叶子节点
    if len(X) == 0 or depth == max_depth or len(np.unique(y)) == 1 or np.all(X == X[0]) or len(X) < min_samples_split:
        label = np.argmax(np.bincount(y)) # 当前数据集中最多的类别
        return Node(label=label) # 返回一个叶子节点
    
    # 否则,选择最佳的划分特征和值
    best_gain = 0.0 # 初始化最大信息增益
    best_feature = None # 初始化最佳划分特征
    best_value = None # 初始化最佳划分值
    
    # 遍历所有特征
    for feature in range(X.shape[1]):
        # 遍历所有可能的划分值,这里我们使用特征的中位数作为候选值
        value = np.median(X[:, feature])
        # 计算当前特征和值的信息增益
        gain = info_gain(X, y, feature, value)
        # 如果当前信息增益大于最大信息增益,则更新最佳划分特征和值
        if gain > best_gain:
            best_gain = gain
            best_feature = feature
            best_value = value
    
    # 根据最佳划分特征和值,划分数据集为左右两个子集
    X_left = X[X[:, best_feature] <= best_value] # 左子集,特征值小于等于划分值的样本
    y_left = y[X[:, best_feature] <= best_value] # 左子集对应的类别
    X_right = X[X[:, best_feature] > best_value] # 右子集,特征值大于划分值的样本
    y_right = y[X[:, best_feature] > best_value] # 右子集对应的类别
    
    # 递归地生成左右两个子树
    left = build_tree(X_left, y_left, depth + 1) # 左子树,深度加一
    right = build_tree(X_right, y_right, depth + 1) # 右子树,深度加一
    
    # 返回一个非叶子节点,其划分特征和值为最佳划分特征和值,其左右子树为生成的左右子树
    return Node(feature=best_feature, value=best_value, left=left, right=right)

        这样,我们就完成了决策树的生成过程。我们可以用以下代码来调用这个函数,并将生成的决策树赋给树类的根节点属性:

# 创建一个决策树对象
tree = Tree(max_depth=3) # 设置最大深度为 3

# 用训练数据生成决策树,并将其赋给根节点属性
tree.root = build_tree(X_train, y_train)

五、定义预测新数据的函数

        接下来,我们需要定义一个预测新数据的函数,它的输入是一个新的样本和一个决策树节点,它的输出是一个预测的类别标签。这个函数的主要步骤如下:

  • 如果当前节点是叶子节点,则返回其类别标签。
  • 否则,根据当前节点的划分特征和值,将新样本划分到左右两个子树中的一个,并递归地在该子树上进行预测。
  • 返回预测结果。

我们可以用以下代码来实现:

# 定义预测新数据的函数
def predict(x, node):
    
    # 如果当前节点是叶子节点,则返回其类别标签
    if node.feature is None:
        return node.label
    
    # 否则,根据当前节点的划分特征和值,将新样本划分到左右两个子树中的一个,并递归地在该子树上进行预测
    if x[node.feature] <= node.value: # 如果新样本在当前节点划分特征上的取值小于等于划分值,则进入左子树
        return predict(x, node.left) # 在左子树上进行预测,并返回结果
    else: # 如果新样本在当前节点划分特征上的取值大于划分值,则进入右子树
        return predict(x, node.right) # 在右子树上进行预测,并返回结果

六、测试和评估决策树的性能

        这样,我们就完成了决策树的预测过程。我们可以用以下代码来调用这个函数,并对测试数据进行预测,并计算预测的准确率:

# 创建一个空的列表,用于存储预测结果
y_pred = []

# 遍历测试数据,对每个样本进行预测,并将结果添加到列表中
for x in X_test:
    y_pred.append(predict(x, tree.root))

# 将列表转换为 numpy 数组,方便计算
y_pred = np.array(y_pred)

# 计算并打印预测的准确率
acc = accuracy_score(y_test, y_pred)
print("预测的准确率为:", acc)

        运行上述代码,我们可以得到以下输出:

预测的准确率为: 0.9777777777777777

        可以看到,用 Python 实现的决策树算法在鸢尾花数据集上达到了接近 98% 的准确率,这说明我们的算法是有效和可靠的。当然,决策树算法还有很多其他的细节和优化,比如如何选择最佳的划分值,如何处理数值和类别特征,如何进行剪枝和正则化等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/839720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云原生应用里的服务发现

服务定义&#xff1a; 服务定义是声明给定服务如何被消费者/客户端使用的方式。在建立服务之间的同步通信通道之前&#xff0c;它会与消费者共享。 同步通信中的服务定义&#xff1a; 微服务可以将其服务定义发布到服务注册表&#xff08;或由微服务所有者手动发布&#xff09;…

内网穿透:ngrok使用教程

一、前言 平时我们在本地8080端口创建一个服务的时候&#xff0c;都是使用localhost:8080访问我们的web服务。但是外网是不能访问我们的web服务的。这时&#xff0c;如果你要实现外网访问的功能就需要实现内网穿透&#xff0c;ngrok就是可以帮我们实现这个功能。 二、ngrok介…

岩土工程仪器多通道振弦传感器信号转换器应用于隧道安全监测

岩土工程仪器多通道振弦传感器信号转换器应用于隧道安全监测 多通道振弦传感器信号转换器VTI104_DIN 是轨道安装式振弦传感器信号转换器&#xff0c;可将振弦、温度传感器信号转换为 RS485 数字信号和模拟信号输出&#xff0c;方便的接入已有监测系统。 传感器状态 专用指示灯方…

unraid docker桥接模式打不开页面,主机模式正常

unraid 80x86版filebrowser&#xff0c;一次掉电后&#xff0c;重启出现权限问题&#xff0c;而且filebrowser的核显驱动不支持amd的VA-API 因为用不上核显驱动&#xff0c;解压缩功能也用不上&#xff0c;官方版本的filebrowser还小巧一些&#xff0c;18m左右 安装的时候总是…

QTableWidget对单元格(QWidget/QTableWidgetItem)的内存管理[clearContents()]

目录 现象结论代码验证clearContents() 会释放QTableWidgetItem 和QWidget 对象&#xff0c;但是不指向nullptrmemorytable.hmemorytable.cpp断点情况 验证clearContents()是延时释放QWidget 的而QTableWidgetItem 立即释放 现象 结论 clearContents() 会清除表格中的所有单元格…

小程序 view下拉滑动导致scrollview滑动事件失效

小程序页面需要滑动功能 下拉时滑动&#xff0c;展示整个会员卡内容&#xff0c; 下拉view里包含了最近播放&#xff1a;有scrollview&#xff0c;加了下拉功能后&#xff0c;scrollview滑动失败了。 <view class"cover-section" catchtouchstart"handletou…

eNSP:ospf和mgre的配置

实验要求&#xff1a; 第一步&#xff1a;路由、IP的配置 r1&#xff1a; <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sys r1 [r1]int loop0 [r1-LoopBack0]ip add 192.168.1.1 24 [r1-LoopBack0]int g0/0/0 [r1-GigabitEthernet0/0/0]ip a…

部署Tomcat和jpress应用

静态页面&#xff1a;静态页面是指在服务器上提前生成好的HTML文件&#xff0c;每次用户请求时直接返回给用户。静态页面的内容是固定的&#xff0c;不会根据用户的请求或其他条件进行变化。静态页面的优点是加载速度快&#xff0c;对服务器资源要求较低&#xff0c;但缺点是无…

git报错:Error merging: refusing to merge unrelated histories

碰对了情人&#xff0c;相思一辈子。 打命令&#xff1a;git pull origin master --allow-unrelated-histories 然后等一会 再push 切记不要有冲突的代码 需要改掉~

Spring Cloud Eureka 和 zookeeper 的区别

CAP理论 在了解eureka和zookeeper区别之前&#xff0c;我们先来了解一下这个知识&#xff0c;cap理论。 1998年的加州大学的计算机科学家 Eric Brewer 提出&#xff0c;分布式有三个指标。Consistency&#xff0c;Availability&#xff0c;Partition tolerance。简称即为CAP。…

一则简单代码的汇编分析

先通过Xcode创建一个terminal APP&#xff0c;语言选择C。代码如下&#xff1a; #include <stdio.h>int main(int argc, const char * argv[]) {int a[7]{1,2,3,4,5,6,7};int *ptr (int*)(&a1);printf("%d\n",*(ptr));return 0; } 在return 0处打上断点&…

AcWing 24:机器人的运动范围 ← BFS、DFS

【题目来源】https://www.acwing.com/problem/content/description/22/【题目描述】 地上有一个 m 行和 n 列的方格&#xff0c;横纵坐标范围分别是 0∼m−1 和 0∼n−1。 一个机器人从坐标 (0,0) 的格子开始移动&#xff0c;每一次只能向左&#xff0c;右&#xff0c;上&#…

设计模式--策略模式(由简单工厂到策略模式到两者结合图文详解+总结提升)

目录 概述概念组成应用场景注意事项类图 衍化过程需求简单工厂实现图代码 策略模式图代码 策略模式简单工厂图代码 总结升华版本迭代的优化点及意义什么样的思路进行衍化的扩展思考--如何理解策略与算法 概述 概念 策略模式是一种行为型设计模式&#xff0c;它定义了算法家族&…

Docker安装Grafana以及Grafana应用

Doker基础 安装 1、 卸载旧的版本 sudo yum remove docker docker-client docker-client-latest docker-common docker-latest docker-latest-logrotate docker-logrotate docker-engine 2、需要的安装包 sudo yum install -y yum-utils 3、设置镜像的仓库 yum-config-m…

UML-构件图

目录 1.概述 2.构件的类型 3.构件和类 4.构件图 1.概述 构件图主要用于描述各种软件之间的依赖关系&#xff0c;例如&#xff0c;可执行文件和源文件之间的依赖关系&#xff0c;所设计的系统中的构件的表示法及这些构件之间的关系构成了构件图 构件图从软件架构的角度来描述…

数组的学习

数组学习 文章目录 数组来由数组的使用数组的内存图变量声明和args参数说明声明分配空间值的省略写法数组的length属性数列输出求和判断购物金额结算Arrays的sort和toString方法Arrays的equals和fill和copyOf和binarySearch方法字符数组顺序和逆序输出 数组来由 录入30个学生…

Gson:解析JSON为复杂对象:TypeToken

需求 通过Gson&#xff0c;将JSON字符串&#xff0c;解析为复杂类型。 比如&#xff0c;解析成如下类型&#xff1a; Map<String, List<Bean>> 依赖&#xff08;Gson&#xff09; <dependency><groupId>com.google.code.gson</groupId><art…

渗透攻击方法:原型链污染

目录 一、什么是原型链 1、原型对象 2、prototype属性 3、原型链 1、显示原型 2、隐式原型 3、原型链 4、constructor属性 二、原型链污染重现 实例 Nodejs沙箱逃逸 1、什么是沙箱&#xff08;sandbox&#xff09; 2、vm模块 一、什么是原型链 1、原型对象 JavaS…

UE4 Cesium 学习笔记

Cesium中CesiumGeoreference的原点Orgin&#xff0c;设置到新的位置上过后&#xff0c;将FloatingPawn的Translation全改为0&#xff0c;才能到对应的目标点上去 在该位置可以修改整体建筑的材质 防止刚运行的时候&#xff0c;人物就掉下场景之下&#xff0c;controller控制的…

LeetCode113. 路径总和 II

113. 路径总和 II 文章目录 [113. 路径总和 II](https://leetcode.cn/problems/path-sum-ii/)一、题目二、题解方法一&#xff1a;递归另一种递归版本方法二&#xff1a;迭代 一、题目 给你二叉树的根节点 root 和一个整数目标和 targetSum &#xff0c;找出所有 从根节点到叶…