图神经网络:处理点云

news2024/11/6 9:38:11

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook和有关文献。超链。提取码8848。

文章目录

    • 简单前置工作学习
    • 文献阅读
    • Point++的实现
    • 模型问题

简单前置工作学习

工作目标:根据点云去进行40分类。
工作流程:1.读取PyG内置的几何图形数据。2.随机但是均匀采样。3.K最邻近算法构边建图。4.使用Point++进行图分类。
导库,下载数据,导库,定义函数

from torch_geometric.datasets import GeometricShapes
dataset=GeometricShapes(root='/Data/GeometricShapes')
import matplotlib.pyplot as plt
def visualize_mesh(pos,face):
    fig=plt.figure()
    ax=fig.add_subplot(111,projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:,0],pos[:,1],pos[:,2],triangles=face.t(),antialiased=False)
    plt.show()

PS1:这段代码会在C盘生成一个DATA的文件并将数据集放在DATA中,有强迫症注意一下。
PS2:就是几何图形网格。细节可以点击这里。
打印信息与可视化

print(dataset)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述
导库以及定义函数

from torch_geometric.transforms import SamplePoints
import torch
def visualize_points(pos,edge_index=None,index=None):
    fig=plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src,dst) in edge_index.t().tolist():
            src=pos[src].tolist()
            dst=pos[dst].tolist()
            plt.plot([src[0],dst[0]],[src[1],dst[1]],linewidth=1,color='black')
    if index is None:
        plt.scatter(pos[:,0],pos[:,1],s=50,zorder=1000)
    else:
       mask=torch.zeros(pos.size(0),dtype=torch.bool)
       mask[index]=True
       plt.scatter(pos[~mask,0],pos[~mask,1],s=50,color='lightgray',zorder=1000)
       plt.scatter(pos[mask,0],pos[mask,1],s=50,zorder=1000)
    plt.axis('off')
    plt.show()

从图形表面均匀地采样,打印信息与可视化

dataset.transform=SamplePoints(num=256)
data=dataset[0]
print(data)
visualize_points(data.pos)
data=dataset[4]
print(data)
visualize_points(data.pos)

jupyter notebook内输出如下
在这里插入图片描述

文献阅读

参考文献: PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

文章概述: “Deep learning on point sets for 3d classification and segmentation”是参考文献之前前沿工作,核心思想对每个点空间编码然后聚合所有单点要素到全局的空间。显然这样无法捕捉局部特征。受到卷积神经网络启发,这里参考文献便就来了。具体步骤: 第一步:进行局部划分;第二步:组合局部特征;第三步:加工局部特征,重复上述过程直到点云所有特征都被利用。所以面临三个问题。第一问:如何进行局部划分。第二问:如何组合局部特征。第三问:如何加工局部特征。解决第一问:Farthest Point Sampling,FPS。解决第二问:Ball Query。解决第三问:上面那篇文章的Point

分层的点云学习器: Sampling layer: Farthest Point Sampling,FPS。 可以使用K最近邻算法但是不好。固定一个区域更加有普适性。PS:注意一下KNN与Ball Query的区别。Grouping layer: 输入: N × ( d + C ) N \times (d+C) N×(d+C) 以及 N ′ × d N'\times d N×d 输出: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C)。符号说明: N N N是点的数量, d d d是质心坐标, C C C是点的特征维数, N ′ N' N是质心数量, K K K是邻域内点数量。Point Net layer: 输入: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C) 输出: N ′ × ( d + C ′ ) N' \times (d + C') N×(d+C) 。这个模型鲁棒性强,对于不均匀的数据效果同样。这个图挺好的。
在这里插入图片描述
PS1:原文还有其他很好工作,有兴趣有时间建议去看,但是我们这里跳过。
PS2:对于上面前置工作,由于采用是均匀的,可以这样建图。如下:
导库

from torch_cluster import knn_graph

打印信息与可视化

data=dataset[0]
data.edge_index=knn_graph(data.pos,k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)
data=dataset[4]
data.edge_index=knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)

jupyter notebook内输出如下
在这里插入图片描述

Point++的实现

我们使用数学公式首先进行EdgeConv的描述: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , h j ( l − 1 ) − h i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},h_j^{(l-1)}-h_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),hj(l1)hi(l1))。Point++类似于这个公式: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , p j ( l − 1 ) − p i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},p_j^{(l-1)}-p_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),pj(l1)pi(l1))
搭建多层的Point++

from torch_geometric.nn import MessagePassing
from torch.nn import Sequential,Linear,ReLU

class PointNetLayer(MessagePassing):

    def __init__(self,in_channels,out_channels):
        super().__init__(aggr='max')
        self.mlp=Sequential(Linear(in_channels+3,out_channels),ReLU(),Linear(out_channels,out_channels))
        
    def forward(self,h,pos,edge_index):
        return self.propagate(edge_index,h=h,pos=pos)
    
    def message(self,h_j,pos_j,pos_i):
        input=pos_j-pos_i
        if h_j is not None:
            input=torch.cat([h_j,input],dim=-1)
        return self.mlp(input)
from torch_geometric.nn import global_max_pool

class PointNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1=PointNetLayer(3,32)
        self.conv2=PointNetLayer(32,32)
        self.classifier=Linear(32,dataset.num_classes)
        
    def forward(self,pos,batch):
        edge_index=knn_graph(pos,k=16,batch=batch,loop=True)
        h=self.conv1(h=pos,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=self.conv2(h=h,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=global_max_pool(h,batch)
        return self.classifier(h)

model=PointNet()
print(model)
#输出如下
#PointNet(
#  (conv1): PointNetLayer()
#  (conv2): PointNetLayer()
#  (classifier): Linear(in_features=32, out_features=40, bias=True)
#)

导库,训测拆分数据变换以及划分批量

from torch_geometric.loader import DataLoader
train_dataset=GeometricShapes(root='/Data/GeometricShapes',train=True,transform=SamplePoints(128))
test_dataset=GeometricShapes(root='/Data/GeometricShapes',train=False,transform=SamplePoints(128))
train_loader=DataLoader(train_dataset,batch_size=10,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=10)

进行实验

model=PointNet();optimizer=torch.optim.Adam(model.parameters(),lr=0.01);criterion=torch.nn.CrossEntropyLoss()

def train(model,optimizer,loader):
    model.train()
    total_loss=0
    for data in loader:
        optimizer.zero_grad()
        logits=model(data.pos,data.batch)
        loss=criterion(logits,data.y)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()*data.num_graphs
    return total_loss/len(train_loader.dataset)

def test(model,loader):
    model.eval()
    total_correct=0
    for data in loader:
        logits=model(data.pos,data.batch)
        pred=logits.argmax(dim=-1)
        total_correct+=int((pred==data.y).sum())
    return total_correct/len(loader.dataset)

for epoch in range(1,51):
    loss=train(model,optimizer,train_loader)
    test_acc=test(model,test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
#输出如下(这里只有最后一次):
#Epoch: 50, Loss: 0.7294, Test Accuracy: 0.8250

模型问题

出现问题: 由于模型使用坐标进行输入并且选择笛卡尔坐标系传递信息所以旋转坐标就不可行。可以按照如下方式进行实验。

from torch_geometric.transforms import Compose,RandomRotate
random_rotate=Compose([
    RandomRotate(degrees=180,axis=0),
    RandomRotate(degrees=180,axis=1),
    RandomRotate(degrees=180,axis=2),
])
dataset=GeometricShapes(root='/DATA//GeometricShapes',transform=random_rotate)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述

transform=Compose([
    random_rotate,
    SamplePoints(num=128),
])
test_dataset=GeometricShapes(root='/DATA/GeometricShapes',train=False,transform=transform)
test_loader=DataLoader(test_dataset,batch_size=10)
test_acc=test(model,test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
#输出如下:
#Test Accuracy: 0.2000
print(len(test_dataset))
#输出如下:
#40

可以看到,模型效果,就不好了。有解决方法的。暂时就这样吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/544957.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用项目管理思维来过5·20,真实太酷啦!

明天就是一年一度的520啦,阿道单身多年的同事刚京在四月成功使用SWOT分析模型相亲成功,牵手女嘉宾。二人眼看着就要迎来在一起后的第一个节日520,刚京却因为没有头绪而陷入了不知所措的焦虑。 团队成员齐上阵,用项目管理思维&…

使用 Apache Flink 开发实时 ETL

Apache Flink 是大数据领域又一新兴框架。它与 Spark 的不同之处在于,它是使用流式处理来模拟批量处理的,因此能够提供亚秒级的、符合 Exactly-once 语义的实时处理能力。Flink 的使用场景之一是构建实时的数据通道,在不同的存储之间搬运和转…

<组件封装:Vue + elementUi 通过excel文件实现 “ 批量导入 ” 表单数据,生成对应新增信息 >

Vue elementUi 通过excel文件实现 “ 批量导入 ” 表单数据,生成对应新增信息 👉 前言👉 一、封装组件对应API及绑定事件> Attributes> Event 👉 二、实现案例> HTML父组件模板> 子组件模板 👉 三、效果演…

线程相关基础知识

一、相关概念 1.1 cpu 中央处理器(central processing unit, 简称cpu ),计算机系统的 运算 和 控制 核心 1.2 cpu核心数和线程数 cpu核心数指cpu 内核数量,如双核、四核、八核。 cpu线程数是一种逻辑的概念,就是模…

基于 SpringBoot + Redis 实现分布式锁

大家好,我是余数,这两天温习了下分布式锁,然后就顺便整理了这篇文章出来。文末附有源码链接,需要的朋友可以自取。 至于什么是分布式锁,这里不做赘述,不了解的可以自行去查阅资料。 文章目录 实现要点项目…

android13 FLAG_BLUR_BEHIND 壁纸高斯模糊,毛玻璃背景方案设计-千里马framework实战

hi,粉丝朋友们! 今天有个学员朋友,问到了一个高斯模糊相关问题,这个高斯模糊相关的需求我相对还是比较熟悉,下面来重点讲解一下新版本高斯模糊相关的实现。 更多framework干货知识手把手教学 Log.i("qq群",“422901085…

[230528] 托福阅读真题|TPO66 13/30|整卷得分22/30|9:45~10:45|15:40~16:40

The Actor and the Audience P1 rehearsev 排练;排演anticipate v 预期;预料;预见 audiencen 观众brilliantadj 灿烂的;绝妙的rehearsaln 排练;预演;排演crumblev 崩塌stage frightn 怯场(演员…

自动化测试框架?这应该是全网最全自动化框架总结了,你要的都有...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 Python自动化测试&…

学术加油站|基于LSM-tree存储系统的内存管理,最大限度降低I/O成本

本文系北京理工大学科研助理牛颂登所著,本篇也是 OceanBase 学术系列稿件第 10 篇。欢迎访问 OceanBase 官网获取更多信息:https://www.oceanbase.com/ 「牛颂登:北京理工大学科研助理,硕士期间在电子科技大学网络空间安全研究院从…

资深老鸟总结,Selenium自动化测试实战小技巧,不要再走弯路了...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 Selenium4自动化测…

数据库小技能:数据报表

文章目录 I 需求1.1 补贴II 实现思路2.1 生成资金调节报表数据III Dto3.1 报表基本查询IV 接口I 需求 代理商调节活动汇总商户调节活动汇总激励金日月汇总数据源:活动流水表(上游回调) 1.1 补贴 调节活动补贴= D0补贴+T1补贴。(比如交易金额满足1000,转T1) 补贴金额 =…

图扑数字孪生智慧灯杆,“多杆合一”降本增效

前言 随着智慧城市建设的不断深入,智慧灯杆作为城市基础设施的重要组成部分,正在成为城市智能化和绿色化的重要手段之一。 效果展示 图扑智慧灯杆系统在城市道路照明领域引入信息化手段,通过构建路灯物联网,实现了现代化的路灯按…

线性代数 --- Gram-Schmidt, 格拉姆-施密特正交化(下)

Gram-Schmidt正交化过程 到目前为止,我们都是在反复强调“对于无解的方程组Axb而言,如果矩阵A是标准正交矩阵的话,就怎么怎么好了。。。。”。因为,不论是求投影还是计算最小二乘的正规方程,他们都包含了。当A为标准正…

yolov4论文解读

数据层面上的数据增强 四张照片拼接成一张进行训练 相当于增大了batch-size,更适合于单GPU。 Mosaic data augmentation 马赛克数据增强 self-adversarial training(SAT) 自我对抗训练 DropBlock Label Smoothing 损失函数 由IOU改进到CIOU 网络结构 CSPNet&…

Win10 WLAN驱动正常但仍然不显示无线网络解决办法

Win10 WLAN驱动正常但仍然不显示无线网络解决办法 写作背景过程解决方案结尾 写作背景 本菜鸡重置了电脑的网络,然后重新启动后 WLAN 不见了,连不了 WIFI 了,很疑惑,后来经过一番搜索找到了问题所在,写下本篇文章以记…

Spark/Flink广播实现作业配置动态更新

前言 在实时计算作业中,往往需要动态改变一些配置,举几个栗子: 实时日志ETL服务,需要在日志的格式、字段发生变化时保证正常解析;实时NLP服务,需要及时识别新添加的领域词与停用词;实时风控服…

访问学者J1签证面签的七个问题

作为访问学者,申请J1签证面签时可能会遇到一些常见问题。下面知识人网小编将介绍七个访问学者面签可能遇到的问题,并提供相应的答案。 问题一:您将在美国进行何种类型的学术研究? 答案:我将在美国从事学术研究&#x…

普冉PY32L020单片机简介,主频最高48MHZ

PY32L020单片机是一颗32 位 ARM Cortex-M0内核,宽电压工作范围的 MCU。这颗MCU的价格跟八位单片机相差不大,性价比可以说是非常的高了。来看看PY32L020的配置吧。 PY32L020单片机产品特性: 内核: — 32 位 ARM Cortex - M0 — 最…

飞浆AI studio人工智能课程学习(2)-Prompt优化思路|十个技巧高效优化Prompt|迭代法|Trick法|通用法|工具辅助

文章目录 优化思路上节课的例子问题分析思路解析 Prompt优化技巧Prompt优化原理 十个技巧高效优化Prompt迭代法Trick法工具法通用技巧│定基础通用技巧│做强调需求强调怎么做? 通用技巧│提预设Trick法│戴高帽原理 Trick法│说好话以基础计算为例: Trick法│给提示…

小红书数据分析:如何用ChatGPT输出爆文笔记

ChatGPT的热度依旧不减,随着技术升级,越来越多更高级的玩法被发掘。今天我们就来聊聊,如何用ChatGPT写出小红书风格的文章。 首先,小红书笔记制作分为两个步骤: 1、找选题 2、写小红书风格的笔记 我们用例子说话&a…