山东大学机器学习实验lab9 决策树

news2025/1/24 17:58:38

山东大学机器学习实验lab9 决策树

  • 所有公式来源于<<机器学习>>周志华
  • github上有.ipynb源文件

修改:

  • 2024 5.15 添加了一些Node属性,用于标记每个Node使用的划分feature的名称,修改后的版本见 github
Node
  • 构造函数 初始化决策树的节点,包含节点ID、特征数据、特征ID、划分阈值、标签、父节点ID、子节点列表以及节点所属的类别
  • set_sons 设置节点的子节点列表
  • judge_stop 判断是否停止继续生成节点条件包括
    • 所有样本属于同一类别
    • 所有样本在所有属性上取值相同
    • 节点对应的样本集合为空
DecisionTree
  • 构造函数 初始化决策树,包括根节点列表、节点ID计数器
  • ent 计算信息熵,衡量数据集纯度
  • gain_single_feature 对于连续属性,寻找最佳划分点以最大化信息增益
  • gain_rate 计算所有特征的信息增益率,选择信息增益率最高的特征作为分裂依据
  • generate_node 递归生成子节点,根据特征的最佳分裂点划分数据集,直到满足停止条件
  • train 从根节点开始构建决策树
  • predict 给定数据,根据决策树进行分类预测
Data
  • 构造函数 加载数据集,初始化特征和标签
  • get_k_exmaine 实现K折交叉验证的数据切分,返回K个特征和标签组

整体流程

  1. 数据预处理:使用Data类读取数据文件,进行K折交叉验证的数据切分
  2. 模型训练与评估:对每一份测试数据,结合其余K-1份数据进行训练,使用DecisionTree类构建决策树模型,然后对测试集进行预测,计算正确率
  3. 结果展示:收集每一次交叉验证的正确率,最后计算并输出平均正确率

代码执行流程

  • 首先,Data类加载数据并进行K折交叉验证数据分割
  • 接着,对于每个验证折,训练数据被合并以训练决策树模型,然后在对应的测试数据上进行预测
  • 对每个测试样本,通过遍历决策树找到其所属类别,并与实际标签对比,累计正确预测的数量
  • 计算并打印每次验证的正确率,最后计算并输出所有折的平均正确率,以此评估模型的泛化能力

代码以及运行结果展示

import numpy as np 
import matplotlib.pyplot as plt 
import math 
class Node():
    def __init__(self,id_,features_,feature_id_,divide_,labels_,father_,sons_=[]):
        self.divide=divide_
        self.feature_id=feature_id_
        self.id=id_
        self.feature=features_
        self.labels=labels_ 
        self.father=father_ 
        self.sons=sons_
        self.clas='None'
    def set_sons(self,sons_):
        self.sons=sons_ 
    def judge_stop(self):
        labels=np.unique(self.labels)
        #如果节点样本属于同一类别
        if(labels.shape[0]==1):
            self.clas=labels[0]
            return True
        features=np.unique(self.feature)
        #如果所有样本在所有属性上取值相同
        if(features.shape[0]==1):
            unique_values, counts = np.unique(labels, return_counts=True)
            self.clas = unique_values[counts.argmax()]
            return True 
        #如果对应的样本集合为空
        if(self.feature.shape[0]==0 or self.feature.shape[1]==0):
            self.clas=1
            return True
        return False
class DecisionTree():
    def __init__(self):
        self.tree=[]
        self.id=0
        pass 
    #计算信息熵
    def ent(self,labels):
        labels_s=list(set([labels[i,0] for i in range(labels.shape[0])]))
        ans=0
        for label in labels_s:
            num=np.sum(labels==label)
            p=num/labels.shape[0]
            ans-=p*math.log(p,2)
        return ans 
    #计算一个标签对应的最佳分界(连续值)
    def gain_single_feature(self,feature,labels):
        origin_ent=self.ent(labels)
        divide_edge=[]
        feature=list(set(feature))
        feature=np.sort(feature)
        divide_edge=[(feature[i]+feature[i+1])/2.0 for i in range(feature.shape[0]-1)]
        best_ent=0
        best_divide=0
        l1=l2=np.array([[]])
        for condition in divide_edge:
            labels1=np.array([labels[i] for i in range(feature.shape[0]) if feature[i]<=condition])
            labels2=np.array([labels[i] for i in range(feature.shape[0]) if feature[i]>condition])
            ent1=self.ent(labels1)
            ent2=self.ent(labels2)
            ans=origin_ent-((labels1.shape[0]/labels.shape[0])*ent1+(labels2.shape[0]/labels.shape[0])*ent2)
            if(ans>=best_ent):
                best_divide=condition
                l1=labels1
                l2=labels2
                best_ent=ans 
        return best_divide,l1,l2,best_ent
    #计算信息增益
    def gain_rate(self,features,labels):
        origin_ent=self.ent(labels)
        gain_rate=0
        feature_id=-1
        divide=-1
        l=labels.shape[0]
        for id in range(features.shape[1]):
            divide1,labels1,labels2,th_gain=self.gain_single_feature(features[:,id],labels)
            l1=labels1.shape[0]
            l2=labels2.shape[0]
            iv=-1*((l1/l)*math.log(l1/l,2)+(l2/l)*math.log(l2/l,2))
            if iv!=0:
                rate=th_gain/iv
            else:
                rate=0
            if(rate>=gain_rate):
                gain_rate=rate
                divide=divide1
                feature_id=id
        return feature_id,divide
    def generate_node(self,node:Node):
        a=1
        features1_id=np.array([i for i in range(node.feature.shape[0]) if node.feature[i,node.feature_id]>=node.divide])
        features2_id=np.array([i for i in range(node.feature.shape[0]) if node.feature[i,node.feature_id]<node.divide])
        features1=node.feature[features1_id]
        features2=node.feature[features2_id]
        labels1=node.labels[features1_id]
        labels2=node.labels[features2_id]
        features1=np.delete(features1,node.feature_id,axis=1)
        features2=np.delete(features2,node.feature_id,axis=1)
        features_id1,divide1=self.gain_rate(features1,labels1)
        features_id2,divide2=self.gain_rate(features2,labels2)
        tmp=0
        if(features_id1!=-1):
            tmp+=1
            node1=Node(self.id+tmp,features1,features_id1,divide1,labels1,node.id,[])
            node1.father=node.id
            self.tree.append(node1)
            node.sons.append(self.id+tmp)
        if(features_id2!=-1):
            tmp+=1
            node2=Node(self.id+tmp,features2,features_id2,divide2,labels2,node.id,[])
            node2.father=node.id
            self.tree.append(node2)
            node.sons.append(self.id+tmp)
        self.id+=tmp
        if(tmp==0):
            unique_values, counts = np.unique(node.labels, return_counts=True)
            node.clas = 0 if counts[0]>counts[1] else 1
            return 
        for n in [self.tree[i] for i in node.sons]:
            if(n.judge_stop()):
                continue
            else:
                self.generate_node(n)
    def train(self,features,labels):
        feature_id,divide=self.gain_rate(features,labels)
        root=Node(0,features,feature_id,divide,labels,-1,[])
        self.tree.append(root)
        self.generate_node(root)
    def predict(self,features):
        re=[]
        for feature in features:
            node=self.tree[0]
            while(node.clas=='None'):
                th_feature=feature[node.feature_id]
                feature=np.delete(feature,node.feature_id,axis=0)
                th_divide=node.divide
                if(node.clas!='None'):
                    break 
                if(th_feature<th_divide):
                    node=self.tree[node.sons[len(node.sons)-1]]
                else:
                    node=self.tree[node.sons[0]]
            re.append(node.clas)
        return re 
class Data():
    def __init__(self):
        self.data=np.loadtxt('/home/wangxv/Files/hw/ml/lab9/data/ex6Data/ex6Data.txt',delimiter=',')
        self.data_num=self.data.shape[0]
        self.features=self.data[:,:-1]
        self.labels=self.data[:,-1:]
    def get_k_exmaine(self,k:int):
        num=int(self.data_num/k)
        data=self.data
        np.random.shuffle(self.data)
        features=data[:,:-1]
        labels=self.data[:,-1:]
        feature_groups=[features[i:i+num-1] for i in np.linspace(0,self.data_num,k+1,dtype=int)[:-1]]
        labels_groups=[labels[i:i+num-1] for i in np.linspace(0,self.data_num,k+1,dtype=int)[:-1]]
        return feature_groups,labels_groups
data=Data() 
feature_groups,label_groups=data.get_k_exmaine(10)
rate_set=[]
for ind in range(10):
    dt=DecisionTree()
    features_=[feature_groups[i] for i in range(10) if i!=ind]
    labels_=[label_groups[i] for i in range(10) if i!=ind]
    train_features=features_[0]
    train_labels=labels_[0]
    for feature,label in zip(features_[1:],labels_[1:]):
        train_features=np.vstack((train_features,feature))
        train_labels=np.vstack((train_labels,label))
    test_features=feature_groups[ind]
    test_labels=label_groups[ind]
    dt.train(train_features,train_labels)
    pred_re=dt.predict(test_features)
    right_num=0
    for i in range(len(pred_re)): 
        if pred_re[i]==test_labels[i][0]:
            right_num+=1
    right_rate=right_num/len(pred_re)
    print(str(ind+1)+'  correct_rate : '+str(right_rate))
    rate_set.append(right_rate)
print("average_rate : "+str(np.mean(np.array(rate_set))))
1  correct_rate : 0.7930327868852459
2  correct_rate : 0.8032786885245902
3  correct_rate : 0.8012295081967213
4  correct_rate : 0.7848360655737705
5  correct_rate : 0.7889344262295082
6  correct_rate : 0.7909836065573771
7  correct_rate : 0.7827868852459017
8  correct_rate : 0.7766393442622951
9  correct_rate : 0.8012295081967213
10  correct_rate : 0.7725409836065574
average_rate : 0.7895491803278689

最终得到的平均正确率为0.7895>0.78符合实验书要求

可视化

  • 需要将自定义的Node类转换为可识别的数据格式
from graphviz import Digraph

def show_tree(root_node, dot=None):
    if dot is None:
        dot = Digraph(comment='decision_tree')
    node_label = f"{root_node.id} [label=\"feature {root_node.feature_id}: {root_node.divide} | class: {root_node.clas}\"]"
    dot.node(str(root_node.id), node_label)
    if root_node.sons:
        for son_id in root_node.sons:
            dot.edge(str(root_node.id), str(son_id))
            visualize_tree(dt.tree[son_id], dot)
    return dot
root_node=dt.tree[0]
dot = show_tree(root_node)
dot.render('decision_tree', view=True,format='png')
from PIL import Image
Image.open('./decision_tree.png')
  • 这个决策树图有点过于大了,凑合看吧
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1681174.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Flask-RESTful构建RESTful API

文章目录 安装Flask-RESTful导入模块和类创建一个资源类运行应用测试API总结 Flask是一个轻量级的Python web开发框架&#xff0c;而Flask-RESTful是一个基于Flask的扩展&#xff0c;专门用于构建RESTful API。它提供了一些帮助类和方法&#xff0c;使构建API变得更加简单和高效…

文档分类DPCNN简介(pytorch实现)

文档分类DPCNN简介 DPCNN简介 模型结构区域嵌入等长卷积1/2池化DPCNN模型代码实现 DPCNN简介 论文中提出了一种基于 word-level 级别的网络-DPCNN&#xff0c;由于 TextCNN 不能通过卷积获得文本的长距离依赖关系&#xff0c;而论文中 DPCNN 通过不断加深网络&#xff0c;可以…

家用充电桩远程监控安全管理系统解决方案

家用充电桩远程监控安全管理系统解决方案 在当今电动汽车日益普及的背景下&#xff0c;家用充电桩的安全管理成为了广大车主关注的重点问题。为了实现对充电桩的高效、精准、远程监控&#xff0c;一套完善的家用充电桩远程监控安全管理系统解决方案应运而生。本方案旨在通过先…

6大部分,20 个机器学习算法全面汇总!!建议收藏!(上篇)

前两天有小伙伴说想要把常见算法的原理 公式汇集起来。 这样非常非常方便查看&#xff01;分为上下两篇&#xff0c;下篇地址&#xff1a; 本次文章分别从下面6个方面&#xff0c;涉及到20个算法知识点&#xff1a; 监督学习算法 无监督学习算法 半监督学习算法 强化学习…

【阿里云】云服务器ECS运行node服务

本文介绍如何在&#xff08;CentOS 7.9 64位&#xff09;操作系统的ECS实例上&#xff0c;安装Node.js并部署测试项目。 使用工具&#xff1a;FinalShell4.3.10 目录 步骤一&#xff1a;部署Node.js环境 1.远程连接已创建的ECS实例。 2.部署Node.js环境。 a.安装分布式版本管…

人机协同中的比较、调整与反转

人机协同是指人与机器之间的合作关系&#xff0c;通过共同努力实现特定任务的目标。在人机协同中&#xff0c;存在着比较与调整的过程&#xff0c;这是为了实现更好的合作效果和任务完成质量。 比较是指人与机器在任务执行过程中对彼此的表现进行评估和比较。这可以通过对机器的…

如何下载小米壁纸到本地分享给他人

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 操作方法 📒🚥 注意事项⚓️ 相关链接 ⚓️📖 介绍 📖 你是否曾被小米主题壁纸软件中的精美壁纸所吸引,却苦于无法将其下载到本地或与朋友分享?本文将为你揭晓如何将小米壁纸下载到本地分享给他人! 🏡 演示环境 �…

图文教程 | 2024年最新VSCode下载和安装教程c/c++环境配置,json文件详解,实用插件分享

前言 &#x1f4e2;博客主页&#xff1a;程序源⠀-CSDN博客 &#x1f4e2;欢迎点赞&#x1f44d;收藏⭐留言&#x1f4dd;如有错误敬请指正&#xff01; 由于重装电脑&#xff0c;需要重新安装VsCode&#xff0c;记录安装配置过程。 一、VSCode下载 官网地址&#xff1a; Vis…

Spring Security实现用户认证二:前后端分离时自定义返回Json内容

Spring Security实现用户认证二&#xff1a;前后端分离时自定义返回Json内容 1 前后端分离2 准备工作依赖WebSecurityConfig配置类 2 自定义登录页面2.1 Spring Security的默认登录页面2.2 自定义配置formLogin 3 自定义登录成功处理器4 自定义登录失败处理器5 自定义登出处理器…

5万字带你一文看懂自动驾驶之高精度地图前世今生

在讲解高精度地图之前&#xff0c;我们先把定位这个事情弄清楚&#xff0c;想明白&#xff0c;后面的事情就会清晰很多&#xff0c;自古哲学里面讨论的人生终极问题&#xff0c;无非就三个&#xff0c;我是谁&#xff0c;我从哪里来&#xff0c;我要去哪里&#xff0c;这里的位…

语言模型测试系列【8】

语言模型 文心一言星火认知大模型通义千问豆包360智脑百川大模型腾讯混元助手Kimi Chat商量C知道 这次的测试比较有针对性&#xff0c;是在使用钉钉新推出的AI助理功能之后发现的问题&#xff0c;即创建AI助理绑定自己钉钉的知识库进行问答&#xff0c;其中对于表结构的文档学…

Vue3商城后台管理实战-用户登录界面设计

界面设计 此时界面的预览效果如下&#xff1a; 登录界面的完整代码如下&#xff1a; <script setup> import {reactive} from "vue/reactivity";const form reactive({username: "",password: "", })const onSubmit () > {} <…

多点 Dmall x TiDB:出海多云多活架构下的 TiDB 运维实战

作者&#xff1a;多点&#xff0c;唐万民 导读 时隔 2 年&#xff0c; 在 TiDB 社区成都地区组织者冯光普老师的协助下&#xff0c;TiDB 社区线下地区活动再次来到成都。来自多点 Dmall 的国内数据库负责人唐万民老师&#xff0c;在《出海多云架构&#xff0c;多点 TiDB 运维…

【class9】人工智能初步(处理单张图片)

Class9的任务&#xff1a;处理单张图像 为了更高效地学习&#xff0c;我们将“处理单张图像”拆分成以下几步完成&#xff1a; 1. 读取图像文件 2. 调用通用物体识别 3. 提取图像分类信息 4. 对应分类文件夹还未创建时&#xff0c;创建文件夹 5. 移动图像到对应文件夹 0.获取…

Qt---TCP文件传输服务器

文件传输流程&#xff1a; 服务器端&#xff1a; serverwidget.ui serverwidget.h #ifndef SERVERWIDGET_H #define SERVERWIDGET_H#include <QWidget> #include<QTcpServer>//监听套接字 #include<QTcpSocket>//通信套接字 #include<QFile> #includ…

查看Linux系统是Ubuntu还是CentOS

要查看Linux系统是Ubuntu还是CentOS&#xff0c;可以通过多种方式进行确认&#xff1a; 查看/etc/os-release文件&#xff1a; 在终端中执行以下命令&#xff1a; cat /etc/os-release 如果输出中包含"IDubuntu"&#xff0c;则表示系统是Ubuntu&#xff1b;如果输出中…

构建智能电子商务系统:数字化引领未来商业发展

随着互联网技术的飞速发展和消费者行为的变革&#xff0c;电子商务系统的重要性日益凸显。在这一背景下&#xff0c;构建智能电子商务系统成为推动商业数字化转型的关键举措。本文将深入探讨智能电子商务系统的构建与优势&#xff0c;助力企业把握数字化转型的主动权。 ### 智…

【Linux】19. 习题②

2022-11-12_Linux环境变量 1. 分页存储(了解) 一个分页存储管理系统中&#xff0c;地址长度为 32 位&#xff0c;其中页号占 8 位&#xff0c;则页表长度是__。 A.2的8次方 B.2的16次方 C.2的24次方 D.2的32次方 【答案解析】A 页号即页表项的序号&#xff0c;总共占8个二进制…

数字化智能:Web3时代的物联网创新之路

引言 随着科技的不断发展&#xff0c;物联网&#xff08;IoT&#xff09;技术正在迅速普及和应用。而随着Web3时代的到来&#xff0c;物联网将迎来新的发展机遇和挑战。本文将探讨Web3时代的物联网创新之路&#xff0c;深入分析其核心技术、应用场景以及未来发展趋势。 Web3时…

C语言性能深度剖析:从底层优化到高级技巧及实战案例分析

C语言以其接近硬件的特性、卓越的性能和灵活性&#xff0c;在系统编程、嵌入式开发和高性能计算等领域中占据着举足轻重的地位。本文将深入探讨C语言性能优化的各个方面&#xff0c;包括底层原理、编译器优化、内存管理和高级编程技巧&#xff0c;并结合多个代码案例来具体分析…