基于ROPNet项目训练modelnet40数据集进行3d点云的配置

news2024/11/27 2:23:22

项目地址: https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)(ICCV Workshop 2021)中获得了第二名。项目可以在win10环境下运行。
论文地址: https://arxiv.org/abs/2107.02583

网络简介: 一种新的深度学习模型,该模型利用具有区别特征的代表性重叠点进行配准,将部分到部分配准转换为部分完全配准。基于pointnet输出的特征设计了一个上下文引导模块,使用一个编码器来提取全局特征来预测点重叠得分。为了更好地找到有代表性的重叠点,使用提取的全局特征进行粗对齐。然后,引入一种变压器来丰富点特征,并基于点重叠得分和特征匹配去除非代表性点。在部分到完全的模式下建立相似度矩阵,最后采用加权支持向量差来估计变换矩阵。
在这里插入图片描述
实施效果: 从数据上看ROPNet与RPMNet与保持了断崖式的领先地位
在这里插入图片描述

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/ROPNet,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入ROPNet-master/src目录,执行以下安装命令。如果已经安装了torch 环境和open3d包,则不用再进行安装了

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

pip install open3d

1.3 模型与数据下载

modelnet40数据集 here [435M]
数据集下载后存储为以下路径即可。
在这里插入图片描述

官网预训练模型,无。
第三方预训练模型:使用ROPNet项目在modelnet40数据集上训练的模型

2、关键代码

2.1 dataloader

作者所提供的dataloader只能加载https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 数据集,其所返回的tgt_cloud, src_cloud实质上是基于一个点云采样而来的。 其中的self.label2cat, self.cat2label, self.symmetric_labels等对象代码实际上是没有任何作用的。

import copy
import h5py
import math
import numpy as np
import os
import torch

from torch.utils.data import Dataset
import sys

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOR_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOR_DIR)
from utils import  random_select_points, shift_point_cloud, jitter_point_cloud, \
    generate_random_rotation_matrix, generate_random_tranlation_vector, \
    transform, random_crop, shuffle_pc, random_scale_point_cloud, flip_pc
    


half1 = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl',
         'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser',
         'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp']
half1_symmetric = ['bottle', 'bowl', 'cone', 'cup', 'flower_pot', 'lamp']

half2 = ['laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano',
         'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool',
         'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
half2_symmetric = ['tent', 'vase']


class ModelNet40(Dataset):
    def __init__(self, root, split, npts, p_keep, noise, unseen, ao=False,
                 normal=False):
        super(ModelNet40, self).__init__()
        self.single = False # for specific-class visualization
        assert split in ['train', 'val', 'test']
        self.split = split
        self.npts = npts
        self.p_keep = p_keep
        self.noise = noise
        self.unseen = unseen
        self.ao = ao # Asymmetric Objects
        self.normal = normal
        self.half = half1 if split in 'train' else half2
        self.symmetric = half1_symmetric + half2_symmetric
        self.label2cat, self.cat2label = self.label2category(
            os.path.join(root, 'shape_names.txt'))
        self.half_labels = [self.cat2label[cat] for cat in self.half]
        self.symmetric_labels = [self.cat2label[cat] for cat in self.symmetric]
        files = [os.path.join(root, 'ply_data_train{}.h5'.format(i))
                 for i in range(5)]
        if split == 'test':
            files = [os.path.join(root, 'ply_data_test{}.h5'.format(i))
                     for i in range(2)]
        self.data, self.labels = self.decode_h5(files)
        print(f'split: {self.split}, unique_ids: {len(np.unique(self.labels))}')

        if self.split == 'train':
            self.Rs = [generate_random_rotation_matrix() for _ in range(len(self.data))]
            self.ts = [generate_random_tranlation_vector() for _ in range(len(self.data))]

    def label2category(self, file):
        with open(file, 'r') as f:
            label2cat = [category.strip() for category in f.readlines()]
            cat2label = {label2cat[i]: i for i in range(len(label2cat))}
        return label2cat, cat2label

    def decode_h5(self, files):
        points, normal, label = [], [], []
        for file in files:
            f = h5py.File(file, 'r')
            cur_points = f['data'][:].astype(np.float32)
            cur_normal = f['normal'][:].astype(np.float32)
            cur_label = f['label'][:].flatten().astype(np.int32)
            if self.unseen:
                idx = np.isin(cur_label, self.half_labels)
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            if self.ao and self.split in ['val', 'test']:
                idx = ~np.isin(cur_label, self.symmetric_labels)
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            if self.single:
                idx = np.isin(cur_label, [8])
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            points.append(cur_points)
            normal.append(cur_normal)
            label.append(cur_label)
        points = np.concatenate(points, axis=0)
        normal = np.concatenate(normal, axis=0)
        data = np.concatenate([points, normal], axis=-1).astype(np.float32)
        label = np.concatenate(label, axis=0)
        return data, label

    def compose(self, item, p_keep):
        tgt_cloud = self.data[item, ...]
        if self.split != 'train':
            np.random.seed(item)
            R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()
        else:
            tgt_cloud = flip_pc(tgt_cloud)
            R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()

        src_cloud = random_crop(copy.deepcopy(tgt_cloud), p_keep=p_keep[0])
        src_size = math.ceil(self.npts * p_keep[0])
        tgt_size = self.npts
        if len(p_keep) > 1:
            tgt_cloud = random_crop(copy.deepcopy(tgt_cloud),
                                    p_keep=p_keep[1])
            tgt_size = math.ceil(self.npts * p_keep[1])

        src_cloud_points = transform(src_cloud[:, :3], R, t)
        src_cloud_normal = transform(src_cloud[:, 3:], R)
        src_cloud = np.concatenate([src_cloud_points, src_cloud_normal],
                                   axis=-1)
        src_cloud = random_select_points(src_cloud, m=src_size)
        tgt_cloud = random_select_points(tgt_cloud, m=tgt_size)

        if self.split == 'train' or self.noise:
            src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3])
            tgt_cloud[:, :3] = jitter_point_cloud(tgt_cloud[:, :3])
        tgt_cloud, src_cloud = shuffle_pc(tgt_cloud), shuffle_pc(
            src_cloud)
        return src_cloud, tgt_cloud, R, t

    def __getitem__(self, item):
        src_cloud, tgt_cloud, R, t = self.compose(item=item,
                                                  p_keep=self.p_keep)
        if not self.normal:
            tgt_cloud, src_cloud = tgt_cloud[:, :3], src_cloud[:, :3]
        return tgt_cloud, src_cloud, R, t

    def __len__(self):
        return len(self.data)

2.2 模型设计

模型设计如下:
在这里插入图片描述

2.3 loss设计

其主要包含Init_loss、Refine_loss和Ol_loss。
其中Init_loss是用于计算 预测点 云 0 预测点云_0 预测点0与目标点云的mse或mae loss,
Refine_loss用于计算 预测点 云 [ 1 : ] 预测点云_{[1:]} 预测点[1:]与目标点云的加权mae loss
Ol_loss用于计算两个输入点云输出的重叠分数,使两个点云对应点的重叠分数是一样的。
在这里插入图片描述

具体实现代码如上:


import math
import torch
import torch.nn as nn
from utils import square_dists


def Init_loss(gt_transformed_src, pred_transformed_src, loss_type='mae'):

    losses = {}
    num_iter = 1
    if loss_type == 'mse':
        criterion = nn.MSELoss(reduction='mean')
        for i in range(num_iter):
            losses['mse_{}'.format(i)] = criterion(pred_transformed_src[i],
                                                   gt_transformed_src)
    elif loss_type == 'mae':
        criterion = nn.L1Loss(reduction='mean')
        for i in range(num_iter):
            losses['mae_{}'.format(i)] = criterion(pred_transformed_src[i],
                                                   gt_transformed_src)
    else:
        raise NotImplementedError

    total_losses = []
    for k in losses:
        total_losses.append(losses[k])
    losses = torch.sum(torch.stack(total_losses), dim=0)
    return losses


def Refine_loss(gt_transformed_src, pred_transformed_src, weights=None, loss_type='mae'):
    losses = {}
    num_iter = len(pred_transformed_src)
    for i in range(num_iter):
        if weights is None:
            losses['mae_{}'.format(i)] = torch.mean(
                torch.abs(pred_transformed_src[i] - gt_transformed_src))
        else:
            losses['mae_{}'.format(i)] = torch.mean(torch.sum(
                weights * torch.mean(torch.abs(pred_transformed_src[i] -
                                               gt_transformed_src), dim=-1)
                / (torch.sum(weights, dim=-1, keepdim=True) + 1e-8), dim=-1))

    total_losses = []
    for k in losses:
        total_losses.append(losses[k])
    losses = torch.sum(torch.stack(total_losses), dim=0)

    return losses


def Ol_loss(x_ol, y_ol, dists):
    CELoss = nn.CrossEntropyLoss()
    x_ol_gt = (torch.min(dists, dim=-1)[0] < 0.05 * 0.05).long() # (B, N)
    y_ol_gt = (torch.min(dists, dim=1)[0] < 0.05 * 0.05).long() # (B, M)
    x_ol_loss = CELoss(x_ol, x_ol_gt)
    y_ol_loss = CELoss(y_ol, y_ol_gt)
    ol_loss = (x_ol_loss + y_ol_loss) / 2
    return ol_loss


def cal_loss(gt_transformed_src, pred_transformed_src, dists, x_ol, y_ol):
    losses = {}
    losses['init'] = Init_loss(gt_transformed_src,
                               pred_transformed_src[0:1])
    if x_ol is not None:
        losses['ol'] = Ol_loss(x_ol, y_ol, dists)
    losses['refine'] = Refine_loss(gt_transformed_src,
                                   pred_transformed_src[1:],
                                   weights=None)
    alpha, beta, gamma = 1, 0.1, 1
    if x_ol is not None:
        losses['total'] = losses['init'] + beta * losses['ol'] + gamma * losses['refine']
    else:
        losses['total'] = losses['init'] + losses['refine']
    return losses

3、训练与预测

先进入src目录,并将modelnet40_ply_hdf5_2048.zip解压在src目录下
在这里插入图片描述

3.1 训练

训练命令及训练输出如下所示

python train.py --root modelnet40_ply_hdf5_2048/ --noise --unseen

python请添加图片描述
在训练过程中会在work_dirs\models\checkpoints目录下生成两个模型文件
在这里插入图片描述

3.2 验证

训练命令及训练输出如下所示

python eval.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --cuda --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

请添加图片描述

3.3 测试

测试训练数据的命令如下

python vis.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

具体配准效果如下所示,其中绿色点云为输入点云,红色点云为参考点云,蓝色点云为配准后的点云。可以看到蓝色点云基本与红色点云重合,可以确定其配准效果十分完好。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.4 处理自己的数据集

基于该项目训练并处理自己数据的教程后续会给出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1287115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue2项目中添加字体文件

vue2项目中添加字体文件 1、下载相关文件&#xff0c;放置文件夹中&#xff0c;这里我是在assets文件中新建了fontFamily 2、在assets文件中新建css文件 3、在页面中使用 <style lang"less" scoped> import ../../assets/css/fonts.less;.total-wrap {displa…

深度学习火车票识别系统 计算机竞赛

文章目录 0 前言1 课题意义课题难点&#xff1a; 2 实现方法2.1 图像预处理2.2 字符分割2.3 字符识别部分实现代码 3 实现效果4 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 图像识别 火车票识别系统 该项目较为新颖&#xff0c;适…

详细介绍如何使用 SSD 进行实时物体检测:单次 MultiBox 探测器-含源码

介绍 在实时对象检测中,主流范例传统上采用多步骤方法,包括边界框、像素或特征重采样以及高质量分类器应用的提议。虽然这种方法已经实现了高精度,但其计算需求往往阻碍了其对实时应用的适用性。然而,单次多框检测器 (SSD) 代表了基于深度学习的对象检测的突破性飞跃。SSD…

rust中动态数组Vec的简单使用

在Rust中&#xff0c;Vector&#xff08;简称Vec&#xff09;是一个动态数组数据结构&#xff0c;它可以动态地增加或减少其容量。Vec是Rust标准库中的一个常见类型&#xff0c;非常适合用于存储和操作一系列相同类型的值。 Vec其实是一个智能指针&#xff0c;用于在堆上分配内…

自定义软件app定制开发的需求和趋势|企业网站小程序搭建

自定义软件app定制开发的需求和趋势|企业网站小程序搭建 随着智能手机的普及和移动互联网的快速发展&#xff0c;移动应用程序&#xff08;App&#xff09;成为人们日常生活和工作中必不可少的一部分。然而&#xff0c;市面上已有的应用程序并不能完全满足用户的个性化需求&…

编写测试用例的17个技巧

1、前言 测试用例是任何测试周期的第一步&#xff0c;对任何项目都非常重要。如果在此步骤中出现任何问题&#xff0c;则在整个软件测试过程中都会扩大影响。如果测试人员在创建测试用例模板时使用正确的过程和准则&#xff0c;则可以避免这种情况。 在本文中将分享一些简单而…

synchronized底层原理(一)

文章目录 1. 问题引入2. 相关概念3. Synchronized使用4. Synchronized底层原理1. 简介2. Monitor&#xff08;管程/监视器&#xff09;3. Java语言的内置管程synchronized4. Java对象的内存布局5. 如何使用MarkWord记录锁状态6. 偏向锁7. 轻量级锁 1. 问题引入 假设我们有1000…

PostGIS学习教程十:空间索引

PostGIS学习教程十&#xff1a;空间索引 回想一下&#xff0c;空间索引是空间数据库的三个关键特性之一。空间索引使得使用空间数据库存储大型数据集成为可能。在没有空间索引的情况下&#xff0c;对要素的任何搜索都需要对数据库中的每条记录进行"顺序扫描"。索引通…

在gitlab中使用gitlab-sshd替换ssh服务

参考&#xff1a;https://docs.gitlab.com/ee/administration/operations/gitlab_sshd.html 说明 gitlab-sshd 是 OpenSSH 的轻量级替代品&#xff0c;用于提供 SSH 操作。虽然 OpenSSH 使用受限的 shell 方法&#xff0c;但 gitlab-sshd 的行为更像是一个现代的多线程服务器应…

mabatis基于xml方式和注解方式实现多表查询

前面步骤 http://t.csdnimg.cn/IPXMY 1、解释 在数据库中&#xff0c;单表的操作是最简单的&#xff0c;但是在实际业务中最少也有十几张表&#xff0c;并且表与表之间常常相互间联系&#xff1b; 一对一、一对多、多对多是表与表之间的常见的关系。 一对一&#xff1a;一张…

【Maven】安装和使用

1. Maven 概述 Maven 是一款用于管理和构建 java 项目的工具&#xff0c;可以进行依赖管理、统一项目结构和项目构建。 1.1 Maven 模型 项目对象模型 (Project Object Model)依赖管理模型(Dependency)构建生命周期/阶段(Build lifecycle & phases) 1.2 Maven 仓库 仓库的…

ROS2教程01 ROS2介绍

ROS2介绍 版权信息 Copyright 2023 Herman YeAuromix. All rights reserved.This course and all of its associated content, including but not limited to text, images, videos, and any other materials, are protected by copyright law. The author holds all right…

高级搜索——伸展树Splay详解

文章目录 伸展树Splay伸展树Splay的定义局部性原理Splay的伸展操作逐层伸展双层伸展zig-zig/zag-zagzig-zag/zag-zigzig/zag双层伸展的效果与效率 伸展树的实现动态版本实现递增分配器节点定义Splay类及其接口定义伸展操作左单旋右单旋右左/左右双旋伸展 查找操作删除操作插入操…

【laBVIEW学习】4.声音播放,自定义图标,滚动条设置,保存参数以及恢复参数

一。声音播放&#xff08;报错&#xff0c;未实现&#xff09; 1.报错4810 2.解决方法&#xff1a; 暂时未解决。 二。图片修改 1.目标&#xff1a;灯泡---》自定义灯泡 2.步骤&#xff1a; 1.右键点击--》自定义运行 表示可以制作自定义类型 2.右键--》打开自定义类型 这样就…

【Qt开发流程】之对象模型2:属性系统

描述 Qt提供了一个复杂的属性系统&#xff0c;类似于一些编译器供应商提供的属性系统。然而&#xff0c;作为一个独立于编译器和平台的库&#xff0c;Qt不依赖于非标准的编译器特性&#xff0c;如__property或[property]。 Qt解决方案适用于Qt支持的所有平台上的任何标准c编译…

java的GUI基础使用

java.awt包提供了基本的GUI设计工具&#xff0c;主要包括组件&#xff08;Component&#xff09;、容器&#xff08;Container&#xff09;和布局管理器&#xff08;LayoutManager&#xff09;&#xff1b; Java的图形用户界面的最基本组成部分是组件&#xff08;Component&…

Hadoop学习笔记(HDP)-Part.11 安装Kerberos

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …

数据库事务:保障数据一致性的基石

目录 1. 什么是数据库事务&#xff1f; 1.1 ACID特性解析 2. 事务的实现与控制 2.1 事务的开始和结束 2.2 事务的隔离级别 3. 并发控制与事务管理 3.1 并发控制的挑战 3.2 锁和并发控制算法 4. 最佳实践与性能优化 4.1 事务的划分 4.2 批处理操作 5. 事务的未来发展…

odoo自定义提示性校验

背景: 在odoo16的原生的代码里&#xff0c;可以给按钮添加一个 confirm属性&#xff0c;从而达到 提示性校验的效果。 问题&#xff1a; 这个属性加了之后一定会弹出提示性校验的对话框&#xff0c;于是如何根据我们的实际业务&#xff0c;从后端返回提示性信息&#xff0c;…

极致体验云上无缝协作

探索SOLIDWORKS云上之旅 谁适合应用3DEXPERIENCE云平台? 迈向云策略的数字化转型企业、加速新品上市的企业创新部门、资源有限的小微及初创企业 什么是3DEXPERIENCE云平台? 3DEXPERIENCE(3DX)是一种业务与创新平台,可让所有组织整体实时了解业务活动和生态系统&#xff0c…