论文地址:[2006.10637] Temporal Graph Networks for Deep Learning on Dynamic Graphs (arxiv.org)
项目地址:GitHub - twitter-research/tgn: TGN: Temporal Graph Networks
作者提出了一种名为Temporal Graph Networks(TGNs)的新型深度学习框架,专门用于处理动态图数据。动态图是指图的结构或特征随时间变化的图,例如社交网络或生物互作网络。
1. 引言
- 图表示学习在多个领域取得了成功,但大多数方法假设图是静态的。
- 真实世界的交互系统通常是动态的,学习动态图是相对新颖的领域。
2. 背景
- 静态图由节点和边组成,图神经网络(GNN)通过消息传递机制来聚合邻居节点信息,生成节点嵌入。
- 动态图分为离散时间动态图(DTDG)和连续时间动态图(CTDG)。离散时间动态图(DTDG)是按时间间隔拍摄的静态图快照序列。连续时间动态图(CTDG)更通用,可以表示为事件的定时列表,其中可能包括边添加或删除、节点添加或删除以及节点或边特征转换。
- 我们的时态(多)图被建模为一系列带有时间戳的事件
3. Temporal Graph Networks (TGNs)
TGNs是为了处理连续时间动态图(CTDGs)而设计的,这类图可以用一系列时间标记的事件序列来表示。TGNs由以下几个核心模块组成:
顶部:嵌入模块使用时态图和节点的内存 (1) 生成。然后,使用嵌入来预测批量交互并计算损失 (2, 3)。
底部:这些相同的交互用于更新内存 (4, 5, 6)。这是一个简化的操作流程,可以防止在底部训练所有模块,因为它们不会接收梯度。第 3.2 节解释了如何更改操作流程以解决此问题,
3.1 核心模块
用于训练内存相关模块的 TGN 的操作流程。原始消息存储存储计算消息所需的原始信息,即消息函数的输入,我们称之为原始消息,用于过去由模型处理过的交互。这使得模型可以将交互带来的内存更新延迟到后续批次。首先,使用从前一批 (1、2、3) 中存储的原始消息计算出的消息来更新内存。然后可以使用刚刚更新的内存(灰色链接)(4)来计算嵌入。通过这样做,内存相关模块的计算直接影响损失 (5, 6),并且它们会接收梯度。最后,此批处理交互的原始消息存储在原始消息存储 (6) 中,以便在将来的批处理中使用。
记忆(Memory):
- 每个节点
i
都有一个状态向量si(t)
,表示模型到目前为止所看到的状态。 - 当节点参与事件(例如与其他节点的交互或节点特征变化)时,状态向量会更新。
- 记忆的目的是以压缩格式表示节点的历史。
消息函数(Message Function):
- 对于涉及节点
i
的每个事件,计算一个消息以更新i
的状态。 - 消息函数
msgs
,msgd
, 和msgn
是可学习的函数,例如多层感知机(MLPs)。
消息聚合器(Message Aggregator):
- 由于批处理的原因,同一节点可能在同一批中涉及多个事件。
- 使用聚合函数
agg
来合并针对同一节点的多个消息。
记忆更新器(Memory Updater):
- 节点的记忆在每次涉及该节点的事件后更新。
- 更新函数
mem
可以是例如长短期记忆网络(LSTM)或门控循环单元(GRU)这样的循环神经网络。
嵌入模块(Embedding):
- 嵌入模块用于生成节点在任何时间
t
的时序嵌入zi(t)
。 - 嵌入模块的目的是避免所谓的记忆陈旧问题,即节点长时间没有参与事件时,其记忆可能变得过时。
3.2 训练
TGN的训练策略需要解决记忆相关模块(消息函数、消息聚合器和记忆更新器)不直接影响损失函数,因此它们不会直接获得梯度的问题。为了解决这个问题,论文提出了一种训练流程,其中包括:
- 使用Raw Message Store来存储过去批次中处理的交互的原始信息。
- 在预测当前批次的交互之前,使用这些存储的信息来更新记忆。
- 这样,记忆相关的模块就可以通过影响损失函数来获得梯度。
4. 相关工作
- 论文回顾了早期关于DTDGs的模型,包括聚合快照、张量分解和RNNs等方法。
- 近期的工作开始关注CTDGs,例如使用RNNs更新节点表示。
5. 实验
- 使用Wikipedia、Reddit和Twitter数据集进行实验,关注未来边缘预测和动态节点分类任务。
- 与现有的连续时间动态图学习方法和静态图方法进行比较。
6. 结论
- TGN是一个通用框架,用于学习连续时间动态图,在多个任务和数据集上取得了最先进的结果,并且比以前的方法更快。