文章目录
- @[toc]
- TGN模型训练阶段代码理解
- 论文信息
- 代码过程手绘
- 代码训练过程
- compute_temporal_embeddings
- update_memory
- get_raw_messages
- get_updated_memory
- self.message_aggregator.aggregate
- self.memory_updater.get_updated_memory
- Memory
- get_embedding_module
- GraphAttentionEmbedding
- TimeEncode
- NeighborFinder
- MergeLayer
文章目录
- @[toc]
- TGN模型训练阶段代码理解
- 论文信息
- 代码过程手绘
- 代码训练过程
- compute_temporal_embeddings
- update_memory
- get_raw_messages
- get_updated_memory
- self.message_aggregator.aggregate
- self.memory_updater.get_updated_memory
- Memory
- get_embedding_module
- GraphAttentionEmbedding
- TimeEncode
- NeighborFinder
- MergeLayer
TGN模型训练阶段代码理解
论文信息
论文链接:https://arxiv.org/abs/2006.10637
GitHub: https://github.com/twitter-research/tgn?tab=readme-ov-file
年份:2020
代码过程手绘
代码训练过程
pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)
函数compute_edge_probabilities
def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute probabilities for edges between sources and destination and between sources and
negatives by first computing temporal embeddings using the TGN encoder and then feeding them
into the MLP decoder.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Probabilities for both the positive and negative edges
source_nodes 源节点id列表
destination_nodes 目标节点id列表
negative_nodes 负采样节点id列表
edge_times 源节点列表中的节点与目标节点列表中的节点发生关系时的时间
edge_idxs 边的编号
"""
n_samples = len(source_nodes)
# compute_temporal_embeddings
source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)
score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
torch.cat([destination_node_embedding,
negative_node_embedding])).squeeze(dim=0)
pos_score = score[:n_samples]
neg_score = score[n_samples:]
return pos_score.sigmoid(), neg_score.sigmoid()
compute_temporal_embeddings
这个方法的目的是计算时间嵌入
def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
"""
Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
这个方法的目的是计算时间嵌入
source_nodes [batch_size]: source ids.
:param destination_nodes [batch_size]: destination ids
:param negative_nodes [batch_size]: ids of negative sampled destination
:param edge_times [batch_size]: timestamp of interaction
:param edge_idxs [batch_size]: index of interaction
:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
layer
:return: Temporal embeddings for sources, destinations and negatives
"""
# n_samples 表示源节点有多少个
n_samples = len(source_nodes)
# nodes是所有的节点这个batch_size中所有的节点id, size=200*3=600
nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
# positives 是将源节点和目标节点和在一切,前200个是源节点的node_id, 后200个是目标节点的node_id
positives = np.concatenate([source_nodes, destination_nodes])
# timestamps shape=200*3 edge_times 是发生交互的时间
timestamps = np.concatenate([edge_times, edge_times, edge_times])
# edge_times shape = batch_size 是源节点和目的节点发生的时间
memory = None
time_diffs = None
if self.use_memory:
if self.memory_update_at_start: # 是不是刚开始使用记忆
# n_nodes 表示的是图中一共有多少个节点 9228
# 记忆列表 self.memory.messages 当前状态一定为空
# 在这个地方出来的memory是最新的memory,是根据节点的messages信息进行更新的,在代码中会取该节点messages列表中最新的那一个
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages) # memory shape = [n_nodes(9228), memory_dimension(172)] last_update shape [n_nodes(9228)]
else:
memory = self.memory.get_memory(list(range(self.n_nodes)))
last_update = self.memory.last_update
# ===================================== 下面这些都是处理单个节点的信息 ==============================
# 计算节点内存最后一次更新的时间与我们希望计算该节点嵌入的时间之间的差异。
# source_time_diffs shape [batch_size]
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long()
# 这是标准化操作
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
# destination_time_diffs shape [batch_size]
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long()
# 这是标准化操作
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
negative_nodes].long()
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
# 时间差
time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
dim=0)
# Compute the embeddings using the embedding module
# self.embedding_module 在下面所示
# 1. 先是 self.embedding_module = get_embedding_module()
"""
memory 记忆对象
nodes 是一个结合了源节点目的节点和负采样节点的node_id列表
timestamps 200*3的时间列表
self.n_layers 递归的层数 这里为2
n_neighbors 选取多少个邻居节点 这里是10
time_diffs 标准化过后的时间差
"""
# node_embedding shape [600, 172] 融合了节点的特征和邻居其余边的特征
node_embedding = self.embedding_module.compute_embedding(memory=memory,
source_nodes=nodes,
timestamps=timestamps,
n_layers=self.n_layers,
n_neighbors=n_neighbors,
time_diffs=time_diffs)
# 然后去获取不同列表的节点特征
source_node_embedding = node_embedding[:n_samples]
destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
negative_node_embedding = node_embedding[2 * n_samples:]
if self.use_memory:
# 进行记忆力更新
if self.memory_update_at_start:
# Persist the updates to the memory only for sources and destinations (since now we have
# new messages for them)
self.update_memory(positives, self.memory.messages)
assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"
# Remove messages for the positives since we have already updated the memory using them
# 记忆已经更新了,那么对于每个信息就即为空
self.memory.clear_messages(positives)
unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
source_node_embedding,
destination_nodes,
destination_node_embedding,
edge_times, edge_idxs)
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
destination_node_embedding,
source_nodes,
source_node_embedding,
edge_times, edge_idxs)
if self.memory_update_at_start:
# 存储信息
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
else:
self.update_memory(unique_sources, source_id_to_messages)
self.update_memory(unique_destinations, destination_id_to_messages)
if self.dyrep:
source_node_embedding = memory[source_nodes]
destination_node_embedding = memory[destination_nodes]
negative_node_embedding = memory[negative_nodes]
return source_node_embedding, destination_node_embedding, negative_node_embedding
update_memory
def update_memory(self, nodes, messages):
# Aggregate messages for the same nodes
# self.message_aggregator -> LastMessageAggregator
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)
# Update the memory with the aggregated messages
# 聚合完了就去更新
self.memory_updater.update_memory(unique_nodes, unique_messages,
timestamps=unique_timestamps)
get_raw_messages
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
destination_node_embedding, edge_times, edge_idxs):
# edge_times shape is [200, ]
edge_times = torch.from_numpy(edge_times).float().to(self.device)
# edge_features shape is [200, 172]
edge_features = self.edge_raw_features[edge_idxs]
source_memory = self.memory.get_memory(source_nodes) if not \
self.use_source_embedding_in_message else source_node_embedding
destination_memory = self.memory.get_memory(destination_nodes) if \
not self.use_destination_embedding_in_message else destination_node_embedding
source_time_delta = edge_times - self.memory.last_update[source_nodes]
# source_time_delta_encoding [200, 172]
source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
source_nodes), -1)
# source_message shape [200, 688]
source_message = torch.cat([source_memory, destination_memory, edge_features,
source_time_delta_encoding],
dim=1)
messages = defaultdict(list)
unique_sources = np.unique(source_nodes)
for i in range(len(source_nodes)):
messages[source_nodes[i]].append((source_message[i], edge_times[i]))
return unique_sources, messages
get_updated_memory
def get_updated_memory(self, nodes, messages):
# Aggregate messages for the same nodes
# nodes 是一个列表 range(n_nodes)
# messages是消息列表
# 先是聚合消息,然后更新记忆
# 在第一次进来这个函数的时候,返回的全是[]
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes, # 是一个列表 range(n_nodes)
messages # 是消息列表
)
if len(unique_nodes) > 0:
# 有两个选择
"""
class MLPMessageFunction(MessageFunction):
def __init__(self, raw_message_dimension, message_dimension):
super(MLPMessageFunction, self).__init__()
self.mlp = self.layers = nn.Sequential(
nn.Linear(raw_message_dimension, raw_message_dimension // 2),
nn.ReLU(),
nn.Linear(raw_message_dimension // 2, message_dimension),
)
def compute_message(self, raw_messages):
messages = self.mlp(raw_messages)
return messages
class IdentityMessageFunction(MessageFunction):
def compute_message(self, raw_messages):# 作者使用的是这个,啥也没有边,直接返回
return raw_messages
"""
unique_messages = self.message_function.compute_message(unique_messages)
# 在头一次训练的过程中进来这个地方, 返回的全是0的矩阵
# 形状为,[n_nodes, memory_dimension] [n_nodes]
updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
unique_messages,
timestamps=unique_timestamps)
return updated_memory, updated_last_update
self.message_aggregator.aggregate
代码中默认使用last
def get_message_aggregator(aggregator_type, device):
if aggregator_type == "last":
return LastMessageAggregator(device=device)
elif aggregator_type == "mean":
return MeanMessageAggregator(device=device)
else:
raise ValueError("Message aggregator {} not implemented".format(aggregator_type))
LastMessageAggregator代码:
class LastMessageAggregator(MessageAggregator):
def __init__(self, device):
super(LastMessageAggregator, self).__init__(device)
def aggregate(self, node_ids, messages):
"""Only keep the last message for each node"""
unique_node_ids = np.unique(node_ids) # 去重节点,不知道啥作用,因为本来就没有重复
unique_messages = []
unique_timestamps = []
to_update_node_ids = []
for node_id in unique_node_ids: # 循环range(n_nodes)=9228
if len(messages[node_id]) > 0:
"""
上一步结束每个节点存储的信息以及对应的(时间?)
source_message = torch.cat([source_memory, destination_memory, edge_features,
source_time_delta_encoding], dim=1)
source_message, edge_times
"""
to_update_node_ids.append(node_id)
unique_messages.append(messages[node_id][-1][0])
unique_timestamps.append(messages[node_id][-1][1])
unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
return to_update_node_ids, unique_messages, unique_timestamps
self.memory_updater.get_updated_memory
代码中默认采用使用gru的方式去更新记忆
class SequenceMemoryUpdater(MemoryUpdater):
def __init__(self, memory, message_dimension, memory_dimension, device):
super(SequenceMemoryUpdater, self).__init__()
self.memory = memory
self.layer_norm = torch.nn.LayerNorm(memory_dimension)
self.message_dimension = message_dimension
self.device = device
def update_memory(self, unique_node_ids, unique_messages, timestamps):
if len(unique_node_ids) <= 0:
return
assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
"update memory to time in the past"
memory = self.memory.get_memory(unique_node_ids)
self.memory.last_update[unique_node_ids] = timestamps
updated_memory = self.memory_updater(unique_messages, memory)
self.memory.set_memory(unique_node_ids, updated_memory)
def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
if len(unique_node_ids) <= 0:
# 这里的self.memory在下面进行定义
# self.memory.memory 在初始化的时候是一个全为0,shape=[n_nodes, memory_dimension], 没有梯度的矩阵
# self.memory.last_update 在初始化的时候是一个全为0,shape=[n_nodes], 没有梯度的举证
# 这里的clone是深拷贝,并不会影响原来的值是多少
# 第二次就不是走这里咯
return self.memory.memory.data.clone(), self.memory.last_update.data.clone()
assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
"update memory to time in the past"
updated_memory = self.memory.memory.data.clone()
updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])
updated_last_update = self.memory.last_update.data.clone()
updated_last_update[unique_node_ids] = timestamps
return updated_memory, updated_last_update
class GRUMemoryUpdater(SequenceMemoryUpdater):
def __init__(self, memory, message_dimension, memory_dimension, device):
super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
self.memory_updater = nn.GRUCell(input_size=message_dimension,
hidden_size=memory_dimension)
Memory
class Memory(nn.Module):
def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
device="cpu", combination_method='sum'):
super(Memory, self).__init__()
self.n_nodes = n_nodes
self.memory_dimension = memory_dimension
self.input_dimension = input_dimension
self.message_dimension = message_dimension
self.device = device
self.combination_method = combination_method
self.__init_memory__()
# 这是初是化
def __init_memory__(self):
"""
Initializes the memory to all zeros. It should be called at the start of each epoch.
"""
# Treat memory as parameter so that it is saved and loaded together with the model
# self.memory_dimension = 172
# self.n_nodes = 9228
# self.memory shape is [9228, 172]的一个记忆,每一个节点都有对应的记忆,并且每一个记忆向量是172
# self.memory = 一个全为0的矩阵
self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
requires_grad=False)
# last_update shape = [9228]
self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
requires_grad=False)
self.messages = defaultdict(list)
def store_raw_messages(self, nodes, node_id_to_messages):
for node in nodes:
self.messages[node].extend(node_id_to_messages[node])
def get_memory(self, node_idxs):
return self.memory[node_idxs, :]
def set_memory(self, node_idxs, values):
self.memory[node_idxs, :] = values
def get_last_update(self, node_idxs):
return self.last_update[node_idxs]
def backup_memory(self):
messages_clone = {}
for k, v in self.messages.items():
messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
return self.memory.data.clone(), self.last_update.data.clone(), messages_clone
def restore_memory(self, memory_backup):
self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()
self.messages = defaultdict(list)
for k, v in memory_backup[2].items():
self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
def detach_memory(self):
self.memory.detach_()
# Detach all stored messages
for k, v in self.messages.items():
new_node_messages = []
for message in v:
new_node_messages.append((message[0].detach(), message[1]))
self.messages[k] = new_node_messages
def clear_messages(self, nodes):
for node in nodes:
self.messages[node] = []
get_embedding_module
这里的module_type=graph_attention
def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
embedding_dimension, device,
n_heads=2, dropout=0.1, n_neighbors=None,
use_memory=True):
# embedding_module采用的是这个
if module_type == "graph_attention":
return GraphAttentionEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
elif module_type == "graph_sum":
return GraphSumEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
n_heads=n_heads, dropout=dropout, use_memory=use_memory)
elif module_type == "identity":
return IdentityEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
dropout=dropout)
elif module_type == "time":
return TimeEmbedding(node_features=node_features,
edge_features=edge_features,
memory=memory,
neighbor_finder=neighbor_finder,
time_encoder=time_encoder,
n_layers=n_layers,
n_node_features=n_node_features,
n_edge_features=n_edge_features,
n_time_features=n_time_features,
embedding_dimension=embedding_dimension,
device=device,
dropout=dropout,
n_neighbors=n_neighbors)
else:
raise ValueError("Embedding Module {} not supported".format(module_type))
GraphAttentionEmbedding
class GraphEmbedding(EmbeddingModule):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True):
super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features,
embedding_dimension, device, dropout)
self.use_memory = use_memory
self.device = device
def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
use_time_proj=True):
"""Recursive implementation of curr_layers temporal graph attention layers.
使用递归的方式来实现一系列时间图注意力
src_idx_l [batch_size]: users / items input ids.
cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
curr_layers [scalar]: number of temporal convolutional layers to stack.
num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
"""
"""
memory 记忆对象
source_nodes 是一个结合了源节点目的节点和负采样节点的node_id列表(一开始是,后面不是)
timestamps 200*3的时间列表
self.n_layers 递归的层数 这里为2
n_neighbors 选取多少个邻居节点 这里是10
time_diffs 标准化过后的时间差
"""
assert (n_layers >= 0)
# source_nodes_torch shape = [n_nodes]
source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
# timestamps_torch shape = [3*200, 1]
timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)
# query node always has the start time -> time span == 0
# 这里的time_encoder是一个模型,经过的是一个cos(linear(x)),在下面有对应的代码
# torch.zeros_like(timestamps_torch) 是一个全为0 shape = [3*200, 1]
# source_nodes_time_embedding shape = [3*200, 1, 172]
source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch))
# self.node_features是一个全为0的矩阵
# self.node_features shape is [n_nodes, node_dim] = [9228, 172]
# source_node_features 是所有节点的特征 shape is [600, 172]
source_node_features = self.node_features[source_nodes_torch, :]
if self.use_memory:
# 将节点当前的特征 再加上记忆中节点的特征
source_node_features = memory[source_nodes, :] + source_node_features
# ====================================== 这下面执行了一个递归的操作 ==================================
# n_layers = 1
if n_layers == 0:
return source_node_features
else:
# 再一次调用自己,返回的是节点的特征shape is [600, 172]
source_node_conv_embeddings = self.compute_embedding(memory,
source_nodes,
timestamps,
n_layers=n_layers - 1,
n_neighbors=n_neighbors)
# 获得是source_nodes这3*200个节点,在3*200的时间列表中,选取前十个邻居
"""
neighbors shape is [3*200, n_neighbors]
edge_idxs shape is [3*200, n_neighbors]
edge_times shape is [3*200, n_neighbors]
"""
neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
source_nodes,
timestamps,
n_neighbors=n_neighbors)
# 这里的邻居节点node_id是source_nodes中的每一个邻居节点,变成torch形式
neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
# 时间差,600个节点的
edge_deltas = timestamps[:, np.newaxis] - edge_times
edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
# 展平,变成6000
neighbors = neighbors.flatten()
# 这是neighbor_embeddings shape = [600*10, 172]
neighbor_embeddings = self.compute_embedding(memory,
neighbors, # 这里有6000个
np.repeat(timestamps, n_neighbors), # 也是6000
n_layers=n_layers - 1,
n_neighbors=n_neighbors)
effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
# 这是neighbor_embeddings shape = [600, 10, 172]
neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
# edge_time_embeddings shape is [600, 10, 172]
edge_time_embeddings = self.time_encoder(edge_deltas_torch)
# self.edge_features shape [157475, 172]
# edge_idxs shape [600, 10]
# edge_features shape [600, 10, 172]
edge_features = self.edge_features[edge_idxs, :]
mask = neighbors_torch == 0
# 这个聚合在下面
"""
n_layers: 1
source_node_conv_embeddings: 一开始那600个节点的编码
source_nodes_time_embedding: 数据是和timestamps_torch一样的0矩阵[3*200, 1, 172]
neighbor_embeddings: 之前那600个节点的发生过操作的邻居
edge_time_embeddings: 时间差编码
edge_features: 一开始那600个节点,对应的十个邻居,分别边的特征是多少
mask = [600*10]
"""
source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings,
edge_features,
mask)
return source_embedding
def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings, edge_features, mask):
return NotImplemented
class GraphAttentionEmbedding(GraphEmbedding):
def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
n_heads=2, dropout=0.1, use_memory=True):
super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
neighbor_finder, time_encoder, n_layers,
n_node_features, n_edge_features,
n_time_features,
embedding_dimension, device,
n_heads, dropout,
use_memory)
self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
n_node_features=n_node_features,
n_neighbors_features=n_node_features,
n_edge_features=n_edge_features,
time_dim=n_time_features,
n_head=n_heads,
dropout=dropout,
output_dimension=n_node_features)
for _ in range(n_layers)])
def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings, edge_features, mask):
attention_model = self.attention_models[n_layer - 1]
source_embedding, _ = attention_model(source_node_features,
source_nodes_time_embedding,
neighbor_embeddings,
edge_time_embeddings,
edge_features,
mask)
return source_embedding
TimeEncode
class TimeEncode(torch.nn.Module):
# Time Encoding proposed by TGAT
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension # 172
self.w = torch.nn.Linear(1, dimension)
# todo
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.float().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
def forward(self, t): # -> [batch_size, seq_len, dimension]
# t has shape [batch_size, seq_len]
# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
t = t.unsqueeze(dim=2)
# output has shape [batch_size, seq_len, dimension]
output = torch.cos(self.w(t))
return output
NeighborFinder
class NeighborFinder:
def __init__(self, adj_list, uniform=False, seed=None):
self.node_to_neighbors = []
self.node_to_edge_idxs = []
self.node_to_edge_timestamps = []
for neighbors in adj_list:
# Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
# We sort the list based on timestamp
sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors])) # 是一个二维数组,第一个维度表示的是某一个节点,第二个维度表示的是这个节点和那些节点发生的联系
self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))
self.uniform = uniform
if seed is not None:
self.seed = seed
self.random_state = np.random.RandomState(self.seed)
def find_before(self, src_idx, cut_time):
"""
Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.
Returns 3 lists: neighbors, edge_idxs, timestamps
"""
i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)
return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[
src_idx][:i]
def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
"""
Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.
Params
------
src_idx_l: List[int]
cut_time_l: List[float],
num_neighbors: int
"""
assert (len(source_nodes) == len(timestamps))
tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
# NB! All interactions described in these matrices are sorted in each row by time
neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( # shape [600, 10]
np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
timestamp) # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time
if len(source_neighbors) > 0 and n_neighbors > 0:
if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling
sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)
neighbors[i, :] = source_neighbors[sampled_idx]
edge_times[i, :] = source_edge_times[sampled_idx]
edge_idxs[i, :] = source_edge_idxs[sampled_idx]
# re-sort based on time
pos = edge_times[i, :].argsort()
neighbors[i, :] = neighbors[i, :][pos]
edge_times[i, :] = edge_times[i, :][pos]
edge_idxs[i, :] = edge_idxs[i, :][pos]
else:
# Take most recent interactions
source_edge_times = source_edge_times[-n_neighbors:]
source_neighbors = source_neighbors[-n_neighbors:]
source_edge_idxs = source_edge_idxs[-n_neighbors:]
assert (len(source_neighbors) <= n_neighbors)
assert (len(source_edge_times) <= n_neighbors)
assert (len(source_edge_idxs) <= n_neighbors)
neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs
return neighbors, edge_idxs, edge_times
class TemporalAttentionLayer(torch.nn.Module):
"""
Temporal attention layer. Return the temporal embedding of a node given the node itself,
its neighbors and the edge timestamps.
"""
def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
output_dimension, n_head=2,
dropout=0.1):
super(TemporalAttentionLayer, self).__init__()
self.n_head = n_head
self.feat_dim = n_node_features
self.time_dim = time_dim
self.query_dim = n_node_features + time_dim
self.key_dim = n_neighbors_features + time_dim + n_edge_features
self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)
self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
kdim=self.key_dim,
vdim=self.key_dim,
num_heads=n_head,
dropout=dropout)
def forward(self, src_node_features, src_time_features, neighbors_features,
neighbors_time_features, edge_features, neighbors_padding_mask):
"""
"Temporal attention model
:param src_node_features: float Tensor of shape [batch_size, n_node_features]
:param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
:param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
:param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
time_dim]
:param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
:param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
:return:
attn_output: float Tensor of shape [1, batch_size, n_node_features]
attn_output_weights: [batch_size, 1, n_neighbors]
"""
# src_node_features_unrolled shape is [600, 1, 172]
src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
# 将节点特征和时间特征结合在一起
# query shape is [600, 1, 172*2]
query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
# 邻居的特征、边的特征和时间差特征组合在一起 key shape = [600, 10, 516]
key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)
# query shape is [1, 600, 344]
query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features]
# key shape is [10, 600, 516]
key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features]
# 在dim=1的维度下,要是全为True,那么就代表这一行是没有用的,反之为False
invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
#
neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False
# print(query.shape, key.shape)
attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
key_padding_mask=neighbors_padding_mask)
# mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1]
# mask = mask.permute([0, 2, 1])
# attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
# mask=mask)
# attn_output shape = [600, 344]
# attn_output_weights = [600, 10]
attn_output = attn_output.squeeze()
attn_output_weights = attn_output_weights.squeeze()
# Source nodes with no neighbors have an all zero attention output. The attention output is
# then added or concatenated to the original source node features and then fed into an MLP.
# This means that an all zero vector is not used.
attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)
# Skip connection with temporal attention over neighborhood and the features of the node itself
# attn_output = [600, 172]
attn_output = self.merger(attn_output, src_node_features)
return attn_output, attn_output_weights
MergeLayer
class MergeLayer(torch.nn.Module):
def __init__(self, dim1, dim2, dim3, dim4):
super().__init__()
self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
self.fc2 = torch.nn.Linear(dim3, dim4)
self.act = torch.nn.ReLU()
torch.nn.init.xavier_normal_(self.fc1.weight)
torch.nn.init.xavier_normal_(self.fc2.weight)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
h = self.act(self.fc1(x))
return self.fc2(h)