【大数据】机器学习------决策树

news2025/1/16 0:06:45

一、基本流程

决策树是一种基于树结构的分类和回归方法,它通过对特征空间进行划分,每个内部节点表示一个特征测试,每个分支代表一个测试输出,每个叶节点代表一个类别或回归值。

在这里插入图片描述

  1. 特征选择:根据某种准则(如信息增益、信息增益比、基尼指数等)选择最优特征进行划分。
  2. 决策树生成:从根节点开始,根据选定的特征对样本进行划分,生成子节点,递归地构建决策树。
  3. 决策树剪枝:通过剪枝处理防止过拟合,提高决策树的泛化能力。

二、划分选择

1. 信息增益(ID3 算法)

信息增益表示得知特征 X X X 的信息而使得类 Y Y Y 的信息不确定性减少的程度。

设数据集 D D D 的信息熵为:

在这里插入图片描述

其中 C k C_k Ck 是类别 k k k 的样本集合, ∣ C k ∣ |C_k| Ck 是类别 k k k 的样本数量, ∣ D ∣ |D| D 是数据集 D D D 的样本总数。

对于特征 A A A,信息增益为:
在这里插入图片描述

2. 信息增益比(C4.5 算法)

信息增益比克服了信息增益偏向于选择取值较多的特征的问题,定义为:

在这里插入图片描述

3. 基尼指数(CART 算法)

基尼指数表示集合的不确定性,对于数据集 D D D
在这里插入图片描述

其中 在这里插入图片描述

对于特征 A A A 的基尼指数:
在这里插入图片描述

三、剪枝处理

1. 预剪枝

在决策树生成过程中,对每个节点在划分前进行估计,如果当前节点的划分不能带来决策树泛化性能的提升,则停止划分。

2. 后剪枝

先生成完整的决策树,然后自底向上对非叶节点进行考察,若将其替换为叶节点能提高泛化性能,则进行剪枝。

四、连续与缺失值处理

1. 连续值处理

对于连续特征,通常将其离散化,如采用二分法,将连续特征的取值排序,取相邻值的平均值作为划分点,计算不同划分点的信息增益(或其他指标),选择最优划分点。

2. 缺失值处理

  • 样本权重调整:对于含有缺失值的样本,根据无缺失值样本中该特征的取值分布,将其以一定权重划分到不同子节点。
  • 属性值填充:使用一些策略(如均值、中位数、众数)填充缺失值。

五、多变量决策树

多变量决策树不是为每个节点寻找一个最优划分属性,而是试图建立一个线性组合作为划分属性,如:
在这里插入图片描述

六、代码示例

1. 使用 sklearn 实现决策树分类

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target


# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


# 初始化决策树分类器,使用信息增益(默认)
clf = DecisionTreeClassifier(criterion='entropy')


# 训练模型
clf.fit(X_train, y_train)


# 预测
y_pred = clf.predict(X_test)


# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

在这里插入图片描述

2. 自定义决策树(使用信息增益)

import numpy as np


def entropy(y):
    """计算信息熵"""
    unique_labels = np.unique(y)
    entropy = 0
    for label in unique_labels:
        p = np.mean(y == label)
        entropy -= p * np.log2(p)
    return entropy


def information_gain(X, y, feature_index):
    """计算信息增益"""
    base_entropy = entropy(y)
    values = np.unique(X[:, feature_index])
    new_entropy = 0
    for value in values:
        sub_y = y[X[:, feature_index] == value]
        p = len(sub_y) / len(y)
        new_entropy += p * entropy(sub_y)
    return base_entropy - new_entropy


class Node:
    """决策树节点类"""
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value


def build_tree(X, y, depth=0, max_depth=5):
    """构建决策树"""
    if len(np.unique(y)) == 1:
        return Node(value=y[0])
    if depth >= max_depth:
        return Node(value=np.bincount(y).argmax())
    n_features = X.shape[1]
    best_gain = 0
    best_feature = None
    best_threshold = None
    for feature_index in range(n_features):
        gain = information_gain(X, y, feature_index)
        if gain > best_gain:
            best_gain = gain
            best_feature = feature_index
    if best_gain == 0:
        return Node(value=np.bincount(y).argmax())
    feature_values = np.unique(X[:, best_feature])
    best_threshold = (feature_values[:-1] + feature_values[1:]) / 2
    best_threshold = best_threshold[np.argmax([information_gain(X, y, best_feature, t) for t in best_threshold])]
    left_indices = X[:, best_feature] <= best_threshold
    right_indices = X[:, best_feature] > best_threshold
    left = build_tree(X[left_indices], y[left_indices], depth + 1, max_depth)
    right = build_tree(X[right_indices], y[right_indices], depth + 1, max_depth)
    return Node(feature_index=best_feature, threshold=best_threshold, left=left, right=right)


def predict_sample(node, sample):
    """预测单个样本"""
    if node.value is not None:
        return node.value
    if sample[node.feature_index] <= node.threshold:
        return predict_sample(node.left, sample)
    else:
        return predict_sample(node.right, sample)


def predict(tree, X):
    """预测多个样本"""
    return np.array([predict_sample(tree, sample) for sample in X])


# 示例数据
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([0, 1, 1, 0])


# 构建决策树
tree = build_tree(X, y, max_depth=3)


# 预测
y_pred = predict(tree, X)
print(y_pred)

在这里插入图片描述

代码解释

1. 使用 sklearn 实现决策树分类

  • load_iris() 加载鸢尾花数据集。
  • DecisionTreeClassifier(criterion='entropy') 初始化一个使用信息熵作为划分准则的决策树分类器。
  • clf.fit(X_train, y_train) 训练模型。
  • clf.predict(X_test) 对测试集进行预测。
  • accuracy_score(y_test, y_pred) 计算准确率。

2. 自定义决策树(使用信息增益)

  • entropy(y) 函数计算信息熵。
  • information_gain(X, y, feature_index) 计算信息增益。
  • Node 类定义决策树的节点。
  • build_tree(X, y, depth=0, max_depth=5) 递归构建决策树,使用信息增益选择特征和阈值进行划分。
  • predict_sample(node, sample) 对单个样本进行预测。
  • predict(tree, X) 对多个样本进行预测。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2277247.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务器数据恢复—raid5故障导致上层ORACLE无法启动的数据恢复案例

服务器数据恢复环境&故障&#xff1a; 一台服务器上的8块硬盘组建了一组raid5磁盘阵列。上层安装windows server操作系统&#xff0c;部署了oracle数据库。 raid5阵列中有2块硬盘的硬盘指示灯显示异常报警。服务器操作系统无法启动&#xff0c;ORACLE数据库也无法启动。 服…

Day05-后端Web基础——TomcatServletHTTP协议SpringBootWeb入门

目录 Web基础知识课程内容1. Tomcat1.1 简介1.2 基本使用1.2.1 下载1.2.2 安装与卸载1.2.3 启动与关闭1.2.4 常见问题 2. Servlet2.1 快速入门2.1.1 什么是Servlet2.1.2 入门程序2.1.3 注意事项 2.2 执行流程 3. HTTP协议3.1 HTTP-概述3.1.1 介绍3.1.2 特点 3.2 HTTP-请求协议3…

【已解决】【记录】2AI大模型web UI使用tips 本地

docker desktop使用 互动 如果需要发送网页链接&#xff0c;就在链接上加上【#】号 如果要上传文件就点击这个➕号 中文回复 命令它只用中文回复&#xff0c;在右上角打开【对话高级设置】 输入提示词&#xff08;提示词使用英文会更好&#xff09; Must reply to the us…

Deep4SNet: deep learning for fake speech classification

Deep4SNet&#xff1a;用于虚假语音分类的深度学习 摘要&#xff1a; 虚假语音是指即使通过人工智能或信号处理技术产生的语音记录。生成虚假录音的方法有"深度语音"和"模仿"。在《深沉的声音》中&#xff0c;录音听起来有点合成&#xff0c;而在《模仿》中…

Docker save load 镜像 tag 为 <none>

一、场景分析 我从 docker hub 上拉了这么一个镜像。 docker pull tomcat:8.5-jre8-alpine 我用 docker save 命令想把它导出成 tar 文件以便拷贝到内网机器上使用。 docker save -o tomcat-8.5-jre8-alpine.tar.gz 镜像ID 当我把这个镜像传到别的机器&#xff0c;并用 dock…

备战蓝桥杯 队列和queue详解

目录 队列的概念 队列的静态实现 总代码 stl的queue 队列算法题 1.队列模板题 2.机器翻译 3.海港 双端队列 队列的概念 和栈一样&#xff0c;队列也是一种访问受限的线性表&#xff0c;它只能在表头位置删除&#xff0c;在表尾位置插入&#xff0c;队列是先进先出&…

工厂物流管理系统方案(二):危险品车辆专用导航系统架构设计深度剖析

本文专为IT架构师、物流技术专家、软件开发工程师及对危险品运输导航技术有深入探索需求的读者撰写&#xff0c;旨在全面解析危险品车辆专用导航系统的架构设计&#xff0c;展现其技术深度与复杂性&#xff0c;为行业同仁提供权威的技术参考与实践指导。如需获取危险品车辆专用…

用 Python 从零开始创建神经网络(十九):真实数据集

真实数据集 引言数据准备数据加载数据预处理数据洗牌批次&#xff08;Batches&#xff09;训练&#xff08;Training&#xff09;到目前为止的全部代码&#xff1a; 引言 在实践中&#xff0c;深度学习通常涉及庞大的数据集&#xff08;通常以TB甚至更多为单位&#xff09;&am…

No.1|Godot|俄罗斯方块复刻|棋盘和初始方块的设置

删掉基础图标新建assets、scenes、scripts文件夹 俄罗斯方块的每种方块都是由四个小方块组成的&#xff0c;很适合放在网格地图中 比如网格地图是宽10列&#xff0c;高20行 要实现网格的对齐和下落 Node2D节点 新建一个Node2D 添加2个TileMapLayer 一个命名为Board&…

蓝桥云客第 5 场 算法季度赛

题目&#xff1a; 2.开赛主题曲【算法赛】 - 蓝桥云课 问题描述 蓝桥杯组委会创作了一首气势磅礴的开赛主题曲&#xff0c;其歌词可用一个仅包含小写字母的字符串 S 表示。S 中的每个字符对应一个音高&#xff0c;音高由字母表顺序决定&#xff1a;a1,b2,...,z26。字母越靠后…

刀客doc:快手的商业化架构为什么又调了?

一、 1月10日&#xff0c;快手商业化及电商事业部进行新一轮的架构调整。作为2025年快手的第一次大调整&#xff0c;变动最大的是负责广告业务的商业化事业部。快手商业化将原来的8个业务中心&#xff0c;现在统合成了5个&#xff0c;行业归拢看上去更加明晰了。 根据自媒体《…

6.2 MySQL时间和日期函数

以前我们就用过now()函数来获得系统时间&#xff0c;用datediff()函数来计算日期相差的天数。我们在计算工龄的时候&#xff0c;让两个日期相减。那么其中的这个now函数返回的就是当前的系统日期和时间。 1. 获取系统时间函数 now()函数&#xff0c;返回的这个日期和时间的格…

mock服务-通过json定义接口自动实现mock服务

go-mock介绍 不管在前端还是后端开发过程中&#xff0c;当我们需要联调其他服务的接口&#xff0c;而这个服务还没法提供调用时&#xff0c;那我们就要用到mock服务&#xff0c;自己按接口文档定义一个临时接口返回指定数据&#xff0c;以供本地开发联调测试。 怎么快速启动一…

sparkSQL练习

1.前期准备 &#xff08;1&#xff09;建议先把这两篇文章都看一下吧&#xff0c;然后把这个项目也搞下来 &#xff08;2&#xff09;看看这个任务 &#xff08;3&#xff09;score.txt student_id,course_code,score 108,3-105,99 105,3-105,88 107,3-105,77 105,3-245,87 1…

CSS | 实现三列布局(两边边定宽 中间自适应,自适应成比)

目录 示例1 &#xff08;中间自适应 示例2&#xff08;中间自适应 示例3&#xff08;中间自适应 示例4 &#xff08;自适应成比 示例5&#xff08;左中定宽&#xff0c;右边自适应 示例6&#xff08;中间自适应 示例7&#xff08;中间自适应 示例8&#xff08;中间定宽…

力扣 子集

回溯基础&#xff0c;一题多解&#xff0c;不同的回朔过程。 题目 求子集中&#xff0c;数组的每种元素有选与不选两种状态。因此在使用dfs与回溯时把每一个元素分别进行选与不选的情况考虑即可。可以先用dfs跳过当前元素即不选然后一直深层挖下去&#xff0c;直到挖到最深了即…

网络层协议-----IP协议

目录 1.认识IP地址 2.IP地址的分类 3.子网划分 4.公网IP和私网IP 5.IP协议 6.如何解决IP地址不够用 1.认识IP地址 IP 地址&#xff08;Internet Protocol Address&#xff09;是指互联网协议地址。 它是分配给连接到互联网的设备&#xff08;如计算机、服务器、智能手机…

RocketMQ 知识速览

文章目录 一、消息队列对比二、RocketMQ 基础1. 消息模型2. 技术架构3. 消息类型4. 消费者类型5. 消费者分组和生产者分组 三、RocketMQ 高级1. 如何解决顺序消费和重复消费2. 如何实现分布式事务3. 如何解决消息堆积问题4. 如何保证高性能读写5. 刷盘机制 &#xff08;topic 模…

C++(类和对象)

C中的类 C中兼容对C语言中struct的所有用法.同时C对struct进行了语法的升级.将struct升级成了类. // c中对于struct的改进: struct Stack {int* a;int top;int capacity; } int main() { Stack s;// 这里可以直接使用Stack进行使用,而不再需要struct关键字了return 0; }注意:…

centos 8 中安装Docker

注&#xff1a;本次样式安装使用的是centos8 操作系统。 1、镜像下载 具体的镜像下载地址各位可以去官网下载&#xff0c;选择适合你们的下载即可&#xff01; 1、CentOS官方下载地址&#xff1a;https://vault.centos.org/ 2、阿里云开源镜像站下载&#xff1a;centos安装包…