【Pytroch】基于决策树算法的数据分类预测(Excel可直接替换数据)

news2025/1/11 14:05:57

【Pytroch】基于决策树算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

决策树是一种常用的机器学习算法,用于分类和回归任务。它通过树状结构表示数据的决策过程,将数据集划分为不同的子集,并在每个子集上进行决策。决策树分类的原理可以概括如下:

  1. 特征选择: 决策树的分类过程始于根节点,根据某个特征将数据集划分为不同的子集。选择哪个特征用来进行划分是决策树算法的关键之一。常见的特征选择标准包括信息增益、基尼不纯度、方差等。信息增益表示划分前后熵的差异,基尼不纯度测量随机选择两个样本,其类别不一致的概率。

  2. 递归划分: 一旦选择了划分特征,数据集就会被分为多个子集,每个子集对应于划分特征的不同取值。对每个子集,都会重复进行相同的特征选择和划分过程,直到满足终止条件。终止条件可能包括以下情况:

    • 所有数据点属于同一类别。
    • 到达了事先设定的树的深度。
    • 某个节点上的数据点数目低于某个阈值。
  3. 叶节点分类: 当达到终止条件时,形成了决策树的叶节点。每个叶节点对应于一个分类标签,即决策树预测的输出。该标签可以是子集中样本最常见的标签,或者由其他方法确定。

  4. 预测: 对于新的未见样本,决策树通过从根节点开始,根据特征值依次沿着树的分支进行遍历,直到到达叶节点。在叶节点上的标签即为预测的类别。

  5. 剪枝(可选): 构建决策树时,可能会出现过拟合问题,即模型在训练数据上表现良好,但在新数据上表现较差。为了避免过拟合,可以进行剪枝操作,即通过移除某些分支来简化树的结构。

决策树的优点包括易于理解和解释、可处理数值型和分类型数据、能够处理缺失值等。然而,它也存在一些局限性,如容易过拟合、对数据的小变化敏感等。为了应对这些问题,通常会使用集成方法(如随机森林和梯度提升树)来进一步提升决策树的性能。

总的来说,决策树分类是一种基于树状结构的机器学习方法,通过特征选择和递归划分,将数据映射到不同的类别。其简单性和可解释性使其成为许多机器学习问题的有力工具。

2.数学公式

当涉及决策树算法的数学公式时,我们主要关注以下几个方面:信息增益、基尼不纯度以及决策树的构建和预测过程。

  1. 信息增益(Information Gain): 信息增益用于衡量特征选择的好坏,它表示在给定特征下,数据集的不确定性减少了多少。信息增益越大,意味着使用该特征进行划分可以获得更多的信息。

    信息增益 = 划分前的熵 − 加权划分后的熵 \text{信息增益} = \text{划分前的熵} - \text{加权划分后的熵} 信息增益=划分前的熵加权划分后的熵

    划分前的熵:

    H ( D ) = − ∑ i = 1 c p i log ⁡ 2 ( p i ) H(D) = - \sum_{i=1}^{c} p_i \log_2(p_i) H(D)=i=1cpilog2(pi)

    划分后的加权熵:

    H ( D ∣ A ) = ∑ i = 1 v ∣ D i ∣ ∣ D ∣ ⋅ H ( D i ) H(D|A) = \sum_{i=1}^{v} \frac{|D_i|}{|D|} \cdot H(D_i) H(DA)=i=1vDDiH(Di)

    其中, H ( D ) H(D) H(D) 是数据集的熵, p i p_i pi 是数据集中类别 i i i 的概率, H ( D ∣ A ) H(D|A) H(DA) 是在特征 A A A 下的条件熵, v v v 是特征 A A A 的取值数量, ∣ D i ∣ |D_i| Di 是属于取值 i i i 的样本数量。

  2. 基尼不纯度(Gini Impurity): 基尼不纯度用于衡量在数据集中随机选择两个样本,其类别不一致的概率。

    基尼不纯度 = 1 − ∑ i = 1 c p i 2 \text{基尼不纯度} = 1 - \sum_{i=1}^{c} p_i^2 基尼不纯度=1i=1cpi2

    其中, p i p_i pi 是数据集中类别 i i i 的概率。

  3. 决策树的构建过程: 决策树的构建过程基于递归,以下是一个简化的伪代码表示:

    function BuildDecisionTree(data, target, features):
        if all samples in target belong to the same class:
            return a leaf node with the class label
        if features is empty:
            return a leaf node with the most common class label in target
        choose the best feature based on information gain or Gini impurity
        create a decision node for the best feature
        for each unique value v of the best feature:
            create a branch from the decision node for value v
            recursively call BuildDecisionTree with subset data, target, and updated features
        return the decision node
    
    decision_tree = BuildDecisionTree(training_data, training_target, all_features)
    
  4. 决策树的预测过程: 决策树的预测过程通过从根节点开始沿着树的分支进行遍历,直到到达叶节点为止。

    function PredictSample(tree, sample):
        if tree is a leaf node:
            return the class label of the leaf node
        else:
            get the feature value from the sample
            if the value leads to a known branch:
                recursively call PredictSample on the branch
            else:
                return the most common class label among the leaf nodes' labels
    

这些数学公式和伪代码描述了决策树分类的核心原理,涉及到特征选择、信息增益、基尼不纯度以及决策树的构建和预测过程。实际实现中,可能还会涉及到剪枝等细节,但以上内容已经涵盖了决策树分类的基本原理。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import torch
import pandas as pd
from sklearn.model_selection import train_test_split  # Add this line
import numpy as np
import matplotlib.pyplot as plt

class DecisionTreeNode:
    def __init__(self, feature_index=None, label=None):
        self.feature_index = feature_index  # 特征索引
        self.label = label  # 叶节点标签
        self.children = {}  # 子节点字典

def entropy(labels):
    _, counts = torch.unique(labels, return_counts=True)
    probabilities = counts.float() / len(labels)
    return torch.sum(-probabilities * torch.log2(probabilities))

def information_gain(data, feature_index, target):
    feature_values = data[:, feature_index]
    unique_values = torch.unique(feature_values)
    total_entropy = entropy(target)

    weighted_entropy = 0.0
    for value in unique_values:
        mask = feature_values == value
        subset_target = target[mask]
        subset_entropy = entropy(subset_target)
        weighted_entropy += (len(subset_target) / len(target)) * subset_entropy

    return total_entropy - weighted_entropy

def build_decision_tree(data, target, features, max_depth=None, min_samples_split=2, current_depth=0):
    # 若样本中所有标签相同,则返回叶节点
    if torch.unique(target).size(0) == 1:
        return DecisionTreeNode(label=target[0].item())

    # 若达到最大深度或样本数不足以继续划分,则返回叶节点,将标签设为样本中最常见的标签
    if current_depth == max_depth or len(target) < min_samples_split:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    # 若没有特征可用,则返回叶节点,将标签设为样本中最常见的标签
    if features.size(0) == 0:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    best_feature_index = None
    best_info_gain = -1.0

    for feature_index in range(features.size(0)):
        info_gain = information_gain(data, feature_index, target)
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature_index = feature_index

    if best_info_gain == 0:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    best_feature = features[best_feature_index]
    decision_tree = DecisionTreeNode(feature_index=best_feature_index)

    unique_values = torch.unique(data[:, best_feature_index])
    for value in unique_values:
        mask = data[:, best_feature_index] == value
        subset_data = data[mask]
        subset_target = target[mask]
        subset_features = features[features != best_feature]
        decision_tree.children[value.item()] = build_decision_tree(subset_data, subset_target, subset_features, max_depth, min_samples_split, current_depth + 1)

    return decision_tree

def predict_sample(tree, sample):
    if tree.label is not None:
        return tree.label

    feature_value = sample[tree.feature_index]
    if feature_value in tree.children:
        return predict_sample(tree.children[feature_value], sample)
    else:
        # 若在训练时未见过该特征值,则返回叶节点中最常见的标签
        labels = [child.label for child in tree.children.values()]
        return max(set(labels), key=labels.count)

def predict(tree, data):
    predictions = []
    for sample in data:
        prediction = predict_sample(tree, sample)
        predictions.append(prediction)
    return torch.tensor(predictions)

def accuracy(y_true, y_pred):
    return torch.sum(y_true == y_pred).item() / len(y_true)

def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()

def main():
    # 读取Data.xlsx文件并加载数据
    data = pd.read_excel("iris.xlsx")

    # 划分特征值和标签
    features = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32)
    labels = torch.tensor(data.iloc[:, -1].values, dtype=torch.long)

    # 将数据集拆分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    # 构建决策树
    # 构建决策树,限制最大深度为5,叶节点至少包含5个样本
    decision_tree = build_decision_tree(X_train, y_train, torch.arange(X_train.size(1)), max_depth=5,
                                        min_samples_split=5)

    # 进行预测
    y_pred = predict(decision_tree, X_test)

    # 计算准确率
    acc = accuracy(y_test, y_pred)
    print("模型的准确率:", acc)
    # 绘制混淆矩阵
    from sklearn.metrics import confusion_matrix
    conf_matrix = confusion_matrix(y_test, y_pred)
    classes = ['class_0', 'class_1', 'class_2']
    plot_confusion_matrix(conf_matrix, classes)

    # 绘制真实标签与预测标签对比图
    plot_predictions_vs_true(y_test, y_pred)

if __name__ == "__main__":
    main()

7.运行结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/872816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Opencv4基于C++的 实时人脸检测

文章目录: 一&#xff1a;环境配置搭建(VS2015Opencv4.6) 二&#xff1a;下资源文件 第一种&#xff1a;本地生成 第二种 直接下载 三&#xff1a;代码展示 窗口布局 main.cpp test.h test.cpp 效果图◕‿◕✌✌✌&#xff1a;opencv人脸识别效果图(请叫我真爱粉) 一&…

运算器组成实验

1.实验目的及要求 实验目的 1、熟悉双端口通用寄存器组的读写操作。 2、熟悉运算器的数据传送通路。 3、验证运算器74LS181的算术逻辑功能。 4、按给定数据&#xff0c;完成指定的算术、逻辑运算。 实验要求 1、做好实验预习。掌握运算器的数据传送通路和ALU的功能特性&…

7.3.tensorRT高级(2)-future、promise、condition_variable

目录 前言1. 生产者消费者模式2. 问答环节总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程&#xff0c;之前有看过一遍&#xff0c;但是没有做笔记&#xff0c;很多东西也忘了。这次重新撸一遍&#xff0c;顺便记记笔记。 本次课程学习 tensorRT 高级-future、promise…

【算法——双指针】LeetCode 1089 复写零

千万不要被这道题标注着“简单”迷惑了&#xff0c;实际上需要注意的细节很多。 题目描述&#xff1a; 解题思路&#xff1a; 正序遍历&#xff0c;确定结果数组的最后一个元素所在的位置&#xff1b;知道最后一个元素的位置后倒序进行填充。 先找到最后一个需要复写的数 先…

C++ 泛型编程:函数模板

文章目录 前言一、什么是泛型编程二、函数模板三、函数模板的使用四、多参数函数模板五&#xff0c;示例代码&#xff1a;总结 前言 当需要编写通用的代码以处理不同类型的数据时&#xff0c;C 中的函数模板是一个很有用的工具。函数模板允许我们编写一个通用的函数定义&#…

php从静态资源到动态内容

1、从HTML到PHP demo.php:后缀由html直接改为php,实际上当前页面已经变成了动态的php应用程序脚本 demo.php: 允许通过<?php ... ?>标签,添加php代码到当前脚本中 php标签内部代码由php.exe解释, php标签之外的代码原样输出,仍由web服务器解析 <!DOCTYPE html>…

Qt中将信号封装在一个继承类中的方法

QLabel标签类对应的信号如下&#xff1a; Qt中标签是没有双击&#xff08;double Click&#xff09;这个信号的&#xff1b; 需求一&#xff1a;若想双击标签使其能够改变标签中文字的内容&#xff0c;那么就需要自定义一个“双击”信号&#xff0c;并将其封装在QLabel类的派生…

使用vscode写vue文件代码有时不提示

背景&#xff1a; 安装了volar插件&#xff0c;但是在vue文件中导入js文件代码不提示&#xff0c;准确来说是有时提示有时不提示 解决方案&#xff1a; 插件冲突&#xff0c;卸载 JavaScript (ES6) code snippets 插件&#xff0c;这个插件在vue文件中适配不是很好。 很有可能…

【正版系统】2023热门短剧SAAS版开源 | 小程序+APP+公众号H5

当我们在刷百度、D音、K手等各种新闻或短视频时经常会刷到剧情很有吸引力的短剧广告&#xff0c;我们点击广告链接即可进入短剧小程序&#xff0c;小程序运营者通过先免费看几集为诱耳然后在情节高潮时弹出充值或开VIP会员才能继续看的模式来赚钱&#xff0c;以超级赘婿、乡村小…

电影数据可视化综合分析

数据可视化&分析实战 1.1 沈腾参演电影数据获取 1.2 电影数据可视化分析 目录 数据可视化&分析实战前言1. 数据认知2. 数据可视化2.1 解决matplotlib不能绘制中文字符的问题2.2 折线图2.3 柱状图绘制2.4 箱线图绘制2.5 饼图 3. Na值处理及相关性分析3.1 相关性分析3.2…

2023.08.13 学习周报

文章目录 摘要文献阅读1.题目2.要点3.问题4.解决方案5.本文贡献6.方法6.1 特征选择6.2 时间序列平稳性检测与数据分解6.3 基于GRU神经网络的PM2.5浓度预测 7.实验7.1 网络参数7.2 实验结果7.3 对比实验 8.讨论9.结论10.展望 PINNS模型1.自动微分2.全连接神经网络3.PINNs模型的P…

NavMeshPlus 2D寻路插件

插件地址:h8man/NavMeshPlus&#xff1a; Unity NavMesh 2D Pathfinding (github.com) 我对Unity官方是深恶痛觉,一个2D寻路至今都没想解决,这破引擎早点倒闭算了. 这插件是githun的开源项目,我本身是有写jps寻路的,但是无法解决多个单位互相阻挡的问题(可以解决但是有性能问…

Yolov5(一)VOC划分数据集、VOC转YOLO数据集

代码使用方法注意修改一下路径、验证集比例、类别名称&#xff0c;其他均不需要改动&#xff0c;自动划分训练集、验证集、建好全部文件夹、一键自动生成Yolo格式数据集在当前目录下&#xff0c;大家可以直接修改相应的配置文件进行训练。 目录 使用方法&#xff1a; 全部代码…

Window停止更新操作

在这里插入图片描述

Android平台RTMP推送或GB28181设备接入端如何实现采集audio音量放大?

我们在做Android平台RTMP推送和GB28181设备对接的时候&#xff0c;遇到这样的问题&#xff0c;有的设备&#xff0c;麦克风采集出来的audio&#xff0c;音量过高或过低&#xff0c;特别是有些设备&#xff0c;采集到的麦克风声音过低&#xff0c;导致播放端听不清前端采集的aud…

1216. 验证回文字符串 III;764. 最大加号标志;1135. 最低成本联通所有城市

1216. 验证回文字符串 III 核心思想&#xff1a;动态规划&#xff0c;这题需要一个思路的转换&#xff0c;删除最多k个字符判断是否为回文串&#xff0c;就相当于问你子序列中最长的回文串的长度是否比n-k长,就将这题转换为了最长回文子序列。 764. 最大加号标志 核心思想&am…

前后端分离------后端创建笔记(03)前后端对接(下)

本文章转载于【SpringBootVue】全网最简单但实用的前后端分离项目实战笔记 - 前端_大菜007的博客-CSDN博客 仅用于学习和讨论&#xff0c;如有侵权请联系 源码&#xff1a;https://gitee.com/green_vegetables/x-admin-project.git 素材&#xff1a;https://pan.baidu.com/s/…

Shader 编程:三角形、矩形等多边形绘制

该原创文章首发于微信公众号&#xff1a;字节流动 未经作者&#xff08;微信ID&#xff1a;Byte-Flow&#xff09;允许&#xff0c;禁止转载 SDF 有向距离场 上节其实牵扯到 SDF 算法&#xff0c;因为后面涉及高级特效的时候会经常用到&#xff0c;这里先提前对它做个简单的介…

注意:阿里云服务器随机分配可用区说明

阿里云服务器如有ICP备案需求请勿选择随机可用区&#xff0c;因为当前地域下的可用区可能不支持备案&#xff0c;阿里云百科分享提醒大家&#xff0c;如果你的购买的云服务器搭建网站应用&#xff0c;网站域名需要使用这台云服务器备案的话&#xff0c;不要随机分配可用区&…

从源码分析常见集合的区别之List接口

说到Java集合&#xff0c;共有两大类分别是Collection和Map。今天就详细聊聊大家耳熟能详的List吧。 List接口实现自Collection接口&#xff0c;是Java的集合框架中的一员&#xff0c;List接口下又有ArrayList、LinkedList和线程安全的Vector&#xff0c;今天就简单分析一下Ar…