图神经网络
安装Pyg
首先安装torch_geometric需要安装pytorch然后查看一下自己电脑Pytorch的版本
import torch
print(torch.__version__)
#1.12.0+cu113
然后进入官网文档网站
链接: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
安装自己的版本选择安装命令,我python用的稳定的3.8版本。如果安装失败可以考虑降低python的版本
因为我之前安装过所以显示如下
图信号数据集初入门
本次入门选用Karateclub数据集
这个数据集讲诉的是一个空手道俱乐部之间人和人的关系,每个节点代表一个人说俱乐部的两个教练吵架了,要每一个节点所代表的人进行站队通过图信号预测。
首先读取数据集
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')
#Number of graphs:1
#Number of features:34
#Number of classes:4
可以看到只有一张图每个节点有34个特征,每个特征代表的应该是每一个会员的信息,分成四类我们可以暂时理解成跟了教练A的,跟了教练B的,换了一个新教练的,和退出俱乐部的这四类。
然后我们将这个图打出来进行观察可以看到节点是分成了四类
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
dataset = KarateClub()
print(f'Dataset{dataset}')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')
def visualize_graph(G,color):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G,pos = nx.spring_layout(G,seed = 42),with_labels=False,node_color=color,cmap="Set2")
plt.savefig("net.jpg")
plt.show()
data = dataset[0]
print(data)
G = to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)
然后我们观察一个图的数据可以观察到一共有34个节点每个节点有34个数据一共有156条边
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
data = dataset[0]
print(data)
#Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
训练代码
接下来使用pyg进行训练
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
import matplotlib.pyplot as plt
dataset = KarateClub()
data = dataset[0]
def visualize_embedding(h,color,epoch=None,loss=None):
global i
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
h = h.detach().cpu().numpy()
plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
if epoch is not None and loss is not None:
plt.xlabel(f"Epoch:{epoch},Loss: {loss.item():.4f}",fontsize = 16,)
plt.show()
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features,4)
self.conv2 = GCNConv(4,4)
self.conv3 = GCNConv(4,2)
self.classifier = Linear(2,dataset.num_classes)
def forward(self,x,edge_index):
h = self.conv1(x,edge_index)
h = h.tanh()
h = self.conv2(h,edge_index)
h = h.tanh()
h = self.conv3(h,edge_index)
h = h.tanh()
out = self.classifier(h)
# return F.softmax(out,dim=1),h
return out,h
model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_list = []
def train(data):
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
loss_list.append(loss.item())
optimizer.step()
return loss, h
for epoch in range(401):
loss, h = train(data)
if epoch % 10 == 1:
visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
plt.plot(loss_list)
plt.show()
损失曲线如下
训练集可视化动图如下
![在这里插入图片描述](https://img-blog.csdnimg.cn/e1cff7bf8ba0498cb8abd6f61a2f83a6.gif