DGL中NN模块的构造函数

news2024/12/23 23:34:25

在这里插入图片描述
上图引用自:dgl用户文档第三章(nn模块编写)

"""

构造函数完成以下几个任务:
1、设置选项。
2、注册可学习的参数或者子模块。
3、初始化参数。

"""
import torch.nn as nn
from dgl.utils import expand_as_pair
import dgl.nn
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
"""
在构造函数中,用户首先需要设置数据的维度。
对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。
对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型
包括 mean、 sum、 max 和 min。一些模块可能会使用更加复杂的聚合函数,
比如 lstm。
"""
"""注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear、 nn.LSTM 等。"""

class SAGE(nn.Module):
    def __init__(self, in_feats, out_feats, aggregator_type,
                 bias=True, norm=None, activation=None):
        super(SAGE, self).__init__()
        # 获取源节点和目标节点的输入特征维度
        self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)
        # 输出特征维度
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        # 聚合类型:mean、pool、lstm、gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
    #  构造函数的最后调用了 reset_parameters() 进行权重初始化。
    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2
    def forward(self, graph, feat):    #SAGEConv示例中的 forward() 函数
        # 输入图对象的规范检测
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)
        # 消息传递和聚合
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # 更新特征作为输出
        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst
"""
在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:
1、检测输入图对象是否符合规范。
2、消息传递和聚合。
3、聚合后,更新特征作为输出。

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 
比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入
度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是
,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输
出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图和子图块。



聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息
传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的
消息传递代码 里所介绍的性能优化。

聚合后,更新特征作为输出
forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
"""


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1259029.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用opencv实现图像滤波

1 图像滤波介绍 滤波是信号和图像处理中的基本任务之一,其旨在有选择地提取图像的某些特征,可以用于在给定应用程序的上下文中传达重要信息,例如,去除图像中的噪声、提取所需的视觉特征、图像重采样等。 1.1 图像滤波理论 图像…

【论文解读】基于生成式面部先验的真实世界盲脸修复

论文地址:https://arxiv.org/pdf/2101.04061.pdf 代码地址:https://github.com/TencentARC/GFPGAN 图片解释: 与最先进的面部修复方法的比较:HiFaceGAN [67]、DFDNet [44]、Wan 等人。[61] 和 PULSE [52] 在真实世界的低质量图像…

final关键字-Java

final关键字 一、使用场景1、当不希望类被继承时,可以用final修饰。2、当不希望父类的某个方法被子类覆盖/重写(override)时,可以用final修饰。3、当不希望类的的某个属性的值被修改,可以用final修饰。4、当不希望某个局部变量被修改&#xf…

CSGO搬砖如何选品?选品软件和教程靠谱吗?

说到CSGO搬砖项目,目前平台最火的就是CSGO游戏搬砖。在CSGO搬砖项目中,选品是至关重要的环节,直接影响到利润。而选品软件可以帮助我们更快地了解市场变化、计算成本利润等关键信息,提高选品的效率和准确性。可靠的选品软件还能够…

技术人员都了解,动态代理IP和静态代理IP的区别及适用的场景

动态代理IP和静态代理IP是两种常见的代理IP技术,它们在网络通信中起到了重要的作用。虽然它们都可以用于隐蔽真实的IP地址,但在实际应用中有一些区别和适用的场景。本文将介绍这两种代理IP的区别以及它们适用于哪些场景。 一、静态代理IP 静态代理IP是指…

Java的threadd常用方法

常用API 给当前线程命名 主线程 package com.itheima.d2;public class ThreadTest1 {public static void main(String[] args) {Thread t1 new MyThread("子线程1");//t1.setName("子线程1");t1.start();System.out.println(t1.getName());//获得子线程…

eutil.dll文件缺失修复全指南,教你快速修复eutil.dll

eutil.dll缺失了要怎么办?eutil.dll是一种常见的动态链接库(DynamicLinkLibrary,DLL)文件,它在Windows操作系统中发挥着重要作用。DLL文件允许程序共享代码以执行诸如打印或连接网络之类的功能。这不仅节省了系统资源&…

浅谈API自动化测试

前言 本文主要针对API测试的概念及API测试在Choerodon中的实践展开。 API(应用程序编程接口)测试是一种软件测试,可以直接在API级别执行验证。它是集成测试的一部分,它确定API是否满足测试人员对功能,可靠性&#xf…

【第五节:微信小程序 小程序UI组件B】微信小程序入门,以思维导图的方式展开5

上图若是看不清,可私信给发大图哈 5、小程序UI组件B 表单form button 按钮 size String default 有效值 default, mini type String default 按钮的样式类型,有效值 primary, default, warn plain Bo…

[算法总结] - 蓄水池采样算法

问题描述 在长度为N的数组中,随机等概率选取K个元素,如何实现这个随机算法。 思路很简单,生成一个[0, N]的随机数index,然后返回index上的数值即可。 但是,如果输入是一个长度未知的数组比如stream,先遍历…

IDEA中Tomcat启动web项目

1.首先【Run】-->【Edit Configurations】,进入对应功能界面 2.点击左上角【】,选择Tomcat Server -->Local 3.Name输入自己中意的,下面两个port,保证没被占用就行 4.切到【Deployment】页签,点击【】&#xff…

elk日志分析系统:

elk日志分析系统: elk是一套完整的日志集中处理方案,由三个开源的软件简称组成; E:Easticsearch 简称ES是一个开源的,分布式的存储检索引擎,(索引型的非关系数据库)存储日志 由java代码开发的&#xff0…

【Java Spring】SpringBoot 五大类注解

文章目录 Spring Boot 注解简介1、五大类注解的作用2、五大类注解的关系3、通过注解获取对象4、获取Bean对象名规则解析 Spring Boot 注解简介 Spring Boot的核心就是注解。Spring Boot通过各种组合注解,极大地简化了Spring项目的搭建和开发。五大类注解是Spring B…

用泰勒展开线性化

在点附近做泰勒展开: 当和很接近的时候,很小,更小,所以可以忽略及后面的高阶项,得到 因为、都是常数,所以等式右边是 x的线性方程,在点附近进行了线性化。 举个例子: 假设 那么做一…

【Linux】安卓端JuiceSSH结合内网穿透实现远程连接服务器

目录 前言1. Linux安装cpolar2. 创建公网SSH连接地址3. JuiceSSH公网远程连接4. 固定连接SSH公网地址5. SSH固定地址连接测试 前言 处于内网的虚拟机如何被外网访问呢?如何手机就能访问虚拟机呢? 本文介绍 cpolarJuiceSSH 实现手机端远程连接Linux虚拟…

linux 命令 sudo、su 命令

sudo命令详解 1、初识sudo sudo是linux下常用的允许普通用户使用超级用户权限的工具,sudo 用来执行需要提升权限(通常是作为 root 用户)的命令,允许系统管理员让普通用户执行一些或者全部的root命令,如halt&#xff…

C++中类的静态成员、存储、this、友元和运算符重载

静态成员 在类定义中,它的成员(包括成员变量和成员函数),这些成员可以用关键字static 声明为静态的,称为静态成员。 不管这个类创建了多少个对象,静态成员只有一个拷贝,这个拷贝被所有属于这个…

搜索百度可以直接生成代码拉

先看效果图: 使用示例: 比如我要搜索“JS取一个数在两个数更近”的方法,直接搜“JS取一个数在两个数更近”,点击百度一下,就会出现想要的代码,如上图。

网站频频告警故障排查实录

故障描述 位于某Proxmox VE超融合集群上的一个网站频频报警,表现的形式是一会儿服务不可用,一会儿又恢复(如下图所示),但同一集群上的其他Web站点未发现异常。 可能的原因 1)出口带宽占满。 2)…

【技巧】Excel表格如何退出“只读方式”?

如果Excel表格被设置了“只读模式”,那每次打开Excel都会出现对话框提示是否以“只读方式”打开,并且以“只读方式”打开的Excel,如果进行更改是无法保存原文件的。那要如何退出“只读方式”呢? 首先,我们要看下Excel表…