创作不易,您的打赏、关注、点赞、收藏和转发是我坚持下去的动力!
图神经网络(Graph Neural Network, GNN)是针对图数据的一类神经网络模型。图数据具有节点(节点代表实体)和边(边代表节点之间的关系),因此,GNN能够处理这种复杂的关系结构,提取图结构中有用的信息。GNN的基本思想是通过消息传递(message passing)机制将节点和它们的邻居进行特征融合,从而更新节点的表示。这种表示可以用来进行节点分类、边预测或者整个图的分类等任务。
1. GNN基础知识
GNN的核心机制是基于图的消息传递和特征聚合。对于每个节点,GNN会收集其邻居节点的信息,然后通过一定的聚合函数(例如求和或平均)生成新的特征表示。
1.1 图的定义
- 节点(Node):图中的实体,记作 (v_i)。
- 边(Edge):节点之间的关系,记作 (e_{ij}),表示从节点 (v_i) 到节点 (v_j) 的连接。
- 邻居节点(Neighbors):节点 (v_i) 的直接相连节点集合,记作 (N(v_i))。
1.2 GNN的消息传递机制
GNN的基本操作包括两个步骤:
- 消息传递(Message Passing):从每个节点的邻居节点收集特征。
- 特征更新(Feature Update):将节点的特征与邻居的特征聚合,更新节点的表示。
假设节点 (v_i) 的初始特征为 (h_i^{(0)}),其第 (k) 次迭代时的特征表示为 (h_i^{(k)})。GNN通过以下两步进行更新:
- 聚合邻居特征:将节点 (v_i) 的所有邻居节点的特征聚合起来,例如求和或平均:
[
m_i^{(k)} = \text{AGGREGATE}({ h_j^{(k-1)} : j \in N(v_i) })
] - 更新节点特征:将聚合的邻居特征与节点本身的特征结合起来,更新节点的表示:
[
h_i^{(k)} = \text{UPDATE}(h_i^{(k-1)}, m_i^{(k)})
]
1.3 GNN在图分类任务中的应用
图分类任务的目标是给定一张图,预测该图的类别。常见应用包括化学分子分类、社交网络分析等。在这种任务中,GNN的目标是通过学习图的全局结构信息来预测整张图的标签。
GNN处理图分类任务的流程一般如下:
- 特征初始化:给每个节点赋予初始特征(可以是节点的属性)。
- 消息传递与特征更新:通过多层GNN层,将节点特征与其邻居进行聚合和更新。
- 图的汇总(Readout):将所有节点的特征汇总为图的表示(例如通过求平均或全连接层)。
- 分类器:使用图的表示作为输入,通过一个分类器预测图的类别。
2. Python实现示例
我们可以使用PyTorch Geometric
来实现一个简单的图分类任务。
2.1 安装依赖
首先,你需要安装PyTorch
和PyTorch Geometric
库:
pip install torch
pip install torch-geometric
2.2 数据准备
我们使用PyTorch Geometric
中的一个经典的图分类数据集MUTAG
,这是一个小型化学分子数据集,每个分子作为一张图,目标是预测分子的类别。
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
# 加载数据集
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
# 划分训练集和测试集
train_dataset = dataset[:150]
test_dataset = dataset[150:]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
2.3 定义GNN模型
我们定义一个简单的图卷积网络(GCN)用于图分类任务。
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
# 定义两个GCN层
self.conv1 = GCNConv(dataset.num_node_features, 64)
self.conv2 = GCNConv(64, 64)
# 最后一个全连接层用于图分类
self.fc = torch.nn.Linear(64, dataset.num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
# 第一层GCN + ReLU激活
x = self.conv1(x, edge_index)
x = F.relu(x)
# 第二层GCN
x = self.conv2(x, edge_index)
# 使用全局平均池化将节点特征聚合为图的特征
x = global_mean_pool(x, batch)
# 最后通过全连接层进行分类
x = self.fc(x)
return F.log_softmax(x, dim=1)
2.4 模型训练和测试
我们定义训练和测试的函数,分别用于训练模型和评估模型的性能。
# 定义设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def test(loader):
model.eval()
correct = 0
for data in loader:
data = data.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)
# 训练模型
for epoch in range(1, 201):
loss = train()
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')
2.5 解释代码
- GCNConv:图卷积层,用于将节点的特征与其邻居的特征进行聚合。
- global_mean_pool:对图中的所有节点特征进行全局池化,将节点特征汇总为图的特征表示。
- forward:定义了模型的前向传播,输入图的特征和结构,输出图的类别预测。
通过上述代码,你可以用GNN进行图分类任务。这个模型会对每张图中的所有节点进行特征更新,并最终通过全连接层进行分类。