【DGL】图分类

news2024/11/27 10:10:16

目录

    • 概述
    • 数据集
    • 定义Data Loader
    • DGL中的batched graph
    • 定义模型
    • 训练
    • 参考

概述

除了节点级别的问题——节点分类、边级别的问题——链接预测之外,还有整个图级别的问题——图分类。经过聚合、传递消息得到节点和边的新的表征后,映射得到整个图的表征。

数据集

dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
g = dataset[0]
print(g)
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)
(Graph(num_nodes=42, num_edges=204,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor(0))
Node feature dimensionality: 3
Number of graph categories: 2

共1113个图,每个图中的节点的特征维度是3,图的类别数是2.

定义Data Loader

from torch.utils.data.sampler import SubsetRandomSampler

from dgl.dataloading import GraphDataLoader

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)

取80%用作训练集,其余用作测试集
mini-batch操作,取5个graph打包成一个大的batched graph

it = iter(train_dataloader)
batch = next(it)
print(batch)
[Graph(num_nodes=259, num_edges=1201,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor([0, 1, 0, 0, 0])]

DGL中的batched graph

在这里插入图片描述
在每个mini-batch里面,batched graph是由dgl.batch对graph进行打包的

batched_graph, labels = batch
print(
    "Number of nodes for each graph element in the batch:",
    batched_graph.batch_num_nodes(),
)
print(
    "Number of edges for each graph element in the batch:",
    batched_graph.batch_num_edges(),
)

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print("The original graphs in the minibatch:")
print(graphs)
Number of nodes for each graph element in the batch: tensor([ 55,  16, 116,  31,  41])
Number of edges for each graph element in the batch: tensor([209,  70, 584, 153, 185])
The original graphs in the minibatch:
[Graph(num_nodes=55, num_edges=209,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=16, num_edges=70,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=116, num_edges=584,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=31, num_edges=153,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=41, num_edges=185,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})]

定义模型

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")#取所有节点的'h'特征的平均值来表征整个图  readout

model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

一个batched graph中,不同的图是完全分开的,即没有边连接两个图,所有消息传递函数仍然具有相同的结果(和没有打包之前相比)。
其次,将对每个图分别执行readout功能。假设批次大小为B,要聚合的特征维度为D,则读取出的形状为(B, D)。

训练

for epoch in range(20):
    num_correct = 0
    num_trains = 0
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        num_trains += len(labels)
        num_correct += (pred.argmax(1)==labels).sum().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('train accuracy: ', num_correct/num_trains)

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1)==labels).sum().item()
    num_tests += len(labels)

print("Test accuracy: ", num_correct/num_tests)
train accuracy:  0.7404494382022472
train accuracy:  0.7426966292134831
train accuracy:  0.7471910112359551
train accuracy:  0.7539325842696629
train accuracy:  0.7584269662921348
train accuracy:  0.7674157303370787
train accuracy:  0.7629213483146068
train accuracy:  0.7617977528089888
train accuracy:  0.7584269662921348
train accuracy:  0.7707865168539326
train accuracy:  0.7629213483146068
train accuracy:  0.7651685393258427
train accuracy:  0.7629213483146068
train accuracy:  0.7561797752808989
train accuracy:  0.7606741573033707
train accuracy:  0.7584269662921348
train accuracy:  0.7617977528089888
train accuracy:  0.7707865168539326
train accuracy:  0.7629213483146068
train accuracy:  0.7539325842696629

Test accuracy:  0.26905829596412556

效果非常一般 明显过拟合 应该和没有边特征,节点特征信息不足有关。

参考

https://docs.dgl.ai/tutorials/blitz/5_graph_classification.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/335824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pycharm无法通过外网访问阿里云服务器中的Flask解决方案

一、修改/添加安全组端口这是第一种方案,也是能解决大部分问题的一个方案。由于我的服务器是阿里云的,所以在阿里云的ECS云服务器控制台中,管理安全组,添加5000和8000端口以便测试。经过测试,外网依旧无法访问。二、修…

Shell脚本之——Hadoop3单机版安装

目录 1.解压 2.文件重命名 3.配置环境变量 4.hadoop-env.sh 5.core-site.xml 6. hdfs-site.xml 7. mapred-site.xml 8.yarn-site.xml 9.完整脚本代码(注意修改主机名) 10.重启环境变量 11.初始化 12.启动服务 13.jps查询节点 1.解压 tar -zxf /opt/install/hadoo…

【速通版】吴恩达机器学习笔记Part3

目录 1.多元线性回归 a.特征缩放 可行的缩放方式: 1.除以最大值: 2.mean normalization: 3.Z-score normalization b.learning curve: c.learning rate: 2.多项式回归 3.classification logistics regression 1.多元线性回归 其意义很…

UML术语标准和分类

一、UML术语标准 1.中文UML术语标准 中国软件行业协会(CSIA)与日本UML建模推进协会(UMTP)共同在中国推动的UML专家认证,两个协会共同颁发认证证书、两国互认,CSIA与UMTP共同推出了UML中文术语…

(record)QEMU安装最小linux系统——TinyCore(命令行版)

文章目录QEMU安装最小linux系统——TinyCore参考QEMU使用qemu创建tinycore虚拟机再次启动文件保存QEMU安装最小linux系统——TinyCore 简单记录安装过程和记录点 参考 [原创] qemu 与 Tiny Core tinycore的探索 QEMU qemu不多介绍,这里是在WSL2上安装的linux版…

最近很火的一部电视(狂飙)像安欣和高启强这样类型的人,谁更合适做软件测试工程师

狂飙》央视收视率狂飙。央视发布《狂飙》收视成绩,全剧平均收视1.54%,平均收视份额6.99%,单集最高收视率2.20%,单集最高收视份额10.69%;晚间电视剧类节目第一。可以说还部剧为今年开了个好头,一开年就引爆收…

财报解读:四季度营收超预期,优步却越来越“不务正业”了

“公司第四季度的业绩表现将是强劲的”。 公布2022年第三季度财报时,优步的高管给出了这样的预告,给资本市场打了一针“强心剂”。然而有人对此表示质疑,后疫情时代,带着新模式、新车型的全新网约车公司层出不穷,车企…

Java面试数据库

目录 一、关系型数据库 数据库权限 表设计及创建 表数据相关 数据库架构优化 二、非关系型数据库 redis 今天给大家稍微整理了一下,内容有数据表设计的三大范式原则、sql查询如何优化、redis数据的击穿、穿透、雪崩等...,以及相关的面试题&#xff0…

Intel中断体系(1)中断与异常处理

文章目录概述中断与异常中断可屏蔽中断与不可屏蔽中断(NMI)异常异常分类中断与异常向量中断描述符表中断描述符中断与异常处理中断与异常处理过程堆栈切换错误码64位模式下的中断异常处理64位中断描述符64位处理器下的堆栈切换相关参考概述 中断是现代计…

不用创建项目,直接在 VS 里快速测试 C/C++ 代码

概述 Visual Studio 强大、方便,但是每次写代码都要先创建新项目,这对于一些简单的代码测试来说有点不方便。 本文介绍一种使用 VS 快速测试代码的方法。 该方法适用任何版本的 VS。“不用创建项目”,是指不用“手工”创建项目&#xff0c…

Python Scrapy 爬虫简单教程

1. Scrapy install 准备知识 pip 包管理Python 安装XpathCssWindows安装 Scrapy $>- pip install scrapy Linux安装 Scrapy $>- apt-get install python-scrapy 2. Scrapy 项目创建 在开始爬取之前,必须创建一个新的Scrapy项目。进入自定义的项目目录中&am…

爆火出圈的ChatGPT,真的那么好用吗?

近期,ChatGPT在互联网行业爆火! 这个由人工智能研究和部署公司OpenAI开发的“交互机器人”,在今年1月其全球月活跃用户已达1亿,成为史上用户增长速度最快的消费级应用。 爆火的ChatGPT到底是什么? ChatGPT是一个原型人…

Java程序的执行顺序、简述对线程池的理解

点个关注,必回关 文章目录一、Java程序是如何执行的二、合理利用线程池能够带来三个好处一、Java程序是如何执行的 我们日常的工作中都使用开发工具(IntelliJ IDEA 或 Eclipse 等)可以很方便的调试程序,或者是通 过打包工具把项目…

删除Node.js,安装nvm,看这一篇就够了(有坑)

nvm的作用就是可以任意切换Node.js的版本,所以在下载nvm之前,现将系统中的Node.js全部删除,若之前没有安装过,可忽略第一步。 删除Node.js 一、程序和功能处找到Node.js,并删除 二、删除Node.js相关的目录文件 C:\Program Fil…

Science:北京脑研究中心李莹实验室揭示性满足感的分子机制

短暂的社交经历(例如,性经历)可导致内部状态的长期变化并影响社会行为,如交配、攻击。例如,在成功交配射精后,许多物种迅速表现出对交配倾向的抑制有数小时、数天或更长时间,这种效应称为性满足…

【报复性赚钱】2023年5大风口行业

今天就来和大家分享一下,在时代的洪流下,普通人如何顺应大势抓住机遇! 实现人在风口上,猪都会飞起来。 根据对市场的观察及各平台数据分析结果,结合国家政策和经济专家的分析,小编预测了2023年将会迎来大…

“1+1>2”!《我要投资》与天际汽车再度“双向奔赴”!

文|螳螂观察 作者| 图霖 胡海泉老师重磅回归、创始人现场真情告白……新一季的《我要投资》,不仅维持了往季在专业度上的高水准,也贡献了不少高话题度的“出圈”时刻。 在竞争激烈的的综艺节目竞技场,能举办数季的节目,往往都是…

Linux修改文件时间或创建新文件:touch

每个文件在Linux下面都记录了许多的时间参数,其实是三个主要的变动时间 修改时间(modification time,mtime):当该文件的【内容数据】变更时,就会更新这个时间,内容数据是指文件的内容&#xff…

Zabbix 构建监控告警平台(一)--部署安装

监控对象监控收集信息方式Zabbix 部署 1.监控对象 源代码: *.html *.jsp *.php *.py 数据库: MySQL,MariaDB,Oracle,SQL Server,DB2 应用软件:Nginx,Apache,PHP,Tomcat agent 集群: LVS,Keepalived,HAproxy…

期望风险, 经验风险和结构风险

经验风险模型关于所有训练集上的平均损失称为经验风险或经验损失.公式如下:至此, 我们通过计算单点误差损失的平均值来衡量(刻画)模型对训练集拟合的好坏, 但是我们如何衡量模型对未知数据的拟合能力呢, 也就是如何衡量模型在全体数据集上的性能, 因此我们引入概率论中两个随机…