import math
import matplotlib.pyplot as plt
class Node:
def __init__(self, data, left=None, right=None):
self.data = data
self.left = left
self.right = right
# 创建KDTree类
class KDTree:
def __init__(self, k):
self.k = k
def create_tree(self,dataset,depth):
if not dataset:
return None
mid_index=len(dataset)//2 # 中位数
axis = depth%self.k # 按照哪个坐标轴划分
sorted_dataset = sorted(dataset,key=(lambda x : x[axis])) # 按照坐标轴划分
mid_data = sorted_dataset[mid_index]#中位数数据值
current_node = Node(mid_data) # 创建当前节点
left_data = sorted_dataset[:mid_index] # 划分左节点数据
right_data = sorted_dataset[mid_index+1:] # 划分右节点数据
current_node.left = self.create_tree(left_data,depth+1) # 创建左子树
current_node.right = self.create_tree(right_data,depth+1) # 创建右子树
return current_node
def search(self, tree, new_data):
self.nearest_point = None # 当前最邻近点
self.nearest_val = None # 当前最邻近点与目标节点间距离
def dfs(node,depth): # 深度优先搜索
# 递归找叶子节点
if not node:
return None
axis = depth % self.k
if new_data[axis] < node.data[axis]:
dfs(node.left, depth+1)
else:
dfs(node.right, depth+1)
# 比较距离,判断是否更新最近邻点
dist = self.distance(new_data,node.data)
if not self.nearest_val or dist<self.nearest_val:
self.nearest_val = dist
self.nearest_point = node.data
# 判断是否遍历该节点另一边子树
if abs(new_data[axis]-node.data[axis]) <= self.nearest_val: # 计算父节点在其分割特征上的data距离目标点在该特征上的data的距离。若该距离小于 nearest_val,则进入另一个孩子节点,否则不进入
if new_data[axis] < node.data[axis]: # 之前若先遍历左子树,现在就要遍历右子树
dfs(node.right, depth+1)
else:
dfs(node.left, depth+1)
dfs(tree, 0)
return self.nearest_point
def distance(self,new_data, new_val):
res = 0
for i in range(self.k):
res += (new_data[i]-new_val[i])**2
return math.sqrt(res)
if __name__ == '__main__':
data_set = [[3,3],[5,4],[5,6],[2,7],[9,1],[2,5],[3,2],[2,0]
new_data = [2,9]
k = len(data_set[0])
kd_tree = KDTree(k)
our_tree = kd_tree.create_tree(data_set,0)
predict = kd_tree.search(our_tree,new_data)
print(f"Nearest Point of {new_data} is {predict}")
plt.scatter([x[0] for x in data_set],[x[1] for x in data_set],c='purple',label='train_data')
plt.scatter(new_data[0],new_data[1],c='red',label='target_data')
plt.plot([predict[0], new_data[0]], [predict[1],new_data[1]], c='green',label='Nearest Point',linestyle='--')
plt.legend()
plt.show()
Node
类用于表示KD树的节点。data
保存当前节点的数据点。left
和right
分别指向左子树和右子树。KDTree
类用于创建和操作KD树。k
表示数据点的维度。
create_tree
方法用于递归地创建KD树。dataset
是要构建树的数据集。depth
表示当前节点的深度,用于确定划分的轴。- 根据深度计算轴并排序数据集,选择中位数作为当前节点的数据点。
- 递归地创建左子树和右子树。
search
方法用于在KD树中查找离new_data
最近的点。self.nearest_point
和self.nearest_val
用于保存当前找到的最近点及其距离。- 定义深度优先搜索
dfs
函数,递归地搜索树,更新最近点和距离。 - 检查是否需要遍历另一边的子树。
- 主程序创建数据集
data_set
和要查找的点new_data
。 - 初始化
KDTree
实例并创建KD树。 - 使用
search
方法查找最近点并打印结果。 - 使用
matplotlib
绘制数据点和最近邻点的连线。
参考文献Kd Tree算法详解_kd-tree-CSDN博客
Python手撸机器学习系列(十一):KNN之kd树实现_knn原理及python代码实现建立kd树-CSDN博客