PointNet++改进策略 :模块改进 | PointCAT, 使用交叉注意力机制来提升3D点云任务中提升模型精度

news2024/9/27 21:25:34
  • 论文题目:PointCAT:Cross-Attention Transformer for Point Cloud
  • 通讯地址:南京理工大学
  • 代码地址:https://github.com/xincheng-yang/PointCAT
    Pasted image 20240913085930
  • . PointCAT架构:PointCAT提出了一种基于交叉注意力机制的Transformer网络,专门用于点云表示。它通过两个不同的多尺度特征分支,利用交叉注意力机制来交换信息。通过这种方式,模型能够有效捕获点云中的长程依赖和多层次特征。
  • . 计算效率优化:为了降低模型的计算复杂度,PointCAT只使用一个分支的单类token作为查询,计算与另一个分支的注意力图,从而减少多分支结构带来的计算开销。
  • 实验结果:通过广泛的实验,论文证明了PointCAT在形状分类、部分分割和语义分割任务上取得了优异或可比的性能。模型在ModelNet40、ShapeNetPart和S3DIS等数据集上表现良好,尤其是在分类任务中取得了93.5%的整体精度。
  • 创新点
    • 提出了一个高效的层次结构,用于提取3D点云中的多尺度特征。
    • 基于这种结构,设计了一种双分支交叉注意力Transformer架构,结合了位置和内容特征,适用于点云学习任务。
    • 实验表明,PointCAT在多个任务上都能提供更加精确和可区分的特征表示。

PointCAT 核心思想

Pasted image 20240913090231
图 1. 用于形状分类的 PointCAT 架构图示。输入点首先被分为两个多尺度子集补丁,然后处理局部聚合以捕获点嵌入的高维特征。我们在 L 个交叉注意力变换层的堆栈之前向每个序列添加一个额外的可学习类标记。最后,将获得的类别标记输入 MLP 头并输出预测。

Pasted image 20240913091144
图 2. 大点分支的交叉注意力层。大分支的xcls被投影到小分支的特征维度,并作为交叉注意力的Query来与小分支的嵌入特征标记生成的Key和Value进行交互。我们通过另一个使用 Layernorm 的线性投影将尺寸对齐回大分支。最终,我们将处理后的类标记与原始补丁标记连接起来。

PointCAT的核心思想是通过 双分支跨注意力机制(Cross-Attention Mechanism) 来有效处理三维点云数据,克服其不规则和无序的特点,进而捕捉点云中的 长程依赖性多尺度特征。具体来说,PointCAT的核心理念包括以下几个关键点:

  1. 双分支架构
    • PointCAT引入了两个独立的Transformer分支,每个分支处理不同尺度的点云特征。这两个分支分别聚合局部和全局的几何信息,能够捕捉到点云在不同层级上的细节和全局结构。
  2. 跨注意力层
    • 跨注意力机制是该模型的核心创新之一。它允许两个分支之间的信息交互,即通过一个分支的特征作为查询,另一个分支的特征作为键和值,从而实现信息的互换和融合。这样既能保持局部特征的独立性,又能增强全局特征的表达能力。
  3. 层次结构与计算效率
    • 为了避免由于多分支结构带来的计算负担,PointCAT设计了一个高效的层次结构。这种设计通过对点云进行分组和降采样,减少了需要处理的点的数量,同时保持了模型的计算效率和表达能力。
  4. 位置和内容特征结合
    • 在跨注意力机制中,PointCAT不仅处理点云的位置信息,还处理点之间的内容特征。这使得模型可以更加准确地理解点云的空间布局和几何关系,提升了点云表示的精度。

通过这种双分支跨注意力架构,PointCAT能够有效地捕捉点云中的长距离依赖关系多尺度信息,并且能够在形状分类、部件分割和语义分割等任务中实现高效且准确的表现。

核心代码实现讲解

它的核心思想是通过 跨注意力机制(Cross-Attention Mechanism) 来融合多尺度特征,从而高效地捕捉点云数据中的局部和全局几何信息。这种机制能在点云分类或分割任务中取得较好的性能表现。

1. 模型架构 (PointCAT 类)

PointCAT 类是 PointCATCross 的主体结构,负责将点云数据输入网络,并输出分类结果。主要组成部分包括:

  • 嵌入层 (Embedding_Layer):将输入点云的坐标嵌入到高维特征空间。
  • 多尺度分组与聚合 (Multi_Grouping):逐层对点云进行分组采样和特征聚合,生成多尺度的局部特征。
  • 跨注意力模块 (CrossAttBlock):在大尺度和小尺度的特征之间进行信息交互。
  • 分类器 (mlp_head_largemlp_head_small):将处理后的特征通过多层感知机(MLP)输出分类结果。
class PointCAT(nn.Module):
    def __init__(self, small_dim=1024, large_dim=512, k=16, class_num=40):
        super(PointCAT, self).__init__()
        self.embedding = Embedding_Layer(in_channel=3, embedding_channel=64)
        self.sample_1 = Multi_Grouping(channel=64, groups=512, kneighbors=k, use_xyz=False, normalize="center")
        self.sample_2 = Multi_Grouping(channel=128, groups=256, kneighbors=k, use_xyz=False, normalize="center")
        self.sample_3 = Multi_Grouping(channel=256, groups=128, kneighbors=k, use_xyz=False, normalize="center")
        self.sample_4 = Multi_Grouping(channel=512, groups=64, kneighbors=k, use_xyz=False, normalize="center")

        self.class_token_small = nn.Parameter(torch.randn(1, 1, small_dim))
        self.class_token_large = nn.Parameter(torch.randn(1, 1, large_dim))

        self.cross_attn_layer = CrossAttBlock(large_dim=large_dim, small_dim=small_dim, cross_attn_depth=2,
                                              cross_attn_heads=8, channels=large_dim)
        ...

在这里:

  • sample_1sample_4 分别用于将输入点云降采样至不同的分辨率,同时提取相应的几何特征。
  • class_token_smallclass_token_large 用作跨注意力机制中的分类 token,分别对应小尺度和大尺度特征。
  • cross_attn_layer 是跨注意力层,负责将大尺度和小尺度特征进行交互。

2. 多尺度特征提取与聚合 (Multi_Grouping 类)

每个 Multi_Grouping 模块负责逐层降采样点云,并提取局部的几何特征。通过 K 近邻算法(KNN)查找点的局部邻域,然后对其进行特征聚合。

class Multi_Grouping(nn.Module):
    def __init__(self, channel, groups, kneighbors, use_xyz, normalize="center"):
        super(Multi_Grouping, self).__init__()
        self.grouper = LocalGrouper(channel=channel, groups=groups, kneighbors=kneighbors, use_xyz=use_xyz,
                                    normalize=normalize)
        self.net = Local_Aggregation(2 * channel, 2 * channel)

    def forward(self, xyz, x):  # B, N, C
        new_xyz, new_feature = self.grouper(xyz, x)
        x = self.net(new_feature)
        return new_xyz, x

forward 函数中:

  • LocalGrouper 通过 KNN 对点进行分组,找到局部邻域的点。
  • Local_Aggregation 对这些局部特征进行聚合,以提取出更加紧凑的局部表示。

3. 跨注意力机制 (CrossAttBlockCrossAttEncoder)

跨注意力机制负责在大尺度和小尺度的特征之间进行交互。该模块的作用是增强全局与局部特征的联系。

class CrossAttBlock(nn.Module):
    def forward(self, xl, xs):
        xl, xs = self.cross_att1(xl, xs)
        xl, xs = self.cross_att2(xl, xs)
        return xl, xs

CrossAttBlock 包含两个 CrossAttEncoder,通过多层次的注意力计算对不同分支的特征进行多次交互。具体的注意力计算由 CrossAttEncoder 完成:

class CrossAttEncoder(nn.Module):
    def forward(self, l, s):
        ...
        # 大尺度分支
        cal_q = conv1_l_s(large_class.unsqueeze(-1)).permute(0, 2, 1)
        cal_q = self.ln_ls1(cal_q)
        cal_qkv = torch.cat((cal_q, x_small), dim=1)
        cal_out = cal_q + cross_attn_l(cal_qkv)
        ...
        # 小尺度分支
        cal_q = conv1_s_l(small_class.unsqueeze(-1)).permute(0, 2, 1)
        cal_qkv = torch.cat((cal_q, x_large), dim=1)
        cal_out = cal_q + cross_attn_s(cal_qkv)
        ...
        return xl, xs

在这里,大尺度和小尺度的分类 token 被用作查询(query),另一个分支的特征被用作键和值,通过注意力机制来计算分支间的交互。这样,模型能够融合全局和局部的信息。

4. 特征分类与输出

经过多尺度特征提取和跨注意力交互后,模型将这些特征通过两个 MLP 头(mlp_head_largemlp_head_small)进行分类:

class PointCAT(nn.Module):
    def forward(self, x):
        ...
        x1, x2 = self.cross_attn_layer(x1, x2)
        x1 = self.mlp_head_large(x1)
        x2 = self.mlp_head_small(x2)
        x = x1 + x2
        return x

最终的分类结果是通过大尺度和小尺度特征的分类结果相加得到的。这个设计可以确保不同分辨率的特征都对分类任务有所贡献。

总结

PointCATCross 的核心在于通过 多尺度特征提取跨注意力机制 来融合不同分辨率下的几何信息。跨注意力机制能够在全局和局部特征之间建立联系,从而提升模型对复杂三维点云结构的理解能力。借鉴以上思路,可以进行如下改进来增强PointNet++ 对点云多尺度特征、全局信息以及复杂几何关系的捕捉能力,进一步提升了其在点云分类和分割任务中的表现。

  1. 引入跨尺度注意力机制:利用 PointCATCross 的跨注意力模块,实现多尺度特征的相互交互和增强。
  2. 引入分类 Token 和多层级特征交互:使用分类 Token 捕捉全局信息,并通过跨注意力机制增强局部和全局特征的交互。
  3. 增强全局上下文特征的捕捉:通过引入 Transformer 或自注意力机制,增强 PointNet++ 在全局上下文特征捕捉方面的能力。

如何改进PointNet++

要利用 PointCATCross 来改进 PointNet++,可以借鉴 PointCATCross 中的跨注意力机制、多尺度特征交互以及全局上下文特征融合等优点,从以下三个方面进行改进:

引入跨尺度注意力机制

在 PointNet++ 中的多尺度聚合(如 SA 模块)中加入跨尺度注意力机制。PointNet++ 通过半径搜索和 KNN 提取多层次的局部特征,但缺少不同尺度之间的联系和交互。通过引入 PointCATCross 的跨注意力机制,可以让不同尺度的特征相互交互,增强模型对点云全局和局部几何结构的捕捉能力。

优点:

  • 跨尺度注意力机制增强了局部和全局特征的融合,能够提高模型在处理不同分辨率点云上的表现。
  • 更好地利用多尺度信息,实现更精准的特征捕捉。

引入分类 Token 和多层级特征交互

PointCATCross 的一大特色是通过分类 Token 进行全局特征的交互与聚合。借鉴这一思路,可以在 PointNet++ 中引入类似的分类 Token 概念,在不同特征层之间使用全局 Token 来捕捉全局信息,并通过跨尺度交互让不同层的特征聚合。

优点:

  • 引入分类 Token 可以在特征提取过程中捕捉全局信息,并有效提升最终的分类或分割性能。
  • 通过多层次的特征交互,进一步增强模型对全局和局部几何关系的理解。

增强全局上下文特征的捕捉(使用 Transformer 模块)

PointNet++ 主要通过逐层局部特征提取来获取全局信息,但对长程依赖和全局上下文的捕捉较弱。可以借鉴 PointCATCross 中的全局特征融合机制,使用 Transformer 模块或全局注意力机制来增强全局特征的捕捉能力。

优点:

  • 自注意力机制能够捕捉点云中的长程依赖,特别适用于具有复杂几何结构的场景。
  • 通过 Transformer 模块,增强了模型的全局信息捕捉能力,提升了分类或分割任务的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2137602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python机器学习】序列到序列建模——使用序列到序列网络构建一个聊天机器人

为了寻聊天机器人,下面使用康奈尔电影对话语料库训练一个序列到序列的网络来“适当的”湖大问题或语句。以下聊天机器人示例采用的是Keras blog中的序列到序列的示例。 为训练准备语料库 首先,需要加载语料库并从中生成训练集,训练数据将决…

项目升级必备!TS装饰:简化代码、增加功能的利器 | TypeScript入门指南07

嘿,朋友!听说过TS里的装饰器没?就像给代码加了个‘魔法贴’,轻轻一点,功能升级,结构清晰。这篇文章,咱们聊聊这背后的魔法是怎么一回事! ts 入门指南系列 Ts vs Js 谁适合前端开发&a…

用Inno Setup打包QT程序输出安装包

InnoSetup打包编译好的QT程序 文章目录 InnoSetup打包编译好的QT程序介绍具体步骤自定义脚本更改引入配置文件/动态库路径申请管理员权限设置安装过程界面的图标和图片C程序依赖运行库 介绍 Inno Setup:用于打包安装程序 具体步骤 首先打开inno setup compiler 第…

MES系统:智能工厂与数字化改造的关键引擎

一、概念范畴: 数字化工厂:是一个更广泛的概念,它强调整个制造过程的数字化转型,包括从产品设计、生产准备、制造过程、管理到营销等各个环节的数字化。数字化工厂利用信息技术、自动化技术和智能化技术对传统工厂进行改造和升级…

怎样在备忘录中添加提醒?怎么设置备忘录提醒

备忘录作为我们日常生活中常用的软件,其记录事项的便捷性已经得到了广泛认可。无论是工作计划、购物清单还是个人日记,备忘录都能帮助我们将这些信息快速记录下来。然而,如果备忘录能够进一步提供提醒功能,那么它将变得更加实用&a…

为什么要用docker?

目录 一、为什么会有docker出现 二、docker理念 三、容器与虚拟机的差异 四、开发/运维(DevOps) 五、企业级应用 一、为什么会有docker出现 一款产品从开发到上线,从操作系统,到运行环境,再到应用配置。作为开发…

pytorch快速入门(一)—— 基本工具及平台介绍

前言 该pytorch学习笔记应该配合b站小土堆的《pytorch深度学习快速入门教程》使用 环境配置:Anaconda Python编译器:pycharm、jupyter 两大法宝函数 dir():知道包中有什么东西(函数 / 属性..…

串口数据波形显示工具对比

目录 1 Serial Port Plotter 1.1 界面 1.2 简介 1.3 使用方法 1.3.1 单通道示例 1.3.2 多通道示例 2 serialplot 2.1 界面 2.2 简介 2.3 使用方法 1 Serial Port Plotter 1.1 界面 1.2 简介 这是一款开源串口显示工具,项目链接: GitHub - C…

待机模式中WKUP上升沿模拟开机与关机

本篇博客重点在于标准库函数的理解与使用,搭建一个框架便于快速开发 目录 前言 待机模式 代码 wkup.h wkup.c main.c 使用注意 前言 建议先阅读下面的博客中待机模式部分。本博客主要分享代码-基于待机模式WKUP引脚的上升沿实现类似长按开机与关机的功能…

TikTok云手机解决运营效率低、封号问题

TikTok,一个拥有30亿下载量、10亿活跃用户的全球热门平台,已成为众多出海商家的首选运营阵地。其玩法与国内抖音相似,为运营者提供了熟悉的操作环境,降低了上手难度。然而,随着TikTok的快速发展,运营者们也…

德之匠信息化阶段模型

今天适逢仲秋佳节,祝大家幸福安康! 春生、夏长、秋天也是我们作为知识工作者最开心的季节。经历了项目实践、课题研究,终于能思有一得,是世上第一等的好事。 这个题目的背景是这样。近期我们在做一批中长期信息化项目的分析和总…

深度学习基础案例5--运用动态学习率构建CNN卷积神经网络实现的运动鞋识别(测试集的准确率84%)

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 前言 前几天一直很忙,一直在数学建模中,没有来得及更新,接下来将恢复正常这一次的案例很有意思:在学习动态调整…

【IP协议】解决 IP 地址不够用的问题(IP地址管理:动态分配、NAT、Ipv6)

文章目录 方案一、动态分配 IP 地址方案二、NATNAT 机制的缺点 方案三、IPv6 方案一、动态分配 IP 地址 一个设备上网就分配 IP,不上网就先不分配(权宜之计) 方案二、NAT 网络地址转换 以一当千,使用一个 IP,代表一大…

Golang | Leetcode Golang题解之第404题左叶子之和

题目: 题解: func isLeafNode(node *TreeNode) bool {return node.Left nil && node.Right nil }func sumOfLeftLeaves(root *TreeNode) (ans int) {if root nil {return}q : []*TreeNode{root}for len(q) > 0 {node : q[0]q q[1:]if no…

anaconda 离线安装PySimpleGUI

下载离线py包 PySimpleGUI-5.0.6-py3-none-any.whl(官网)pyasn1-0.6.1-py3-none-any.whl(清华镜像)rsa-4.9-py3-none-any.whl(清华镜像) 打开Anaconda Prompt 测试 # -*- utf-8 -*- import PySimpleGUI …

同城找搭子小程序有哪些?找搭子社交软件测评笔记分享

寻找搭子不再迷茫!今日测评几款热门找搭子小程序,为你开启全新社交体验。真实体验,深度剖析,帮你找到最适合的搭子平台,快来一探究竟。 1. 咕哇找搭子小程序:这是一个实名制的找搭子交友平台。正是由于实名…

任务通知(Task Notifications)

任务通知简介 相对于以前使用 FreeRTOS 内核通信的资源, 必须创建队列、 二进制信号量、计数信号量或事件组的情况, 使用任务通知显然更灵活。 按照 FreeRTOS 官方的说法, 使用任务通知比通过信号量等 IPC 通信方式解除阻塞的任务要快 45%&a…

基于RFID的门禁系统的设计(论文+源码)

1系统方案设计 通过需求分析,基于RFID的门禁系统总体设计框图。系统采用STM32单片机作为系统主控核心,利用独立按键与RFID识别模块能够实现门禁系统密码与IC卡开门功能。WiFi模块实现系统与手机APP的通信,用户可以通过手机APP进行门禁开关操…

怀孕之天赋共享:提高智商

想办法怀个哪吒享受三年,哈哈哈哈哈 话说我们有个同事,是属于憨得有点愣那种,情商也低。怀孕以后突然有一天我觉得她聪明了好几个数量级!反应也快了说话也会说了,我说是不是她的宝宝给她提高了智商?后来生…

Redis学习——数据不一致怎么办?更新缓存失败了又怎么办?

文章目录 引言正文读写缓存的数据一致性只读缓存的数据一致性删除和修改数据不一致问题操作执行失败导致数据不一致解决办法 多线程访问导致数据不一致问题总结 总结参考信息 引言 最近面试快手的时候被问到了缓存不一致怎么解决?一开始还是很懵的,因为…