实验记录 | 点云处理 | K-NN算法3种实现的性能比较

news2024/11/13 9:45:29

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):
    """
    Input:
        x: all points, [B, C, N]
        k: k nearest points of each point
    Return:
        idx: grouped points index, [B, N, k]
    """
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def knn_point(nsample, xyz, new_xyz):
    """
    Input:
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
    return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch

# Make sure your CUDA is available.
assert torch.cuda.is_available()

from knn_cuda import KNN
"""
if transpose_mode is True, 
    ref   is Tensor [bs x nr x dim]
    query is Tensor [bs x nq x dim]
    
    return 
        dist is Tensor [bs x nq x k]
        indx is Tensor [bs x nq x k]
else
    ref   is Tensor [bs x dim x nr]
    query is Tensor [bs x dim x nq]
    
    return 
        dist is Tensor [bs x k x nq]
        indx is Tensor [bs x k x nq]
"""

knn = KNN(k=10, transpose_mode=True)

ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()

dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNN

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def knn_point(nsample, xyz, new_xyz):
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
    return group_idx

# Custom knn implementation
def test_knn(query, k, times):
    query = query.permute(0,2,1)
    start_time = time.time()  # Start timer
    for i in range(times):
        indx = knn(query, k = k)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time

# Custom knn_point implementation
def test_knn_point(ref, query, k, times):
    start_time = time.time()  # Start timer
    for i in range(times):
        indx = knn_point(k, ref, query)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time

# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):
    knn = KNN(k=k, transpose_mode=True)
    start_time = time.time()  # Start timer
    for i in range(times):
        dist, indx = knn(ref, query)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time


# Main testing function
def test_knn_methods(ref, query, k, times):

    print("Test times: %d" % times)

    # Test custom knn
    time_knn = test_knn(query, k, times)
    print(f"knn      : {time_knn:.6f} seconds")

    # Test custom knn_point
    time_point = test_knn_point(ref, query, k, times)
    print(f"knn_point: {time_point:.6f} seconds")

    # Test knn_cuda
    time_cuda = test_knn_cuda(ref, query, k, times)
    print(f"knn_cuda : {time_cuda:.6f} seconds")
    

if __name__ == '__main__':

    # Sample input
    B, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinates
    k = 24                            # Number of nearest neighbors
    ref = torch.randn(B, N, C).cuda() # Reference points

    # Test above methods
    times_list = [1,2,3,10,50,100]
    for times in times_list:
        test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2114318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【OpenCV2.2】图像的算术与位运算(图像的加法运算、图像的减法运算、图像的融合)、OpenCV的位运算(非操作、与运算、或和异或)

1 图像的算术运算 1.1 图像的加法运算 1.2 图像的减法运算 1.3 图像的融合 2 OpenCV的位运算 2.1 非操作 2.2 与运算 2.3 或和异或 1 图像的算术运算 1.1 图像的加法运算 add opencv使用add来执行图像的加法运算 图片就是矩阵, 图片的加法运算就是矩阵的加法运算, 这就要求加…

notepad下载安装教程

一、强大高效的代码编辑器 Notepad 是一款功能强大的代码编辑器,专为程序员和开发人员设计。无论是编写代码、处理文本文件,还是进行快速编辑,Notepad 都能提供卓越的性能和便利的功能,极大提升您的工作效率。 二、安装详细教程…

双指针(5)_单调性_有效三角形的个数

个人主页:C忠实粉丝 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 C忠实粉丝 原创 双指针(5)_单调性_有效三角形的个数 收录于专栏【经典算法练习】 本专栏旨在分享学习C的一点学习笔记,欢迎大家在评论区交流讨论💌 目录…

c++stack和list 介绍

stack介绍 堆栈是一种容器适配器,专门设计用于在 LIFO 上下文(后进先出)中运行,其中元素仅从容器的一端插入和提取。 堆栈作为容器适配器实现,容器适配器是使用特定容器类的封装对象作为其基础容器 的类,提…

mysql可重复读不能解决幻读吗?

1、可重复读和幻读的概念 1.1、可重复读 可重复读是数据库的四个隔离级别之一,可重复读可以保证在一个事物之内读取到的数据永远是相同的(通过mvcc表快照实现的),哪怕这期间有其它事务对数据做了修改,也不会影响当前事务的查询。 1.2、幻读 网上有不少博客说:幻读是一个事物内…

正规表达式例题

解析:从题意可知,a可以有零个或多个,b有1个或多个 选项A:这里a至少有1个,不符合题意 选项B:a^*bb^*,a是0个或多个,b可以是1个或多个,符合题意 选项C和选项D&#xff0…

Jenkins 通过 Version Number Plugin 自动生成和管理构建的版本号

步骤 1:安装 Version Number Plugin 登录 Jenkins 的管理界面。进入 “Manage Jenkins” -> “Manage Plugins”。在 “Available” 选项卡中搜索 “Version Number Plugin”。选中并安装插件,完成后可能需要重启 Jenkins。 步骤 2:配置…

尚品汇-支付宝下单接口显示二维码实现(四十六)

目录: (1)支付功能实现 (2)保存支付信息 (3)编写支付宝支付接口 (1)支付功能实现 支付宝有了同步通知为什么还需要异步通知? 同步回调两个作用 第一是从支付…

密保管家-随机密码本地生成

下载 简介 安全无忧:采用先进的加密算法,确保您的密码安全不外泄。 随机性强:每次生成的密码都是完全随机的,避免模式化,增加破解难度。 易于管理:简洁的界面设计让您轻松管理所有账号的密码。 独立运行:无需网络连接,所有数据本地存储,保护隐私的同时提供便捷的密…

【MATLAB】模拟退火算法

模拟退火算法的MATLAB实现 模拟退火算法简介模拟退火算法应用实例关于计算结果 模拟退火算法简介 1982年,Kirkpatrick 将退火思想引入组合优化领域,提出了一种能够有效解决大规模组合优化问题的算法,尤其对 NP 完全问题表现出显著优势。模拟…

FreeRTOS 优先级翻转以及互斥信号量

优先级翻转: 高优先级的任务反而慢执行,低优先级的任务反而优先执行 优先级翻转在抢占式内核中是非常常见的,但是在实时操作系统中是不允许出现优先级翻转的,因为优先级翻转会破坏任务的预期顺序,可能会导致未知的严重…

react | 自学笔记 | 持续更新

React自学速学笔记 数据单向流动事件为什么上述例子,是onClick{()>shoot("goal!")}而不是onClick{shoot("goal")}?event对象 条件渲染if方法&&?: 三元表达式 纯小白自学笔记,有不对的欢迎指正。 数据单向流动 单向流动…

如何确保光伏电站EPC施工的质量

说到保证EPC施工的质量,我们得先了解什么是EPC施工,是指:指总承包商按照合同约定,承担工程项目的设计、采购、施工等工作,并对工程的质量、安全、工期和造价全面负责。 EPC施工还有几个特点: 一体化服务&…

单片机毕业设计基于stm32的蔬菜大棚智能监控系统设计

文章目录 前言资料获取设计介绍功能介绍程序代码部分参考 设计清单具体实现截图参考文献设计获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP…

2.2.3 UDP的可靠传输协议QUIC 2

udp可靠传输 kcp协议 网络通畅下,kcp比tcp慢 这里直接看课件图片, 延迟ack比非延迟减少应答包数量,但是慢 kcp 讲解 kan代码ikcp.c 按照readme指南编译一下!! mkdir build cd build cmake .. make第一遍报错&#xf…

ant-design-vue中实现a-tree树形控件父子关联选中过滤的算法

在使用ant-design-vue的框架时,a-tree是比较常用的组件,比较适合处理树形结构的数据。 但是在与后台数据进行授权交互时,就不友好了。 在原生官方文档的例子中,若子项被勾选,则父级节点会被关联勾选,但这勾…

【堆的应用--C语言版】

前面一节我们都已将堆的结构(顺序存储)已经实现,对树的相关概念以及知识做了一定的了解。其中我们在实现删除操作和插入操作的时候,我们还同时实现了建大堆(小堆)的向上(下)调整算法…

【CSS in Depth 2 精译_024】4.2 弹性子元素的大小

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一章 层叠、优先级与继承(已完结) 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位(已完结) 2.1 相对…

PyInstaller问题解决 onnxruntime-gpu 使用GPU和CUDA加速模型推理

前言 在模型推理时,需要使用GPU加速,相关的CUDA和CUDNN安装好后,通过onnxruntime-gpu实现。 直接运行python程序是正常使用GPU的,如果使用PyInstaller将.py文件打包为.exe,发现只能使用CPU推理了。 本文分析这个问题…

TL-Tomcat中长连接的底层源码原理实现

长连接:浏览器告诉tomcat不要将请求关掉。 如果不是长连接,tomcat响应后会告诉浏览器把这个连接关掉。 tomcat中有一个缓冲区 如果发送大批量数据后 又不处理 那么会堆积缓冲区 后面的请求会越来越慢。