KGAT: Knowledge Graph Attention Network for Recommendation

news2025/1/11 23:44:47

[1905.07854] KGAT: Knowledge Graph Attention Network for Recommendation (arxiv.org)

LunaBlack/KGAT-pytorch (github.com)

目录

1、背景

2、任务定义

3、模型

3.1 Embedding layer

3.2 Attentive Embedding Propagation Layers

3.3 Model Prediction

3.4 Optimization

4、部分代码解读

4.1 数据集

4.2 数据集的处理

4.3 模型

4.4 模型训练


1、背景

CF方法,基于相似用户或者相似商品属性推荐,无法利用属性等各种side information,例如u1和u2相似,则可能会推荐i2

基于特征的SL模型,例如FM/NFM/Wide&Deep可以利用side-information,i1和i2有相同属性e1则推荐i2。

但是基于特征的SL模型单独的建模每个实例,没有建模实例之间的交互,无法从集体行为提取有用信息。例如u1很难对u1很难建模

尽管e1是连接导演和演员字段的桥梁。因此,我们认为这些方法没有充分探索高阶连通性,并且没有触及组合高阶关系

为了解决基于特征的SL模型的局限性,可以将知识图与用户-项目图的混合结构称为协同知识图(CKG),但是也有挑战:1)与目标用户具有高阶关系的节点随着订单规模的增加而急剧增加,这会给模型带来计算过载;2)高阶关系对预测的贡献不平等,这需要模型仔细加权(或选择)它们。

已经有一些基于CKG模型进行推荐的方法:

(1)基于路径的方法

提取携带高阶信息的路径并将其输入到预测模型,但是第一阶段的路径选择对性能影响很大,而且定义有效元路径需要领域知识,工作过量很大。

(2)基于正则化的方法

基于正则化的方法设计额外损失项捕获KG结构,正则化推荐模型。联合训练推荐和KGC两个任务,两个任务间共享item embedding。这些方法不是直接将高阶关系插入为推荐而优化的模型中,而是以隐式的方式对它们进行编码。由于缺乏显式建模,既不能保证捕获远程连接,也不能解释高阶建模的结果。

考虑到上述局限性,作者开发一个能够以高效、显式和端到端方式利用KG中的高阶信息的模型。

2、任务定义

User-Item Bipartite Graph:在推荐场景中,通常有历史用户项目交互(例如,购买和点击),将交互数据表示为用户-物品二部图。

Knowledge Graph:除了user-item之间的交互外,还有物品的side-information(例如,物品属性和外部知识)。通常,这些辅助数据由真实世界的实体和额外的知识组成。作者将side-information组成知识图,

 并有一个实体和item的对齐的集合,

 Collaborative Knowledge Graph:定义了CKG的概念,它将用户行为和商品知识编码为一个统一的关系图首先将每个用户行为表示为三元组(u, interaction,i),其中y_ui = 1表示为用户u与项目i之间的附加关系interaction。然后基于entity-item对齐集,将user-item图与KG无缝集成为统一图

制定本文要解决的推荐任务:

•输入:协作知识图G,包括用户项二部图G1和知识图G2。

•输出:预测函数,预测用户u采用物品i的概率为\hat{y}_{ui}

3、模型

最主要的是将user-item交互也融入KG中计算

3.1 Embedding layer

使用TransR建模,首先将头实体eh和尾实体er利用由特定于关系的投影矩阵Wr投影到关系所在的空间,然后再计算三元组得分(投影的头实体+关系得到的向量,和投影的尾实体向量越相似越好,g越小越好)

 损失采用对比学习方法

3.2 Attentive Embedding Propagation Layers

(1)权重计算

选择tanh作为非线性激活函数。这使得注意力得分依赖于关系r空间中eh和et之间的距离,为更接近的实体传播更多信息,为简单,只使用内积计算

 接着使用softmax

最终的注意力得分能够建议应该给予哪些邻居节点更多的关注来捕获协作信号。在进行前向传播时,注意流会提示需要关注的部分数据,这可以视为推荐背后的解释。

(2)消息传递

为了表征实体h的一阶连通性结构,计算了h的自我网络的线性组合(自我网络,是h为头实体的三元组的集合)

(3)聚合

最后一个阶段是将实体自己本身的表征eh和它的自我网络表征e_Nh聚合为实体h的新表征——更正式地说, e{_{(1)}^{h}}=f(e_{h},e_{N_{h}})

  • GCN Aggregator:将两个表征向量求和,并经过一个非线性激活函数

  • GraphSage Aggregator :将两个表征向量拼接,经过一个非线性函数

  • Bi-Interaction Aggregator:设计了两种函数,求和,以及两个特征向量元素积,并求和再通过一个非线性函数

 (4)高阶传播

以上展示的是一阶传播和聚合的例子,很容易可以推广到高阶。

在第l步中,递归地将实体的表示表示为:

 实体h在自我网络内传播的信息定义如下:

3.3 Model Prediction

执行L层后,得到用户节点u的多个表示,即\{e_{(u)}^{1},e_{(u)}^{2},...,e_{(u)}^{L}\};与项目节点i类似,得到\{e_{(i)}^{1},e_{(i)}^{2},...,e_{(i)}^{L}\}。由于第l层的输出是图1中根于u(或i)的l的树结构深度的消息聚合,因此不同层的输出强调的是不同阶次的连通性信息。因此,采用层聚合机制,将每一步的表示concatence成单个向量

这样一来,不仅可以通过进行嵌入传播操作来丰富初始嵌入,还可以通过调整L来控制传播强度。

最后,对用户表征与物品表征进行内积,从而预测其匹配得分

3.4 Optimization

使用BPR损失优化推荐模型,它假设观察到的交互,这表明更多的用户偏好,应该被赋予比未观察到的更高的预测值:

 (u,i)是观察到的真实的交互,(u,j)是未观察到(负样本)交互。

最后的损失函数,联合嵌入损失和推荐系统损失以及正则化

 在训练时,交替优化KG嵌入损失和CF推荐损失。

4、部分代码解读

4.1 数据集

最后的CKG图由user-item交互二部图以及补充item信息的KG组成。

  • train.txt/test.txt

训练数据集,由user id 和此user交互的itemID list组成。测试集和训练集中出现的交互为positive sample,没有观察到的交互作为negative sample。

  • user_list.txt

由原来的user id,已经映射到CKG dataset中的id组成org_id remap_id

  • item_list.txt

由原来的item id,已经映射到CKG dataset中的id,以及item在freebase中对应的id组成org_id remap_id freebase_id

  • entity_list.txt

表明KG中的实体,由原来的在freebase中的entity id,已经映射到CKG dataset中的id组成org_id remap_id

  • relation_list.txt

表明KG中的relation,由原来的在freebase中的relation id,已经映射到CKG dataset中的id组成org_id remap_id

4.2 数据集的处理

  • 将kg添加逆关系,并对关系重新编号,做法是+2;将user-item交互图融入kg中,将user重新编号user id+实体总数,将user-item编码为0,将item-user编码为1
  • 采样一个batch_size的数据,包含bath_size的user,以及为每一个user采样user-item交互的正样例,负样例
  • 对产生的CKG图{h:(r,t)}进行采样生成负例正例。
loader_base.py

    def load_cf(self, filename):
        """
        函数说明:对user-item交互矩阵进行处理
        Return:
            (user, item) - user和其作用的item
            user_dict - {user-id:[item1,item2,..],}
        """
        user = []
        item = []
        user_dict = dict()

        lines = open(filename, 'r').readlines()
        for l in lines:
            tmp = l.strip()
            inter = [int(i) for i in tmp.split()]

            if len(inter) > 1:
                user_id, item_ids = inter[0], inter[1:]
                item_ids = list(set(item_ids))

                for item_id in item_ids:
                    user.append(user_id)
                    item.append(item_id)
                user_dict[user_id] = item_ids

        user = np.array(user, dtype=np.int32)
        item = np.array(item, dtype=np.int32)
        return (user, item), user_dict

    def statistic_cf(self):
        """
        获取user、item、训练集、测试集总数
        """
        self.n_users = max(max(self.cf_train_data[0]), max(self.cf_test_data[0])) + 1
        self.n_items = max(max(self.cf_train_data[1]), max(self.cf_test_data[1])) + 1
        self.n_cf_train = len(self.cf_train_data[0])
        self.n_cf_test = len(self.cf_test_data[0])


    def load_kg(self, filename):
        """
        读取最后的CKG数据,返回dataframe形式
        """
        kg_data = pd.read_csv(filename, sep=' ', names=['h', 'r', 't'], engine='python')
        kg_data = kg_data.drop_duplicates()
        return kg_data


    def sample_pos_items_for_u(self, user_dict, user_id, n_sample_pos_items):
        """
        对user-item交互正样本进行采样
        """
        pos_items = user_dict[user_id]
        n_pos_items = len(pos_items)

        sample_pos_items = []
        while True:
            if len(sample_pos_items) == n_sample_pos_items:
                break

            pos_item_idx = np.random.randint(low=0, high=n_pos_items, size=1)[0]
            pos_item_id = pos_items[pos_item_idx]
            if pos_item_id not in sample_pos_items:
                sample_pos_items.append(pos_item_id)
        return sample_pos_items


    def sample_neg_items_for_u(self, user_dict, user_id, n_sample_neg_items):
        """
        为user-item交互采样负样例
        """
        pos_items = user_dict[user_id]

        sample_neg_items = []
        while True:
            if len(sample_neg_items) == n_sample_neg_items:
                break

            neg_item_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
            if neg_item_id not in pos_items and neg_item_id not in sample_neg_items:
                sample_neg_items.append(neg_item_id)
        return sample_neg_items


    def generate_cf_batch(self, user_dict, batch_size):
        """
        采样batch_size的user,并对对这些user采样正样本,负样本
        """
        exist_users = user_dict.keys()
        if batch_size <= len(exist_users):
            batch_user = random.sample(exist_users, batch_size)
        else:
            batch_user = [random.choice(exist_users) for _ in range(batch_size)]

        batch_pos_item, batch_neg_item = [], []
        for u in batch_user:
            # 为每一个采样的user生成一个正样例和一个负样例
            batch_pos_item += self.sample_pos_items_for_u(user_dict, u, 1)
            batch_neg_item += self.sample_neg_items_for_u(user_dict, u, 1)

        batch_user = torch.LongTensor(batch_user)
        batch_pos_item = torch.LongTensor(batch_pos_item)
        batch_neg_item = torch.LongTensor(batch_neg_item)
        return batch_user, batch_pos_item, batch_neg_item


    def sample_pos_triples_for_h(self, kg_dict, head, n_sample_pos_triples):
        """
        为融合user-item交互的CKG图采样正例
        """
        pos_triples = kg_dict[head]
        n_pos_triples = len(pos_triples)

        sample_relations, sample_pos_tails = [], []
        while True:
            if len(sample_relations) == n_sample_pos_triples:
                break

            pos_triple_idx = np.random.randint(low=0, high=n_pos_triples, size=1)[0]
            tail = pos_triples[pos_triple_idx][0]
            relation = pos_triples[pos_triple_idx][1]

            if relation not in sample_relations and tail not in sample_pos_tails:
                sample_relations.append(relation)
                sample_pos_tails.append(tail)
        return sample_relations, sample_pos_tails


    def sample_neg_triples_for_h(self, kg_dict, head, relation, n_sample_neg_triples, highest_neg_idx):
        """
        为融合user-item交互的CKG图采样负例
        """
        pos_triples = kg_dict[head]

        sample_neg_tails = []
        while True:
            if len(sample_neg_tails) == n_sample_neg_triples:
                break

            tail = np.random.randint(low=0, high=highest_neg_idx, size=1)[0]
            if (tail, relation) not in pos_triples and tail not in sample_neg_tails:
                sample_neg_tails.append(tail)
        return sample_neg_tails


    def generate_kg_batch(self, kg_dict, batch_size, highest_neg_idx):
        """为训练集CKG中每一个头实体采样一个正例的(r,t),一个负例的t"""
        exist_heads = kg_dict.keys()
        if batch_size <= len(exist_heads):
            batch_head = random.sample(exist_heads, batch_size)
        else:
            batch_head = [random.choice(exist_heads) for _ in range(batch_size)]

        batch_relation, batch_pos_tail, batch_neg_tail = [], [], []
        for h in batch_head:
            relation, pos_tail = self.sample_pos_triples_for_h(kg_dict, h, 1)
            batch_relation += relation
            batch_pos_tail += pos_tail

            neg_tail = self.sample_neg_triples_for_h(kg_dict, h, relation[0], 1, highest_neg_idx)
            batch_neg_tail += neg_tail

        batch_head = torch.LongTensor(batch_head)
        batch_relation = torch.LongTensor(batch_relation)
        batch_pos_tail = torch.LongTensor(batch_pos_tail)
        batch_neg_tail = torch.LongTensor(batch_neg_tail)
        return batch_head, batch_relation, batch_pos_tail, batch_neg_tail

loader_kgat.py
    def construct_data(self, kg_data):
        """
        函数说明:创建逆边,并把user-item交互图融入,创建CKG
        """
        # add inverse kg data
        n_relations = max(kg_data['r']) + 1
        inverse_kg_data = kg_data.copy()
        inverse_kg_data = inverse_kg_data.rename({'h': 't', 't': 'h'}, axis='columns')
        inverse_kg_data['r'] += n_relations
        kg_data = pd.concat([kg_data, inverse_kg_data], axis=0, ignore_index=True, sort=False)

        # re-map user id
        kg_data['r'] += 2
        self.n_relations = max(kg_data['r']) + 1
        self.n_entities = max(max(kg_data['h']), max(kg_data['t'])) + 1
        self.n_users_entities = self.n_users + self.n_entities

        # re-map user id = user-item中的id + num_entities
        self.cf_train_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_train_data[0]))).astype(np.int32), self.cf_train_data[1].astype(np.int32))
        self.cf_test_data = (np.array(list(map(lambda d: d + self.n_entities, self.cf_test_data[0]))).astype(np.int32), self.cf_test_data[1].astype(np.int32))

        self.train_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.train_user_dict.items()}
        self.test_user_dict = {k + self.n_entities: np.unique(v).astype(np.int32) for k, v in self.test_user_dict.items()}

        # add interactions to kg data
        # 将user-item交互数据融入kg中user交互item的关系编码为0,item-user交互编码为1
        cf2kg_train_data = pd.DataFrame(np.zeros((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        cf2kg_train_data['h'] = self.cf_train_data[0]
        cf2kg_train_data['t'] = self.cf_train_data[1]

        inverse_cf2kg_train_data = pd.DataFrame(np.ones((self.n_cf_train, 3), dtype=np.int32), columns=['h', 'r', 't'])
        inverse_cf2kg_train_data['h'] = self.cf_train_data[1]
        inverse_cf2kg_train_data['t'] = self.cf_train_data[0]

        self.kg_train_data = pd.concat([kg_data, cf2kg_train_data, inverse_cf2kg_train_data], ignore_index=True)
        self.n_kg_train = len(self.kg_train_data)

        # construct kg dict
        h_list = []
        t_list = []
        r_list = []

        self.train_kg_dict = collections.defaultdict(list)
        self.train_relation_dict = collections.defaultdict(list)

        for row in self.kg_train_data.iterrows():
            h, r, t = row[1]
            h_list.append(h)
            t_list.append(t)
            r_list.append(r)

            self.train_kg_dict[h].append((t, r))
            self.train_relation_dict[r].append((h, t))

        self.h_list = torch.LongTensor(h_list)
        self.t_list = torch.LongTensor(t_list)
        self.r_list = torch.LongTensor(r_list)

4.3 模型

  • 权重的计算:内积计算相似性,越相似的尾实体,则应该传递更多消息,权重应该更大

  • 消息的传递和聚合

聚合ego-netework嵌入的加权和以及自身嵌入 

  • 损失函数:包含CF的损失和KGC的损失,以及参数的正则化部分

 

  •  预测

将多层的消息传递聚合结果拼接起来,然后进行内积运算,得到用户点击某物品的概率

KGAT.py

import torch
import torch.nn as nn
import torch.nn.functional as F


def _L2_loss_mean(x):
    return torch.mean(torch.sum(torch.pow(x, 2), dim=1, keepdim=False) / 2.)


class Aggregator(nn.Module):

    def __init__(self, in_dim, out_dim, dropout, aggregator_type):
        super(Aggregator, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dropout = dropout
        self.aggregator_type = aggregator_type

        self.message_dropout = nn.Dropout(dropout)
        self.activation = nn.LeakyReLU()

        if self.aggregator_type == 'gcn':
            self.linear = nn.Linear(self.in_dim, self.out_dim)       # W in Equation (6)
            nn.init.xavier_uniform_(self.linear.weight)

        elif self.aggregator_type == 'graphsage':
            self.linear = nn.Linear(self.in_dim * 2, self.out_dim)   # W in Equation (7)
            nn.init.xavier_uniform_(self.linear.weight)

        elif self.aggregator_type == 'bi-interaction':
            self.linear1 = nn.Linear(self.in_dim, self.out_dim)      # W1 in Equation (8)
            self.linear2 = nn.Linear(self.in_dim, self.out_dim)      # W2 in Equation (8)
            nn.init.xavier_uniform_(self.linear1.weight)
            nn.init.xavier_uniform_(self.linear2.weight)

        else:
            raise NotImplementedError


    def forward(self, ego_embeddings, A_in):
        """
        ego_embeddings:  (n_users + n_entities, in_dim)
        A_in:            (n_users + n_entities, n_users + n_entities), torch.sparse.FloatTensor
        """
        # Equation (3)
        side_embeddings = torch.matmul(A_in, ego_embeddings)

        if self.aggregator_type == 'gcn':
            # Equation (6) & (9)
            embeddings = ego_embeddings + side_embeddings
            embeddings = self.activation(self.linear(embeddings))

        elif self.aggregator_type == 'graphsage':
            # Equation (7) & (9)
            embeddings = torch.cat([ego_embeddings, side_embeddings], dim=1)
            embeddings = self.activation(self.linear(embeddings))

        elif self.aggregator_type == 'bi-interaction':
            # Equation (8) & (9)
            sum_embeddings = self.activation(self.linear1(ego_embeddings + side_embeddings))
            bi_embeddings = self.activation(self.linear2(ego_embeddings * side_embeddings))
            embeddings = bi_embeddings + sum_embeddings

        embeddings = self.message_dropout(embeddings)           # (n_users + n_entities, out_dim)
        return embeddings


class KGAT(nn.Module):

    def __init__(self, args,
                 n_users, n_entities, n_relations, A_in=None,
                 user_pre_embed=None, item_pre_embed=None):

        super(KGAT, self).__init__()
        self.use_pretrain = args.use_pretrain

        self.n_users = n_users
        self.n_entities = n_entities
        self.n_relations = n_relations

        self.embed_dim = args.embed_dim
        self.relation_dim = args.relation_dim

        self.aggregation_type = args.aggregation_type
        self.conv_dim_list = [args.embed_dim] + eval(args.conv_dim_list)
        self.mess_dropout = eval(args.mess_dropout)
        self.n_layers = len(eval(args.conv_dim_list))

        self.kg_l2loss_lambda = args.kg_l2loss_lambda
        self.cf_l2loss_lambda = args.cf_l2loss_lambda

        self.entity_user_embed = nn.Embedding(self.n_entities + self.n_users, self.embed_dim)
        self.relation_embed = nn.Embedding(self.n_relations, self.relation_dim)
        self.trans_M = nn.Parameter(torch.Tensor(self.n_relations, self.embed_dim, self.relation_dim))

        if (self.use_pretrain == 1) and (user_pre_embed is not None) and (item_pre_embed is not None):
            other_entity_embed = nn.Parameter(torch.Tensor(self.n_entities - item_pre_embed.shape[0], self.embed_dim))
            nn.init.xavier_uniform_(other_entity_embed)
            entity_user_embed = torch.cat([item_pre_embed, other_entity_embed, user_pre_embed], dim=0)
            self.entity_user_embed.weight = nn.Parameter(entity_user_embed)
        else:
            nn.init.xavier_uniform_(self.entity_user_embed.weight)

        nn.init.xavier_uniform_(self.relation_embed.weight)
        nn.init.xavier_uniform_(self.trans_M)

        self.aggregator_layers = nn.ModuleList()
        for k in range(self.n_layers):
            self.aggregator_layers.append(Aggregator(self.conv_dim_list[k], self.conv_dim_list[k + 1], self.mess_dropout[k], self.aggregation_type))

        # A是邻接矩阵
        self.A_in = nn.Parameter(torch.sparse.FloatTensor(self.n_users + self.n_entities, self.n_users + self.n_entities))
        if A_in is not None:
            self.A_in.data = A_in
        self.A_in.requires_grad = False


    def calc_cf_embeddings(self):
        """
        计算多层的消息传递和聚合
        """
        ego_embed = self.entity_user_embed.weight
        all_embed = [ego_embed]

        for idx, layer in enumerate(self.aggregator_layers):
            ego_embed = layer(ego_embed, self.A_in)
            norm_embed = F.normalize(ego_embed, p=2, dim=1)
            all_embed.append(norm_embed)

        # Equation (11)
        all_embed = torch.cat(all_embed, dim=1)         # (n_users + n_entities, concat_dim)
        return all_embed


    def calc_cf_loss(self, user_ids, item_pos_ids, item_neg_ids):
        """
        user_ids:       (cf_batch_size)
        item_pos_ids:   (cf_batch_size)
        item_neg_ids:   (cf_batch_size)
        """
        all_embed = self.calc_cf_embeddings()                       # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                            # (cf_batch_size, concat_dim)
        item_pos_embed = all_embed[item_pos_ids]                    # (cf_batch_size, concat_dim)
        item_neg_embed = all_embed[item_neg_ids]                    # (cf_batch_size, concat_dim)

        # Equation (12)
        pos_score = torch.sum(user_embed * item_pos_embed, dim=1)   # (cf_batch_size)
        neg_score = torch.sum(user_embed * item_neg_embed, dim=1)   # (cf_batch_size)

        # Equation (13)
        # cf_loss = F.softplus(neg_score - pos_score)
        cf_loss = (-1.0) * F.logsigmoid(pos_score - neg_score)
        cf_loss = torch.mean(cf_loss)

        l2_loss = _L2_loss_mean(user_embed) + _L2_loss_mean(item_pos_embed) + _L2_loss_mean(item_neg_embed)
        loss = cf_loss + self.cf_l2loss_lambda * l2_loss
        return loss


    def calc_kg_loss(self, h, r, pos_t, neg_t):
        """
        h:      (kg_batch_size)
        r:      (kg_batch_size)
        pos_t:  (kg_batch_size)
        neg_t:  (kg_batch_size)
        """
        r_embed = self.relation_embed(r)                                                # (kg_batch_size, relation_dim)
        W_r = self.trans_M[r]                                                           # (kg_batch_size, embed_dim, relation_dim)

        h_embed = self.entity_user_embed(h)                                             # (kg_batch_size, embed_dim)
        pos_t_embed = self.entity_user_embed(pos_t)                                     # (kg_batch_size, embed_dim)
        neg_t_embed = self.entity_user_embed(neg_t)                                     # (kg_batch_size, embed_dim)

        r_mul_h = torch.bmm(h_embed.unsqueeze(1), W_r).squeeze(1)                       # (kg_batch_size, relation_dim)
        r_mul_pos_t = torch.bmm(pos_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)
        r_mul_neg_t = torch.bmm(neg_t_embed.unsqueeze(1), W_r).squeeze(1)               # (kg_batch_size, relation_dim)

        # Equation (1)
        pos_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_pos_t, 2), dim=1)     # (kg_batch_size)
        neg_score = torch.sum(torch.pow(r_mul_h + r_embed - r_mul_neg_t, 2), dim=1)     # (kg_batch_size)

        # Equation (2)
        # kg_loss = F.softplus(pos_score - neg_score)
        kg_loss = (-1.0) * F.logsigmoid(neg_score - pos_score)
        kg_loss = torch.mean(kg_loss)

        l2_loss = _L2_loss_mean(r_mul_h) + _L2_loss_mean(r_embed) + _L2_loss_mean(r_mul_pos_t) + _L2_loss_mean(r_mul_neg_t)
        loss = kg_loss + self.kg_l2loss_lambda * l2_loss
        return loss


    def update_attention_batch(self, h_list, t_list, r_idx):
        """
        更新注意力权重
        """
        r_embed = self.relation_embed.weight[r_idx]
        W_r = self.trans_M[r_idx]

        h_embed = self.entity_user_embed.weight[h_list]
        t_embed = self.entity_user_embed.weight[t_list]

        # Equation (4)
        r_mul_h = torch.matmul(h_embed, W_r)
        r_mul_t = torch.matmul(t_embed, W_r)
        v_list = torch.sum(r_mul_t * torch.tanh(r_mul_h + r_embed), dim=1)
        return v_list


    def update_attention(self, h_list, t_list, r_list, relations):
        device = self.A_in.device

        rows = []
        cols = []
        values = []

        for r_idx in relations:
            index_list = torch.where(r_list == r_idx)
            batch_h_list = h_list[index_list]
            batch_t_list = t_list[index_list]

            batch_v_list = self.update_attention_batch(batch_h_list, batch_t_list, r_idx)
            rows.append(batch_h_list)
            cols.append(batch_t_list)
            values.append(batch_v_list)

        rows = torch.cat(rows)
        cols = torch.cat(cols)
        values = torch.cat(values)

        indices = torch.stack([rows, cols])
        shape = self.A_in.shape
        A_in = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

        # Equation (5)
        A_in = torch.sparse.softmax(A_in.cpu(), dim=1)
        self.A_in.data = A_in.to(device)


    def calc_score(self, user_ids, item_ids):
        """
        user_ids:  (n_users)
        item_ids:  (n_items)
        计算user点击item的得分
        """
        all_embed = self.calc_cf_embeddings()           # (n_users + n_entities, concat_dim)
        user_embed = all_embed[user_ids]                # (n_users, concat_dim)
        item_embed = all_embed[item_ids]                # (n_items, concat_dim)

        # Equation (12)
        cf_score = torch.matmul(user_embed, item_embed.transpose(0, 1))    # (n_users, n_items)
        return cf_score


    def forward(self, *input, mode):
        if mode == 'train_cf':
            return self.calc_cf_loss(*input)
        if mode == 'train_kg':
            return self.calc_kg_loss(*input)
        if mode == 'update_att':
            return self.update_attention(*input)
        if mode == 'predict':
            return self.calc_score(*input)


4.4 模型训练

主要包括交替训练CF与KGC两个任务,并在每次交替训练后更新消息传递的权重。

main_kgat.py

def train(args):
    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    log_save_id = create_log_id(args.save_dir)
    logging_config(folder=args.save_dir, name='log{:d}'.format(log_save_id), no_console=False)
    logging.info(args)

    # GPU / CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load data
    data = DataLoaderKGAT(args, logging)
    if args.use_pretrain == 1:
        user_pre_embed = torch.tensor(data.user_pre_embed)
        item_pre_embed = torch.tensor(data.item_pre_embed)
    else:
        user_pre_embed, item_pre_embed = None, None

    # construct model & optimizer
    model = KGAT(args, data.n_users, data.n_entities, data.n_relations, data.A_in, user_pre_embed, item_pre_embed)
    if args.use_pretrain == 2:
        model = load_model(model, args.pretrain_model_path)

    model.to(device)
    logging.info(model)

    cf_optimizer = optim.Adam(model.parameters(), lr=args.lr)
    kg_optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # initialize metrics
    best_epoch = -1
    best_recall = 0

    Ks = eval(args.Ks)
    k_min = min(Ks)
    k_max = max(Ks)

    epoch_list = []
    metrics_list = {k: {'precision': [], 'recall': [], 'ndcg': []} for k in Ks}

    # train model
    for epoch in range(1, args.n_epoch + 1):
        time0 = time()
        model.train()

        # train cf
        time1 = time()
        cf_total_loss = 0
        n_cf_batch = data.n_cf_train // data.cf_batch_size + 1
        # 交替训练CF与KGC
        for iter in range(1, n_cf_batch + 1):
            time2 = time()
            # 采样一个cf_batch_size的user list,并为user list中的每一个user采样一个正样例和负样例。
            cf_batch_user, cf_batch_pos_item, cf_batch_neg_item = data.generate_cf_batch(data.train_user_dict, data.cf_batch_size)
            cf_batch_user = cf_batch_user.to(device)
            cf_batch_pos_item = cf_batch_pos_item.to(device)
            cf_batch_neg_item = cf_batch_neg_item.to(device)


            cf_batch_loss = model(cf_batch_user, cf_batch_pos_item, cf_batch_neg_item, mode='train_cf')

            if np.isnan(cf_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (CF Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_cf_batch))
                sys.exit()

            cf_batch_loss.backward()
            cf_optimizer.step()
            cf_optimizer.zero_grad()
            cf_total_loss += cf_batch_loss.item()

            if (iter % args.cf_print_every) == 0:
                logging.info('CF Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_cf_batch, time() - time2, cf_batch_loss.item(), cf_total_loss / iter))
        logging.info('CF Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_cf_batch, time() - time1, cf_total_loss / n_cf_batch))

        # train kg
        time3 = time()
        kg_total_loss = 0
        n_kg_batch = data.n_kg_train // data.kg_batch_size + 1

        for iter in range(1, n_kg_batch + 1):
            time4 = time()
            kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail = data.generate_kg_batch(data.train_kg_dict, data.kg_batch_size, data.n_users_entities)
            kg_batch_head = kg_batch_head.to(device)
            kg_batch_relation = kg_batch_relation.to(device)
            kg_batch_pos_tail = kg_batch_pos_tail.to(device)
            kg_batch_neg_tail = kg_batch_neg_tail.to(device)

            kg_batch_loss = model(kg_batch_head, kg_batch_relation, kg_batch_pos_tail, kg_batch_neg_tail, mode='train_kg')

            if np.isnan(kg_batch_loss.cpu().detach().numpy()):
                logging.info('ERROR (KG Training): Epoch {:04d} Iter {:04d} / {:04d} Loss is nan.'.format(epoch, iter, n_kg_batch))
                sys.exit()

            kg_batch_loss.backward()
            kg_optimizer.step()
            kg_optimizer.zero_grad()
            kg_total_loss += kg_batch_loss.item()

            if (iter % args.kg_print_every) == 0:
                logging.info('KG Training: Epoch {:04d} Iter {:04d} / {:04d} | Time {:.1f}s | Iter Loss {:.4f} | Iter Mean Loss {:.4f}'.format(epoch, iter, n_kg_batch, time() - time4, kg_batch_loss.item(), kg_total_loss / iter))
        logging.info('KG Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s | Iter Mean Loss {:.4f}'.format(epoch, n_kg_batch, time() - time3, kg_total_loss / n_kg_batch))
        # 交替训练完一次更新注意力权重
        # update attention
        time5 = time()
        # h_list/t_list/r_list是CKG图中所有的头实体、关系、尾实体列表
        h_list = data.h_list.to(device)
        t_list = data.t_list.to(device)
        r_list = data.r_list.to(device)
        relations = list(data.laplacian_dict.keys())
        model(h_list, t_list, r_list, relations, mode='update_att')
        logging.info('Update Attention: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time5))

        logging.info('CF + KG Training: Epoch {:04d} | Total Time {:.1f}s'.format(epoch, time() - time0))

        # evaluate cf
        if (epoch % args.evaluate_every) == 0 or epoch == args.n_epoch:
            time6 = time()
            _, metrics_dict = evaluate(model, data, Ks, device)
            logging.info('CF Evaluation: Epoch {:04d} | Total Time {:.1f}s | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
                epoch, time() - time6, metrics_dict[k_min]['precision'], metrics_dict[k_max]['precision'], metrics_dict[k_min]['recall'], metrics_dict[k_max]['recall'], metrics_dict[k_min]['ndcg'], metrics_dict[k_max]['ndcg']))

            epoch_list.append(epoch)
            for k in Ks:
                for m in ['precision', 'recall', 'ndcg']:
                    metrics_list[k][m].append(metrics_dict[k][m])
            best_recall, should_stop = early_stopping(metrics_list[k_min]['recall'], args.stopping_steps)

            if should_stop:
                break

            if metrics_list[k_min]['recall'].index(best_recall) == len(epoch_list) - 1:
                save_model(model, args.save_dir, epoch, best_epoch)
                logging.info('Save model on epoch {:04d}!'.format(epoch))
                best_epoch = epoch

    # save metrics
    metrics_df = [epoch_list]
    metrics_cols = ['epoch_idx']
    for k in Ks:
        for m in ['precision', 'recall', 'ndcg']:
            metrics_df.append(metrics_list[k][m])
            metrics_cols.append('{}@{}'.format(m, k))
    metrics_df = pd.DataFrame(metrics_df).transpose()
    metrics_df.columns = metrics_cols
    metrics_df.to_csv(args.save_dir + '/metrics.tsv', sep='\t', index=False)

    # print best metrics
    best_metrics = metrics_df.loc[metrics_df['epoch_idx'] == best_epoch].iloc[0].to_dict()
    logging.info('Best CF Evaluation: Epoch {:04d} | Precision [{:.4f}, {:.4f}], Recall [{:.4f}, {:.4f}], NDCG [{:.4f}, {:.4f}]'.format(
        int(best_metrics['epoch_idx']), best_metrics['precision@{}'.format(k_min)], best_metrics['precision@{}'.format(k_max)], best_metrics['recall@{}'.format(k_min)], best_metrics['recall@{}'.format(k_max)], best_metrics['ndcg@{}'.format(k_min)], best_metrics['ndcg@{}'.format(k_max)]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/771284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

docker容器引擎(一)

docker 一、docker的理论部分docker的概述容器受欢迎的原因容器与虚拟机的区别docker核心概念 二、安装docker三、docker镜像操作四、docker容器操作 一、docker的理论部分 docker的概述 一个开源的应用容器引擎&#xff0c;基于go语言开发并遵循了apache2.0协议开源再Linux容…

ThreeJS打造自己的人物

hello&#xff0c;大家好&#xff0c;我是better&#xff0c;今天为大家分享如何使用Three打造属于自己的3D人物模型。 人物建模 当下有很多人物建模的网站&#xff0c;这里给大家分享的 Ready Player Me - Create a Full-Body 3D Avatar From a Photo 前往这个网址&#xff…

AI销售工具:驱动销售团队效率和个性化服务的未来

在数字化时代&#xff0c;AI销售工具成为推动销售行业发展的重要力量。这些创新工具融合了人工智能技术和销售流程&#xff0c;以提高销售团队的效率和提供个性化服务为目标。随着科技的不断进步&#xff0c;AI销售工具正引领着销售行业走向一个更加智能和高效的未来。 AI驱动的…

专题-【哈希函数】

14年三-3&#xff09; 20年三-2&#xff09; 哈希表&#xff1a; 查找成功&#xff1a; ASL(12113144)/817/8&#xff1b; 查找失败&#xff08;模为7&#xff0c;0-6不成功次数进行计算&#xff09;&#xff1a; ASL(3217654)/74。

Qgis3.16ltr+VS2017二次开发环境搭建(保姆级教程)

1.二次开发环境搭建 下载osgeo4w-setup.exeDownload QGIShttps://www.qgis.org/en/site/forusers/download.html 点击OSGeo4W Network Installer 点击下载 OSGeo4W Installer 运行程序 osgeo4w-setup.exe&#xff0c;出现以下界面&#xff0c;点击下一页。 选中install from i…

【全方位解析】如何写好技术文章

前言 为何而写 技术成长&#xff1a;相对于庞大的计算机领域的知识体系&#xff0c;人的记忆还是太有限了&#xff0c;而且随着年龄的增大&#xff0c;记忆同样也会逐渐衰退&#xff0c;正如俗话所说“好记性不如烂笔头”。并且在分享博客的过程中&#xff0c;我们也可以和大…

小白带你学习Linux的rsync的基本操作(二十四)

目录 前言 一、概述 二、特性 1、快速 2、安全 三、应用场景 四、数据的同步方式 五、rsync传输模式 六、rsync应用 七、rsync命令 1、格式 2、选项 3、举例 4、配置文件 5、练习 八、rsyncinotfy实时同步 1、服务器端 2、开发客户端 前言 Rsync是一个开源的…

光线追踪计算加速:包围盒

包围盒&#xff08;Bounding box&#xff09;是加速光线追踪&#xff08;Ray Tracing&#xff09;的最简单方法&#xff0c;不一定将其视为加速结构&#xff0c;但这无疑是减少渲染时间的最简单方法。 推荐&#xff1a;用 NSDT设计器 快速搭建可编程3D场景。 使用包围盒来加速光…

椒图——靶场模拟

先查看ip&#xff0c;10.12.13.232模拟的外网ip&#xff0c;其他的模拟内网ip&#xff0c;服务里面搭建好的漏洞环境。 #第一个测试项目&#xff0c;web风险发现 新建&#xff0c;下发任务&#xff0c;点威胁检测&#xff0c;webshell&#xff0c;点扫描任务&#xff0c;点新…

迅镭激光赋能工程机械,客户连续复购激光加工设备达双赢!

工程机械是装备制造业的重要组成部分&#xff0c;当前&#xff0c;我国已成为门类齐全、规模庞大、基础坚实、竞争力强的工程机械设备制造大国。 随着工程机械产业正在全面向智能化、绿色化转型&#xff0c;激光加工成为推动工程机械产业转型升级的重要工具&#xff0c;越来越多…

一道SQL题

有个搞数仓的朋友不知道从哪儿弄了个题。。。 做了做体验了一下。。。 记录记录。 分析 要保证每天都要做新题 5天必须都做题&#xff0c;不然GG 最后一天必须做新题&#xff0c;如果最后一天做新题了&#xff0c;前面那几天没做新题&#xff0c;做的是老题 最后一天&#…

初识mysql之理解索引

目录 一、 primary key对索引的影响 1. 主键数据有序问题 2. mysql中的page 3. 主键排序问题 二、理解多个page 1. 数据在page中的保存 2. 页目录 3. 单页情况 4. 多页情况 5. 为什么除了叶子节点外的其他节点不保存数据&#xff0c;只保存目录 6. 为什么叶子节点全…

创意网页模板免费下载,让你的网站与众不同!

今天给大家带来的网站模板素材&#xff0c;网站类型丰富&#xff0c;包含户外旅行、餐饮、个人网站等等&#xff0c;可以学习和参考其中的布局排版和配色。 ⬇⬇⬇点击获取更多设计资源 https://js.design/community?categorydesign&sourcecsdn&planbbqcsdn772 1、设…

域内信息收集

将网络中多台计算机逻辑上组织到一起进行集中管理&#xff0c;这种区别于工作组的逻辑环境叫 做域。域是由域控制器(Domain Controller)和成员计算机组成&#xff0c;域控制器就是安装了活动 目录(Active Directory)的计算机。活动目录提供了存储网络上对象信息并使用网络使用该…

图像处理之canny边缘检测(非极大值抑制和高低阈值)

Canny 边缘检测方法 Canny算子是John F.Canny 大佬在1986年在其发表的论文 《Canny J. A computational approach to edge detection [J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 1986 (6): 679-698.》提出来的。 检测目标&#xff1a; 低错误率…

2023年软件测试八股文(含答案+文档)

Part1 1、你的测试职业发展是什么&#xff1f; 测试经验越多&#xff0c;测试能力越高。所以我的职业发展是需要时间积累的&#xff0c;一步步向着高级测试工程师奔去。而且我也有初步的职业规划&#xff0c;前3年积累测试经验&#xff0c;按如何做好测试工程师的要点去要求自…

汽车新品研发用泛微事井然,全过程数字化、可视化

产品研发是汽车制造产业链的运营过程中的初始阶段&#xff0c;是提升汽车企业创新力、竞争力的重要一环&#xff0c;不仅要洞悉市场变化&#xff0c;还要有效协同企业内部的各类资源… 汽车新产品研发项目周期长、资源投入大&#xff0c;面临着诸多挑战&#xff1a; 1、市场需…

采集传感器的物联网网关怎么采集数据?

随着工业4.0和智能制造的快速发展&#xff0c;物联网&#xff08;IoT&#xff09;技术的应用越来越广泛&#xff0c;传感器在整个物联网系统中使用非常普遍&#xff0c;如温度传感器、湿度传感器、光照传感器等&#xff0c;对于大部分物联网应用来说&#xff0c;采集传感器都非…

02.MySQL——CURD

文章目录 表的增删改查Create单行数据全列插入多行数据指定列插入插入否则更新替换——REPLACE RetrieveSELECT 列WHERE 条件结果排序筛选分页结果 UpdateDelete删除数据截断表 插入查询结果聚合函数group bywhere和having SQL查询中关键字优先级函数日期函数字符串函数数学函数…

Spring 事务控制

1. 编程式事务控制相关对象 1.1 平台事务管理器 1.2 事务定义对象 1.3 事务状态对象 关系&#xff1a; PlatformTransactionManager TransactionManager TransactionStatus 2. 基于XML的声明式事务控制 切点&#xff1a;&#xff08;目标对象&#xff09;业务方法&#xff…