WWW2022 | 基于领域增强的图对比协同过滤方法+代码实践

news2025/1/11 14:17:40

嘿,记得给“机器学习与推荐算法”添加星标


今天跟大家分享一篇将对比学习应用于图协同过滤方法的文章,该论文发表于WWW2022会议上。其主要思想是在图神经网络协同过滤方法上应用了两种领域类型的对比学习方法,分别是显式的结构领域和隐式的语义间的领域,相比于随机采样的对比学习方式,其挖掘了用户或者物品的邻居关系并开发了对比学习在推荐系统上的潜力,实验表明该方法在多个数据集上取得了良好的推荐性能。

1b9792813e93ed3fcfd1ee667901befa.png

论文:https://arxiv.org/abs/2202.06200
代码:https://github.com/RUCAIBox/NCL/blob/master/ncl.py

近年来,图协同过滤方法得到了非常广泛的关注。虽然可以实现较好的推荐性能,但其仍然存在数据稀疏等问题。为了缓解数据稀疏问题,常见的做法是在图协同过滤方法的基础上引用对比学习。然而,当前主流的基于对比学习的图协同过滤方法主要通过随机采样的方式来构建对比对,但这样的方式忽略了用户或者物品间的邻居关系,因此不能将对比学习推荐方法的威力发挥到极致。基于此,该文提出了一种邻域增强的对比学习推荐方法NCL。

f9ae029957611afa42816a49cec3b7f0.png

其可以显式的将潜在的邻域信息建模在对比对中。其中本文从图结构和语义空间中引用了两种具体的邻域对比对,即结构对比对(structure contrastive pair)和语义对比对(semantic contrastive pair),更加直观的图示可见图1。对于结构对比对,主要是从交互图中提取的,其将当前用户以及当前用户的邻居当做正对比对。对于语义对比对,主要是在Embedding空间中将当前用户的Embedding与所在的簇中心Embedding当做正样本。

该方法将LightGCN作为backbone,通常经过propagate和readout过程来生成用户和物品的特征表示,具体的公式如下:

be8ac8672cbff7a6a453bb995e9371e4.png

其中表示第层用户的特征表示,表示用户的邻居,表示GNN的层数,表示用户的初始Embedding。聚合了该用户和其邻居在第层的特征表示。整合了层的特征表示以此来获得对于用户在多阶邻居上的语义特征表示,常见的readout操作比如last-layer only、concatenation以及weighted sum等。物品的特征表示具有类似的上述过程。

该方法在聚合的过程中丢弃了非线性变换、特征转换以及自连接,所以对于用户和物品的特征聚合形式如下:

434535d1cac490fc4a8bc820381e4262.png
all_embeddings = self.get_ego_embeddings()
  embeddings_list = [all_embeddings]
  for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)):
      all_embeddings = torch.sparse.mm(self.norm_adj_mat, all_embeddings)
      embeddings_list.append(all_embeddings)

该方法在生成最终的第层表示时采用加权求和(weighted sum)的方法,具体形式如下:

37e5929276fd1725731d50dbee91346a.png
lightgcn_all_embeddings = torch.stack(
  embeddings_list[: self.n_layers + 1], dim=1
        )
  lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

当获得用户和物品的特征表示后采用内积的方式进行推荐,即。

def predict(self, interaction):
     user = interaction[self.USER_ID]
     item = interaction[self.ITEM_ID]

     user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

     u_embeddings = user_all_embeddings[user]
     i_embeddings = item_all_embeddings[item]
     scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
     return scores

然后采用BPR损失进行监督式训练。

f40f5e2b523e4be6adbad20b7e8eff74.png
gamma = 1e-10
    pos_score = torch.randn(3, requires_grad=True)
    neg_score = torch.randn(3, requires_grad=True)
    loss = -torch.log(gamma + torch.sigmoid(pos_score - neg_score)).mean()

通过优化BPR损失可以对用户和物品之间的交互进行建模。然而,用户(或物品)内的高阶邻居关系对于推荐也是有价值的。例如,用户更有可能购买与邻居相同的产品。接下来将介绍本文所提出的两个对比学习对,以捕捉用户和物品的潜在邻居关系。具体的示意图如下图所示。

120ebb6eed4cfa4452bba17669c0258c.png

基于结构邻域的对比学习

为了充分利用对比学习的优势,首先将每个用户(或物品)与其显式的结构邻居进行对比,然后再通过GNN进行聚合得到最终的表示。其中,基本GNN模型的第层的输出表示每个节点跳结构邻居的加权和,因此可以利用其偶数跳的输出来表示该节点的结构领域。具体而言,我们将用户自身的嵌入和偶数层GNN的相应输出的嵌入视为正对比对。基于InfoNCE损失来进行优化,具体如下所示:

c66daf8907c13724965822bba0d6f5d2.png

其中,表示GNN模型层的输出,也就是用户的阶邻居的表示,当然得是偶数。表示当前用户的特征表示。同理物品侧的对比损失如下:

115d43d9effd9a0b8e78d3b89c2ba2b8.png
def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
        current_user_embeddings, current_item_embeddings = torch.split(
            current_embedding, [self.n_users, self.n_items]
        )
        previous_user_embeddings_all, previous_item_embeddings_all = torch.split(
            previous_embedding, [self.n_users, self.n_items]
        )

        current_user_embeddings = current_user_embeddings[user]
        previous_user_embeddings = previous_user_embeddings_all[user]
        norm_user_emb1 = F.normalize(current_user_embeddings)
        norm_user_emb2 = F.normalize(previous_user_embeddings)
        norm_all_user_emb = F.normalize(previous_user_embeddings_all)
        pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
        ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

        ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

        current_item_embeddings = current_item_embeddings[item]
        previous_item_embeddings = previous_item_embeddings_all[item]
        norm_item_emb1 = F.normalize(current_item_embeddings)
        norm_item_emb2 = F.normalize(previous_item_embeddings)
        norm_all_item_emb = F.normalize(previous_item_embeddings_all)
        pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
        ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)

        ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
        return ssl_loss

基于语义邻域的对比学习

基于结构领域的对比对显式地建模了由交互图定义的邻居。然而,结构对比损失对用户/物品的同质邻居一视同仁,这不可避免地将噪声信息引入到对比对中为了减轻这种印象,本文考虑将语义空间的领域信息引用对比学习中。具体的,通过学习每个用户和物品的潜在原型(prototype)来构造语义邻居。基于这一思想,进一步提出了原型对比目标,以探索潜在的语义邻居,并将其纳入对比学习中,以更好地捕捉协同过滤中用户和物品的语义特征。特别是,相似的用户/物品往往落在相邻的嵌入空间中,原型是指一组语义邻居的集群的中心。因此,本文将聚类算法应用于用户和物品的嵌入,以获得用户或物品的原型。由于该过程不能进行端到端优化,所以使用EM算法学习所提出的原型对比目标。形式上,GNN模型的目标是最大化以下对数似然函数:

2e35d6c63aaca7dfcccfaa6889bba6a6.png

其中,表示模型参数,表示交互矩阵,表示用户的原型,表示当前用户的向量表示。

def e_step(self):
        user_embeddings = self.user_embedding.weight.detach().cpu().numpy()
        item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
        self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
        self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)

    def run_kmeans(self, x):
        """Run K-means algorithm to get k clusters of the input tensor x"""
        import faiss

        kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
        kmeans.train(x)
        cluster_cents = kmeans.centroids

        _, I = kmeans.index.search(x, 1)

        # convert to cuda Tensors for broadcast
        centroids = torch.Tensor(cluster_cents).to(self.device)
        centroids = F.normalize(centroids, p=2, dim=1)

        node2cluster = torch.LongTensor(I).squeeze().to(self.device)
        return centroids, node2cluster

然后再根据用户当前的向量表示以及原型进行优化,因此基于原型对比对的InfoNCF损失如下:

5041c76eb43868e213c43fbd5fa10688.png

其中,表示用户的原型,其是通过K-means聚类算法来计算得出的,一共有个聚类中心。同理物品侧的损失函数如下:

15fd2f1a848626506a317040e3b68863.png
def ProtoNCE_loss(self, node_embedding, user, item):
        user_embeddings_all, item_embeddings_all = torch.split(
            node_embedding, [self.n_users, self.n_items]
        )

        user_embeddings = user_embeddings_all[user]  # [B, e]
        norm_user_embeddings = F.normalize(user_embeddings)

        user2cluster = self.user_2cluster[user]  # [B,]
        user2centroids = self.user_centroids[user2cluster]  # [B, e]
        pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
        pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
        ttl_score_user = torch.matmul(
            norm_user_embeddings, self.user_centroids.transpose(0, 1)
        )
        ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

        proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

        item_embeddings = item_embeddings_all[item]
        norm_item_embeddings = F.normalize(item_embeddings)

        item2cluster = self.item_2cluster[item]  # [B, ]
        item2centroids = self.item_centroids[item2cluster]  # [B, e]
        pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
        pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
        ttl_score_item = torch.matmul(
            norm_item_embeddings, self.item_centroids.transpose(0, 1)
        )
        ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
        proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

        proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
        return proto_nce_loss

最终通过将以上损失函数进行相加然后通过Adama优化算法进行优化,值得注意的是,在计算原型的过程中不是端到端的,因此需要EM算法来交替更新用户和物品的特征向量以及它们的原型向量。

最后该算法在5个数据集上对比了8种方法,实验结果表明所提组件在图协同过滤方法上的优越性。

d2200f6e0c259e9097bef98cac71e5ff.png

欢迎干货投稿 \ 论文宣传 \ 合作交流

推荐阅读

论文周报 | 推荐系统领域最新研究进展

深度推荐系统调参技巧总结

CCF推荐列表重磅更新, RecSys升级成为B类会议, 中国科学: 信息科学成为A类期刊...

由于公众号试行乱序推送,您可能不再准时收到机器学习与推荐算法的推送。为了第一时间收到本号的干货内容, 请将本号设为星标,以及常点文末右下角的“在看”。

喜欢的话点个在看吧👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/71006.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TGK-Planner-前后端路径规划(基于梯度的后端无约束优化)

高速移动无人机的在线路径规划一直是学界当前研究的难点,引起了大量机器人行业的研究人员与工程师的关注。然而无人机的计算资源有限,要在短时间内规划出一条安全可执行的路径,这就要求无人机的运动规划算法必须轻型而有效。本文将介绍一种无…

electron-vue中报错 Cannot use import statement outside a module解决方案(亲测有效!!!)

错误: Cannot use import statement outside a module(不能在模块之外使用导入语句)。 原因: 安装的某个依赖包里使用了import语法,因为我们打包输出的是commonjs规范,所以不识别import语法而导致报错。 可以从 .electron-vue/w…

PrimoBurnerSDK蓝光刻录工具开发工具包

PrimoBurnerSDK蓝光刻录工具开发工具包 PrimoBurnerSDK是一个CD、DVD和蓝光刻录工具开发工具包。它还提供了一个全面灵活的API,用于快速轻松地实现各种燃烧/翻录替代方案。 PrimoBurner SDK for.NET的强大功能: 自2003年以来一直在发展的广泛使用的老式发…

比机器人还智能的数字孪生地下停车场监管系统!

现在的停车场管理大多采用人工或智能收费系统,两种方式都有一个弊端就是无法直接知晓停车场内部信息。 车驶入停车场只能自行寻找停车位,工作人员也只有走进停车场才能知晓停车场内部情况,无可避免造成很多麻烦。 停车场智慧监管系统结合数…

期货开户交易操作技巧

期货交易的时候需要有一些操作技巧,以及要注意一些操作上常见的错误。 个人建议刚刚开始交易的投资者期货交易的投资者,一定要多看慢做,首先要摒弃做这个会一夜暴富的想法。抱着个想法来的往往都会折戟沉沙,一去不复返了。所以我…

基于springboot+mybatis+mysql+vue中学生成绩管理系统

基于springbootmybatismysqlvue中学生成绩管理系统一、系统介绍二、功能展示1.登陆2.用户管理(管理员)3.班主任信息管理(管理员)4.教师信息管理(管理员、班主任)5.学生信息管理(管理员)6.成绩信息管理(管理员、班主任、…

一个人,仅30天!开发一款3D竞技足球游戏!他究竟经历了些什么?

今天,晓衡向大家推荐一款Coco Store 优质 3D足球竞技游戏 资源《足球快斗》玩法介绍:游戏为 7V7 足球竞技类玩法。玩家控制本队的一个球员(脚下高亮圆圈显示的是玩家),其他球员和守门员为电脑AI控制,期间可…

Jvm上如何运行其他语言?JSR223规范最详细讲解

一 在Java的平台里,其实是可以执行其他的语言的。包括且不仅限于jvm发展出来的语言。 有的同学可能会说,在java项目里执行其他语言,这不吃饱了撑着么,java体系那么庞大,各种工具一应俱全,放着好好的java不…

责任链模式在复杂数据处理场景中的实战

相信大家在日常的开发中都遇到过复杂数据处理和复杂数据校验的场景,本文从一线开发者的角度,分享了责任链模式在这种复杂数据处理场景下的实战案例,此外,作者在普通责任链模式的基础上进行了升级改造,可以适配更加复杂…

34_DAC原理及数模转换实验

目录 数模转换原理 DAC模块框图 事件选择控制数字模拟转换 DAC转换 DAC数据格式 选择DAC触发 DAC输出电压计算 硬件连接 DAC配置步骤 实验源码 数模转换原理 STM32的DAC模块(数字/模拟转换模块)是12位数字输入,电压输出型的DAC。DAC可以配置为8位或12位模式,也可以与…

linux安装nginx

1.nginx官网 http://nginx.org/en/download.html 下载安装包,如图所示下载nginx-1.23.2,并上传到指定目录:/usr/local/src/nginx 2.解压 tar -zxvf nginx-1.23.2.tar.gz3.安装nginx, cd /usr/local/src/nginx/nginx-1.23.2 该目录…

Titanic 泰坦尼克数据集 特诊工程 机器学习建模

以下内容为讲课时使用到的泰坦尼克数据集分析、建模过程,整体比较完整,分享出来,希望能帮助大家。部分内容由于版本问题,可能无法顺利运行。 Table of Contents 1 经典又有趣的Titanic问题1.1 目标1.2 解决方法1.3 项目目的2…

Vector-常用CAN工具 - CANoe入门到精通_03

NetWork Node 前面已经介绍了CANoe的基本情况、硬件环境搭建、CANoe软件环境配置,今天我们就来聊一下NetWork Node,在我们的测试工作中,大部分情况我们默认CANoe作为一个Client端,但是有些情况,我们需要实时监测被测件…

Akka 学习(四)Remote Actor

目录一 介绍1.1 Remote Actor1.2 适用场景1.3 踩坑点二 实战2.1 需求2.2 Java 版本2.2.1 效果图2.2.2 实体类2.2.3 服务端Actor 处理2.2.4 服务端配置文件2.2.5 客服端Actor处理2.2.6 客服端配置文件2.2.7 测试2.3 Scala 版本2.3.1 效果2.2.3 服务端Actor处理2.3.4 客户端Actor…

使用 Excel 数据透视表深入研究数据分析

问题 1(文章数据在底部) 为美国选民案例研究创建一个数据透视表,并用它来回答以下问题: A) 有多少个州的选民人口百分比低于 55%?哪些州? 答:有5个州的选民人数低于55%,分别是得克萨斯州、阿肯色州、俄克拉荷马州、夏威夷州和西弗吉尼亚州。 步骤:根据以下结果,创建…

基于jsp+java+ssm的社会保险信息管理系统-计算机毕业设计

项目介绍 课题研究的基本内容及预期目标或成果 用户注册与登录功能,在单位注册功能中有申请管理功能,填写具体信息。 系统管理员: 1)个人密码修改:实现了管理员用户密码信息的修改。 2)参保人员管理&a…

ORACE dbca创建报错Oracle system identifier(SID) “orcl“

最近项目需要通过备份恢复oracle实例,必须使用orcl,通过dbca创建实例是提示如下报错: 查看日志,$ORACLE_HOME/cfgtoollogs/dbca/dbcaui.log EVERE: [FATAL] A database instance with Oracle system identifier(SID) "orcl&…

零基础入门推荐系统 - 新闻推荐 - 实操2

内容导航: 零基础入门推荐系统 - 新闻推荐 - 实操2比赛数据分析:用户属性分析:训练集和测试集中分别有多少用户?用户城市分布有什么规律?平均每个用户会点击多少个文章?点击来源与文章点击次数是否存在关联?用户行为分析:零基础入…

【车载开发系列】UDS诊断---读取周期标识符($0x2A)

【车载开发系列】UDS诊断—读取周期标识符($0x2A) UDS诊断---读取周期标识符($0x2A)【车载开发系列】UDS诊断---读取周期标识符($0x2A)一.概念定义二.报文格式1)请求报文2)初始响应3…