图神经网络并在 TensorFlow 中实现

news2025/1/4 14:41:26
asokraju.medium.com

一、说明

        本文将引导您了解图神经网络 (GNN) 并使用 TensorFlow 实现该网络。在后续的 文章中,我们讨论 GNN 的不同变体及其实现。这是一个分步计划:

  1. 图神经网络 (GNN) 的使用:我们首先讨论 GNN 是什么、它们如何工作以及它们的使用地点。
  2. 理解图:在深入研究 GNN 之前,了解图的基础知识非常重要,包括节点、边、邻接矩阵和图表示。
  3. 理解图神经网络:我们还将简要介绍神经网络的基础知识,因为 GNN 是神经网络的一种。
  4. 图神经网络 (GNN) 的变体
  5. 使用 TensorFlow 实现 GNN:最后,我们将介绍使用 TensorFlow 实现简单 GNN 的过程。

二、图神经网络 (GNN) 的使用

        图神经网络 (GNN) 是一种神经网络,旨在对图数据结构执行机器学习任务。它们对于数据以图形表示的任务特别有用,例如社交网络、分子结构和推荐系统。

        GNN 的工作原理是将信息从节点传播到其邻居。图中的节点根据其邻居的状态进行更新,并且此过程会重复多次迭代。然后可以使用节点的最终状态进行预测。

        例如,在社交网络中,GNN 可用于根据用户朋友的兴趣来预测用户的兴趣。 GNN 将从每个用户的一些初始表示开始,然后根据其朋友的表示更新每个用户的表示。经过几次迭代后,每个用户的最终表示不仅会捕获他们自己的兴趣,还会捕获他们的朋友、朋友的朋友等的兴趣。

三、理解图表:

        图是一种对对象之间的关系进行建模的数学结构。它由节点(也称为顶点)和组成。节点代表对象,边代表这些对象之间的关系。

        例如,在社交网络中,每个人可以由一个节点表示,每个友谊可以由连接两个节点的边表示。

        有两种主要类型的图表:

  1. 无向图:在无向图中,边没有方向。也就是说,如果存在从节点 A 到节点 B 的边,则也存在从节点 B 到节点 A 的边。 Facebook 友谊就是这样的一个示例:如果人 A 是人 B 的朋友,那么人 B 也是人与 A 是朋友。
  2. 有向图:在有向图中,边确实有方向。也就是说,如果从节点 A 到节点 B 存在一条边,并不一定意味着从节点 B 到节点 A 也存在一条边。 Twitter 关注就是一个例子:如果 A 关注了 B,那么它就会关注 B。并不意味着B跟随A。

        图可以用多种方式表示,但最常见的方式之一是通过邻接矩阵。邻接矩阵是一个方阵,其中第 i 行第 j 列中的条目等于节点 i 和 j 之间的边数。对于无向图,邻接矩阵是对称的。

        另一种常见的表示形式是边列表,其中每条边由一对节点表示。

        了解图的这些基础知识对于理解图神经网络的工作原理至关重要,因为它们直接在图结构上运行。

四、理解图神经网络

GNN 是一种神经网络,旨在对图数据结构执行机器学习任务。它们对于数据以图形表示的任务特别有用,例如社交网络、分子结构和推荐系统。

GNN 背后的关键思想是捕获图中连接之间的依赖关系。他们通过聚合相邻节点的特征来为每个节点生成嵌入来实现这一点。然后,这些嵌入可用于执行各种任务,例如节点分类、链接预测和图分类。

以下是 GNN 工作原理的更详细的分步过程:

  1. 节点特征初始化:图中的每个节点都使用特征向量进行初始化。这可能是节点标签的单热编码、特定于节点的一些实值向量,甚至是零向量。
  2. 特征聚合:每个节点聚合其邻近节点的特征向量以更新自己的特征向量。这通常是使用一个函数来完成的,该函数接收节点及其邻居的特征向量并输出一个新的特征向量。该函数可以是简单平均值、加权和或更复杂的函数。
  3. 特征变换:然后对聚合的特征向量进行变换,通常使用线性变换,然后使用非线性激活函数。这与传统神经网络层中发生的情况类似。
  4. 重复步骤 2 和 3:重复步骤 2 和 3 一定次数的迭代。在每次迭代中,节点都会聚合并转换来自越来越大邻域的特征。
  5. 读出:最终迭代后,使用读出函数聚合图中所有节点的特征向量以产生图级输出。

GNN 的优点在于它们可以处理不同大小和形状的图,并且可以捕获图的局部和全局结构。

五、使用 TensorFlow 实现 GNN

有几个构建在 TensorFlow 之上的库提供了各种类型的 GNN 的实现,例如 Graph Nets 和 Spektral。我们可以使用这些库之一来简化实现过程。

首先,您需要安装 Spektral 库。您可以使用 pip 执行此操作:

pip install spektral

安装 Spektral 后,您可以首先导入必要的库:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout
from spektral.layers import GCNConv, global_sum_pool
from spektral.data import DisjointLoader, Dataset
from spektral.datasets import TUDataset

在此示例中,我们将使用 TUDataset,它是用于图分类的基准数据集的集合。

接下来,让我们加载数据集:

dataset = TUDataset('PROTEINS')

这将下载 PROTEINS 数据集,这是蛋白质结构的图形分类数据集。

  1. 读出:在最后一层之后,使用读出函数聚合图中所有节点的特征向量以产生图级输出。

现在,让我们看看如何使用 TensorFlow 中的 Spektral 库实现一个简单的 GraphSAGE 模型:

import spektral
from spektral.layers import GraphSageConv
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout, Dense

# Define the model
class GraphSageModel(Model):
    def __init__(self, n_hidden, n_labels):
        super().__init__()
        self.sage_conv1 = GraphSageConv(n_hidden)
        self.sage_conv2 = GraphSageConv(n_labels)
        self.dropout = Dropout(0.5)
        self.dense = Dense(n_labels, 'softmax')

    def call(self, inputs, training=False):
        x, a = inputs
        x = self.dropout(x, training=training)
        x = self.sage_conv1([x, a])
        x = self.sage_conv2([x, a])
        return self.dense(x)

# Instantiate the model
model = GraphSageModel(n_hidden=64, n_labels=dataset.n_labels)

该模型将由其节点特征表示的图作为输入x、邻接矩阵a和批次索引i.该模型首先对节点特征应用 dropout,然后应用两个图卷积层,将节点特征池化为图级表示,最后应用密集层来预测每个图的类别。

接下来,让我们编译并训练我们的模型:

model = GNN(n_hidden=64, n_labels=dataset.n_labels)
model.compile('adam', 'categorical_crossentropy', ['acc'])
loader = DisjointLoader(dataset, batch_size=32, epochs=10)
model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch)

什么是global_sum_pool represent?

在图神经网络(GNN)的背景下,池化是一种用于将整个图的信息聚合成单个向量表示的技术。这对于图级预测任务特别有用,我们想要对整个图(而不是单个节点或边)进行预测。

global_sum_pool是 Spektral 库提供的一种此类池化操作。顾名思义,它只是将图中所有节点的特征向量相加以生成单个向量。此操作对于图中节点的顺序是不变的,这对于许多基于图的任务来说是一个重要属性。

值得注意的是,求和池化是一种非常简单的池化操作,GNN 中还可以使用许多其他更复杂的池化操作,例如均值池化、最大池化以及更复杂的方法,例如图注意力池化和图同构池化。池化操作的选择会对 GNN 的性能产生重大影响,而最佳选择通常取决于具体的任务和数据。

i 表示 x = self.pool(x, i) 是什么?

函数调用中的i表示每个节点的批次索引。global_sum_pool(x, i)

当您在批量设置中处理图形数据(即单个批次中的多个图形)时,您需要一种方法来指示哪些节点属于哪些图形。这是因为与图像或文本数据不同,批次中的图可以具有不同的大小(即不同数量的节点和边),因此不能简单地将它们堆叠在单个张量中。

批次索引i 是一个向量,它将每个节点分配给批次中的特定图。例如,如果批次中有两个图表,第一个有 3 个节点,第二个有 2 个节点,则批次索引 i 将为 [0, 0, 0, 1, 1]。这表明前三个节点属于第一个图,最后两个节点属于第二个图。

在后续文章中,我们讨论 GNN 的不同变体及其实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1327630.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文阅读<MULTISCALE DOMAIN ADAPTIVE YOLO FOR CROSS-DOMAIN OBJECT DETECTION>

论文链接:https://arxiv.org/pdf/2106.01483v2.pdfhttps://arxiv.org/pdf/2106.01483v2.pdf 代码链接:GitHub - Mazin-Hnewa/MS-DAYOLO: Multiscale Domain Adaptive YOLO for Cross-Domain Object DetectionMultiscale Domain Adaptive YOLO for Cross…

[JS设计模式]Command Pattern

文章目录 举例说明优点缺点完整代码 With the Command Pattern, we can decouple objects that execute a certain task from the object that calls the method. 使用命令模式,我们可以将执行特定任务的对象与调用该方法的对象解耦。 怎么理解 执行特定任务的对…

【SpringBoot】之Security进阶使用

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是君易--鑨,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的博客专栏《SpringBoot开发之Security系列》。&#x1f3af…

解决 Linux git push 贡献者不同(没有出现绿点)的问题

第一步,通过下面的指令,修改 linux git 的配置文件: vi ~/.gitconfig会进入下图界面: 进入本地(Windows)中 git 的设置界面 复制 名称 和 Email 到 gitconfig 里,不要在末尾加 (空…

直排轮滑教程4

蹬地 1,前面练习了蹬地的结构,知道蹬地方向,如何用力。下面来练习具体的蹬地的方法,轮滑蹬地有自己特点。 2,技术方法和特点:蹬地速度快,蹬地有弹性。似跳非跳蹬。 3,四轮着地。轮…

使用PE信息查看工具和Beyond Compare文件比较工具排查dll文件版本不对的问题

目录 1、问题说明 2、修改了代码,但安装版本还是有问题 3、使用PE信息查看工具查看音视频库文件(二进制)的时间戳 4、使用Beyond Compare比较两个库文件的差异 5、找到原因 6、最后 C软件异常排查从入门到精通系列教程(专栏…

NFS原理详解

一、NFS介绍 1)什么是NFS 它的主要功能是通过网络让不同的机器系统之间可以彼此共享文件和目录。 NFS服务器可以允许NFS客户端将远端NFS服务器端的共享目录挂载到本地的NFS客户端中。 在本地的NFS客户端的机器看来,NFS服务器端共享的目录就好像自己的磁…

【蓝桥杯】树的重心

树的重心 图的dfs模板 int dfs(int u) {st[u]true;for(int ih[u];i!-1;ine[i]){int je[i];if(!st[j]){dfs(j);}} }树是这样的。 邻接表: 1: 4->7->2->-1 2: 5->8->1->-1 3: 9->4->-1 4: 6->3->1->-1 5: 2->-1 6: 4->-1 7…

计算机网络 运输层下 | TCP概述 可靠传输 流量控制 拥塞控制 连接管理

文章目录 3 运输层主要协议 TCP 概述3.1 TCP概述 特点3.2 TCP连接RSVP资源预留协议 4 TCP可靠传输4.1 可靠传输工作原理4.1.1 停止等待协议4.1.2 连续ARQ协议 4.2 TCP可靠通信的具体实现4.2.1 以字节为单位的滑动窗口4.2.2 超时重传时间的选择4.2.3 选择确认SACK 5 TCP的流量控…

Python---socket之send和recv原理剖析

1. 认识TCP socket的发送和接收缓冲区 当创建一个TCP socket对象的时候会有一个发送缓冲区和一个接收缓冲区,这个发送和接收缓冲区指的就是内存中的一片空间。 2. send原理剖析 send是不是直接把数据发给服务端? 不是,要想发数据,必须得…

GEE-Sentinel-2月度时间序列数据合成并导出

系列文章目录 第一章:时间序列数据合成 文章目录 系列文章目录前言时间序列数据合成总结 前言 利用每个月可获取植被指数数据取均值,合成月度平均植被指数,然后将12个月中的数据合成一个12波段的时间数据合成数据。 时间序列数据合成 代码…

嵌入式中断理解

一、概念 中断: 在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行。 中断优先级&#x…

YACS(上海计算机学会竞赛平台)一星级题集——水仙花指数

题目描述 定义一个正整数的十进制表示中各位数字的立方和为它的水仙花指数,给定一个整数 n,请计算它的水仙花指数。 例如 n1234 时,水仙花指数为 输入格式 单个整数:表示 n 输出格式 单个整数:表示 n 的水仙花指…

IPC之九:使用UNIX Domain Socket进行进程间通信的实例

socket 编程是一种用于网络通信的编程方式,在 socket 的协议族中除了常用的 AF_INET、AF_RAW、AF_NETLINK等以外,还有一个专门用于 IPC 的协议族 AF_UNIX,IPC 是 Linux 编程中一个重要的概念,常用的 IPC 方式有管道、消息队列、共…

深入探讨开源对话系统:IntelliQ的世界

在人工智能的快速发展时代,开源项目成为了推动技术革新的重要力量。最近,我注意到了一个特别有趣的项目——IntelliQ。这个项目旨在利用大型语言模型(LLM)构建一个多轮问答系统,不仅具备强大的意图识别和词槽填充&…

【Unity】【WebRTC】如何用Unity而不是浏览器接收远程画面

【背景】 之前几篇我们讨论了如何设置信令服务器,如何发送画面给远端以及如何用浏览器查看同步画面,今天来讨论如何实现Unity内部接收画面。 看本篇之前请先看过之前将web服务器设置和基本远程画面功能的几篇博文。(同专栏下查看&#xff09…

使用VSC从零开始Vue.js——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务3:数据可视化

使用Visual Studio Code(VSC)进行Vue开发非常方便,下面是一些基本步骤: 一、下载和安装Vue 官网下载地址Download | Node.js Vue.js是基于Node.js的,所以首先需要安装Node.js,官网下载地址:No…

PMP项目管理 - 范围管理

系列文章目录 系统架构设计 PMP项目管理 - 整合管理 PMP项目管理 - 质量管理 PMP项目管理 - 采购管理 PMP项目管理 - 资源管理 PMP项目管理 - 风险管理 PMP项目管理 - 沟通管理 现在的一切都是为将来的梦想编织翅膀,让梦想在现实中展翅高飞。 Now everything is …

快速搭建Grafana Promethus 服务器监控系统

该文参考文章,其中又遇到一些问题,并解决,当前主要为了记录一下 探针 Grafana Prometheus 之比 Docker 更简单的部署流程 - 承飞之咎本文重在 Grafana Prometheus 探针 方案的部署流程,介绍和更多使用请到:探针 ̵……

【12.20】转行小白历险记 登录+注册页

一、登录注册页面逻辑 写样式布局:垂直居中、编程式路由、调后端接口正则表达式验证用户输入的密码规则校验通过后,跳转页面js兜底校验调后端接口将token值存储到vuex中,实现持久化存储 vuex不是持久化存储的,如果需要持久化存储…