【图神经网络】使用DGL框架实现简单图分类任务

news2024/10/6 8:28:51

使用DGL框架实现简单图分类任务

  • 简单图分类任务
    • 实现过程
    • 打包一个图的小批量
    • 定义图分类器
      • 图卷积
      • 读出和分类
    • 准备和训练
    • 核心代码
    • 参考资料

图分类(预测图的标签)是图结构数据里一类重要的问题。它的应用广泛,可见于生物信息学、化学信息学、社交网络分析、城市计算以及网络安全。随着近来学界对于图神经网络的热情持续高涨,出现了一批用图神经网络做图分类的工作。比如 训练图神经网络来预测蛋白质结构的性质,根据社交网络结构来预测用户的所属社区等(Ying et al., 2018, Cangea et al., 2018, Knyazev et al., 2018, Bianchi et al., 2019, Liao et al., 2019, Gao et al., 2019)。

本文使用DGL框架实现简单的图分类任务,任务目标有两个:

  1. 如何使用DGL批量化处理大小各异的图数据
  2. 训练图神经网络完成一个简单的图分类任务

简单图分类任务

这里设计了一个简单的图分类任务。在DGL里已经实现了一个迷你图分类数据集(MiniGCDataset)。它由以下8类图结构数据组成。每一类图包含同样数量的随机样本。任务目标是训练图神经网络模型对这些样本进行分类。
DGL框架中的8中图

实现过程

以下是使用 MiniGCDataset 的示例代码。
首先,创建了一个拥有 100 个样本的数据集。数据集中每张图随机有 16 到 32 个节点。DGL 中所有的数据集类都符合 Sequence 的抽象结构——既可以使用 dataset[i] 来访问第 i 个样本。这里每个样本包含图结构以及它对应的标签。
创建数据集
运行以上代码,可以画出数据集中第64个样本的图结果及其对应的标签:
网格图

打包一个图的小批量

为了更高效地训练神经网络,一个常见的做法是将多个样本打包成小批量(mini-batch)。打包尺寸相同的张量样本非常简单。比如说打包两个尺寸为 2828 的图片会得到一个 22828 的张量。相较之下,打包图面临两个挑战
(1)图的边比较稀疏
(2)图的大小、形状各不相同

DGL 提供了名为 dgl.batch 的接口来实现打包一个图批量的功能。其核心思路非常简单**。将 n 张小图打包在一起的操作可以看成是生成一张含 n 个不相连小图的大图**。下图的可视化从直觉上解释了 dgl.batch 的功能。
dgl.batch
可以看到通过 dgl.batch 操作,生成了一张大图,其中包含了一个环状和一个星状的连通分量。其邻接矩阵表示则对应为在对角线上把两张小图的邻接矩阵拼接在一起(其余部分都为 0)

以下是使用 dgl.batch 的一个实际例子。这里,定义了一个 collate 函数来将 MiniGCDataset 里多个样本打包成一个小批量。

import dgl

def collate(samples):
    # 输入“samples”是一个列表
    # 每个元素都是一个二元组(图,标签)
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

正如打包 N 个张量得到的还是张量,dgl.batch 返回的也是一张图。这样的设计有两点好处。首先,任何用于操作一张小图的代码可以被直接使用在一个图批量上。其次,由于 DGL 能够并行处理图中节点和边上的计算,因此同一批量内的图样本都可以被并行计算

定义图分类器

这里使用的图分类器和应用在图像或者语音上的分类器类似——先通过多层神经网络计算每个样本的表示(representation),再通过表示计算出每个类别的概率,最后通过向后传播计算梯度。一个常见的图分类器由以下几个步骤构成:

  1. 通过图卷积(Graph Convolution)层获得图中每个节点的表示。
  2. 使用「读出」操作(Readout)获得每张图的表示。
  3. 使用 Softmax 计算每个类别的概率,使用向后传播更新参数。

下图展示了整个流程:
图分类器的步骤之后我们将分步讲解每一个步骤。

图卷积

我们的图卷积操作基本类似图卷积网络 GCN(具体可以参见我们的关于 GCN 的教程)。图卷积模型可以用以下公式表示:
h v l + 1 = R e L U ( b ( l ) + ∑ u ∈ N ( v ) h u ( l ) W ( l ) ) h_v^{l+1}=ReLU(b^{(l)}+\sum_{u\in N(v)}h_{u}^{(l)}W^{(l)}) hvl+1=ReLU(b(l)+uN(v)hu(l)W(l))
在这个例子中,对这个公式进行了微调:
h v l + 1 = R e L U ( b ( l ) + 1 ∣ N ( v ) ∣ ∑ u ∈ N ( v ) h u ( l ) W ( l ) ) h_v^{l+1}=ReLU(b^{(l)}+\frac{1}{|N(v)|} \sum_{u\in N(v)}h_{u}^{(l)}W^{(l)}) hvl+1=ReLU(b(l)+N(v)1uN(v)hu(l)W(l))
我们将求和替换成求平均可用来平衡度数不同的节点,在实验中这也带来了模型表现的提升。

此外,在构建数据集时,给每个图里所有的节点都加上了和自己的边(自环)。这保证节点在收集邻居节点表示进行更新时也能考虑到自己原有的表示。以下是定义图卷积模型的代码。这里使用 PyTorch 作为 DGL 的后端引擎(DGL 也支持 MXNet 作为后端)。

首先,使用 DGL 的内置函数定义消息传递:

import dgl.function as fn
import torch
import torch.nn as nn

# 将节点表示h作为信息发出
msg = fn.copy_src(src='h',out='m')

其次,定义消息累和函数。这里我们对收到的消息进行平均。

def reduce(nodes):
    """对所有邻接点节点特征求平均并覆盖原本的节点特征"""
    accum = torch.mean(nodes.mailbox['m'],1)
    return {'h':accum}

之后,对收到的消息应用线性变换和激活函数。

class NodeApplyModule(nn.Module):
    """将节点特征hv更新为ReLU(Whv+b)"""

    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h': h}

最后,把所有的小模块串联起来成为 GCNLayer。

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCNLayer, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        # 使用 h 初始化节点特征
        g.ndata['h'] = feature
        # 使用 update_all 接口和自定义的消息传递及累和函数更新节点表示
        g.update_all(msg, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

读出和分类

读出(Readout)操作的输入是图中所有节点的表示,输出则是整张图的表示。在 Google 的 Neural Message Passing for Quantum Chemistry(Gilmer et al. 2017) 论文中总结过许多不同种类的读出函数。在这个示例里,我们对图中所有节点表示取平均以作为图的表示:
h g = 1 ∣ V ∣ ∑ v ∈ V h v h_g=\frac{1}{|V|}\sum_{v\in V}h_v hg=V1vVhv

DGL 提供了许多读出函数接口,以上公式可以很方便地用 dgl.mean(g) 完成。最后将图的表示输入分类器。分类器对图表示先做了一个线性变换,然后得到每一类在 softmax 之前的 logits。具体代码如下:

import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        # 两层图卷积层
        self.layers = nn.ModuleList([
            GCNLayer(in_dim, hidden_dim, F.relu),
            GCNLayer(hidden_dim, hidden_dim, F.relu)
        ])
        # 分类层
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # 使用节点度数作为初始节点表示
        h = g.in_degrees().view(-1, 1).float()
        # 图卷积层
        for conv in self.layers:
            h = conv(g, h)
            g.ndata['h'] = h
            # 读出函数
            graph_repr = dgl.mean_nodes(g, 'h')
            # 分类层
        return self.classify(graph_repr)

准备和训练

阅读到这边的读者可以长舒一口气了。因为之后的训练过程和其他经典的图像,语音分类问题基本一致。首先创建了一个包含 400 张节点数量为 16~32的合成数据集。其中 320 张图作为训练数据集,80 张图作为测试集。

import torch.optim as optim
from torch.utils.data import DataLoader

# 创建一个训练数据集和测试数据集
trainset = MiniGCDataset(320, 16, 32)
testset = MiniGCDataset(80, 16, 32)

# 使用PyTorch的DataLoader和之前定义的collate函数
data_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate)

其次,创建一个刚刚定义的图神经网络模型对象。

# 其次创建一个图神经网络模型对象
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

训练过程则是经典的反向传播和梯度下降。

# 训练过程是经典的反向传播和梯度下降
epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        label = torch.tensor(label, dtype=torch.long)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

运行结果:
运行结果
下图是以上模型训练的学习曲线:
训练损失变化示意图
在训练完成后,在测试集上验证模型的表现。出于部署教程的考量,我们限制了模型训练的时间。如果你花更多时间训练模型,应该能得到更好的表现(80%-90%)。

为了更好地理解模型学到的节点和图的表示,我们使用了 t-SNE 来进行降维和可视化。
tSNE
两张小图分别可视化了做完 1 层和 2 层图卷积后的节点表示。不同颜色代表属于不同类别的图的节点。可以看到,经过训练后,属于同一类别的节点表示更加接近。并且,经过两层图卷积后这一聚类效果更明显。其原因是因为两层卷积后每个节点能接收到 2 度范围内的邻居信息。
readout之后的tSNE
底部的大图可视化了每张图在做 softmax 前的 logits,也就是图表示。可以看到通过读出函数后,图表示能非常好地各自区分开来。这一区分度比节点表示更加明显。

核心代码

import datetime
import pandas as pd

epochs = 100
log_step_freq = 10

dfhistory = pd.DataFrame(columns=['epoch', 'loss', metric_name, 'val_loss', 'val' + metric_name])
print("Start Training........")
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("==========" * 8 + "%s" % nowtime)

for epoch in range(1, epochs + 1):
    # 训练过程
    model.train()
    epoch_loss = 0.0
    metric_sum = 0.0
    step = 1
    for iter, (bg, label) in enumerate(data_loader, 1):
        # 梯度清零
        optimizer.zero_grad()
        # 正向传播损失
        prediction = model(bg)
        metric, _ = metric_func(prediction, label)

        label = label.to(torch.long)
        loss = loss_func(prediction, label)
        # 反向传播求梯度
        loss.backward()
        optimizer.step()
        # 打印batch级别日志
        epoch_loss += loss.detach().item()
        metric_sum += metric.item()
        if step % log_step_freq == 0:
            print(("[step = %d] loss: %.3f, " + metric_name + ": %.3f") % (step, epoch_loss / step, metric_sum / step))

    # 验证循环
    model.eval()
    val_loss = 0.0
    val_metric = 0.0
    val_step = 1
    for val_iter, (bg, label) in enumerate(val_loader, 1):
        with torch.no_grad():
            prediction = model(bg)
            val_metric, y_pred_cls = metric_func(prediction, label)
            label = label.to(torch.long)
            val_loss = loss_func(prediction, label)

        val_loss += val_loss.detach().item()
        val_metric += val_metric.item()
    # 记录日志
    info = (epoch, epoch_loss / step, metric_sum / step,
            val_loss / val_step, val_metric / val_step)
    dfhistory.loc[epoch - 1] = info
    print(("\nEPOCH = %d, loss = %.3f," + metric_name +
           "  = %.3f, val_loss = %.3f, " + "val_" + metric_name + " = %.3f")
          % info)
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n" + "==========" * 8 + "%s" % nowtime)

print("Finished Training...")

运行结果

参考资料

[1] https://www.jiqizhixin.com/articles/2019-01-29-2
[2] Task4:Pytorch实现模型训练与验证
[3] Pytorch实战总结篇之模型训练、评估与使用
[4] t-SNE及pytorch实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/11798.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

aws xray ec2环境搭建和基础用法

参考资料 https://docs.amazonaws.cn/en_us/xray/latest/devguide/xray-daemon.html https://docs.aws.amazon.com/xray-sdk-for-nodejs/latest/reference/ https://github.com/aws/aws-xray-sdk-node https://docs.aws.amazon.com/xray-sdk-for-python/latest/reference/ba…

联想集团:长期前景稳定,业务转型正在提高盈利能力

来源;猛兽财经 作者:猛兽财经 由疫情驱动的个人电脑需求正在减弱 在经历了两年的个人电脑销售强劲增长之后,随着全球对疫情封锁限制的放松,由疫情引发的远程工作和在线学习趋势带来的全球个人电脑需求正在减弱。根据IDC的数据,20…

文件之间的拷贝(拷贝图片实例)java.io.FileNotFoundException: G:\dad (拒绝访问。)通过绝对路径获取各种文件名

1.报错解决 :java.io.FileNotFoundException: G:\dad (拒绝访问。) 参考文献:(364条消息) java.io.FileNotFoundException:(拒接访问)_corelone2的博客-CSDN博客_java.io.filenotfoundexception 2.code 代码参考地址:(364条消息) java中文件拷贝的几种方式_babar…

深入理解New操作符

前言 当我们对函数进行实例化时,需要用new操作符来实现。那么,对于它的底层实现原理你是否清楚呢?本文就跟大家分享下它的原理并用一个函数来模拟实现它,欢迎各位感兴趣的开发者阅读本文。 原理分析 我们通过一个具体的例子来看…

MySQL——数据库基础

文章目录什么叫做数据库?主流数据库基本使用服务器、数据库、表之间的关系MySQL逻辑结构MySQL架构MySQL分类存储引擎什么叫做数据库? 软件角度: 为用户或者用户程序提供更加方便的数据管理的软件,通过SQL语句进行! 数…

【PostgreSQL-14版本snapshot的几点优化】

最近在分析PostgreSQL-14版本性能提升的时候,关注到了Snapshots的这一部分。发现在PostgreSQL-14版本,连续合入了好几个和Snapshots相关的patch。 并且,Andres Freund也通过这些改进显著减少了已确定的快照可扩展性瓶颈,从而改进了…

【C++】C/C++内存管理

众所周知,C/C没有内存(垃圾)回收机制,所以写C/C程序常常会面临内存泄漏等问题。这一节我们一起来学习C/C的内存管理机制,深入了解这套机制有利于我们之后写出更好的C/C程序。 在那些看不到太阳的日子里,别忘…

Spring(九)- Spring自定义命名空间整合第三方框架原理解析

文章目录一、Spring通过命名空间整合第三方框架1. Dubbo 命名空间2. Context 命名空间二、Spring自定义命名空间原理解析三、手写自定义命名空间标签与Spring整合一、Spring通过命名空间整合第三方框架 1. Dubbo 命名空间 Spring 整合其他组件时就不像MyBatis这么简单了&#…

电影影院购票管理系统

1、项目介绍 电影影院购票管理系统拥有两种角色:管理员和用户 管理员:用户管理、影片管理、影厅管理、订单管理、影评管理、排片管理等 用户:登录注册、个人中心、查看电影票、电影选座、下单支付、发布影评、查看票房统计等 2、项目技术 …

14、Horizontal Pod Autoscal

一、为何进行缩扩容? 在实际生产中,经常会遇到某个服务需要扩容的场景,可能会遇到由于资源紧张或者工作负载降低而需要减少服务实例数量的场景。可以利用Deployment/RC的Scale机制来完成这些工作。二、缩扩容模式 Kubernetes 对 Pod 扩容与缩…

mysql-Innodb解析

一.计算机不同介质操作速度 相对于CPU和内存操作, 我们可以看到磁盘的操作延时明显要大得多, 一次磁盘搜索的延时需要10ms。 假入我们某一个业务操作进行了大量磁盘读写, 那可以预料到这个服务的性能肯定是非常差的, 那么到底是什…

3.2文法与语言

1、文法生成语言 推导 定义:当αAβ直接推导出αγβ,即αAβ⇒αγβ,仅当A→γ是一个产生式,且α,β∈(VT∪VN)*。 注:按照我的理解是两个字符串的推导。如果α1⇒α2⇒…⇒αn,则我们称这个序列是从α1到αn的一个…

动态规划01 背包问题(算法)

上篇文章说了,查找组成一个偶数最接近的两个素数算法: 查找组成一个偶数最接近的两个素数https://blog.csdn.net/ke1ying/article/details/127872594 本篇文章题目是 动态规划01 背包问题: 背包容量5kg,现在有三个物体&#xf…

BVH动捕文件导入到E3D骨骼树

BVH动捕文件导入到E3D骨骼树 1. BVH动捕文件 BVH动作捕捉文件有两部分组成,第一部分描述了静止状态下角色的基本骨骼结构,角色通常处于Apose或Tpose姿态下.文本用树状结构描述了各个关节点的相对位置(OFFSET xyz),连接两关节点的…

学好MySQL增删查改,争取不做CURD程序员【下篇(六个小时肝MySQL万字大总结)】

✨✨hello,愿意点进来的小伙伴们,你们好呐! 🐻🐻系列专栏:【MySQL初阶】 🐲🐲本篇内容:一套打通MySQL基础操作. 🐯🐯作者简介:一名现大二的三非编…

解决小程序-wx.canvasGetImageData()-RGB取色盘苹果手机获取颜色慢问题

简介 最近做了一个微信小程序控制蓝牙设备,通过小程序中的RGB取色盘,获取当前的RGB颜色,通过蓝牙发送给设备,设备接收到RGB以后,做出相应的调整。 图1:RGB取色盘 在安卓手机上运行正常,能够迅速…

企业实战项目rsync+inotify实现实时同步

目录 一、inotify安装和介绍 1. 安装inotify 2. inotify-tools常用命令 3. rsync inotify 实践 3.1 服务端配置 3.2 客户端配置 一、inotify安装和介绍 1. 安装inotify yum install epel-release -y yum install inotify-tools -y 2. inotify-tools常用命令 inotify-to…

C++ 使用哈希表封装模拟实现unordered_map unordered_set

一、unordered_map unordered_set 和 map set的区别 1. map set底层采取的红黑树的结构,unordered_xxx 底层数据结构是哈希表。unordered_map容器通过key访问单个元素要比map快,但它通常在遍历元素子集的范围迭代方面效率较低。 2. Java中对应的容器名…

vivo和oppo通知权限弹窗

在vivo和oppo部分手机上,首次安装app时,会弹出一个系统级的通知权限弹窗,(部分一加手机也会出现,是因为一加手机使用了OPPO的colorOS系统)如图。 这个通知权限弹窗比较坑,一来可能不符合产品对…

Word控件Spire.Doc 【文本】教程(21) ;如何在 C# 中用 Word 文档替换文本

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下,轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具,专注于创建、编辑、转…