基于图卷积网络的人体3D网格分割

news2024/9/22 19:36:19

深度学习在 2D 视觉识别任务上取得了巨大成功。十年前被认为极其困难的图像分类和分割等任务,现在可以通过具有类似人类性能的神经网络来解决。这一成功归功于卷积神经网络 (CNN),它取代了手工制作的描述符。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

受 CNN 在 2D 图像上取得成功的推动,研究人员已将 CNN 扩展到不规则图数据,例如 3D 网格。将卷积运算扩展到非结构化图数据并非易事,因为相对位置的概念已经丢失。因此,我们专门撰写了这篇博文和一个 Colab笔记本 来回答如何将 CNN 推广到不规则图。特别是,我们讨论了如何修改 CNN 以预测 3D 网格上的分割蒙版。

1、为什么 CNN 不适用?

在欧几里得空间中定义的规则图具有固定的节点排序和邻域大小。通用图没有邻居节点排序,并且其邻居数量是可变的。

不幸的是,CNN 不能直接转换为不规则图。CNN 在欧几里得空间和规则图上工作,其中相邻节点可以通过它们与中心节点的相对位置相互区分,并且相邻节点的数量是恒定的。卷积运算利用这两个属性来得出信息节点嵌入。对于不规则和非结构化图,这些属性不满足。非结构化图没有相对位置的概念—没有顶部、底部、右侧或左侧。此外,邻域连接/大小可能因节点而异。因此,不可能在不规则图上执行卷积运算。

使用 3x3 CNN 层计算节点嵌入 h 的更新公式。请注意,为每个邻居学习唯一的权重矩阵 W。这是可能的,因为可以使用邻居与中心节点的相对位置来区分邻居。

2、图卷积网络

由于普通卷积运算不适用于不规则图,因此需要重新制定该运算。图卷积网络 (GCN) 和图卷积运算就是这样的重新制定。与传统 CNN 类似,GCN 对局部邻域进行推理。与 CNN 一样,GCN 通过聚合来自直接邻居的信息来计算嵌入。然而,GCN 确保此信息聚合操作不依赖于节点的相对位置,也不依赖于节点排序,并且与节点的邻域大小无关 - CNN 不会强制执行不变性。

  • 不规则图有节点顺序不变性,因此 GCN 的输出必须与图的节点排序无关。具体而言,无论节点排序如何,神经网络的输出都必须相同。如果输入了两个不同的图节点排序,我们不希望神经网络输出两个不同的值。此属性称为置换不变性。我们在下面进行说明。
  • 节点的连接性在整个图中可能有所不同。因此,图卷积运算必须与邻域大小无关。

对于不同顺序的计划,学习函数必须将节点映射到相同的输出,而不管它们的顺序如何(排列等变性——如图所示),并且必须将图映射到相同的向量(排列不变性),而不管图的节点顺序如何。

单个 GCN 层由三个操作定义:消息计算、聚合和更新。执行这些操作是为了导出节点嵌入。

  • 消息计算:消息计算相当于转换图中每个节点的特征向量。应用的学习转换对所有节点都是通用的,以强制置换不变性。
  • 聚合:导出所有消息后,每个节点都会通过置换等变函数(例如均值、最大值或总和函数)聚合其邻居的消息。这种消息聚合使信息能够在整个图中传播。
  • 更新:最后,更新函数将聚合消息和当前节点的先前节点嵌入结合起来。从而更新节点的嵌入。通过结合来自节点邻域和节点本身的信息,导出的节点描述符对结构和节点级信息进行编码。

具有节点嵌入 h 的 GCN 的一般消息、聚合和更新公式。消息按节点度进行归一化,聚合是邻居和当前节点的总和,并应用激活来更新嵌入。请注意,W 不依赖于当前节点或其邻居。

如需更详细的解释,请参阅斯坦福大学关于图形机器学习的课程 CS224W 的第 6 和第 7 讲。

3、图卷积网络的应用

为了探索引入的图卷积操作,我们转向图分割任务。与图像分割类似,其中像素被分配给语义类别,图分割相当于找到图节点和一组类别之间的映射。我们专注于将身体部位标签分配给人形 3D 网格节点的问题。

图形分割任务:网格中的每个顶点被分配给十二个身体部位之一

4、3D 网格数据

为了解决所提出的分割任务,我们利用了 3D 网格中编码的所有数据。3D 网格通过顶点和三角面的集合来定义表面。它由两个矩阵表示:

  • 一个维度为 (n, 3) 的顶点矩阵,其中每行指定 3D 空间中顶点的空间位置 [x, y, z]。
  • 一个维度为 (m, 3) 的面整数矩阵,其中每行包含定义三角面的顶点矩阵的三个索引。

请注意,顶点矩阵捕获节点级特征信息,而面矩阵描述节点连通性。正式地,每个网格可以转换为具有顶点 V 的图 G = (X, A),其中 X 具有维度  (|V|, 3) 并定义 V 中每个节点 u 的空间 xyz 特征,而邻接矩阵 A 具有维度  (|V|, |V|) 并定义每个节点的连通邻域。我们使用这个推导出的网格图。

5、人体3D网格分割GCN实现

我们依靠特征引导图卷积从上述推到网格图中得出节点级身体部位分配。这个 GCN 层是由 Verma 等人提出的。

通过特征引导图卷积进行节点级嵌入更新

特征引导图卷积的工作原理如下。通过聚合相邻节点(绿色)的变换嵌入以及其自身的变换嵌入来更新中心节点(红色)的嵌入。具体来说,通过将其嵌入传递到 M 个权重矩阵,为每个节点创建 M 条不同的消息。因此,总共创建了 M x Ni 条消息。这些 M x Ni 通过可调加权平均值进行聚合。这可以表示为:

其中 Nv 是顶点 v 的相邻顶点(邻居)集(包括 v 本身)。

M 条消息通过可学习的注意力机制加权。这些消息按以下方式针对每一层进行计算:

其中 u 和 c 是可学习的参数,特定于每个层。请注意,此注意力权重函数是平移不变的,这在使用空间输入特征时是可取的。它确保输入网格的平移不会影响计算出的注意力权重。

6、在 PyG 中实现

此图卷积操作可以使用 PyTorch geometric(PyG) 消息传递类来实现。三种类方法定义了一个 PyG 消息传递类。它们是 forward(执行前向传递)、 message(构造节点级消息)和 aggregate (聚合)。 message和 forward函数定义如下:

class FeatureSteeredConvolution(MessagePassing):
    """Implementation of FeatureSteeredConvolution
    References
    ----------
    .. [1] Verma, Nitika, Edmond Boyer, and Jakob Verbeek.
       "Feastnet: Feature-steered graph convolutions for 3d shape analysis."
       Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.
    """
    
    ...

    def forward(self, x, edge_index):
        """Forward pass through a feature steered convolution layer.
        Parameters
        ----------
        x: torch.tensor [|V|, in_features]
            Input feature matrix, where each row describes
            the input feature descriptor of a node in the graph.
        edge_index: torch.tensor [2, E]
            Edge matrix capturing the graph's
            edge structure, where each row describes an edge
            between two nodes in the graph.
        Returns
        -------
        torch.tensor [|V|, out_features]
            Output feature matrix, where each row corresponds
            to the updated feature descriptor of a node in the graph.
        """
        if self.with_self_loops:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index=edge_index, num_nodes=x.shape[0])

        out = self.propagate(edge_index, x=x)
        return out if self.bias is None else out + self.bias

    def _compute_attention_weights(self, x_i, x_j):
        """Computation of attention weights.
        Parameters
        ----------
        x_i: torch.tensor [|E|, in_feature]
            Matrix of feature embeddings for all central nodes,
            collecting neighboring information to update its embedding.
        x_j: torch.tensor [|E|, in_features]
            Matrix of feature embeddings for all neighboring nodes
            passing their messages to the central node along
            their respective edge.
        Returns
        -------
        torch.tensor [|E|, M]
            Matrix of attention scores, where each row captures
            the attention weights of transformed node in the graph.
        """
        if x_j.shape[-1] != self.in_channels:
            raise ValueError(
                f"Expected input features with {self.in_channels} channels."
                f" Instead received features with {x_j.shape[-1]} channels."
            )
        if self.v is None:
            attention_logits = self.u(x_i - x_j) + self.c
        else:
            attention_logits = self.u(x_i) + self.b(x_j) + self.c
        return F.softmax(attention_logits, dim=1)

    def message(self, x_i, x_j):
        """Message computation for all nodes in the graph.
        Parameters
        ----------
        x_i: torch.tensor [|E|, in_feature]
            Matrix of feature embeddings for all central nodes,
            collecting neighboring information to update its embedding.
        x_j: torch.tensor [|E|, in_features]
            Matrix of feature embeddings for all neighboring nodes
            passing their messages to the central node along
            their respective edge.
        Returns
        -------
        torch.tensor [|E|, out_features]
            Matrix of updated feature embeddings for
            all nodes in the graph.
        """
        attention_weights = self._compute_attention_weights(x_i, x_j)
        x_j = self.linear(x_j).view(-1, self.num_heads, self.out_channels)
        return (attention_weights.view(-1, self.num_heads, 1) * x_j).sum(dim=1)

完整实现可在我们的 Colab 中找到。

7、网络架构

特征引导图卷积构成了我们身体部位标记网络的主干。我们的一般网络架构如下:

  • 首先使用多层感知器编码器嵌入节点级输入特征向量(xyz 坐标)。
  • 然后,编码后的节点级特征依次通过四个特征引导卷积层。这些层各自聚合来自 12 个注意头(身体部位输出标签的数量)的消息。
  • 最后,精炼后的节点级特征通过预测头,即多感知器预测。它为每个节点输出一个类对应关系。

我们训练神经网络以最小化预测和真实分割标签之间的交叉熵损失。该损失通过整个网络反向传播以更新参数。由此产生的训练循环定义如下:

def train(net, train_data, optimizer, loss_fn, device):
    net.train()
    cumulative_loss = 0.0
    for data in train_data:
        data = data.to(device)
        optimizer.zero_grad()
        out = net(data)
        loss = loss_fn(out, data.segmentation_labels.squeeze())
        loss.backward()
        cumulative_loss += loss.item()
        optimizer.step()
    return cumulative_loss / len(train_data)

最后,我们通过计算预测分割标签与真实分割标签之间的准确度来评估网络的性能。计算结果如下所示:

def accuracy(predictions, gt_seg_labels):
    """Compute accuracy of predicted segmentation labels.
    Parameters
    ----------
    predictions: [|V|, num_classes]
        Soft predictions of segmentation labels.
    gt_seg_labels: [|V|]
        Ground truth segmentations labels.
    Returns
    -------
    float
        Accuracy of predicted segmentation labels.
    """
    predicted_seg_labels = predictions.argmax(dim=-1, keepdim=True)
    if predicted_seg_labels.shape != gt_seg_labels.shape:
        raise ValueError("Expected Shapes to be equivalent")
    correct_assignments = (predicted_seg_labels == gt_seg_labels).sum()
    num_assignemnts = predicted_seg_labels.shape[0]
    return float(correct_assignments / num_assignemnts)

8、数据集

我们在 MPI FAUST 数据集 [2] 上训练我们的网络。它包含 100 个严密的人形网格。我们通过标记单个网格来生成分割标签。这些标签通过数据集中包含的地面真实顶点对应关系转移到所有其他网格。我们使用 80 个网格来训练我们的神经网络,并对 20 个网格进行评估。

MPI FAUST 数据集中人形网格的 10 个姿势

网络在此数据集上的学习过程如下所示。我们看到类别预测收敛到正确的身体部位标签。经过训练的网络在测试数据集上的准确率达到 95%。

9、结束语

在本文中,我们展示了如何将卷积运算从欧几里得空间中定义的规则图推广到具有任意连通性的不规则图。这种推广使我们定义了 GCN,即 CNN 的扩展。与传统 CNN 不同,GCN 可以在不规则图上运行。因此,GCN 可以应用于各种数据——我们所示的网格、分子、社交网络,甚至 2D 图像。


原文链接:基于GCN的3D网格分割 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1696308.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

监控服务器性能指标,提升服务器性能

服务器是网络中最关键的组件之一,混合网络架构中的每个关键活动都以某种方式与服务器操作相关,服务器不仅是现代计算操作的支柱,也是网络通信的关键。 从发送电子邮件到访问数据库和托管应用程序,服务器的可靠性和性能直接影响到…

Cobaltstrike框架介绍

Cobaltstrike简介 cobalt strike(简称CS)是一款团队作战渗透测试神器,分为客户端及服务端,一个服务端可以对应多个客户 端,一个客户端可以连接多个服务端,可被团队进行分布式协团操作. 和MSF关系 metas…

嵌入式全栈开发学习笔记---C语言笔试复习大全23

目录 联合体 联合体的定义 联合体的长度 如果来判断设备的字节序? 如何把大端数据转换成小端数据? 枚举 枚举的定义 上一篇复习了结构体,这一节复习联合体和枚举。 说明:我们学过单片机的一般都是有C语言基础的了&#xff…

酷黑简洁大气体育直播自适应模板赛事直播门户网站源码

源码名称:酷黑简洁大气体育直播自适应模板赛事直播门户网站源码 开发环境:帝国cms 7.5 安装环境:phpmysql 支持PC与手机端同步生成html(多端同步生成插件) 带软件采集,可以挂着自动采集发布,无…

【JTS Topology Suite】Java对二维几何进行平移、缩放、旋转等坐标变换

JTS介绍 Github项目地址:https://github.com/locationtech/jts Maven库地址:https://mvnrepository.com/artifact/org.locationtech.jts JTS Topology Suite是一个用于创建和操作二维矢量几何的Java库。 JTS有对应的.NET版本NetTopologySuite库&…

PyQt6--Python桌面开发(34.QStatusBar状态栏控件)

QStatusBar状态栏控件 self.statusBar.showMessage(q.text()菜单选项被点击了,5000)

HP Laptop 15s-fq2xxx,15s-fq2706TU原厂Win11系统镜像下载

惠普星15青春版原装Windows11系统,恢复出厂开箱状态oem预装系统,带恢复重置还原 链接:https://pan.baidu.com/s/1t4Pc-Q0obApLkG8o_9Kkkw?pwdduzj 提取码:duzj 适用型号:15s-fq2xxx,15s-fq2000 15s-f…

垃圾回收机制及算法

文章目录 概要对象存活判断引用计数算法可达性分析算法对象是否存活各种引用 垃圾收集算法分代收集理论复制算法标记清除算法标记-整理算法 概要 垃圾收集(Garbage Collection, 下文简称GC),其优缺点如下: 优点&#…

从“反超”到“引领”,中国卫浴品牌凭何遥遥领先?

作者 | 曾响铃 文 | 响铃说 前不久,第28届中国国际厨房、卫浴设施展览会(以下简称“中国国际厨卫展”)在上海如期举行,就结果来说真的让人大开眼界。 冲水声比蚊子声更小的马桶、能化身无感交互平台的魔镜柜、可以语音交互的淋浴器,这些“…

揭秘章子怡成功之路:她是如何征服世界的?

章子怡的演艺生涯可谓是一部传奇❗❗❗ 从一个普通工人家庭的女孩,到如今的国际巨星 她的每一步都充满了努力和汗水 她的舞蹈基础为她日后的演艺事业奠定了坚实的基础 而她对戏剧和电影的热爱更是让她在演艺道路上不断前行 从《我的父亲母亲》到《卧虎藏龙》&…

VS Code添加高亮缩进功能

当代码缩进层次较多时,为了视觉上容易识别,一般希望可以多个缩进以不同颜色进行高亮显示, VS Code 中 indent-rainbow 插件可以实现这个功能。

【C++练级之路】【Lv.21】C++11——列表初始化和声明

快乐的流畅:个人主页 个人专栏:《算法神殿》《数据结构世界》《进击的C》 远方有一堆篝火,在为久候之人燃烧! 文章目录 引言一、列表初始化1.1 内置类型1.2 结构体或类1.3 容器 二、声明2.1 auto2.2 decltype2.3 nullptr 三、STL的…

海外链游地铁跑酷全自动搬砖挂机掘金变现项目,号称单窗口一天收益30+(教程+工具)

一、项目概述 地铁跑酷海外版国外版自动搬砖挂机掘金项目是一款结合了地铁跑酷元素的在线游戏,为玩家提供一个全新的游戏体验,使得玩家可以轻松地进行游戏,无需手动操作,节省时间和精力。 二、游戏特点 1. 自动化操作&#xff1…

爬虫案例:有道翻译python逆向

pip install pip install requestspip install base64pip install pycrytodome tools 浏览器的开发者工具,重点使用断点,和调用堆栈 工具网站:https://curlconverter.com/ 简便请求发送信息 flow 根据网站信息,preview,respon…

2024 年 5 个 GO REST API 框架

什么是API? API是一个软件解决方案,作为中介,使两个应用程序能够相互交互。以下一些特征让API变得更加有用和有价值: 遵守REST和HTTP等易于访问、广泛理解和开发人员友好的标准。API不仅仅是几行代码;这些是为移动开…

【MAC】Spring Boot 集成OpenLDAP(含本地嵌入式服务器方式)

目录 一、添加springboot ldap依赖: 二、本地嵌入式服务器模式 1.yml配置 2.创建数据库文件:.ldif 3.实体类 4.测试工具类 5.执行测试 三、正常连接服务器模式 1.yml配置 2.连接LDAP服务器配置类,初始化连接,创建LdapTem…

springboot社区助老志愿服务系统-计算机毕业设计源码96682

摘要 大数据时代下,数据呈爆炸式地增长。为了迎合信息化时代的潮流和信息化安全的要求,利用互联网服务于其他行业,促进生产,已经是成为一种势不可挡的趋势。在图书馆管理的要求下,开发一款整体式结构的社区助老志愿服务…

洗地机十大品牌排名:2024十大值得入手的洗地机盘点

随着生活水平的提高,智能清洁家电已经成为日常生活中的必需品。洗地机之所以在家庭清洁中大受欢迎,主要是因为它的多功能特性。传统的清洁方式通常需要扫帚、拖把和吸尘器分别进行操作,而洗地机将这些功能集成在一个设备中,使清洁…

MHDDoS:一个包含了56种技术的DDoS测试工具

关于MHDDoS MHDDoS是一款功能强大的DDoS服务器/站点安全测试工具,该工具包含56种技术,可以帮助广大研究人员对自己的服务器或网站执行DDoS安全测试。 工具技术 Layer7 GET | GET 泛洪 POST | POST 泛洪 OVH | 绕过OVH RHEX | 随机HEX STOMP | 绕过chk_…

Mysql-根据字段名查询字段在哪些表里

SELECT * FROM information_schema.COLUMNS WHERE COLUMN_NAMElabel_name;