图神经网络实战(19)——异构图神经网络

news2025/1/13 17:33:47

图神经网络实战(19)——异构图神经网络

    • 0. 前言
    • 1. 异构图
      • 1.1 异构图基本概念
      • 1.2 构建异构图数据集
    • 2. 将同构图神经网络转换为异构图神经网络
      • 2.1 数据集介绍
      • 2.2 同构图注意力网络
      • 2.3 异构图神经网络
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何生成包含不同类型节点(原子)和边(键)的分子结构,这种技术在其它应用中也具有广泛用途,例如推荐系统(用户和商品)、社交网络(关注者和被关注者)或网络安全(路由器和服务器)。我们将这类图称为异构图 (heterogeneous graph),与同构图 (homogeneous graph) 相对,后者只涉及一种类型的节点和一种类型的边。在本节中,我们将回顾关于同构图神经网络 (Graph Neural Networks, GNN) 与消息传递神经网络框架的相关概念,以扩展 GNN 架构适用于异构图。首先,我们将创建自定义异构数据集。然后,将同构架构转化为异构架构。

1. 异构图

1.1 异构图基本概念

异构图 (heterogeneous graph)是表示不同实体间关系的强大工具,拥有不同类型的节点和边会创建更复杂但也更难学习的图结构。同时,异构图的一个主要问题是,来自不同类型节点或边的特征不一定具有相同的意义或维度。
因此,合并不同的特征会破坏大量信息。而同构图 (homogeneous graph) 则不同,在同构图中,每个节点或边的每个维度都具有完全相同的含义。
异构图是一种更通用的网络,可以表示不同类型的节点和边。从形式上看,异构图定义为由节点集 V V V 和边集 E E E 组成的图 G = ( V , E ) G = (V, E) G=(V,E),在异构图中,包括节点类型映射函数 ϕ : V → A ϕ :V→A ϕ:VA (其中 A A A 表示节点类型集),以及边类型映射函数 ψ : E → R ψ:E→R ψ:ER (其中 R R R 表示边类型集)。下图是一个具有三种节点类型和三种边类型的异构图。

异构图

在上图中,我们可以看到三种类型的节点(用户、游戏和开发者)和三种类型的边(关注、游戏和开发)。它代表了一个涉及人员(用户和开发者)和游戏的网络,可用于游戏推荐等各种应用。如果这个图包含数百万个元素,它就可以用作图结构的知识数据库或知识图谱。知识图谱能够用来回答查询,比如“谁玩 Dev 1 开发的游戏?”。
类似的查询可以提取有用的同质图。例如,我们可能只想考虑玩 Game 1 的用户,输出结果为 User 1User 2。我们也可以创建更复杂的查询,例如“谁是玩 Dev 1 开发的游戏的用户?”结果是相同的,但遍历了两个关系来获得用户,这种查询称为元路径 (meta-path)。
在第一个例子中,元路径是 User → Game → User (通常表示为 UGU),而在第二个例子中,我们的元路径是 User → Game → Dev → Game → User (或表示为 UGDGU)。需要注意的是,起点节点类型和终点节点类型是相同的。元路径是异构图中的一个基本概念,通常用于衡量不同节点的相似性。

1.2 构建异构图数据集

接下来,我们使用 PyTorch Geometric (PyG) 实现异构图,使用数据对象 HeteroData 创建一个数据对象来存储上示异构图。

(1)torch_geometric.data 中导入 HeteroData 类,并创建变量 data

from torch_geometric.data import HeteroData

data = HeteroData()

(2) 首先,存储节点特征。例如,可以使用 data['user'].x 访问用户特征。我们使用一个维度为 [num_users, num_features_users] 的张量作为输入,其中 num_users 表示用户数量,num_features_users 表示用户特征数量。在本例中,内容并不重要,因此我们将创建一个用 1 表示 user 1、用 2 表示 user 2、用 3 表示 user 3 的特征向量:

data['user'].x = torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]) # [num_users, num_features_users]

(3) 我们对游戏特征和开发者特征重复以上过程。需要注意的是,特征向量的维度并不相同;这是异构图在处理不同表示时的一个重要优势:

data['game'].x = torch.Tensor([[1, 1], [2, 2]])
data['dev'].x = torch.Tensor([[1], [2]])

(4) 接下来,在节点之间建立联系。连接具有不同的含义,因此我们将创建三组边索引。我们可以使用三元组(源节点类型、边缘类型、目标节点类型)来声明每组边索引,例如 data['user','follows','user'].edge_index。然后,将连接存储在一个维数为 [2, num_edge] 的张量中,其中 num_edge 表示边的数量:

data['user', 'follows', 'user'].edge_index = torch.Tensor([[0, 1], [1, 2]]) # [2, num_edges_follows]
data['user', 'plays', 'game'].edge_index = torch.Tensor([[0, 1, 1, 2], [0, 0, 1, 1]])
data['dev', 'develops', 'game'].edge_index = torch.Tensor([[0, 1], [0, 1]])

(5) 边也可以具有特征,例如,边 plays 可以包括用户玩相应游戏的小时数。我们假设 user 1 玩了 2 小时 game 1user 2 玩了半小时 game 110 小时 game 2user 3 玩了 12 小时 game 2

data['user', 'plays', 'game'].edge_attr = torch.Tensor([[2], [0.5], [10], [12]])

(6) 最后,打印 data 对象来验证结果:

print(data)
'''
HeteroData(
  user={ x=[3, 4] },
  game={ x=[2, 2] },
  dev={ x=[2, 1] },
  (user, follows, user)={ edge_index=[2, 2] },
  (user, plays, game)={
    edge_index=[2, 4],
    edge_attr=[4, 1]
  },
  (dev, develops, game)={ edge_index=[2, 2] }
)
'''

从以上实现中可以看出,不同类型的节点和边并不共享相同的张量,甚至它们的维度也并不相同。因此,我们需要思考如何使用图神经网络 (Graph Neural Networks, GNN) 聚合来自多个张量的信息。
在同构图中,我们只关注单一类型的节点,权重矩阵的大小适合与预定义的维度相乘。然而,当具有不同维度的输入时,该如何实现 GNN

2. 将同构图神经网络转换为异构图神经网络

2.1 数据集介绍

为了更好地理解如何将同构图神经网络 (Graph Neural Networks, GNN) 转换为异构 GNN,我们以一个真实的数据集为例。DBLP 计算机科学文献提供了一个包含四种节点类型的数据集,分别是论文(papers14328 篇)、术语(terms7723 个)、作者(authors4057 个)和会议(conferences20 个)。该数据集的目标是将作者正确地分为四类研究领域——数据库 (database)、数据挖掘 (data mining)、人工智能 (artificial intelligence) 和信息检索 (information retrieval)。作者的节点特征是他们在论文中可能使用的 334 个关键词组成的词袋( “0” 或 “1”),不同节点类型之间的关系如下所示。

请添加图片描述

这些节点类型的维度和语义关系并不相同。在异构图中,节点之间的关系至关重要,这也是需要考虑节点对的原因。例如,不需要向 GNN 层输入作者节点,而是考虑 (作者、论文) 这种节点对。这意味着我们现在需要为每个关系建立一个 GNN 层;在这种情况下,“to” 关系是双向的,因此我们需要建立六个层。
这些新层具有独立的权重矩阵,适用于每种节点类型的正确维度。现在我们有了六个不共享任何信息的不同层,可以通过引入跳跃连接 (skip-connections)、共享层 (shared layers)、跳转知识 (jumping knowledge) 等方法来解决信息共享问题。
在将同构模型转化为异构模型之前,我们先在 DBLP 数据集上实现经典的图注意力网络 (Graph Attention Networks,GAT) 模型。GAT 无法考虑不同的关系;我们必须给它一个唯一的邻接矩阵,将作者相互连接起来。可以通过使用元路径技术生成这种邻接矩阵,如作者-论文-作者,将同一篇论文的作者连接起来。
也可以通过随机游走构建一个良好的邻接矩阵。即使图是异构的,也可以进行探索,并连接经常出现在相同序列中的节点。

2.2 同构图注意力网络

接下来,使用 PyTorch Geometric (PyG) 在 DBLP 数据集上实现经典图注意力网络 (Graph Attention Networks,GAT) 架构。

(1) 导入所需的库:

from torch import nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import GAT

(2) 使用特定语法定义要使用的元路径:

metapaths = [[('author', 'paper'), ('paper', 'author')]]

(3) 使用 AddMetaPaths 转换函数自动计算元路径。使用 drop_orig_edge_types=True 从数据集中移除其他关系( GAT 只能考虑一种关系):

transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True)

(4) 加载 DBLP 数据集并打印相关信息:

dataset = DBLP('.', transform=transform)
data = dataset[0]
print(data)

输出结果如下所示,可以看到转换函数创建的 (author, metapath_0, author) 关系:

输出结果

(5) 创建一个单层 GAT 模型,其中 in_channels=-1 用于执行懒初始化(模型将自动计算值),out_channels=4 用于将作者节点分为四种类别:

model = GAT(in_channels=-1, hidden_channels=64, out_channels=4, num_layers=1)

(6) 实例化 Adam 优化器并尝试将模型和数据转移到GPU中:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

(7) 定义 test() 函数用于评估模型预测的准确性:

@torch.no_grad()
def test(mask):
    model.eval()
    pred = model(data.x_dict['author'], data.edge_index_dict[('author', 'metapath_0', 'author')]).argmax(dim=-1)
    acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
    return float(acc)

(8) 创建训练循环:

for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict['author'], data.edge_index_dict[('author', 'metapath_0', 'author')])
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        train_acc = test(data['author'].train_mask)
        val_acc = test(data['author'].val_mask)
        print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')

模型输出

(9) 在测试集上对训练后的模型进行了测试:

test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')

# Test accuracy: 73.29%

使用元路径将异构数据集缩减为同构数据集,并应用了经典 GAT 架构。模型的测试准确率为 73.29%,这可以作为其他技术进行比较的基准。

2.3 异构图神经网络

接下来,构建图注意力网络 (Graph Attention NetworksGAT) 模型的异构版本。如前所示,我们需要六个(而不再是一个) GAT 层。在 PyTorch Geometric 可以使用 to_hetero()to_hetero_bases() 函数自动完成。to_hetero() 函数需要三个重要参数:

  • module: 要转换的同构模型
  • metadata: 有关图的异构性质的信息,用元组 (node_types, edge_types) 表示,其中 node_typesedge_types 分别表示节点类型和边类型
  • aggr:聚合算子,用于聚合由不同关系(例如,求和、最大值或均值)生成的节点嵌入

同构 GAT (左图)和使用 to_hetero() 得到的异构版本(右图)如下所示。

同构 GAT 与 异构 GAT

异构 GAT 的实现过程于同构 GAT 模型相似。

(1) 首先,从 PyTorch Geometric 中导入 GNN 层:

from torch_geometric.nn import GATConv, Linear, to_hetero

(2) 加载 DBLP 数据集:

dataset = DBLP(root='.')
data = dataset[0]

(3) 当我们打印这个数据集的信息时,注意到会议节点没有任何特征。这于我们的架构假设(每个节点类型都有自己的特征)相违背,可以通过生成零值作为特征来解决此问题:

data['conference'].x = torch.zeros(20, 1)

(4) 创建 GAT 类,其中包含 GAT 层和线性层,使用 (-1, -1) 元组再次进行懒初始化:

class GAT(torch.nn.Module):
    def __init__(self, dim_h, dim_out):
        super().__init__()
        self.conv = GATConv((-1, -1), dim_h, add_self_loops=False)
        self.linear = nn.Linear(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.conv(x, edge_index).relu()
        h = self.linear(h)
        return h

(5) 实例化 GAT 模型,并使用 to_hetero() 进行转换:

model = GAT(dim_h=64, dim_out=4)
model = to_hetero(model, data.metadata(), aggr='sum')
print(model)

模型架构

(5) 实例化 Adam 优化器,并尝试将模型和数据转移到 GPU 上:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

(6) 编写 test() 函数,不需要指定任何关系,因为模型会考虑所有关系:

@torch.no_grad()
def test(mask):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict)['author'].argmax(dim=-1)
    acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
    return float(acc)

(7) 实现训练循环:

for epoch in range(101):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)['author']
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        train_acc = test(data['author'].train_mask)
        val_acc = test(data['author'].val_mask)
        print(f'Epoch: {epoch:>3} | Train Loss: {loss:.4f} | Train Acc: {train_acc*100:.2f}% | Val Acc: {val_acc*100:.2f}%')
'''
...
Epoch:  60 | Train Loss: 0.5049 | Train Acc: 98.00% | Val Acc: 73.25%
Epoch:  80 | Train Loss: 0.2687 | Train Acc: 99.25% | Val Acc: 76.75%
Epoch: 100 | Train Loss: 0.1574 | Train Acc: 100.00% | Val Acc: 76.50%
'''

(8) 训练后模型在测试数据集上的测试准确率如下:

test_acc = test(data['author'].test_mask)
print(f'Test accuracy: {test_acc*100:.2f}%')

# Test accuracy: 78.39%

异构 GAT 的测试准确率为 78.39%,比之同构版本有了较大提高 (+5.10%)。

小结

在本节中,我们扩展了消息传递神经网络 (Message Passing Neural Network, MPNN) 框架,以考虑由不同类型的节点和边组成的异构图。这种特殊的图可以表示实体之间的各种关系,这比单一类型的连接具有更高的表达能力。此外,我们还介绍了如何利用 PyTorch Geometric 将同构图神经网络 (Graph Neural Networks, GNN) 转换为异构 GNN,描述了异构图注意力网络 (Graph Attention Networks,GAT) 中的不同层,将节点对作为输入来模拟它们之间的关系。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法
图神经网络实战(17)——深度图生成模型
图神经网络实战(18)——消息传播神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2077389.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4、Unity【基础】画线功能Linerenderer、物理系统Physics

文章目录 画线功能Linerenderer1、LineRenderer是什么2、LineRender参数相关3、LineRender代码相关思考1 请写一个方法,传入一个中心点,传入一个半径,用LineRender画个圆出来思考2 在Game窗口长按鼠标用LineRender画出鼠标移动的轨迹 核心系统…

Android studio设置国内镜像代理(HTTP Proxy)教程详解

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂,风趣幽默",感觉非常有意思,忍不住分享一下给大家。 👉点击跳转到教程 1、Android Studio是在谷歌的服务器上,初次安装Android Studio时下载SDK可能…

6Valley 14.2 免授权php – 跨境电商在线商城 – 完整的电子商务APP端和web端程序

6Valley 14.2 Nulled – 多供应商电子商务 – 完整的电子商务移动应用程序、Web、卖家和管理面板 后台可自定义收款,和翻译多国语言,中文需要自己对比翻译!一般用不到中文。毕竟是跨境电商平台 带商家即时通讯,全套带文档和APP双…

Spring(2)

目录 一、使用注解开发 1.1 主要注解 1.2 衍生注解 1.3 xml与注解 二、使用Java的方式配置Spring 三、代理模式 3.1 静态代理 3.1.1 角色分析 3.1.2 代码步骤 3.1.3 优点 3.1.4 缺点 3.2 动态代理 3.2.1 代码步骤 四、AOP 4.1 使用Spring的API接口 4.2 使用自定义…

云计算实训36——mysql镜像管理、同步容器和宿主机时间、在容器外执行容器内命令、容器的ip地址不稳定问题、基础镜像的制作、镜像应用

一、线上考试系统的数据虚拟化技术部署 1.部署前段服务器 步骤一:将资源上传到服务器 将dist.zip上传给服务器 下载unzip的包 yum -y install unzip 解压 unzip dist.zip 步骤二:创建基础容器在服务器上 启动服务 systemctl start docker.servic…

LVS部署——DR集群

目录 一、LVS—DR工作原理 二、LVS-DR数据流向 三、LVS-DR模式特点和优缺点 3.1、特点 3.2、优缺点 四、LVS-DR中的ARP问题 4.1、IP地址冲突 4.2、第二次访问请求失败 五、部署LVS-DR集群 5.1、实验准备 5.2、配置负载调度器(192.168.20.15) …

SeaweedFS 分布式存储安装weed

下载Single Binary weed Start 官方推荐 https://github.com/seaweedfs/seaweedfs 下载 https://github.com/seaweedfs/seaweedfs/releases解压 single binary file weed or weed.exe. wget https://github.com/seaweedfs/seaweedfs/releases/download/3.72/darwin_amd64.…

【前端面试基础】计算机网络、浏览器、操作系统

计算机网络 一、网络协议与模型 什么是协议? 协议是指计算机系统中完成特定任务所必需的规则和约定,特别是数据传输和交换的规则和约定。OSI和TCP/IP是什么? OSI(开放式系统互连参考模型)是一种网络架构模型&#xf…

C#复习之索引器

知识点一:索引器基本概念 基本概念: 让对象可以像数组一样通过索引访问其中元素,使程序看起来更直观,更容易编写 知识点二:索引器语法 //value代表传入的值 知识点三:索引器的使用 知识点四&#xff1a…

大容量永磁同步电机转速电流双环PID控制MATLAB仿真模型

电气仔推送 模型简介 同步电机采用转速环PI控制,电流环PI控制,在电机转速300r/min,输出转矩为40000N.m时,电机的输出功率为1.25MW。模型各部分完整,电流输出正弦度好,谐波含量低。赠送建模说明文件&#…

PAT (Basic Level) Practice (中文)

1003 我要通过 通过观察不难发现在一个规律:P之前A的个数*P和T之间A的个数等于T之后A的个数答案才正确 总结一下如何才能答案正确? 1.必须只能有P,A,T这三种字符 2.P和T之间必须要有A 3.P之前A的个数*P和T之间A的个数等于T之…

【HTML】模拟消息折叠效果【附源代码】

文件结构 收起效果 展开效果 HTML部分 HTML部分定义了网页的结构和内容。 <!DOCTYPE html> 声明了文档类型和HTML版本。<html> 元素是所有其他HTML元素的父元素。<head> 元素包含了文档的元数据&#xff0c;如字符集、视口设置、标题和链接的样式表。<b…

高效又经济,乔拓云助力,快速上线功能全面的小程序解决方案

乔拓云模板化小程序开发费用 在当今数字化时代&#xff0c;小程序成为企业拓展市场的新利器。乔拓云平台提供模板化开发方案&#xff0c;让您的小程序能同时覆盖微信与百度&#xff0c;迅速触达更多用户。 选择乔拓云模板&#xff0c;无需从零开始设计&#xff0c;直接复用精美…

ssrf+redis未授权访问漏洞复现

目录 靶场搭建 报错问题解决 组合利用 使用goherus生成payload 靶场搭建 首先我们进入ubutuo拉取靶场 docker run -d -p 8765:80 8023/pikachu-expect:latest 报错问题解决 如果出现docker报错&#xff0c;靶场一直拉取不下来 解决办法&#xff1a;配置镜像加速器 vim /et…

二叉树中查找值为x的节点(递归查找)

一&#xff1a;前提 本文紧接此篇博客&#xff1a; 递归实现 前/中/后序 遍历二叉树 的详细讲解-CSDN博客 模型依旧为&#xff1a; 二&#xff1a;代码 三&#xff1a;递归展开 假设找3&#xff1a; 假设找 7,7不存在&#xff0c;最后返回NULL 左&#xff1a; 右&#xff1…

机器学习 第5章 神经网络

这里写目录标题 5.1 神经元模型5.2 感知机与多层网络5.3 误差逆传播算法5.4 其他常见神经网络5.4.1 RBF网络5.4.2 ART网络5.4.3 SOM网络5.4.4 级联相关网络 5.5 深度学习 5.1 神经元模型 神经网络是一种由神经元构成的计算模型&#xff0c;模拟了生物神经系统的工作原理。神经…

【MySQL】优化 - 深分页

深分页 问题优化方法子查询延迟关联游标 问题 就是查询偏移量过大的场景&#xff0c;会导致查询性能较低&#xff0c;例如 # MySQL 在无法利用索引的情况下跳过1000000条记录后&#xff0c;再获取10条记录 SELECT * FROM t_order ORDER BY id LIMIT 1000000, 10首先&#xff…

嵌入式:用J-Link Commander和J-Flash进行Flash编程的区别

相关阅读 嵌入式https://blog.csdn.net/weixin_45791458/category_12768532.html?spm1001.2014.3001.5482 J-Link Commander和J-Flash都是用于Flash编程的工具&#xff0c;但它们的功能和应用场景有所不同。以下是两者的区别&#xff1a; J-Link Commander: 类型: 命令行工…

.NET应用UI框架DevExpress XAF v24.1 - 可用性进一步增强

DevExpress XAF是一款强大的现代应用程序框架&#xff0c;允许同时开发ASP.NET和WinForms。DevExpress XAF采用模块化设计&#xff0c;开发人员可以选择内建模块&#xff0c;也可以自行创建&#xff0c;从而以更快的速度和比开发人员当前更强有力的方式创建应用程序。 在DevEx…

为什么说中医的本质是医“中”

日前&#xff0c;与一位懂中医的朋友朋友聊天&#xff0c;他言简意赅地指出“中医的本质就是医‘中’”。反思后总结如下&#xff0c;以飨读者&#xff0c;同时欢迎批评指正&#xff01; “中医的本质是医‘中’”强调了中医的核心在于其整体观和辩证方法。中医“中”的本质在于…