机器学习实验二-----决策树构建

news2025/1/16 0:34:09

决策树是机器学习中一种基本的分类和回归算法,是依托于策略抉择而建立起来的树。本文学习的是决策树的分类

1. 构建决策树流程

  1. 选择算法:常用的算法包括ID3、C4.5、CART等。

  2. 划分节点:根据数据特征和算法选择,递归地划分节点,直到满足停止条件。

  3. 决策树剪枝:对决策树进行剪枝操作,减少决策树的复杂度,提高泛化能力。

  4. 决策树评估:使用测试数据集评估决策树模型的性能,通常使用准确率、召回率、F1值等。

2.常用的三个算法

2.1 ID3

D3采用信息增益来划分属性。

2.12 信息熵

用来衡量数据集的混乱程度,信息熵越大,表明数据集的混乱程度越大,不确定性越大。

公式:H(X=xi)=-\sum_{1}^{n}pi log pi

其中pi表示的是分类为xi这个样本在中的占比。

2.12信息增益

 划分数据集之前之后信息发生的变化      

公式:Gain(D,a)=Ent(D)-\sum_{v=1}^{V}\frac{\left | D^{v} \right |}{\left | D \right |}Ent(D^{v})

信息增益越大,则意味着采用该属性a划分节点获得的纯度提升更大。在每次划分中采用信息增益最大的划分。

信息增益实际上就是数据集整体的信息熵减去使用特征 a进行划分后各子集的加权平均信息熵,即子集的信息熵的期望值。当信息增益越大时,意味着子集的信息熵的减少量越大,即数据集的不确定性减少的程度更大,信息熵变小。

2.2   C4.5

C4.5算法在ID3算法上做了提升,使用信息增益比来构造决策树,且有剪枝功能防止过拟合。

信息增益比:特征a对训练集D的信息增益比定义为特征a的信息增益与训练集D对于a的信息熵之比, 同样是信息增益比越大越好。

公式:Gain_radio(D,a)=\frac{Gain(D,a)}{H(D,a)}

先剪枝:提前停止树的构建而对树”剪枝“,提前停止的策略有定义一个树的深度,到达指定深度自动停止构造;

后剪枝:先构造完整的子树,对于决策树中信息增益比较低的子树用叶子节点代替。

2.3 CART基尼指数

基尼指数是衡量数据集纯度或不确定性的一种指标,常用于决策树算法中的特征选择和节点划分。

公式:\text{Gini}(D) = 1 - \sum_{k=1}^{K} (p_k)^2

基尼指数越小越好。

3.划分节点

划分节点就是根据我们选择的算法来进行划分的,我们这边拿C4.5算法来举例一下。拿鸢尾花来举例一下,我们有花萼长度,花萼宽度,花瓣长度,花瓣宽度四个特征值。我们分别计算一下每一个的熵,根据公式计算出信息增益比,选择按照信息增益比大小排序的特征来当划分的依据。这边假设我们排序就是花萼长度,花萼宽度,花瓣长度,花瓣宽度,那我们先按照花萼长度把根节点划分为左右子树,子树再根据花萼宽度把子树再继续划分,一直用递归来划分每一个节点。最后就得到整棵树。

4. 决策树的剪枝

决策树生成算法递归的产生决策树,直到不能继续下去为止,这样产生的树往往对训练数据的分类很准确,但对未知测试数据的分类缺没有那么精确,即会出现过拟合现象。

过拟合产生的原因在于在学习时过多的考虑如何提高对训练数据的正确分类,从而构建出过于复杂的决策树,解决方法是考虑决策树的复杂度,对已经生成的树进行简化。

剪枝:从已经生成的树上裁掉一些子树或叶节点,并将其根节点或父节点作为新的叶子节点,从而简化分类树模型。

剪枝分为预剪枝与后剪枝。

预剪枝是指在决策树的生成过程中,对每个节点在划分前先进行评估,若当前的划分不能带来泛化性能的提升,则停止划分,并将当前节点标记为叶节点。

后剪枝是指先从训练集生成一颗完整的决策树,然后自底向上对非叶节点进行考察,若将该节点对应的子树替换为叶节点,能带来泛化性能的提升,则将该子树替换为叶节点。

5.决策树构建

这个函数直接依据我们所给的数据用C4.5创建了决策树

clf.fit(X_train, y_train)

6.决策树可视化

使用了export_graphviz函数从训练好的决策树模型中生成一个Graphviz格式的文本文件,然后使用graphviz.Source将这个文本文件转换为一个Graphviz对象,最后使用render方法将这个对象渲染为图形文件。

dot_data = export_graphviz(clf, out_file=None,
                     feature_names=["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"],
                     class_names=["setosa", "versicolor", "virginica"],
                     filled=True, rounded=True,
                     special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("iris_decision_tree") 

graph.view()

 运行这个代码得到创建的决策树会包含我们数据集内的所有东西,我的鸢尾花数据集中每个特征的信息增益比在不同的条件下就得到下面这棵树的划分标准。

完整代码展现:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier


def load_data(train_file, test_file):
    train_data = pd.read_csv(train_file, sep='\s+')
    test_data = pd.read_csv(test_file, sep='\s+')

    X_train = train_data[["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]].values
    y_train = train_data["Species"].values
    X_test = test_data[["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]].values
    y_test = test_data["Species"].values

    return X_train, y_train, X_test, y_test


clf = DecisionTreeClassifier()

train_file = "C:\\Users\\李烨\\Desktop\\新建文件夹\\6\\iris.txt"
test_file = "C:\\Users\\李烨\\Desktop\\新建文件夹\\6\\iristest.txt"
X_train, y_train, X_test, y_test = load_data(train_file, test_file)

clf.fit(X_train, y_train)


def predict_flower(sepal_length, sepal_width, petal_length, petal_width):
    input_features = [[sepal_length, sepal_width, petal_length, petal_width]]
    prediction = clf.predict(input_features)
    if prediction[0] == 'setosa':
        print("预测类别:setosa")
    elif prediction[0] == 'versicolor':
        print("预测类别:versicolor")
    elif prediction[0] == 'virginica':
        print("预测类别:virginica")
    else:
        print("未知类别")
    return prediction[0]


def get_input():
    sepal_length = float(input("请输入花萼长度:"))
    sepal_width = float(input("请输入花萼宽度:"))
    petal_length = float(input("请输入花瓣长度:"))
    petal_width = float(input("请输入花瓣宽度:"))
    return sepal_length, sepal_width, petal_length, petal_width


print("实验二决策树分类")
while True:
    try:
        user_input = input("输入 'exit' 退出:")
        if user_input.lower() == 'exit':
            print("程序结束")
            break
            
            
        sepal_length, sepal_width, petal_length, petal_width = get_input()
        result = predict_flower(sepal_length, sepal_width, petal_length, petal_width)

    except ValueError:
            print("输入有误,请重新输入。")



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1603578.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++算法题 - 矩阵

目录 36. 有效的数独54. 螺旋矩阵48. 旋转图像73. 矩阵置零289. 生命游戏 36. 有效的数独 LeetCode_link 请你判断一个 9 x 9 的数独是否有效。只需要 根据以下规则 ,验证已经填入的数字是否有效即可。 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现…

Linux安装和使用Android Debug Bridge(ADB)

目录 1、开发环境和工具 2、ADB是什么? 3、安装ADB 3.1、使用包管理器安装 ADB 3.2、手动安装 ADB 4、使用ADB 4.1、连接设备 4.2、执行shell命令 4.3、安装应用程序 4.4、截取屏幕截图 4.5、模拟按键和手势 4.6、上传文件到Android设备 4.7、从Android设备下载文件…

el-table 表格列里添加 树

<el-table-column label"部门名称" align"center"><template slot-scope"scope"><el-cascader filterable :disabled"type 3 ? true : false" :show-all-levels"false":ref"provinceTree scope.row.…

【机器学习300问】72、神经网络的隐藏层数量和各层神经元节点数如何影响模型的表现?

评估深度学习的模型的性能依旧可以用偏差和方差来衡量。它们反映了模型在预测过程中与理想情况的偏离程度&#xff0c;以及模型对数据扰动的敏感性。我们简单回顾一下什么是模型的偏差和方差&#xff1f; 一、深度学习模型的偏差和方差 偏差&#xff1a;衡量模型预测结果的期望…

[Meachines][Easy] Usage

Main # nmap -sV -sC 10.10.11.18 --min-rate 1000 # echo 10.10.11.18 usage.htb admin.usage.htb >> /etc/hosts 在/forget-password发现存在SQL注入 emailmaptnh%40log.comAND5212%3dBENCHMARK(5000000,MD5(0x62434473))--NKGG $ sqlmap -r request.txt --level 5 -…

js: UrlDecode解码、UUID和GUID、阿拉伯数字转为中文数字

UrlDecode解码&#xff1a; UrlDecode 是一个 JavaScript 函数&#xff0c;用于将经过 URL 编码的字符串转换为普通字符串。 URL 编码是将特殊字符转换为它们的百分比编码表示形式的过程。这些特殊字符包括空格、斜线、井号&#xff08;#&#xff09;等。UrlDecode 函数将这些…

【C++杂货铺】继承

目录 &#x1f308;前言&#x1f308; &#x1f4c1; 继承的概念和定义 &#x1f4c2; 概念 &#x1f4c2; 定义 &#x1f4c1; 基类和派生类对象赋值转换 &#x1f4c1; 继承中的作用域 &#x1f4c1; 派生类的默认成员函数 构造函数 析构函数 拷贝构造函数 赋值重载…

html select 支持内容过滤列表 -bootstrap实现

实现使用bootstrap-select插件 http://silviomoreto.github.io/bootstrap-select <!DOCTYPE html> <html> <meta charset"UTF-8"> <head><title>jQuery bootstrap-select可搜索多选下拉列表插件-www.daimajiayuan.com</title>&…

【Pytorch】Conv1d

conv1d 先看看官方文档 再来个简单的例子 import torch import numpy as np import torch.nn as nndata np.arange(1, 13).reshape([1, 4, 3]) data torch.tensor(data, dtypetorch.float) print("[data]:\n", data) conv nn.Conv1d(in_channels4, out_channels1…

ARM作业day8

温湿度数据采集应用&#xff1a; 由上图可知&#xff1a; 控制温湿度采集模块的引脚是PF14&#xff08;串行时钟线&#xff09;和PF15&#xff08;串行数据线&#xff09;&#xff1a;控制温湿度采集模块的总线是AHB4&#xff0c;通过GPIOF串口和RCC使能完成初始化操作。 控制…

批量插入10w数据方法对比

环境准备(mysql5.7) CREATE TABLE user (id bigint(20) NOT NULL AUTO_INCREMENT COMMENT 唯一id,user_id bigint(10) DEFAULT NULL COMMENT 用户id-uuid,user_name varchar(100) NOT NULL COMMENT 用户名,user_age bigint(10) DEFAULT NULL COMMENT 用户年龄,create_time time…

ubuntu 查询mysql的用户名和密码 ubuntu查看username

ubuntu 查询mysql的用户名和密码 ubuntu查看username 文章标签mysqlUbuntu用户名文章分类MySQL数据库 一.基本命令 1.查看Ubuntu版本 $ lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 16.04.5 LTS Release: 16.04 Coden…

HarmonyOS开发实例:【分布式手写板】

介绍 本篇Codelab使用设备管理及分布式键值数据库能力&#xff0c;实现多设备之间手写板应用拉起及同步书写内容的功能。操作流程&#xff1a; 设备连接同一无线网络&#xff0c;安装分布式手写板应用。进入应用&#xff0c;点击允许使用多设备协同&#xff0c;点击主页上查询…

接口压力测试 jmeter--入门篇(一)

一 压力测试的目的 评估系统的能力识别系统的弱点&#xff1a;瓶颈/弱点检查系统的隐藏的问题检验系统的稳定性和可靠性 二 性能测试指标以及测算 【虚拟用户数】&#xff1a;线程用户【并发数】&#xff1a;指在某一时间&#xff0c;一定数量的虚拟用户同时对系统的某个功…

如何使用 ArcGIS Pro 制作边界晕渲效果

在某些出版的地图中&#xff0c;边界有类似于“发光”的晕渲效果&#xff0c;这里为大家介绍一下如何使用ArcGIS Pro 制作这种晕渲效果&#xff0c;希望能对你有所帮助。 数据来源 教程所使用的数据是从水经微图中下载的行政区划数据&#xff0c;除了行政区划数据&#xff0c…

【C++进阶】C++中的继承

一、概述 作为C的三大特性之一封装&#xff0c;继承&#xff0c;多态 中的继承&#xff0c;我们在进阶部分一定要详细说明。请跟着如下的小标题进入深度学习。 二、正文 1.继承的概念及定义 首先&#xff0c;我们先要知道什么是继承&#xff0c; 继承 (inheritance)机制是面…

Unity之OpenXR+XR Interaction Toolkit快速监听手柄任意按键事件

前言 当我们开发一个VR时,有时希望监听一个手柄按键的点击事件,或者一个按钮的Value值等。但是每次有可能监听的按钮有不一样,有可能监听的值不一样,那么每次这么折腾,有点累了,难道就没有一个万能的方法,让我可以直接监听我想要的某个按钮的事件么? 答案是肯定的,今…

vscode 搭建stm32开发环境记录(eide+cortex-debug+jlink)

前言 clion使用的快过期了&#xff0c;所以就准备使用vscode 来代替clion作为代码开发环境 vscode 插件安装 创建个空白工程 添加项目相关的源文件&#xff0c;和配置宏定义和头文件目录 编译和烧录(ok) 结合cortex-debug 结果(测试ok)

Prometheus + Grafana 搭建监控仪表盘

目标要求 1、需要展现的仪表盘&#xff1a; SpringBoot或JVM仪表盘 Centos物理机服务器&#xff08;实际为物理分割的虚拟服务器&#xff09;仪表盘 2、展现要求: 探索Prometheus Grafana搭建起来的展示效果&#xff0c;尽可能展示能展示的部分。 一、下载软件包 监控系统核心…

政安晨:【深度学习神经网络基础】(十一)—— 激活函数的导数以及在反向传播中的应用

目录 线性激活函数的导数 Softmax激活函数的导数 S型激活函数的导数 双曲正切激活函数的导数 ReLU激活函数的导数 如何在反向传播中应用 批量训练和在线训练 随机梯度下降 反向传播权重更新 选择学习率和动量 Nesterov动量 政安晨的个人主页&#xff1a;政安晨 欢迎…