图神经网络:处理复杂关系结构与图分类任务的强大工具

news2024/11/20 18:29:47

创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力!图神经网络

图神经网络(Graph Neural Network, GNN)是针对图数据的一类神经网络模型。图数据具有节点(节点代表实体)和边(边代表节点之间的关系),因此,GNN能够处理这种复杂的关系结构,提取图结构中有用的信息。GNN的基本思想是通过消息传递(message passing)机制将节点和它们的邻居进行特征融合,从而更新节点的表示。这种表示可以用来进行节点分类、边预测或者整个图的分类等任务。

1. GNN基础知识

GNN的核心机制是基于图的消息传递和特征聚合。对于每个节点,GNN会收集其邻居节点的信息,然后通过一定的聚合函数(例如求和或平均)生成新的特征表示。

1.1 图的定义
  • 节点(Node):图中的实体,记作 (v_i)。
  • 边(Edge):节点之间的关系,记作 (e_{ij}),表示从节点 (v_i) 到节点 (v_j) 的连接。
  • 邻居节点(Neighbors):节点 (v_i) 的直接相连节点集合,记作 (N(v_i))。
1.2 GNN的消息传递机制

GNN的基本操作包括两个步骤:

  1. 消息传递(Message Passing):从每个节点的邻居节点收集特征。
  2. 特征更新(Feature Update):将节点的特征与邻居的特征聚合,更新节点的表示。

假设节点 (v_i) 的初始特征为 (h_i^{(0)}),其第 (k) 次迭代时的特征表示为 (h_i^{(k)})。GNN通过以下两步进行更新:

  • 聚合邻居特征:将节点 (v_i) 的所有邻居节点的特征聚合起来,例如求和或平均:
    [
    m_i^{(k)} = \text{AGGREGATE}({ h_j^{(k-1)} : j \in N(v_i) })
    ]
  • 更新节点特征:将聚合的邻居特征与节点本身的特征结合起来,更新节点的表示:
    [
    h_i^{(k)} = \text{UPDATE}(h_i^{(k-1)}, m_i^{(k)})
    ]
1.3 GNN在图分类任务中的应用

图分类任务的目标是给定一张图,预测该图的类别。常见应用包括化学分子分类、社交网络分析等。在这种任务中,GNN的目标是通过学习图的全局结构信息来预测整张图的标签。

GNN处理图分类任务的流程一般如下:

  1. 特征初始化:给每个节点赋予初始特征(可以是节点的属性)。
  2. 消息传递与特征更新:通过多层GNN层,将节点特征与其邻居进行聚合和更新。
  3. 图的汇总(Readout):将所有节点的特征汇总为图的表示(例如通过求平均或全连接层)。
  4. 分类器:使用图的表示作为输入,通过一个分类器预测图的类别。

2. Python实现示例

我们可以使用PyTorch Geometric来实现一个简单的图分类任务。

2.1 安装依赖

首先,你需要安装PyTorchPyTorch Geometric库:

pip install torch
pip install torch-geometric
2.2 数据准备

我们使用PyTorch Geometric中的一个经典的图分类数据集MUTAG,这是一个小型化学分子数据集,每个分子作为一张图,目标是预测分子的类别。

import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# 加载数据集
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

# 划分训练集和测试集
train_dataset = dataset[:150]
test_dataset = dataset[150:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
2.3 定义GNN模型

我们定义一个简单的图卷积网络(GCN)用于图分类任务。

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        # 定义两个GCN层
        self.conv1 = GCNConv(dataset.num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        # 最后一个全连接层用于图分类
        self.fc = torch.nn.Linear(64, dataset.num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # 第一层GCN + ReLU激活
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        # 第二层GCN
        x = self.conv2(x, edge_index)
        
        # 使用全局平均池化将节点特征聚合为图的特征
        x = global_mean_pool(x, batch)
        
        # 最后通过全连接层进行分类
        x = self.fc(x)
        
        return F.log_softmax(x, dim=1)
2.4 模型训练和测试

我们定义训练和测试的函数,分别用于训练模型和评估模型的性能。

# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

# 训练模型
for epoch in range(1, 201):
    loss = train()
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')
2.5 解释代码
  • GCNConv:图卷积层,用于将节点的特征与其邻居的特征进行聚合。
  • global_mean_pool:对图中的所有节点特征进行全局池化,将节点特征汇总为图的特征表示。
  • forward:定义了模型的前向传播,输入图的特征和结构,输出图的类别预测。

通过上述代码,你可以用GNN进行图分类任务。这个模型会对每张图中的所有节点进行特征更新,并最终通过全连接层进行分类。

大家有技术交流指导、论文及技术文档写作指导、课程知识点讲解、项目开发合作的需求可以搜索关注我私信我

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode[中等] 55.跳跃游戏

给你一个非负整数数组 nums ,你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标,如果可以,返回 true ;否则,返回 false 。 思路 贪心算法 可达位置…

html5 + css3(下)

目录 CSS基础体验cssCSS引入方式选择器选择器-标签选择器-类选择器-id选择器-通配符 文字基本样式1.1 字体大小1.2 字体粗细1.3 字体样式(是否倾斜) 文字-字体1.4 常见字体系列(了解)1.5 字体系列 拓展-层叠性font复合属性文本缩进…

erlang学习:Linux命令学习8

shell脚本案例学习 循环求 1-100 的每一步和 —案例 j0 i1 while((i<100)) do j$((ji)) echo $j ((i)) done每 30 s循环判断一次 user 用户是否登录系统 —案例 设置了一个次数&#xff0c;如果循环了五次在user文件中添加user用户&#xff0c;表示用户登录 USERS"u…

嵌入式 ADC基础知识

在现实世界中&#xff0c;常见的信号大都是模拟量&#xff0c;像温度、声音、气压等&#xff0c;但在信号的处理与传输中&#xff0c;为了减少噪声的干扰&#xff0c;较多使用的是数字量。因此我们经常会将现实中的模拟信号&#xff0c;通过 ADC 转换为数字信号进行运算、传输、…

Java | Leetcode Java题解之第442题数组中重复的数据

题目&#xff1a; 题解&#xff1a; class Solution {public List<Integer> findDuplicates(int[] nums) {int n nums.length;List<Integer> ans new ArrayList<Integer>();for (int i 0; i < n; i) {int x Math.abs(nums[i]);if (nums[x - 1] > …

端到端如火如荼, 传统规划控制还有前途吗?

近些年自动驾驶领域一定绕不开端到端, 伴随着各大车企纷纷转向拥抱端到端, 传统PnC的处境似乎愈发尴尬了起来. 但是端到端真的如水中月镜中花般美好吗? 不可否认深度学习给诸多领域带来了天翻地覆的变化, 但是自动驾驶直接关系到交通安全. 自动驾驶系统的输出, 必须具备足够的…

YOLO11改进|注意力机制篇|引入MLCA轻量级注意力机制

目录 一、MLCA注意力机制1.1MLCA注意力介绍1.2MLCA核心代码 五、添加MLCA注意力机制5.1STEP15.2STEP25.3STEP35.4STEP4 六、yaml文件与运行6.1yaml文件6.2运行成功截图 一、MLCA注意力机制 1.1MLCA注意力介绍 MLCA&#xff08;Multi-Level Channel Attention&#xff0c;多级通…

简单的微信小程序登录 注册 页面及逻辑

一、示例 二、示例代码 1.wxml <!--pages/login.wxml--> <!-- 登录注册文字 --> <view class"title">{{TitleText}}</view> <!-- 登录框 --> <view class"inputBox"><input type"text" placeholder&qu…

Nature Machine Intelligence 基于强化学习的扑翼无人机机翼应变飞行控制

尽管无人机技术发展迅速&#xff0c;但复制生物飞行的动态控制和风力感应能力&#xff0c;仍然遥不可及。生物学研究表明&#xff0c;昆虫翅膀上有机械感受器&#xff0c;即钟形感受器campaniform sensilla&#xff0c;探测飞行敏捷性至关重要的复杂气动载荷。 近日&#xff0…

国庆普及模拟赛-1 赛后总结

题目链接&#xff1a; file:///D:/C/%E9%9B%86%E8%AE%AD%E6%B5%8B%E8%AF%95/1001/2022%20-%20J2.pdf T1&#xff1a;隔离 题意如图。需要求所有时间的最短。 思路&#xff1a; 不需要进行一次次枚举&#xff0c;先算出总共要办事的总时间sum&#xff0c;如果某一次时间超过2…

Mysql数据库~~条件查询、分页查询、修改操作

目录 1.表的其他操作 1.1创建一个表 1.2对于表的排序 1.3修改某一列的名字 1.4使用表达式 1.5删除列的重复项 1.6多个列进行排序 2.条件查询 2.1条件查询语句 2.2比较运算符 2.3条件查询展示 2.4条件查询的先后问题 2.5逻辑运算符使用 2.6模糊查询匹配 2.7对于nu…

【2022工业3D异常检测文献】BTF: 结合手工制作的3D描述和颜色特征的异常检测方法

BACK TO THE FEATURE: CLASSICAL 3D FEATURES ARE (ALMOST) ALL YOU NEED FOR 3D ANOMALY DETECTION 1、Background BTF(Back to the Feature)&#xff0c;一种 结合手工制作的3D表示&#xff08;FPFH&#xff09;和基于深度颜色特征提取&#xff08;PatchCore&#xff09; 的…

关于未知物检测设备和方法(测未知物成分含量)

未知物检测是一项涉及多个学科和技术的复杂工作&#xff0c;它对于新材料的研究、开发、生产以及质量控制具有重要意义。以下是一些常用的未知物检测方法和设备&#xff1a; 光谱分析&#xff1a;包括红外光谱&#xff08;IR&#xff09;、核磁共振&#xff08;NMR&#xff09;…

【Android 13源码分析】Activity生命周期之onCreate,onStart,onResume-2

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

无源码实现免登录功能

因项目要求需要对一个没有源代码的老旧系统实现免登录功能&#xff0c;系统采用前后端分离的方式部署&#xff0c;登录时前端调用后台的认证接口&#xff0c;认证接口返回token信息&#xff0c;然后将token以json的方式存储到cookie中&#xff0c;格式如下&#xff1a; 这里有…

10月1日星期二今日早报简报微语报早读

10月1日星期二&#xff0c;国庆节&#x1f1e8;&#x1f1f3;&#xff0c;农历八月廿九&#xff0c;早报#微语早读。 1、A股暴涨刷新多项历史纪录&#xff1a;两市成交总额近2.6万亿元&#xff0c;创指涨逾15%&#xff1b; 2、文旅部&#xff1a;常年不超过最高承载量的旅游景…

Docker 安装 Citus 单节点集群:全面指南与详细操作

Docker 安装 Citus 单节点集群&#xff1a;全面指南与详细操作 文章目录 Docker 安装 Citus 单节点集群&#xff1a;全面指南与详细操作一 服务器资源二 部署图三 安装部署1 创建网络2 运行脚本1&#xff09;docker-compose.cituscd1.yml2&#xff09;docker-compose.cituswk1.…

zi2zi-chain: 中国书法字体图片生成和字体制作的一站式开发

在zi2zi-pytorch的基础上&#xff0c;做了进一步的修复和完善。本项目github对应网址为https://github.com/not-bald-owl/zi2zi-chain/tree/master。 修复部分为&#xff1a;针对预处理部分的函数弃用、生僻字无法生成、训练和推理部分单卡支持改为多卡并行、以及扩展从本地的…

过去8年,编程语言的流行度发生了哪些变化?PHP下降,Objective-C已过时

前天有一个汇总9个不同排名数据的“地表最强”编程语言排行榜&#xff0c;为了更好地理解语言流行度的变化&#xff0c;作者将2016年的类似调查结果与2024年的数据进行了比较。 虽然2016年的调查只包含6个排名&#xff0c;但它仍然提供了宝贵的参考数据。 我们来看看详细的情…

C++之String类(下)

片头 嗨喽~ 我们又见面啦&#xff0c;在上一篇C之String类&#xff08;上&#xff09;中&#xff0c;我们对string类的函数有了一个初步的认识&#xff0c;这一篇中&#xff0c;我们将继续学习string类的相关知识。准备好了吗&#xff1f;咱们开始咯~ 二、标准库中的string类 …