GNN:graph neural network 图神经网络,是⼀种连接模型,通过⽹络中节点之间的信息传递(message passing)的⽅式来获取图中的依存关系(dependence of graph),GNN通过从节点任意深度的邻居来更新该节点状态,这个状态能够表示状态信息。由于 GNN 在图节点之间强大的建模功能,使得与图分析相关的研究领域取得了突破。图神经网络(GNN)是一类基于深度学习的处理图域信息的方法。由于其较好的性能和可解释性,现已被广泛应用到各个领域。涵盖了推荐系统、组合优化、计算机视觉、物理 / 化学以及药物发现等领域。
一、数据集介绍
数据集中只有一张图。
该图描述了一个空手道俱乐部会员的社交关系,以34名会员作为节点,如果两位会员在俱乐部之外仍保持社交关系,则在节点间增加一条边。
每个节点具有一个34维的特征向量,一共有78条边。在收集数据的过程中,管理人员 John A 和 教练 Mr. Hi(化名)之间产生了冲突,会员们选择了站队,一半会员跟随 Mr. Hi 成立了新俱乐部,剩下一半会员找了新教练或退出了俱乐部。通过收集到的图数据,Zachary 进行了分类,除1名会员外都分类正确。将原图进行抽象可得到下图:
二、GNN实战
1. 导入所需的包
%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
# KarateClub是torch_geometric内置的数据集
from torch_geometric.datasets import KarateClub
注:torch_geometric库的安装不能直接pip install,具体的安装方法可以参考之前的blog:https://blog.csdn.net/m0_51339444/article/details/128611141
2. 定义可视化函数
def visualize_graph(G, color):
plt.figure(figsize=(5,5))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color=color, cmap='Set2')
plt.show()
3. 导入并查看KarateClub数据集
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print(f'Number of the graphs: {len(dataset)}')
print(f'Number of the features: {dataset.num_features}')
print(f'Number of the classes: {dataset.num_classes}')
data = dataset[0]
print(data)
# edge_index是邻接矩阵,表示每两个点之间的关联
edge_index = data.edge_index
# 打印出每个点分别和谁有关系
print(edge_index.t())
这里对上一个运行结果解释一下,这是整个数据集的全部生态环境了,x是特征,就是一个一个的点,第一个34表示一共有34个点,即34个样本,第二个34表示每个样本是34维的向量(即34个特征);edge_index是邻接矩阵,表示每两个点之间的关联,第一个元素一定是2,表示两个点之间的边,156表示一共有156个关系,即156条边;train_mask记录了34个数据中有标签与否,有标签是True,没有标签是False。
4. 使用networkx进行可视化展示
# 将处理好(对应的标准格式)的data传入to_networkx,再传入visualize_graph(最上面自己定义的)绘图
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
5. 搭建网络
这里会使用到torch_geometric的方法(封装好的函数),有疑问的地方可以去官网查询API,这里拍个链接:https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features, 4) # 两个参数分别为输入特征和输出特征
self.conv2 = GCNConv(4,4)
self.conv3 = GCNConv(4,2)
self.classifier = Linear(2, dataset.num_classes)
# x是特征,没经过一层后数据都是不断变化的,即x 变成h,h不断变成新的h,而edge_index邻接矩阵是一直不变的,谁和谁之间有联系是不变的
def forward(self, x, edge_index):
h = self.conv1(x, edge_index) # 输入特征和邻接矩阵
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh()
# 分类层
out = self.classifier(h)
# out是输出,h是中间结果(conv3的输出)(一个2维的向量(方便绘图打印))
return out, h
model = GCN()
print(model)
由于数据集比较小,因此搭建小网络即可,网络参数如下:
6. 进行embedding操作并可视化
def visualize_embedding(h, color, epoch=None, loss=None):
plt.figure(figsize=(5,5))
plt.xticks([])
plt.yticks([])
h = h.detach().cpu().numpy()
plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap='Set2')
if epoch is not None and loss is not None:
plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
plt.show()
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize_embedding(h, color=data.y)
7. 训练模型
import time
model = GCN()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(data):
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
# 这里体现了半监督的思想,只拿有标签的计算损失,没有标签的不参与计算
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss, h
for epoch in range(401):
loss, h = train(data)
if epoch % 10 == 0:
visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
time.sleep(0.3)
可以看到,随着epoch的增大,损失函数逐渐收敛,可视化结果逐渐将三种颜色分成了三个类别(类似聚类的结果)。