图神经网络实战(10)——归纳学习

news2024/11/25 13:16:16

图神经网络实战(10)——归纳学习

    • 0. 前言
    • 1. 转导学习与归纳学习
    • 2. 蛋白质相互作用数据集
    • 3. 构建 GraphSAGE 模型实现归纳学习
    • 小结
    • 系列链接

0. 前言

归纳学习 (Inductive learning) 通过基于已观测训练数据,建立一个通用模型,使模型能够对未见过的节点和图进行归纳预测,而转导学习(Transductive learning, 也称直推学习)是基于所有已经观测到的训练和测试数据构建模型,这种方法是通过已经有标记的节点信息来预测无标记数据节点,因此,在图神经网络 (Graph Neural Networks, GNN)、图卷积网络 (Graph Convolutional Network, GCN)、图注意力网络 (Graph Attention Networks,GAT) 和 GraphSAGE 等节中所构建的模型均属于转导学习模型。在本节中,我们将介绍图数据中的归纳学习和多标签分类,使用 GraphSAGE 模型在蛋白质相互作用 (protein-protein interactions) 数据集执行多标签分类任务,并了解归纳学习的优势和实现方法。

1. 转导学习与归纳学习

在图神经网络 (Graph Neural Networks, GNN)中,可以将学习分为两类,转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning):

  • 在归纳学习中,GNN 在训练过程中只看到训练集中的数据,而在测试过程中需要对未见过的数据进行预测,这属于机器学习中典型的监督学习 (supervised learning)。在这种情况下,标签用来调整 GNN 的参数,模型需要具备良好的泛化能力,能够从有限的样本中推断出普遍适用的规律
  • 在转导学习中,GNN 在训练过程中会看到来自训练集和测试集的数据,它通过对已有的样本进行学习来进行预测和分类。模型只从训练集中学习数据,模型会尝试将已有的样本归类到已知的类别中,并根据这些样本的特征进行预测,标签用于信息扩散。转导学习不是直接从训练集中学习出一般性的规律,而是利用图数据间的相似性或连接性进行预测

我们在之前构建的图神经网络 (Graph Neural Networks, GNN) 和图卷积网络 (Graph Convolutional Network, GCN) 属于转导学习情况。而 GraphSAGE 模型可以在训练过程中使用整个图进行预测 (self(batch.x, batch.edge_index)),然后部分屏蔽这些预测,只使用训练数据计算损失并训练模型 (criterion(out[batch.train_mask], batch.y[batch.train_mask]))。
转导学习只能为固定的图生成嵌入,不能泛化到未见过的节点或图。但由于采用了邻居采样,GraphSAGE 可以在局部水平上对经过剪枝的计算图进行预测,这种情况下属于归纳学习框架,可以应用于具有相同特征模式的任何计算图。

2. 蛋白质相互作用数据集

在 GraphSAGE 一节中,我们已经在 PubMed 数据集上构建 GraphSAGE 模型实现了转导学习。接下来,我们将 GraphSAGE 应用于由 Agrawal 等人提出的蛋白质相互作用 (protein-protein interaction, PPI) 网络数据集。该数据集是 24 个图的集合,其中节点( 21,557 个)是人类蛋白质,边( 342,353 条)是人类细胞中蛋白质之间的连接。用 Gephi 制作的 PPI 图数据集可视化结果如下所示:

PPI 数据集

该数据集的目标是使用 121 个标签进行多标签分类,这意味着每个节点可以具有 0121 个标签。这不同于多类别分类,多类别分类中每个节点只会属于一个类别。接下来,我们使用 PyTorch Geometric (PyG) 实现 GraphSAGE 模型用于对 PPI 数据集执行多标签分类任务。

(1)PPI 数据集加载为三个不同的子集,训练集、验证集和测试集:

import torch
from sklearn.metrics import f1_score

from torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE

# Load training, evaluation, and test sets
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')

(2) 训练集包含 20 个图,而验证集和测试集只有两个图。对训练集应用邻居采样,为了方便起见,使用 Batch.from_data_list() 将所有训练图统一到一个集合中,然后应用邻居采样:

train_data = Batch.from_data_list(train_dataset)
train_loader = NeighborLoader(train_data, batch_size=2048, shuffle=True, num_neighbors=[20, 10], num_workers=2, persistent_workers=True)

(3) 训练集创建完毕后,使用 DataLoader 类创建批数据,将 batch_size 值定义为 2,与每批图的数量相对应:

val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

(4) 定义设备使批处理能够在 GPU 上进行处理。如果计算机中有 GPU,使用 GPU,否则就使用 CPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. 构建 GraphSAGE 模型实现归纳学习

使用 PyTorch Geometrictorch_geometric.nn 模块构建 GraphSAGE 模型。

(1) 使用 GraphSAGE() 类初始化一个两层的 GraphSAGE 模型,其中隐藏维度为 512,此外,还需要使用 to(device) 将模型放置在与数据相同的设备上:

model = GraphSAGE(
    in_channels=train_dataset.num_features,
    hidden_channels=512,
    num_layers=2,
    out_channels=train_dataset.num_classes,
).to(device)

(2) fit() 函数与 GraphSAGE 一节中使用的函数类似,不同之处在于,我们希望尽可能将数据移动到 GPU 上,并且由于每批数据有两个图,因此将损失乘以 2 (data.num_graphs):

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

def fit(loader):
    model.train()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(loader.data)

由于 val_loadertest_loader 包含两个图且 batch_size 值为 2,因此在 test() 函数中,两个图位于同一个批数据中,而无需像训练时那样在加载器中循环。

(3) 使用度量指标 F1 分数代替准确率,F1 分数相当于精确度和召回率的调和平均值。但,模型的预测结果是 121 维的实数向量,我们需要将其转换成二进制向量,使用 out > 0 将它们与 data.y 进行比较:

@torch.no_grad()
def test(loader):
    model.eval()

    data = next(iter(loader))
    out = model(data.x.to(device), data.edge_index.to(device))
    preds = (out > 0).float().cpu()

    y, pred = data.y.numpy(), preds.numpy()
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

(4) 对模型进行 300epoch 的训练,并打印训练过程中模型在验证数据集上的 F1 分数:

for epoch in range(301):
    loss = fit(train_loader)
    val_f1 = test(val_loader)
    if epoch % 50 == 0:
        print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')
'''
Epoch   0 | Train Loss: 12.686 | Val F1-score: 0.4866
Epoch  50 | Train Loss: 8.734 | Val F1-score: 0.7963
Epoch 100 | Train Loss: 8.600 | Val F1-score: 0.8098
Epoch 150 | Train Loss: 8.531 | Val F1-score: 0.8202
Epoch 200 | Train Loss: 8.495 | Val F1-score: 0.8230
Epoch 250 | Train Loss: 8.497 | Val F1-score: 0.8255
Epoch 300 | Train Loss: 8.432 | Val F1-score: 0.8290
'''

(5) 最后,计算测试集上的 F1 分数:

print(f'Test F1-score: {test(test_loader):.4f}')

# Test F1-score: 0.8527

可以看到,在归纳学习中,模型在 PPI 数据集上训练后的 F1 分数为 0.9360。当增加或减少隐藏维度的大小时,模型的性能会有有大幅改变,我们可以使用不同的值,如 1281,024,并观察训练的后的模型 F1 分数变化。
需要注意的是,在以上代码中,我们并未显式的使用掩码。这是由于实际上,归纳学习是由 PPI 数据集强制实现的;训练数据、验证数据和测试数据位于不同的图和数据加载器中。我们也可以使用 Batch.from_data_list() 将它们合并,然后再使用归纳学习的设定。

小结

在本节中,学习了图神经网络中转导学习(Transductive learning, 也称直推学习)和归纳学习 (Inductive learning) 的区别。其中,图神经网络中的归纳学习通常指的是从给定的训练图数据中学习出一个泛化能力强的模型,以便对未知图数据中的节点或边进行预测或分类,而转导学习通常指的是利用训练图数据和测试图数据之间的关联性进行推断,从而对给定的测试图数据进行预测或分类。并且构建了 GraphSAGE 模型在 PPI 数据集上测试了归纳学习,以执行多标签分类任务。

系列链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现
图神经网络实战(8)——图注意力网络(Graph Attention Networks, GAT)
图神经网络实战(9)——GraphSAGE详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1669586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

冯喜运:5.13黄金多头反扑欲“染指”2400,今日原油走势分析

【黄金消息面分析】:周一(5月13)亚市,现货黄金窄幅震荡,目前交投于2362.00美元/盎司附近。金价上周五攀升0.6%,收报2360.75美元/盎司,录得五周来最佳单周表现,因近期美国就业数据疲弱…

基于FPGA的数字信号处理(12)--定点数的舍入模式(3)收敛取整convergent

前言 在之前的文章介绍了定点数为什么需要舍入和几种常见的舍入模式。今天我们再来看看另外一种舍入模式:收敛取整convergent。 10进制数的convergent convergent: 收敛取整。它的舍入方式和四舍五入非常类似,都是舍入到最近的整数&#x…

高校推免报名|基于SSM+vue的高校推免报名系统的设计与实现(源码+数据库+文档)

高校推免报名 目录 基于SSM+vue的高校推免报名的设计与实现 一、前言 二、系统设计 三、系统功能设计 1系统功能模块 2后台登录模块 5.2.1管理员功能模块 5.2.2考生功能模版 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八…

Cross-Image Attention for Zero-Shot Appearance Transfer——【代码复现】

本文发表于SIGGRAPH 2024,是一篇关于图像编辑的论文,Github官网网址如下: garibida/cross-image-attention: “Cross-Image Attention for Zero-Shot Appearance Transfer”的正式实现 (github.com) 一、基本配置环境准备 请确保…

国产之光:SmartEDA电路仿真软件何以超越传统,引领新潮流?

在当今电子工程领域,电路仿真软件的重要性不言而喻。它不仅是工程师们进行电路设计、分析和优化的得力助手,也是学生们深入理解电路原理、提高实践操作能力的关键工具。近年来,一款名为SmartEDA的国产电路仿真软件逐渐崭露头角,以…

Python 全栈系列244 nginx upstream 负载均衡 踩坑日记

说明 最初是因为租用算力机(Python 全栈系列242 踩坑记录:租用算力机完成任务),所以想着做一个负载均衡,然后多开一些服务,把配置写在nginx里面就好了。 一开始租用了一个3080起了一个服务,后来觉得速度不够快,再起了…

el-menu 保持展开点击不收缩 默认选择第一个菜单

<el-menu:default-openeds"[/system]" 数组 默认展开第一个:collapse"isCollapse"close"handleClose" 点击关闭的时候 让菜单打开 就可以实现保持展开效果ref"menus":unique-opened"true":active-text-color"se…

笔记-跨域方式实现原理

websocket Websocket是HTML5的一个持久化的协议&#xff0c;它实现了浏览器与服务器的全双工通信&#xff0c;同时也是跨域的一种解决方案。WebSocket和HTTP都是应用层协议&#xff0c;都基于 TCP 协议。但是 WebSocket 是一种双向通信协议&#xff0c;在建立连接之后&#xff…

振弦采集仪在岩土工程中的实时监测和预警作用

振弦采集仪在岩土工程中的实时监测和预警作用 河北稳控科技振弦采集仪被广泛应用于岩土工程中的实时监测和预警。它通过对地下振弦信号的连续监测和分析&#xff0c;能够提供准确的地下结构变形和应力变化信息&#xff0c;为岩土工程的安全和稳定提供重要的支持。 振弦采集仪主…

python爬虫(四)之九章智算汽车文章爬虫

python爬虫&#xff08;四&#xff09;之九章智算汽车文章爬虫 闲来没事就写一条爬虫抓取网页上的数据&#xff0c;现在数据已经抓完&#xff0c;将九章智算汽车文章的爬虫代码分享出来。当前代码采用python编写&#xff0c;可抓取所有文章&#xff0c;攻大家参考。 import r…

宝塔安装多个版本的PHP,如何设置默认的PHP版本

如何将默认的PHP版本设置为7.3.32&#xff0c; 创建软链接指向7.3版本&#xff0c;关键命令&#xff1a;ln -sf /www/server/php/73/bin/php /usr/bin/php 然后再查看PHP版本验证一下结果 [rootlocalhost ~]# ln -sf /www/server/php/73/bin/php /usr/bin/php [rootlocalho…

共享充电宝语音芯片ic方案支持远程4g无线更新语音

一、简介 共享充电宝语音芯片ic方案支持远程4g无线wifi蓝牙更新语音 共享充电宝已经是遍布在大街小巷的好产品&#xff0c;解决了携带充电宝麻烦的痛点 但是很多的共享充电宝在人机交互方便&#xff0c;还做得不够好&#xff0c;比如&#xff1a;借、还设备没有语音提示&…

开散列哈希桶

通过上面这幅图&#xff0c;读者应该能较为直观地理解何为开散列&#xff0c;以及闭散列与开散列的区别在哪里 —— 数据的存储形式不同&#xff0c;至于其他的&#xff0c;如确定每个元素的哈希地址等一概相同。 与闭散列相比&#xff0c;开散列能够更好地处理发生冲突的元素 …

知识付费行业数字化转型:转的是什么?你知道吗!

在知识付费的浪潮中&#xff0c;数字化转型正悄然改变着这个行业的格局&#xff01;那么&#xff0c;知识付费行业数字化转型到底转的是什么呢&#xff1f;这是一个值得我们深入探讨的问题。 1.转的是商业模式&#xff1a;从传统的销售模式转向多元化的盈利模式。从简单的买卖关…

数据结构(二) 线性表

2024年5月13日一稿 线性表的定义与基本操作 数据类型相同(各个元素占用空间相同) 是有限序列 基操

Netty源码分析二NioEventLoop 剖析

剖析方向 NioEventLoop是一个重量级的类&#xff0c;其中涉及到的方法都有很复杂的继承关系&#xff0c;调用链&#xff0c;要想把源码全部过一遍工作量实在是太大了&#xff0c;于是小编就基于下面的这些常见的问题来对NioEventLoop的源码来进行剖析 1.Seletor何时创建 1.1Se…

前端Vue架构

1 理解&#xff1a; 创建视图的函数&#xff08;render&#xff09;和数据之间的关联&#xff1b; 当数据发生变化的时候&#xff0c;希望render重新执行&#xff1b; 监听数据的读取和修改&#xff1b; defineProperty&#xff1a;监听范围比较窄&#xff0c;只能通过属性描…

基于SSM的计算机课程实验管理系统的设计与实现(源码)

| 博主介绍&#xff1a;✌程序员徐师兄、8年大厂程序员经历。全网粉丝15w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f44…

架构每日一学 5:拼多多如何通过洞察人性脱颖而出?

本文首发于公众平台&#xff1a;腐烂的橘子 上一篇文章&#xff0c;我们讲到架构活动一定要顺应人性&#xff0c;今天我们就来聊一聊&#xff0c;拼多多如何通过洞察人性在电商行业脱颖而出。 拼多多从诞生到现在&#xff0c;可以说是颠覆了整个互联网的认知。 2015 年&#…

JSON 转为json串后出现 “$ref“

问题描述 转为JSON 串时出现 "$ref":"$.RequestParam.list[0]" $ref&#xff1a; fastjson数据重复的部分会用引用代替&#xff0c;当一个对象包含另一个对象时&#xff0c;fastjson就会把该对象解析成引用 “$ref”:”..” 上一级 “$ref”:”” 当前对…