一 问题描述
在 K 维空间中有很多点,给定一个点,找出最近的 M 个点。点 p 和点 q 之间的距离是连接它们的直线段的长度。
二 输入和输出
1 输入
有多个测试用例。第 1 行包含两个非负整数 n 和 k ,分别表示点数和维数,1≤n≤5×10^4 ,1≤k≤5。下面的 n 行,每行都包含 k 个整数,表示一个点的坐标。接下来的一行包含一个正整数 t,表示查询数,1≤t≤10^4 。再接下来的每个查询都包含两行,在第 1 行中输入的 k 个整数表示给定的点;第 2 行包含一个整数 m ,表示应该找到的最近点的数量,1≤m≤10。所有坐标的绝对值都不超过10^4。
2 输出
对每个查询都输出 m +1 行:第 1 行输出“the closest m points are:”,其中 m 是点的数量;接下来输出的 m 行代表 m 个点,从近到远排列。输入的数据保证答案唯一,从给定点到所有最近 m+1 点的距离都不同。
三 输入和输出样例
1 输入样例
3 2
1 1
1 3
3 4
2
2 3
2
2 3
1
2 输出样例
the closest 2 points are:
1 3
3 4
the closest 1 points are:
1 3
四 算法分析和设计
1 分析
(1)根据输入的数据创建 KD 树。
(2)在 KD 树中查询距离给定点 p 最近的 m 个点。
2 设计
查询距离p 最近的m 个点,算法步骤如下。
(1)创建一个序对,第 1 个元素记录当前节点到 p 的距离,第 2 个元素记录当前节点;然后创建一个优先队列,存储距离 p 最近的序对,优先队列按距离最大值优先。
(2)查询时从树根开始,首先计算树根与 p 的距离,用 tmp.first 记录。
(3)若 p.x[dim]<kd[rt].x[dim],则首先在左子树 lc 中查询 , 否则在右子树 rc 中查询。在程序中,可判断若 p.x[dim]≥kd[rt].x[dim],则交换 lc 和 rc,这样就可以统一为首先在 lc 中查询。
(4)若 lc 不空,则在 lc 中递归查询 query(lc,m,dep+1,p)。
(5)若队列中的元素个数小于 m ,则直接将 tmp入队,flag=1,还需要在右子树中查询;否则若 tmp 到 p 的距离小于堆顶到 p 的距离,则堆顶出队,tmp 入队。若以 p 为球心且 p 到队列中最远点的距离为半径的超球体与划分点的另一区域有交集(d <r),则 flag=1,还需要在右子树中查询。
(6)若 rc 不空且 flag=1,则在 rc 中递归查询 query(rc,m,dep+1,p)。
五 代码
package com.platform.modules.alg.alglib.hdu4347;
import javafx.util.Pair;
import java.util.Arrays;
import java.util.Comparator;
import java.util.PriorityQueue;
public class Hdu4347 {
private int N = 50010;
int n, k, t, m, idx;
Node a[] = new Node[N];
int sz[] = new int[N << 2];
Node kd[] = new Node[N << 2];
PriorityQueue<Pair<Integer, Node>> que = new PriorityQueue<>(new MyComparator());
public Hdu4347() {
for (int i = 0; i < a.length; i++) {
a[i] = new Node();
}
}
void build(int rt, int l, int r, int dep) {
if (l > r) return;
int mid = (l + r) >> 1;
idx = dep % k;
sz[rt] = 1;
sz[rt << 1] = sz[rt << 1 | 1] = 0;
Arrays.sort(a, l, r + 1);
kd[rt] = a[mid];
build(rt << 1, l, mid - 1, dep + 1);
build(rt << 1 | 1, mid + 1, r, dep + 1);
}
void query(int rt, int m, int dep, Node p) {
if (sz[rt] == 0) return;
Pair<Integer, Node> tmp = new Pair(0, kd[rt]);
for (int j = 0; j < k; j++) {
tmp = new Pair((tmp.getKey() + (tmp.getValue().x[j] - p.x[j]) * (tmp.getValue().x[j] - p.x[j])), tmp.getValue());
}
int lc = rt << 1, rc = rt << 1 | 1, dim = dep % k, flag = 0;
if (p.x[dim] >= kd[rt].x[dim]) {
int temp = lc;
lc = rc;
rc = temp;
}
if (sz[lc] > 0)
query(lc, m, dep + 1, p);
if (que.size() < m) {
que.add(tmp);
flag = 1;
} else {
if (tmp.getKey() < que.peek().getKey()) {//大顶堆,保存最邻近的m个点
que.poll();
que.add(tmp);
}
if ((p.x[dim] - kd[rt].x[dim] * (p.x[dim] - kd[rt].x[dim]) < que.peek().getKey()))
flag = 1;
}
if (sz[rc] > 0 && flag == 1)
query(rc, m, dep + 1, p);
}
public String output = "";
public String cal(String input) {
String[] line = input.split("\n");
String[] num = line[0].split(" ");
n = Integer.parseInt(num[0]);
k = Integer.parseInt(num[1]);
int count = 1;
for (int i = 0; i < n; i++) {
String[] node = line[count++].split(" ");
for (int j = 0; j < k; j++) {
a[i].x[j] = Integer.parseInt(node[j]);
}
}
build(1, 0, n - 1, 0);
t = Integer.parseInt(line[count++]);
while (t-- > 0) {
Node p = new Node();
String[] queryNode = line[count++].split(" ");
for (int i = 0; i < k; i++) {
p.x[i] = Integer.parseInt(queryNode[i]);
}
m = Integer.parseInt(line[count++]);
query(1, m, 0, p);
Node tmp[] = new Node[15];
for (int i = 0; !que.isEmpty(); que.poll()) // 大顶堆暂存到数组,逆序输出
tmp[++i] = que.peek().getValue();
output += String.format("the closest %d points are:\n", m);
for (int i = m; i > 0; i--) { // 逆序输出
output += tmp[i].x[0];
for (int j = 1; j < k; j++) {
output += String.format(" %d", tmp[i].x[j]);
}
output += "\n";
}
}
return output;
}
class Node implements Comparable<Node> {
int x[] = new int[5];
public int compareTo(Node o) {
return x[idx] > o.x[idx] ? 1 : -1; // 升序
}
}
class MyComparator implements Comparator<Pair<Integer, Node>> {
@Override
public int compare(Pair<Integer, Node> num1, Pair<Integer, Node> num2) {
return num2.getKey().compareTo(num1.getKey());
}
}
}
六 测试