引言
K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较。
在本文中,我将K近邻(KNN)算法的应用分为两种情况:
-
全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。
-
局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。
全局查询
def knn(x, k):
"""
Input:
x: all points, [B, C, N]
k: k nearest points of each point
Return:
idx: grouped points index, [B, N, k]
"""
inner = -2*torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x**2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
return idx
这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。
函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。
插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。
局部查询
(1)knn_point 函数
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def knn_point(nsample, xyz, new_xyz):
"""
Input:
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
return group_idx
这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。
(2)knn_cuda 库函数
import torch
# Make sure your CUDA is available.
assert torch.cuda.is_available()
from knn_cuda import KNN
"""
if transpose_mode is True,
ref is Tensor [bs x nr x dim]
query is Tensor [bs x nq x dim]
return
dist is Tensor [bs x nq x k]
indx is Tensor [bs x nq x k]
else
ref is Tensor [bs x dim x nr]
query is Tensor [bs x dim x nq]
return
dist is Tensor [bs x k x nq]
indx is Tensor [bs x k x nq]
"""
knn = KNN(k=10, transpose_mode=True)
ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()
dist, indx = knn(ref, query) # 32 x 50 x 10
大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。
需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了。
性能比较
(1)测试代码
import torch
import time
from knn_cuda import KNN
def knn(x, k):
inner = -2*torch.matmul(x.transpose(2, 1), x)
xx = torch.sum(x**2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1)
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
return idx
def square_distance(src, dst):
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def knn_point(nsample, xyz, new_xyz):
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
return group_idx
# Custom knn implementation
def test_knn(query, k, times):
query = query.permute(0,2,1)
start_time = time.time() # Start timer
for i in range(times):
indx = knn(query, k = k)
end_time = time.time() # End timer
return end_time - start_time # Return elapsed time
# Custom knn_point implementation
def test_knn_point(ref, query, k, times):
start_time = time.time() # Start timer
for i in range(times):
indx = knn_point(k, ref, query)
end_time = time.time() # End timer
return end_time - start_time # Return elapsed time
# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):
knn = KNN(k=k, transpose_mode=True)
start_time = time.time() # Start timer
for i in range(times):
dist, indx = knn(ref, query)
end_time = time.time() # End timer
return end_time - start_time # Return elapsed time
# Main testing function
def test_knn_methods(ref, query, k, times):
print("Test times: %d" % times)
# Test custom knn
time_knn = test_knn(query, k, times)
print(f"knn : {time_knn:.6f} seconds")
# Test custom knn_point
time_point = test_knn_point(ref, query, k, times)
print(f"knn_point: {time_point:.6f} seconds")
# Test knn_cuda
time_cuda = test_knn_cuda(ref, query, k, times)
print(f"knn_cuda : {time_cuda:.6f} seconds")
if __name__ == '__main__':
# Sample input
B, N, S, C = 32, 1024, 50, 3 # Batch size, total points, query points, coordinates
k = 24 # Number of nearest neighbors
ref = torch.randn(B, N, C).cuda() # Reference points
# Test above methods
times_list = [1,2,3,10,50,100]
for times in times_list:
test_knn_methods(ref, ref, k, times)
这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knn
、knn_point
以及基于 knn_cuda
库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。
上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。
当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。
(2)模型训练
直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。
总结
我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。