DGL在异构图上的GraphConv模块

news2024/12/25 10:30:41

回顾同构图GraphConv模块

首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例):
在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所使用的聚合类型,方便后面使用Pytorch中的nn模块实现。最后是特征归一化操作。
其具体的代码段为:

获取相关输入特征

        # 获取源节点和目标节点的输入特征维度
        self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)
        # 输出特征维度
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation

根据聚合类型选择Pytorch对应的nn模块中的函数

        # 聚合类型:mean、pool、lstm、gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)

权重初始化

构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2

forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:

  1. 检测输入图对象是否符合规范。
  2. 消息传递和聚合
  3. 聚合后,更新特征作为输出。

检测输入图对象的规范性

# 输入图对象的规范检测
with graph.local_scope():
    # 指定图类型,然后根据图类型扩展输入特征
    feat_src, feat_dst = expand_as_pair(feat, graph)

对于expand_as_pair()函数,其实现的操作是如果输入的特征不是一对的话(源节点和目标节点),就根据图Graph将特征变成一对,但要求图必须是一个block,其对应的源码为:

def expand_as_pair(input_, g=None):
    """Return a pair of same element if the input is not a pair.
    如果输入不是一对,则返回相同元素的一对。


    If the graph is a block, obtain the feature of destination nodes from the source nodes.
    如果图是块,则从源节点中获取目的节点的特征。
    Parameters
    ----------
    input_ : Tensor, dict[str, Tensor], or their pairs
        The input features
    g : DGLGraph or None
        The graph.

        If None, skip checking if the graph is a block.

    Returns
    -------
    tuple[Tensor, Tensor] or tuple[dict[str, Tensor], dict[str, Tensor]]
        The features for input and output nodes
        输入和输出节点的特性
    """
    if isinstance(input_, tuple):
        return input_
    elif g is not None and g.is_block:
        if isinstance(input_, Mapping):
            input_dst = {
                k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                for k, v in input_.items()
            }
        else:
            input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
        return input_, input_dst
    else:
        return input_, input_

消息传递和聚合

聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() APIDGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

        # 消息传递和聚合
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

如果是gcn聚合方式的话还需要用到它自身的特征,但是SAGE不需要,它只需要聚合邻居的特征,这里通过一条判断语句加以区分:

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

更新特征

聚合后,更新特征作为输出——forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

        # 更新特征作为输出
        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

异构图GraphConv模块

DGL提供了 HeteroGraphConv,用于定义异构图上GNN模块。 实现逻辑与消息传递级别的API multi_update_all() 相同,它包括:

  • 每个关系上的DGL NN模块。
  • 聚合来自不同关系上的结果。
    其对应的数学公式为:(r表示关系)

在这里插入图片描述

__ init __函数

异构图的卷积操作接受一个字典类型参数 mods。这个字典的键为关系名,值为作用在该关系上NN模块对象。参数 aggregate 则指定了如何聚合来自不同关系的结果。

class HeteroGraphConv(nn.Module):
    def __init__(self, mods, aggregate='sum'):
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # 获取聚合函数的内部函数
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

nn.ModuleDict() 用于保存字典中的子模块。Pytorch官方也给出了对应的示例:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

forward函数

对于前向传播函数,除了需要输入图和输入张量以外,它还需要2个额外的字典参数mod_argsmod_kwargs。这2个字典与 self.mods 具有相同的键,值则为对应NN模块自定义参数
forward() 函数的输出结果也是一个字典类型的对象。其键为 nty,其值为每个目标节点类型 nty 的输出张量的列表, 表示来自不同关系的计算结果HeteroGraphConv 会对这个列表进一步聚合,并将结果返回给用户。聚合操作主要是:

if g.is_block:
    src_inputs = inputs
    dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
    src_inputs = dst_inputs = inputs

for stype, etype, dtype in g.canonical_etypes:
    rel_graph = g[stype, etype, dtype]
    if rel_graph.num_edges() == 0:
        continue
    if stype not in src_inputs or dtype not in dst_inputs:
        continue
    dstdata = self.mods[etype](
        rel_graph,
        (src_inputs[stype], dst_inputs[dtype]),
        *mod_args.get(etype, ()),
        **mod_kwargs.get(etype, {}))
    outputs[dtype].append(dstdata)

输入 g 可以是异构图或来自异构图的子图区块。和普通的NN模块一样,forward() 函数需要分别处理不同的输入图类型

上述代码中的for循环为处理异构图计算的主要逻辑

  • 首先我们遍历图中所有的关系(通过调用 canonical_etypes)。
  • 通过关系名,我们可以使用g[ stype, etype, dtype ]的语法将只包含该关系的子图( rel_graph )抽取出来。
  • 对于二分图,输入特征将被组织为元组 (src_inputs[stype], dst_inputs[dtype])
  • 接着调用用户预先注册在该关系上的NN模块,并将结果保存在outputs字典中。

最后,HeteroGraphConv 会调用用户注册的 self.agg_fn 函数聚合来自多个关系的结果。

rsts = {}
for nty, alist in outputs.items():
    if len(alist) != 0:
        rsts[nty] = self.agg_fn(alist, nty)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1249452.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023亚太杯数学建模B题完整原创论文讲解

大家好呀,从发布赛题一直到现在,总算完成了2023亚太地区数学建模竞赛B题玻璃温室的微气候调控完整的成品论文。 本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃圾半成品论文。 论文共6…

ZC-OFDM模糊函数原理及仿真

文章目录 前言一、ZC 序列二、ZC-OFDM 信号1、OFDM 信号表达式2、模糊函数表达式三、MATLAB 仿真1、MATLAB 核心源码2、仿真结果①、ZC-OFDM 模糊函数②、ZC-OFDM 距离分辨率③、ZC-OFDM 速度分辨率前言 本文进行 ZC-OFDM 的原理讲解及仿真,首先看一下 ZC-OFDM 的模糊函数仿真…

roseha for windows 11+oracle 11g部署过程

文章目录 一、环境准备关闭防火墙配置hosts共享存储准备 二、部署步骤1.主机A、B安装数据库软件2.主机A进行数据库实例创建3.主机B创建数据库4.安装配置roseha软件 一、环境准备 windows server 2019 oracle 11.2.0.3 EE roseha for windows 11 5个IP地址:2心跳、3…

元宇宙vr线上展馆在线制作降低开发门槛和成本

让人人都拥有自己的元宇宙空间,说起来就是一个令人亢奋的消息,也是大家所期待的,VR元宇宙空间在线编辑平台是VRARAI元宇宙公司深圳华锐视点自主研发的平台,允许用户在虚拟环境中创建、设计和共享空间,操作简单&#xf…

Ubuntu20.04上编译安装TVM

本文主要讲述如何在ubuntu20.04平台上编译TVM代码并在python中import tvm成功。 源代码下载: git clone --recursive https://github.com/apache/tvm tvm 平台环境升级: 1) sudo apt-get update 2) sudo apt-get install -y pyth…

RK3588平台 USB框架与USB识别流程

一.USB的基本概念 在最初的标准里,USB接头有4条线:电源,D-,D,地线。我们暂且把这样的叫做标准的USB接头吧。后来OTG出现了,又增加了miniUSB接头。而miniUSB接头则有5条线,多了一条ID线,用来标识身份用的。 热插拔&am…

【信息隐藏】信息隐藏基础

00 学习资源 0.1 推荐书籍 1.多媒体安全基础导论 复旦大学出版社 蓝皮; 2.隐写学原理与技术(赵险峰)科学出版社 蓝皮 0.2 视频课程 南开大学-信息隐藏技术(没看) 0.3 代码资源 GitHub一位phd:https:/…

Spring Cloud 版本升级遇坑记:OpenFeignClient与Gateway的恩怨情仇

Spring Cloud 版本升级遇坑记:OpenFeignClient与Gateway的恩怨情仇 近日,在对项目中的 Spring Boot、Spring Cloud 以及 Spring Cloud Alibaba 进行版本升级时,遭遇了一个令人头疼的问题:Spring Cloud Gateway 在运行时一直卡住&a…

ES之x-pack-core-7.14.2许可证修改为白金版

X-Pack是什么 X-pack是elasticsearch的一个扩展包,将安全,警告,监视,图形和报告功能捆绑在一个易于安装的软件包中,虽然x-pack被设计为一个无缝的工作,但是你可以轻松的启用或者关闭一些功能。 主要分一下步…

W11安装mysql8详细保姆篇

一、MySQL的下载 目前官方最新版本是8.0.34,考虑到其稳定性、可靠性还需一定周期保证,所以使用官方版求稳定仍然建议5.7系列。MySQL官方下载链接:MySQL官网下载 二、MySQL的安装 1、右击下载完成的安装包 2、点击Custom >> Next 3、…

基于springboot实现冬奥会科普平台系统【项目源码+论文说明】

基于SpringBoot实现冬奥会科普平台系统演示 摘要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理平台应运而生&…

销售为什么会选择使用电销这种方式 ?

在网络经济时代的大环境下,网络营销作为一种新型营销模式和营销理念,已经抢占了大部分市场。 网络营销,是指利用互联网技术和现代信息技术,以及社交媒体平台,进行产品宣传、销售、服务、品牌传播等活动的一种营销模式。…

如何在GO中写出准确的基准测试

一般来说,我们不应该对性能进行猜测。在编写优化时,会有许多因素可能起作用,即使我们对结果有很强的看法,测试它们很少是一个坏主意。然而,编写基准测试并不简单。很容易编写不准确的基准测试,并且基于这些…

Vue框架学习笔记——事件处理

文章目录 前文提要事件处理的解析过程样例代码如下:效果展示图片:v-on:click"响应函数"v-on:click简写形式响应函数添加响应函数传参占位符"$event"注意事项 前文提要 本人仅做个人学习记录,如有错误,请多包…

城市数字孪生优秀案例集 - 城市治理类 - 深圳市城市交通数字孪生建设

一、背景意义 “十四五”规划、《数字交通发展规划纲要》、《广东省数字经济促进条例》等提出“构建城市数据资源 体系,推进城市数据大脑建设,探索建设数字孪生城市”。 当前,我国 9 亿城市化人口每天出行约 16 亿人 次,主要大城…

图神经网络的介绍

1. 图神经网络概念 https://cloud.tencent.com/developer/article/2334518?areaId106001 https://blog.csdn.net/qq_44689178?typeblog; 先参考阅读这篇博主; 该文献中,介绍了 多视图 的 图神经网络的学习; 以及多视图 图神经网络的 对比…

基于SRGAN的人脸图像超分辨率

引言 SRGAN是第一个将GAN用在图像超分辨率上的模型。在这之前,超分辨率常用的损失是L1、L2这种像素损失,这使得模型倾向于学习到平均的结果,也就是给低分辨率图像增加“模糊的细节”。SRGAN引入GAN来解决这个问题。GAN可以生成“真实”的图像…

老牌开源 SVG 编辑器 SVGEdit 是如何架构的?

大家好,我是前端西瓜哥。这次简单看看 SVGEdit 的架构。 SVGEdit 的版本为 7.2.0。 SVGEdit 一款非常老牌的 SVG 图形编辑器,用于编辑处理 SVG,start 数目前是 5.8k。 它的优点在于经过多年的开发,完成度高,较为成熟&a…

Vue解析器

解析器本质上是一个状态机。但我们也曾提到,正则表达式其实也是一个状态机。因此在编写 parser 的时候,利用正则表达式能够让我们少写不少代码。本章我们将更多地利用正则表达式来实现 HTML 解析器。另外,一个完善的 HTML 解析器远比想象的要…

Android设计模式--享元模式

水不激不跃,人不激不奋 一,定义 使用共享对象可有效地支持大量的细粒度的对象 享元模式是对象池的一种实现,用来尽可能减少内存使用量,它适合用于可能存在大量重复对象的场景,来缓存可共享的对象,达到对象…