回顾分类决策树相关知识并利用python实现

news2024/9/23 3:31:54

         大家好,我是带我去滑雪!

         决策树(Decision Tree)是一种基本的分类与回归方法,呈树形结构,在分类问题中,表示预计特征对实例进行分类的过程。它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点node)和有向边directed edge)组成。

        结点有两种类型:内部结点internal node)和叶结点leaf node)。内部结点表示一个特征或属性,叶结点表示一个。从根节点开始,对样本数据某一个特征进行测试,根据测试结果,将对应样本分配到子节点上,形成两个或者多棵子树,然后再对每一棵子树重复同样的过程,直到样本数据分到某个叶子节点。简单来说,决策树就是通过一系列规则对数据进行分类的过程

目录

1、如何构建决策树?

2、特征选择

 3、决策树生成

 4、决策树剪枝

5、实现决策树案例1—自定义函数

(1)导入相关模块

(2)导入数据

(3)自定义函数

(4)调用函数

(5)模型预测

6、使用python中的模块实现决策树

(1)导入相关模块

(2)导入数据

(3)计算模型得分

(4)测试集预测并计算混淆矩阵

(5)计算模型评价指标

(6)绘制ROC曲线


1、如何构建决策树?

       主要步骤如下:

1)选择最优特征(构建根节点)决策树是依靠信息熵变化的程度来选择最优特征的,并以此分割训练数据集;

2)生成决策树(叶子节点的选择):若子集被基本正确分类,构建叶结点,否则,继续选择新的最优特征;重复上面两步,直到所有训练数据子集被正确分类;决策树的三种常见算法,则是根据选择最优属性时计算的信息熵函数不同划分的。ID3 是根据信息熵,C4.5是根据信息增益率。CART是采用了基尼Gini系数。

3剪枝(防止过拟合)剪枝是把通过训练集得到的决策树,切掉一些叶子节点,一方面可以减少判决步骤,提高判决效率。但更主要的是,防决策树模型对训练集数据过拟合,使决策树具有更好的泛化能力。

2、特征选择

        特征选择在于选取对训练数据具有分类能力的特征。这样可以提高决策树学习的效率。如果利用一个特征进行分类的结果与随机分类的结果没有很大差别,则称这个特征是没有分类能力的。经验上人叼这样的特征对决策树学习的精度影响不大。通常特征选择的准则是信息增益和信息增益比,特征选择是决定用哪个特征来划分特征空间。

 3、决策树生成

 4、决策树剪枝

        优秀的决策树是:在具有好的拟合和泛化能力的同时,深度小、叶结点少。所以为了更好的契合真实情况,需要对决策树进行剪枝操作。一般情况下有两种剪枝方法

      预剪枝构造决策树的过程中,先对每个结点在划分前进行估计,如果当前结点的划分不能提高泛化能力,则停止划分,将当前结点标记为叶结点

     后剪枝:生成一颗完整的决策树之后,自底向上地对内部结点进行考察,若此内部结点变为叶结点,可以提升泛化性能,则做此替换

     总结:预剪枝 ----> 边构造边剪枝,后剪枝 ----> 构造完再剪枝

5、实现决策树案例1—自定义函数

(1)导入相关模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
import math
from math import log
import pprint

(2)导入数据

def create_data():
    datasets = [['青年', '否', '否', '一般', '否'],
               ['青年', '否', '否', '好', '否'],
               ['青年', '是', '否', '好', '是'],
               ['青年', '是', '是', '一般', '是'],
               ['青年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '好', '否'],
               ['中年', '是', '是', '好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '好', '是'],
               ['老年', '是', '否', '好', '是'],
               ['老年', '是', '否', '非常好', '是'],
               ['老年', '否', '否', '一般', '否'],
               ]
    labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']
    # 返回数据集和每个维度的名称
    return datasets, labels

datasets, labels = create_data()
train_data = pd.DataFrame(datasets, columns=labels)
train_data

输出结果:

(3)自定义函数

class Node:
    def __init__(self, root=True, label=None, feature_name=None, feature=None):
        self.root = root
        self.label = label
        self.feature_name = feature_name
        self.feature = feature
        self.tree = {}
        self.result = {
            'label:': self.label,
            'feature': self.feature,
            'tree': self.tree
        }

    def __repr__(self):
        return '{}'.format(self.result)

    def add_node(self, val, node):
        self.tree[val] = node

    def predict(self, features):
        if self.root is True:
            return self.label
        return self.tree[features[self.feature]].predict(features)


class DTree:
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon
        self._tree = {}

    # 熵
    @staticmethod
    def calc_ent(datasets):
        data_length = len(datasets)
        label_count = {}
        for i in range(data_length):
            label = datasets[i][-1]
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        ent = -sum([(p / data_length) * log(p / data_length, 2)
                    for p in label_count.values()])
        return ent

    # 经验条件熵
    def cond_ent(self, datasets, axis=0):
        data_length = len(datasets)
        feature_sets = {}
        for i in range(data_length):
            feature = datasets[i][axis]
            if feature not in feature_sets:
                feature_sets[feature] = []
            feature_sets[feature].append(datasets[i])
        cond_ent = sum([(len(p) / data_length) * self.calc_ent(p)
                        for p in feature_sets.values()])
        return cond_ent

    # 信息增益
    @staticmethod
    def info_gain(ent, cond_ent):
        return ent - cond_ent

    def info_gain_train(self, datasets):
        count = len(datasets[0]) - 1
        ent = self.calc_ent(datasets)
        best_feature = []
        for c in range(count):
            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))
            best_feature.append((c, c_info_gain))
        # 比较大小
        best_ = max(best_feature, key=lambda x: x[-1])
        return best_

    def train(self, train_data):
        """
        input:数据集D(DataFrame格式),特征集A,阈值eta
        output:决策树T
        """
        _, y_train, features = train_data.iloc[:, :
                                               -1], train_data.iloc[:,
                                                                    -1], train_data.columns[:
                                                                                            -1]
        # 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T
        if len(y_train.value_counts()) == 1:
            return Node(root=True, label=y_train.iloc[0])

        # 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T
        if len(features) == 0:
            return Node(
                root=True,
                label=y_train.value_counts().sort_values(
                    ascending=False).index[0])

        # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征
        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))
        max_feature_name = features[max_feature]

        # 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T
        if max_info_gain < self.epsilon:
            return Node(
                root=True,
                label=y_train.value_counts().sort_values(
                    ascending=False).index[0])

        # 5,构建Ag子集
        node_tree = Node(
            root=False, feature_name=max_feature_name, feature=max_feature)

        feature_list = train_data[max_feature_name].value_counts().index
        for f in feature_list:
            sub_train_df = train_data.loc[train_data[max_feature_name] ==
                                          f].drop([max_feature_name], axis=1)

            # 6, 递归生成树
            sub_tree = self.train(sub_train_df)
            node_tree.add_node(f, sub_tree)

        # pprint.pprint(node_tree.tree)
        return node_tree

    def fit(self, train_data):
        self._tree = self.train(train_data)
        return self._tree

    def predict(self, X_test):
        return self._tree.predict(X_test)

(4)调用函数

datasets, labels = create_data()
train_data = pd.DataFrame(datasets, columns=labels)
dt = DTree()
tree = dt.fit(train_data)
tree

输出结果:

{'label:': None, 'feature': 2, 'tree': {'否': {'label:': None, 'feature': 1, 'tree': {'否': {'label:': '否', 'feature': None, 'tree': {}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}

(5)模型预测

dt.predict(['老年', '否', '否', '一般'])

输出结果:

'否'

6、使用python中的模块实现决策树

(1)导入相关模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

(2)导入数据

data=pd.read_csv("E:\工作\硕士\博客\博客44-/车辆保险训练集.csv")
X_raw=data.iloc[:,:-1]
X = pd.get_dummies(X_raw)
y=data.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2, random_state=100)
data.info()
data.head(10)

部分输出结果:

(3)计算模型得分

clf = DecisionTreeClassifier(criterion='gini',max_depth=6, max_features=6,random_state=123)
model=clf.fit(X_train, y_train)
clf.score(X_test, y_test)

输出结果:

0.8025

(4)测试集预测并计算混淆矩阵

y_pred_clf = clf.predict_proba(X_test)
y_pred_clf[:5]

pred = model.predict(X_test)
pred[:5]

输出结果:

array([0, 0, 0, 1, 0], dtype=int64)

table = pd.crosstab(y_test, pred, rownames=['Actual'], colnames=['Predicted'])
table

输出结果:

(5)计算模型评价指标

from sklearn import metrics
print("树模型在测试集上的准确度为:", metrics.accuracy_score(pred, y_test))
print("树模型在测试集上的精度为:", metrics.average_precision_score(pred, y_test))
print("树模型在测试集上的召回率为:", metrics.recall_score(pred, y_test))
print("树模型在测试集上的F1度量为:", metrics.f1_score(pred, y_test))
metrics.plot_confusion_matrix(clf, X_test, y_test)  
plt.show()

输出结果:

树模型在测试集上的准确度为: 0.8025
树模型在测试集上的精度为: 0.7072256945930905
树模型在测试集上的召回率为: 0.7408450704225352
树模型在测试集上的F1度量为: 0.7690058479532164

(6)绘制ROC曲线

 import sklearn.metrics as metrics
y_pred_clf = clf.predict_proba(X_test)[:,1]
fpr, tpr, threshold=metrics.roc_curve(y_test,y_pred_clf)
roc_auc = metrics.auc(fpr, tpr)
#plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % roc_auc)#生成ROC曲线
lw = 2
plt.plot(fpr, tpr, color='darkorange',lw=lw, label='AUC = %0.2f' % roc_auc)  ###假正率为横坐标,真正率为纵坐标做曲线
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('1 - Specificity')
plt.ylabel('Sensitivity')
# plt.title('ROCs for Densenet')
plt.legend(loc="lower right")

plt.savefig("E:\工作\硕士\博客\博客44-/squares1.png",
            bbox_inches ="tight",
            pad_inches = 1,
            transparent = True,
            facecolor ="w",
            edgecolor ='w',
            dpi=300,
            orientation ='landscape')

输出结果:

需要数据集的家人们可以去百度网盘(永久有效)获取:

链接:https://pan.baidu.com/s/1E59qYZuGhwlrx6gn4JJZTg?pwd=2138
提取码:2138 


更多优质内容持续发布中,请移步主页查看。

   点赞+关注,下次不迷路!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/708096.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多表-DDL以及DQL

多表DDL 个表之间也可能存在关系 存在在一对多和多对多和一对多的关系 一对多&#xff08;外键&#xff09; 在子表建一哥字段&#xff08;列&#xff09;和对应父表关联 父表是一&#xff0c;对应子表的多&#xff08;一个部门对应多个员工&#xff0c;但一个员工只能归属一…

结构体和数据结构--从基本数据类型到抽象数据类型、结构体的定义

在冯-诺依曼体系结构中&#xff0c;程序代码和数据都是以二进制存储的&#xff0c;因此对计算机系统和硬件本身而言&#xff0c;数据类型的概念其实是不存在的。 在高级语言中&#xff0c;为了有效的组织数据&#xff0c;规范数据的使用&#xff0c;提高程序的可读性&#xff0…

使用Streamlit和Matplotlib创建交互式折线图

大家好&#xff0c;本文将介绍使用Streamlit和Matplotlib创建一个用户友好的数据可视化Web应用程序。该应用程序允许上传CSV文件&#xff0c;并为任何选定列生成折线图。 构建Streamlit应用程序 在本文中&#xff0c;我们将指导完成创建此应用程序的步骤。无论你是专家还是刚刚…

three.js利用点材质打造星空

最终效果如图&#xff1a; 一、THREE.BufferGeometry介绍 这里只是做个简单的介绍&#xff0c;详细的介绍大家可以看看THREE.BufferGeometry及其属性介绍 THREE.BufferGeometry是Three.js中的一个重要的类&#xff0c;用于管理和操作几何图形数据。它是对THREE.Geometry的一…

leetcode 226. 翻转二叉树

2023.7.1 这题依旧可以用层序遍历的思路来做。 在层序遍历的代码上将所有节点的左右节点进行互换即可实现二叉树的反转。 下面上代码&#xff1a; class Solution { public:TreeNode* invertTree(TreeNode* root) {queue<TreeNode*> que;if(root nullptr) return{};que…

gradio库中的Dropdown模块:创建交互式下拉菜单

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…

2020年全国硕士研究生入学统一考试管理类专业学位联考逻辑试题——纯享题目版

&#x1f3e0;个人主页&#xff1a;fo安方的博客✨ &#x1f482;个人简历&#xff1a;大家好&#xff0c;我是fo安方&#xff0c;考取过HCIE Cloud Computing、CCIE Security、CISP等证书。&#x1f433; &#x1f495;兴趣爱好&#xff1a;b站天天刷&#xff0c;题目常常看&a…

编译原理期末复习简记(更新中~)

注意&#xff1a;该复习简记只是针对我校期末该课程复习纲要进行的&#xff0c;仅供参考 第一章 引论 编译程序是什么&#xff1f; 编译程序是一个涉及分析和综合的复杂系统 编译程序组成 编译程序通常由以下内容组成 词法分析器 输入 组成源程序的字符串输出 记号/单词序列语法…

Jenkins+Gitlab+Springboot项目部署Jar和image两种方式

Springboot环境准备 利用spring官网快速创建springboot项目。 添加一个controller package com.example.demo;import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController;RestController public class…

新华三眼中的AI天路

ChatGPT的火爆&#xff0c;在全球范围内掀起了新一轮的AI风暴。如今&#xff0c;各行各业都在讨论AI&#xff0c;各个国家都在密集进行新一轮的AI基础设施建设与技术投入。 但眼前的盛景并非突然到来&#xff0c;就拿这一轮大模型热潮来说&#xff0c;谷歌早在2018年底就发布了…

协议速攻 IIC协议详解

介绍 IIC是一种 同步 半双工 串行 总线 同步 指的是同一根时钟线(SCL) 半双工 可以进行双向通信&#xff0c;但是收发不能同时进行&#xff0c;发的时候禁止接收&#xff0c;接的时候禁止发送 串行 数据是一位一位发送的 总线 两根线(SCL SDA)可以接多个IIC类型器件&#…

《统计学习方法》——逻辑斯蒂回归和最大熵模型

参考资料&#xff1a; 《统计学习方法》李航通俗理解信息熵 - 知乎 (zhihu.com)拉格朗日函数为什么要先最大化&#xff1f; - 知乎 (zhihu.com) 1 逻辑斯蒂回归 1.1 逻辑斯蒂回归 输入 x ( x ( 1 ) , x ( 2 ) , ⋯ , x ( n ) , 1 ) T x(x^{(1)},x^{(2)},\cdots,x^{(n)},1…

【动态规划算法练习】day11

文章目录 一、1312. 让字符串成为回文串的最少插入次数1.题目简介2.解题思路3.代码4.运行结果 二、1143. 最长公共子序列1.题目简介2.解题思路3.代码4.运行结果 三、1035. 不相交的线1.题目简介2.解题思路3.代码4.运行结果 总结 一、1312. 让字符串成为回文串的最少插入次数 1…

DevOps系列文章之 设计一个简单的DevOps系统

前置条件 gitlab gitlab-runner k8s docker 1. gitlab创建群组 创建群组的好处是,对项目进行分组,群组内的资源可以共享,这里创建了一个tibos的群组 2. 在群组创建一个项目 这里创建一个空白项目,项目名为Gourd.Test,将项目克隆到本地,然后在该目录下创建一个.net core3.1的w…

Spring Cloud Alibaba Seata源码分析

目录 一、Seata源码分析 1、Seata源码入口 1.1、2.0.0.RELEASE 1.2、2.2.6.RELEASE 2、Seata源码分析-2PC核心源码 3、Seata源码分析-数据源代理 3.1、数据源代理DataSourceProxy 4、Seata源码分析- Seata服务端&#xff08;TC&#xff09;源码 一、Seata源码分析 Sea…

P1dB、IIP3、OIP3、IMD定义及关系

P1dB 1分贝压缩输出功率。放大器有一个线性动态范围&#xff0c;在这个范围内&#xff0c;放大器的输出功率随输入功率线性增加。随着输入功率的继续增加&#xff0c;放大器进入非线性区&#xff0c;其输出功率不再随输入功率的增加而线性增加&#xff0c;也就是说&#xff0c;…

【新星计划·2023】Linux文件权限讲解

作者&#xff1a;Insist-- 个人主页&#xff1a;insist--个人主页 作者会持续更新网络知识和python基础知识&#xff0c;期待你的关注 前言 这篇文章&#xff0c;将带你详细的了解一下 Linux 系统里面有哪些重要的文件&#xff1f;。 不过&#xff0c;每个文件都有相当多的属性…

ROS学习篇之传感器(二)IMU(超核IMU HI266)

文章目录 一.确定IMU型号二.安装驱动1.找到驱动的包2.解压该压缩包3.安装步骤说明4.具体安装5.检查IMU的usb接口是否插到电脑 三.在RVIZ中的显示1.复制示例下的src里的文件复制到自己的src下2.自己的文件目录3.尝试编译一下4.示例的文件说明5.运行Demo6.配置Rviz 四.查看IMU的实…

【深入了解系统性能优化】「实战技术专题」全方面带你透彻探索服务优化技术方案(系统服务调优)

全方面带你透彻探索服务优化技术方案&#xff08;服务器系统性能调优&#xff09; 调优意义计划分析 流程相关分析优化分析Nginx请求服务日志将请求热度最高的接口进行优化异步调用优化方式注意要点 分析调用链路追踪体系建立切面操作分析性能和数据统计存储相关的调用以及耗时…

Pycharm中画图警告:MatplotlibDeprecationWarning

前言&#xff1a; \textcolor{Green}{前言&#xff1a;} 前言&#xff1a; &#x1f49e;这是由于在python中画图出现的问题&#xff0c;一般不会有错。因为它只是个警告&#xff0c;但是我们也可以知道解决这个问题的方法&#xff0c;防止后面出问题的时候知道怎么解决。 前因…