机器学习——决策树剪枝算法

news2025/1/18 9:10:08

机器学习——决策树剪枝算法

决策树是一种常用的机器学习模型,它能够根据数据特征的不同进行分类或回归。在决策树的构建过程中,剪枝算法是为了防止过拟合,提高模型的泛化能力而提出的重要技术。本篇博客将介绍剪枝处理的概念、预剪枝和后剪枝方法,以及决策树的损失函数(目标函数),并使用Python实现以上所有的算法。

1. 剪枝处理

在决策树的构建过程中,为了防止过拟合,通常会对生成的决策树进行剪枝处理。剪枝的目的是通过降低树的复杂度来提高模型的泛化能力。

2. 预剪枝与后剪枝

预剪枝是在决策树生成过程中,在决策树生长的过程中,根据一定的条件提前终止分支的生成。常用的预剪枝条件包括限制树的最大深度、叶节点最小样本数等。

后剪枝是在决策树生成完成后,通过一定的方法对决策树进行剪枝。后剪枝的思想是先生成一颗完全生长的决策树,然后根据损失函数(目标函数)对节点进行逐个判断,判断删除某一节点后是否能提高模型的泛化能力,如果能,则删除该节点。

3. 决策树的损失函数

决策树的损失函数(目标函数)是在剪枝过程中判断节点是否应该被剪枝的依据。通常使用的损失函数包括基于误分类率、基尼指数和交叉熵等。

3.1 基于误分类率的损失函数:

C α ( T ) = C ( T ) + α ∣ T ∣ C_{\alpha}(T) = C(T) + \alpha|T| Cα(T)=C(T)+αT

其中, C ( T ) C(T) C(T)是模型对训练数据的误分类率, ∣ T ∣ |T| T是决策树的叶节点个数, α \alpha α是调节参数。

3.2 基于基尼指数的损失函数:

C α ( T ) = C ( T ) + α ∣ T ∣ C_{\alpha}(T) = C(T) + \alpha|T| Cα(T)=C(T)+αT

其中, C ( T ) C(T) C(T)是模型的基尼指数, ∣ T ∣ |T| T是决策树的叶节点个数, α \alpha α是调节参数。

3.3 基于交叉熵的损失函数:

C α ( T ) = C ( T ) + α ∣ T ∣ C_{\alpha}(T) = C(T) + \alpha|T| Cα(T)=C(T)+αT

其中, C ( T ) C(T) C(T)是模型的交叉熵, ∣ T ∣ |T| T是决策树的叶节点个数, α \alpha α是调节参数。

4. Python实现

接下来,将使用Python实现预剪枝和后剪枝两种剪枝算法,并在相同的数据集上进行比较。

4.1 预剪枝算法

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树模型(预剪枝)
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, min_samples_split=5, min_samples_leaf=2, random_state=42)
clf.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Pre-pruning Accuracy:", accuracy)

4.2 后剪枝算法

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

#```python
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树模型(后剪枝)
clf = DecisionTreeClassifier(criterion='entropy', random_state=42)
clf.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Before Pruning Accuracy:", accuracy)

# 后剪枝
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(criterion='entropy', random_state=42, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)

train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

best_clf = clfs[test_scores.index(max(test_scores))]
y_pred = best_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("After Pruning Accuracy:", accuracy)

示例

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树模型(后剪枝)
clf = DecisionTreeClassifier(criterion='entropy', random_state=42)
clf.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Before Pruning Accuracy:", accuracy)

# 后剪枝
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(criterion='entropy', random_state=42, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)

train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

best_clf = clfs[test_scores.index(max(test_scores))]
y_pred = best_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("After Pruning Accuracy:", accuracy)

# 绘制准确率随着剪枝参数的变化曲线
plt.figure(figsize=(10, 6))
plt.plot(ccp_alphas, train_scores, marker='o', label='Train', drawstyle="steps-post")
plt.plot(ccp_alphas, test_scores, marker='o', label='Test', drawstyle="steps-post")
plt.xlabel("CCP Alpha")
plt.ylabel("Accuracy")
plt.title("Accuracy vs. CCP Alpha for Decision Tree Pruning")
plt.legend()
plt.show()

在这里插入图片描述

5. 总结

本篇博客介绍了决策树的剪枝算法,包括预剪枝和后剪枝两种方法,以及决策树的损失函数(目标函数)。通过Python实现了预剪枝和后剪枝算法,并在相同的数据集上进行了比较。

预剪枝通过限制决策树的生长来防止过拟合,但可能会导致欠拟合。后剪枝是在决策树生成完成后,通过一定的方法对决策树进行剪枝,可以更好地提高模型的泛化能力。在实际应用中,需要根据具体问题的特点和数据集的情况选择合适的剪枝算法,并通过调参来优化模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1539033.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《优化接口设计的思路》系列:第九篇—用好缓存,让你的接口速度飞起来

一、前言 大家好!我是sum墨,一个一线的底层码农,平时喜欢研究和思考一些技术相关的问题并整理成文,限于本人水平,如果文章和代码有表述不当之处,还请不吝赐教。 作为一名从业已达六年的老码农&#xff0c…

vue2 自定义 v-model (model选项的使用)

效果预览 model 选项的语法 每个组件上只能有一个 v-model。v-model 默认会占用名为 value 的 prop 和名为 input 的事件,即 model 选项的默认值为 model: {prop: "value",event: "input",},通过修改 model 选项,即可自定义v-model …

35 跨域相关问题, 以及常见的解决方式

前言 跨域相关 这是一个 经常会碰到的问题 然后 常见的解决方式 也大概就是几种, 各有各的问题 这里仅仅是 从理论上 来探讨这个问题 主流的解决方式 是通过代理, 将不同域 合并到同一个域 测试用例 测试用例如下, 这里仅仅是一个简单的数据展示 获取对方 “/config.jso…

【c++入门】引用,内联函数,auto

🔥个人主页:Quitecoder 🔥专栏:c笔记仓 朋友们大家好,本节我们来到c中一个重要的部分:引用 目录 1.引用的基本概念与用法1.1引用特性1.2使用场景1.3传值、传引用效率比较1.4引用做返回值1.5引用和指针的对…

Kubernetes(k8s)集群健康检查常用的五种指标

文章目录 1、节点健康指标2、Pod健康指标3、服务健康指标4、网络健康指标5、存储健康指标 1、节点健康指标 节点状态:检查节点是否处于Ready状态,以及是否存在任何异常状态。 资源利用率:监控节点的CPU、内存、磁盘等资源的使用情况&#xf…

SpringCloud从入门到精通速成(二)

文章目录 1.Nacos配置管理1.1.统一配置管理1.1.1.在nacos中添加配置文件1.1.2.从微服务拉取配置 1.2.配置热更新1.2.1.方式一1.2.2.方式二 1.3.配置共享1)添加一个环境共享配置2)在user-service中读取共享配置3)运行两个UserApplication&…

c语言食堂就餐排队问题290行

定制魏:QTWZPW,获取更多源码等 目录 题目 数据结构 函数设计 结构设计 总结 效果截图 ​ 主函数代码 题目 设计一个程序来模拟食堂就餐排队问题,通过输入学生人数和面包数量,计算有多少学生能够吃到午餐。 数据结构 该…

原神x星穹铁道文本转原神语音源码

《原神》x《星穹铁道》文本转原神语音源码介绍文案 探索未知的奇幻世界,与心仪的角色共舞冒险之旅——《原神》与《星穹铁道》的梦幻联动,为你带来前所未有的游戏体验!而此刻,我们将为你揭秘一项革命性的创新:文本转原…

T470 双电池机制

ThinkPad系列电脑牛黑科技双电池管理体系技术,你知道吗? - 北京正方康特联想电脑代理商 上文的地址 在放电情况下:优先让外置电池放电,当放到一定电量后开始让内置电池放电。 在充电情况下:优先给内置电池充电,当充…

数据结构从入门到精通——希尔排序

希尔排序 前言一、希尔排序( 缩小增量排序 )二、希尔排序的特性总结三、希尔排序动画演示四、希尔排序具体代码实现test.c 前言 希尔排序是一种基于插入排序的算法,通过比较相距一定间隔的元素来工作,各趟比较所用的距离随着算法的进行而减小&#xff0…

c++核心学习5

4.6继承 有些类与类之间存在特殊的关系,例如下图中: 我们发现,定义这些类时,下级别的成员除了拥有上一级的共性,还有自己的特性。这个时候我们就可以考虑利用继承的技术,减少重复代码 4.6.1继承的基本语法…

学点儿Java_Day9_字符串操作

1 实现trim方法 实现简单的trim方法,实现传入一个字符串,返回忽略前导空格和尾部空格。 public String myTrim(String str) {if (str null || str.isEmpty()) {//"".equals(str)return null;}char[] chars str.toCharArray();int start 0…

GD32串口通信PB6,PB7

我发现GD32很多接口都需要冲映射,刚开始还是不习惯,还要打开要选打开AFIO时钟。算了,直接看代码: 1,usart.c //#include "usart.h"//void USART_GPIO_init(void) //{ // //初始化引脚 // rcu_periph_clock_enable(RCU…

Qt打开已有工程方法

在Qt中,对于一个已有工程如何进行打开? 1、首先打开Qt Creator 2、点击文件->打开文件或项目,找到对应文件夹下的.pro文件并打开 3、点击配置工程 这样就打开对应的Qt项目了,点击运行即可看到对应的效果 Qt开发涉及界面修饰…

网络工程师笔记15(OSPF协议-2)

OSPF协议 OSPF是典型的链路状态路由协议,是目前业内使用非常广泛的 IGP 协议之一。 Router-ID(Router ldentifier,路由器标识符),用于在一个 OSPF 域中唯一地标识一台路由器。Router-ID 的设定可以通过手工配置的方式,或使用系统自…

宏集PLC如何应用于建筑的3D打印?

案例概况 客户:Rebuild 合作伙伴:ASTOR 应用:用于建筑的大尺寸3D打印 应用产品:3D混凝土打印机 一、应用背景 自从20世纪80年代以来,增材制造技术(即3D打印)不断发展。大部分3D打印技术应…

day11【网络编程】-综合案例

day11【网络编程】 第三章 综合案例 3.1 文件上传案例 文件上传分析图解 【客户端】输入流,从硬盘读取文件数据到程序中。【客户端】输出流,写出文件数据到服务端。【服务端】输入流,读取文件数据到服务端程序。【服务端】输出流&#xf…

scDEA一键汇总12种单细胞差异分析方法 DESeq2、edgeR、MAST、monocle、scDD、Wilcoxon

问题来源 单细胞可以做差异分析,但是究竟选择哪种差异分析方法最靠谱呢? 解决办法 于是我去检索文献,是否有相关研究呢? https://academic.oup.com/bib/article/23/1/bbab402/6375516 文章指出,现有的差异分析方法…

Linux基础-Makefile

目录 一、Make简介 二、Makefile基本结构 示例: 补充(Makefile): 伪目标: 三、创建和使用变量 变量定义的方式: 简单方式: 递归方式: 用?定义变量 为变量添加值 预定义变量 例 自动变量 例 …

数据结构从入门到精通——快速排序

快速排序 前言一、快速排序的基本思想常见方式通用模块 二、快速排序的特性总结三、三种快速排序的动画展示四、hoare版本快速排序的代码展示普通版本优化版本为什么要优化快速排序代码三数取中法优化代码 五、挖坑法快速排序的代码展示六、前后指针快速排序的代码展示七、非递…