【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析

news2025/2/25 15:24:01

如果觉得我的分享有一定帮助,欢迎关注我的微信公众号 “码农的科研笔记”,了解更多我的算法和代码学习总结记录。或者点击链接扫码关注【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析

【RecBole-GNN/源码】RecBole-GNN中lightGCN源码解析


原文:https://arxiv.org/pdf/2002.02126.pdf

源码:伯乐工具箱

LightGCN架构图

输入数据源(图节点仅仅使用了用户或者物品的ID进行模型搭建):

  • ml-1m.inter
  • ml-1m.item
  • ml-1m.user

GCN聚合消息需要定义节点特征以及边

1 节点

节点特征(是需要经过训练得到合适的embedding):得到所有节点特征all_embeddings(9748(6041+3707)*64)

#定义user嵌入:6041*64
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
#定义item嵌入:3707*64
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
#进行组合得到:9748(6041+3707)*64
all_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)

2 边

得到所有边edge_index(1610886-1) 以及权重 edge_weight(1610886-1)

#根据.iter交互文件,获取user_id那一列作为row(805443*1)
row = self.inter_feat[self.uid_field]
#根据.iter交互文件,获取item_id那一列作为col(计数id需要加self.user_num)(805443*1)
col = self.inter_feat[self.iid_field] + self.user_num
edge_index1 = torch.stack([row, col])
edge_index2 = torch.stack([col, row])
#得到所有边矩阵2*1610886(805443+805443)
# row col //因为边是双向的
# col row 
edge_index = torch.cat([edge_index1, edge_index2], dim=1)
# 获得每个节点的度(节点的连边)
deg = degree(edge_index[0], self.user_num + self.item_num)
#对于每个节点,如果其度数为 $0$,则将其规范化因子设为 $1$,否则将其规范化因子设为 $1/\sqrt{\text{degree}}$。最终,得到的 #norm_deg 张量表示了每个节点的规范化因子。
norm_deg = 1. / torch.sqrt(torch.where(deg == 0, torch.ones([1]), deg))
#为每条边计算一个权重,该权重等于该边两个节点的规范化因子之积。(1610886*1)
edge_weight = norm_deg[edge_index[0]] * norm_deg[edge_index[1]]

3 GCN聚合

for layer_idx in range(self.n_layers):
    all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
    embeddings_list.append(all_embeddings)
#多轮嵌入求均值
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
#获得user和item节点的最终嵌入表示
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])

self.propagate(edge_index, x=x, edge_weight=edge_weight) 是 PyTorch Geometric(简称 PyG)库中定义的一个函数。该函数的作用是对输入的节点特征矩阵 x 进行消息传递,更新节点特征矩阵,并返回更新后的节点特征矩阵。

其中,edge_index 是一个形状为 2 × E 2 \times E 2×E 的张量,表示图中所有边的起始节点和结束节点的编号, E E E 表示边的数量;x 是一个形状为 N × F N \times F N×F 的节点特征矩阵,表示图中所有 N N N 个节点的特征, F F F 表示每个节点的特征向量的维度;edge_weight 是一个形状为 E E E 的张量,表示图中每条边的权重。

在该函数中,消息传递的方式是通过定义一个 message 函数和一个 update 函数来实现的。message 函数的作用是将源节点的特征和边权重作为输入,计算出每条边传递的消息;update 函数的作用是将每个节点收到的消息进行聚合,并更新节点的特征。

具体来说,该函数中的 propagate 函数会对输入的 xedge_weight 执行消息传递,按照以下步骤进行:

  1. 根据输入的 edge_indexedge_weight 构造一个稀疏权重矩阵 edge_index,形状为 N × N N \times N N×N,其中 N N N 表示节点数量,矩阵中的每个元素表示一条边的权重。
  2. 调用 message 函数,将源节点的特征和边权重作为输入,计算出每条边传递的消息。
  3. 将每个节点收到的消息进行聚合,并更新节点的特征。具体来说,对于每个节点 i i i,将其所有邻居节点 j j j 的消息按照一定的方式聚合起来,得到一个新的特征向量,用于更新节点 i i i 的特征。
  4. 返回更新后的节点特征矩阵。

在实际应用中,propagate 函数通常会被多次调用,用于实现多轮消息传递,并最终得到图中所有节点的特征表示。

4 推荐任务

#获得正例和负例的各自embedding
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)
loss = mf_loss + self.reg_weight * reg_loss

5 实验

  • 和NGCF进行实验对比:
  • 和最优模型进行对比:NGCF、Mult-VAE、GRMF
  • 消融实验:证明了非线性激活和特征转换这些GCN的结构在推荐系统中并不适用,这很可能是因为推荐系统中每个图节点仅仅使用了用户或者物品的ID进行模型搭建和训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/361792.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++】初识CC++内存管理

前言 我们都知道C&C是非常注重性能的语言,因此对于C&C的内存管理是每一个C/C学习者必须重点掌握的内容,本章我们并不是深入讲解C&C内存管理,而是介绍C&C内存管理的基础知识,为我们以后深入理解C&C内存管理做铺…

基于 U-Net 网络的遥感图像语义分割 完整代码+论文

一、研究目的U-Net 是一种由全卷积神经网络启发的对称结构网络,在医疗影像分割领域取得了很好的效果。 此次研究尝试使用 U-Net 网络在对多光谱遥感影像数据集上进行训练,尝试使用卷积神经网络自动分割出建筑,希望能够得到一种自动分割遥感影…

ElementUI分页的实现

官网地址&#xff1a;Element - The worlds most popular Vue UI framework 第一步&#xff1a;拷贝你喜欢的分页类型放在你的组件页面需要用到的分页位置 <el-paginationsize-change"handleSizeChange"current-change"handleCurrentChange":current-p…

1.JAVA-JDK安装

前言&#xff1a;工具下载地址阿里云盘&#xff1a;Java-Jdk&#xff1a;https://www.aliyundrive.com/s/JpV55xhVq2A提取码: j53y一、jdk下载&#xff1a;前往Oracle官网可免费下载地址&#xff1a;https://www.oracle.com/java/technologies/downloads/ 此处我下载的是jdk8&a…

【nas折腾篇】抉择吧,是入门还是放弃

2018年公司一位女同事问群晖的nas是否值得买。我一脸懵&#xff0c;以前给公司买云服务有采购nas盘&#xff0c;直接mount挂到服务器上当存储&#xff0c;但对于单独的nas服务器没有什么概念。一晃几年过去了&#xff0c;陆续刷到些nas服务的视频&#xff0c;周边朋友用nas的也…

nginx的介绍及源码安装

文章目录前言一、nginx介绍二、nginx应用场合三、nginx的源码安装过程1.下载源码包2.安装依赖性-安装nginx-创建软连接-启动服务-关闭服务3.创建nginx服务启动脚本4.本实验---纯代码过程前言 高可用&#xff1a;高可用(High availability,缩写为 HA),是指系统无中断地执行其功…

OSI七层模型与物理层与设备链路层

目录 协议 举例 OSI七层模型 理解七层模型 以下为OSI七层模型数据逐层封装和数据逐层解封的过程 TCP/IP参考模型 数据包的层层封装与层层拆包 各层的数据以及协议 封装所用的协议的数字表示形式 物理层 模拟信号 模拟信号特点 数字信号 数字信号特点 数据通信模…

【存储】etcd的存储是如何实现的(3)-blotdb

前两篇分别介绍了etcd的存储模块以及mvcc模块。在存储模块中&#xff0c;提到了etcd kv存储backend是基于boltdb实现的&#xff0c;其在boltdb的基础上封装了读写事务&#xff0c;通过内存缓存批量将事务刷盘&#xff0c;提升整体的写入性能。botldb是etcd的真正的底层存储。本…

CSS预处理器sass和less

文章目录CSS预处理器什么是CSS预处理器Sass和LESS背景介绍Sass背景介绍LESS的背景介绍Sass安装Sass下载Ruby安装文件安装Ruby安装Sass编译Sass命令行编译命令行编译配置选项四种编译排版演示nested 编译排版格式expanded 编译排版格式compact 编译排版格式compressed 编译排版格…

Ethernet-APL——过程自动化的新黄金标准

| Ethernet-APL为终客户和设备制造商带来益处 Ethernet-APL&#xff08;Advanced Physical Layer&#xff0c;高级物理层&#xff09;是一种两线制以太网物理层&#xff0c;它使用了由IEEE 802.3cg所定义的10BASE-T1L&#xff0c;并采用了新的工艺制造规定&#xff0c;因此构成…

2.21多线程

一.并发编程java实现并发编程的方式是多线程其他语言,主打的 并发编程并不一样Go 主要通过多协程的方式实现并发erlang 是通过actor模型实现并发JS 通过定时器和事件回调的方式实现并发二.多线程在java标准库,提供了一个Thread类,表示/操作线程Thread类可以视为Java标准库提供的…

CCNP350-401学习笔记(401-450题)

401、What is the function of vBond in a Cisco SDWAN deployment? A. initiating connections with SD-WAN routers automatically B. pushing of configuration toward SD-WAN routersC. onboarding of SDWAN routers into the SD-WAN overlay D. gathering telemetry dat…

易点天下基于 StarRocks 全面构建实时离线一体的湖仓方案

作者&#xff1a;易点天下数据平台团队易点天下是一家技术驱动发展的企业国际化智能营销服务公司&#xff0c;致力于为客户提供全球营销推广服务&#xff0c;通过效果营销、品牌塑造、垂直行业解决方案等一体化服务&#xff0c;帮助企业在全球范围内高效地获取用户、提升品牌知…

yolov5源码解读--训练策略

yolov5源码解读--训练策略超参数解读命令行参数train模型迭代测试超参数解读 hyp.scratch.yaml lr0: 0.0032 初始学习率 lrf: 0.12 使用余弦函数动态降低学习率(lr0*lrf) momentum: 0.843 动量 weight_decay: 0.00036 权重衰减项 warmup_epochs: 2.0 预热&#xf…

详解Unicode字符集以及字符编码实现(一)

在日常生活中&#xff0c;我们经常会碰到打开一个文件&#xff0c;但是文件内容乱码的问题&#xff0c;比如我想看《西游记》这部小说。 下载链接&#xff1a;https://m.ijjjxs.com/txt/dl-35-12585.html 点击TXT电子书下载&#xff0c;很快就会下载完成&#xff0c;但是使用…

【测试面试】自我分析+功能+接口自动化+性能测试面试题(大全),知己知彼百战百胜......

目录&#xff1a;导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09;前言 分析自己和面试企业…

mac tcpdump学习

学习原因 工作上遇到了重启wifi后无法发出mDNS packet的情况&#xff0c;琢磨一下用tcpdump用的命令如下 sudo tcpdump -n -k -s 0 -i en0 -w VENDOR-DUT-INTERFACE.pcapng是在测airplay BCT认证时&#xff0c;官方文档的解决方法。对tcpdump很不了解&#xff0c;现汇总如下的学…

JS中数组如何去重(ES6新增的Set集合类型)+经典two sum面试题

现在有这么一个重复数组&#xff1a;const arr [a,a,b,a,b,c]只推荐简单高效的方法&#xff0c;复杂繁琐的方法不做推荐方法一&#xff1a;const res [...new Set(arr)]Set类型是什么呢&#xff1f;Set 是ES6新增的一种新集合类型。具体知识点可以看下面附录&#xff1a;根据…

ES6中Set类型的基本使用

在ES6之前&#xff0c;存储数据的结构主要有两种&#xff1a;数组、对象。 在ES6中新增了另外两种数据结构&#xff08;存放数据的方式&#xff09;&#xff1a;Set、Map&#xff0c;以及他们的另外形式WeakSet、WeakMap。 Set的基本使用 Set是一个新增的数据结构&#xff0c…

广东望京卡牌科技有限公司,2023年团建活动圆满举行

玉兔初临&#xff0c;春天相随&#xff0c;抖擞精神&#xff0c;好运连连。春天是一个万物复苏的季节&#xff0c;来自广东的望京卡牌科技有限公司&#xff0c;也迎来了新年第一次团建活动。在“乘风破浪、追逐梦想”的口号声中&#xff0c;2023望京卡牌目标启动会团结活动正式…