图神经网络:(处理点云)PointNet++的实现

news2024/11/19 0:31:14

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook和有关文献。超链。提取码8848。

文章目录

    • 简单前置工作学习
    • 文献阅读
    • PointNet++的实现
    • 模型问题

简单前置工作学习

工作目标:根据点云去进行40分类。
工作流程:1.读取PyG内置的几何图形数据。2.随机但是均匀采样。3.K最邻近算法构边建图。4.使用PointNet++进行图分类。
导库,下载数据,导库,定义函数

from torch_geometric.datasets import GeometricShapes
dataset=GeometricShapes(root='/Data/GeometricShapes')
import matplotlib.pyplot as plt
def visualize_mesh(pos,face):
    fig=plt.figure()
    ax=fig.add_subplot(111,projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:,0],pos[:,1],pos[:,2],triangles=face.t(),antialiased=False)
    plt.show()

PS1:这段代码会在C盘生成一个DATA的文件并将数据集放在DATA中,有强迫症注意一下。
PS2:就是几何图形网格。细节可以点击这里。
打印信息与可视化

print(dataset)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述
导库以及定义函数

from torch_geometric.transforms import SamplePoints
import torch
def visualize_points(pos,edge_index=None,index=None):
    fig=plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src,dst) in edge_index.t().tolist():
            src=pos[src].tolist()
            dst=pos[dst].tolist()
            plt.plot([src[0],dst[0]],[src[1],dst[1]],linewidth=1,color='black')
    if index is None:
        plt.scatter(pos[:,0],pos[:,1],s=50,zorder=1000)
    else:
       mask=torch.zeros(pos.size(0),dtype=torch.bool)
       mask[index]=True
       plt.scatter(pos[~mask,0],pos[~mask,1],s=50,color='lightgray',zorder=1000)
       plt.scatter(pos[mask,0],pos[mask,1],s=50,zorder=1000)
    plt.axis('off')
    plt.show()

从图形表面均匀地采样,打印信息与可视化

dataset.transform=SamplePoints(num=256)
data=dataset[0]
print(data)
visualize_points(data.pos)
data=dataset[4]
print(data)
visualize_points(data.pos)

jupyter notebook内输出如下
在这里插入图片描述

文献阅读

参考文献: PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

文章概述: “Deep learning on point sets for 3d classification and segmentation”是参考文献之前前沿工作,核心思想对每个点空间编码然后聚合所有单点要素到全局的空间。显然这样无法捕捉局部特征。受到卷积神经网络启发,这里参考文献便就来了。具体步骤: 第一步:进行局部划分;第二步:组合局部特征;第三步:加工局部特征,重复上述过程直到点云所有特征都被利用。所以面临三个问题。第一问:如何进行局部划分。第二问:如何组合局部特征。第三问:如何加工局部特征。解决第一问:Farthest Point Sampling,FPS。解决第二问:Ball Query。解决第三问:上面那篇文章的Point

分层的点云学习器: Sampling layer: Farthest Point Sampling,FPS。 可以使用K最近邻算法但是不好。固定一个区域更加有普适性。PS:注意一下KNN与Ball Query的区别。Grouping layer: 输入: N × ( d + C ) N \times (d+C) N×(d+C) 以及 N ′ × d N'\times d N×d 输出: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C)。符号说明: N N N是点的数量, d d d是质心坐标, C C C是点的特征维数, N ′ N' N是质心数量, K K K是邻域内点数量。Point Net layer: 输入: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C) 输出: N ′ × ( d + C ′ ) N' \times (d + C') N×(d+C) 。这个模型鲁棒性强,对于不均匀的数据效果同样。这个图挺好的。
在这里插入图片描述
PS1:原文还有其他很好工作,有兴趣有时间建议去看,但是我们这里跳过。
PS2:对于上面前置工作,由于采用是均匀的,可以这样建图。如下:
导库

from torch_cluster import knn_graph

打印信息与可视化

data=dataset[0]
data.edge_index=knn_graph(data.pos,k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)
data=dataset[4]
data.edge_index=knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)

jupyter notebook内输出如下
在这里插入图片描述

PointNet++的实现

我们使用数学公式首先进行EdgeConv的描述: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , h j ( l − 1 ) − h i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},h_j^{(l-1)}-h_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),hj(l1)hi(l1))。Point++类似于这个公式: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , p j ( l − 1 ) − p i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},p_j^{(l-1)}-p_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),pj(l1)pi(l1))
搭建多层的PointNel++

from torch_geometric.nn import MessagePassing
from torch.nn import Sequential,Linear,ReLU

class PointNetLayer(MessagePassing):

    def __init__(self,in_channels,out_channels):
        super().__init__(aggr='max')
        self.mlp=Sequential(Linear(in_channels+3,out_channels),ReLU(),Linear(out_channels,out_channels))
        
    def forward(self,h,pos,edge_index):
        return self.propagate(edge_index,h=h,pos=pos)
    
    def message(self,h_j,pos_j,pos_i):
        input=pos_j-pos_i
        if h_j is not None:
            input=torch.cat([h_j,input],dim=-1)
        return self.mlp(input)
from torch_geometric.nn import global_max_pool

class PointNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1=PointNetLayer(3,32)
        self.conv2=PointNetLayer(32,32)
        self.classifier=Linear(32,dataset.num_classes)
        
    def forward(self,pos,batch):
        edge_index=knn_graph(pos,k=16,batch=batch,loop=True)
        h=self.conv1(h=pos,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=self.conv2(h=h,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=global_max_pool(h,batch)
        return self.classifier(h)

model=PointNet()
print(model)
#输出如下
#PointNet(
#  (conv1): PointNetLayer()
#  (conv2): PointNetLayer()
#  (classifier): Linear(in_features=32, out_features=40, bias=True)
#)

导库,训测拆分数据变换以及划分批量

from torch_geometric.loader import DataLoader
train_dataset=GeometricShapes(root='/Data/GeometricShapes',train=True,transform=SamplePoints(128))
test_dataset=GeometricShapes(root='/Data/GeometricShapes',train=False,transform=SamplePoints(128))
train_loader=DataLoader(train_dataset,batch_size=10,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=10)

进行实验

model=PointNet();optimizer=torch.optim.Adam(model.parameters(),lr=0.01);criterion=torch.nn.CrossEntropyLoss()

def train(model,optimizer,loader):
    model.train()
    total_loss=0
    for data in loader:
        optimizer.zero_grad()
        logits=model(data.pos,data.batch)
        loss=criterion(logits,data.y)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()*data.num_graphs
    return total_loss/len(train_loader.dataset)

def test(model,loader):
    model.eval()
    total_correct=0
    for data in loader:
        logits=model(data.pos,data.batch)
        pred=logits.argmax(dim=-1)
        total_correct+=int((pred==data.y).sum())
    return total_correct/len(loader.dataset)

for epoch in range(1,51):
    loss=train(model,optimizer,train_loader)
    test_acc=test(model,test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
#输出如下(这里只有最后一次):
#Epoch: 50, Loss: 0.7294, Test Accuracy: 0.8250

模型问题

出现问题: 由于模型使用坐标进行输入并且选择笛卡尔坐标系传递信息所以旋转坐标就不可行。可以按照如下方式进行实验。

from torch_geometric.transforms import Compose,RandomRotate
random_rotate=Compose([
    RandomRotate(degrees=180,axis=0),
    RandomRotate(degrees=180,axis=1),
    RandomRotate(degrees=180,axis=2),
])
dataset=GeometricShapes(root='/DATA//GeometricShapes',transform=random_rotate)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述

transform=Compose([
    random_rotate,
    SamplePoints(num=128),
])
test_dataset=GeometricShapes(root='/DATA/GeometricShapes',train=False,transform=transform)
test_loader=DataLoader(test_dataset,batch_size=10)
test_acc=test(model,test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
#输出如下:
#Test Accuracy: 0.2000
print(len(test_dataset))
#输出如下:
#40

可以看到,模型效果,就不好了。有解决方法的。暂时就这样吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/545092.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

convLSTM2D 层使用方法解析(Keras库)

最近在研究时序图像分类问题,需要用到convLSTM层提取特征,所以在此仔细分析一下keras.layers.ConvLSTM2D层的使用方法。深度学习框架是tensorflow 官方文档:recurrent/#convlstm2d - Keras 中文文档 下面这部分内容摘自官方文档 ConvLSTM2D…

Axure 轮播图如何制作

近来在学习axure,用的版本为Axure 9,给大家讲一下怎么使用轮播图,老规矩保姆式教学法 一、作图 1.创建新的页面,方便我们做图 2.在元件库搜索“动态面板”字样,设置一个动态面板,为什么要设置呢&#xff…

IIC总线通讯协议学习

​ IIC(最简单的总线通讯,简单意味着通用和普适性) iic通讯一般采用一主多从的方式.同一时间要么在发送信息,要么在读取信息(半双工通讯) ​​​​​​​​​​​​​​ ​​​​ ​​​ 标准的写数据帧(主机向从机写数据) 解释以上的写数据帧 S:起始信号(在SCL…

Linux学习---VMWare安装和CentOS7安装

1、 VMWare安装 1、VMware16安装包 链接:https://pan.baidu.com/s/1TKf5szN6k5Hk4HH4zqBgrg 提取码:zhm6 –来自百度网盘超级会员V1的分享 2、VMWare安装流程 (1)找到下载好的安装包,双击运行程序 (2&…

云贝餐饮连锁V2-2.7.7 【新增】外卖新订单提醒

独立版:云贝餐饮连锁V2、版本更新至2.7.7,小程序、公众号版本,全插件,包含微信公众号小程序;包更新,独立版; 带商家端,修复收银台、排队点餐、堂食点餐;最新版更新了&…

【pytorch损失函数(3)】nn.L1Loss()和nn.SmoothL1Loss()

文章目录 【回归损失函数】L1(MAE)、L2(MSE)、Smooth L1 Loss详解1. L1 Loss(Mean Absolute Error,MAE)1.1 数学定义1.2 、使用场景与问题1.3 、如何使用 2. L2 Loss(Mean Squared E…

最流行的开源 LLM (大语言模型)整理

本文对国内外公司、科研机构等组织开源的 LLM 进行了全面的整理。 Large Language Model (LLM) 即大规模语言模型,是一种基于深度学习的自然语言处理模型,它能够学习到自然语言的语法和语义,从而可以生成人类可读的文本。 所谓"语言模…

MTK平台的SWT异常的简单总结(1)——WatchDog

SWT系列资料很多来源于Google (1)概念相关 SWT是SoftWare Watchdog Timeout的缩写,在Android系统中,为了监控SystemServer是否处于正常运行状态,加入了SWT线程来监控SystemServer中重要线程和Service的运行情况。判断…

多线程-程序、进程、线程与并行、并发的概念

多线程 本章要学习的内容: 专题1:相关概念的理解专题2:多线程创建方式一:继承Thread类专题3:多线程创建方式二:实现Runnable接口专题4:Thread类的常用方法专题5:多线程的优点、使用…

合肥工业大学信息隐藏实验报告

✅作者简介:CSDN内容合伙人、信息安全专业在校大学生🏆 🔥系列专栏 :信息隐藏实验报告 📃新人博主 :欢迎点赞收藏关注,会回访! 💬舞台再大,你不上台&#xff…

OpenCV基础操作(1)图片及视频基础操作、常用绘图函数

OpenCV基础操作(1)图片、视频、绘图函数 import cv2 as cv import numpy as np1、图像的读取、显示、保存 使用函数 cv2.imread() 读入图像。 第一个参数是幅图路径, 第二个参数是要告诉函数应该如何读取这幅图片。 • cv2.IMREAD_COLOR(1):读入一副彩色…

模板字符串、startsWith()方法和endsWith()方法、repeat()、Set数据结构、Set对象实例方法、遍历Set

模版字符串 ES6新增的创建字符串的方式,使用反引号定义 示例 <script>// 1.模板字符串可以解析变量 ${}显示变量的值let name 张三;let sayHello HEllo,我的名字叫${name};console.log(name);console.log(sayHello);let result {name: "zhangsan",age: 20…

激光切割机在使用过程中常见故障有哪些(一)

由于不少客户在使用光纤激光切割机的过程中&#xff0c;因为操作不当等原因&#xff0c;造成激光切割机出现一些小故障&#xff0c;这些故障虽然不大&#xff0c;但是却会对正常使用工期造成延误&#xff0c;甚至造成损失&#xff0c;所以了解光纤激光切割机的常见故障迫在眉睫…

本地电脑远程服务器,复制大文件报:未指定错误的解决办法

1、本地电脑快捷键WINR 打开运行窗口 2、输入 \\IP地址\磁盘$。如下&#xff1a; 3、上一步点击确定&#xff0c;即远程到了相应的磁盘&#xff0c;可在本地进行复制粘贴。

北京打响大模型地方战第一枪:公布通用人工智能发展21项措施

21项&#xff01;北京就促进AGI创新发展措施征集意见。 作者 | 李水青 来源 | 智东西 ID | zhidxcom 智东西5月16日消息&#xff0c;近日&#xff0c;《北京市促进通用人工智能创新发展的若干措施&#xff08;2023-2025年&#xff09;&#xff08;征求意见稿&#xff09;》…

【C++】基础知识--程序的结构(1)

C简介&#xff1a; C 是一种静态类型的、编译式的、通用的、大小写敏感的、不规则的编程语言&#xff0c;支持过程化编程、面向对象编程和泛型编程。 C 被认为是一种中级语言&#xff0c;它综合了高级语言和低级语言的特点。 C 是由 Bjarne Stroustrup 于 1979 年在新泽西州…

Codeforces Round 873 (Div. 2) 题解

5.18晚VP&#xff0c;共AC三题&#xff0c;ABC题&#xff0c;感觉难度还是挺大的&#xff0c;做起来一点也不顺手。。。A题秒出&#xff0c;卡在了B题&#xff0c;在B题花费了好多时间&#xff0c;还没有C题做得顺利。。。B题开始想错了&#xff0c;思路不对&#xff0c;但确实…

LeetCode225.用队列实现栈

&#x1f4ad;前言&#xff1a; 建议本题和LeetCode232对比实现 syseptember的个人博客&#xff1a;LeetCode232.栈模拟队列http://t.csdn.cn/HCEDg 题目 思路 ❗注意&#xff1a;本题的逻辑结构是栈&#xff0c;物理结构是队列&#xff0c;我们需要通过2个队列模拟栈的操作。…

Doxygen源码分析:构建过程简介,并生成doxygen自身的C++文档

2023-05-19 11:52:17 ChrisZZ imzhuofoxmailcom Hompage https://github.com/zchrissirhcz 文章目录 1. doxygen 版本2. 找出所有的 CMakeLists.txt 和 *.cmake 文件3. cmake 构建目标清单4. 生成 Doxygen 自己的文档 1. doxygen 版本 zzLegion-R7000P% git log …

LabVIEWCompactRIO 开发指南23 Web服务

LabVIEWCompactRIO 开发指南23 Web服务 LabVIEW8.6中引入的LabVIEWWeb服务提供了一种开放的标准方式&#xff0c;可通过Web与VI进行通信。考虑一个部署在分布式系统中的LabVIEW应用程序。LabVIEW提供了网络流等功能来建立通信&#xff0c;但许多开发人员需要一种方式&#xf…