论文代码阅读:TGN模型训练阶段代码理解

news2024/11/30 0:37:22

文章目录

    • @[toc]
  • TGN模型训练阶段代码理解
    • 论文信息
    • 代码过程手绘
    • 代码训练过程
          • compute_temporal_embeddings
          • update_memory
          • get_raw_messages
          • get_updated_memory
          • self.message_aggregator.aggregate
          • self.memory_updater.get_updated_memory
          • Memory
          • get_embedding_module
          • GraphAttentionEmbedding
          • TimeEncode
          • NeighborFinder
          • MergeLayer

TGN模型训练阶段代码理解

论文信息

论文链接:https://arxiv.org/abs/2006.10637

GitHub: https://github.com/twitter-research/tgn?tab=readme-ov-file

年份:2020

代码过程手绘

微信图片_20231210165320

微信图片_20231210165409

代码训练过程

pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

函数compute_edge_probabilities

def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                               edge_idxs, n_neighbors=20):
    """
    Compute probabilities for edges between sources and destination and between sources and
    negatives by first computing temporal embeddings using the TGN encoder and then feeding them
    into the MLP decoder.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Probabilities for both the positive and negative edges
    source_nodes 源节点id列表
    destination_nodes 目标节点id列表
    negative_nodes 负采样节点id列表
    edge_times 源节点列表中的节点与目标节点列表中的节点发生关系时的时间
    edge_idxs 边的编号
    """
    n_samples = len(source_nodes)
    # compute_temporal_embeddings
    source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)

    score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
                                torch.cat([destination_node_embedding,
                                           negative_node_embedding])).squeeze(dim=0)
    pos_score = score[:n_samples]
    neg_score = score[n_samples:]

    return pos_score.sigmoid(), neg_score.sigmoid()
compute_temporal_embeddings

这个方法的目的是计算时间嵌入

def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
                                edge_idxs, n_neighbors=20):
    """
    Compute temporal embeddings for sources, destinations, and negatively sampled destinations.
	这个方法的目的是计算时间嵌入
    source_nodes [batch_size]: source ids.
    :param destination_nodes [batch_size]: destination ids
    :param negative_nodes [batch_size]: ids of negative sampled destination
    :param edge_times [batch_size]: timestamp of interaction
    :param edge_idxs [batch_size]: index of interaction
    :param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
    layer
    :return: Temporal embeddings for sources, destinations and negatives
    """
	
    # n_samples 表示源节点有多少个
    n_samples = len(source_nodes)
    # nodes是所有的节点这个batch_size中所有的节点id, size=200*3=600
    nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
    # positives 是将源节点和目标节点和在一切,前200个是源节点的node_id, 后200个是目标节点的node_id
    positives = np.concatenate([source_nodes, destination_nodes])
    # timestamps shape=200*3 edge_times 是发生交互的时间
    timestamps = np.concatenate([edge_times, edge_times, edge_times])
	# edge_times shape = batch_size 是源节点和目的节点发生的时间
    memory = None
    time_diffs = None
    if self.use_memory:
        if self.memory_update_at_start: # 是不是刚开始使用记忆
            # n_nodes 表示的是图中一共有多少个节点 9228
            # 记忆列表 self.memory.messages 当前状态一定为空
            # 在这个地方出来的memory是最新的memory,是根据节点的messages信息进行更新的,在代码中会取该节点messages列表中最新的那一个
            memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
                                                          self.memory.messages)  # memory shape = [n_nodes(9228), memory_dimension(172)] last_update shape [n_nodes(9228)]
        else:
            memory = self.memory.get_memory(list(range(self.n_nodes)))
            last_update = self.memory.last_update
		
        # ===================================== 下面这些都是处理单个节点的信息 ==============================
		 # 计算节点内存最后一次更新的时间与我们希望计算该节点嵌入的时间之间的差异。
        # source_time_diffs shape [batch_size]
        source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long() 
        # 这是标准化操作
        source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src 
        # destination_time_diffs shape [batch_size]
        destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long()
        # 这是标准化操作
        destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

        negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
            negative_nodes].long()
        negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

        # 时间差
        time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
                               dim=0)

    # Compute the embeddings using the embedding module
    # self.embedding_module 在下面所示
    # 1. 先是 self.embedding_module = get_embedding_module()
    """
    memory 记忆对象
    nodes 是一个结合了源节点目的节点和负采样节点的node_id列表
    timestamps 200*3的时间列表
    self.n_layers 递归的层数 这里为2
    n_neighbors 选取多少个邻居节点 这里是10
    time_diffs 标准化过后的时间差
    """
    # node_embedding shape [600, 172] 融合了节点的特征和邻居其余边的特征
    node_embedding = self.embedding_module.compute_embedding(memory=memory,
                                                             source_nodes=nodes,
                                                             timestamps=timestamps,
                                                             n_layers=self.n_layers,
                                                             n_neighbors=n_neighbors,
                                                             time_diffs=time_diffs)
	# 然后去获取不同列表的节点特征 
    source_node_embedding = node_embedding[:n_samples]
    destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
    negative_node_embedding = node_embedding[2 * n_samples:]

    if self.use_memory:
        # 进行记忆力更新
        if self.memory_update_at_start:
            # Persist the updates to the memory only for sources and destinations (since now we have
            # new messages for them)
            self.update_memory(positives, self.memory.messages)

            assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
                "Something wrong in how the memory was updated"

            # Remove messages for the positives since we have already updated the memory using them
            # 记忆已经更新了,那么对于每个信息就即为空
            self.memory.clear_messages(positives)

        unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
                                                                      source_node_embedding,
                                                                      destination_nodes,
                                                                      destination_node_embedding,
                                                                      edge_times, edge_idxs)
        unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
                                                                                destination_node_embedding,
                                                                                source_nodes,
                                                                                source_node_embedding,
                                                                                edge_times, edge_idxs)
        if self.memory_update_at_start:
            # 存储信息
            self.memory.store_raw_messages(unique_sources, source_id_to_messages)
            self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
        else:
            self.update_memory(unique_sources, source_id_to_messages)
            self.update_memory(unique_destinations, destination_id_to_messages)

        if self.dyrep:
            source_node_embedding = memory[source_nodes]
            destination_node_embedding = memory[destination_nodes]
            negative_node_embedding = memory[negative_nodes]

    return source_node_embedding, destination_node_embedding, negative_node_embedding
update_memory
def update_memory(self, nodes, messages):
    # Aggregate messages for the same nodes
    # self.message_aggregator -> LastMessageAggregator
    unique_nodes, unique_messages, unique_timestamps = \
        self.message_aggregator.aggregate(
            nodes,
            messages)

    if len(unique_nodes) > 0:
        unique_messages = self.message_function.compute_message(unique_messages)

    # Update the memory with the aggregated messages
    # 聚合完了就去更新
    self.memory_updater.update_memory(unique_nodes, unique_messages,
                                      timestamps=unique_timestamps)
get_raw_messages
def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
                     destination_node_embedding, edge_times, edge_idxs):
    # edge_times shape is [200, ]
    edge_times = torch.from_numpy(edge_times).float().to(self.device)
    # edge_features shape is [200, 172]
    edge_features = self.edge_raw_features[edge_idxs]

    source_memory = self.memory.get_memory(source_nodes) if not \
        self.use_source_embedding_in_message else source_node_embedding
    destination_memory = self.memory.get_memory(destination_nodes) if \
        not self.use_destination_embedding_in_message else destination_node_embedding

    source_time_delta = edge_times - self.memory.last_update[source_nodes]
    # source_time_delta_encoding [200, 172]
    source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
        source_nodes), -1)
	# source_message shape [200, 688]
    source_message = torch.cat([source_memory, destination_memory, edge_features,
                                source_time_delta_encoding],
                               dim=1)
    messages = defaultdict(list)
    unique_sources = np.unique(source_nodes)

    for i in range(len(source_nodes)):
        messages[source_nodes[i]].append((source_message[i], edge_times[i]))

    return unique_sources, messages
get_updated_memory
def get_updated_memory(self, nodes, messages):
    # Aggregate messages for the same nodes
    # nodes 是一个列表 range(n_nodes)
    # messages是消息列表
    # 先是聚合消息,然后更新记忆
    # 在第一次进来这个函数的时候,返回的全是[]
    unique_nodes, unique_messages, unique_timestamps = \
        self.message_aggregator.aggregate(
            nodes, # 是一个列表 range(n_nodes)
            messages # 是消息列表
    )

    if len(unique_nodes) > 0:
        # 有两个选择
        """
        class MLPMessageFunction(MessageFunction):
    def __init__(self, raw_message_dimension, message_dimension):
        super(MLPMessageFunction, self).__init__()
        self.mlp = self.layers = nn.Sequential(
            nn.Linear(raw_message_dimension, raw_message_dimension // 2),
            nn.ReLU(),
            nn.Linear(raw_message_dimension // 2, message_dimension),
        )

    def compute_message(self, raw_messages):
        messages = self.mlp(raw_messages)
        return messages

    class IdentityMessageFunction(MessageFunction): 
        def compute_message(self, raw_messages):# 作者使用的是这个,啥也没有边,直接返回
            return raw_messages
        """
        unique_messages = self.message_function.compute_message(unique_messages)
	
    # 在头一次训练的过程中进来这个地方, 返回的全是0的矩阵
    # 形状为,[n_nodes, memory_dimension] [n_nodes]
    updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
                                                                                 unique_messages,
                                                                               timestamps=unique_timestamps)

    return updated_memory, updated_last_update
self.message_aggregator.aggregate

代码中默认使用last

def get_message_aggregator(aggregator_type, device):
    if aggregator_type == "last":
        return LastMessageAggregator(device=device)
    elif aggregator_type == "mean":
        return MeanMessageAggregator(device=device)
    else:
        raise ValueError("Message aggregator {} not implemented".format(aggregator_type))

LastMessageAggregator代码:

class LastMessageAggregator(MessageAggregator):
    def __init__(self, device):
        super(LastMessageAggregator, self).__init__(device)

    def aggregate(self, node_ids, messages):
        """Only keep the last message for each node"""
        unique_node_ids = np.unique(node_ids) # 去重节点,不知道啥作用,因为本来就没有重复
        unique_messages = []
        unique_timestamps = []

        to_update_node_ids = []

        for node_id in unique_node_ids: # 循环range(n_nodes)=9228
            if len(messages[node_id]) > 0:
                """
                上一步结束每个节点存储的信息以及对应的(时间?)
                source_message = torch.cat([source_memory, destination_memory, edge_features,
                                    source_time_delta_encoding], dim=1)
                source_message, edge_times
                """
                to_update_node_ids.append(node_id)
                unique_messages.append(messages[node_id][-1][0])
                unique_timestamps.append(messages[node_id][-1][1])

        unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
        unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []

        return to_update_node_ids, unique_messages, unique_timestamps
self.memory_updater.get_updated_memory

代码中默认采用使用gru的方式去更新记忆

class SequenceMemoryUpdater(MemoryUpdater):
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(SequenceMemoryUpdater, self).__init__()
        self.memory = memory
        self.layer_norm = torch.nn.LayerNorm(memory_dimension)
        self.message_dimension = message_dimension
        self.device = device

    def update_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            return

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                          "update memory to time in the past"

        memory = self.memory.get_memory(unique_node_ids)
        self.memory.last_update[unique_node_ids] = timestamps

        updated_memory = self.memory_updater(unique_messages, memory)

        self.memory.set_memory(unique_node_ids, updated_memory)

    def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            # 这里的self.memory在下面进行定义
            # self.memory.memory 在初始化的时候是一个全为0,shape=[n_nodes, memory_dimension], 没有梯度的矩阵
            # self.memory.last_update 在初始化的时候是一个全为0,shape=[n_nodes], 没有梯度的举证
            # 这里的clone是深拷贝,并不会影响原来的值是多少
            
            # 第二次就不是走这里咯
            return self.memory.memory.data.clone(), self.memory.last_update.data.clone()

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                          "update memory to time in the past"

        updated_memory = self.memory.memory.data.clone()
        updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])

        updated_last_update = self.memory.last_update.data.clone()
        updated_last_update[unique_node_ids] = timestamps

        return updated_memory, updated_last_update


class GRUMemoryUpdater(SequenceMemoryUpdater):
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

        self.memory_updater = nn.GRUCell(input_size=message_dimension,
                                         hidden_size=memory_dimension)
Memory
class Memory(nn.Module):

    def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
                 device="cpu", combination_method='sum'):
        super(Memory, self).__init__()
        self.n_nodes = n_nodes
        self.memory_dimension = memory_dimension
        self.input_dimension = input_dimension
        self.message_dimension = message_dimension
        self.device = device

        self.combination_method = combination_method

        self.__init_memory__()
	
    # 这是初是化
    def __init_memory__(self):
        """
        Initializes the memory to all zeros. It should be called at the start of each epoch.
        """
        # Treat memory as parameter so that it is saved and loaded together with the model
        # self.memory_dimension = 172
        # self.n_nodes = 9228
        # self.memory shape is [9228, 172]的一个记忆,每一个节点都有对应的记忆,并且每一个记忆向量是172
        # self.memory = 一个全为0的矩阵
        self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
                                   requires_grad=False)

        # last_update shape = [9228]
        self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
                                        requires_grad=False)

        self.messages = defaultdict(list)

    def store_raw_messages(self, nodes, node_id_to_messages):
        for node in nodes:
            self.messages[node].extend(node_id_to_messages[node])

    def get_memory(self, node_idxs):
        return self.memory[node_idxs, :]

    def set_memory(self, node_idxs, values):
        self.memory[node_idxs, :] = values

    def get_last_update(self, node_idxs):
        return self.last_update[node_idxs]

    def backup_memory(self):
        messages_clone = {}
        for k, v in self.messages.items():
            messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]

        return self.memory.data.clone(), self.last_update.data.clone(), messages_clone

    def restore_memory(self, memory_backup):
        self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()

        self.messages = defaultdict(list)
        for k, v in memory_backup[2].items():
            self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]

    def detach_memory(self):
        self.memory.detach_()

        # Detach all stored messages
        for k, v in self.messages.items():
            new_node_messages = []
            for message in v:
                new_node_messages.append((message[0].detach(), message[1]))

            self.messages[k] = new_node_messages

    def clear_messages(self, nodes):
        for node in nodes:
            self.messages[node] = []
get_embedding_module

这里的module_type=graph_attention

def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
                         time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
                         embedding_dimension, device,
                         n_heads=2, dropout=0.1, n_neighbors=None,
                         use_memory=True):
    # embedding_module采用的是这个
    if module_type == "graph_attention":
        return GraphAttentionEmbedding(node_features=node_features,
                                       edge_features=edge_features,
                                       memory=memory,
                                       neighbor_finder=neighbor_finder,
                                       time_encoder=time_encoder,
                                       n_layers=n_layers,
                                       n_node_features=n_node_features,
                                       n_edge_features=n_edge_features,
                                       n_time_features=n_time_features,
                                       embedding_dimension=embedding_dimension,
                                       device=device,
                                       n_heads=n_heads, dropout=dropout, use_memory=use_memory)
    elif module_type == "graph_sum":
        return GraphSumEmbedding(node_features=node_features,
                                 edge_features=edge_features,
                                 memory=memory,
                                 neighbor_finder=neighbor_finder,
                                 time_encoder=time_encoder,
                                 n_layers=n_layers,
                                 n_node_features=n_node_features,
                                 n_edge_features=n_edge_features,
                                 n_time_features=n_time_features,
                                 embedding_dimension=embedding_dimension,
                                 device=device,
                                 n_heads=n_heads, dropout=dropout, use_memory=use_memory)

    elif module_type == "identity":
        return IdentityEmbedding(node_features=node_features,
                                 edge_features=edge_features,
                                 memory=memory,
                                 neighbor_finder=neighbor_finder,
                                 time_encoder=time_encoder,
                                 n_layers=n_layers,
                                 n_node_features=n_node_features,
                                 n_edge_features=n_edge_features,
                                 n_time_features=n_time_features,
                                 embedding_dimension=embedding_dimension,
                                 device=device,
                                 dropout=dropout)
    elif module_type == "time":
        return TimeEmbedding(node_features=node_features,
                             edge_features=edge_features,
                             memory=memory,
                             neighbor_finder=neighbor_finder,
                             time_encoder=time_encoder,
                             n_layers=n_layers,
                             n_node_features=n_node_features,
                             n_edge_features=n_edge_features,
                             n_time_features=n_time_features,
                             embedding_dimension=embedding_dimension,
                             device=device,
                             dropout=dropout,
                             n_neighbors=n_neighbors)
    else:
        raise ValueError("Embedding Module {} not supported".format(module_type))
GraphAttentionEmbedding
class GraphEmbedding(EmbeddingModule):
    def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True):
        super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
                                             neighbor_finder, time_encoder, n_layers,
                                             n_node_features, n_edge_features, n_time_features,
                                             embedding_dimension, device, dropout)

        self.use_memory = use_memory
        self.device = device

    def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                          use_time_proj=True):
        """Recursive implementation of curr_layers temporal graph attention layers.
		使用递归的方式来实现一系列时间图注意力
        src_idx_l [batch_size]: users / items input ids.
        cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
        curr_layers [scalar]: number of temporal convolutional layers to stack.
        num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
        """
		"""
            memory 记忆对象
            source_nodes 是一个结合了源节点目的节点和负采样节点的node_id列表(一开始是,后面不是)
            timestamps 200*3的时间列表
            self.n_layers 递归的层数 这里为2
            n_neighbors 选取多少个邻居节点 这里是10
            time_diffs 标准化过后的时间差
        """
        assert (n_layers >= 0)
		# source_nodes_torch shape = [n_nodes]
        source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
        # timestamps_torch shape = [3*200, 1]
        timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)

        # query node always has the start time -> time span == 0
        # 这里的time_encoder是一个模型,经过的是一个cos(linear(x)),在下面有对应的代码
        # torch.zeros_like(timestamps_torch) 是一个全为0 shape = [3*200, 1]
        # source_nodes_time_embedding shape = [3*200, 1, 172]
        source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch))
		# self.node_features是一个全为0的矩阵
        # self.node_features shape is [n_nodes, node_dim] = [9228, 172]
        # source_node_features 是所有节点的特征 shape is [600, 172]
        source_node_features = self.node_features[source_nodes_torch, :]

        if self.use_memory:
            # 将节点当前的特征 再加上记忆中节点的特征
            source_node_features = memory[source_nodes, :] + source_node_features
		
        # ====================================== 这下面执行了一个递归的操作 ==================================
        # n_layers = 1
        if n_layers == 0:
            return source_node_features
        else:
			# 再一次调用自己,返回的是节点的特征shape is [600, 172]
            source_node_conv_embeddings = self.compute_embedding(memory,
                                                                 source_nodes,
                                                                 timestamps,
                                                                 n_layers=n_layers - 1,
                                                                 n_neighbors=n_neighbors)
			# 获得是source_nodes这3*200个节点,在3*200的时间列表中,选取前十个邻居
            """
            neighbors shape is [3*200, n_neighbors]
            edge_idxs shape is [3*200, n_neighbors]
            edge_times shape is [3*200, n_neighbors] 
            """
            neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
                source_nodes,
                timestamps,
                n_neighbors=n_neighbors)
			
            # 这里的邻居节点node_id是source_nodes中的每一个邻居节点,变成torch形式
            neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
			
            edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
			
            # 时间差,600个节点的
            edge_deltas = timestamps[:, np.newaxis] - edge_times
            edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
			
            # 展平,变成6000
            neighbors = neighbors.flatten()
			# 这是neighbor_embeddings shape = [600*10, 172]
            neighbor_embeddings = self.compute_embedding(memory,
                                                         neighbors, # 这里有6000个
                                                         np.repeat(timestamps, n_neighbors), # 也是6000
                                                         n_layers=n_layers - 1,
                                                         n_neighbors=n_neighbors)

            effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
			# 这是neighbor_embeddings shape = [600, 10, 172]
            neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
            # edge_time_embeddings shape is [600, 10, 172]
            edge_time_embeddings = self.time_encoder(edge_deltas_torch)
			
            # self.edge_features shape [157475, 172]
            # edge_idxs shape [600, 10]
            # edge_features shape [600, 10, 172]
            edge_features = self.edge_features[edge_idxs, :]

            mask = neighbors_torch == 0
			
            # 这个聚合在下面
            """
            n_layers: 1
            source_node_conv_embeddings: 一开始那600个节点的编码
            source_nodes_time_embedding: 数据是和timestamps_torch一样的0矩阵[3*200, 1, 172]
            neighbor_embeddings: 之前那600个节点的发生过操作的邻居
            edge_time_embeddings: 时间差编码
            edge_features: 一开始那600个节点,对应的十个邻居,分别边的特征是多少
            mask = [600*10]
            """
            source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
                                              source_nodes_time_embedding,
                                              neighbor_embeddings,
                                              edge_time_embeddings,
                                              edge_features,
                                              mask)

            return source_embedding

    def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
                  neighbor_embeddings,
                  edge_time_embeddings, edge_features, mask):
        return NotImplemented


class GraphAttentionEmbedding(GraphEmbedding):
    def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                 n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                 n_heads=2, dropout=0.1, use_memory=True):
        super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
                                                      neighbor_finder, time_encoder, n_layers,
                                                      n_node_features, n_edge_features,
                                                      n_time_features,
                                                      embedding_dimension, device,
                                                      n_heads, dropout,
                                                      use_memory)

        self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
            n_node_features=n_node_features,
            n_neighbors_features=n_node_features,
            n_edge_features=n_edge_features,
            time_dim=n_time_features,
            n_head=n_heads,
            dropout=dropout,
            output_dimension=n_node_features)
            for _ in range(n_layers)])

    def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                  neighbor_embeddings,
                  edge_time_embeddings, edge_features, mask):
        attention_model = self.attention_models[n_layer - 1]

        source_embedding, _ = attention_model(source_node_features,
                                              source_nodes_time_embedding,
                                              neighbor_embeddings,
                                              edge_time_embeddings,
                                              edge_features,
                                              mask)

        return source_embedding
TimeEncode
class TimeEncode(torch.nn.Module):
    # Time Encoding proposed by TGAT
    def __init__(self, dimension):
        super(TimeEncode, self).__init__()

        self.dimension = dimension # 172
        self.w = torch.nn.Linear(1, dimension)

        # todo
        self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
                                           .float().reshape(dimension, -1))
        self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())

    def forward(self, t): # -> [batch_size, seq_len, dimension]
        # t has shape [batch_size, seq_len]
        # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
        t = t.unsqueeze(dim=2)

        # output has shape [batch_size, seq_len, dimension]
        output = torch.cos(self.w(t))

        return output
NeighborFinder
class NeighborFinder:
    def __init__(self, adj_list, uniform=False, seed=None):
        self.node_to_neighbors = []
        self.node_to_edge_idxs = []
        self.node_to_edge_timestamps = []

        for neighbors in adj_list:
            # Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
            # We sort the list based on timestamp
            sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
            self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))  # 是一个二维数组,第一个维度表示的是某一个节点,第二个维度表示的是这个节点和那些节点发生的联系
            self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
            self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))

        self.uniform = uniform

        if seed is not None:
            self.seed = seed
            self.random_state = np.random.RandomState(self.seed)

    def find_before(self, src_idx, cut_time):
        """
        Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.

        Returns 3 lists: neighbors, edge_idxs, timestamps

        """
        i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)

        return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[
                                                                                             src_idx][:i]

    def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
        """
        Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.

        Params
        ------
        src_idx_l: List[int]
        cut_time_l: List[float],
        num_neighbors: int
        """
        assert (len(source_nodes) == len(timestamps))

        tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
        # NB! All interactions described in these matrices are sorted in each row by time
        neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( # shape [600, 10]
            np.int32)  # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
        edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
            np.float32)  # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
        edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
            np.int32)  # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]

        for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
            source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
                                                                                     timestamp)  # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time

            if len(source_neighbors) > 0 and n_neighbors > 0:
                if self.uniform:  # if we are applying uniform sampling, shuffles the data above before sampling
                    sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)

                    neighbors[i, :] = source_neighbors[sampled_idx]
                    edge_times[i, :] = source_edge_times[sampled_idx]
                    edge_idxs[i, :] = source_edge_idxs[sampled_idx]

                    # re-sort based on time
                    pos = edge_times[i, :].argsort()
                    neighbors[i, :] = neighbors[i, :][pos]
                    edge_times[i, :] = edge_times[i, :][pos]
                    edge_idxs[i, :] = edge_idxs[i, :][pos]
                else:
                    # Take most recent interactions
                    source_edge_times = source_edge_times[-n_neighbors:]
                    source_neighbors = source_neighbors[-n_neighbors:]
                    source_edge_idxs = source_edge_idxs[-n_neighbors:]

                    assert (len(source_neighbors) <= n_neighbors)
                    assert (len(source_edge_times) <= n_neighbors)
                    assert (len(source_edge_idxs) <= n_neighbors)

                    neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
                    edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
                    edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs

        return neighbors, edge_idxs, edge_times
class TemporalAttentionLayer(torch.nn.Module):
  """
  Temporal attention layer. Return the temporal embedding of a node given the node itself,
   its neighbors and the edge timestamps.
  """

  def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
               output_dimension, n_head=2,
               dropout=0.1):
    super(TemporalAttentionLayer, self).__init__()

    self.n_head = n_head

    self.feat_dim = n_node_features
    self.time_dim = time_dim

    self.query_dim = n_node_features + time_dim
    self.key_dim = n_neighbors_features + time_dim + n_edge_features

    self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)

    self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
                                                   kdim=self.key_dim,
                                                   vdim=self.key_dim,
                                                   num_heads=n_head,
                                                   dropout=dropout)

  def forward(self, src_node_features, src_time_features, neighbors_features,
              neighbors_time_features, edge_features, neighbors_padding_mask):
    """
    "Temporal attention model
    :param src_node_features: float Tensor of shape [batch_size, n_node_features]
    :param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
    :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
    :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
    time_dim]
    :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
    :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
    :return:
    attn_output: float Tensor of shape [1, batch_size, n_node_features]
    attn_output_weights: [batch_size, 1, n_neighbors]
    """
	
    # src_node_features_unrolled shape is [600, 1, 172]
    src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
	
    # 将节点特征和时间特征结合在一起
    # query shape is [600, 1, 172*2]
    query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
    # 邻居的特征、边的特征和时间差特征组合在一起 key shape = [600, 10, 516]
    key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)

    # query shape is [1, 600, 344]
    query = query.permute([1, 0, 2])  # [1, batch_size, num_of_features]
    # key shape is [10, 600, 516]
    key = key.permute([1, 0, 2])  # [n_neighbors, batch_size, num_of_features]
    # 在dim=1的维度下,要是全为True,那么就代表这一行是没有用的,反之为False
    invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
	# 
    neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False

    # print(query.shape, key.shape)

    attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
                                                              key_padding_mask=neighbors_padding_mask)

    # mask = torch.unsqueeze(neighbors_padding_mask, dim=2)  # mask [B, N, 1]
    # mask = mask.permute([0, 2, 1])
    # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key,
    #                                                           mask=mask)
	
    # attn_output shape = [600, 344]
    # attn_output_weights = [600, 10]
    attn_output = attn_output.squeeze()
    attn_output_weights = attn_output_weights.squeeze()

    # Source nodes with no neighbors have an all zero attention output. The attention output is
    # then added or concatenated to the original source node features and then fed into an MLP.
    # This means that an all zero vector is not used.
    attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
    attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)

    # Skip connection with temporal attention over neighborhood and the features of the node itself
    # attn_output = [600, 172]
    attn_output = self.merger(attn_output, src_node_features)

    return attn_output, attn_output_weights
MergeLayer
class MergeLayer(torch.nn.Module):
    def __init__(self, dim1, dim2, dim3, dim4):
        super().__init__()
        self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
        self.fc2 = torch.nn.Linear(dim3, dim4)
        self.act = torch.nn.ReLU()

        torch.nn.init.xavier_normal_(self.fc1.weight)
        torch.nn.init.xavier_normal_(self.fc2.weight)

    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        h = self.act(self.fc1(x))
        return self.fc2(h)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1299800.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从线性回归到神经网络

一、线性回归关键思想 1、线性模型 2、基础优化算法 二、线性回归的从零开始实现 在了解线性回归的关键思想之后&#xff0c;我们可以开始通过代码来动手实现线性回归了。在这一节中&#xff0c;我们将从零开始实现整个方法&#xff0c;包括数据流水线、模型、损失函数和小批量…

深度探索Linux操作系统 —— 从内核空间到用户空间

系列文章目录 深度探索Linux操作系统 —— 编译过程分析 深度探索Linux操作系统 —— 构建工具链 深度探索Linux操作系统 —— 构建内核 深度探索Linux操作系统 —— 构建initramfs 深度探索Linux操作系统 —— 从内核空间到用户空间 文章目录 系列文章目录一、Linux操作系统加…

Java IO流(六)(字符流FileReader和FileWriter)

字符流 字符流的底层其实就是字节流 字符流字节流字符集 特点 输入流&#xff1a;一次读一个字节&#xff0c;遇到中文时&#xff0c;一次读多个字节 使用场景 对于纯文本文件进行读写操作 FileReader类 ①创建字符输入流对象 构造方法 说明 public FileReader(File f…

Configuring environment||ROS2环境配置

Goal: This tutorial will show you how to prepare your ROS 2 environment. Tutorial level: Beginner Time: 5 minutes ROS 2 relies on the notion &#xff08;concept&#xff09;of combining workspaces using the shell environment. “Workspace” is a ROS term …

读书笔记:《股票量化交易的七个策略》

从长远来看&#xff0c;基本面最重要&#xff1b;从短期来看&#xff0c;价格和情绪最重要。在别人贪婪时恐惧&#xff0c;在别人恐惧时贪婪。 相对强弱指数策略【趋势反转】 相对强弱指数&#xff08;Relative Strength Index&#xff0c;RSI&#xff09; RSI的取值范围在0到…

亚马逊云科技re_Invent 2023产品体验:亚马逊云科技产品应用实践 国赛选手带你看Elasticache Serverless

抛砖引玉 讲一下作者背景&#xff0c;曾经参加过国内世界技能大赛云计算的选拔&#xff0c;那么在竞赛中包含两类&#xff0c;一类是架构类竞赛&#xff0c;另一类就是TroubleShooting竞赛&#xff0c;对应的分别为AWS GameDay和AWS Jam&#xff0c;想必也有朋友玩过此类竞赛&…

【C++】输入输出流 ⑥ ( cout 标准输出流对象 | cout 常用 api 简介 | cout.put(char c) 函数 )

文章目录 一、cout 标准输出流对象1、cout 标准输出流对象简介2、cout 常用 api 简介 二、cout.put(char c) 函数1、cout.put(char c) 函数 简介2、代码示例 - cout.put(char c) 函数 一、cout 标准输出流对象 1、cout 标准输出流对象简介 cout 是 标准输出流 对象 , 是 ostrea…

Redis的概念与常见命令

&#x1f307;个人主页&#xff1a;平凡的小苏 &#x1f4da;学习格言&#xff1a;命运给你一个低的起点&#xff0c;是想看你精彩的翻盘&#xff0c;而不是让你自甘堕落&#xff0c;脚下的路虽然难走&#xff0c;但我还能走&#xff0c;比起向阳而生&#xff0c;我更想尝试逆风…

新版Spring Security6.2架构 (一)

Spring Security 新版springboot 3.2已经集成Spring Security 6.2&#xff0c;和以前会有一些变化&#xff0c;本文主要针对官网的文档进行一些个人翻译和个人理解&#xff0c;不对地方请指正。 整体架构 Spring Security的Servlet 支持是基于Servelet过滤器&#xff0c;如下…

永磁同步电机反电动势系数怎么算?磁链强度怎么算?转矩系数怎么算?

在进行永磁同步电机simulink仿真时&#xff0c;一个关键参数就是永磁体磁链强度&#xff0c; 实际上在simulink中&#xff0c;永磁体磁链强度/反电动势系数/转矩系数这三个是放一起的&#xff0c;这是因为他们都可互相算出来。 Torque constat (Nm/A) 转矩系数 1.5 * Np *…

Java 匿名内部类使用的外部变量,为什么一定要加 final?

问题描述 Effectively final Java 1.8 新特性&#xff0c;对于一个局部变量或方法参数&#xff0c;如果他的值在初始化后就从未更改&#xff0c;那么该变量就是 effectively final&#xff08;事实 final&#xff09;。 这种情况下&#xff0c;可以不用加 final 关键字修饰。 …

计算机设备管理器如何看内存,怎么查看电脑配置信息?3种方法,让你掌握电脑全部信息!...

转载&#xff1a;https://blog.csdn.net/weixin_35849957/article/details/118512756?spm1001.2014.3001.5502 原标题&#xff1a;怎么查看电脑配置信息&#xff1f;3种方法&#xff0c;让你掌握电脑全部信息&#xff01; 电脑的配置决定了电脑性能高低以及运行速度。而电脑…

Ubuntu22.04 LTS + CUDA12.3 + CUDNN8.9.7 + PyTorch2.1.1

简介 本文记录Ubuntu22.04长期支持版系统下的CUDA驱动和cuDNN神经网络加速库的安装&#xff0c;并安装PyTorch2.1.1来测试是否安装成功。 安装Ubuntu系统 如果是旧的不支持UEFI启动的主板&#xff0c;请参考本人博客U盘系统盘制作与系统安装&#xff08;详细图解&#xff09…

深度学习基础介绍

定义&#xff1a; 深度学习是机器学习领域中一个新的研究方向&#xff0c;被引入机器学习使其更接近于最初的目标&#xff0c;即人工智能AI&#xff0c; Artifical Intelligence。 深度学习是学习样本数据的内在规律和表示层次&#xff0c;这些学习过程中获得的信息对诸如文字…

Docker网络架构介绍

本文主要介绍了Docker容器的单机网络架构与集群网络架构&#xff0c;辅以演示&#xff0c;并简单介绍了网络管理中的命令。 前文&#xff1a; Docker的安装与简单操作命令-CSDN博客 docker网络原理介绍 与ovs类似&#xff0c;docker容器采用veth-pair linux bridge (虚拟交…

CPU设计——Triumphcore——MP_work版本

该版本用作系统寄存器的实现&#xff0c;M/S/U状态的实现与切换&#xff0c;以及load/store的虚实地址转换 设计指标 2023.12.8 2023.12.9 不实现mideleg和medeleg&#xff0c;因此一旦出现异常&#xff0c;直接切换至M态&#xff0c; 调试记录 到存储区中取PTE要额外至少…

hive 命令行中使用 replace 和nvl2 函数报错

1.有时候在命令行的情况下使用 replace 函数时会报错 这个时候可以使用 translate 代替 2.有时候使用 nvl2() 函数的时候会报错 这个时候可以用 case when 来代替

IO / day06 作业

1.使用有名管道&#xff0c;完成两个进程的相互通信 代码&#xff1a; // 使用有名管道&#xff0c;完成两个进程的相互通信#include <myhead.h>// task sender void *tasks(void *arg) {printf("I am tasks\n");int fdw -1;const char **ppargv (const c…

Redis核心知识点总结

1.Redis介绍 Redis 是 NoSQL&#xff0c;但是可处理 1 秒 10w 的并发&#xff08;数据都在内存中&#xff09; 使用 java 对 redis 进行操作类似 jdbc 接口标准对 mysql&#xff0c;有各类实现他的实现类&#xff0c;我们常用的是 druid 其中对 redis&#xff0c;我们通常用 J…

链表面试题的总结和思路分享

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好&#xff0c;我是xiaoxie.希望你看完之后,有不足之处请多多谅解&#xff0c;让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …