pyg的NeighborLoader和LinkNeighborLoader

news2024/11/16 21:57:05

NeighborLoader

1 数据格式要求

需要传入加载的属性值:

class NeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], 
num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], 
input_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, 
input_time: Optional[Tensor] = None, 
replace: bool = False, 
directed: bool = True, 
disjoint: bool = False, 
temporal_strategy: str = 'uniform', 
time_attr: Optional[str] = None, 
transform: Optional[Callable] = None, 
transform_sampler_output: Optional[Callable] = None, 
is_sorted: bool = False, 
filter_per_worker: bool = False, 
neighbor_sampler: Optional[NeighborSampler] = None, **kwargs)

        data: 要求加载 torch_geometric.data.Data 或者 torch_geometric.data.HeteroData 类型数据;

        num_neighbors: 每轮迭代要采样邻居节点的个数,即第i-1轮要为每个节点采样num_neighbors[i]个节点,如果为-1,则代表所有邻居节点都将被包含(一阶相邻邻居),在异构图中,还可以使用字典来表示每个单独的边缘类型要采样的邻居数量;

        input_nodes : 中心节点集合,用来指导采样一个mini-batch内的节点,如果为None,则代表包含data中的所有节点。如果设置为 None,将考虑所有节点。在异构图中,需要作为包含节点类型和节点索引的元组传递。 (默认值:None)

        input_time (torch.Tensor, optional) – 可选值,用于覆盖 input_nodes 中给定的输入节点的时间戳。如果未设置,将使用 time_attr 中的时间戳作为默认值(如果存在)。需要设置 time_attr 才能使其工作。 (默认值:None)

        replace (bool, optional) – 如果设置为 True,将进行替换采样。 (默认值:False)

        directed (bool, optional) – 如果设置为 False,将包括所有采样节点之间的所有边。 (默认值:True)

        disjoint (bool, optional) – 如果设置为 :obj: True,每个种子节点将创建自己的不相交子图。如果设置为 True,小批量输出将有一个批量向量保存节点到它们各自子图的映射。在时间采样的情况下将自动设置为 True。 (默认值:False) 

        temporal_strategy (str, optional) -- 使用时间采样时的采样策略(“uniform”、“last”)。如果设置为“uniform”,将在满足时间约束的邻居之间统一采样。如果设置为“last”,将对满足时间约束的最后 num_neighbors 进行采样。 (默认值:“uniform”)

         transform (callable, optional) – 一个函数/转换,它接受一个采样的小批量并返回一个转换后的版本。 (默认值:None)

        transform_sampler_output (callable, optional) – 接受 SamplerOutput 并返回转换后版本的函数/转换。 (默认值:无)

        **kwargs(可选)—— torch.utils.data.DataLoader 的附加参数,例如 batch_size、shuffle、drop_last 或 num_workers。

2 上述参数使用案例:

(1)当 num_neighbors = [-1]时,获取中心节点所有的一阶邻居;

        batch_size=1,表示中心节点只有一个; 

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]

loader_2 = NeighborLoader(
    data,
    num_neighbors=[-1],
    batch_size=1,
    input_nodes=data.n_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# Data(x=[4, 1433], edge_index=[2, 3], y=[4], train_mask=[4], val_mask=[4], test_mask=[4], n_id=[4], batch_size=1)
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

        画图展示:

         代码中的sampled_data_2中的涉及节点的输出:

sampled_data_2.n_id

# tensor([   0,  633, 1862, 2582])
# 前batch_size个节点为中心节点

(2)当 num_neighbors = [2,3]时,获取中心节点所有的一阶邻居(任选取3个节点)以及一阶邻居的邻居(任选取两个节点);

代码展示:

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]
data.n_id = torch.arange(data.num_nodes)

loader_2 = NeighborLoader(
    data,
    num_neighbors=[2,3],
    batch_size=3,
    input_nodes=data.n_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# Data(x=[11, 1433], edge_index=[2, 14], y=[11], train_mask=[11], val_mask=[11], test_mask=[11], n_id=[11], batch_size=3)
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

         代码中的sampled_data_2中的涉及节点的输出:

sampled_data_2.n_id

# tensor([   0,    1,    2,  633, 2582,  654, 1454, 1701, 1866, 1166, 1862])
# 前batch_size个节点为中心节点

3 获得子图的id的映射

        当实际应用中我们要获取训练集和测试集的子图,因此一般输入在NeighborLoader的input_nodes参数的值对应于训练集的id和测试集的id;

        而获得的边对应的id不是实际大图中的节点id,而是后来按照顺序分配的;

例如:

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader,LinkNeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]
data.n_id = torch.arange(data.num_nodes)
test_id = torch.tensor([i for i in range(100,120)])

loader_2 = NeighborLoader(
    data,
    num_neighbors=[2,3],
    batch_size=3,
    input_nodes=test_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# 
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

 

print(sampled_data_2.edge_index)
print(sampled_data_2.n_id)
print(sampled_data_2.num_nodes)
# 输出
tensor([[ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12,  0, 13, 14,  1, 15,  1, 16, 17,
          2,  8, 18, 19, 20],
        [ 0,  0,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  6,  6,  6,
          7,  7,  8,  8,  8]])
tensor([ 100,  101,  102, 1602, 2056,  281, 1589, 1561, 1623,   95,  315, 2073,
         734, 1628, 1347, 1382, 1745, 2596, 1769, 1772, 1771])
21

 将图进行可视化时,可以映射回大图中的id

2 LinkNeighborLoader

1 数据格式要求

需要传入加载的属性值:

class LinkNeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], 
num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], 
edge_label_index: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, 
edge_label: Optional[Tensor] = None, 
edge_label_time: Optional[Tensor] = None, replace: bool = False, 
directed: bool = True, disjoint: bool = False, 
temporal_strategy: str = 'uniform', 
neg_sampling: Optional[NegativeSampling] = None, 
neg_sampling_ratio: Optional[Union[int, float]] = None, 
time_attr: Optional[str] = None, 
transform: Optional[Callable] = None, 
transform_sampler_output: Optional[Callable] = None, 
is_sorted: bool = False, 
filter_per_worker: bool = False, 
neighbor_sampler: Optional[NeighborSampler] = None, **kwargs)

        作为基于节点的 torch_geometric.loader.NeighborLoader 的扩展派生的基于链接的数据加载器。该加载器允许在无法进行整批训练的大规模图上对 GNN 进行小批量训练

        更具体地说,这个加载器首先从输入边 edge_label_index 集合中选择一个边样本(它可能是原始图中的边,也可能不是原始图中的边),然后通过在每次迭代中采样 num_neighbors 个邻居,从这个列表中存在的所有节点构造一个子图.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/419940.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

进程调度的基本过程

进程调度的基本过程🔎 进程是什么🔎 进程管理🔎 进程中结构体的属性进程标识符(PID)内存指针文件描述符表结构体中与进程调度相关的属性进程的状态进程的优先级进程的上下文进程的记账信息🔎 总结🔎 结尾🔎…

(第十四届蓝桥真题) 整数删除(线段树+二分)

样例输入: 5 3 1 4 2 8 7 样例输出: 17 分析:这道题我想的比较复杂,不过复杂度还是够用的,我是用线段树二分来做的。 我们用线段树维护所有位置的最小值,那么我们每次删除一个数之前先求一遍最小值&a…

停车场管理系统文件录入(C++版)

❤️作者主页:微凉秋意 ✅作者简介:后端领域优质创作者🏆,CSDN内容合伙人🏆,阿里云专家博主🏆 文章目录一、案例需求描述1.1、汽车信息模块1.2、普通用户模块1.3、管理员用户模块二、案例分析三…

mysql:使用终端操作数据库

登录进入终端: mysql -u root -p 展示数据库 SHOW DATABASES; 创建数据库: CREATE DATABASE IF NOT EXISTS RUNOOB_TEST DEFAULT CHARSET utf8 COLLATE utf8_general_ci; 1. 如果数据库不存在则创建,存在则不创建。 2. 创建RUNOOB_TEST数据库…

ElasticSearch安装、启动、操作及概念简介

ElasticSearch快速入门 文件链接:https://pan.baidu.com/s/15kJtcHY-RAY3wzpJZIn4-w?pwd0k5a 提取码:0k5a 有些软件对于安装路径有一定的要求,例如:路径中不能有空格,不能有中文,不能有特殊符号&#xf…

JUC并发编程之ReentrantLock

1. 非公平锁实现原理 加锁解锁流程 构造器默认实现的是非公平锁 public ReentrantLock() {sync new NonfairSync();}NonfairSync 继承 Sync, Sync 继承 AbstractQueuedSynchronizer 没有竞争时 第一个竞争出现时 Thread-1 执行了 CAS 尝试将state 由 0 改为 1&…

Stable Diffusion免费(三个月)通过阿里云轻松部署服务

温馨提示:划重点,活动入口在这里喔,不要迷路了。 其实我就在AIGC_有没有一种可能,其实你早就在AIGC了?阿里云邀请你,体验一把AIGC级的毕加索、达芬奇、梵高等大师作画的快感。阿里云将提供免费云产品资源&…

如何使用evosuite为指定被测方法生成测试用例

目录 省流版本 准备工作 环境 evosuite获取 检验环境 参数解释 怎样表示被测方法 怎样指向被测类 其他参数 参考 省流版本 java -jar .\target\depd\evosuite-1.1.0.jar -generateTests -Dtarget_method"isLenient()Z" -class com.google.gson.stream.…

Midjourney教程(二)——Prompt基本结构

Midjourney教程——Prompt基本结构 Basic Prompt 基础版本的prompt仅仅包含图片的描述,能够满足普通的需求,如下图所示 Advanced Prompt 高级版本的prompt主要包含三个部分,如下图所示 Image Prompts(可选) prompt第一部分是Image&#x…

TCP/IP协议详解

一.引言TCP/IP 是 TCP 和 IP 两种协议群的统称,具体来说,IP 或 ICMP、TCP 或 UDP、TELNET 或 FTP、以及 HTTP 等都属于 TCP/IP 协议二.计算机网络体系结构分层计算机网络体系结构分层计算机网络体系结构分层不难看出,TCP/IP 与 OSI 在分层模块…

【C语言】迷宫问题

【C语言】迷宫问题一. 题目描述二. 思想2.1 算法---回溯算法2.2 思路分析图解三. 代码实现3.1 二维数组的实现3.2 上下左右四个方向的判断3.4 用栈记录坐标的实现3.5 完整代码四. 总结一. 题目描述 牛客网链接:https://www.nowcoder.com/questionTerminal/cf2490605…

STM32看门狗

目录 独立看门狗 IWDG 什么是看门狗? 独立看门狗本质 独立看门狗框图 独立看门狗时钟 分频系数算法: ​编辑 重装载寄存器 键寄存器 溢出时间计算公式 独立看门狗实验 需求: 硬件接线: 溢出时间计算&#xff1…

macOS设置环境变量和别名

因为我的mac所用shell是bash,所以本文中涉及的环境变量和别名配置均在~/.zshrc文件中,且在每次配置完成后,需要执行source ~/.zshrc命令使配置文件生效 环境变量 通过配置环境变量,我们可以将某个路径暴露到全局,这样可以在全局…

周总结(第一周)

3月份3个星期 *** 三个星代表不会 ** 再做 * 加强 题目1-完全二叉树(记忆) 考察数据结构 完全二叉树的深度deplog2(N1)1 完全二叉树节点的深度depiceil(log2(i1))向上舍入 完全二叉树的层次遍历,遍历每层的二叉树计算基础每层的总和,然后找出最大的和…

Talk预告 | 新加坡国立大学郑奘巍 AAAI‘23 杰出论文:大批量学习算法加速推荐系统训练

本期为TechBeat人工智能社区第486期线上Talk! 北京时间3月30日(周四)20:00,新加坡国立大学二年级博士生——郑奘巍的Talk将准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “大批量学习算法加速推荐系统训练”,届时将分…

Kubernetes 多集群网络方案系列 2 -- Submariner 监控

Submariner 是一个用于连接 Kubernetes 集群的跨集群网络解决方案,可以实现集群之间的服务发现、网络通信等功能。 Prometheus 是一个开源的监控和告警系统,专门用于收集、存储和查询各种应用、系统和基础设施的实时指标数据。Prometheus 具备多维数据模…

Java开发 - MySQL主从复制初体验

前言 前面已经学到了很多知识,大部分也都是偏向于应用方面,在应用实战这条路上,博主一直觉得只有实战才是学习中最快的方式。今天带来主从复制给大家,在刚刚开始动手写的时候,才想到似乎忽略了一些重要的东西&#xf…

面试篇-揭开Spring Bean加载的神秘面纱

SpringBean加载完整过程 启动spring容器(创建beanfactory)->加载配置(注解、xml)->实例化bean(执行构造方法)->注入依赖->初始化bean(设置属性值)->使用->销毁 解析和读取 XML 配置文件或注解配置类&#xff0…

Linux嵌入式学习之Ubuntu入门(五)汇编语法学习

系列文章目录 一、Linux嵌入式学习之Ubuntu入门(一)基本命令、软件安装及文件结构 二、Linux嵌入式学习之Ubuntu入门(二)磁盘文件介绍及分区、格式化等 三、Linux嵌入式学习之Ubuntu入门(三)用户、用户组…

synchronized原理、偏向锁、轻量级锁、重量级锁、锁升级

文章目录Synchronized概念自增自减字节码指令临界区竞态条件基本使用原理查看synchronized的字节码指令序列Monitor对象的内存布局Mark Word是如何记录锁状态的偏向锁什么是偏向锁偏向锁延迟偏向偏向锁状态跟踪偏向锁撤销之调用对象HashCode偏向锁撤销之调用wait/notify轻量级锁…