图神经网络实战(17)——深度图生成模型

news2024/12/28 2:06:44

图神经网络实战(17)——深度图生成模型

    • 0. 前言
    • 1. 变分图自编码器
    • 2. 自回归模型
    • 3. 生成对抗网络
    • 小结
    • 系列链接

0. 前言

我们已经学习了经典的图生成算法,虽然它们能够完成图生成任务,但也存在一些问题,促使基于图神经网络 (Graph Neural Networks, GNN) 的图生成技术的出现。深度图生成模型基于 GNN 架构,比传统技术更具表达能力。然而,缺点在于它们往往过于复杂,无法像经典方法那样进行分析和理解。主要的深度生成模型架构包括:变分自编码器 (Variational Autoencoder, VAE)、生成对抗网络 (Generative Adversarial Network, GAN)、自回归模型 (Autoregressive Model)、归一化流模型 (Normalizing Flow Model) 或扩散模型 (Diffusion Model) 等,但相较而言,前三种模型更加成熟。在本节中,将介绍三类图生成模型:基于变分自编码器 (Variational Autoencoder, VAE) 的模型、基于自回归模型 (Autoregressive Model) 和基于生成对抗网络 (Generative Adversarial Network, GAN) 的模型。

1. 变分图自编码器

我们已经知道变分自编码器 (Variational Autoencoder, VAE)可用于近似邻接矩阵,而变分图自编码器 (Variational Graph Autoencoder, VGAE) 模型由两个部分组成:编码器和解码器。编码器使用共享第一层的两个图卷积网络 (Graph Convolutional Network, GCN) 来学习每个潜正态分布的均值和方差。然后,解码器对学习到的分布进行采样,执行潜变量之间的内积。最后,得到了近似邻接矩阵 A ^ = σ ( Z T Z ) \hat A = σ(Z^TZ) A^=σ(ZTZ)
在使用图神经网络预测链接一节中,使用 A ^ \hat A A^ 来预测链接。然而,这并不是它的唯一应用,它可以直接给出一个网络的邻接矩阵,模仿训练过程中所看到的图。除了预测链接之外,我们还可以使用这个输出来生成新的图。以下是由 VGAE 模型创建的邻接矩阵的示例:

import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, split_labels=True, add_negative_train_samples=False),
])

dataset = Planetoid('.', name='Cora', transform=transform)

train_data, val_data, test_data = dataset[0]

from torch_geometric.nn import GCNConv, VGAE

class Encoder(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.conv1 = GCNConv(dim_in, 2 * dim_out)
        self.conv_mu = GCNConv(2 * dim_out, dim_out)
        self.conv_logstd = GCNConv(2 * dim_out, dim_out)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

model = VGAE(Encoder(dataset.num_features, 16)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index) + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)

for epoch in range(301):
    loss = train()
    val_auc, val_ap = test(val_data)
    if epoch % 50 == 0:
        print(f'Epoch: {epoch:>3} | Val AUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')

val_auc, val_ap = test(val_data)
print(f'\nTest AUC: {val_auc:.4f} | Test AP: {val_ap:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
adj = torch.where((z @ z.T) > 0.9, 1, 0)
print(adj)

Epoch:   0 | Val AUC: 0.7145 | Val AP: 0.7259
Epoch:  50 | Val AUC: 0.7030 | Val AP: 0.7175
Epoch: 100 | Val AUC: 0.7359 | Val AP: 0.7561
Epoch: 150 | Val AUC: 0.8535 | Val AP: 0.8618
Epoch: 200 | Val AUC: 0.8942 | Val AP: 0.8978
Epoch: 250 | Val AUC: 0.9005 | Val AP: 0.9050
Epoch: 300 | Val AUC: 0.9101 | Val AP: 0.9138
'''
Test AUC: 0.9101 | Test AP: 0.9138
tensor([[1, 0, 1,  ..., 0, 1, 1],
        [0, 1, 1,  ..., 0, 1, 1],
        [1, 1, 1,  ..., 0, 1, 1],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 1, 1],
        [1, 1, 1,  ..., 0, 1, 1]], device='cuda:0')
'''

这种技术已扩展到 VGAE 模型之外,能够输出节点和边的特征。GraphVAE 是最流行的基于 VAE 的图生成模型之一,该模型于 2018 年由 SimonovskyKomodakis 提出,旨在生成逼真的分子,这需要具备区分节点(原子)和边(化学键)的能力。
GraphVAE 考虑图 G = ( A , E , F ) G= (A, E, F) G=(A,E,F) ,其中 A A A 是邻接矩阵, E E E 是边属性矩阵, F F F 是节点属性矩阵。GraphVAE 学习了具有预定节点数的图 G ~ = ( A ~ , E ~ , F ~ ) \widetilde G=(\widetilde A,\widetilde E,\widetilde F) G =(A ,E ,F ) 的概率,其中 A ~ \widetilde A A 包含节点 ( A ~ a , a ) (\widetilde A_{a, a}) (A a,a) 和边 ( A ~ a , b ) (\widetilde A_{a,b}) (A a,b) 的概率, E ~ \widetilde E E 表示边的类别概率, F ~ \widetilde F F 包含节点的类别概率。与 VGAE 相比,GraphVAE 的编码器是一个具有边条件图卷积 (conditional graph convolutions, ECC) 的前馈网络,其解码器是一个具有三个输出的多层感知机 (multilayer perceptron, MLP),整体架构如下所示:

GraphVAE 架构

还有许多其它基于 VAE 的图生成架构,但它们的作用并不局限于模仿图,还可以添加约束条件,引导生成的图类型:

  • 添加约束的一种常用方法是在解码阶段进行检查,如约束图变分自编码器 (Constrained Graph Variational Autoencoder, CGVAE)。在此架构中,编码器是一个门控图卷积网络 (Gated Graph Convolutional Network, GGCN),解码器是一个自回归模型。自回归解码器可以验证整个过程中每个步骤的每个约束条件
  • 另一种添加约束条件的技术是使用基于 Lagrangian 的正则化器,这种正则化器计算速度更快,但生成的约束条件并不那么严格

2. 自回归模型

自回归模型 (Autoregressive Model) 也可以单独使用,自回归模型与其他模型的区别在于,模型过去的输出会作为当前输入的一部分。在此框架下,图生成成为一个连续的决策过程,既要考虑数据,又要考虑过去的决策。例如,在每一步中,自回归模型可以创建一个新节点或新链接,然后,生成的图被输入到模型中用于下一步生成,直到达到停止条件。这一过程如下图所示:

自回归模型

在实践中,可以使用循环神经网络 (Recurrent Neural Network, RNN) 来实现这种自回归模型。在 RNN 架构中,先前的输出被用作计算当前隐藏状态的输入。此外,RNN 还能处理任意长度的输入,这对于迭代生成图至关重要。但这种架构的计算比前馈网络慢,因为必须处理整个序列才能获得最终输出。最流行的两种 RNN 为门控递归单元 (Gated Recurrent Unit, GRU) 和长短期记忆 (Long Short-Term Memory, LSTM) 网络。
2018You 等人提出了 GraphRNN,是自回归模型在深度图生成方面的直接实现。该架构使用两个 RNN

  • 一个图级 RNN,用于生成节点序列(包括初始状态)
  • 一个边级 RNN,用于预测每个新添加节点的连接情况

边级 RNN 将图级 RNN 的隐藏状态作为输入,然后使用自己的输出。下图展示了模型推理时的生成机制:

GraphRNN 架构

两个 RNN 实际上是在完成一个邻接矩阵,图级 RNN 创建的每个新节点都会增加一行和一列,而边级 RNN 会用 01 进行填充。总体而言,GraphRNN 执行以下步骤:

  1. 添加新节点:图级 RNN 对图进行初始化,并将其输出反馈给边级 RNN
  2. 添加新连接:边级 RNN 会预测新节点是否与之前的每个节点相连
  3. 停止图生成:重复前两个步骤,直到边级 RNN 输出 EOS 标记,标志着生成过程结束

GraphRNN 可以学习不同类型的图(网格、社交网络、蛋白质等),其性能完全优于传统技术。与 GraphVAE 相比,GraphRNN 是模仿给定图的首选架构。

3. 生成对抗网络

与变分自编码器 (Variational Autoencoder, VAE) 一样,生成对抗网络 (Generative Adversarial Network, GAN) 也是机器学习 (Machine Learning, ML) 中著名的生成模型。在 GAN 框架中,两个神经网络在零和博弈中以不同目标展开竞争。第一个神经网络是生成器 (generator),负责创建新数据;第二个神经网络是判别器 (discriminator),负责将每个样本分为真实样本(来自训练集)或虚假样本(由生成器创建)。
为了提升模型性能,研究人员提出了多种改进原始架构的方案。Wasserstein GAN (WGAN) 通过最小化两个概率分布之间的 Wasserstein 距离(或称推土机距离)来提高训练的稳定性。这一改进可以通过引入梯度惩罚而非原始梯度剪切进一步进行完善。
将这一框架应用于深度图生成中,与其它技术一样,GAN 可以生成图以优化某些约束条件,这在寻找具有特定性质的新化合物等应用中非常有效。由于其离散性,这一问题异常庞大且复杂。

小结

图生成是生成新图的技术,并且希望所生成的图具有真实世界中图的性质。由于传统图生成方法缺乏表达能力,因此提出了更加灵活的基于 GNN 的技术。本节中,我们介绍了三类深度图生成模型: 基于变分自编码器 (Variational Autoencoder, VAE) 的模型、基于自回归模型 (Autoregressive Model) 和基于生成对抗网络 (Generative Adversarial Network, GAN) 的模型。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现
图神经网络实战(10)——归纳学习
图神经网络实战(11)——Weisfeiler-Leman测试
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
图神经网络实战(13)——经典链接预测算法
图神经网络实战(14)——基于节点嵌入预测链接
图神经网络实战(15)——SEAL链接预测算法
图神经网络实战(16)——经典图生成算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933319.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pytorch学习(四)绘制loss和correct曲线

这一次学习的时候静态绘制loss和correct曲线,也就是在模型训练完成后,对统计的数据进行绘制。 以minist数据训练为例子 import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvisi…

GESP CCF C++ 三级认证真题 2024年6月

第 1 题 小杨父母带他到某培训机构给他报名参加CCF组织的GESP认证考试的第1级,那他可以选择的认证语言有()种。 A. 1 B. 2 C. 3 D. 4 第 2 题 下面流程图在yr输入2024时,可以判定yr代表闰年,并输出 2月是29天 &#x…

python-字符金字塔(赛氪OJ)

[题目描述] 请打印输出一个字符金字塔,字符金字塔的特征请参考样例。输入格式: 输入一个字母,保证是大写。输出格式: 输出一个字母金字塔,输出样式见样例。样例输入 C样例输出 A ABA …

【前端8】element ui常见页面布局:注意事项

【前端8】element ui常见页面布局:注意事项 写在最前面遇到的问题Element UI 常见页面布局:注意事项1. 了解基本布局组件常用的菜单1多一个下角 常用的菜单2 2. 栅格系统的使用3. 响应式布局4. Flex 布局的应用5. 避免滥用嵌套6. 处理边距和填充 小结 &a…

基于STC89C51单片机的烟雾报警器设计(煤气火灾检测报警)(含文档、源码与proteus仿真,以及系统详细介绍)

本篇文章论述的是基于STC89C51单片机的烟雾报警器设计的详情介绍,如果对您有帮助的话,还请关注一下哦,如果有资源方面的需要可以联系我。 目录 摘要 原理图 实物图 仿真图 元件清单 代码 系统论文 资源下载 摘要 随着现代家庭用火、…

TikTok内嵌跨境商城全开源_搭建教程/前端uniapp+后端源码

多语言跨境电商外贸商城 TikTok内嵌商城,商家入驻一键铺货一键提货 全开源完美运营,接在tiktok里面的商城内嵌,也可单独分开出来当独立站运营 二十一种语言,可以做很多国家的市场,支持商家入驻,多店铺等等…

服务器IP和电脑IP有什么不同

服务器IP和电脑IP有什么不同?在当今的信息化时代,IP地址作为网络世界中不可或缺的元素,扮演着举足轻重的角色。然而,对于非专业人士来说,服务器IP和电脑IP之间的区别往往模糊不清。本文旨在深入探讨这两者之间的不同&a…

若依前端和后端时间相差8小时

原因基类未设置时区 实体类继承 BaseEntity 加上timezone"GMT8" /** 创建时间 */ JsonFormat(pattern "yyyy-MM-dd HH:mm:ss" , timezone"GMT8") private Date createTime; 解决

golang程序性能提升改进篇之文件的读写---第一篇

背景:接手的项目是golang开发的(本人初次接触golang)经常出现oom。这个程序是计算和io密集型,调用流量属于明显有波峰波谷,但是因为各种原因,当前无法快速通过serverless或者动态在高峰时段调整资源&#x…

MViTv2:Facebook出品,进一步优化的多尺度ViT | CVPR 2022

论文将Multiscale Vision Transformers (MViTv2) 作为图像和视频分类以及对象检测的统一架构进行研究,结合分解的相对位置编码和残差池化连接提出了MViT的改进版本 来源:晓飞的算法工程笔记 公众号 论文: MViTv2: Improved Multiscale Vision Transforme…

Fiddler抓包过滤host及js、css等地址

1、如上图所示 在Filter页面中勾选Hide if URL contains;输入框输入 REGEX:\.(js|css|png|google|favicon\?.*) 隐藏掉包含js、css、png、google等的地址: Hide if URL contains: REGEX:\.(js|css|png|google|favicon\?.*) 2、使Filters设置生效 A…

微软新版WSL 2.3.11子系统带来“数百个新内核模块“和新功能

微软今天发布了新版的 Windows Subsystem for Linux(WSL)。与当前的 WSL 2.2.4 稳定版相比,WSL 2.3.11 具有许多特性:它从旧版的 Linux 5.15 LTS 内核转到了 Linux 6.6LTS内核。今天的发布说明指出,WSL 2.3.11 基于 Linux 6.6.36.3&#xff0…

【C++刷题】[UVA 489]Hangman Judge 刽子手游戏

题目描述 题目解析 这一题看似简单其实有很多坑,我也被卡了好久才ac。首先题目的意思是,输入回合数,一个答案单词,和一个猜测单词,如果猜测的单词里存在答案单词里的所有字母则判定为赢,如果有一个字母是答…

力扣622.设计循环队列

力扣622.设计循环队列 通过数组索引构建一个虚拟的首尾相连的环当front rear时 队列为空当front rear 1时 队列为满 (最后一位不存) class MyCircularQueue {int front;int rear;int capacity;vector<int> elements;public:MyCircularQueue(int k) {//最后一位不存…

基于python的三次样条插值原理及代码

1 三次样条插值 1.1 三次样条插值的基本概念 三次样条插值是通过求解三弯矩方程组&#xff08;即三次样条方程组的特殊形式&#xff09;来得出曲线函数组的过程。在实际计算中&#xff0c;还需要引入边界条件来完成计算。样条插值的名称来源于早期工程师制图时使用的细长木条&…

【机器学习】--过采样原理及代码详解

过采样&#xff08;Oversampling&#xff09;是一个在多个领域都有应用的技术&#xff0c;其具体含义和应用方法会根据领域的不同而有所差异。以下是对过采样技术的详细解析&#xff0c;主要从机器学习和信号处理两个领域进行阐述。 一、机器学习中的过采样 在机器学习中&…

未来的社交标杆:如何通过AI让Facebook更加智能化?

在当今信息爆炸的时代&#xff0c;社交媒体平台的智能化已成为提高用户体验和互动质量的关键因素。Facebook&#xff0c;作为全球最大的社交平台之一&#xff0c;通过人工智能&#xff08;AI&#xff09;的广泛应用&#xff0c;正不断推进其智能化进程。本文将探讨Facebook如何…

Qt日志库QsLog使用教程

前言 最近项目中需要用到日志库。上一次项目中用到了log4qt库&#xff0c;这个库有个麻烦的点是要配置config文件&#xff0c;所以这次切换到了QsLog。用了后这个库的感受是&#xff0c;比较轻量级&#xff0c;嘎嘎好用&#xff0c;推荐一波。 下载QsLog库 https://github.c…

CSS技巧专栏:一日一例 7 - 纯CSS实现炫光边框按钮特效

CSS技巧专栏&#xff1a;一日一例 7 - 纯CSS实现炫光边框按钮特效 本例效果图 案例分析 相信你可能已经在网络见过类似这样的流光的按钮&#xff0c;在羡慕别人做的按钮这么酷的时候&#xff0c;你有没有扒一下它的源代码的冲动&#xff1f;或者你当时有点冲动&#xff0c;却…

在Oxygen中比较两个目录的差异,用于编写手册两个版本的变更说明

▲ 搜索“大龙谈智能内容”关注公众号▲ 当我们对手册进行改版的时候&#xff0c;我们通常需要编写变更说明&#xff0c;如下图&#xff1a; 改版通常会改动很多文件的很多地方&#xff0c;如何知道哪些地方更改了呢&#xff1f; Oxygen提供了比较两个目录的功能&#xff0c…