5.基于图神经网络的点云分类

news2024/11/27 8:36:25

目录

    • 一、数据处理
    • 二、点云生成
    • 三、PointNet++
      • 阶段1:通过动态图生成进行分组
      • 阶段2:邻居聚合
    • 四、网络架构
    • 五、训练程序

       在本教程中,您将学习使用图神经网络进行点云分类的基本工具。在这里,我们得到了一个对象或点集的数据集,我们希望以这样一种方式嵌入这些对象,即在手头有任务的情况下,它们是线性可分离的。具体而言,原始点云被用作神经网络的输入,并将学习捕捉有意义的局部结构,以便对整个点集进行分类。

       让我们来看看PyTorch Geometric提供的一个简单的数据集,GeometricShapes 数据集

一、数据处理

       GeometricShapes数据集包含40种不同的二维和三维几何形状,如立方体、球体和金字塔。每种形状都有两个不同的版本,一个用于训练神经网络,另一个用于评估其性能。

%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_mesh(pos, face):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d') # 创建一个带有3D投影的AxesSubplot对象
    ax.axes.xaxis.set_ticklabels([]) # 隐藏3D坐标轴刻度
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:, 0], pos[:, 1], pos[:, 2], triangles=data.face.t(), antialiased=False)
    plt.show()
from torch_geometric.datasets import GeometricShapes

dataset = GeometricShapes(root='data/GeometricShapes')
print(dataset)

data = dataset[0]
print(data)
visualize_mesh(data.pos, data.face)

data = dataset[4]
print(data)
visualize_mesh(data.pos, data.face)

在这里插入图片描述
       我们可以通过PyTorch Geometric轻松导入和实例化GeometricShapes数据集,并打印出一些信息,例如数据集的描述或关于单个示例中存在的属性的一些信息。特别地,每个对象被表示为网格,包含关于pos中的顶点和面中顶点的三角形连通性的信息(具有shape[3,num_faces])。

二、点云生成

       由于我们对点云分类感兴趣,我们可以通过使用“transforms”将网格变换为点。
       在这里,PyTorch Geometric提供了torch_geometric.transforms.SamplePoints变换,该变换将根据网格面的面积对网格面上固定数量的点进行均匀采样。
       我们可以通过dataset.transform = SamplePoints(num=...)将此转换添加到数据集中。每次从数据集中访问示例时,都会调用转换过程:

def visualize_points(pos, edge_index=None, index=None):
    fig = plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src, dst) in edge_index.t().tolist():
            src = pos[src].tolist()
            dst = pos[dst].tolist()
            plt.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=1, color='black')
    if index is None:
        plt.scatter(pos[:, 0], pos[:, 1], s=50, zorder=1000)
    else:
        mask = torch.zeros(pos.size(0), dtype=torch.bool)
        mask[index] = True
        plt.scatter(pos[~mask, 0], pos[~mask, 1], s=50, color='lightgray', zorder=1000)
        plt.scatter(pos[mask, 0], pos[mask, 1], s=50, zorder=1000)
    plt.axis('off')
    plt.show()
import torch
from torch_geometric.transforms import SamplePoints

torch.manual_seed(42)
dataset.transform = SamplePoints(num=256)

data = dataset[0]
print(data)
visualize_points(data.pos, data.edge_index)

data = dataset[4]
print(data)
visualize_points(data.pos)

在这里插入图片描述

三、PointNet++

       由于我们现在已经准备好使用点云数据集,让我们看看如何通过图神经网络和 PyTorch Geometric库的帮助来处理它。
在这里,我们将重新实现PointNet++架构,这是通过图神经网络进行点云分类/分割的开创性工作。

       PointNet++通过遵循简单的分组、邻域聚合和下采样方案来迭代处理点云:

  1. 分组阶段构建一个图,其中连接了附近的点。通常,这是通过𝑘-最近邻居搜索或通过球查询(将半径内的所有点连接到查询点)。
  2. 邻域聚合阶段执行图形神经网络层,该层为每个点聚合来自其直接邻域的信息(由前一阶段构建的图给出)。这允许PointNet++以不同的规模捕获局部信息。
  3. 下采样阶段实现了适用于具有潜在不同大小的点云的池化方案。我们将暂时忽略这一阶段,稍后再回到这一阶段。

在这里插入图片描述

阶段1:通过动态图生成进行分组

       PyTorch Geometric通过其辅助程序包torch_cluster提供用于动态图形生成的实用程序,特别是通过𝑘-最近邻和球查询生成图。

from torch_cluster import knn_graph

data = dataset[0]
data.edge_index = knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos, edge_index=data.edge_index)

data = dataset[4]
data.edge_index = knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos, edge_index=data.edge_index)

在这里插入图片描述
       在这里,我们从torch_cluster导入knn_graph函数,并通过传入输入点pos和最近邻居k的数量来调用它。作为输出,我们将接收shape[2,num_edges]edge_index张量,该张量将保存每列中源和目标节点索引的信息(称为 the sparse matrix COO format)。

阶段2:邻居聚合

PointNet++层遵循一个简单的神经消息传递方案,该方案通过:

h i ( ℓ + 1 ) = max ⁡ j ∈ N ( i ) MLP ( h j ( ℓ ) , p j − p i ) \mathbf{h}^{(\ell + 1)}_i = \max_{j \in \mathcal{N}(i)} \textrm{MLP} \left( \mathbf{h}_j^{(\ell)}, \mathbf{p}_j - \mathbf{p}_i \right) hi(+1)=jN(i)maxMLP(hj(),pjpi)

  • h i ( ℓ ) ∈ R d \mathbf{h}_i^{(\ell)} \in \mathbb{R}^d hi()Rd denotes the hidden features of point i i i in layer ℓ \ell .
  • p i ∈ R 3 \mathbf{p}_i \in \mathbb{R}^3 piR3 denotes the position of point i i i.

       我们可以利用MessagePassing 接口来实现这个层。
       MessagePassing接口通过自动处理消息传播,帮助我们创建消息传递图神经网络
       在这里,我们只需要定义其message函数以及使用哪种聚合方案,例如aggr="max" (see here for the accompanying tutorial):

from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing


class PointNetLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        # Message passing with "max" aggregation.
        super().__init__(aggr='max')
        
        # Initialization of the MLP:
        # Here, the number of input features correspond to the hidden node
        # dimensionality plus point dimensionality (=3).
        self.mlp = Sequential(Linear(in_channels + 3, out_channels),
                              ReLU(),
                              Linear(out_channels, out_channels))
        
    def forward(self, h, pos, edge_index):
        # Start propagating messages.
        return self.propagate(edge_index, h=h, pos=pos)
    
    def message(self, h_j, pos_j, pos_i):
        # h_j defines the features of neighboring nodes as shape [num_edges, in_channels]
        # pos_j defines the position of neighboring nodes as shape [num_edges, 3]
        # pos_i defines the position of central nodes as shape [num_edges, 3]

        input = pos_j - pos_i  # Compute spatial relation.

        if h_j is not None:
            # In the first layer, we may not have any hidden node features,
            # so we only combine them in case they are present.
            input = torch.cat([h_j, input], dim=-1)

        return self.mlp(input)  # Apply our final MLP.

       可以看出,在PyTorch Geometric中实现PointNet++层非常简单。

       在 __init__ 函数中,我们首先定义我们想要应用 max aggregation,然后初始化MLP,该MLP负责将相邻节点特征以及源节点和目标节点之间的空间关系转换为(可训练的)消息。

       在 forward 函数中,我们可以开始基于edge_index传播消息,传入创建消息所需的所有内容。

       在message 函数中,我们现在可以分别通过*_j*_i访问相邻节点和中心节点信息,并为每个连接返回一条消息。

四、网络架构

       我们可以使用 knn_graphPointNetLayer 来定义我们的网络架构。
       在这里,我们感兴趣的是一种能够以 mini-batch fashion在点云上操作的架构。

       PyTorch Geometric通过创建稀疏块对角邻接矩阵(由 edge_index定义)和节点维度上的串联特征矩阵(如 pos),在小批量上实现并行化。

       为了区分小批量中的实例,存在一个名为 batch 的特殊向量,(shape [num_nodes]),其将每个节点映射到该批中的其各自的图:
batch = [ 0 ⋯ 0 , 1 ⋯ n − 2 n − 1 ⋯ n − 1 ] ⊤ \textrm{batch} = {[ 0 \cdots 0, 1 \cdots n-2 n-1 \cdots n - 1 ]}^{\top} batch=[00,1n2n1n1]

       我们需要使用这个batch向量来生成 knn_graph ,因为我们不想连接来自不同示例的节点。

       这样,我们的整体PointNe架构看起来如下:

import torch
import torch.nn.functional as F
from torch_cluster import knn_graph
from torch_geometric.nn import global_max_pool


class PointNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        torch.manual_seed(12345)
        self.conv1 = PointNetLayer(3, 32)
        self.conv2 = PointNetLayer(32, 32)
        self.classifier = Linear(32, dataset.num_classes)
        
    def forward(self, pos, batch):
        # Compute the kNN graph:
        # Here, we need to pass the batch vector to the function call in order
        # to prevent creating edges between points of different examples.
        # We also add `loop=True` which will add self-loops to the graph in
        # order to preserve central point information.
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True)
        
        # 3. Start bipartite message passing.
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()
        print(h.shape)

        # 4. Global Pooling.
        h = global_max_pool(h, batch)  # [num_examples, hidden_channels]
        print(h.shape)
        
        # 5. Classifier.
        return self.classifier(h)


model = PointNet()
print(model)

在这里插入图片描述

       在这里,我们通过继承torch.nn.Module来创建我们的网络架构,构造函数中初始化两个PointNetLayer模块和一个final linear classifier(torch.nn.Linear)。

       在forward方法中,我们首先基于节点的位置pos 动态生成一个16-nearest neighbor graph 。基于得到的图连通性,我们应用了两个基于图的卷积算子,并通过ReLU非线性对它们进行了增强。

       第一个操作获取3个输入特征(节点的位置),并将它们映射到32个输出特征。

       之后,每个点都保存关于its 2-hop neighborhood的信息,并且应该已经能够区分简单的局部形状。

       接下来,我们应用 global graph readout function,即global_max_pool,对于每个示例,其取沿着节点维度的最大值。

       最后,我们应用线性分类器将剩余的32个特征映射到40个类中的一个

五、训练程序

       我们现在准备编写两个简单的过程,分别在训练和测试数据集上训练和测试我们的模型。
       如果你不是PyTorch的新手,这个方案对你来说应该很熟悉。

import torch
from torch_geometric.transforms import SamplePoints
from torch_cluster import knn_graph
from torch_geometric.nn import global_max_pool
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import GeometricShapes
import matplotlib.pyplot as plt

dataset = GeometricShapes(root='data/GeometricShapes')

class PointNetLayer(MessagePassing): # MessagePassing:消息传播基类
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max')
        self.mlp = Sequential(Linear(in_channels + 3, out_channels), ReLU(), Linear(out_channels, out_channels))

    def forward(self, h, pos, edge_index):
        return self.propagate(edge_index, h=h, pos=pos)

    def message(self, h_j, pos_j, pos_i):
        input = pos_j - pos_i
        if h_j is not None:
            input = torch.cat([h_j, input], dim=-1) # 按列拼接
        return self.mlp(input)


class PointNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = PointNetLayer(3, 32)
        self.conv2 = PointNetLayer(32, 32)
        self.classifier = Linear(32, dataset.num_classes)

    def forward(self, pos, batch):
        edge_index = knn_graph(pos, k=16, batch=batch, loop=True) # 在每个batch里,各自生成k最近邻图
        h = self.conv1(h=pos, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = self.conv2(h=h, pos=pos, edge_index=edge_index)
        h = h.relu()
        h = global_max_pool(h, batch)
        return self.classifier(h)

model = PointNet()
print(model)

# 准备数据,并进行批传入
# GeometricShapes数据集包含40种不同的2D和3D几何形状,如立方体、球体和金字塔
# 每种形状都有两个不同的版本,一个用于训练神经网络,另一个用于评估其性能
train_dataset = GeometricShapes(root='data/GeometricShapes', train=True, transform=SamplePoints(128)) # 每个样本采样128个点
test_dataset = GeometricShapes(root='data/GeometricShapes', train=False, transform=SamplePoints(128))
# 构建Dataloader
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) # 一批为10个样本
test_loader = DataLoader(test_dataset, batch_size=10)

# 模型、优化器和损失函数
model = PointNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Adam算法
criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失


def train(model, optimizer, loader):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()  # 梯度清零
        logits = model(data.pos, data.batch)  # 前向传播
        loss = criterion(logits, data.y)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 参数更新
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_loader.dataset) # 训练样本平均损失


def test(model, loader):
    model.eval()
    total_correct = 0
    for data in loader:
        logits = model(data.pos, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(loader.dataset)


loss_history = [] # 存储训练损失
test_acc_history = [] # 存储测试准确率
for epoch in range(101):
    loss = train(model, optimizer, train_loader)
    test_acc = test(model, test_loader)
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

    loss_history.append(loss)
    test_acc_history.append(test_acc)


# 画训练损失和测试集准确率随Epoch变化图
def plot_loss_with_acc(loss_history, test_acc_history):
    Epoch_list = list(range(101))  # epoch:0-100列表
    fig, ax = plt.subplots() # 创建一个 Figure 对象和一个 Axes 对象
    ax.plot(Epoch_list, loss_history, color='blue') # loss图
    ax2 = ax.twinx() # 创建一个共享 x 轴的第二个 y 轴
    ax2.plot(Epoch_list, test_acc_history, color='red') # TestAcc图

    # 设定左边Loss轴标签和颜色
    ax.set_ylabel('Loss', color='blue')
    ax.tick_params(axis='y', labelcolor='blue')

    # 设定右边ValAcc轴标签和颜色
    ax2.set_ylabel('TestAcc', color='red')
    ax2.tick_params(axis='y', labelcolor='red')

    plt.title('Training Loss & Test Accuracy')
    plt.show()

plot_loss_with_acc(loss_history, test_acc_history) # 画图

在这里插入图片描述
在这里插入图片描述

       正如我们所看到的,即使每个类只训练一个例子,我们也能够实现大约85%的测试准确率(请注意,我们当然可以通过更长时间的训练和使用更深层次的神经网络来提高性能)。

本文内容参考:PyG官网

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/647074.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

北京大学研发基于机器学习的多能干细胞分化系统,高效、稳定制备功能性细胞

内容一览:20 世纪以来,干细胞与再生医学技术一直是国际生物医学领域的热点前沿之一。现如今,研究人员已开始探索将干细胞转变为特定类型细胞。然而,这一过程中干细胞会出现不规则生长或自发分化为不同类型细胞的情况,因…

大数据治理:数据安全

数据安全 (Data Security)一般指保护重要的、机密的纸质信息或数字信息,防止未经授权的非法访问、泄露、篡改、丢失、损坏、数据滥用等情形。数据安全涵盖的范围非常广泛,包括存储数据的硬件设备、访问数据的软件环境、访问权限控制、相关的规章制度等。…

vscode配置clangd和clang-format

vscode安装和配置 如何安装和配置vscode以搭建c开发环境,可以查看我的另一篇博客:Windows上最轻量的vscode-C开发环境搭建。 在这篇博客中,详细介绍了如何安装vscode以及应该安装哪些插件。这里不再赘述。 vscode中想使用clangd来作为语言…

Unity极坐标Shader特效,以及使用Instanced Property实现相同材质不同参数

Unity极坐标特效 先看看效果 Unity极坐标Shader特效 有时候我们需要在场景中摆放一些热点,用户点击之后出现互动,当然实现这个功能的方法有很多,作为一名程序员,当然是要用最简单的实现。用shader程序化实现它。 啥是极坐标 极坐…

鲸落送书第二期清华出版社系列丛书

1.《Node.js从基础到项目实践(视频教学版)》 《Node.js从基础到项目实践(视频教学版)》以理论结合实践的形式,讲解了Node.js 基础、框架、进阶知识和项目实践。本书为视频教学版,每一章节都有相对应的视频讲解&#xf…

番茄工作法图解——简单易行的时间管理方法

ISBN: 978-7-115-24669-1 作者:【瑞典】诺特伯格(Staffan Noteberg) 页数:136页 阅读时间:2023-06-10 推荐指数:★★★★★ 番茄工作法(意大利语:Pomodoro Technique)是一…

网工内推 | 数通专场!最高19k*13薪,HCIE/CCIE认证优先

01 嘉环科技股份有限公司 招聘岗位:数据工程师 职责描述: 1、 承担TL/TE职责,负责数通接入(路由器、交换机、安全、PTN、OLT等)相关产品的工程项目交付。 2、 作为技术负责人/交付工程师支撑项目交付,指导…

zabbix监控域名证书期限

前言 zabbix通过自定义key"domain.discovery"发现域名(Json格式),然后自动生成监控项,监控项通过自定义key"https"获取域名证书有效期,若少于30天则出发告警。 说明 名称作用domain.txt域名列表…

day06--java高级编程:多线程,枚举类,注解,反射,网络通讯

1 Day16–多线程01 1.1 程序概念 程序(program):是为完成特定任务、用某种语言编写的一组指令的集合。即指一段静态的代码,静态对象。 1.2 进程 1.2.1 概念 进程(process):是程序的一次执行过程,或是正在运行的一个程序。是一…

基于Python3.7的robotframework环境搭建步骤

Windows环境搭建 安装Python3 官网下载,我这边环境是Python 3.7.0 安装robotframework基础依赖 在dos命令输入 pip install robotframework 在线安装robotframework 在dos命令输入 pip install Pypubsub3.3.0 在线安装 Pypubsub 在dos命令输入 pip install wxPy…

汇编学习教程:寻址大总结

前言 在上篇博文中,我们主要学习了一个全新的寄存器:bp。bp 寄存器在功能和使用上与 bx 有着异曲同工之妙,只不过两人绑定的服务对象不同:bx 默认绑定的是 DS 段寄存器,而 bp 默认绑定的是 SS 段寄存器。bx 和 bp 有着…

抓包!抓包! HTTPS中间人抓包

简介 抓包是一种网络分析技术,可以用于捕获和分析数据包,通常用于网络故障排查、协议分析、安全审计等。网络上所有的数据包都是以二进制的形式在网络上传输的,抓包工具可以捕获到这些数据包并将其转换为可读的格式,方便进行分析…

Python使用阿里API进行身份证识别

Python使用阿里API进行身份证实名认证 1. 作者介绍2. 身份证识别介绍3. 调用阿里智能云API4. 代码解析4.1 完整代码4.2 实验结果 参考 1. 作者介绍 孟莉苹,女,西安工程大学电子信息学院,2021级硕士研究生,张宏伟人工智能课题组 研…

极致呈现系列之:Echarts折线图的视觉冲击力

目录 认识折线图折线图的创建折线图的美化修改折线的样式修改坐标轴的样式修改折线图上点的样式将折线设置为平滑曲线设置渐变色面积给折线图添加标记线给折线图添加标记点 折线图的交互添加鼠标悬停提示添加数据区域选择与缩放 认识折线图 折线图是一种常用的数据可视化图表&…

React中的HOC高阶组件处理

先了解函数柯里化 柯里化函数(Currying Function)是指将一个接受多个参数的函数转化成一系列只接受单个参数的函数,并且返回接受单个参数的函数,达到简化函数调用和提高可读性的目的。 简单来说,柯里化即将接收多个参…

大数据为什么如此重要?

简单来说,大数据就是结构化的传统数据再加上非结构化的新数据。那么传统数据和新数据又是什么呢?传统数据就是IT业务系统里面的数据,如客户资料、财务数据等。这些数据是结构化的,量也不是特别大,一般只是TB级。对比传…

如何让自己的代码顺利通过代码审查?

最近很多同学,都去暑期实习了,实习就意味着要在公司项目是写代码了。 大多数同学,可能面试能力不错,但是实操还是弱了一些。之前有位同学,春招靠面试能力去了大厂,然后实习刚工作的时候,要写代…

Java30天拿下-----第二天(运算符,标识符,Scanner,进制转换)

Java30天拿下-----第二天 一 运算符算术运算符赋值运算符关系运算符逻辑运算符三元运算符运算符的优先级 二 标识符关键字保留字 三 控制台接收键盘输入:Scanner四 进制进制的转换(基本功)其他进制转为十进制十进制转为其他进制二进制转为其他…

《当我谈跑步时,我谈些什么》痛楚难以避免,而磨难可以选择

《当我谈跑步时,我谈些什么》痛楚难以避免,而磨难可以选择 村上春树,日本当代小说家,情感类类型作家。主要作品有《且听风吟》《挪威的森林》《海边的卡夫卡》《奇鸟行状录》《1Q84》等。 施小炜 译 来自百度百科的一条&#xff1…

存储快速入门——【2】数据复制与容灾、云存储、大数据概念

存储快速入门——【2】数据复制与容灾、云存储、大数据概念 一、数据复制与容灾 1 恢复时间目标(RTO)和恢复点目标(RPO) 对于信息系统而言,容灾就是使信息系统具有应对一定的灾难袭击,保持系统或间断运行…