【教程】用GraphSAGE和UnsupervisedSampler进行节点表示学习

news2024/11/17 7:34:42

转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn]

目录

无监督的GraphSAGE

加载 CORA 网络数据

按需采样的无监督GraphSAGE

无监督取样器(UnsupervisedSampler)

提取节点嵌入

节点嵌入的可视化

下游任务

数据拆分

分类器训练

无监督的图表示学习的用途


随着更新迭代,代码不一定能完全运行,仅供学习一下思想~

        Stellargraph Unsupervised GraphSAGE是论文中所述GraphSAGE方法的实现: 大图上的归纳表征学习。W.L. Hamilton, R. Ying, and J. Leskovec arXiv:1706.02216 [cs.SI], 2017。

        本笔记本是一个简短的演示,说明如何使用Stellargraph无监督GraphSAGE来学习CORA引文网络中代表论文的节点的嵌入。此外,这个笔记本展示了在下游节点分类任务中使用学习到的嵌入(按主题对论文进行分类)。请注意,节点嵌入也可用于其他图机器学习任务,如链接预测、社区检测等。

无监督的GraphSAGE

        对无监督GraphSAGE图表示学习方法的高层次解释如下。

        目标: 给定一个图,只使用图的结构和节点特征来学习节点的嵌入,而不使用任何已知的节点类别标签(因此是 "无监督的";关于节点嵌入的半监督学习,见此演示)。

        无监督的GraphSAGE模型: 在无监督GraphSAGE模型中,节点嵌入是通过解决一个简单的分类任务来学习的:给定一大组从图上进行的随机行走中产生的 "正"(目标、背景)节点对(即在随机行走中某个背景窗口内共同出现的节点对),以及同样大的 "负 "节点对(根据某种分布从图上随机选择),学习一个二进制分类器,预测任意节点对是否可能在图上进行的随机行走中共同出现。通过学习这个简单的二进制节点对分类任务,该模型自动学习了从节点及其邻居的属性到高维向量空间中的节点嵌入的归纳映射,这保留了节点的结构和特征相似性。与Node2Vec等算法获得的嵌入不同,这种映射是归纳式的:给定一个新的节点(有属性)及其与图中其他节点的链接(在模型训练期间未见过),我们可以评估其嵌入,而不必重新训练模型

        在我们的无监督GraphSAGE的实现中,节点对的训练集是由图中同等数量的正负(目标,背景)节点对组成。阳性(目标,背景)节点对是在图上随机行走时共同出现的节点对,而负节点对是从图的全局节点度分布中随机抽取的。

        节点对分类器的结构如下:输入的节点对(含节点特征)与图结构一起被送入一对相同的GraphSAGE编码器,产生一对节点嵌入。然后,这些嵌入被送入一个节点对分类层,该层对这些节点嵌入应用一个二进制运算符(例如,连接它们),并将产生的节点对嵌入通过一个线性变换和一个二进制激活(例如,sigmoid),从而为节点对预测一个二进制标签。

        整个模型通过最小化所选择的损失函数(例如,预测的节点对标签和真实链接标签之间的二进制交叉熵),使用随机梯度下降法(SGD)更新模型参数来进行端到端的训练,按要求生成迷你批次的 "训练 "链接并输入模型。

        从经过训练的分类器的编码器部分获得的节点嵌入可以用于各种下游任务。在这个演示中,我们展示了这些如何用于预测节点标签。

# install StellarGraph if running on Google Colab
import sys
if 'google.colab' in sys.modules:
  %pip install -q stellargraph[demos]==1.2.1

# verify that we're using the correct version of StellarGraph for this notebook
import stellargraph as sg

try:
    sg.utils.validate_notebook_version("1.2.1")
except AttributeError:
    raise ValueError(
        f"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>."
    ) from None
import networkx as nx
import pandas as pd
import numpy as np
import os
import random

import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import GraphSAGELinkGenerator
from stellargraph.layer import GraphSAGE, link_classification
from stellargraph.data import UniformRandomWalk
from stellargraph.data import UnsupervisedSampler
from sklearn.model_selection import train_test_split

from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.metrics import accuracy_score

from stellargraph import globalvar

from stellargraph import datasets
from IPython.display import display, HTML

加载 CORA 网络数据

        (参见 "从Pandas加载"演示,了解如何加载数据的细节)。

        Cora数据集由2708份科学出版物组成,分为七个类别之一。引文网络由5429个链接组成。数据集中的每份出版物都由一个0/1值的单词向量描述,表示字典中相应单词的缺席/存在。词典由1433个独特的词组成。

dataset = datasets.Cora()
display(HTML(dataset.description))
G, node_subjects = dataset.load()

print(G.info())

'''
StellarGraph: Undirected multigraph
 Nodes: 2708, Edges: 5429

 Node types:
  paper: [2708]
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5429]
'''

按需采样的无监督GraphSAGE

        无监督GraphSAGE需要一个训练样本,这个样本可以以(目标,上下文)节点对列表的形式提供,也可以用一个UnsupervisedSampler实例来提供,该实例负责按需生成节点对的正负样本。在这个演示中,我们讨论后一种技术。

无监督取样器(UnsupervisedSampler)

        UnsupervisedSampler类接收了一个Stellargraph图的实例。UnsupervisedSampler中的生成器方法负责从图中生成同等数量的正负节点对样本用于训练。这些样本是通过使用UniformRandomWalk对象在图上进行均匀的随机行走而产生的。正面(目标,背景)节点对从散步中提取,对于每个正面的节点对(目标,节点),通过从图的度分布中随机抽取节点来生成相应的负面节点对。一旦样本的数量达到batch_size,生成器就会得到一个正负节点对的列表,以及它们各自的1/0标签。
        在目前的实现中,我们使用统一的随机行走来探索图的结构。漫步的长度和数量,以及开始漫步的根节点都可以由用户指定。根节点的默认列表是图的所有节点,默认的行走次数是1(每个根节点至少有一次行走),默认的行走长度是2(需要在行走中至少有一个超出根节点的节点作为潜在的正面背景)。

1. 指定其他可选参数值:根节点、每个节点的行走次数、每个行走的长度和随机种子。

nodes = list(G.nodes())
number_of_walks = 1
length = 5

2. 创建UnsupervisedSampler实例,并向其传递相关参数。

unsupervised_samples = UnsupervisedSampler(
    G, nodes=nodes, length=length, number_of_walks=number_of_walks
)

        图G和无监督采样器将被用来生成样本。

3. 创建一个节点对生成器:

        接下来,创建节点对生成器,用于取样并将训练数据流向模型。节点对生成器本质上是将节点对(目标,上下文)"映射 "到GraphSAGE的输入中:它要么接受节点对的分批,要么接受一个UnsupervisedSampler实例,该实例按要求生成节点对的分批。生成器从这些节点对中提取带有(目标,上下文)头部节点的2跳子图,并将其与相应的二进制标签一起送入带有GraphSAGE节点编码器的节点对分类器的输入层,用于模型参数的SGD更新。

指定:

  • 迷你批大小(每个迷你批的节点对数量)。
  • 训练模型的 epochs 数目。
  • GraphSAGE的1跳和2跳邻居样本的大小:

        注意,num_samples列表的长度定义了GraphSAGE编码器的层数/迭代数。在这个例子中,我们定义的是一个2层的GraphSAGE编码器。

batch_size = 50
epochs = 4
num_samples = [10, 5]

        下面我们将展示节点对生成器与无监督采样器的工作,它将按要求生成样本。

generator = GraphSAGELinkGenerator(G, batch_size, num_samples)
train_gen = generator.flow(unsupervised_samples)

        建立模型:一个2层GraphSAGE编码器作为节点表示学习器,在连接的(引用-论文,被引用-论文)节点嵌入上有一个链接分类层。

        该模型的GraphSAGE部分,两个GraphSAGE层的隐藏层大小为50,有一个偏置项,没有剔除。(可以通过指定一个正的辍学率来开启辍学,0<辍学<1)。注意,layer_sizes列表的长度必须等于num_samples的长度,因为len(num_samples)定义了GraphSAGE编码器的跳数(层数)。

layer_sizes = [50, 50]
graphsage = GraphSAGE(
    layer_sizes=layer_sizes, generator=generator, bias=True, dropout=0.0, normalize="l2"
)

# Build the model and expose input and output sockets of graphsage, for node pair inputs:
x_inp, x_out = graphsage.in_out_tensors()

        最终的节点对分类层,采取由graphsage编码器产生的一对节点嵌入,对其应用二进制运算符以产生相应的节点对嵌入(ip为内积;二进制运算符的其他选项可以通过运行带有?link_classification的单元查看),并将其传递给稠密层:

prediction = link_classification(
    output_dim=1, output_act="sigmoid", edge_embedding_method="ip"
)(x_out)

        将GraphSAGE编码器和预测层堆叠到Keras模型中,并指定损失。

model = keras.Model(inputs=x_inp, outputs=prediction)

model.compile(
    optimizer=keras.optimizers.Adam(lr=1e-3),
    loss=keras.losses.binary_crossentropy,
    metrics=[keras.metrics.binary_accuracy],
)

4. 训练模型。

history = model.fit(
    train_gen,
    epochs=epochs,
    verbose=1,
    use_multiprocessing=False,
    workers=4,
    shuffle=True,
)

        请注意,多进程是关闭的,因为在有大量节点对的训练集时,多进程会随着数据在不同进程之间的传输而大大减慢训练过程。

        另外,在Keras 2.2.4及以上版本中可以使用多个工作者,由于多线程,它可以大大加快训练过程。

提取节点嵌入

        现在,节点对分类器已经训练完毕,我们可以使用其节点编码器部分作为节点嵌入评估器。下面我们将节点嵌入评估为GraphSAGE层栈输出的激活,并将其可视化,根据主题标签给节点着色。

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from stellargraph.mapper import GraphSAGENodeGenerator
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

建立一个新的基于节点的模型

        (src, dst) 节点对分类器模型有两个相同的节点编码器:一个用于节点对中的源节点,另一个用于传递给模型的节点对中的目的节点。我们可以使用这两个相同的编码器中的任何一个来评估节点嵌入。下面我们通过定义一个新的Keras模型来创建一个嵌入模型,x_inp_src(x_inp中奇数元素的列表)和x_out_src(x_out中的第1个元素)分别作为输入和输出。注意,这个模型的权重与之前训练的节点对分类器中相应的节点编码器的权重相同。

x_inp_src = x_inp[0::2]
x_out_src = x_out[0]
embedding_model = keras.Model(inputs=x_inp_src, outputs=x_out_src)

        我们还需要一个节点生成器来给embedding_model提供图的节点。我们想为图中的所有节点评估节点嵌入:

node_ids = node_subjects.index
node_gen = GraphSAGENodeGenerator(G, batch_size, num_samples).flow(node_ids)

        我们现在使用node_gen将所有节点送入嵌入模型并提取它们的嵌入:

node_embeddings = embedding_model.predict(node_gen, workers=4, verbose=1)

节点嵌入的可视化

        接下来我们用t-SNE将节点嵌入可视化。节点的颜色描述了节点的真实类别(在Cora数据集为主题的情况下)。

node_subject = node_subjects.astype("category").cat.codes

X = node_embeddings
if X.shape[1] > 2:
    transform = TSNE  # PCA

    trans = transform(n_components=2)
    emb_transformed = pd.DataFrame(trans.fit_transform(X), index=node_ids)
    emb_transformed["label"] = node_subject
else:
    emb_transformed = pd.DataFrame(X, index=node_ids)
    emb_transformed = emb_transformed.rename(columns={"0": 0, "1": 1})
    emb_transformed["label"] = node_subject

alpha = 0.7

fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(
    emb_transformed[0],
    emb_transformed[1],
    c=emb_transformed["label"].astype("category"),
    cmap="jet",
    alpha=alpha,
)
ax.set(aspect="equal", xlabel="$X_1$", ylabel="$X_2$")
plt.title(
    "{} visualization of GraphSAGE embeddings for cora dataset".format(transform.__name__)
)
plt.show()

        观察到嵌入空间中相同颜色的节点集中在一起,说明相同主题的论文的嵌入是相似的。我们在此再次强调,节点嵌入是以无监督的方式学习的,没有使用真实的类别标签。

下游任务

        使用无监督的GraphSAGE计算的节点嵌入可以作为节点特征向量用于下游任务,如节点分类。

        在这个例子中,我们将使用节点嵌入来训练一个简单的逻辑回归分类器来预测Cora数据集中的论文题目。

# X will hold the 50 input features (node embeddings)
X = node_embeddings
# y holds the corresponding target values
y = np.array(node_subject)

数据拆分

        我们把数据分成训练集和测试集。

        我们使用5%的数据进行训练,其余95%的数据作为测试集。

X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=0.05, test_size=None, stratify=y
)

分类器训练

        我们在训练数据上训练一个Logistic回归分类器。

clf = LogisticRegression(verbose=0, solver="lbfgs", multi_class="auto")
clf.fit(X_train, y_train)

        预测持有的测试集。计算分类器在测试集上的准确性。

y_pred = clf.predict(X_test)
accuracy_score(y_test, y_pred)

        获得的准确率相当不错,比使用node2vec获得的节点嵌入要好,node2vec忽略了节点属性,只考虑了图结构(见这个演示)。

预测的类别

pd.Series(y_pred).value_counts()

真正的类

pd.Series(y).value_counts()

无监督的图表示学习的用途

        无监督的GraphSAGE学习无标签的图节点的嵌入。这是非常有用的,因为大多数现实世界的数据通常都是无标签的,或者有嘈杂的、不可靠的或稀疏的标签。在这种情况下,通过利用图形结构和节点的特征来学习图形中节点的低维有意义表示的无监督技术是非常有用的。
        此外,GraphSAGE是一种归纳技术,使我们能够获得未见过的节点的嵌入,而不需要重新训练嵌入模型。也就是说,GraphSAGE不是为每个节点训练单独的嵌入(如node2vec等算法中学习节点嵌入的查询表),而是学习一个函数,通过从每个节点的本地邻域采样和聚合属性,并将这些属性与节点自身的属性相结合,来生成嵌入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/544005.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8.1.0版本ELK搭建,不开启xpack认证机制

8.1.0版本ELK搭建&#xff0c;不开启xpack认证机制 部署环境安排下载安装包服务器环境配置部署elasticsearch部署kibana部署logstash部署httpd&#xff0c;filebeat配置kibana页面 部署环境安排 ip部署服务192.168.0.121kibana,elasticsearch192.168.0.83elasticsearch,logsta…

Spring boot 注解@Async不生效 无效 不起作用

今天在做公司项目时&#xff0c;有一个发邮件的需求。所以写了一个发送邮件的方法后来发现发邮件很慢&#xff0c;导致接口响应也很慢。于是我便想到要使用异步调用去处理这个方法。于是我把注解Async 加到了自己service类下的一个发邮件的一个方法&#xff0c;后来发现并没有生…

Push rejected,用Git修改已提交的注释

问题&#xff1a;有时候因注释与git规定的模板不匹配&#xff0c;会导致远程提交被拒绝 Push rejected 解决&#xff1a;修改不符合规范的注释再push即可 1、打开命令窗口 在项目根目录下右键点击出 Git批处理命令窗口。 2、查看已提交的commit 运行命令&#xff1a;git reba…

dataease源码阅读

源码&#xff1a;https://gitee.com/fit2cloud-feizhiyun/DataEase.git 文件夹目录 1.仪表盘主路由&#xff1a;frontend/src/views/panel |-- penel |-- index.vue |-- panel.js |-- appTemplate | |-- AppTemplateContent.vue | |-- index.vue | |-- component | |-- AppT…

华为OD机试真题 Java 实现【机器人活动区域】【2023Q1 200分】

一、题目描述 现有一个机器人&#xff0c;可放置于 M N的网格中任意位置&#xff0c;每个网格包含一个非负整数编号。当相邻网格的数字编号差值的绝对值小于等于 1 时&#xff0c;机器人可在网格间移动 问题&#xff1a;求机器人可活动的最大范围对应的网格点数目。 说明&a…

ESP32-C2开发板Homekit例程

准备 1.1硬件ESP32 C2开发板&#xff0c;如图1-1所示 图1-1 ESP32 C2开发板 1.2软件 CozyLife APP可以在各大应用市场搜索下载&#xff0c;也可以扫描二维码下载如图1-2所示 HomeKit flash download tool 烧录工具 esp32c2 homkit演示固件 烧录教程 打开flash_download_to…

每日一题161——对角线遍历

给你一个大小为 m x n 的矩阵 mat &#xff0c;请以对角线遍历的顺序&#xff0c;用一个数组返回这个矩阵中的所有元素。 示例 1&#xff1a; 输入&#xff1a;mat [[1,2,3],[4,5,6],[7,8,9]] 输出&#xff1a;[1,2,4,7,5,3,6,8,9] 示例 2&#xff1a; 输入&#xff1a;mat …

【大数据学习篇8】 热门品类Top10分析

在HBase命令行工具中执行“list”命令&#xff0c;查看HBase数据库中的所有数据表。学习目标/Target 掌握热门品类Top10分析实现思路 掌握如何创建Spark连接并读取数据集 掌握利用Spark获取业务数据 掌握利用Spark统计品类的行为类型 掌握利用Spark过滤品类的行为类型 掌握利用…

【嵌入式烧录刷写文件】-1.4-移动Motorola S-record(S19/SREC/mot/SX)中指定地址范围内的数据

案例背景&#xff08;共5页精讲&#xff09;&#xff1a; 有如下一段S19文件&#xff0c;将源地址范围0x9100-0x9104中数据&#xff0c;移动至一个“空的&#xff0c;未填充的”目标地址范围0xA000-0xA004。 S0110000486578766965772056312E30352EA6 S123910058595A5B5C5D5E5…

调用返回风格

主程序子程序 面向过程 单线程控制&#xff0c;把问题划分为若干个处理步骤&#xff0c;构件即为主程序和子程序&#xff0c;子程序通常可合成为模块。过程调用作为交互机制&#xff0c;即充当连接件的角色。调用关系具有层次性&#xff0c;其语义逻辑表现为主程序的正确性取…

nodejs微信小程序 vue+uniapp停车场车位管理系统sringboot+python

使用微信小程序进行应用开发&#xff0c;使用My SQL软件搭建数据库&#xff0c;管理后台数据并使用Java语言进行程序设计&#xff0c;借鉴国内现有的停车场管理系统&#xff0c;在他们的基础上进行增减和创新&#xff0c;使用Photoshop完成升降式停车场管理系统的界面部件设计&…

Python学习30:存款买房(C)

描述‪‬‪‬‪‬‪‬‪‬‮‬‪‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‪‬‭‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‮‬‪‬‪‬‪‬‪‬‪‬‮‬‭‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‪‬‪‬‪‬‪‬‪‬‪‬‮‬‭‬‫‬‪‬‪‬‪‬‪‬‪‬‮‬‫‬‪‬ 你刚刚大学毕业&#xff0c;…

龙蜥开发者说:构建软件包?不,是构建开源每一个角落!| 第 20 期

「龙蜥开发者说」第 20 期来了&#xff01;开发者与开源社区相辅相成&#xff0c;相互成就&#xff0c;这些个人在龙蜥社区的使用心得、实践总结和技术成长经历都是宝贵的&#xff0c;我们希望在这里让更多人看见技术的力量。本期故事&#xff0c;我们邀请了龙蜥社区开发者任博…

JavaWeb-Ajax的学习

Ajax 今日目标&#xff1a; 能够使用 axios 发送 ajax 请求熟悉 json 格式&#xff0c;并能使用 Fastjson 完成 java 对象和 json 串的相互转换使用 axios json 完成综合案例 概述 AJAX (Asynchronous JavaScript And XML)&#xff1a;异步的 JavaScript 和 XML。 我们先来…

LED显示屏的部件组成

LED显示屏通常由以下几个主要部件组成&#xff1a; LED模块&#xff1a;LED模块是构成LED显示屏的基本单元&#xff0c;包含多个LED发光元件以及相应的电路和连接器。LED模块通常以方形或长方形的形式存在&#xff0c;可以根据需要组合成各种尺寸和形状的显示屏。免费送你Led模…

【2023秋招】华为od-4.14三道题思路题解

2023大厂笔试模拟练习网站&#xff08;含题解&#xff09; www.codefun2000.com 最近我们一直在将收集到的各种大厂笔试的解题思路还原成题目并制作数据&#xff0c;挂载到我们的OJ上&#xff0c;供大家学习交流&#xff0c;体会笔试难度。现已录入200道互联网大厂模拟练习题&…

Go Etcd 分布式锁实战

1 分布式锁概述 谈到分布式锁&#xff0c;必然是因为单机锁无法满足要求&#xff0c;在现阶段微服务多实例部署的情况下&#xff0c;单机语言级别的锁&#xff0c;无法满足并发互斥资源的安全访问。常见的单机锁如Java的jvm锁Lock、synchronized&#xff0c;golang的Mutex等 对…

mysql8之前如何实现row_number() over(partition by xxx order by xxx asc/desc)

文章目录 背景问题分析难点解决方案&#xff1a;总结公式多字段作为分组如何处理 背景 最近笔者在进行对广告业务的数据统计时遇到这种情况&#xff0c;业务方嫌弃离线数仓太慢&#xff0c;又无需太高的实时性本该使用即席查询的OLAP去做&#xff0c;但是当前公司调研的OLAP还…

Unity 2022 版本 寻路 NavMesh

首先装包 先给地图 和 阻挡 设置为静态 然后给地上行走的地方 添加组件 可以直接bake 然后会显示蓝色的可行走路径 player 添加插件 然后给角色添加脚本 using System.Collections; using System.Collections.Generic; using UnityEngine;public class PlayerMove : Mon…

SpringBoot自动配置底层源码解析

1&#xff0c;配置分类 对于一个Spring项目&#xff0c;主要就是有两种配置 一种是类似端口号、数据库地址、用户名密码等一种是各种Bean&#xff0c;比如整合Mybatis需要配置的MapperFactoryBean&#xff0c;比如整合事务需要配置DataSourceTransactionManager SpringBoot中…