Inductive Representation Learning on Large Graphs 论文/GraphSAGE学习笔记

news2024/9/21 0:37:32

1 动机

1.1 过去的方法

现存的方法大多是transductive的,也就是说,在训练图的时候需要将整个图都作为输入,为图上全部节点生成嵌入,每个节点在训练的过程中都是可知的。举个例子,上一次我学习了GCN模型,它的前向传播表达式为:

H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)}=σ(\widetilde D^{- \frac{1}{2}} \widetilde A \widetilde D^{- \frac{1}{2}} H^{(l)} W^{(l)} ) H(l+1)=σ(D 21A D 21H(l)W(l))

可以看出,对GCN的训练需要将整个图的邻接矩阵作为输入,这不利于大图的训练,因为电脑的内存可能无法支持如此巨大的输入。同时,也没有办法对图进行很好的切割,不利于分布式训练。

并且现实中很多应用的数据都会不断地变化更新,采用这种transductive的训练方式对于新增节点的情况需要进行重新训练,这增大了计算开销。

1.2 GraphSAGE

为了解决这个问题,本文的作者们提出了inductive的方法—GraphSAGE。该方法不需要将整图输入来为图中所有节点生成嵌入,而是通过对节点的领域里的邻居进行采样和聚合的方式来为独立的节点生成嵌入。因此,GraphSAGE能更好地应对unseen节点,不需要对模型重新训练。

请添加图片描述

2 流程

2.1 算法1:前向传播

算法思想:在每一层,每个节点从自己的领域聚合n个邻居的信息,然后将聚合的信息和自身信息进行加权连接并乘上非线性激活函数。随着层的增加,节点能聚合到的邻居阶数也会增加。

算法的流程如下图所示:

请添加图片描述

  • N ( v ) N(v) N(v)是从集合 { u ∈ V : ( u , v ) ∈ E } \{ u \in V : (u,v)\in \mathcal{E} \} {uV:(u,v)E}中用统一抽样的方法抽取固定个数的节点

总结一下,GraphSAGE的前向传播流程可以分为以下三步:

  1. Sample : 通过特定的方法从节点的邻居抽取固定个数的邻居
  2. Aggregate :通过特定的方法聚合抽取出来的邻居的信息
  3. Concat : 将聚合后的信息加上自身的信息从而更新节点的特征值

灵感来源:WL算法(计算图同构的算法,可以比较两个图的相似性),将WL算法种的哈希函数变成了可训练的神经网络聚合器

定理1:对于任何图,如果每个节点的特征不同(并且模型足够高维),算法 1 都存在一个参数设置使得它可以将该图中的聚类系数逼近到任意精度

2.2 采样器 Sampler

采样器的作用是选取固定个数的节点邻居,从而保持每个batch的大小固定。在本文中,作者固定大小为K,其中,对于不足邻居个数少于S的节点,则全部采样。

具体算法:

如果邻居个数小于采样数

  • sample全部邻居

如果邻居个数大于采样数

  • 如果总邻居的数量小于设定值(本论文中为21
    • 则每次在 0~n-i 范围内抽取其中一个邻居 j ,然后把将该选择的位置 j 上的邻居变为 n-i-1 的位置上的邻居,i-1 后开始下一次选择
  • 如果总邻居的数量大于设定值(本论文中为21
    • 则设立一个select_add列表存储已选择的邻居下标信息,记录选择的邻居已经在select_add列表中存在,则重新随机sample一个邻居

2.3 聚合器 Aggregator

聚合器的作用是聚合邻居信息,在本文中会对无序的数组集合(也就是节点的邻居集合)进行操作。

理想情况下,聚合函数在可训练并且能够保持强表达能力的同时还要是对称的。聚合函数的对称性确保我们的神经网络模型可以被训练并应用于任意排序的节点邻域特征集。

作者总共设计了3种聚合邻居信息的方式,分别是:

Mean aggregator

这个方法将传统的transductive GCN的传播规则变成了inductive的方式,用以下的公式来代替聚合更新的过程(没有concatenation操作):

h v k ← σ ( W ⋅ M E A N ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) ) h^k_v \leftarrow \sigma (W \cdot MEAN( \{ h_v^{k-1} \} \cup \{ h_u^{k-1} , \forall u \in \mathcal{N}(v) \} )) hvkσ(WMEAN({hvk1}{huk1,uN(v)}))

LSTM aggregator

LSTM相比Mean方法,有着更好的表达能力,但不对称。

由于LSTM需要输入是有序的,作者将节点的邻居顺序随机打乱作为输入。

Pooling aggregator

Pooling既有对称性又是可训练的,作者在本文种选择了最大池化的方法,也就是说,在聚合的时候,只选择计算值最大的邻居作为最终聚合的信息,其公式为:

A G G R E G A T E k p o o l = m a x ( { σ ( W p o o l h u i k + b ) , ∀ u i ∈ N ( v ) } ) AGGREGATE_k^{pool} = max(\{ \sigma (W_{pool} h_{u_i}^k +b), \forall u_i \in \mathcal{N}(v) \}) AGGREGATEkpool=max({σ(Wpoolhuik+b),uiN(v)})

其中,作者没有选择平均池的原因是作者发现平均池和最大池方法的差距不大。

2.4 更新 Concat

if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

源码中的连接方式非常直接,将邻居信息连接到自身信息后面。

2.5 损失函数

无监督

J G ( z u ) = − l o g ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) l o g ( σ ( − z u T z v n ) ) J \mathcal{G} (z_u) = - log(\sigma (z_u^T z_v)) - Q \cdot E_{v_n \sim P_n(v)}log(\sigma (-z_u^T z_{v_n})) JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzvn))

  • v v v 是同时出现在节点 u 附件的固定随机游走长度的节点
  • σ \sigma σ 是sigmoid函数
  • P n P_n Pn 是负采样分布
  • Q Q Q 是负采样数量
  • z u z_u zu 是节点u的特征,由节点u的邻居的特征得到

该基于图的损失函数鼓励相近的节点拥有相似的表征,而相离的节点拥有不同的表征

有监督

交叉熵损失

3 实验

3.1 实验设置

4个baseline:

  1. 随机分类器 (Random)
  2. 基于特征的逻辑回归分类器(忽略图结构)(Raw features)
  3. DeepWalk算法(作为基于分解的代表方法)
  4. 结合原始特征和DeepWalk嵌入的方法 (DeepWalk + features)

超参数设置:

  • 网络层数: K = 2 K=2 K=2
    • 理由:选择K=2相比k=1可以提高10-15%的准确率,但是训练时长会提高10-100倍(取决于采样个数)
  • 采样个数: S 1 = 25 , S 2 = 10 S_1=25,S_2=10 S1=25,S2=10
  • Batch size:512

三个实验,每个实验都会进行有监督和无监督训练进行对比

实验一:在一个大型引文数据集(Citation)上预测论文类别

  • 数据集:Thomson Reuters Web of Science Core Collection中2000-2005的生物领域论文
  • 图类型:无向图,进化图(数据会不断更新,也就是说,会产生很多unseen节点)
  • 类别数:6
  • 节点数:302424
  • 平均度数:9.15
  • 训练集:2000-2004年论文
  • 测试集:2005年论文(30%为验证集,用于调整超参数)

实验二:预测不同Reddit帖子所属的社区

  • 数据集:作者对2014.09发布的贴子建立了图数据集,节点标签为社区
  • 图类型:进化图
  • 节点(帖子)个数:232965
  • 类别(社区)数:50
  • 平均度数:492
  • 训练集:前20天的数据
  • 测试集:后20天的数据(30%为验证集)

实验三:总结多种PPI(生物蛋白质-蛋白质作用)图(每个图对应不同的人体组织),根据基因本体的细胞功能来为蛋白质的功能分类

  • 数据集:Molecular Signatures Database
    • 特征:positional gene sets, motif gene sets and immunological signatures
    • 标签:gene ontology sets
  • 类别数:121
  • 节点数:2373
  • 平均度数:28.8
  • 图数量:20
  • 测试集:2个图(另选2个图作为验证集)

3.2 实验结果

请添加图片描述

总体而言,基于LSTM和Pool的聚合器在平均表现和最佳表现次数上都是最好的。

4 问题

4.1 Mean aggregator

疑问来源:作者说Mean aggregator是对GCN的修改,将transductive变成了inductive?但是从源码上看,作者只是简单地对采样得到的邻居信息进行加权平均的操作。

解答:作者这里可能只是用到了卷积的思想,也就是AWX中的W卷积核。

4.2 采样器的设计

疑问来源:在运行GraphSAGE进行分类任务时,发现相同设置下的运行结果相差还是比较大的,在分类准确率上大约会有1%-5%的误差。这种分类不稳定性可能是由采样器的设计引起的。

解答:可以改变采样器的设计,比如按度来排序进行更有代表性的抽样,从而使结果更稳定。

4.3 聚合函数的对称性

疑问来源:作者谈到,理想的聚合函数需要在可训练、有强表达能力的同时具有对称性,这是因为聚合函数的对称性确保我们的神经网络模型可以被训练并应用于任意排序的节点邻域特征集。为什么对称性能够确保上述情况?

解答:对称性指的是对于输入的K个邻居,不同的顺序不会影响最终的结果。

4.4 图的改变

疑问来源:我们的理解为,GraphSAGE中每个batch存放了图中n个节点sample到的K个邻居信息,从而可以分为多个minibatch来进行聚合更新的计算。但是在看源码时,发现输入为整图的邻接矩阵,并通过邻接矩阵来得到每个节点的邻居。那么当图的结构改变时,或者加入不可见的结点时,是不是又要重新输入整图的邻接矩阵,还是说只需要输入新增节点及其邻居信息即可?

解答:接下来我们会看相关部分的源码来理解作者的做法。

4.5 Concat维度的问题

疑问来源:由于作者在进行concat的时候直接进行连接的操作,那么每一次concat都会使原有数据的维度变为两倍,是如何进行降维的?

output = tf.concat([from_self, from_neighs], axis=1)

解答:

第一层:定义权重矩阵为128 by 1433*2。concat后的数据为n by 1433 *2,点乘后得到 128 by n的矩阵,达成降维。

enc1 = Encoder(features, 1433, 128, adj_lists, agg1, *gcn*=True, *cuda*=False)

第二层:定义权重矩阵为128 by 128,再次达到降维。

enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
            base_model=enc1, gcn=True, cuda=False)

4.6 权值矩阵问题

疑问来源:看论文时,思路还是比较清晰的,总共有3个地方可以进行权重的训练:1 聚合器中的权重矩阵;2 连接后用于降维的权重矩阵;3 用于分类的权重矩阵。但是在看源码的时候,对GraphSAGE训练了哪些权重矩阵产生了疑惑

解答:对于MEAN方法,除去用于分类的权重矩阵,总共有2个权重矩阵,分别是2层神经网络的GCN公式权重矩阵,而对于其他聚合方法,聚合器的权重矩阵只有一个,两层神经网络又分别各有一个用于降维的连接权重矩阵。

请添加图片描述

请添加图片描述

4.7 GraphSAGE 和 GCN的本质区别

疑问来源:来自于GCN作者的留言(如下

请添加图片描述

解答:说GCN和GraphSAGE最大的区别在于采样的方式其实是没有问题的。以minibatch为例,GCN可以在每个batch中存放含有固定个数节点的子图的邻接矩阵,这样同样可以保证batch size的一致,但采样得到的邻居个数在这种情况下是不固定的,在子图中有多有少。而GraphSAGE则尽量固定了采样的邻居个数,对于邻居个数大于K的节点,则采样K个邻居。按上述的思想,GraphSAGE同样可以推广到inductive,让新增的unseen节点加入所在的含有n个节点的子图进行计算,同样可以得到新增节点的特征。

但是,我认为其本质区别还是训练的对象不同,GCN是为整个图上所有节点生成嵌入,也就是训练得到的函数是对全图而言的。而GraphSAGE则是为单个节点生成嵌入,训练得到的函数是对单个节点而言,聚合邻居并连接自身信息的函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/58284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

考研数据结构大题整合_组三(LZH组)

考研数据结构大题整合 目录考研数据结构大题整合三、LZH组LZH 组一LZH 组二LZH 组三LZH 组四LZH 组五LZH 组七三、LZH组 LZH 组一 给出如图所示的无向图G的邻接矩阵和邻接表两种存储结构. (2)解答下面的问题(6分) (…

二、进程管理(五)死锁

目录 5.1死锁的定义和产生条件 5.2死锁的处理策略 5.2.1死锁预防 5.2.2死锁避免 5.2.3死锁检测和解除 5.1死锁的定义和产生条件 在并发环境下,各进程因竞争资源而造成的一种互相等待对方手里的资源,导致各进程都阻塞,都无法向前推进的现…

【GlobalMapper精品教程】034:创建漫游动画并制作漫游视频的方法

本实例讲解在globalmapper中根据路径创建漫游动画,并制作漫游视频的方法。 文章目录 一、绘制漫游路径二、创建3D虚拟漫游三、播放虚拟漫游四、保存虚拟漫游实验数据可以是点云数据、DEM、三维模型等,本文加载数字表面模型DSM进行演示。 一、绘制漫游路径 同创建矢量线状数据…

海带软件分享——日常办公学习软件分享(收藏)

>>>深度学习Tricks&#xff0c;第一时间送达<<< &#x1f680; 写在前面 &#x1f431;‍&#x1f3cd; 本期开始&#xff0c;小海带会定期推荐一些日常办公学习软件及趣味网址&#xff0c;供大家交流参考 ~ 小伙伴们记得一键三连喔&#xff01;&#x1f6…

几款好用到爆炸的在线画图工具

前言 实际工作中&#xff0c;我们经常会编写文档以及制作图表。尤其是对一名优秀的攻城狮来说&#xff0c;经常会用各种各样的软件来制作流程、思维导图、思维笔记等。一个良好的思维导图能系统概括项目工程的整体结构和开发的系统框架。要想制作一个完美的流程图、思维导图离不…

菜狗杯Misc一层一层一层地剥开我的♥wp

目录一、原题二、解题步骤对jpg图片的处理对文件名是一个心形的数据文件的处理base100解码这题完全是看着官方wp复现的&#xff0c;感觉涉及的步骤比较多但每一步本身不难&#xff0c;多记录一遍加深印象。 一、原题 原题给的是一个叫myheart.zip的文件&#xff0c;但尝试解压…

高通开发系列 - ALSA声卡驱动中音频通路kcontrol控件

By: fulinux E-mail: fulinux@sina.com Blog: https://blog.csdn.net/fulinus 喜欢的盆友欢迎点赞和订阅! 你的喜欢就是我写作的动力! 目录 高通开发系列 - ALSA声卡驱动中音频通路kcontrol控件问题背景高通音频通路如何建立widget和routemixer类控件名组合mixer类控件名拼接…

CRM客户关系管理系统(含源码+论文+答辩PPT等)

该项目采用技术&#xff1a;JSP Servlet MySQLjdbccssjs等相关技术&#xff0c;项目含有源码、文档、配套开发软件、软件安装教程、项目发布教程等 项目功能介绍&#xff1a; 系统管理&#xff1a;用户登录退出、个人资料修改 客户管理&#xff1a;客户信息管理、客户来源、联系…

Softmax回归——动手学深度学习笔记

Softmax回归&#xff0c;虽然它的名称叫做回归&#xff0c;其实它是一个分类问题。 回归VS分类 回归估计一个连续值 如&#xff1a;回归估计下个月的房价 分类预测一个离散类别 如&#xff1a; &#xff08;1&#xff09;MNIST&#xff1a;手写数字识别&#xff08;10类&…

初识springmvc

狂神的servlet回顾就不在这里写了。可以翻之前的笔记。 原生开发&#xff1a; 创建webapp的maven项目。 也就是四个文件 &#xff08;不用思考里面的代码&#xff0c;直接CV先走一遍流程&#xff09; HelloController&#xff1a; package com.Li.controller;import org.sp…

SecureCRT之Xmodem操作步骤

以锐捷S3760为例&#xff1a; 故障现象&#xff1a;s3760无法加载&#xff0c;需要重刷RGOS。 一、使用控制线连接s3760&#xff0c;开机加载引导&#xff0c;按Ctrl_B进入“BOOT MENU”页面&#xff1a; 选择【0】进入XModem操作界面&#xff1a; 说明&#xff1a; 0--更新…

android 开发——疑难杂症ANR简单介绍与解析

一、ANR介绍 ANR-application not response&#xff0c;应用无响应&#xff0c;应用开发者一般是关注自己的APP进程有没有出现&#xff0c;系统开发者会关注当前系统运行起来后整体上所有的APP进程有没有出现ANR&#xff0c;从这句话可以知道&#xff0c;只有应用进程的主线程…

多元正态分布-参数估计-书后习题回顾总结

重点考察知识点汇总 协方差矩阵 协方差矩阵为对称矩阵协方差矩阵的对角线为各分量的方差&#xff0c;其余位置(i,j)(i,j)(i,j)表示的是分量iii和分量jjj的协方差 多元正态分布的线性组合仍然服从多元正态分布 设X∼Np(μ,Σ)X\sim N_{p}(\mu,Σ)X∼Np​(μ,Σ)&#xff0c;B…

Python从零到就业

Python面向对象编程五步曲基础阶段01-Python的注释02 乱码原因02-Python变量03-数据类型04-数据类型转换05-Python运算符06-Python输入07-Python输出08-占位格式符(补充)09-Python分支10-Python循环(while)11-Python循环(for)12-Python循环打断待更新基础阶段 01-Python的注释…

C++输出四舍五入的一些小问题

嗯…今天刚去练了一会简单题 就我大一刚上学做的那种题&#xff0c;嗯&#xff0c;然后我发现我还是得调试&#xff0c;想骂人了&#xff0c;就啥样的题呢, 嗯,就这样的题&#xff0c;虽然我大一可能也过不了这种题&#xff0c;hh 现在题目里面要求一些四舍五入的问题 刚才没整…

网站都变成灰色的了,代码是怎么实现的呢?

今天来聊一聊页面的滤镜&#xff0c;或者说换肤更合适些。根据技术栈不同&#xff0c;页面换肤可以分为 web 端和 app 端&#xff0c;因此本文通过以下两部分介绍 PC 端 APP 端 一、PC 端 有关 PC 端的一键换肤&#xff0c;这个操作常用&#xff0c;所以大概率是有某个全局字…

Spring框架(三):SpringAop思想底层实现和日志应用(一):Spring代理实现

Spring框架&#xff08;三&#xff09;&#xff1a;SpringAop思想底层实现和应用&#xff08;日志&#xff09;引子Aop简介通过SpringBean实现Aop引子 痛定思痛&#xff0c;主要问题出现在自己雀氏不熟悉框架基础、一些面试题&#xff0c;以及sql的使用淡忘了。 本章节的开始是…

[附源码]计算机毕业设计springboot学生综合数据分析系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

java之NIO编程

NIO介绍 前面介绍了BIO为阻塞IO,其阻塞表现在两个方面:服务端等待客户端连接时的阻塞以及连接后没有发生数据传输时的阻塞。NIO是非阻塞IO,那么NIO是如何非阻塞的呢&#xff1f;带着这个疑问&#xff0c;开始研究NIO。 NIO有三大组件:Selector 选择器、Channel 管道、buffer 缓…

【网络层】MTU、IP数据报分片、IP详解、NAT

注&#xff1a;最后有面试挑战&#xff0c;看看自己掌握了吗 文章目录最大传送单元MTU--------以太网MTU是1500BIP数据报分片-------标识字段----同一数据报分片采用同一标识标志字段-----------只有两位有意义-------------中间为DF------dont fragment 不许分片--------DF1禁…