torch_geometric实现GCN和LightGCN

news2025/1/23 9:29:40

torch_geometric实现GCN和LightGCN

  • 题记
  • demo示意图
  • GCN代码
  • LightGCN代码
  • 参考博文及感谢

题记

使用torch_geometric实现GCN和LightGCN,以后可能要用,做一下备份

demo示意图

在这里插入图片描述

GCN代码

X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X Θ \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta} X=D^1/2A^D^1/2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones


torch.manual_seed(2023)
"""
 默认   \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},加自连接,按权重传递
        传递完成后归一化
"""


class BaseModel(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize=True, self_loops=True, bias=True, aggr='add', **kwargs):
        super(BaseModel, self).__init__(aggr=aggr, **kwargs)
        self.aggr = aggr
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.self_loops = self_loops
        self.normalize = normalize
        self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.in_channels, self.weight)
        uniform(self.in_channels, self.bias)


    def forward(self, x, edge_index, edge_weight=None):
        if self.self_loops:
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
        x = torch.matmul(x, self.weight)  # 表示乘以一个可学习参数矩阵
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
        # propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)

    def message(self, x_j, edge_index, size, edge_weight):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        return norm.view(-1, 1) * x_j if norm is not None else x_j

        # norm = edge_weight   #将上面全部注释,即没有对邻接矩阵的归一化
        # return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)  # 按行进行归一化
        return aggr_out

    def __repr(self):
        return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


x = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
GCN = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]])  # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2

h = F.leaky_relu(GCN(x, edge_index, edge_weight=edge_weight))
print(h)

LightGCN代码

X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} X=D^1/2A^D^1/2X

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones

torch.manual_seed(2023)
"""
 默认   \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X},不加自连接,按权重传递
        传递完成后不进行归一化
"""


class BaseModel(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize=False, self_loops=False, aggr='add', **kwargs):
        super(BaseModel, self).__init__(aggr=aggr, **kwargs)
        self.aggr = aggr
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.self_loops = self_loops
        self.normalize = normalize

    def forward(self, x, edge_index, edge_weight=None):
        if self.self_loops:
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
        # propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)

    def message(self, x_j, edge_index, size, edge_weight):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        return norm.view(-1, 1) * x_j if norm is not None else x_j

        # norm = edge_weight   #将上面全部注释,即没有对邻接矩阵的归一化
        # return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)  # 按行进行归一化
        return aggr_out

    def __repr(self):
        return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


x = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
Lightgcn = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]])  # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2

h = Lightgcn(x, edge_index, edge_weight=edge_weight)
print(h)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 MMGCN论文开源代码
https://github.com/weiyinwei/MMGCN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/883011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 僵死进程

fork复制进程之后,会产生一个进程叫做子进程,被复制的进程就是父进程。不管父进程先结束,还是子进程先结束,对另外一个进程完全没有影响,父进程和子进程是两个不同的进程。 一、孤儿进程 现在有以下代码:…

【Vue-Router】路由元信息

路由元信息(Route Meta Information)是在路由配置中为每个路由定义的一组自定义数据。这些数据可以包含任何你希望在路由中传递和使用的信息,比如权限、页面标题、布局设置等。Vue Router 允许你在路由配置中定义元信息,然后在组件…

国产32位单片机XL32F001,带1 路 12bit ADC,I2C、SPI、USART 等外设

XL32F001 系列单片机采用高性能的 32 位 ARM Cortex-M0内核,宽电压工作范围的 MCU。嵌入 24KbytesFlash 和 3Kbytes SRAM 存储器,最高工作频率 24MHz。包含多种不同封装类型多款产品。芯片集成 I2C、SPI、USART 等通讯外设,1 路 12bit ADC&am…

【校招VIP】测试方案之测试用例分析

考点介绍 测试用例是测试岗面试和工作后的核心,在面试里对测试用例的分析是高频考查点。但是很多同学因为没有真实的商业产品需求,只能简单的看别人的用例学习,导致面试时被一个陌生问题卡住。 比如最简单的用户名密码输入,在商业…

2023年网络安全比赛--综合渗透测试(超详细)

一、竞赛时间 180分钟 共计3小时 二、竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 1.扫描目标靶机将靶机开放的所有端口,当作flag提交(例:21,22,23); 2.扫描目标靶机将靶机的http服务版本信息当作flag提交(例:apache 2.3.4); 3.靶机网站存在目录遍历漏洞,请将…

数据库--SQL关键字的执行顺序

一条sql语句通常包括: select from join where group by having order by 聚合函数 limit top 浅谈执行顺序: 1)、首先确定一点,并不是按照我们写的语句顺序,从左—>右执行的 2)、…

JVM——分代收集理论和垃圾回收算法

一、分代收集理论 1、三个假说 弱分代假说:绝大多数对象都是朝生夕灭的。 强分代假说:熬过越多次垃圾收集过程的对象越难以消亡。 这两个分代假说共同奠定了多款常用的垃圾收集器的一致的设计原则:收集器应该将Java堆划分出不同的区域&…

R语言实现免疫浸润分析(1)

免疫浸润分析是生物信息学研究中的一项关键内容,它旨在评估肿瘤微环境中不同类型的免疫细胞组成。免疫细胞在肿瘤发展和治疗中起着至关重要的作用,因为它们可以影响肿瘤的生长、扩散和对治疗的响应。 为了了解免疫细胞在肿瘤中的分布和数量,…

【潮州饶平】联想 IBM x3850 x6 io主板故障 服务器维修

哈喽 最近比较忙也好久没有更新服务器维修案例了,这次分享一例潮州市饶平县某企业工厂一台IBM System x3850 x6服务器亮黄灯告警且无法正常开机的服务器故障问题。潮州饶平ibm服务器维修IO主板故障问题 故障如下图所示: 故障服务器型号:IBM 或…

客达天下项目案例

本资料转载于传智播客https://www.itheima.com/ https://space.bilibili.com/3493265607232348 黑马程序员主办的全日制统招大学——大同互联网职业技术学院 预计2024年开始招生,敬请持续关注! B站视频入口:002_接口项目介绍_哔哩哔哩_bili…

互联网发展历程:从布线到无线,AC/AP的崭新时代

互联网的发展,一直在追求更便捷、更灵活的连接方式。在网络的早期,布线问题常常让人头疼。一项革命性的技术应运而生,那就是“无线AC/AP”。 布线问题的烦恼:繁琐的布线 早期网络的布线工作常常耗费时间和精力,尤其在大…

随机森林:人类基因组中病毒片段识别

百万年前人类基因组中基因组中就已经嵌入了病毒序列,其中一部分在某些条件下会致病,通过基因测序获得海量片段之后就可以判断正常基因和病毒序列了。 我们根据这种包含众多碱基的基因测序结果从中选取部分特征,关于特征的选取也是有好有坏的…

剑指offer62.圆圈中最后剩下的数字

这道题在算法课上的一个小故事上有一个类似的,就是一个军官打了败仗,带着他的几个兵逃到一个山洞,他们不想当俘虏想自杀,但是军官不想自杀但是又不好意思走,于是军官想了个办法,他们几个人围成一个圈&#…

数据库的新工具datagrip

datagrip的安装(一路next即可) 首先,双击datagrip安装包,会出现下面的界面,然后直接点击next 继续点击next 选中tatagrip,然后在点击next 点击install 勾选datagrip,然后在点击finish 直接点击…

【面试题】JavaScript高级四、高阶技巧

JavaScript高级四、高阶技巧 1、深浅拷贝 首先浅拷贝和深拷贝只针对引用类型 (1)浅拷贝 浅拷贝:拷贝对象的属性的值(简单类型存的值就是值本身,引用类型存的值是对象的堆地址),所以如果拷贝的…

大模型PEFT技术原理(二):P-Tuning、P-Tuning v2

随着预训练模型的参数越来越大,尤其是175B参数大小的GPT3发布以来,让很多中小公司和个人研究员对于大模型的全量微调望而却步,近年来研究者们提出了各种各样的参数高效迁移学习方法(Parameter-efficient Transfer Learning&#x…

gitee上传一个本地项目到一个空仓库

gitee上传一个本地项目到一个空仓库 引入 比如,你现在本地下载了一个半成品的框架,现在想要把这个本地项目放到gitee的仓库上,这时就需要我们来做到把这个本地项目上传到gitee上了。 具体步骤 1. 登录码云 地址:https://gite…

基于安防监控EasyCVR视频汇聚融合技术的运输管理系统的分析

一、项目背景 近年来,随着物流行业迅速发展,物流运输费用高、运输过程不透明、货损货差率高、供应链协同能力差等问题不断涌现,严重影响了物流作业效率,市场对于运输管理数字化需求愈发迫切。当前运输行业存在的难题如下&#xf…

Hlang社区项目说明

文章目录 前言Hlang社区技术前端后端 前言 Hello,欢迎来到本专栏,那么这也是第一次做这种类型的专栏,如有不做多多指教。那么在这里我要隆重介绍的就是这个Hlang这个项目。 首先,这里我要说明的是,我们的这个项目其实是分为两个…

【第三阶段】kotlin语言中的语法异常处理与自定义异常特点

fun main() {var name:String?nulltry{checkException(name)println(name!!.length)//不管name是不是null 后面都会执行}catch(e:Exception){println("你好$e")} }fun checkException(name:String?){name?:throw CustException()//?: 如果name为null 执行后面的抛…