决策树实验分析(分类和回归任务,剪枝,数据对决策树影响)

news2024/11/25 10:25:08

目录

1. 前言

2. 实验分析

        2.1 导入包

        2.2 决策树模型构建及树模型的可视化展示

        2.3 概率估计

        2.4 绘制决策边界

        2.5 决策树的正则化(剪枝)

        2.6 对数据敏感

        2.7 回归任务

        2.8 对比树的深度对结果的影响

        2.9 剪枝


1. 前言

        本文主要分析了决策树的分类和回归任务,对比一系列的剪枝的策略对结果的影响,数据对于决策树结果的影响。

        介绍使用graphaviz这个决策树可视化工具

2. 实验分析

        2.1 导入包

#1.导入包
import os
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')

        2.2 决策树模型构建及树模型的可视化展示

        下载安装包:https://graphviz.gitlab.io/_pages/Download/Download_windows.html

         选择一款安装,注意安装时要配置环境变量

        注意这里使用的是鸢尾花数据集,选择花瓣长和宽两个特征

#2.建立树模型
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:,2:] # petal legth and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X,y)
#3.树模型的可视化展示
from sklearn.tree import export_graphviz
export_graphviz(
    tree_clf,
    out_file='iris_tree.dot',
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)

        然后就可以使用graphviz包中的dot.命令工具将此文件转换为各种格式的如pdf,png,如 dot -Tpng iris_tree.png -o iris_tree.png

        可以去文件系统查看,也可以用python展示

from IPython.display import Image
Image(filename='iris_tree.png',width=400,height=400)

        分析:value表示每个节点所有样本中各个类别的样本数,用花瓣宽<=0.8和<=1.75 作为根节点划分,叶子节点表示分类结果,结果执行少数服从多数策略,gini指数随着分类进行在减小。

        2.3 概率估计

        估计类概率 输入数据为:花瓣长5厘米,宽1.5厘米的花。相应节点是深度为2的左节点,因此决策树因输出以下概率:

        iris-Setosa为0%(0/54)

        iris-Versicolor为90.7%(49/54)

        iris-Virginica为9.3%(5/54)

        

#4.概率估计
print(tree_clf.predict_proba([[5,1.5]]))
print(tree_clf.predict([[5,1.5]]))

        2.4 绘制决策边界

        

#5.绘制决策边界
from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf,X,y,axes=[0,7.5,0,3],iris=True,legend=False,plot_training=True):
    #找两个特征 x1 x2
    x1s = np.linspace(axes[0],axes[1],100)
    x2s = np.linspace(axes[2],axes[3],100)
    #构建棋盘
    x1,x2 = np.meshgrid(x1s,x2s)
    #在棋盘中构建待测试数据
    X_new = np.c_[x1.ravel(),x2.ravel()]
    #将预测值算出来
    y_pred = clf.predict(X_new).reshape(x1.shape)
    #选择颜色
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    #绘制并填充不同的区域
    plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)

    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contourf(x1,x2,y_pred,alpha=0.8,cmap=custom_cmap2)
    
    #可以把训练数据展示出来
    if plot_training:
        plt.plot(X[:,0][y==0],X[:,1][y==0],'yo',label='Iris-Setosa')
        plt.plot(X[:,0][y==1],X[:,1][y==1],'bs',label='Iris-Versicolor')
        plt.plot(X[:,0][y==2],X[:,1][y==2],'g^',label='Iris-Virginica')
    if iris:
        plt.xlabel('Petal length',fontsize = 14)
        plt.ylabel('Petal width',fontsize = 14)

    else:
        plt.xlabel(r'$x_1$',fontsize=18)
        plt.ylabel(r'$x_2$',fontsize=18)
    if legend:
        plt.legend(loc='lower right',fontsize=14)
    
plt.figure(figsize=(8,4))
plot_decision_boundary(tree_clf,X,y)
plt.plot([2.45,2.45],[0,3],'k-',linewidth=2)
plt.plot([2.45,7.5],[1.75,1.75],'k--',linewidth=2)
plt.plot([4.95,4.95],[0,1.75],'k:',linewidth=2)
plt.plot([4.85,4.85],[1.75,3],'k:',linewidth=2)
plt.text(1.40,1.0,'Depth=0',fontsize=15)
plt.text(3.2,1.80,'Depth=1',fontsize=13)
plt.text(4.05,0.5,'(Depth=2)',fontsize=11)
plt.title('Decision Tree decision boundareies')

plt.show()

        
    

        可以看出三种不同颜色的代表分类结果,Depth=0可看作第一刀切分,Depth=1,2 看作第二刀,三刀,把数据集切分。

        2.5 决策树的正则化(剪枝)

        决策树的正则化

        DecisionTreeClassifier类还具有一些其他的参数类似地限制了决策树的形状

        min-samples_split(节点在分割之前必须具有的样本数)

        min-samples_leaf(叶子节点必须具有的最小样本数)

        max-leaf_nodes(叶子节点的最大数量)

        max_features(在每个节点处评估用于拆分的最大特征数)

        max_depth(树的最大深度)

#6.决策树正则化
from sklearn.datasets import make_moons
X,y = make_moons(n_samples=100,noise=0.25,random_state=53)
plt.plot(X[:,0],X[:,1],"b.")
tree_clf1 = DecisionTreeClassifier(random_state=42)
tree_clf2 = DecisionTreeClassifier(random_state=42,min_samples_leaf=4)
tree_clf1.fit(X,y)
tree_clf2.fit(X,y)
plt.figure(figsize=(12,4))
plt.subplot(121)
plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('no restriction')
plt.subplot(122)
plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('min_samples_leaf={}'.format(tree_clf2.min_samples_leaf))

        可以看出在没有加限制条件之前,分类器要考虑每个点,模型变得复杂,容易过拟合。其他的一些参数读者可以自行尝试。

        2.6 对数据敏感

        决策树对于数据是很敏感的

        

#6.对数据敏感
np.random.seed(6)
Xs = np.random.rand(100,2) - 0.5
ys = (Xs[:,0] > 0).astype(np.float32) * 2

angle = np.pi /4
rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)
 
tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs,ys)
tree_clf_sr.fit(Xsr,ys)

plt.figure(figsize=(11,4))
plt.subplot(121)
plot_decision_boundary(tree_clf_s,Xs,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')

plt.subplot(122)
plot_decision_boundary(tree_clf_sr,Xsr,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')

plt.show()

         这里是把数据又旋转了45度,然而决策边界并没有也旋转45度,却是变复杂了。可以看出,对于复杂的数据,决策树是很敏感的。

        2.7 回归任务

#7.回归任务 
np.random.seed(42)
m = 200
X = np.random.rand(m,1)
y = 4 * (X-0.5)**2
y = y + np.random.randn(m,1) /10
plt.plot(X,y,'b.')
from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X,y)
from sklearn.tree import export_graphviz
export_graphviz(
    tree_reg,
    out_file='regression_tree.dot',
    feature_names=['X1'],
    rounded=True,
    filled=True
)
from IPython.display import Image
Image(filename='regression_tree.png',width=400,height=400)

 

         回归任务,这里的衡量标准就变成了均方误差。

        2.8 对比树的深度对结果的影响

#8.对比树的深度对结果的影响
from sklearn.tree import DecisionTreeRegressor
tree_reg1 = DecisionTreeRegressor(random_state=42,max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42,max_depth=3)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)

def plot_regression_predictions(tree_reg,X,y,axes=[0,1,-0.2,1],ylabel='$y$'):
    x1 = np.linspace(axes[0],axes[1],500).reshape(-1,1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel('$X_1$',fontsize =18)
    if ylabel:
        plt.ylabel(ylabel,fontsize = 18,rotation=0)
    plt.plot(X,y,'b.')
    plt.plot(x1,y_pred,'r.-',linewidth=2,label=r'$\hat{y}$')


plt.figure(figsize=(11,4))
plt.subplot(121)

plot_regression_predictions(tree_reg1,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    plt.plot([split,split],[-0.2,1],style,linewidth = 2)
plt.text(0.21,0.65,'Depth=0',fontsize= 15)
plt.text(0.01,0.2,'Depth=1',fontsize= 13)
plt.text(0.65,0.8,'Depth=0',fontsize= 13)
plt.legend(loc='upper center',fontsize = 18)
plt.title('max_depth=2',fontsize=14)
plt.subplot(122)
plot_regression_predictions(tree_reg2,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    plt.plot([split,split],[-0.2,1],style,linewidth = 2)
for split in (0.0458,0.1298,0.2873,0.9040):
    plt.plot([split,split],[-0.2,1],linewidth = 1)
plt.text(0.3,0.5,'Depth=2',fontsize= 13)
plt.title('max_depth=3',fontsize=14)

plt.show()

        不同的树的深度,对于结果产生极大的影响

        2.9 剪枝

        

#9.加一些限制
tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42,min_samples_leaf=10)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)

x1 = np.linspace(0,1,500).reshape(-1,1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)

plt.figure(figsize=(11,4))

plt.subplot(121)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred1,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('No restrctions',fontsize =14)

plt.subplot(122)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred2,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('min_samples_leaf={}'.format(tree_reg2.min_samples_leaf),fontsize =14)

plt.show()

        一目了然。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1494399.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

XSS漏洞--概念、类型、实战--分析与详解[结合靶场pikachu]

目录 一、XSS概念简述 1、XSS简介&#xff1a; 2、XSS基本原理&#xff1a; 3、XSS攻击流程&#xff1a; 4、XSS漏洞危害&#xff1a; 二、XSS类型&#xff1a; 1、反射型XSS&#xff1a; 2、存储型XSS&#xff1a; 3、DOM型XSS&#xff1a; 三、靶场漏洞复现(pikach…

C++(13)——string类

string类的由来 C语言中&#xff0c;字符串是以 \0‘ 结尾的一些字符的集合&#xff0c;为例操作方便&#xff0c;C标准库中提供了一些str 系列的库函数&#xff0c;但是这些库函数与字符串是分开的&#xff0c;不太符合OOP&#xff08;封装、继承和多态&#xff09;的思想&am…

云原生基础知识:容器技术的历史

容器化的定义&#xff1a; 容器化是一种轻量级的虚拟化技术&#xff0c;将应用程序及其所有依赖项&#xff08;包括运行时、系统工具、系统库等&#xff09;打包到一个称为容器的单独单元中。容器提供了一种隔离的执行环境&#xff0c;使得应用程序可以在不同的环境中运行&…

splay学习笔记重制版

以前写的学习笔记&#xff1a;传送门 但是之前写的比较杂乱&#xff0c;这里重制一下 问题背景 假设我们要维护一个数据结构&#xff0c;支持插入、删除、查询某个值的排名&#xff0c;查询第 k k k大的值等操作。 最直接的想法是用二叉搜索树&#xff0c;也就是左子树权值&l…

Java | 在消息对话框中显示文本

首先需要导入JOptionPane类&#xff0c;JOptionPane类属于Swing组件中的一种&#xff0c;其导入方式如下&#xff1a; import javax.swing.JOptionPane;可以使用JOptionPane的showMessageDialog方法显示消息文本。 参数格式&#xff1a; JOptionPane.showMessageDialog(paren…

jdk安装,配置path系统变量

直接点击安装 不要包含空格&#xff0c;中文字符 3.找到刚刚的路径&#xff0c;看一下&#xff0c;有东西就说明安装对了 配置path winr输入sysdm.cpl点击确定 全部依次点击 确定 即可。 验证jdk是否安装成功 看java、javac是否可用看java、javac版本号是否无问题 win…

解密程序员的“藏宝图”:我的祖传代码大公开

程序员是如何看待“祖传代码”的&#xff1f; 大家好&#xff0c;我是小明&#xff0c;一位充满好奇心和分享热情的程序员。今天&#xff0c;我要为大家揭开我心中的“藏宝图”——那些我认为值得传世的祖传代码。让我们一同踏上这场奇妙的代码冒险之旅吧&#xff01; 宝物一…

【广度优先搜索】【堆】【C++算法】407. 接雨水 II

作者推荐 【二分查找】【C算法】378. 有序矩阵中第 K 小的元素 本文涉及知识点 广度优先搜索 堆 LeetCoce407. 接雨水 II 给你一个 m x n 的矩阵&#xff0c;其中的值均为非负整数&#xff0c;代表二维高度图每个单元的高度&#xff0c;请计算图中形状最多能接多少体积的雨…

kerberos学习系列一:原理

1、简介 Kerberos 一词来源于古希腊神话中的 Cerberus —— 守护地狱之门的三头犬。 Kerberos 是一种基于加密 Ticket 的身份认证协议。Kerberos 主要由三个部分组成&#xff1a;Key Distribution Center (即KDC)、Client 和 Service。 优势&#xff1a; 密码无需进行网络传…

Tkinter实现聊天气泡对话框

功能展示&#xff1a; 运行环境&#xff1a; Python: 3.10.4 64-bit 操作系统&#xff1a;win10 64-bit 源码文件列表&#xff1a; 部分代码说明&#xff1a; 调用该接口将消息显示在聊天框中。role参数控制消息显示的位置&#xff1a;0位于对话框左边&#xff0c;1位于右边…

批次大小对ES写入性能影响初探

问题背景 ES使用bulk写入时每批次的大小对性能有什么影响&#xff1f;设置每批次多大为好&#xff1f; 一般来说&#xff0c;在Elasticsearch中&#xff0c;使用bulk API进行批量写入时&#xff0c;每批次的大小对性能有着显著的影响。具体来说&#xff0c;当批量请求的大小增…

LLM(十一)| Claude 3:Anthropic发布最新超越GPT-4大模型

2024年3月4日&#xff0c;Anthropic发布最新多模态大模型&#xff1a;Claude 3系列&#xff0c;共有Haiku、Sonnet和Opus三个版本。 Opus在研究生水平专家推理、基础数学、本科水平专家知识、代码等10个维度&#xff0c;超过OpenAI的GPT-4。 Haiku模型更注重效率&#xff0c;能…

Figma 最新版下载:无需激活码,轻松安装!

从事设计工作&#xff0c;怎么能没有设计工具呢&#xff1f;我相信许多设计师也必须使用Figma这样的软件&#xff0c;真的可以让我们的设计工作更有效率&#xff0c;但我相信你也发现Figma属于外国软件&#xff0c;自然语言也是英语&#xff0c;直到现在没有中文版本&#xff0…

论文解读:Hints for Thin Deep Nets

这篇论文是在Hinton的那篇开山之作《Distilling the Knowledge in a Neural Network》为背景提出来的&#xff0c;主要思想是使用一个宽而浅的教师模型来训练一个窄而深的学生模型。之前的知识蒸馏方法主要是训练教师网络到更浅更宽的网络&#xff0c;没有充分利用深度。而该文…

IntelliJ IDEA 下载安装及配置使用教程

一、IDEA下载 1、打开游览器输入IntelliJ IDEA – the Leading Java and Kotlin IDE (jetbrains.com) 2、点击Download&#xff0c;进入IDEA下载界面 3、 有两个版本&#xff0c;一个是Ultimate 版本为旗舰版&#xff0c;需要付费&#xff0c;包括完整的功能&#xff0c;下载后…

element-ui配置

全局配置 完整引入 Element&#xff1a; import Vue from vue; import Element from element-ui; Vue.use(Element, { size: small, zIndex: 3000 });按需引入 Element Vue.prototype.$ELEMENT { size: small, zIndex: 3000 };如果是vue.config.js中配置了externals 使用按…

设计师成长之路1

. 学习的书籍: 1.写给大家看的设计书 2,设计师要懂心理学 3,平面设计完全手册 4.去日本上设计课2:配色设计原理

【C++】102.二叉树的层序遍历

题目描述 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 示例1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;[[3],[9,20],[15,7]]示例 2&#xff1…

fatal: unable to access ‘***‘: OpenSSL SSL_read: SSL_ERROR_SYSCALL, errno 0解决方案

本文收录于《AI绘画从入门到精通》专栏&#xff0c;专栏总目录&#xff1a;点这里。 大家好&#xff0c;我是水滴~~ 本文主要介绍在从 GitHub 上克隆 stable-diffusion-webui 项目时出现的 fatal: unable to access https://github.com/AUTOMATIC1111/stable-diffusion-webui.…

报错:module ‘collections‘ has no attribute ‘Iterable‘

使用python 高版本&#xff0c;在使用collections遇到报错&#xff1a;module ‘collections’ has no attribute ‘Iterable’ 查了资料 在python3.9 之后collections.Iterable被弃用了。 添加修改语句 collections.Iterable collections.abc.Iterable