用PointNet分类3D点云

news2024/11/22 3:14:03

在本教程中,我们将学习如何训练PointNet进行分类。 我们将主要关注数据和训练过程; 展示如何从头开始编码 Point Net 的教程位于此处。 本教程的代码位于这个Github库中,我们将使用的笔记本位于这个Github库中。 一些代码的灵感来自于这个Github库。

在这里插入图片描述

推荐:用 NSDT设计器 快速搭建可编程3D场景。

1、获取数据

我们将使用只有 16 个类的较小版本的 shapenet 数据集。 如果你使用的是Colab,可以运行以下代码来获取数据。 警告,这将需要很长时间。

!wget -nv https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate
!unzip shapenetcore_partanno_segmentation_benchmark_v0.zip
!rm shapenetcore_partanno_segmentation_benchmark_v0.zip

如果你想在本地运行,请访问上面第一行的链接,数据将自动下载为 zip 文件。

该数据集包含 16 个带有类标识符的文件夹(自述文件中称为“synsetoffset”)。 文件夹结构为:

synsetoffset
  |- points                  # 来自 ShapeNetCore 模型的均匀采样点
  |- point_labels            # 每点分割标签
  |- seg_img                 #标签的可视化
train_test_split:           #带有训练/验证/测试拆分的 JSON 文件

自定义 PyTorch 数据集位于此处,解释代码超出了本教程的范围。 需要了解的重要一点是,数据集可以获取 (point_cloud, class) 或 (point_cloud, seg_labels)。 在训练和验证期间,我们向点云添加高斯噪声,并围绕垂直轴(本例中为 y 轴)随机旋转它们。 我们还对点云进行最小-最大归一化,以便它们的范围为 0-1。 我们可以像这样创建 shapenet 数据集的实例:

from shapenet_dataset import ShapenetDataset

# __getitem__ returns (point_cloud, class)
train_dataset = ShapenetDataset(ROOT, npoints=2500, split='train', classification=True)

2、探索数据

在开始任何训练之前,让我们先探讨一些训练数据。 为此,我们将使用 Open3d 版本 0.16.0(必须为 0.16.0 或更高版本)。

!pip install open3d==0.16.0

我们现在可以使用以下代码查看示例点云。 你应该注意到,每次运行代码时点云都会以不同的方向显示。

import open3d as o3
from shapenet_dataset import ShapenetDataset

sample_dataset = train_dataset = ShapenetDataset(ROOT, npoints=20000, split='train', 
                                                 classification=False, normalize=False)

points, seg = sample_dataset[4000]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(points)
pcd.colors = o3.utility.Vector3dVector(read_pointnet_colors(seg.numpy()))

o3.visualization.draw_plotly([pcd])

在这里插入图片描述

图 1.随机旋转的噪声点云。 Y 轴是纵轴

你可能不会注意到噪音有太大差异,因为我们添加的量很小; 我们添加了少量,因为不想极大地破坏结构,但这一小量足以对模型产生影响。 现在我们就来看看训练分类的数据频率。
在这里插入图片描述

图 2. 训练分类数据点直方图

从图2中我们可以看出,这绝对不是一个平衡的训练集。 因此,我们可能想要应用类别权重,甚至使用焦点损失来帮助我们的模型学习。

3、PointNet损失函数

当训练PointNet进行分类时,我们可以使用 PyTorch 中的标准交叉熵损失,但我们还想添加包括论文中提到的正则化项。

正则化项强制特征变换矩阵正交,但为什么呢? 特征变换矩阵旨在旋转(变换)点云的高维表示。 我们如何确定这种学习的高维旋转实际上是在旋转点云? 为了回答这个问题,让我们考虑一些所需的旋转属性。

我们希望学习到的旋转是仿射的,这意味着它保留结构。 我们希望确保它不会做一些奇怪的事情,例如将其映射回较低维度的空间或弄乱结构。 我们不能只绘制 nx64 点云来检查这一点,但我们可以通过鼓励旋转正交来让模型学习有效的旋转。 这是因为正交矩阵同时保留长度和角度,而旋转矩阵是一种特殊类型的正交矩阵 。 我们可以“鼓励”模型通过使用以下项进行正则化来学习正交旋转矩阵:

在这里插入图片描述

图 3.PointNet正则化项

我们利用正交矩阵的一个基本属性,即它们的列和行是正交向量。 对于完全正交的矩阵,图 3 中的正则化项将等于 0。

在训练期间,我们只需将此项添加到我们的损失中。 如果你已经完成了之前关于如何编码PointNet的教程,可能还记得特征转换矩阵 A 由分类头返回。

现在让我们编写PointNet损失函数的代码。 我们已经添加了加权(平衡)交叉熵损失和焦点损失的术语,但解释它们超出了本教程的范围。 其代码位于此处。 该代码改编自这个Github库。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointNetLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0, reg_weight=0, size_average=True):
        super(PointNetLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reg_weight = reg_weight
        self.size_average = size_average

        # sanitize inputs
        if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,(list, np.ndarray)): self.alpha = torch.Tensor(alpha)

        # get Balanced Cross Entropy Loss
        self.cross_entropy_loss = nn.CrossEntropyLoss(weight=self.alpha)

    def forward(self, predictions, targets, A):

        # get batch size
        bs = predictions.size(0)

        # get Balanced Cross Entropy Loss
        ce_loss = self.cross_entropy_loss(predictions, targets)

        # reformat predictions and targets (segmentation only)
        if len(predictions.shape) > 2:
            predictions = predictions.transpose(1, 2) # (b, c, n) -> (b, n, c)
            predictions = predictions.contiguous() \
                                     .view(-1, predictions.size(2)) # (b, n, c) -> (b*n, c)

        # get predicted class probabilities for the true class
        pn = F.softmax(predictions)
        pn = pn.gather(1, targets.view(-1, 1)).view(-1)

        # get regularization term
        if self.reg_weight > 0:
            I = torch.eye(64).unsqueeze(0).repeat(A.shape[0], 1, 1) # .to(device)
            if A.is_cuda: I = I.cuda()
            reg = torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))
            reg = self.reg_weight*reg/bs
        else:
            reg = 0

        # compute loss (negative sign is included in ce_loss)
        loss = ((1 - pn)**self.gamma * ce_loss)
        if self.size_average: return loss.mean() + reg
        else: return loss.sum() + reg

4、训练PointNet用于分类

现在我们已经了解了数据和损失函数,我们可以继续进行训练。

对于我们的训练需要量化模型的表现。 通常我们会考虑损失和准确性,但对于这个分类问题,我们需要一个衡量错误分类和正确分类的指标。 想想典型的混淆矩阵:真阳性、假阴性、真阴性和假阳性; 我们想要一个在所有这些方面都表现良好的分类器。

马修斯相关系数 (MCC) 量化了我们的模型在所有这些指标上的表现,并且被认为是比准确性或 F1 分数更可靠的单一性能指标。 MCC 的范围从 -1 到 1,其中 -1 是最差的性能,1 是最好的性能,0 是随机猜测。 我们可以通过 torchmetrics 将 MCC 与 PyTorch 结合使用。

from torchmetrics.classification import MulticlassMatthewsCorrCoef

mcc_metric = MulticlassMatthewsCorrCoef(num_classes=NUM_CLASSES).to(DEVICE)

训练过程是一个基本的 PyTorch 训练循环,在训练和验证之间交替。

我们使用 Adam 优化器和我们的点净损失函数以及上面图 3 中描述的正则化项。对于点净损失函数,我们选择设置 alpha,它对每个样本的重要性进行加权。

我们还设置了 gamma 来调节损失函数并迫使其专注于困难示例,其中困难示例是那些以较低概率分类的示例。 有关更多详细信息,请参阅笔记本中的注释。 人们注意到,使用循环学习率时模型训练得更好,因此我们在这里实现了它。

import torch.optim as optim
from point_net_loss import PointNetLoss

EPOCHS = 50
LR = 0.0001
REG_WEIGHT = 0.001 

# manually downweight the high frequency classes
alpha = np.ones(NUM_CLASSES)
alpha[0] = 0.5  # airplane
alpha[4] = 0.5  # chair
alpha[-1] = 0.5 # table

gamma = 1

optimizer = optim.Adam(classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.01, 
                                              step_size_up=2000, cycle_momentum=False)
criterion = PointNetLoss(alpha=alpha, gamma=gamma, reg_weight=REG_WEIGHT).to(DEVICE)

classifier = classifier.to(DEVICE)

请按照笔记本进行训练循环,并确保你有 GPU。 如果没有,请删除调度程序并将学习率设置为 0.01,几个 epoch 后你应该会得到足够好的结果。 如果遇到任何 PyTorch 用户警告(由于 nn.MaxPool1D 的未来更新),可以通过以下方式抑制它们:

import warnings
warnings.filterwarnings("ignore")

5、训练结果

在这里插入图片描述

我们可以看到,训练和验证的准确率都上升了,但 MCC 仅在训练时上升,而在验证时却没有上升。 这可能是由于验证和测试分组中某些类的样本量非常小造成的; 因此在这种情况下,MCC 可能不是用于验证和测试的最佳单一指标。 这需要更多的调查来确定 MCC 何时是一个好的指标; 即多少不平衡对于 MCC 来说是过多? 每个类别需要多少样本才能使 MCC 有效?

我们来看看测试结果:

在这里插入图片描述

我们看到测试准确度约为 85%,但 MCC 略高于 0。由于我们只有 16 个类,让我们查看笔记本中的混淆矩阵,以更深入地了解测试结果。

在这里插入图片描述

图 6. 测试数据混淆矩阵。 资料来源:作者。

大多数情况下,分类是可以的,但也有一些不太常见的类别,例如“火箭”或“滑板”。 该模型在这些类别上的预测性能往往较差,而在这些不太常见的类别上的性能是导致 MCC 下降的原因。

另一件需要注意的事情是,当你检查结果(如笔记本中所示)时,将在更频繁的分类中获得良好的准确性和自信的表现。 然而,在频率较低的课程中,你会发现置信度较低且准确性较差。

6、检查关键集

现在我们将研究本教程中最有趣的部分,即关键集。 关键集是点云集的基本基础点。 这些点定义了它的基本结构。 这里有一些代码展示了如何可视化它们。

from open3d.web_visualizer import draw 


critical_points = points[crit_idxs.squeeze(), :]
critical_point_colors = read_pointnet_colors(seg.numpy())[crit_idxs.cpu().squeeze(), :]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(critical_points)
pcd.colors = o3.utility.Vector3dVector(critical_point_colors)

# o3.visualization.draw_plotly([pcd])
draw(pcd, point_size=5) # does not work in Colab

这里有一些可视化,请注意,我使用“draw()”来获得更大的点大小,但它在 Colab 中不起作用。

在这里插入图片描述

图 7.点云集及其由PointNet学习的相应关键集

我们可以看到,关键集展示了其对应点云的整体结构,它们本质上是稀疏采样的点云。 这表明训练后的模型实际上已经学会了区分差异结构,并表明它实际上能够根据每个点云类别的区别结构对其进行分类。

7、结束语

我们学习了如何从头开始训练PointNet以及如何可视化点集。 如果你真的感兴趣,请尝试提高整体分类性能。 以下是一些帮助你入门的建议:

  • 使用不同的损失函数
  • 在循环学习率调度程序中尝试不同的设置
  • 尝试对PointNet架构进行修改
  • 尝试不同的数据增强
  • 使用更多数据 → 尝试完整的 shapenet 数据集

原文链接:PointNet分类3D点云 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/849372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机网络笔记】第二章物理层

物理层 1、四大特性 ①机械特性:接口是怎样的(接口所用接线器的形状和尺寸,引脚数目和排列,固定和锁定装置等)。 ②电气特性:用多少的电 ③功能特性:线路上电平电压的特性 ④过程特性&#xf…

高通GPIO寄存器值参数意义和设置(深度理解)

目录 1、GPIO寄存器参数值及含义 2、读写寄存器地址 3、施密特触发器引起的滞后 4、高通设备树pinctrl的定义及配置 5、高通设备树GPIO的msmxxxx-pinctrl.dtsi结构定义 本文主要介绍基于高通的gpio配置,其中最少见的就是配置gpio的滞后效应引起的电压偏移对实际使用…

基于Helm快速部署私有云盘NextCloud

1. 添加源 helm repo add nextcloud https://nextcloud.github.io/helm/2. 编写values.yaml 为了解决通过不被信任的域名访问。请联系您的管理员。如果您就是管理员,请参照 config.sample.php 中的示例编辑 config/config.php 中的“trusted_domains”设置。 nex…

Vue中使用uuid生成(脚手架创建自带的)

1.utils 说明:一般封装工具函数。 // 单例模式 import { v4 as uuidv4 } from uuid; // 要生成一个随机的字符串,且每次执行不能发生变化 // 游客身份还要持久存储 function getUUID(){// 先从本地获取uuid,本地存储里面是否有let uuid_tok…

【OpenCV常用函数:轮廓检测+外接矩形检测】cv2.findContours()+cv2.boundingRect()

文章目录 1、cv2.findContours()2、cv2.boundingRect() 1、cv2.findContours() 对具有黑色背景的二值图像寻找白色区域的轮廓,因此一般都会先经过cvtColor()灰度化和threshold()二值化后的图像作为输入。 cv2.findContous(image, mode, method[, contours[, hiera…

Chapter 13: Network Programming | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介Network ProgrammingNetworked programsHypertext Transfer Protocol - HTTPThe world’s simplest web browserRetrieving an image over HTTPRetrieving web pages with urllibReading binary files using urllibParsing HTML and scra…

【DP+矩阵加速】CF691 E

Problem - 691E - Codeforces 题意&#xff1a; 思路&#xff1a; 有人只会暴力DP忘记矩阵快速幂怎么写了 Code&#xff1a; #include <bits/stdc.h>#define int long longusing i64 long long;using namespace std;const int N 1e2 10; const int mod 1e9 7;int…

【Transformer】自注意力机制Self-Attention

1. Transformer 由来 & 特点 1.1 从NLP领域内诞生 "Transformer"是一种深度学习模型&#xff0c;首次在"Attention is All You Need"这篇论文中被提出&#xff0c;已经成为自然语言处理&#xff08;NLP&#xff09;领域的重要基石。这是因为Transfor…

Oracle单实例升级补丁

目录 1.当前DB环境2.下载补丁包和opatch的升级包3.检查OPatch的版本4.检查补丁是否冲突5.关闭数据库实例&#xff0c;关闭监听6.应用patch7.加载变化的SQL到数据库8.ORACLE升级补丁查询 oracle19.3升级补丁到19.18 1.当前DB环境 [oraclelocalhost ~]$ cat /etc/redhat-releas…

[LeetCode - Python]69. x 的平方根(Easy);367. 有效的完全平方数(Easy)

1.题目&#xff1a; 69. x 的平方根(Easy) 1.代码&#xff1a; class Solution:def mySqrt(self, x: int) -> int:# 思路&#xff1a;二分法&#xff0c;左闭右开# 额外添加1&#xff1a;判断0&#xff0c;1是否符合&#xff1b;if x 0 or x 1 :return xleft , right ,…

基于微信小程序的传染病酒店隔离平台设计与实现(Java+spring boot+MySQL+微信小程序)

获取源码或者论文请私信博主 演示视频&#xff1a; 基于微信小程序的传染病酒店隔离平台设计与实现&#xff08;Javaspring bootMySQL微信小程序&#xff09; 使用技术&#xff1a; 前端&#xff1a;html css javascript jQuery ajax thymeleaf 微信小程序 后端&#xff1a;…

使用线性回归预测票房收入 -- 机器学习项目基础篇(10)

当一部电影被制作时&#xff0c;导演当然希望最大化他/她的电影的收入。但是我们能通过它的类型或预算信息来预测一部电影的收入会是多少吗&#xff1f;这正是我们将在本文中学习的内容&#xff0c;我们将学习如何实现一种机器学习算法&#xff0c;该算法可以通过使用电影的类型…

# ⛳ Docker 安装、配置和详细使用教程-Win10专业版

目录 ⛳ Docker 安装、配置和详细使用教程-Win10专业版&#x1f69c; 一、win10 系统配置&#x1f3a8; 二、Docker下载和安装&#x1f3ed; 三、Docker配置&#x1f389; 四、Docker入门使用 ⛳ Docker 安装、配置和详细使用教程-Win10专业版 &#x1f69c; 一、win10 系统配…

20230808在WIN10下使用python3将TXT文件转换为DOCX

20230808在WIN10下使用python3将TXT文件转换为DOCX 2023/8/8 19:30 缘起&#xff0c;由于google的文档翻译不支持SRT/TXT格式的字幕&#xff0c;因此需要将SRT格式的字幕转为DOCX。 Ch4.Unreported.World.2022.Mexicos.Psychedelic.Toads.1080p.HDTV.x265.AAC.MVGroup.org.mkv …

FK-坦克大战制作(一)菜单制作

1、Cocos Creator新建2d项目 2.在资源管理器中新建场景menu 新建scences文件夹》新建场景》改名为menu 3.在层级管理器的Canvas下新建Layout节点&#xff0c;并在此节点下新建Label标签 4.双击Label&#xff0c;在属性检查器中进行编辑 5. 添加动画&#xff1a;(对文本进行放大…

代码随想录算法训练营day57

文章目录 Day57回文子串题目思路代码 最长回文子序列题目思路代码 Day57 回文子串 647. 回文子串 - 力扣&#xff08;LeetCode&#xff09; 题目 给你一个字符串 s &#xff0c;请你统计并返回这个字符串中 回文子串 的数目。 回文字符串 是正着读和倒过来读一样的字符串。…

JavaWeb学习|JSP相关内容

1.什么是JSP Java Server Pages: Java服务器端页面&#xff0c;也和Servlet一样&#xff0c;用于动态Web技术! 最大的特点: 。写JSP就像在写HTML 。区别: 。HTML只给用户提供静态的数据 。JSP页面中可以嵌入JAVA代码&#xff0c;为用户提供动态数据 JSP最终也会被转换成为一…

使用Python和wxPython将图片转换为草图

导语: 将照片转换为艺术风格的草图是一种有趣的方式&#xff0c;可以为您的图像添加独特的效果。在本文中&#xff0c;我们将介绍如何使用Python编程语言和wxPython图形用户界面库来实现这一目标。我们将探讨如何使用OpenCV库将图像转换为草图&#xff0c;并使用wxPython创建一…

科研热点|5本Scopus期刊不再被收录,Scopus期刊目录更新(附下载)!

此次Scopus期刊目录更新后&#xff0c;有5本期刊不再被收录&#xff08;Discontinued titles July 2023&#xff09;&#xff0c;同上次更新时相比&#xff0c;此次又新增139本期刊(Accepted titles)进入Scopus数据库。目前Scopus 来源出版物列表&#xff08;Scopus Sources&am…

[Java]JDK新特性

目录 一、JDK新特性 1.1Java Record 1.1.1Record的使用 1.1.2Instance Methods 1.1.3静态方法 Static Method 1.1.4Record构造方法 1.1.5Record与Lombok 1.1.6Record实现接口 1.1.7Local Record 1.1.8嵌套Record 1.1.9instanceof判断Record类型 1.1.10总结 1.2Swit…