图解kd树+Python实现

news2025/4/8 8:45:28

开篇

在讲解k-近邻算法的时候,我们提供的思路是:对于新到来的样本,计算该样本与训练集中所有样本之间的距离,选取训练集中距离新样本最近的k个样本中大多数样本的类别作为新的样本的类别。

也就是说,每次都要计算新的样本与训练集中全部样本的距离。但是,在实际应用中,训练集的样本量和特征维度都是比较庞大的,这就导致该算法不得不在计算距离上花费大量的时间,那有没有什么方法可以在时间开销上对之前的k-近邻算法进行优化呢?

采用以空间来换时间的思想,就引出了今天的主角:kd树

构造kd树

kd树是一种二叉树,它可以将k维特征空间中的样本进行划分存储,以便实现快速搜索。

一头雾水?没关系,来看一个经典的构造kd树的例子。

现给定一个二维的训练集:
T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}
要求构造一个平衡kd树

复制
  • 第一步,选取第0个维度作为被划分的坐标轴,并按照第0个维度从小到大排列全部样本,得到:
  {(2,3),(4,7),(5,4),(7,2),(8,1),(9,6)}

复制
  • 第二步,找到第0个维度的中位数对应的样本。注意,这里的中位数与我们之前认知的中位数有些不同,具体表现在:对于本例,第0个维度排序后分别为:2,4,5,7,8,9。按道理中位数应该是(5+7)/2=6,但是,训练集的第0个维度中并没有6,所以,我们需要选取距离6最近的出现在训练集的第0个维度中的数字作为中位数,这里,5和7都是可以的。为了便于编程,我们就统一使用下标较大的位置的数字了: 6//2=3,所以最终选择下标为3位置的数字,即数字7作为第0个维度的中位数。

  • 第三步,以第二步中选取的中位数为基准,并作为当前划分的父节点。将从小到大排序好的样本序列进行划分:第0个维度小于基准的样本被划分到当前划分的父节点的左子树,第0个维度大于基准的样本被划分当前划分的父节点的右子树。此时得到如下的树:

  • 第四步,选取新的维度,按照公式 “新的划分维度=(上一次使用的维度+1)mod  特征总维数” ,得到新的维度为:(0+1) mod 2 = 1 。

    于是以维度1替换维度0,重复第一步到第三步:

    • 对左子树{(5,4),(2,3),(4,7)}按照特征的第一个维度从小到大排序:{(2,3),(5,4),(4,7)},确定中位数下标为3//2=1,所以数字4为中位数;将(5,4)作为当前划分的父节点,第一维度大于4的作为其左子树,第一维度小于4的作为其右子树;

    • 对右子树{(9,6),(8,1)}按照特征的第一个维度从小到大排序:{(8,1),(9,6)},确定中位数下标为2//2=1,所以数字6为中位数;将(9,6)作为当前划分的父节点,第一维度大于6的作为其左子树,第一维度小于6的作为其右子树;

此时得到的树如下:

由于此时训练集中所有子区域都已划分完毕(任一子区域中不含样本点),因此kd树就构造完成了。在上面的过程中,每分一次岔,就对应特征空间的一次划分(叶子节点的左右孩子都为空,但这里仍可以看成是一种特殊的分叉【左右分支都为空】) 最终整个特征空间被划分如下:

现在来用Python实现上述过程。首先定义每个节点的数据结构:

class Node():
    def __init__(self,lchild,rchild,value):
        self.lchild=lchild#节点的左子树
        self.rchild=rchild#节点的右子树
        self.value=value#节点的数值

复制

然后初始化一个KD树的类:

class KDTree():
    def __init__(self,data):
        self.dims=len(data[0])#训练集总特征数

复制

接下来到了构建kd树的核心步骤,从之前的例子中,可以总结出我们的思路:

创建kd树的过程是递归的,所以我们可以递归地构造之:
(1) 递归地构造左子树;
(2) 递归地构造右子树;
(3) 构造父节点,将其lchild与构造好的左子树连接,将其rchild与构造好的右子树连接。
除此之外,还有一些辅助的方法,比如求指定维度的中位数,计算下一个划分维度,将会写成单独的方法以使得创建树的代码更加具有可读性。
最后,不要忘了递归出口:被划分的子区域没有样本存在时,就退出。

    def create_kdtree(self,current_data,split_dim):
        #设置递归出口:当全部样本划分完毕时就退出
        if len(current_data)==0:
            return None
        
        mid=self.cal_current_medium(current_data)#计算中位数所在下标
        data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

        #下面三句代码本质上就是二叉树的后序遍历
        lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
        rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
        return Node(lchild,rchild,data_sorted[mid])#连接从根节点出发的左右子树,并返回
    
    #计算下一个划分维度
    def cal_split_dim(self,split_dim):
        return (split_dim+1) % self.dims
    
    #计算当前维度中位数所在下标
    def cal_current_medium(self,current_data):
        return len(current_data)//2
  

复制

完整的kd树构造代码如下:

class KDTree():
    def __init__(self,data):
        self.dims=len(data[0])#训练集总特征数
   def create_kdtree(self,current_data,split_dim):
        #设置递归出口:当全部样本划分完毕时就退出
        if len(current_data)==0:
            return None
        
        mid=self.cal_current_medium(current_data)#计算中位数所在下标
        data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

        #下面三句代码本质上就是二叉树的后序遍历
        lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
        rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
        return Node(lchild,rchild,data_sorted[mid])#连接从根节点出发的左右子树,并返回
    
    #计算下一个划分维度
    def cal_split_dim(self,split_dim):
        return (split_dim+1) % self.dims
    
    #计算当前维度中位数所在下标
    def cal_current_medium(self,current_data):
        return len(current_data)//2

复制

运行下面的代码,就构造好了一棵kd树:

dataset = np.array([[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]])#构建训练数据集
kdtree = KDTree(dataset).create_kdtree(dataset,0)#创建KD树,以特征的第0个维度开始做划分

复制

搜索kd树

这里仅实现最近邻搜索。所谓最近邻,就是k-近邻中k取1时的特殊情况。我们还是以具体的例子进行说明。基于上面构造好的kd树,现在来搜索样本点(2, 4.5)的最近邻点。先把之前的图搬过来,对照该图阅读以下步骤会更容易理解:

从根节点开始:

  1. 首先来到第一层:在构造kd树时,由于(7,2)是根据维度0进行划分的,因此需要比较(2,4.5)与(7,2)的第0个维度的大小。由于2<7,因此接下来将搜索(7,2)的左子树(也就是(5,4)节点),反映到划分图上,就是去"过点(7,2)的垂直于横轴的划分线"的左侧进行接下来的搜索;

  2. 然后来到第二层:在构造kd树时,由于(5,4)是根据维度1进行划分的,因此需要比较(2,4.5)与(5,4)的第1个维度的大小。由于4.5>4,因此接下来将搜索(5,4)的右子树(也就是(4,7)节点),反映到划分图上,就是去"过点(5,4)的垂直于纵轴的划分线"的上侧进行接下来的搜索;

  3. 接着来到第三层:由于(4,7)已经是叶子节点,无左右孩子,所以从根节点(7,2)到叶子节点的搜索就完成了,当前的最近邻节点就是最后到达的叶子节点,也就是(4,7)。

  4. 现在,开始从叶子节点(4,7)向上往根节点进行搜索(这也称之为回溯):

(1)
以(2,4.5)为中心,以(2,4.5)到当前最近邻点(4,7)的距离为半径,画一个圆(这里特征是二维的,所以是圆。一般的,对于高维特征的情况,画出来的是一个超球面),真正的最近邻点一定包含在这个圆的内部。于是当前最近邻点是(4,7),最近距离为半径长度=3.2015;

(2)
从叶子节点(4,7)返回其父节点(5,4),计算(5,4)与(2,4.5)的距离为3.0413,而3.0413<3.2015,因此当前最近邻点被更新为(5,4),最近距离被更新为3.0413;

(3)
返回计算父节点(5,4)的另一子节点(这里也就是(2,3)),计算其与目标点(2,4.5)的距离为1.5,而1.5<3.0413,因此当前最近邻点被更新为(2,3),最近距离被更新为1.5;

(4)
此时父节点(5,4)的另一子节点已经搜索完毕,继续向上回溯搜索那些没有被回溯过的节点,于是来到根节点(7,2),计算(7,2)与(2,4.5)的距离为5.5901,而5.5901>1.5,因此当前最近邻点不变,最近距离也不变。由于已经回溯到了根节点,整个搜索就完毕了,当前最近邻点就是我们最终要找的最近邻点,即(2,3)。

现在,让我们用Python程序来实现以上的搜索过程。基于构造kd树的代码,需要增加搜索的方法以及一些小的变动,具体如下:

  • 由于在前向搜索的过程中,需要知道每个节点是根据哪个维度进行划分的,因此给每个节点增加一个维度属性:split_dim
class Node():
   def __init__(self,lchild,rchild,value,split_dim):
       self.lchild=lchild#节点的左子树
       self.rchild=rchild#节点的右子树
       self.value=value#节点的数值
       self.split_dim=split_dim#用来做划分的维度

复制
  • 为了便于返回最近邻点和最近距离,将这两个属性添加到kd树的属性中:
class KDTree():
   def __init__(self,data):
       self.dims=len(data[0])#总特征数
       self.nearest_point=None
       self.nearest_distance=np.inf#初始化为无穷大

复制
  • 由于涉及到了距离的比较,因此增加计算两点之间距离的方法:
   #计算两点之间的欧氏距离
   def cal_dist(sample1,sample2):
       return np.sqrt(np.sum((sample1-sample2)**2))

复制
  • 算法将从根节点开始搜索,由于是递归的,所以这里可以先写一个辅助的递归入口函数,真正实现递归的算法写在另一个方法中:
   #element:目标节点;root:kd树的根节点
   def get_nearest(self,root,element):
       search(root,element)#递归地搜索
       return self.nearest_point,self.nearest.dist

复制
  • 现在来实现递归搜索的过程:
   def search(self,node,element):
      if node is  None:
        return
   #计算当前划分维度上目标节点与当前节点的单一维度上的距离
      dist = node.value[node.split_dim] - element[node.split_dim]
      #前向搜索
      if dist>0:#当前节点在目标节点的上侧或左侧(在二维空间中)
          self.search(node.lchild,element)#递归地搜索左子树
      else:#否则,当前节点在目标节点的下侧或右侧(在二维空间中)
          self.search(node.rchild,element)#递归地搜索右子树
      #计算目标节点与当前节点的欧氏距离
      curr_dist = self.cal_dist(node.value,element)
      #更新最近邻节点
      if curr_dist < self.nearest_dist:
          self.nearest_dist = curr_dist
          self.nearest_point = node
          #print(self.nearest_point.value)
      #回溯
      #比较“最近距离”是否超过“目标节点与当前节点在当前划分维度上的距离”,超过了就说明可能在当前节点的另一侧子树中存在更近的点,所以需要到当前节点的另一侧子树中去搜索
      if self.nearest_dist > abs(dist):
          #由于是去当前节点的另一侧子树中进行搜索,因此正好与之前的前向搜索相反
          if dist>0:
              self.search(node.rchild,element)
          else:
              self.search(node.lchild,element)

复制

完整代码如下:

import numpy as np
class Node():
    def __init__(self,lchild,rchild,value,split_dim):
        self.lchild=lchild#节点的左子树
        self.rchild=rchild#节点的右子树
        self.value=value#节点的数值
        self.split_dim=split_dim#用来做划分的维度

class KDTree():
    def __init__(self,data):
        self.dims=len(data[0])#总特征数
        self.nearest_point=None
        self.nearest_dist=np.inf#初始化为无穷大
        
    def create_kdtree(self,current_data,split_dim):
        #设置递归出口:当全部样本划分完毕时就退出
        if len(current_data)==0:
            return None
        
        mid=self.cal_current_medium(current_data)#计算中位数所在下标
        data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

        #下面三句代码本质上就是二叉树的后序遍历
        lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
        rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
        return Node(lchild,rchild,data_sorted[mid],split_dim)#连接从根节点出发的左右子树,并返回
    
    #计算下一个划分维度
    def cal_split_dim(self,split_dim):
        return (split_dim+1) % self.dims
    
    #计算当前维度中位数所在下标
    def cal_current_medium(self,current_data):
        return len(current_data)//2
    
    #计算两点之间的欧氏距离
    def cal_dist(self,sample1,sample2):
        return np.sqrt(np.sum((sample1-sample2)**2))
        
    #传入kd树的根节点root和待搜索的点element,搜索element的最近邻点
    def search(self,node,element):
        if node is  None:
            return
  #计算当前划分维度上目标节点与当前节点的单一维度上的距离
        dist = node.value[node.split_dim] - element[node.split_dim]
        #前向搜索
        if dist>0:#当前节点在目标节点的上侧或左侧(在二维空间中)
            self.search(node.lchild,element)#递归地搜索左子树
        else:#否则,当前节点在目标节点的下侧或右侧(在二维空间中)
            self.search(node.rchild,element)#递归地搜索右子树
        #计算目标节点与当前节点的欧氏距离
        curr_dist = self.cal_dist(node.value,element)
        #更新最近邻节点
        if curr_dist < self.nearest_dist:
            self.nearest_dist = curr_dist
            self.nearest_point = node
            #print(self.nearest_point.value)
        #回溯
        #比较“最近距离”是否超过“目标节点与当前节点在当前划分维度上的距离”,超过了就说明可能在当前节点的另一侧子树中存在更近的点,所以需要到当前节点的另一侧子树中去搜索
        if self.nearest_dist > abs(dist):
            #由于是去当前节点的另一侧子树中进行搜索,因此正好与之前的前向搜索相反
            if dist>0:
                self.search(node.rchild,element)
            else:
                self.search(node.lchild,element)
    
     def get_nearest(self,root,element):
        self.search(root,element)
        return self.nearest_point.value,self.nearest_dist

复制

现在来测试一下:

dataset = np.array([[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]])#构建训练数据集
kdtree = KDTree(dataset)#实例化一个kd树对象
root=kdtree.create_kdtree(dataset,0)#创建KD树,且以特征的第0个维度开始做划分,最终返回的是根节点
nearest_point,nearest_dist=kdtree.get_nearest(root,[2,4.5])#搜索[2,4.5]的最近邻点
print('最近邻点:{}\n最近距离:{}'.format(nearest_point,nearest_dist))

复制

运行结果:

最近邻点:[2 3]
最近距离:1.5

复制

这和之前我们推导的结果是一致的。

最后,感谢互联网上的优秀资源,给本文提供了许多参考。

参考资料:

link

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1143137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c语言基础:L1-060 心理阴影面积

这是一幅心理阴影面积图。我们都以为自己可以匀速前进&#xff08;图中蓝色直线&#xff09;&#xff0c;而拖延症晚期的我们往往执行的是最后时刻的疯狂赶工&#xff08;图中的红色折线&#xff09;。由红、蓝线围出的面积&#xff0c;就是我们在做作业时的心理阴影面积。 现给…

4.6 其他安全性保护

思维导图&#xff1a; 4.6 其他安全性保护 1. 推理控制 (Inference Control) 定义&#xff1a;处理强制存取控制未解决的问题&#xff0c;如利用列的函数依赖关系&#xff0c;从低安全等级信息推导出高安全等级信息。示例&#xff1a;在公司信息系统中&#xff0c;姓名和职务为…

SQL查询优化---如何查询截取分析

慢查询日志 1、慢查询日志是什么 MySQL的慢查询日志是MySQL提供的一种日志记录&#xff0c;它用来记录在MySQL中响应时间超过阀值的语句&#xff0c;具体指运行时间超过long_query_time值的SQL&#xff0c;则会被记录到慢查询日志中。 具体指运行时间超过long_query_time值的…

use renv with this project create a git repository

目录 1-create a git repository 2-Use renv with this project 今天在使用Rstudio过程中&#xff0c;发现有下面两个新选项&#xff08;1&#xff09;create a git repository (2) Use renv with this project. 选中这两个选项后&#xff0c;创建新项目&#xff0c;在项目目…

Redis(01)| 数据结构

这里写自定义目录标题 Redis 速度快的原因除了它是内存数据库&#xff0c;使得所有的操作都在内存上进行之外&#xff0c;还有一个重要因素&#xff0c;它实现的数据结构&#xff0c;使得我们对数据进行增删查改操作时&#xff0c;Redis 能高效的处理。 因此&#xff0c;这次我…

此页面不能正确地重定向

这种是由于条件判断有误&#xff0c;程序不断的重定向到一个页面&#xff0c;而造成的死循环的情况 下面列举一个常出现的场景之一 1、使用过滤器实现登录验证错误处理 解释&#xff1a;当用户访问login.jsp进行登录的时候&#xff0c;这个时候请求会被Filter捕获&#xff0…

【Java基础(高级篇)】集合源码剖析

集合源码剖析 文章目录 集合源码剖析1. List接口分析1.1 ArrayList1.2 LinkedList 2. Map接口分析2.1 哈希表的物理结构2.2 HashMap中数据添加过程2.2.1 JDK7中过程分析2.2.2 JDK8中过程分析 2.3 红黑树2.4 HashMap源码剖析(JDK1.8.0_271)2.4.1 Node2.4.2 属性2.4.3 构造器2.4.…

基础课11——数据来源

随着科技的进步和数字化转型的加速&#xff0c;全球数据量正以惊人的速度增长。根据IDC的最新报告&#xff0c;2020年全球数据总量已经达到了约53 ZB&#xff08;Zettabyte&#xff0c;万亿亿GB&#xff09;&#xff0c;而这个数字在2025年预计会达到175 ZB。这种指数级增长不仅…

MAC下安装Python

MAC基本信息&#xff1a; 执行命令&#xff1a; brew install cmake protobuf rust python3.10 git wget 遇到以下问题&#xff1a; > Downloading https://mirrors.aliyun.com/homebrew/homebrew-bottles/rust-1.59.0 Already downloaded: /Users/xxxx/Library/Caches/Ho…

售后处置跟踪系统设想

售后处置跟踪系统设想 前言 随着汽车工业的发展&#xff0c;软件定义车的模式已成为主流汽车设计及智能化功能架构模式&#xff0c;通过引入SOA的软件架构设计&#xff0c;使得现有的座舱软件、云端服务软件、App软件等众多功能模块的版本迭代频次日新月异&#xff0c;发版更…

【ubuntu】 Linux(ubuntu)创建python的虚拟环境

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

语雀故障事件——P0级别事故启示录 发生肾么事了? 怎么回事?

前言 最近&#xff0c;阿里系的语雀出了一个大瓜&#xff0c;知名在线文档编辑与协同工具语雀发生故障&#xff0c;崩溃近10小时。。。。最后&#xff0c;官方发布了一则公告&#xff0c;我们一起来看看这篇公告&#xff0c;能不能有所启发。 目录 前言引出一、语雀P0故障回顾…

设计模式(19)命令模式

一、介绍&#xff1a; 1、定义&#xff1a;命令模式&#xff08;Command Pattern&#xff09;是一种行为设计模式&#xff0c;它将请求封装为一个对象&#xff0c;从而使你可以使用不同的请求对客户端进行参数化。命令模式还支持请求的排队、记录日志、撤销操作等功能。 2、组…

4+非肿瘤纯生信。氧化应激+WGCNA+药物预测筛序关键基因

今天给同学们分享一篇非肿瘤氧化应激WGCNA的生信文章“Identification of oxidative stress-related biomarkers associated with the development of acute-on-chronic liver failure using bioinformatics”&#xff0c;这篇文章于2023年10月10日发表在Scientific Reports期刊…

双十一什么东西一定要买?实用性强好物千万不能错过

一年一度的双十一购物节即将来临啦&#xff01;相信很多朋友都在等这个时间选购一些实用性比较强好物&#xff0c;平时太贵的一些家电都舍不得买&#xff0c;就是为了等到双十一这一些&#xff0c;准备买买买的朋友们&#xff0c;别着急&#xff0c;作为智能家电好物分享家的我…

轻量封装WebGPU渲染系统示例<7>-材质多pass(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/version-1.01/src/voxgpu/sample/MultiMaterialPass.ts 此示例渲染系统实现的特性: 1. 用户态与系统态隔离。 2. 高频调用与低频调用隔离。 3. 面向用户的易用性封装。 4. 渲染数据和渲染机制分离。 …

dolphinscheduler3.2.0 install报错

下载3.2.0版本代码&#xff0c;执行install报错&#xff0c;dolphinscheduler-common无法加载依赖 [ERROR] Failed to execute goal com.diffplug.spotless:spotless-maven-plugin:2.27.2:check (default) on project dolphinscheduler-common: The following files had format…

ue5 右击.uproject generator vs project file 错误

出现如下错误 Unable to find valid 14.31.31103 C toolchain for VisualStudio2022 x64 就算你升级了你的 vs installer 也不好使 那是因为 在C:\Users\{YourUserName}\AppData\Roaming\Unreal Engine\UnrealBuildTool\BuildConfiguration.xml 这个缓存配置文件中写死了 14…

Echarts渲染不报错但是没有内容

&#x1f525;博客主页&#xff1a; 破浪前进 &#x1f516;系列专栏&#xff1a; Vue、React、PHP ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 问题&#xff1a;在开发项目的时候使用了Echarts但是好端端的忽然就不渲染了 感觉很无语啊&#xff0c;毕竟好好的就不渲染了&am…