使用PCReg.PyTorch项目训练自己的数据集进行点云配准

news2025/1/31 11:29:38

项目地址: https://github.com/zhulf0804/PCReg.PyTorch/tree/main
网络简介: 网络是基于PointNet + Concat + FC的,它没有其它复杂的结构,易于复现。因其简洁性,这里暂且把其称作点云配准的Benchmark。因作者源码中复杂的(四元数, 旋转矩阵, 欧拉角之间)的变换操作和冗余性,且其PyTorch版本的不完整性(缺少评估模型等,最近又更新了),
项目详细介绍: 基于深度学习的点云配准Benchmark

本文方法与常见的图像配准逻辑类似,基于采样与transfrom操作从源点云生成目标点云,然后进行训练与评测。总体看来效果不如open3d自带的fgr方法,可以作为入门级项目进行使用。

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/PCReg.PyTorch/tree/main,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入PCReg.PyTorch-main目录,执行以下安装命令:

pip install -r requirements.txt
python -m pip install open3d>=0.9

emd loss编译
如果不做训练使用,可以不用进行emd loss编译。
cd loss/cuda/emd_torch & python setup.py install
在编译过程中,很有可能碰到报错

  File "C:\Users\Administrator\miniconda3\lib\site-packages\torch\utils\cpp_extension.py", line 499, in build_extensions
    _check_cuda_version(compiler_name, compiler_version)
  File "C:\Users\Administrator\miniconda3\lib\site-packages\torch\utils\cpp_extension.py", line 387, in _check_cuda_version
    raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
RuntimeError:
The detected CUDA version (12.1) mismatches the version that was used to compile
PyTorch (11.7). Please make sure to use the same CUDA versions.

这是由于PyTorch 的cuda版本与系统自带的cuda版本不同所导致的,可以先使用一下命令卸载过往的torch版本(慎重操作),然后重新安装torch;也可以在conda环境中重新创建一个符合系统cuda版本的torch环境。
pip uninstall torch torchvision torchaudio

在上述信息输出中,博主的cuda版本为12.1,我们可以打开pytorch官网找打符合自己电脑cuda版本的pytorch安装命令.
在这里插入图片描述
如果cuda版本较早,可以在 https://pytorch.org/get-started/previous-versions/ 中找到安装命令。

在正确的安装cuda版本后,重新执行命令即可实现emd-loss的安装
在这里插入图片描述

1.3 模型与数据下载

  • modelnet40数据集 [here, 435M]

  • 可用的预训练模型 [Complete, pwd: c4z7, 16.09 M] or [Paritial, pwd: pcno, 16.09] first) 模型下载好后将其放置到PCReg.PyTorch项目根路径下即可。

2. 关键代码说明

2.1 数据加载器

在data目录下有CustomData.py和ModelNet40.py两个文件,其中ModelNet40文件对应modelnet40数据集的加载,CustomData文件对应自己个人数据集的加载。从两个文件的__getitem__函数中可以发现,模型不是基于数据对进行训练的。其依据ref_cloud随机采样生成ref_cloud,然后对ref_cloud进行transform操作。具体实例如下所示:

    def __getitem__(self, item):
        file = self.files[item]
        ref_cloud = readpcd(file, rtype='npy')
        ref_cloud = random_select_points(ref_cloud, m=self.npts)
        ref_cloud = pc_normalize(ref_cloud)
        R, t = generate_random_rotation_matrix(-20, 20), \
               generate_random_tranlation_vector(-0.5, 0.5)
        src_cloud = transform(ref_cloud, R, t)
        if self.train:
            ref_cloud = jitter_point_cloud(ref_cloud)
            src_cloud = jitter_point_cloud(src_cloud)
        return ref_cloud, src_cloud, R, t

在以上代码中需要注意的是,所有的点云都进行了坐标值的归一化处理

2.2 模型结构

在model目录下有benchmark.py、fgr.py、icp.py,分别为模型配准,fgr配准,icp配准方法。其中fgr配准与icp配准方法是使用open3d库实现。

benchmark

benchmark为本文模型,其是基于PointNet所实现的一个孪生网络,核心代码如Benchmark类所示。其基于encoder提取2个点云的特征,然后简单的使用全连接层将两个点云的特征进行交互,然后再输出两个点云的特征。在这里最为重要的是loss的设计,即如何设计优化目标,使模型参数以点云配准为优化方向

class Benchmark(nn.Module):
    def __init__(self, gn, in_dim1, in_dim2=2048, fcs=[1024, 1024, 512, 512, 256, 7]):
        super(Benchmark, self).__init__()
        self.in_dim1 = in_dim1
        self.encoder = PointNet(in_dim=in_dim1, gn=gn)
        self.decoder = nn.Sequential()
        for i, out_dim in enumerate(fcs):
            self.decoder.add_module(f'fc_{i}', nn.Linear(in_dim2, out_dim))
            if out_dim != 7:
                if gn:
                    self.decoder.add_module(f'gn_{i}',nn.GroupNorm(8, out_dim))
                self.decoder.add_module(f'relu_{i}', nn.ReLU(inplace=True))
            in_dim2 = out_dim

    def forward(self, x, y):
        x_f, y_f = self.encoder(x), self.encoder(y)
        concat = torch.cat((x_f, y_f), dim=1)
        out = self.decoder(concat)
        batch_t, batch_quat = out[:, :3], out[:, 3:] / torch.norm(out[:, 3:], dim=1, keepdim=True)
        batch_R = batch_quat2mat(batch_quat)
        if self.in_dim1 == 3:
            transformed_x = batch_transform(x.permute(0, 2, 1).contiguous(),
                                            batch_R, batch_t)
        elif self.in_dim1 == 6:
            transformed_pts = batch_transform(x.permute(0, 2, 1)[:, :, :3].contiguous(),
                                            batch_R, batch_t)
            transformed_nls = batch_transform(x.permute(0, 2, 1)[:, :, 3:].contiguous(),
                                              batch_R)
            transformed_x = torch.cat([transformed_pts, transformed_nls], dim=-1)
        else:
            raise ValueError
        return batch_R, batch_t, transformed_x


class IterativeBenchmark(nn.Module):
    def __init__(self, in_dim, niters, gn):
        super(IterativeBenchmark, self).__init__()
        self.benckmark = Benchmark(gn=gn, in_dim1=in_dim)
        self.niters = niters

    def forward(self, x, y):
        transformed_xs = []
        device = x.device
        B = x.size()[0]
        transformed_x = torch.clone(x)
        batch_R_res = torch.eye(3).to(device).unsqueeze(0).repeat(B, 1, 1)
        batch_t_res = torch.zeros(3, 1).to(device).unsqueeze(0).repeat(B, 1, 1)
        for i in range(self.niters):
            batch_R, batch_t, transformed_x = self.benckmark(transformed_x, y)
            transformed_xs.append(transformed_x)
            batch_R_res = torch.matmul(batch_R, batch_R_res)
            batch_t_res = torch.matmul(batch_R, batch_t_res) \
                          + torch.unsqueeze(batch_t, -1)
            transformed_x = transformed_x.permute(0, 2, 1).contiguous()
        batch_t_res = torch.squeeze(batch_t_res, dim=-1)
        #transformed_x = transformed_x.permute(0, 2, 1).contiguous()
        return batch_R_res, batch_t_res, transformed_xs
fgr配准方法
import copy
import open3d as o3d


def fpfh(pcd, normals):
    pcd.normals = o3d.utility.Vector3dVector(normals)
    pcd_fpfh = o3d.registration.compute_fpfh_feature(
        pcd,
        o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=64))
    return pcd_fpfh


def execute_fast_global_registration(source, target, source_fpfh, target_fpfh):
    distance_threshold = 0.01
    result = o3d.registration.registration_fast_based_on_feature_matching(
        source, target, source_fpfh, target_fpfh,
        o3d.registration.FastGlobalRegistrationOption(
            maximum_correspondence_distance=distance_threshold))
    transformation = result.transformation
    estimate = copy.deepcopy(source)
    estimate.transform(transformation)
    R, t = transformation[:3, :3], transformation[:3, 3]
    return R, t, estimate


def fgr(source, target, src_normals, tgt_normals):
    source_fpfh = fpfh(source, src_normals)
    target_fpfh = fpfh(target, tgt_normals)
    R, t, estimate = execute_fast_global_registration(source=source,
                                                      target=target,
                                                      source_fpfh=source_fpfh,
                                                      target_fpfh=target_fpfh)
    return R, t, estimate

ICP配准方法
import copy
import numpy as np
import open3d as o3d


def icp(source, target):
    max_correspondence_distance = 2 # 0.5 in RPM-Net
    init = np.eye(4, dtype=np.float32)
    estimation_method = o3d.pipelines.registration.TransformationEstimationPointToPoint()

    reg_p2p = o3d.pipelines.registration.registration_icp(
        source=source,
        target=target,
        init=init,
        max_correspondence_distance=max_correspondence_distance,
        estimation_method=estimation_method
    )

    transformation = reg_p2p.transformation
    estimate = copy.deepcopy(source)
    estimate.transform(transformation)
    R, t = transformation[:3, :3], transformation[:3, 3]
    return R, t, estimate

3.基本使用

modelnet40数据集的评测及训练可以使用一下代码实现

    # Iterative Benchmark
    python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth --cuda
    
    # Visualization
    # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --checkpoint your_ckpt_path/test_min_loss.pth  --show
    
    # ICP
    # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method icp
    
    # FGR
    # python modelnet40_evaluate.py --root your_data_path/modelnet40_ply_hdf5_2048 --method fgr --normal

  • train

    CUDA_VISIBLE_DEVICES=0 python modelnet40_train.py --root your_data_path/modelnet40_ply_hdf5_2048
    

这里注意讲述训练与评测自己的数据集,其中自己数据集的路径如下所示,里面都是处理好的pcd点云数据。
在这里插入图片描述
具体格式为:

    |- CustomData(dir)
        |- train_data(dir)
            - train1.pcd
            - train2.pcd
            - ...
        |- val_data(dir)
            - val1.pcd
            - val2.pcd
            - ...

3.1 ICP方法性能评测

可以加上 --show 参数来查看每一个配准的数据

python custom_evaluate.py --root cumstom_data --infer_npts 2048  --method icp --normal

如果出现以下报错,则用open3d.pipelines.registration替换open3d.registration,具体可以用本博文的ICP配准方法替换掉models\icp.py中的内容

Traceback (most recent call last):
  File "custom_evaluate.py", line 142, in <module>
    evaluate_icp(args, test_loader)
  File "custom_evaluate.py", line 98, in evaluate_icp
    R, t, pred_ref_cloud = icp(npy2pcd(src_cloud), npy2pcd(ref_cloud))
  File "D:\点云AI配准\PCReg.PyTorch-main\models\icp.py", line 9, in icp
    estimation_method = o3d.registration.TransformationEstimationPointToPoint()
AttributeError: module 'open3d' has no attribute 'registration'

具体执行输出如下所示:
在这里插入图片描述

3.2 模型训练

训练命令:
python custom_train.py --root cumstom_data --train_npts 2048
在这里插入图片描述
训练好的模型保存在work_dirs\models\checkpoints目录中
在这里插入图片描述

评测命令:
python custom_evaluate.py --infer_npts 2048 --root cumstom_data --checkpoint work_dirs\models\checkpoints\test_min_loss.pth --show
其中绿色点云为源点云,红色点云为参考点云,蓝色点云为配准后的源点云,可以看到蓝色点云与红色点云完全没有对齐,这表明训练效果极其不佳。这或许是训练数据太少所导致的,毕竟本次实验只有18个点云数据。
在这里插入图片描述

4、原文效果

下图是作者论文中的配准效果图
在这里插入图片描述
在modelnet40数据集上相关精度信息如下所示,可以确定,本文方法与FGR方法相比没有显著性优势。

  • Point-to-Point Correspondences(R error is large due to EMDLoss, see here)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)time(s)
ICP11.440.1617.64(5.48)0.22(0.07)0.07
FGR0.010.000.07(0.00)0.00(0.00)0.19
IBenchmark5.680.079.77(2.69)0.12(0.03)0.022
IBenchmark + ICP3.650.049.22(1.66)0.11(0.02)
  • Noise Data(infer_npts = 1024)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)
ICP12.140.1718.32(5.86)0.23(0.08)
FGR4.270.0611.55(2.43)0.09(0.03)
IBenchmark6.250.089.28(2.94)0.12(0.04)
IBenchmark + ICP5.100.0710.51(2.39)0.13(0.03)
  • Partial-to-Complete Registration(infer_npts = 1024)
Methodisotropic Risotropic tanisotropic R(mse, mae)anisotropic t(mse, mae)
ICP21.330.3222.83(10.51)0.31(0.15)
FGR9.490.1219.51(5.58)0.17(0.06)
IBenchmark15.020.2215.78(7.45)0.21(0.10)
IBenchmark + ICP9.210.1314.73(4.43)0.18(0.06)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习记录--logistic回归函数的计算图

计算图用于logistic回归函数 先回顾一下单一样本的logistic回归损失函数的公式&#xff0c;公式如下&#xff1a; 将logistic函数用计算图表示出来(以两个基础量为例)&#xff0c;计算图如下&#xff1a; 前向传播已经完成&#xff0c;接下来完成后向传播 运用链式法则依次求…

使用 STM32 微控制器读取光电传感器数据的实现方法

本文介绍了如何使用 STM32 微控制器读取光电传感器数据的实现方法。通过配置和使用STM32的GPIO和ADC功能&#xff0c;可以实时读取光电传感器的模拟信号并进行数字化处理。本文将介绍硬件连接和配置&#xff0c;以及示例代码&#xff0c;帮助开发者完成光电传感器数据的读取。 …

【探索Linux】—— 强大的命令行工具 P.19(多线程 | 线程的概念 | 线程控制 | 分离线程)

阅读导航 引言一、 Linux线程概念1. 什么是线程2. 线程的概念3. 线程与进程的区别4. 线程异常 二、Linux线程控制1. POSIX线程库2. 创建线程 pthread_create() 函数&#xff08;1&#xff09;头文件&#xff08;2&#xff09;函数原型&#xff08;3&#xff09;参数解释&#x…

elasticsearch 内网下如何以离线的方式上传任意的huggingFace上的NLP模型(国内闭坑指南)

es自2020年的8.x版本以来&#xff0c;就提供了机器学习的能力。我们可以使用es官方提供的工具eland&#xff0c;将hugging face上的NLP模型&#xff0c;上传到es集群中。利用es的机器学习模块&#xff0c;来运维部署管理模型。配合es的管道处理&#xff0c;来更加便捷的处理数据…

深度学习记录--logistic回归损失函数向量化实现

前言 再次明确向量化的目的&#xff1a;减少for循环的使用&#xff0c;以更少的代码量和更快的速度来实现程序 正向传播的向量化 对于,用向量化来实现&#xff0c;只需要就可以完成&#xff0c;其中,, ps.这里b只是一个常数&#xff0c;但是依然可以加在每个向量里(python的…

TLS协议握手流程

浅析 TLS&#xff08;ECDHE&#xff09;协议的握手流程&#xff08;图解&#xff09; - 知乎 前言 通过 wireshark 抓取 HTTPS 包&#xff0c;理解 TLS 1.2 安全通信协议的握手流程。 重点理解几个点&#xff1a; TLS 握手流程&#xff1a;通过 wireshark 抓取 HTTPS 包理解…

项目中遇到的半导体公司

作为一个技术人&#xff0c;我并不是亲美&#xff0c;从技术的实事求是角度讲&#xff0c;不得不感叹欧美的半导体技术。他们的datasheet能学到的东西太多太多&#xff1b;我甚至佩服他们缜密的逻辑。从他们的文章中领悟我们技术到底有多low&#xff0c;没办法一个一个了解所有…

Nginx转发内网Flv视频流

1、环境说明 Docker Nginx&#xff1a;1.21.5 实现Nginx ssl转发内网flv视频流 2、配置nginx.conf http {upstream live {server 10.10.10.10:8300;keepalive 64;}map $http_upgrade $connection_upgrade {default upgrade; close;}server {listen 80;listen 443…

ssm+vue的罪犯信息管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频&#xff1a; ssmvue的罪犯信息管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;ssm vue前后端分离项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&…

Redis hash表源码解析

整体数据结构&#xff1a;链式hash解决hash冲突、采用渐进式hash来完成扩容过程。 /** 哈希表节点*/ typedef struct dictEntry {// 键void *key;// 值union {void *val;uint64_t u64;int64_t s64;} v;// 指向下个哈希表节点&#xff0c;形成链表struct dictEntry *next;} dict…

Adobe Acrobat DC 将PDF转曲步骤

1、编辑--更多--背景--添加 2、只需要将不透明度调为0即可。 3、工具--印刷制作 4、拼合器预览 5、只需要将下面标出来的地方勾选即可 6、可以另存为&#xff0c;不影响源文件 7、检查是否成功&#xff0c;文件--属性--字体为空&#xff0c;说明成功了 参考资料&#xff1a; …

详解Spring中BeanPostProcessor在Spring工厂和Aop发挥的作用

&#x1f609;&#x1f609; 学习交流群&#xff1a; ✅✅1&#xff1a;这是孙哥suns给大家的福利&#xff01; ✨✨2&#xff1a;我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 &#x1f96d;&#x1f96d;3&#xff1a;QQ群&#xff1a;583783…

【论文阅读】ICRA: An Intelligent Clustering Routing Approach for UAV Ad Hoc Networks

文章目录 论文基本信息摘要1.引言2.相关工作3.PROPOSED SCHEME4.实验和讨论5.总结补充 论文基本信息 《ICRA: An Intelligent Clustering Routing Approach for UAV Ad Hoc Networks》 《ICRA:无人机自组织网络的智能聚类路由方法》 Published in: IEEE Transactions on Inte…

FISCO-BCOS 在ARM系统架构搭建节点(国密版)

问题&#xff1a; 使用 fisco-bcos v2.9.1 搭建一个节点&#xff0c;批量上链1000条数据&#xff0c;在上链200条-400条数据之间节点会出现异常&#xff0c;导致后面数据不能上链。 系统环境 操作系统&#xff1a;统信 查看系统构架 ld -version rootuos-PC:/# ld -version …

RK3568平台开发系列讲解(Linux系统篇)netlink 监听广播信息

** 🚀返回专栏总目录 文章目录 一、什么是netlink 机制二、netlink 的使用2.1、创建 socket2.2、绑定套接字2.3、接收数据沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将介绍如何通过 netlink 监听广播信息。 一、什么是netlink 机制 Netlink 是 Linux 内核中…

【C/C++笔试练习】公有派生、构造函数内不执行多态、抽象类和纯虚函数、多态中的缺省值、虚函数的描述、纯虚函数的声明、查找输入整数二进制中1的个数、手套

文章目录 C/C笔试练习选择部分&#xff08;1&#xff09;公有派生&#xff08;2&#xff09;构造函数内不执行多态&#xff08;3&#xff09;抽象类和纯虚函数&#xff08;4&#xff09;多态中的缺省值&#xff08;5&#xff09;程序分析&#xff08;6&#xff09;重载和隐藏&a…

使用gcloud SDK 管理和部署 Cloud run service

查看cloud run 上的service 列表&#xff1a; gcloud run services list > gcloud run services listSERVICE REGION URL LAST DEPLOYED BY LAST DEPL…

【Go语言反射reflect】

Go语言反射reflect 一、引入 先看官方Doc中Rob Pike给出的关于反射的定义&#xff1a; Reflection in computing is the ability of a program to examine its own structure, particularly through types; it’s a form of metaprogramming. It’s also a great source of …

Android BT HCI分析简介

对于蓝牙开发者来说&#xff0c;通过HCI log可以帮助我们更好地分析问题&#xff0c;理解蓝牙协议&#xff0c;就好像网络开发一定要会使用Wireshark分析网络协议一样。 本篇主要介绍HCI log的作用、如何抓取一份HCI log&#xff0c;并结合一个实际的例子来说明如何分析HCI log…

004、简单页面-基础组件

之——基础组件 目录 之——基础组件 杂谈 正文 1.Image 1.0 数据源 1.1 缩放 1.2 大小 1.3 网络图片 2.Text 2.0 数据源 2.1 大小 2.2 粗细 2.3 颜色 2.5 样式字体 2.6 基础示例 2.7 对齐 2.8 省略 2.9 划线 3.TextInput 3.1 输入类型 3.2 提示文…