【机器学习】图神经网络(NRI)模型原理和运动轨迹预测代码实现

news2024/12/24 21:04:59

1.引言

1.1.NRI研究的意义

在许多领域,如物理学、生物学和体育,我们遇到的系统都是由相互作用的组分构成的,这些组分在个体和整体层面上都产生复杂的动态。建模这些动态是一个重大的挑战,因为往往我们只能获取到个体的轨迹数据,而不知道其背后的相互作用机制或具体的动态模型。

以篮球运动员在球场上的运动为例,运动员的动态显然受到其他运动员的影响。作为观察者,我们能够推断出场上可能发生的各种交互,如防守、掩护等。然而,手动标注这些交互不仅繁琐,而且耗时。因此,一个更有前景的方法是在无监督的条件下学习这些底层的交互模式,这些模式可能在多种不同的任务中都具有通用性。
在这里插入图片描述

1.2.主要内容

本文将介绍一种基于图结构潜在空间的变分自编码器模型——神经关系推理(Neural Relational Inference)模型。这种模型在相关论文中被详细阐述,并配备了代码库以便于实现和实验。

神经关系推理模型旨在解决预测粒子运动轨迹的问题,特别是在存在未知粒子间相互作用的情况下。设想我们有一组粒子(例如带电粒子),它们因某种相互作用(如电磁力)而在空间中移动。我们观察到每个粒子在一段时间T内的运动轨迹,包括其位置和速度。每个粒子的新状态不仅由其当前状态决定,还受到其他粒子的影响。我们拥有一组粒子的轨迹数据,但粒子间的确切相互作用未知。

模型的目标是通过学习粒子的动态行为,基于已知的轨迹样本来预测未来的轨迹。如果已知粒子间的相互作用形式(即它们如何以图的形式相互连接),预测粒子的动态将更为直接。在这种理想情况下,每个粒子对应于图中的一个节点,而节点间的连接强度可以通过边的权重来表示。然而,在这个问题中,我们并没有获得这样的交互图。

因此,神经关系推理模型采用了变分自编码器的编码器部分,以从给定的轨迹数据中采样潜在的交互图。具体来说,编码器部分使用图神经网络(GNN)技术来捕捉粒子间的潜在关系,并生成一个能够代表这些关系的图结构。这个图结构随后被用作解码器部分的输入,以预测粒子的未来轨迹。

通过这种方式,神经关系推理模型能够同时学习粒子间的潜在交互和粒子的动态行为,从而实现更准确的轨迹预测。在训练过程中,模型通过最大化给定轨迹数据下的似然函数(即证据下界ELBO)来优化其参数,以使得生成的潜在交互图能够最好地解释和预测观察到的轨迹。

2.神经关系推理模型(NRI)原理

2.1.NRI的基本原理

神经关系推理(NRI)模型是一个专注于从观察到的轨迹中推断对象间相互作用和动态行为的模型。它由两个核心组件组成:编码器和解码器,这两个组件是联合训练的。

2.1.1.编码器

编码器负责根据观察到的轨迹数据 x x x 来预测对象间的相互作用,即潜在的图结构 z z z。这里的轨迹数据 x x x 包括 N 个对象在 T 个时间步上的特征向量集合,具体地,我们用 x i t x^t_i xit 表示第 t 个时间点上对象 v i v_i vi 的特征向量(如位置和速度)。所有对象在时间点 t 的特征集合记作 x t = ( x 1 t , … , x N t ) x^t = (x^t_1, \ldots, x^t_N) xt=(x1t,,xNt),而对象 v i v_i vi 的完整轨迹是 x i = ( x i 1 , … , x i T ) x_i = (x^1_i, \ldots, x^T_i) xi=(xi1,,xiT)

编码器 q ( z ∣ x ) q(z|x) q(zx) 的目标是输出一个分布,该分布描述了给定轨迹 x x x 下潜在图结构 z z z 的可能性。特别是, z i j z_{ij} zij 表示对象 v i v_i vi v j v_j vj 之间的离散边类型,用于表示它们之间的交互类型。编码器使用 K 种可能的交互类型对 z i j z_{ij} zij 进行一位有效编码。

2.1.2.解码器

解码器则基于编码器输出的潜在图结构 z z z 和已知的轨迹数据 x x x 来学习并预测对象的动态行为。具体来说,解码器通过以下公式定义:
p ( x ∣ z ) = ∏ t = 1 T p ( x t + 1 ∣ x t , x 1 , z ) p(x|z) = \prod_{t=1}^{T} p(x_{t+1}|x^t, x^1, z) p(xz)=t=1Tp(xt+1xt,x1,z)
它使用图神经网络(GNN)来模拟给定潜在图 z z z 和历史轨迹 x t x^t xt 下,下一个时间步 t + 1 t+1 t+1 的轨迹 x t + 1 x_{t+1} xt+1 的分布。

2.1.3.模型优化

整个模型基于变分自编码器(VAE)框架进行优化,目标是最大化证据下界(ELBO):
L = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − KL [ q ( z ∣ x ) ∣ ∣ p ( z ) ] L = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}[q(z|x) || p(z)] L=Eq(zx)[logp(xz)]KL[q(zx)∣∣p(z)]
其中,KL 表示 Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。先验 p ( z ) p(z) p(z) 假设边类型是均匀分布的,但也可以根据需要采用其他先验分布,比如鼓励稀疏图的先验。

NRI 模型通过编码器和解码器的联合训练,能够无监督地从观察到的轨迹中学习对象间的相互作用和动态行为,这对于理解复杂系统的运动规律和交互模式具有重要意义。

2.2.与VAE编码器的差异

与原始的变分自编码器(VAE)模型相比,我们的神经关系推理(NRI)模型确实存在几个显著的不同之处。以下是对这些差异的详细改写和描述:

  1. 多时间步预测
    在原始的VAE中,解码器通常被训练来根据潜在变量(z)重构单个数据点。然而,在我们的NRI模型中,为了捕捉系统动态的连续性和交互的长期影响,我们训练解码器来预测多个时间步的轨迹,而不仅仅是单个时间步。这种设置使得解码器在预测过程中能够充分利用潜在交互图(z)中的信息,从而避免了解码器忽略潜在变量(z)的问题。

  2. 离散潜在变量与连续松弛
    原始的VAE通常处理连续的潜在变量,而我们的NRI模型则使用离散的潜在变量(z)来表示对象之间的交互类型。为了能够在反向传播过程中优化离散的潜在变量,我们采用了连续的松弛方法,如Gumbel-Softmax或Straight-Through(ST)估计器,以便能够利用重参数化技巧进行梯度传播。这种方法允许我们在保持潜在变量离散性的同时,有效地优化模型参数。

  3. 未建模的初始状态
    在原始的VAE中,通常会对整个数据序列(包括初始状态)进行建模。然而,在我们的NRI模型中,我们主要关注于对象之间的动态交互和这些交互如何影响对象的轨迹。因此,我们没有对初始状态的概率(p(x^1))进行显式建模。尽管如此,如果需要,我们可以轻松地扩展模型以包含对初始状态的建模,但这通常不会显著影响模型在动态和交互预测方面的性能。

  4. 图神经网络(GNN)的引入
    除了上述差异外,我们的NRI模型还引入了图神经网络(GNN)来捕捉对象之间的交互。GNN能够处理图结构的数据,并通过在节点和边之间传递信息来更新节点的表示。在我们的模型中,GNN被用作解码器的一部分,它根据潜在交互图(z)和历史轨迹来预测未来的轨迹。这种图结构的数据处理方法使得我们的模型能够更好地捕捉对象之间的复杂交互和依赖关系。

NRI模型通过引入多时间步预测、离散潜在变量与连续松弛、以及图神经网络等方法,在保持原始VAE框架灵活性的同时,针对动态系统和交互推理问题进行了有效的改进和优化。

模型的概览图如图 1 所示。接下来,我们将详细介绍模型的编码器和解码器部分。
在这里插入图片描述图 1. NRI模型由两个共同训练的部分构成:一个编码器,它根据输入轨迹预测潜在交互的概率分布 q ( z ∣ x ) q(z|x) q(zx);以及一个解码器,它根据编码器的潜在编码和轨迹的前一时间步生成轨迹预测。编码器采用具有多轮节点到边(v → e)和边到节点(e → v)消息传递的GNN形式,而解码器则并行运行多个GNN,每个GNN对应编码器潜在编码 q ( z ∣ x ) q(z|x) q(zx)提供的一种边类型。(图引用自论文:Neural Relational Inference for Interacting Systems

2.3.编码器

编码器在NRI模型中的核心任务是,在观察到轨迹数据 ( x = (x_1, \ldots, x_T) ) 的基础上,推断出对象间潜在的成对交互类型 ( z_{ij} )。由于真实世界的图结构通常是未知的,我们利用一个在全连接图(即每对对象之间都存在潜在的边)上运作的图神经网络(GNN)来预测这种潜在的图结构。

具体地,我们构建编码器模型如下:

q ( z i j ∣ x ) = softmax ( f enc ( x ) i j 1 : K ) q(z_{ij}|x) = \text{softmax}(f_{\text{enc}}(x)_{ij}^{1:K}) q(zijx)=softmax(fenc(x)ij1:K)

其中, f enc ( x ) f_{\text{enc}}(x) fenc(x) 是我们的编码器函数,它在一个不包含自环的全连接图上应用GNN操作。给定输入轨迹 x 1 , … , x T x_1, \ldots, x_T x1,,xT,编码器执行以下消息传递操作来逐步构建对象的表示和边的嵌入:

  1. 初始化节点嵌入:
    h j 1 = f emb ( x j ) h^1_j = f_{\text{emb}}(x_j) hj1=femb(xj)
    这里, f emb f_{\text{emb}} femb 是一个嵌入函数,它将原始轨迹数据 ( x_j ) 映射到初始的节点表示 h j 1 h^1_j hj1

  2. 边嵌入的第一层更新:
    h ( i j ) 1 = f e 1 ( [ h i 1 , h j 1 ] ) h^1_{(ij)} = f^1_e([h^1_i, h^1_j]) h(ij)1=fe1([hi1,hj1])
    对于每对节点 ( i , j ) (i, j) (i,j) f e 1 f^1_e fe1 是一个边更新函数,它接受两个相邻节点的嵌入 h i 1 h^1_i hi1 h j 1 h^1_j hj1,并输出一个更新的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1

  3. 节点嵌入的第二层更新:
    h j 2 = f v 1 ( ∑ i ≠ j h ( i j ) 1 ) h^2_j = f^1_v\left(\sum_{i \neq j} h^1_{(ij)}\right) hj2=fv1 i=jh(ij)1
    这里, f v 1 f^1_v fv1 是一个节点更新函数,它聚合所有指向节点 j j j 的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1,并据此更新节点 j j j 的嵌入 h j 2 h^2_j hj2

  4. 边嵌入的第二层更新(可选):
    h ( i j ) 2 = f e 2 ( [ h i 2 , h j 2 ] ) h^2_{(ij)} = f^2_e([h^2_i, h^2_j])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1860397.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

亚马逊卖家为何需要自养账号?揭秘背后的原因

亚马逊是一家极为重视用户体验的国际电商平台,因此用户的评论和星级评价对店铺排名影响深远。平台审核评论非常严格,这些评价直接影响商品在平台上的生存和发展。 在亚马逊上,用户的评分和评论对商品的成功至关重要。好评多的商品通常被认为优…

docker容器相关命令

☆ 问题描述 docker容器相关命令 ★ 解决方案 1. 拉取镜像 docker pull ubuntu2.查看镜像是否拉取成功 docker images3. 运行容器 docker run -itd --name <容器名称> -p <主机端口>:<容器端口> --cpus30 ubuntu # -p设置端口 --cpus/-c 设置核心 …

【CPP】插入排序:直接插入排序、希尔排序

目录 1.插入排序1.1直接插入排序简介代码分析 1.2直接插入对比冒泡排序简介代码对比分析(直接插入排序与冒泡的复杂度效率区别) 1.3希尔排序简介代码分析 1.插入排序 基本思想&#xff1a;把一个待排数字按照关键码值插入到一个有序序列中&#xff0c;得到一个新的有序序列。 …

Python学习笔记19:进阶篇(八)常见标准库使用之glob模块和argparse模块

前言 本文是根据python官方教程中标准库模块的介绍&#xff0c;自己查询资料并整理&#xff0c;编写代码示例做出的学习笔记。 根据模块知识&#xff0c;一次讲解单个或者多个模块的内容。 这里贴一下教程地址&#xff1a;https://docs.python.org/zh-cn/3/tutorial/stdlib.h…

执行shell脚本出现 $‘ \r‘ 符号导致执行失败【解决】

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

mprpc框架基础类的设计

目录 1.回顾 2.主函数书写 3.框架设计 3.1 mprpcapplication.h 3.2 rpcprovider.h 3.3 mprpcapplication.cc 3.4 mprpcprovider.cc 1.回顾 mprpc框架怎么用&#xff1f; 在上一节&#xff0c;我们完成了如何把本地服务发布成RPC服务。 我们打开example下callee下的users…

制造业ERP五大生产模式详解!

制造业面临着从成本控制、生产效率到供应链管理的挑战&#xff0c;每一个环节都需要精细化的管理和高效的协同。而ERP系统&#xff0c;作为一种集信息技术与管理思想于一体的管理工具&#xff0c;正逐渐成为制造业转型升级的关键。那么&#xff0c;通过本文你将会了解到&#x…

某某商场对账返款单,table

好久不写原生html&#xff0c;今天写了个&#xff0c;快忘完了 。。。 Double Header Table ***商场统一收银结算商户对账返款单 商场&#xff08;盖章有效&#xff09; 铺位名称&#xff1a; 铺位号&#xff1a; 制单人&#xff1a; 制单日期&#xff1a; </tr><tr&…

智慧社区:居民幸福生活的保底线,价值非常大。

大屏应该能够显示社区内的关键数据&#xff0c;如人流量、车辆数量、垃圾分类情况等。这些数据可以通过图表、数字、地图等形式展示&#xff0c;以便居民和管理者能够直观地了解社区的情况。 智慧社区可视化大屏成为一个有益于社区管理和居民生活的工具&#xff0c;提供实时、准…

北京互动阅读app开发,“身”临其境,阅读精彩

随着大数据与智能化的不断发展进步&#xff0c;线上阅读软件也越来越多&#xff0c;为了缓解对传统翻页阅读方式产生的疲劳&#xff0c;人们对线上阅读提出了新的要求。对此&#xff0c;与智能科技相结合的北京互动阅读app开发&#xff0c;以高互动、高体验感的优势&#xff0c…

Casaos之qittorrent设置(没有账号密码)

点击安装只有没有账号密码&#xff0c;只能从运行日志中找密码&#xff1a; # 查看container docker ps -a # 查看container日志 docker logs ae15cb90afbd 进入系统 最下方&#xff0c;保存。

solidity智能合约如何实现跨合约调用函数

背景 比如现在有一个需求、我需要通过外部合约获取BRC20 token的总交易量。那么我需要在brc20的转账函数里面做一些调整&#xff0c;主要是两个函数内统计转移量。然后再提供外部获取函数。 /*** dev Sets amount as the allowance of spender over the callers tokens.** Ret…

从SRE视角透视DevOps的构建精髓

SRE 侧重系统稳定性&#xff0c;DevOps 强调开发运维协作。SRE 实践助力DevOps&#xff0c;提升系统稳定性与团队协作效率。 SRE 运用软件工程的原理&#xff0c;将系统管理员的手工任务自动化&#xff0c;负责运维由系统组件构成的服务&#xff0c;确保服务稳定运行。SRE职责涵…

全国计算机等级考试WPS如何报名

全国计算机等级考试WPS如何报名&#xff1f; 注册并登录 全国计算机等级考试官网选择 考试服务-在线报名选择报考省份-开始报名

现成!小众且创新idea! 小样本+故障识别!1DGAN-SVM 批量生成样本-故障识别一体化程序!MATLAB程序,直接运行!

推荐平台&#xff1a;Matlab2022版及以上 在机器学习、深度学习领域&#xff0c;数据的多样性和数量直接影响模型的性能。生成对抗算法GAN&#xff08;Generative Adversarial Network&#xff09;通过对抗过程训练&#xff0c;能够生成高度逼真的数据样本&#xff0c;增加训练…

Verifieable FHE(VFHE):使用Plonky2来证明Zama TFHE的“Bootstrapping的正确执行”

1. 引言 Zama团队2024年论文Towards Verifiable FHE in Practice: Proving Correct Execution of TFHE’s Bootstrapping using plonky2 中&#xff1a; 首次阐述了&#xff0c;在实践中&#xff0c;将整个FHE bootstrapping操作&#xff0c;使用SNARK来证明。在其相应的http…

买卖的价差与速率之间建立的关系

import numpy as np import matplotlib.pyplot as plt# 参数设置 A 100 delta_ba np.array([1, 5, 10]) # 时间差&#xff0c;以秒为单位 k_values [0.05, 0.1, 0.01] # 不同的k值# 计算不同k值下的λb,a def calculate_lambda(A, k, delta):return A * np.exp(-k * delta…

NSSCTF-Web题目16

目录 [GDOUCTF 2023]受不了一点 1、题目 2、知识点 3、思路 [UUCTF 2022 新生赛]ez_upload 1、题目 2、知识点 3、思路 [GDOUCTF 2023]受不了一点 1、题目 2、知识点 php代码审计、数组绕过、弱比较绕过 3、思路 打开题目&#xff0c;出现代码&#xff0c;我们进行代…

AIGC已经火了好几年,现在学还抓得住这个风口吗?

简介 进入人工智能&#xff08;AI&#xff09;与游戏行业整合&#xff08;AIGC&#xff09;领域需要考虑当前的市场需求和行业发展阶段&#xff1a; 1、市场需求&#xff1a;目前&#xff0c;人工智能和游戏行业都在不断发展壮大&#xff0c;并且两者的结合也逐渐成为一个热门领…

当前周周报自动生成工具

周报自动生成工具使用指南 功能简介 每次只用在xx周报.xlsx 中进行工作内容的更改&#xff0c;然后写完之后点一下run\_clear\_formulas\_and\_save.bat就能导出一份当前周周报&#xff0c;不用进行每周都自己新建一份然后把所有需要改时间的地方都改一遍这种重复劳动。 自动…