图神经网络
- 1、引言
- 2、图神经网络
- 2.1 定义
- 2.2 原理
- 2.3 实现方式
- 2.4 算法公式
- 2.4.1 GNN
- 2.4.2 GCN
- 2.5 代码示例
- 3、总结
1、引言
小屌丝:鱼哥,给俺讲一讲图神经网络啊
小鱼:你看,我这会在忙着呢
小屌丝:啊~
小鱼:这是咋的了,
小屌丝:你咋还有这技术?
小鱼:这… 不是很平常的操作,有啥惊讶的。
小屌丝:哇哦~ 难得哦
小鱼:你这…
小屌丝:看来今晚是有贵客到哦?
小鱼:也没有了, 嘿嘿~
小屌丝: 66号技师??
小鱼:你可真能扯,我是那种人吗,我能做那种事情吗?
小屌丝:那你说你这要干嘛?
小鱼:我… 我就要烧个菜,你真是能联想翩翩
小屌丝:我…
2、图神经网络
2.1 定义
图神经网络(GNN)是一种处理图结构数据的神经网络。
与传统的神经网络不同,GNN能够直接在图结构上进行操作,捕捉节点之间的复杂关系。
这种能力让GNN成为处理社交网络分析、知识图谱、推荐系统等问题的强有力工具。
2.2 原理
GNN的核心原理基于邻居聚合策略,即:通过迭代地聚合邻居节点的信息来更新当前节点的表示。
在每次迭代中,节点会接收来自其邻居的信息,并通过一个可学习的函数(通常是神经网络)来整合这些信息,从而更新自己的状态。
这个过程会重复进行,直到达到一个稳定的状态,最终得到每个节点的高级表示,这些表示可以用于后续的任务,如节点分类、图分类等
2.3 实现方式
GNN的实现通常包括以下几个关键步骤:
- 节点表示初始化:为图中的每个节点分配初始表示(如节点特征或嵌入)。
- 邻居信息聚合:对于每个节点,从其邻居节点收集信息,并通过聚合函数(如平均、求和、最大值)将这些信息整合起来。
- 节点状态更新:结合节点当前的状态和聚合得到的邻居信息,通过一个更新函数(如全连接层)来更新节点的状态。
- 读出:对于图级别的任务,需要通过一个读出(readout)函数将所有节点的表示整合成图的总体表示。
2.4 算法公式
2.4.1 GNN
一个基本的GNN更新公式可以表示为: [ h v ( l + 1 ) = f ( h v ( l ) , □ u ∈ N ( v ) g ( h u ( l ) ) ) ] [h_v^{(l+1)} = f\left(h_v^{(l)}, \square_{u \in \mathcal{N}(v)} g\left(h_u^{(l)}\right)\right)] [hv(l+1)=f(hv(l),□u∈N(v)g(hu(l)))]
其中,
- ( h v ( l ) ) (h_v^{(l)}) (hv(l))表示节点 ( v ) (v) (v)在第 ( l ) (l) (l)层的表示,
- ( N ( v ) ) (\mathcal{N}(v)) (N(v))是 ( v ) (v) (v)的邻居节点集合,
- ( f ) (f) (f)和 ( g ) (g) (g)分别是更新函数和邻居信息聚合函数,
- ( □ ) (\square) (□)是聚合操作(如求和、平均或最大值)。
2.4.2 GCN
对于具体的GNN变体,如GCN(图卷积网络),其公式会有所不同。以GCN为例,其每一层的更新可以表示为:
[ H ( l ) = σ ( D − 1 2 A D − 1 2 H ( l − 1 ) W ( l ) ) ] [ H^{(l)} = \sigma\left(D^{-\frac{1}{2}}AD^{-\frac{1}{2}}H^{(l-1)}W^{(l)}\right) ] [H(l)=σ(D−21AD−21H(l−1)W(l))]
其中:
- ( H ( l ) ) (H^{(l)}) (H(l)) 是一个矩阵,其行表示第 ( l ) (l) (l) 层中所有节点的特征向量。
- ( A ) (A) (A) 是图的邻接矩阵。
- ( D ) (D) (D) 是度矩阵,其对角线上的元素是每个节点的度(即相邻节点的数量)。
- ( W ( l ) ) (W^{(l)}) (W(l)) 是第 ( l ) (l) (l) 层的可学习权重矩阵。
- ( σ ( ⋅ ) ) (\sigma(\cdot)) (σ(⋅)) 是激活函数,如 R e L U ReLU ReLU。
这个公式体现了GCN中的两个关键步骤:
- 邻居信息的聚合(通过 ( D − 1 2 A D − 1 2 H ( l − 1 ) ) (D^{-\frac{1}{2}}AD^{-\frac{1}{2}}H^{(l-1)}) (D−21AD−21H(l−1)) 实现)和线性变换(通过 ( W ( l ) ) (W^{(l)}) (W(l)) 实现)。
- 通过这种方式,GCN能够捕捉图的结构信息并学习节点的有效表示。
2.5 代码示例
# -*- coding:utf-8 -*-
# @Time : 2024-04-02
# @Author : Carl_DJ
'''
实现功能:
使用PyTorch框架和PyTorch Geometric库实现GNN
'''
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.optim as optim
import torch.nn.functional as F
class ComplexGNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(ComplexGNN, self).__init__()
# 第一个图卷积层,将输入特征转换为隐藏层特征
self.conv1 = GCNConv(in_channels, hidden_channels)
# 第二个图卷积层,将隐藏层特征转换为输出特征
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 输入特征通过第一个卷积层,激活函数为ReLU
x = F.relu(self.conv1(x, edge_index))
# 加入dropout,防止过拟合
x = F.dropout(x, training=self.training)
# 通过第二个卷积层得到输出特征
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 加载数据集,这里使用Planetoid数据集作为示例,Cora是其中一个公开的图数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
# 初始化模型,指定输入特征维度、隐藏层维度和输出特征维度
model = ComplexGNN(in_channels=dataset.num_node_features, hidden_channels=16, out_channels=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 定义训练函数
def train():
model.train()
optimizer.zero_grad()
# forward pass
out = model(dataset.data.x, dataset.data.edge_index)
# 计算损失,这里使用负对数似然损失
loss = F.nll_loss(out[dataset.data.train_mask], dataset.data.y[dataset.data.train_mask])
# 反向传播
loss.backward()
optimizer.step()
return loss
# 训练模型
for epoch in range(200):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
代码解析:
初始化,
- 定义了一个名为ComplexGNN的图神经网络模型,它包含两个图卷积层(GCNConv)。
- 模型的输入是节点的特征向量,输出是节点类别的预测。
- 使用ReLU作为激活函数,并在第一个图卷积层后加入了dropout层以减少过拟合。
在数据处理,
- 使用了PyTorch Geometric提供的Planetoid数据集加载工具来加载Cora数据集。
Cora数据集是一个常用的图节点分类数据集,其中节点代表科学出版物,边代表引用关系。
在训练过程,
- 使用负对数似然损失(Negative Log Likelihood Loss)作为损失函数,
- 使用Adam优化器来优化模型参数。
- 训练循环中,对模型进行前向传播,计算损失,执行反向传播并更新模型参数。
3、总结
图神经网络通过其独特的结构处理图数据的能力,在多个领域显示出了卓越的性能,从社交网络分析到分子结构识别。
随着研究的深入和技术的进步,GNN将继续扩展其应用领域,为解决复杂的图结构问题提供有效的工具。
我是小鱼:
- CSDN 博客专家;
- 阿里云 专家博主;
- 51CTO博客专家;
- 企业认证金牌面试官;
- 多个名企认证&特邀讲师等;
- 名企签约职场面试培训、职场规划师;
- 多个国内主流技术社区的认证专家博主;
- 多款主流产品(阿里云等)测评一、二等奖获得者;
关注小鱼,学习【机器学习】&【深度学习】领域的知识。