图神经网络DGL库之消息传递

news2024/10/1 15:51:29

图神经网络DGL库之消息传递

  • 1 消息传递
    • 1.1 图解
    • 1.2 语法格式
      • 1.2.1 message函数
      • 1.2.2 reduce函数
      • 1.2.3 update函数
      • 1.2.4 apply_nodes函数
      • 1.2.5 apply_edges函数
  • 2 具体例子
    • 2.1 建图
    • 2.2 消息传递
      • 2.2.1 函数构造
      • 2.2.2 边更新
      • 2.2.3 节点更新
      • 2.2.4 消息聚合
        • 1 未使用更新函数
        • 2 使用更新函数

1 消息传递

1.1 图解

在这里插入图片描述
对上图的绿色框的函数进行解释:

  • 消息函数(message function):消息函数可以接收源节点的e.src.data,边的e.data以及目标节点的e.dst.data,之后将三者数据进行一些操作(例如加和),最终将数据存放在Mailbox
  • apply_nodes函数:可以使用目标节点的e.dst.data数据进行一些操作(例如e.dst.data+1)
  • 聚合函数(reduce function):可以获取目标节点以及Mailbox数据。将Mailbox数据提取出来,并清空Mailbox,之后更新目标节点。

对上图未提及的函数进行说明:

  • apply_edges函数:可以将消息函数操作后的数据附加在边上
  • update_all函数(更新):启用消息函数和聚合函数,即开始更新节点的流程(消息传递+消息聚合)。

1.2 语法格式

1.2.1 message函数

message函数采用单个参数edges(具有三个成员src,dst和data)分别用于访问源节点,目标节点和边的特征,如下:

def message_func(edges):

1.2.2 reduce函数

reduce函数采用单个参数节点nodes。 节点的成员属性mailbox可以用来访问节点收到的信息,然后做一些运算

  • 一些最常见的聚合运算包括sum,max,min等

如下:

def reducer(nodes):

1.2.3 update函数

调用节点计算的接口是update_all(),它在单个API调用里合并了消息生成、消息聚合和节点特征更新。update_all的参数是消息函数,reduce函数和更新函数

  • 更新函数是可选择参数,用户可以不使用,DGL不推荐在 update_all 中指定更新函数
  • 该函数相当于开始更新节点的流程(消息传递+消息聚合+节点特征更新)。

1.2.4 apply_nodes函数

语法格式:

DGLGraph.apply_nodes(func, v='__ALL__', ntype=None, inplace=False)

参数解释如下:

  • func:用于更新节点特征的函数。
  • v:默认是更新所有节点。
  • ntype:可选,节点类型名称。如果图中只有一个类型的节点,则可以省略。
  • 最后一个已弃用

1.2.5 apply_edges函数

DGLGraph.apply_edges(func, edges='__ALL__', etype=None, inplace=False)

参数解释如下:

  • func:用于生成新的边特征。
  • v:默认是更新所有边。
  • ntype:可选,边类型名称。如果图中只有一个类型的边,则可以省略。
  • 最后一个已弃用

2 具体例子

2.1 建图

示例图如下:
在这里插入图片描述
建图代码如下:

import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])

2.2 消息传递

2.2.1 函数构造

该消息传递方式将源节点的特征和边的特征进行聚合

def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp

2.2.2 边更新

将消息传递函数应用在边上,更新边的特征,代码如下:

import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin edge feat')
print(g.edata)
print('-------------------------------')
def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp

g.apply_edges(message_func)
print('updata edge')
print(g.edata)

运行时,以(0,1)边为例,m=节点0的n_feat + 该边的e_feat,即20+2000=2020,以此类推,结果如下:
在这里插入图片描述

2.2.3 节点更新

import dgl
import torch

# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000])  # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22])  # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')

g.apply_nodes(lambda nodes: {'n_feat': nodes.data['n_feat'] * 2})

# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)

将节点信息×2,结果如图所示:
在这里插入图片描述

2.2.4 消息聚合

1 未使用更新函数
import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin node feat')
print(g.ndata)
print('-------------------------------')
def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp
def reducer(nodes):
    # DGL中,批次中的节点是按照图的划分和计算需求确定的
    print('batch nodes: ',nodes.nodes())
    # nodes.mailbox 只包含在 message_func 中生成并发送到节点的消息
    print('mailbox: ',nodes.mailbox)
    print('--------------------------')
    # 每一行进行求和,目的是将数据转成列表格式
    tmp = {'h': torch.sum(nodes.mailbox['m'],dim=1)}
    return tmp
g.update_all(message_func, reducer)
print('updata node')
print(g.ndata)
print('edge')
print(g.edata)

经过了消息生成、消息聚合和节点特征更新过程,将新特征h更新到节点的特征字典中。

  • 注意:这个过程并不会把特征m加到边的特征字典中

在这里插入图片描述

2 使用更新函数

很少这么用,不建议

import dgl
import torch

# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000])  # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22])  # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')

# 消息传递函数
def message_func(edges):
    # 计算消息:边特征 + 源节点特征
    return {'m': edges.data['e_feat'] + edges.src['n_feat']}

# 聚合函数
def reducer(nodes):
    # 打印批次节点和邮件箱内容
    print('Batch nodes: ', nodes.nodes())
    print('Mailbox: ', nodes.mailbox)
    print('--------------------------')
    # 对消息进行求和
    return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

# 更新节点特征的函数
def update_node_features(nodes):
    # 使用聚合后的特征更新节点特征
    # nodes.data['h'] 是聚合后的消息
    # nodes.data['n_feat'] 是节点的原始特征
    updated_feat = nodes.data['n_feat'] + nodes.data['h']
    return {'h': updated_feat}

# 执行消息传递和聚合
g.update_all(message_func, reducer)

# 在消息传递后,使用 apply_nodes 更新节点特征
g.update_all(message_func, reducer,update_node_features) # 获取聚合结果

# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)
print('Edge features:')
print(g.edata)

结果如图所示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2182652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

M3u8视频由手机拷贝到电脑之后,通过potplayer播放报错找不到文件地址怎么解决?

该文章前面三节主要介绍M3u8视频是什么,视频播放错误(找不到地址)的解决方法在后面 M3U8是一种多媒体播放列表文件格式,主要用于流媒体播放。 一、文件格式特点 1. 文本文件:M3U8是一个采用 UTF-8 编码的文本文件,这意味着它可…

Shell入门基础学习笔记

目录 第1章 Shell概述 第2章 Shell解析器 第3章 Shell脚本入门 第4章 Shell中的变量 4.1 系统变量 4.2 自定义变量 4.3 特殊变量:$n 4.4 特殊变量:$# 4.5 特殊变量:$*、$ 4.6 特殊变量:$? 第5章 运算符 …

玩机进阶教程----MTK芯片机型修改串码IMEI 修改MEID 修复基带步骤详细演示 总结

在前面的博文中有对MTK芯片机型修改参数步骤做过解析。但其中有些步骤友友不太了解。在以前MTK芯片 3G 4G的机型中有使用老版本修改工具SN_Writer_Tool来修改,但对于新版本mtk芯片机型兼容性不是太好。而且局限于必须有基带BP AP文件。今天针对新工具Modem META 修改 做个补充…

国外问卷调查匠哥已经不带人了,但是还可以交流

国外问卷调查匠哥已经不带人了,但是还可以来和匠哥交流, 为啥不带人了呢? 从今年年初开始,匠哥在带学员的过程中发现: 跟往年同样的收费,同样的教学,甚至我付出的时间精力比以前还多&#xff…

Java | Leetcode Java题解之第447题回旋镖的数量

题目&#xff1a; 题解&#xff1a; class Solution {public int numberOfBoomerangs(int[][] points) {int ans 0;for (int[] p : points) {Map<Integer, Integer> cnt new HashMap<Integer, Integer>();for (int[] q : points) {int dis (p[0] - q[0]) * (p[…

如何构建一个生产级的AI平台(1)?

本文概述了生成式 AI 平台的常见组件、它们的作用以及它们的实现方式。 本文重点介绍部署 AI 应用程序的整体架构。 它讨论了需要哪些组件以及构建这些组件时的注意事项。 它不是关于如何构建 AI 应用程序。 这就是整体架构的样子。 这是一个相当复杂的系统。 这篇文章将从最…

基于Leaflet和天地图的细直箭头和突击方向标绘实战

目录 前言 一、细直箭头和突击方向的类设计 1、总体类图 2、对象区别 二、标绘绘制的具体实现 1、绘制时序图 2、相关点的具体绘制 3、最终的成果 三、总结 前言 今天是10月1日国庆节&#xff0c;迎来我们伟大祖国75周年的华诞。有国才有家&#xff0c;在这里首先祝我们…

【vs code(cursor) ssh连不上服务器(2)】但是 Terminal 可以连上,问题解决 ✅

【vs code(cursor) ssh连不上服务器】但是 Terminal 可以连上&#xff0c;问题解决 ✅ 对于类似的问题&#xff0c;之前的解决方法是清洗配置文件再重新连接。当重新连接不起作用时&#xff0c;可以再试下本文的方法。 问题描述&#xff1a;SSH 超时错误 vs code 连不上 ssh…

解决方案:机器学习中,回归及分类常用的模型评估指标有哪些

文章目录 一、现象二、解决方案回归任务的评价指标&#xff1a;均方误差 (MSE):平均绝对误差 (MAE): 分类任务的评价指标&#xff1a;准确率 (Accuracy):混淆矩阵 (Confusion Matrix):精确度 (Precision):召回率 (Recall):F1分数 (F1 Score):ROC曲线 (Receiver Operating Chara…

Qt的互斥量用法

目的 互斥量的概念 互斥量是一个可以处于两态之一的变量:解锁和加锁。这样&#xff0c;只需要一个二进制位表示它&#xff0c;不过实际上&#xff0c;常常使用一个整型量&#xff0c;0表示解锁&#xff0c;而其他所有的值则表示加锁。互斥量使用两个过程。当一个线程(或进程)…

ubuntu 24.04如何分配内存

24版与之前有一点不同&#xff0c;这里记录一下我的经历&#xff0c;希望有帮助 1.进入ubuntu直接试用&#xff0c;没有之前的安装向导&#xff08;如图&#xff09;&#xff0c;在屏幕的左上角会找到安装Ubuntu 2.分配内存 24的手动分配内存&#xff0c;不需要分配系统内存&…

IOT平台颜值天花板?延凡科技物联网平台让人惊叹不已

IOT平台颜值天花板&#xff1f;延凡科技物联网平台让人惊叹不已 在物联网的时代&#xff0c;AIOT平台凭借智能化的管理和决策能力&#xff0c;为多个行业带来了巨大的提升。本文将为大家介绍AIOT物联网平台的核心功能、应用场景以及它是如何改变我们的生活的。 平台简介 AIOT物…

二维环境下的TDOA测距定位的MATLAB代码,带中文注释

TDOA测距定位程序介绍 概述 本MATLAB程序实现了基于时间差到达&#xff08;TDOA&#xff09;技术的二维测距定位&#xff0c;能够处理4个或任意数量&#xff08;大于3个&#xff09;的锚节点。在无线定位和导航系统中&#xff0c;TDOA是一种常用的定位方法&#xff0c;通过测量…

一款免费开源的接口测试工具——ApiFox详细教程

前言 APIfox是一种功能强大的接口测试工具&#xff0c;它可以帮助用户轻松地进行REST API的自动化测试和文档编写。本文将从以下几个方面介绍APIfox的基本使用方法、特点和优势。 一、什么是APIfox&#xff1f; APIfox是一款基于Web的REST API测试工具&#xff0c;通过创建测…

论文笔记:LAFF 文本到视频检索的新基准

整理了ECCV2022 Lightweight Attentional Feature Fusion: A New Baseline for Text-to-Video Retrieval 论文的阅读笔记 背景模型问题定义LAFF(Lightweight Attention Feature Fusion)LAFF Block 实验消融实验可视化对比试验 这篇文章提出了一种新颖灵活的特征融合方式&#x…

初步认识产品经理

产品经理 思考问题的维度 1️⃣为什么要抓住核心用户&#xff1f; 所有和产品有关系的群体就是用户&#xff0c;存在共性和差异了解用户的付费点&#xff0c;更好的优化产品是否使用&#xff1a;&#xff08;目标用户-已使用产品&#xff1a;种子用户-尝鲜&#xff1b;核心用…

【Golang】深入解读Go语言中的错误(error)与异常(panic)

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

【Pyecharts】时间线柱状图x轴坐标重复出现并重叠

问题描述 如图右侧显示多的一列坐标 解决方案 降低pyecharts版本&#xff1a;pip install pyecharts2.0.5

ChatGPT与R语言融合技术在生态环境数据统计分析、绘图(回归和混合效应模型、多元统计分析)

自2022年GPT&#xff08;Generative Pre-trained Transformer&#xff09;大语言模型的发布以来&#xff0c;它以其卓越的自然语言处理能力和广泛的应用潜力&#xff0c;在学术界和工业界掀起了一场革命。在短短一年多的时间里&#xff0c;GPT已经在多个领域展现出其独特的价值…

vue2接入高德地图实现折线绘制、起始点标记和轨迹打点的完整功能(提供Gitee源码)

目录 一、申请密钥 二、安装element-ui 三、安装高德地图依赖 四、完整代码 五、运行截图 六、官方文档 七、Gitee源码 一、申请密钥 登录高德开放平台&#xff0c;点击我的应用&#xff0c;先添加新应用&#xff0c;然后再添加Key。 ​ 如图所示填写对应的信息&…