scikit-learn 决策树入门实践 iris花分类

news2024/11/26 4:45:22

背景

为了了解sklearn的API,以及决策树的工作原理,本文以经典的花分类问题为例,编写代码并讲解。最后深入源代码查看其实现

关键词:决策树、基尼系数、决策树可视化、特征重要性。

代码案例

训练决策树

首先要准备数据集,并调用sklearn的API训练决策树。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import matplotlib.pyplot as plt

iris = load_iris()
print("feature names", iris.feature_names)
print("target names", iris.target_names)

X = iris.data[:, 2:]
y = iris.target

print("data shape", iris.data.shape)
print("X shape", X.shape)

输出如下,每个样本有4个特征,且标签有3种取值。

feature names ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
target names ['setosa' 'versicolor' 'virginica']

取出特征集和标签集。本案例每个样本只选取后两个特征"petal length"和"petal width"(通过切片方式iris.data[:, 2:])。

X = iris.data[:, 2:]
y = iris.target

print("data shape", iris.data.shape)
print("X shape", X.shape)

尝试输出特征集如下:

  • data shape (150, 4)的含义是,原本数据集有150个样本,每个样本原本有4个特征。
  • X shape (150, 2)的含义是,由于每个样本只取后两个特征,其列数只为2。
data shape (150, 4)
X shape (150, 2)

之后,创建决策树并拟合,这里设置了最大深度为2,限制决策树的高度最多为2。

tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)

绘制决策树

绘制决策树图片。sklearn提供了plot_tree的接口。

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(
    tree_clf,
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    filled=True
)

# Save picture
fig.savefig("decistion_tree.png")

调用tree_clf.feature_importances_可以输出决策树的特征重要性:

print("feature_importances_", tree_clf.feature_importances_)

其输出如下,两个特征的重要性分别是0.0和1.0,为什么这样呢

feature_importances_ [0. 1.]

回看决策树的图,可以发现两个树中节点的判断条件分别是"petal width <= 0.8"和"petal width <= 1.75",也就是说,只用到了"petal width"这个属性,而没用到"petal length"属性。

另外,可以检验图中节点的基尼系数。取绿色的树节点为例,样本数为54,判断成3种类别的样本数分别是0、49、5。
根据基尼系数公式,计算得到 1 − ( 49 / 54 ) 2 − ( 5 / 54 ) 2 = 0.168 1-(49/54)^2-(5/54)^2=0.168 1(49/54)2(5/54)2=0.168

决策树的基尼系数公式此处不再赘述,不是本文重点

代码案例2 观察特征重要性

上一例生成的决策树节点只用到了pedal width属性,从而难以校验特征重要性的计算公式。本例把决策树的最大深度从2改为3,鼓励其使用到两种属性。

tree_clf = DecisionTreeClassifier(max_depth=3)

其生成的决策树可视化如下:

输出的两个特征的特征重要性分别是0.58和0.41

feature_importances_ [0.58561555 0.41438445]

验证特征重要性

本节将对这两个数据做验证。两种属性各个取值的基尼系数增益如下

  • "petal length<=4.85"的基尼系数增益为 0.043 − 0.444 ∗ 3 46 − 0.0 ∗ 43 46 = 0.014 0.043-0.444 * \frac{3}{46} - 0.0 * \frac{43}{46}=0.014 0.0430.4444630.04643=0.014
  • "petal length<=4.95"的基尼系数增益为 0.168 − 0.041 ∗ 48 54 − 0.444 ∗ 6 54 = 0.082 0.168-0.041 * \frac{48}{54} - 0.444 * \frac{6}{54}=0.082 0.1680.04154480.444546=0.082
  • "petal length<=2.45"的基尼系数增益为 0.667 − 0.0 ∗ 50 150 − 0.5 ∗ 100 150 = 0.33 0.667-0.0 * \frac{50}{150} - 0.5* \frac{100}{150}=0.33 0.6670.0150500.5150100=0.33
  • "petal width<=1.75"的基尼系数增益为 0.5 − 0.168 ∗ 54 100 − 0.043 ∗ 46 100 = 0.38 0.5- 0.168 * \frac{54}{100} - 0.043 * \frac{46}{100}=0.38 0.50.168100540.04310046=0.38

所以,属性的基尼系数增益之和,为其各个划分点的基尼系数加权和。

  • petal length的基尼系数增益之和为 0.014 ∗ 46 / 150 + 0.082 ∗ 54 / 150 + 0.33 ∗ 150 / 150 = 0.36 0.014*46/150+0.082*54/150+0.33*150/150=0.36 0.01446/150+0.08254/150+0.33150/150=0.36
  • petal width的基尼系数增益之和 0.38 ∗ 100 / 150 = 0.25 0.38*100/150=0.25 0.38100/150=0.25

两者作归一化后,得到0.36 / (0.36+0.25)=0.59,似乎与输出有一点偏差,这是由于舍去小数位末尾导致的。

二次验证

合并公式计算并观察可知,加权系数之间的分子分母可以消除。
也就是说 ( 0.043 − 0.444 ∗ 3 46 − 0.0 ∗ 43 46 ) ∗ 46 / 150 = 0.043 ∗ 46 / 150 − 0.444 ∗ 3 / 150 − 0.0 ∗ 43 / 150 = 0.0043 (0.043-0.444 * \frac{3}{46} - 0.0 * \frac{43}{46})*46/150=0.043*46/150-0.444*3/150-0.0*43/150=0.0043 (0.0430.4444630.04643)46/150=0.04346/1500.4443/1500.043/150=0.0043
以此法引用于每个划分点,可以计算得到另外几项:

  • 0.168 ∗ 54 / 150 − 0.041 ∗ 48 / 150 − 0.444 ∗ 6 / 150 = 0.0296 0.168*54/150-0.041 *48/150 - 0.444 *6/150=0.0296 0.16854/1500.04148/1500.4446/150=0.0296
  • 0.667 − 0.0 ∗ 50 / 150 − 0.5 ∗ 100 / 150 = 0.33 0.667-0.0 * 50/150 - 0.5* 100/150=0.33 0.6670.050/1500.5100/150=0.33
  • 0.5 ∗ 100 / 150 − 0.168 ∗ 54 / 150 − 0.043 ∗ 46 / 150 = 0.259 0.5*100/150- 0.168 * 54/150 - 0.043 * 46/150=0.259 0.5100/1500.16854/1500.04346/150=0.259

所以,两个属性的基尼系数增益之和为0.367、和0.259,归一化得到0.586和0.414,非常接近于程序输出

由此,我们可以得出结论,计算某属性的特征重要性,首先要求各个特征值的基尼系数增益,再各自乘以全局加权系数,并求和,

特征重要性的源码实现

建议先阅读参考文章:

  • feature_importances_ - 从决策树到gbdt
  • sklearn源码解析:ensemble模型 零碎记录;如何看sklearn代码,以tree的feature_importance为例

在sklearn,特征重要性的计算核心函数是cpython文件_tree.pyx的compute_feature_importances

    cpdef compute_feature_importances(self, normalize=True):
        """Computes the importance of each feature (aka variable)."""
        cdef Node* left
        cdef Node* right
        cdef Node* nodes = self.nodes
        cdef Node* node = nodes
        cdef Node* end_node = node + self.node_count
 
        cdef double normalizer = 0.
 
        cdef np.ndarray[np.float64_t, ndim=1] importances
        importances = np.zeros((self.n_features,))
        cdef DOUBLE_t* importance_data = <DOUBLE_t*>importances.data
 
        with nogil:
            while node != end_node:
                if node.left_child != _TREE_LEAF:
                    # ... and node.right_child != _TREE_LEAF:
                    left = &nodes[node.left_child]
                    right = &nodes[node.right_child]
 
                    importance_data[node.feature] += (
                        node.weighted_n_node_samples * node.impurity -
                        left.weighted_n_node_samples * left.impurity -
                        right.weighted_n_node_samples * right.impurity)
                node += 1
 
        importances /= nodes[0].weighted_n_node_samples
 
        if normalize:
            normalizer = np.sum(importances)
 
            if normalizer > 0.0:
                # Avoid dividing by zero (e.g., when root is pure)
                importances /= normalizer
 
        return importances

其中,以下代码所做行为就是在计算某特征值的加权基尼系数增益。

importance_data[node.feature] += (
   node.weighted_n_node_samples * node.impurity -
   left.weighted_n_node_samples * left.impurity -
   right.weighted_n_node_samples * right.impurity)

importance_data[node.feature]+=符号代表这个节点的增益值归属于它的所属特征,由于一个特征可能会有多个划分值(比如"petal length<=4.85"和"petal length<=4.95"都属于petal length),所以它们的增益要累加。

.impurity里的其实就是基尼系数。

weighted_n_node_samples 的含义应该是该节点的全局加权系数,即该节点的样本数除以全局样本数 n n o d e / n t o t a l n_{node}/n_{total} nnode/ntotal。比如对于上一例里右下角三个节点的全局加权系数分别是46/150、3/150、43/150。

importances /= nodes[0].weighted_n_node_samples的含义是,最后除以根节点的全局加权系数。但笔者认为通常这个值就是1。

如果设置要进行归一化,就最后除以总和,保证各特征值相加为1。

if normalize:
    normalizer = np.sum(importances)

    if normalizer > 0.0:
        # Avoid dividing by zero (e.g., when root is pure)
        importances /= normalizer

总结

  • 以sklearn基于iris数据集构建决策树为例,实践了构建决策树、可视化决策树的API。
  • 证实了"特征重要性等于基尼系数增益"的说法,以及全局加权系数的含义指 n n o d e / n t o t a l n_{node}/n_{total} nnode/ntotal,手推了计算过程,并结合源码分析进一步作证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/92601.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1.引入——如何创建Spring项目

目录 1.创建SpringBoot项目 1.未安装插件 2.安装插件 2.尝试着运行这个FirstApplication 3.Spring的核心——IOC&#xff08;控制反转&#xff09;/DI的讲解 1.相关概念&#xff1a; 2.什么是IOC&#xff0c;为什么要有IOC&#xff1f; 4.基于XML的方式&#xff0c;演示…

制造企业数字化车间MES系统方案

在市场经济越发严峻的局面下&#xff0c;现代制造业工厂越来越追求效率与精益生产管理&#xff0c;争相通过各种技术手段实现生产线上的现代管理&#xff0c;其中&#xff0c;可视化生产管理技术受到企业的关注&#xff0c;对MES系统也越来越重视。 MES系统解决的问题 1、条码…

094基于nodejs框架的学生作业管理系统vue

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 前端技术&#xff1a;nodejsvueelementui 前端&#xff1a;HTML5,CSS3、JavaScript、VUE 系统分为不同的层次&#xff1a;视图层&#xff08;vue页面…

【C语言数据结构(基础版)】第三站:链表(二)

目录 一、单链表的缺陷以及双向链表的引入 1.单链表的缺陷 2.双向链表的引入 3.八大链表结构 &#xff08;1&#xff09;单向和双向 &#xff08;2&#xff09;带头和不带头 &#xff08;3&#xff09;循环和不循环 &#xff08;4&#xff09;八种链表结构 二、带头双向…

牛掰,阿里技术人刷了四年LeetCode才总结出来的数据结构和算法手册

时间飞逝&#xff0c;转眼间毕业七年多&#xff0c;从事 Java 开发也六年了。我在想&#xff0c;也是时候将自己的 Java 整理成一套体系。 这一次的知识体系面试题涉及到 Java 知识部分、性能优化、微服务、并发编程、开源框架、分布式等多个方面的知识点。 写这一套 Java 面试…

ssm+Vue计算机毕业设计校园疫情防控管理软件(程序+LW文档)

ssmVue计算机毕业设计校园疫情防控管理软件&#xff08;程序LW文档&#xff09; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项…

数据库拆分3--使用sharding-jdbc 子查询注意事项

最近在使用sharding-jdbc来改造项目的时候遇到了一些问题&#xff0c;主要是有关子查询的&#xff0c;记录一下。 在某一个库中新建两张表 CREATE TABLE user_t ( user_id bigint(20) NOT NULL AUTO_INCREMENT, name varchar(255) DEFAULT NULL, age int(8) DEFAULT NU…

世界杯决赛号角吹响!趁周末来搭一套足球3D+AI量化分析系统吧!

2022年卡塔尔世界杯从11月21日开赛至今&#xff0c;即将在12月18日迎来这次赛事的最后高潮。对于大部分热爱世界杯的朋友来说&#xff0c;无论之前是哪队的球迷&#xff0c;现在都在会师决赛的两支队伍上选择站队。从赛事结果看&#xff0c;最终无论哪支队伍夺冠&#xff0c;都…

01背包和完全背包

01背包 最大约数和 题目链接点击这里 题目描述 选取和不超过 SSS 的若干个不同的正整数&#xff0c;使得所有数的约数&#xff08;不含它本身&#xff09;之和最大。 输入格式 输入一个正整数 SSS。 输出格式 输出最大的约数之和。 样例 #1 样例输入 #1 11样例输出 …

有哪些值得推荐的Python学习网站?

我学习的时候&#xff0c;我发现大部分 Python 课程和资源都太通用了。 马上&#xff0c;我想学习如何使用 Python 制作网站。但是 Python 学习资源要我花几个月的时间学习语法&#xff0c;然后才能进入我感兴趣的领域。 这个问题让人感到恐惧和畏惧。我推迟了几个月。每当我…

大学生化妆品网页设计模板代码 化妆美妆网页作业成品 学校美妆官网网页制作模板 学生简单html网站设计成品

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

为什么人家的开源项目文档如此炫酷?原来用的是这款神器

VuePress简介 VuePress是Vue驱动的静态网站生成器。对比我们的Docsify动态生成网站&#xff0c;对SEO更加友好。 使用VuePress具有如下优点&#xff1a; 使用Markdown来写文章&#xff0c;程序员写起来顺手&#xff0c;配置网站非常简洁。 我们可以在Markdown中使用Vue组件&…

所谓工作能力强,其实就这五点

博客主页&#xff1a;https://tomcat.blog.csdn.net 博主昵称&#xff1a;农民工老王 主要领域&#xff1a;Java、Linux、K8S 期待大家的关注&#x1f496;点赞&#x1f44d;收藏⭐留言&#x1f4ac; #mermaid-svg-YapmQUqJ0V32EFv6 {font-family:"trebuchet ms",ve…

用三台云服务器搭建hadoop完全分布式集群

用三台云服务器搭建hadoop完全分布式集群一、硬件准备&#xff08;一&#xff09;集群配置&#xff08;二&#xff09;集群规划&#xff08;三&#xff09;Hadoop、Zookeeper、Java、CentOS版本二、基础环境配置&#xff08;一&#xff09;关闭防火墙&#xff08;二&#xff09…

[附源码]Python计算机毕业设计SSM基于Java的在线点餐系统(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

12.15

JSONP 1) JSONP 是什么 JSONP(JSON with Padding)&#xff0c;只支持 get 请求。 2) JSONP 怎么工作的&#xff1f; 在网页有一些标签天生具有跨域能力&#xff0c;比如&#xff1a;img link iframe script。 JSONP 就是利用 script 标签的跨域能力来发送请求的。 3) JSONP …

为什么你的接口性能差,实际原因就在这里?

一、前言这篇文章咱们来聊一下&#xff0c;百亿级别的海量数据场景下还要支撑每秒十万级别的高并发查询&#xff0c;这个架构该如何演进和设计&#xff1f;咱们先来看看目前系统已经演进到了什么样的架构&#xff0c;大家看看下面的图&#xff1a;首先回顾一下&#xff0c;整个…

三、Node.js模块化基础 2.0

在Node.js中&#xff0c;模块分为核心&#xff08;原生&#xff09;模块和文件&#xff08;自定义&#xff09;模块&#xff0c;核心模块就是Node.js自带的模块&#xff0c;而自定义模块则是开发者自定义的模块&#xff1b; 核心模块 核心模块有 os&#xff0c;fs&#xff0c;…

发送给Java应用程序的所有参数都必须是字符串吗?

问&#xff1a;发送给Java应用程序的所有参数都必须是字符串吗&#xff1f; 答&#xff1a; 应用程序在运行时&#xff0c;Java将所有参数存储为字符串。要使用整型或其他非字符串参数&#xff0c;必须将其进行转换&#xff0c; 问&#xff1a;既然applet是在Web页面中运行&…

大一作业HTML网页作业:中华传统文化题材网页设计5页(纯html+css实现)

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…