机器学习:回归决策树(Python)

news2025/1/19 2:41:06

一、平方误差的计算

square_error_utils.py

import numpy as np


class SquareErrorUtils:
    """
    平方误差最小化准则,选择其中最优的一个作为切分点
    对特征属性进行分箱处理
    """
    @staticmethod
    def _set_sample_weight(sample_weight, n_samples):
        """
        扩展到集成学习,此处为样本权重的设置
        :param sample_weight: 各样本的权重
        :param n_samples: 样本量
        :return:
        """
        if sample_weight is None:
            sample_weight = np.asarray([1.0] * n_samples)
        return sample_weight

    @staticmethod
    def square_error(y, sample_weight):
        """
        平方误差
        :param y: 当前划分区域的目标值集合
        :param sample_weight: 当前样本的权重
        :return:
        """
        y = np.asarray(y)
        return np.sum((y - y.mean()) ** 2 * sample_weight)

    def cond_square_error(self, x, y, sample_weight):
        """
        计算根据某个特征x划分的区域中y的误差值
        :param x: 某个特征划分区域所包含的样本
        :param y: x对应的目标值
        :param sample_weight: 当前x的权重
        :return:
        """
        x, y = np.asarray(x), np.asarray(y)
        error = 0.0
        for x_val in set(x):
            x_idx = np.where(x == x_val)  # 按区域计算误差
            new_y = y[x_idx]  # 对应区域的目标值
            new_sample_weight = sample_weight[x_idx]
            error += self.square_error(new_y, new_sample_weight)
        return error

    def square_error_gain(self, x, y, sample_weight=None):
        """
        平方误差带来的增益值
        :param x: 某个特征变量
        :param y: 对应的目标值
        :param sample_weight: 样本权重
        :return:
        """
        sample_weight = self._set_sample_weight(sample_weight, len(x))
        return self.square_error(y, sample_weight) - self.cond_square_error(x, y, sample_weight)
    

 二、树的结点信息封装


class TreeNode_R:
    """
    决策树回归算法,树的结点信息封装,实体类:setXXX()、getXXX()
    """
    def __init__(self, feature_idx: int = None, feature_val=None, y_hat=None, square_error: float = None,
                 criterion_val=None, n_samples: int = None, left_child_Node=None, right_child_Node=None):
        """
        决策树结点信息封装
        :param feature_idx: 特征索引,如果指定特征属性的名称,可以按照索引取值
        :param feature_val: 特征取值
        :param square_error: 划分结点的标准:当前结点的平方误差
        :param n_samples: 当前结点所包含的样本量
        :param y_hat: 当前结点的预测值:Ci
        :param left_child_Node: 左子树
        :param right_child_Node: 右子树
        """
        self.feature_idx = feature_idx
        self.feature_val = feature_val
        self.criterion_val = criterion_val
        self.square_error = square_error
        self.n_samples = n_samples
        self.y_hat = y_hat
        self.left_child_Node = left_child_Node  # 递归
        self.right_child_Node = right_child_Node  # 递归

    def level_order(self):
        """
        按层次遍历树...
        :return:
        """
        pass

    # def get_feature_idx(self):
    #     return self.get_feature_idx()
    #
    # def set_feature_idx(self, feature_idx):
    #     self.feature_idx = feature_idx


三、回归决策树CART算法实现

import numpy as np
from utils.square_error_utils import SquareErrorUtils
from utils.tree_node_R import TreeNode_R
from utils.data_bin_wrapper import DataBinsWrapper


class DecisionTreeRegression:
    """
    回归决策树CART算法实现:按照二叉树构造
    1. 划分标准:平方误差最小化
    2. 创建决策树fit(),递归算法实现,注意出口条件
    3. 预测predict_proba()、predict() --> 对树的搜索
    4. 数据的预处理操作,尤其是连续数据的离散化,分箱
    5. 剪枝处理
    """
    def __init__(self, criterion="mse", max_depth=None, min_sample_split=2, min_sample_leaf=1,
                 min_target_std=1e-3, min_impurity_decrease=0, max_bins=10):
        self.utils = SquareErrorUtils()  # 结点划分类
        self.criterion = criterion  # 结点的划分标准
        if criterion.lower() == "mse":
            self.criterion_func = self.utils.square_error_gain  # 平方误差增益
        else:
            raise ValueError("参数criterion仅限mse...")
        self.min_target_std = min_target_std  # 最小的样本目标值方差,小于阈值不划分
        self.max_depth = max_depth  # 树的最大深度,不传参,则一直划分下去
        self.min_sample_split = min_sample_split  # 最小的划分结点的样本量,小于则不划分
        self.min_sample_leaf = min_sample_leaf  # 叶子结点所包含的最小样本量,剩余的样本小于这个值,标记叶子结点
        self.min_impurity_decrease = min_impurity_decrease  # 最小结点不纯度减少值,小于这个值,不足以划分
        self.max_bins = max_bins  # 连续数据的分箱数,越大,则划分越细
        self.root_node: TreeNode_R() = None  # 回归决策树的根节点
        self.dbw = DataBinsWrapper(max_bins=max_bins)  # 连续数据离散化对象
        self.dbw_XrangeMap = {}  # 存储训练样本连续特征分箱的端点

    def fit(self, x_train, y_train, sample_weight=None):
        """
        回归决策树的创建,递归操作前的必要信息处理(分箱)
        :param x_train: 训练样本:ndarray,n * k
        :param y_train: 目标集:ndarray,(n, )
        :param sample_weight: 各样本的权重,(n, )
        :return:
        """
        x_train, y_train = np.asarray(x_train), np.asarray(y_train)
        self.class_values = np.unique(y_train)  # 样本的类别取值
        n_samples, n_features = x_train.shape  # 训练样本的样本量和特征属性数目
        if sample_weight is None:
            sample_weight = np.asarray([1.0] * n_samples)
        self.root_node = TreeNode_R()  # 创建一个空树
        self.dbw.fit(x_train)
        x_train = self.dbw.transform(x_train)
        self._build_tree(1, self.root_node, x_train, y_train, sample_weight)

    def _build_tree(self, cur_depth, cur_node: TreeNode_R, x_train, y_train, sample_weight):
        """
        递归创建回归决策树算法,核心算法。按先序(中序、后序)创建的
        :param cur_depth: 递归划分后的树的深度
        :param cur_node: 递归划分后的当前根结点
        :param x_train: 递归划分后的训练样本
        :param y_train: 递归划分后的目标集合
        :param sample_weight: 递归划分后的各样本权重
        :return:
        """
        n_samples, n_features = x_train.shape  # 当前样本子集中的样本量和特征属性数目
        # 计算当前数结点的预测值,即加权平均值,
        cur_node.y_hat = np.dot(sample_weight / np.sum(sample_weight), y_train)
        cur_node.n_samples = n_samples

        # 递归出口判断
        cur_node.square_error = ((y_train - y_train.mean()) ** 2).sum()
        # 所有的样本目标值较为集中,样本方差非常小,不足以划分
        if cur_node.square_error <= self.min_target_std:
            # 如果为0,则表示当前样本集合为空,递归出口3
            return
        if n_samples < self.min_sample_split:  # 当前结点所包含的样本量不足以划分
            return
        if self.max_depth is not None and cur_depth > self.max_depth:  # 树的深度达到最大深度
            return

        # 划分标准,选择最佳的划分特征及其取值
        best_idx, best_val, best_criterion_val = None, None, 0.0
        for k in range(n_features):  # 对当前样本集合中每个特征计算划分标准
            for f_val in sorted(np.unique(x_train[:, k])):  # 当前特征的不同取值
                region_x = (x_train[:, k] <= f_val).astype(int)  # 是当前取值f_val就是1,否则就是0
                criterion_val = self.criterion_func(region_x, y_train, sample_weight)
                if criterion_val > best_criterion_val:
                    best_criterion_val = criterion_val  # 最佳的划分标准值
                    best_idx, best_val = k, f_val  # 当前最佳特征索引以及取值

        # 递归出口的判断
        if best_idx is None:  # 当前属性为空,或者所有样本在所有属性上取值相同,无法划分
            return
        if best_criterion_val <= self.min_impurity_decrease:  # 小于最小不纯度阈值,不划分
            return
        cur_node.criterion_val = best_criterion_val
        cur_node.feature_idx = best_idx
        cur_node.feature_val = best_val

        # print("当前划分的特征索引:", best_idx, "取值:", best_val, "最佳标准值:", best_criterion_val)
        # print("当前结点的类别分布:", target_dist)

        # 创建左子树,并递归创建以当前结点为子树根节点的左子树
        left_idx = np.where(x_train[:, best_idx] <= best_val)  # 左子树所包含的样本子集索引
        if len(left_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点
            left_child_node = TreeNode_R()  # 创建左子树空结点
            # 以当前结点为子树根结点,递归创建
            cur_node.left_child_Node = left_child_node
            self._build_tree(cur_depth + 1, left_child_node, x_train[left_idx],
                             y_train[left_idx], sample_weight[left_idx])

        right_idx = np.where(x_train[:, best_idx] > best_val)  # 右子树所包含的样本子集索引
        if len(right_idx) >= self.min_sample_leaf:  # 小于叶子结点所包含的最少样本量,则标记为叶子结点
            right_child_node = TreeNode_R()  # 创建右子树空结点
            # 以当前结点为子树根结点,递归创建
            cur_node.right_child_Node = right_child_node
            self._build_tree(cur_depth + 1, right_child_node, x_train[right_idx],
                             y_train[right_idx], sample_weight[right_idx])

    def _search_tree_predict(self, cur_node: TreeNode_R, x_test):
        """
        根据测试样本从根结点到叶子结点搜索路径,判定所属区域(叶子结点)
        搜索:按照后续遍历
        :param x_test: 单个测试样本
        :return:
        """
        if cur_node.left_child_Node and x_test[cur_node.feature_idx] <= cur_node.feature_val:
            return self._search_tree_predict(cur_node.left_child_Node, x_test)
        elif cur_node.right_child_Node and x_test[cur_node.feature_idx] > cur_node.feature_val:
            return self._search_tree_predict(cur_node.right_child_Node, x_test)
        else:
            # 叶子结点,类别,包含有类别分布
            return cur_node.y_hat

    def predict(self, x_test):
        """
        预测测试样本x_test的预测值
        :param x_test: 测试样本ndarray、numpy数值运算
        :return:
        """
        x_test = np.asarray(x_test)  # 避免传递DataFrame、list...
        if self.dbw.XrangeMap is None:
            raise ValueError("请先进行回归决策树的创建,然后预测...")
        x_test = self.dbw.transform(x_test)
        y_test_pred = []  # 用于存储测试样本的预测值
        for i in range(x_test.shape[0]):
            y_test_pred.append(self._search_tree_predict(self.root_node, x_test[i]))
        return np.asarray(y_test_pred)

    @staticmethod
    def cal_mse_r2(y_test, y_pred):
        """
        模型预测的均方误差MSE和判决系数R2
        :param y_test: 测试样本的真值
        :param y_pred: 测试样本的预测值
        :return:
        """
        y_test, y_pred = y_test.reshape(-1), y_pred.reshape(-1)
        mse = ((y_pred - y_test) ** 2).mean()  # 均方误差
        r2 = 1 - ((y_pred - y_test) ** 2).sum() / ((y_test - y_test.mean()) ** 2).sum()
        return mse, r2

    def _prune_node(self, cur_node: TreeNode_R, alpha):
        """
        递归剪枝,针对决策树中的内部结点,自底向上,逐个考察
        方法:后序遍历
        :param cur_node: 当前递归的决策树的内部结点
        :param alpha: 剪枝阈值
        :return:
        """
        # 若左子树存在,递归左子树进行剪枝
        if cur_node.left_child_Node:
            self._prune_node(cur_node.left_child_Node, alpha)
        # 若右子树存在,递归右子树进行剪枝
        if cur_node.right_child_Node:
            self._prune_node(cur_node.right_child_Node, alpha)

        # 针对决策树的内部结点剪枝,非叶结点
        if cur_node.left_child_Node is not None or cur_node.right_child_Node is not None:
            for child_node in [cur_node.left_child_Node, cur_node.right_child_Node]:
                if child_node is None:
                    # 可能存在左右子树之一为空的情况,当左右子树划分的样本子集数小于min_samples_leaf
                    continue
                if child_node.left_child_Node is not None or child_node.right_child_Node is not None:
                    return
            # 计算剪枝前的损失值(平方误差),2表示当前结点包含两个叶子结点
            pre_prune_value = 2 * alpha
            if cur_node and cur_node.left_child_Node is not None:
                pre_prune_value += (0.0 if cur_node.left_child_Node.square_error is None
                                    else cur_node.left_child_Node.square_error)
            if cur_node and cur_node.right_child_Node is not None:
                pre_prune_value += (0.0 if cur_node.right_child_Node.square_error is None
                                    else cur_node.right_child_Node.square_error)

            # 计算剪枝后的损失值,当前结点即是叶子结点
            after_prune_value = alpha + cur_node.square_error

            if after_prune_value <= pre_prune_value:  # 进行剪枝操作
                cur_node.left_child_Node = None
                cur_node.right_child_Node = None
                cur_node.feature_idx, cur_node.feature_val = None, None
                cur_node.square_error = None

    def prune(self, alpha=0.01):
        """
        决策树后剪枝算法(李航)C(T) + alpha * |T|
        :param alpha: 剪枝阈值,权衡模型对训练数据的拟合程度与模型的复杂度
        :return:
        """
        self._prune_node(self.root_node, alpha)
        return self.root_node




 四、回归决策树算法的测试

test_decision_tree_R.py

import numpy as np
import matplotlib.pyplot as plt
from decision_tree_R import DecisionTreeRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor


obj_fun = lambda x: np.sin(x)
np.random.seed(0)
n = 100
x = np.linspace(0, 10, n)
target = obj_fun(x) + 0.3 * np.random.randn(n)
data = x[:, np.newaxis]  # 二维数组

tree = DecisionTreeRegression(max_bins=50, max_depth=10)
tree.fit(data, target)
x_test = np.linspace(0, 10, 200)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)


plt.figure(figsize=(14, 5))
plt.subplot(121)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(UnPrune) and MSE = %.5f R2 = %.5f" % (mse, r2))

plt.subplot(122)
tree.prune(0.5)
y_test_pred = tree.predict(x_test[:, np.newaxis])
mse, r2 = tree.cal_mse_r2(obj_fun(x_test), y_test_pred)
plt.scatter(data, target, s=15, c="k", label="Raw Data")
plt.plot(x_test, y_test_pred, "r-", lw=1.5, label="Fit Model")
plt.xlabel("x", fontdict={"fontsize": 12, "color": "b"})
plt.ylabel("y", fontdict={"fontsize": 12, "color": "b"})
plt.grid(ls=":")
plt.legend(frameon=False)
plt.title("Regression Decision Tree(Prune) and MSE = %.5f R2 = %.5f" % (mse, r2))


plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1441884.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flask基础学习

1.debug、host、port 模式修改 1) debug模式 默认debug模式是off&#xff0c;在修改代码调试过程中需要暂停重启使用&#xff0c;这时可修改on模式解决。 同时在debug模式开启下可看到出错信息。 下面有关于Pycharm社区版和专业版修改debug模式的区别 专业版 社区版&#…

3.1 Verilog 连续赋值

关键词&#xff1a;assign&#xff0c; 全加器 连续赋值语句是 Verilog 数据流建模的基本语句&#xff0c;用于对 wire 型变量进行赋值。&#xff1a; 格式如下 assign LHS_target RHS_expression &#xff1b; LHS&#xff08;left hand side&#xff09; 指赋值操作…

Go语言每日一题——链表篇(七)

传送门 牛客面试笔试必刷101题 ----------------删除链表的倒数第n个节点 题目以及解析 题目 解题代码及解析 解析 这一道题与昨天的题目在解题思路上有一定的相似之处&#xff0c;都是基于双指针定义快慢指针&#xff0c;这里我们让快指针先走n步&#xff0c;又因为n一定…

[Angular 基础] - 自定义事件 自定义属性

[Angular 基础] - 自定义事件 & 自定义属性 之前的笔记&#xff1a; [Angular 基础] - Angular 渲染过程 & 组件的创建 [Angular 基础] - 数据绑定(databinding) [Angular 基础] - 指令(directives) 以上是能够实现渲染静态页面的基础 之前的内容主要学习了怎么通过…

C语言-3

定义指针 /*指针的概念:1.为了方便访问内存中的内容&#xff0c;给每一个内存单元&#xff0c;进行编号&#xff0c;那么我们称这个编号为地址&#xff0c;也就是指针。2.指针也是一种数据类型&#xff0c;指针变量有自己的内存&#xff0c;里面存储的是地址&#xff0c;也就是…

VMware17上安装centos7.9

一、下载安装包&#xff1a; 1、VMware安装 VMware 下载地址&#xff1a; https://www.vmware.com/cn/products/workstation-pro.html VMware下载后安装即可 安装教程可以参考VMware安装教程 2、CentOs7.9下载地址&#xff1a; http://mirrors.aliyun.com/centos/7.9.2009/iso…

HTML+CSS:全景轮播

效果演示 实现了一个简单的网页布局&#xff0c;其中包含了五个不同的盒子&#xff0c;每个盒子都有一个不同的背景图片&#xff0c;并且它们之间有一些间距。当鼠标悬停在某个盒子上时&#xff0c;它的背景图片会变暗&#xff0c;并且文字会变成白色。这些盒子和按钮都被放在一…

如何用Hexo搭建一个优雅的博客

引言 在数字化时代&#xff0c;拥有一个个人博客已经成为许多人展示自己技能、分享知识和与世界互动的重要方式。而在众多博客平台中&#xff0c;Hexo因其简洁、高效和易于定制的特点而备受青睐。本文将详细介绍如何从零开始搭建一个Hexo博客&#xff0c;让你的个人博客在互联…

Git版本与分支

目录 一、Git 二、配置SSH 1.什么是SSH Key 2.配置SSH Key 三、分支 1.为什么要使用分支 2.四个环境及特点 3.实践操作 1.创建分支 2.查看分支 3.切换分支 4.合并分支 5.删除分支 6.重命名分支 7.推送远程分支 8.拉取远程分支 9.克隆指定分支 四、版本 1.什…

Linux操作系统基础(三):虚拟机与Linux系统安装

文章目录 虚拟机与Linux系统安装 一、系统的安装方式 二、虚拟机概念 三、虚拟机的安装 四、Linux系统安装 1、解压人工智能虚拟机 2、找到解压目录中的node1.vmx 3、启动操作系统 虚拟机与Linux系统安装 一、系统的安装方式 Linux操作系统也有两种安装方式&#xf…

蓝桥杯每日一题------背包问题(一)

背包问题 阅读小提示&#xff1a;这篇文章稍微有点长&#xff0c;希望可以对背包问题进行系统详细的讲解&#xff0c;在看的过程中如果有任何疑问请在评论区里指出。因为篇幅过长也可以进行选择性阅读&#xff0c;读取自己想要的那一部分即可。 前言 背包问题可以看作动态规…

js手写Promise(上)

目录 构造函数resolve与reject状态改变状态改变后就无法再次改变 代码优化回调函数中抛出错误 thenonFulfilled和onRejected的调用时机异步then多个then 如果是不知道或者对Promise不熟悉的铁铁可以先看我这篇文章 Promise 构造函数 在最开始&#xff0c;我们先不去考虑Promi…

精简还是全能?如何在 Full 和 Lite 之间做出最佳选择!关于Configuration注解的Full模式与Lite模式(SpringBoot2)

&#x1f3c3;‍♂️ 微信公众号: 朕在debugger© 版权: 本文由【朕在debugger】原创、需要转载请联系博主&#x1f4d5; 如果文章对您有所帮助&#xff0c;欢迎关注、点赞、转发和订阅专栏&#xff01; 前言 关于 Configuration 注解&#xff0c;相信在座的各位 Javaer 都…

可达鸭二月月赛——基础赛第六场(周五)题解,这次四个题的题解都在这一篇文章内,满满干货,含有位运算的详细用法介绍。

姓名 王胤皓 T1 题解 T1 题面 T1 思路 样例输入就是骗人的&#xff0c;其实直接输出就可以了&#xff0c;输出 Hello 2024&#xff0c;注意&#xff0c;中间有一个空格&#xff01; T1 代码 #include<bits/stdc.h> using namespace std; #define ll long long int …

Swift 使用 Combine 管道和线程进行开发 从入门到精通八

Combine 系列 Swift Combine 从入门到精通一Swift Combine 发布者订阅者操作者 从入门到精通二Swift Combine 管道 从入门到精通三Swift Combine 发布者publisher的生命周期 从入门到精通四Swift Combine 操作符operations和Subjects发布者的生命周期 从入门到精通五Swift Com…

ANSI Escape Sequence 下落的方块

ANSI Escape Sequence 下落的方块 1. ANSI Escape 的用途 无意中发现 B站有人讲解&#xff0c; 完全基于终端实现俄罗斯方块。 基本想法是借助于 ANSI Escape Sequence 实现方方块的绘制、 下落动态效果等。对于只了解 ansi escape sequence 用于 log 的颜色打印的人来说&…

(每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理第10章 项目进度管理(四)

博主2023年11月通过了信息系统项目管理的考试&#xff0c;考试过程中发现考试的内容全部是教材中的内容&#xff0c;非常符合我学习的思路&#xff0c;因此博主想通过该平台把自己学习过程中的经验和教材博主认为重要的知识点分享给大家&#xff0c;希望更多的人能够通过考试&a…

【Java EE】----SpringBoot的日志文件

1.SpringBoot使用日志 先得到日志对象通过日志对象提供的方法进行打印 2.打印日志的信息 3.日志级别 作用&#xff1a; 可以筛选出重要的信息不同环境实现不同日志级别的需求 ⽇志的级别分为&#xff1a;&#xff08;1-6级别从低到高&#xff09; trace&#xff1a;微量&#…

高级数据结构与算法 | 布谷鸟过滤器(Cuckoo Filter):原理、实现、LSM Tree 优化

文章目录 Cuckoo Filter基本介绍布隆过滤器局限变体 布谷鸟哈希布谷鸟过滤器 实现数据结构优化项Victim Cache备用位置计算半排序桶 插入查找删除 应用场景&#xff1a;LSM 优化 Cuckoo Filter 基本介绍 如果对布隆过滤器不太了解&#xff0c;可以看看往期博客&#xff1a;海量…

CentOS 7安装Nodejs

说明&#xff1a;本文介绍如何在云服务器上CentOS 7操作系统上安装Nodejs。以及安装过程中遇到的问题。 下载压缩包&解压 首先&#xff0c;先去官网下载Linux版本的Node。 将下载下来的压缩包&#xff0c;上传到云服务器上&#xff0c;解压。配置环境变量。 &#xff08…