机器学习-决策树算法原理及实现-附python代码

news2025/2/25 12:37:26

1.决策树-分类树

sklearn.tree.DecisionTreeClassifier官方地址:
https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

在机器学习中,决策树是最常用也是最强大的监督学习算法,决策树主要用于解决分类问题,决策树算法 DecisionTree 是一种树形结构,采用的是自上而下的递归方法。

class sklearn.tree.DecisionTreeClassifier(, criterion=‘gini’, splitter=‘best’, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0)*

1.1基本思想

决策树是以信息熵为度量构造一个熵值下降最快的树,到叶子节点处的熵值为零或最小,此时每个叶子结点中的实例都属于同一个类别。决策树学习算法的最大优点是自学习,在学习过程中只需要对训练实例进行较好的标注,就能够进行学习,是一种无监督的学习。而建立决策树的关键:在当前状态下选择哪个属性作为分类依据,根据不同的目标函数,建立决策树有ID3、C4.5 和 CART 三种算法:

  • ID3:Iterative Dichotomiser 采用信息增益最大的特征,ID3算法的核心是在决策树的各个结点上应用信息增益准则进行特征选择,从根节点开始,对结点计算所有可能特征的信息增益,选择信息增益最大的特征作为结点的特征,并由该特征的不同取值构建子节点;对子节点递归地调用以上方法,构建决策树;直到所有特征的信息增益均很小或者没有特征可选时为止。
  • C4.5:在生产决策树的过程中,使用信息增益比来进行特征选择。
  • CART:Classification And Regression Tree 对于回归树,采用的是平方误差最小化准则;对于分类树,采用基尼指数最小化准则。

在这里插入图片描述

1.2 Parameters

(1) criterion{“gini”, “entropy”, “log_loss”}, default=”gini”
衡量决策树分割质量的函数,有基尼系数、香农信息增益和log_loss,一般默认gini。

(2) splitter{“best”, “random”}, default=”best”
用于在每个节点选择拆分的策略。支持的策略是“最佳”选择最佳分割,“随机”选择最佳随机分割,默认是最佳分割。

(3) max_depth: int, default=None
树的最大深度

(4) min_samples_split: int or float, default=2
拆分内部节点所需的最小样本数

(5) min_samples_leaf: int or float, default=1
叶节点所需的最小样本数

(6) min_weight_fraction_leaf: float, default=0.0
叶节点所需的(所有输入样本的)权重总和的最小加权分数。当未提供sample_weight时,样本具有相等的权重。

(7) max_featuresint, float or {“auto”, “sqrt”, “log2”}, default=None
寻找最佳分割时需要考虑的特征数量

(8) random_stateint, RandomState instance or None, default=None
随机数

(9) max_leaf_nodes: int, default=None
以最佳优先方式生长具有max_leaf_nodes的树。最佳节点定义为节点不纯度的相对减少。如果“无”,则无限制的叶节点数。

(10) min_impurity_decrease: float, default=0.0
如果此拆分导致节点不纯度减少大于或等于此值,则节点将被拆分。

1.3剪枝

决策树对于训练集属于有很好的分类能力,但对未知的测试数据未必有很好的分类能力,泛化能力弱,可能产生过拟合现象,所以必须要剪枝处理。而以上三种决策树的剪枝过程算法相同,区别近似对于当前树的评价标准不同。

剪枝思路

  • 由完全整棵树t0开始,剪枝部分节点得到t1,再次剪枝部分节点得到t2……直到仅剩树根的树tk
  • 在验证数据集上对这k个树分别评价,选择损失函数最小的树

2.代码示例

# -*- coding: utf-8 -*-
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn import metrics
from sklearn.metrics import classification_report


class TreeModel:
    def __int__(self):
        pass

    @staticmethod
    def load_data():
        iris = datasets.load_iris()
        iris_features = iris.data
        iris_target = iris.target
        feature_name = iris.feature_names
        train_x, test_x, train_y, test_y = train_test_split(iris_features,
                                                            iris_target,
                                                            test_size=0.3,
                                                            random_state=123)
        # print(pd.DataFrame(iris_features).corr())
        return train_x, test_x, train_y, test_y, feature_name

    def train_test_model(self):
        train_x, test_x, train_y, test_y, feature_name = self.load_data()
        model = tree.DecisionTreeClassifier(criterion='gini')
        model.fit(train_x, train_y)
        model.score(train_x, train_y)
        y_pre = model.predict(test_x)
        tree_matrix = metrics.confusion_matrix(test_y, y_pre)
        print('混淆矩阵:\n', tree_matrix)
        print('结果分类报告:\n', classification_report(test_y, y_pre))
        # print('准确率:{:.2%}'.format(metrics.accuracy_score(test_y, y_pre)))

        # 特征重要性
        # model.feature_importances_
        feature_important = pd.DataFrame([*zip(feature_name, model.feature_importances_)],
                                         columns=['features', 'Gini importance'])
        print('特征重要度:\n', feature_important.sort_values(by='Gini importance'))
        return model

    @staticmethod
    def plot_tree():
        model = self.train_test_model()
        feature_name = ['sepal length', 'sepal width', 'petal length', 'petal width']
        import graphviz
        dot_data = tree.export_graphviz(model,
                                        out_file=None,
                                        feature_names=feature_name,
                                        class_names=['setosa', 'versicolor', 'virginica'],
                                        filled=True,
                                        rounded=True
                                        )
        graph = graphviz.Source(dot_data)
        graph


if __name__ == '__main__':
    TreeModel().train_test_model()

预测效果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118641.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用双因子认证2FA替换Google authenticator谷歌令牌,助力准上市公司实现等保安全审计

21世纪初,某人力资源科技公司试水HR SaaS赛道,以大客户为目标客群,持续深耕,稳扎稳打,如今已是一家专门为中大型企业提供一体化HR SaaS及人才管理产品/解决方案的头部企业。其产品覆盖了从员工招募、入职、管理到离职的…

Linux系统安装Mysql5.7(详解)

Linux系统上安装软件的3种方式: 本次使用二进制发布包安装方式安装Mysql5.7 (一)下载Mysql5.7的二进制包 这里可以选择去Mysql官网下载,但是由于服务在外国,下载速度实在是太慢了。这里我们可以选择去阿里云的镜像网…

数据通信基础 - 解调技术(PCM)

文章目录1 概述2 脉冲编码调制技术2.1 采样2.2 量化2.3 编码3 扩展3.1 网工软考真题1 概述 #mermaid-svg-K45XtgYRoAw04KU0 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-K45XtgYRoAw04KU0 .error-icon{fill:#5522…

医疗影像工具LEADTOOLS 入门教程: 使用文档编写器创建文档 - 控制台 C#

LEADTOOLS是一个综合工具包的集合,用于将识别、文档、医疗、成像和多媒体技术整合到桌面、服务器、平板电脑、网络和移动解决方案中,是一项企业级文档自动化解决方案,有捕捉,OCR,OMR,表单识别和处理&#x…

【数据结构】Leedcode消失的数字(面试题)

目录 一、题目说明 二、题目解析 一、题目说明 题目链接: leetcode消失的数字 数组nums包含从0到n的所有整数,但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗? 示例1: 输入:[3,0,1] 输出&#x…

菜鸟也能懂的 - 音视频基础知识。

前言 说到视频,大家自己脑子里基本都会想起电影、电视剧、在线视频等等,也会想起一些视频格式 AVI、MP4、RMVB、MKV等等。 但是我们如果认真思考这些应该就有很多疑问,比如以下问题: mp4 和 mkv有什么区别 ? 视频封装…

Lua基本数据类型

Lua官网文档入口 http://www.lua.org/ document --> manual 一、基本数据类型 lua 中有八种基本数据型,分别是: nil,boolean,number,string,function,userdata,thread 和 tab…

vue - - - - - vue-property-decorator的使用

哪有小孩天天哭,哪有赌徒天天输 。遇到不会的技术、知识点,看得多了,掉的坑多了,也就会了。 vue-property-decorator的使用1. 单文件组件写法 - Component的使用2. 组件内使用变量3. 使用计算属性 - get的使用4. 生命周期5. metho…

Nydus 镜像扫描加速

文|余硕上海交通大学22届毕业生阿里云开发工程师从事云原生底层系统的开发和探索工作。本文 6369 字 阅读 16 分钟GitLink 编程夏令营是在 CCF 中国计算机学会指导下,由 CCF 开源发展委员会(CCF ODC)举办的面向全国高校学生的暑期…

Java字符集编码解码详细介绍

文章目录字符集字符集的基本认识字符集编码和解码字符集 字符集的基本认识 字符集基础知识 计算机底层不可以直接存储字符的。计算机中底层只能存储二进制(0、1) 二进制是可以转换成十进制的 计算机底层可以表示十进制编号。计算机可以给人类字符进行编号存储,这套…

【进阶C语言】数据的存储形式

文章目录一.数据类型分类二.整形的存储形式1.源码,反码,补码的关系内存中数据的存储——二进制源码,反码,补码的关系正数负数三.大小端1.概念2.例题:判断当前编译器的存储形式四.浮点数的存储形式1.二进制的补充&#…

【k8s系列】kube-state-metrics中kube_endpoint_address指标

文章目录背景环境操作方法1:kube_endpoint_address_not_ready选择大于0的验证方式1验证方式2方法2:kube_endpoint_address_available选小于0的方法3:kube_endpoint_address{ready"false"}选大于0的解释参考author: ningan123date: …

java基础巩固-宇宙第一AiYWM:为了维持生计,架构知识+分+微序幕就此拉开之RocketM消息中间件~整起

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 RocketMQ一、RocketMQ概念~一览无余1.消息队列有啥用?能干啥?消息队列的应用场景?2.常见的消息队列有哪些?如何进行消息队列的…

2、MySQL支持的数据类型

目录 1、整数类型 (1)fillzero:根据整数类型的长度自动添加0 (2)unsigned:非负整数 (3)bin(m):将十进制数转为m进制 2、日期时间类型 &#x…

【MySQL基础教程】函数的介绍与使用

前言 本文为 【MySQL基础教程】函数的介绍与使用 相关知识,下边具体将对字符串函数,数值函数,日期函数,流程函数等进行详尽介绍~ 📌博主主页:小新要变强 的主页 👉Java全栈学习路线可参考&…

MAXHUB+腾讯会议:为未来办公造一部动力引擎

科技领域有个规律,我们经常高估一年的变化,而低估了十年或者更长时间所可能发生的变化。不信可以做个测试,你觉得未来线上办公会怎么发展?不少朋友会说,既然线上办公是疫情到来之后的PlanB,那么随着疫情结束…

【STM32F4系列】【HAL库】【自制库】ps2手柄模块驱动

外观和电气连接 外观 手柄外观如下 接收器外观 这是接收器和底座 电气连接 需要4根连接线 单片机输出是CLK DO CS 单片机输入是DI 电源电压是3.3-5v 注意模块和单片机共地 模块不支持高速,最大时钟周期约为4us左右 因此使用软件模拟时序的方式来与模块通信 只需要将模块的4根线…

Golang Context 的几种应用场景

Golang context主要用于定义超时取消,取消后续操作,在不同操作中传递值。本文通过简单易懂的示例进行说明。 超时取消 假设我们希望HTTP请求在给定时间内完成,超时自动取消。 首先定义超时上下文,设定时间返回取消函数&#xff…

Apache POI操作百万数据excel实战方案及JDK性能监控工具Jvisualvm实战

百万数据报表概述 文章目录**百万数据报表概述****1、** **概述****2、 JDK性能监控工具介绍****2.1、 Jvisualvm概述****2.2、 Jvisualvm的位置****2.3、 Jvisualvm的使用****3、** **解决方案分析****4**、**百万数据报表导出****4.1** **需求分析****4.2** **解决方案****4.…

玩转门店管理新方法,促进营收利润加倍

门店管理的好坏是门店是否可以运营下去的重要因素,决定了门店的存亡与兴衰。以往很多门店管理者为了更简单方便,采用的是传统方式进行管理。即运用手工的方式记录和计算门店的各种信息。但是随着门店规模的扩大、商品种类的丰富、客户需求的增加以及员工…