1.colab0
1.1 数据集
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f'Dataset:{dataset}:')
print('======================')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')
初始化 KarateClub
数据集后,我们首先可以检查它的一些属性。例如,我们可以看到这个数据集正好包含一个图形,并且这个数据集中的每个节点都被分配了一个 34 维特征向量(它唯一地描述了空手道俱乐部的成员)。此外,该图正好包含 4 个类,它们代表每个节点所属的社区。
data = dataset[0]
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
print(f'节点个数:{data.num_nodes}')
print(f'边个数:{data.num_edges}')
print(f'平均节点的度:{(data.num_)/}')
print(f'训练的节点数:{data.train_mask.sum()}')
print(f'训练节点标签率:{int(data.train_mask.sum())/data.num_nodes:.2f}')
print(f'包含孤立节点:{data.has_isolated_nodes()}')
print(f'包含自环:{data.has_self_loops()}')
print(f'是否是无向图:{data.is_undirected()}')
-
获取第一个图对象:
data = dataset[0]
从数据集中获取第一个图对象。 -
打印图对象:
print(data)
打印出图对象的详细信息,方便了解图的基本结构和属性。 -
打印分隔线:
print('==============================================================')
打印一条分隔线,使输出结果更清晰。 -
统计信息:
data.num_nodes
:打印节点的数量。data.num_edges
:打印边的数量。平均节点度
:通过计算data.num_edges / data.num_nodes
来获得平均每个节点的边数。训练节点数量
:通过data.train_mask.sum()
打印训练节点的数量。训练节点标签率
:通过计算int(data.train_mask.sum()) / data.num_nodes
来获得训练节点在所有节点中的比例。包含孤立节点
:调用data.has_isolated_nodes()
检查图中是否有孤立节点。包含自环
:调用data.has_self_loops()
检查图中是否包含自环。是否为无向图
:调用data.is_undirected()
检查图是否为无向图。
在 PyTorch Geometric 中,每个图都由一个单独的 Data
对象表示,该对象包含描述其图表示的所有信息。我们可以随时通过 print(data)
打印 data
对象,以获取其属性及其形状的简要总结:
Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])
我们可以看到这个 data
对象包含四个属性:
edge_index
属性保存了图的连接信息,即每条边的源节点和目标节点索引的元组。- PyG 将node_features称为
x
(每个34个节点被分配一个34维的特征向量)。 - PyG 将node_lables称为
y
(每个节点被分配到一个类别)。 - 还有一个额外的属性称为
train_mask
,它描述了哪些节点的社区分配已经知道。
总的来说,我们只知道4个节点的真实标签(每个社区一个),任务是推断其余节点的社区分配。
![[Pasted image 20240807213401.png]]
data
对象还提供了一些实用函数来推断底层图的一些基本属性。例如,我们可以很容易地推断图中是否存在孤立节点(即没有任何节点与之相连的边)、图中是否包含自环(即
(
v
,
v
)
∈
E
(v, v) \in \mathbb{E}
(v,v)∈E),或者图是否是无向图(即对于每条边
(
v
,
w
)
∈
E
(v, w) \in \mathbb{E}
(v,w)∈E,也存在边
(
w
,
v
)
∈
E
(w, v) \in \mathbb{E}
(w,v)∈E)。
1.2 边索引 Edge Index
通过打印 edge_index
,我们可以进一步了解 PyG 如何在内部表示图的连通性。
我们可以看到,对于每条边,edge_index
保存了一个包含两个节点索引的元组,其中第一个值描述了源节点的节点索引,第二个值描述了边的目标节点的节点索引。
这种表示法被称为 COO格式(坐标格式),通常用于表示稀疏矩阵。
PyG 不使用密集表示
A
∈
{
0
,
1
}
∣
V
∣
×
∣
V
∣
\mathbf{A} \in \{ 0, 1 \}^{|\mathbb{V}| \times |\mathbb{V}|}
A∈{0,1}∣V∣×∣V∣ 来保存邻接信息,而是稀疏表示图,即只保存
A
\mathbf{A}
A 中非零条目的坐标/值。
我们可以通过将图转换为 networkx
库的格式进一步可视化,该库除了实现图的操作功能外,还提供了强大的可视化工具:
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
# 将 PyG 的图对象转换为 NetworkX 的图对象
G = to_networkx(data, to_undirected=True)
# 绘制图形
plt.figure(figsize=(8, 8))
nx.draw(G, with_labels=True, node_color=data.y, cmap=plt.get_cmap('Set2'))
plt.show()
-
导入库:导入
networkx
、to_networkx
和matplotlib
库。 -
转换图对象:使用
to_networkx(data, to_undirected=True)
将 PyG 的图对象转换为 NetworkX 的图对象,并将图设置为无向图。 -
绘制图形:
- 设置图形大小为 8x8。
- 使用
nx.draw
绘制图形,设置显示节点标签,节点颜色为data.y
(节点标签),并使用Set2
颜色映射。 - 显示图形。
-
COO格式(坐标格式):
假设我们有一个4x4的稀疏矩阵
0 0 3 0
1 0 0 0
0 2 0 4
0 0 5 0
COO格式表示
row_indices = [0, 1, 2, 2, 3]
col_indices = [2, 0, 1, 3, 2]
values = [3, 1, 2, 4, 5]
这三个列表共同表示了矩阵中的非零元素
每个非零元素由(行索引, 列索引, 值)表示
在COO格式中,我们只存储非零元素的行索引、列索引和对应的值。这种方法特别适合表示稀疏矩阵,即大部分元素为零的矩阵。
2. 稀疏表示 vs 密集表示:
- 密集表示:存储矩阵的所有元素,包括零元素。对于|V|×|V|的矩阵,需要存储|V|^2个元素。
- 稀疏表示:只存储非零元素及其位置。对于边数为|E|的图,只需要存储2|E|个索引(行和列)和|E|个值。
3. 在图神经网络中的应用:
- 图通常是稀疏的,即大多数节点之间没有直接连接。
- 使用COO格式可以显著减少内存使用,特别是对于大规模图。
- PyG(PyTorch Geometric)使用这种格式来高效地表示图结构。
4. 在PyG中的使用:
- PyG使用`edge_index`张量(==形状为[2, num_edges]==)来存储边的连接信息。
- 如果边有权重,还可以使用额外的`edge_attr`张量。
```python
edge_index = data.edge_index
print(edge.index.t())
未转置的张量 tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7,
7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
[ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2,
3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0,
1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1,
2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2,
3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0,
1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23,
24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32,
33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15,
18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])
1.3 可视化
nx.draw_networkx
函数参数解释:
-
h:
- 这是一个 NetworkX 图对象。
-
pos=nx.spring_layout(h, seed=42):
pos
参数指定节点的布局。nx.spring_layout(h, seed=42)
使用 spring layout 算法生成图的节点布局,其中seed=42
保证布局的可重复性。nx.spring_layout
是一种力导向布局算法,节点位置由节点之间的模拟弹簧力和排斥力决定,通常可以生成较为美观的图布局。
-
with_labels=False:
with_labels
参数指定是否在节点上显示标签。设置为False
表示不显示节点标签。
-
node_color=color:
node_color
参数指定节点的颜色。color
可以是颜色字符串、颜色列表或者颜色数组,用于区分不同类别的节点。
-
cmap=“Set2”:
cmap
参数指定颜色映射(colormap)。"Set2"
是一个预定义的颜色映射,通常用于分类数据。
当然,以下是对 visualize
函数以及其中参数的详细解释,结合了 plt.scatter
函数的用法和示例:
plt.scatter
函数参数解释:
-
h[:, 0] 和 h[:, 1]:
h
是一个二维数组或张量,其中h[:, 0]
表示所有点的 x 坐标,h[:, 1]
表示所有点的 y 坐标。- 这两个参数定义了散点图中每个点的位置。
-
s=140:
s
代表散点的大小。- 这里设置每个点的大小为 140,可以根据需要调整。
-
c=color:
c
指定每个点的颜色。color
通常是一个数组,表示每个点的颜色,可以用来区分不同类别的点。
-
cmap=“Set2”:
cmap
指定颜色映射(colormap)。"Set2"
是一个预定义的颜色映射,通常用于分类数据。- 可以根据需要选择不同的颜色映射,比如
"viridis"
、"plasma"
等。
visualize
函数参数解释:
-
h:图节点的嵌入或坐标,可以是 PyTorch 张量或 NetworkX 图对象。
- 如果是张量,通常表示节点在嵌入空间中的坐标,用于绘制散点图。
- 如果是 NetworkX 图对象,表示整个图的结构,用于绘制网络图。
-
color:节点的颜色信息。可以是用于区分不同节点类别的数组,或直接用于节点着色的颜色值。
-
epoch(可选):训练的当前轮数。在训练过程中,可以用来显示当前训练到第几轮。
-
loss(可选):当前轮的损失值。在训练过程中,可以用来显示当前轮的损失值。
-
accuracy(可选):包含训练和验证准确率的字典。在训练过程中,可以用来显示当前轮的训练和验证准确率。
accuracy['train']
:训练集的准确率。accuracy['val']
:验证集的准确率。
# 导入必要的库
%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
# 定义可视化函数
def visualize(h, color, epoch=None, loss=None, accuracy=None):
# 设置图像大小
plt.figure(figsize=(7,7))
# 移除x轴和y轴的刻度,使图像更清晰
plt.xticks([])
plt.yticks([])
# 检查输入的h是否为PyTorch张量
if torch.is_tensor(h):
# 将张量转换为NumPy数组,numpy只能在cpu上运行
h = h.detach().cpu().numpy()
# 创建散点图
plt.scatter(h[:,0],h[:,1],s=140,c = color, cmap = "Set2")
# 如果提供了额外的信息(epoch, loss, accuracy),在图上显示
if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:
plt.xlabel((f'Epoch:{epoch},Loss:{loss.item():.4f} \n'
f'Traing Accuracy:{accuracy["train"]*100:.2f}% \n'
f'Validation Accuracy:{accuracy["val"]*100:.2f}%'),
fontsize=16)
# 如果h不是张量,则假设其为NetworkX图并进行可视化
else:
nx.draw_networkx(h,pos=nx.spring_layout(h,seed=42),with_labels= False,node_color = color,cmap="Set2")
# 显示图像
plt.show()
-
导入语句:确保导入了必要的库用于绘图(
matplotlib
)、处理张量(torch
)和处理图结构(networkx
)。 -
函数定义:
visualize
函数接受以下参数:h
:可以是PyTorch张量或NetworkX图。color
:节点或数据点的颜色信息。epoch
、loss
、accuracy
:可选参数,用于在图上显示附加信息。
-
图像配置:设置图像大小并移除x轴和y轴的刻度,使可视化效果更清晰。
-
张量可视化:如果
h
是张量,将其转换为NumPy数组并使用plt.scatter
绘制散点图。如果提供了额外的信息(epoch
、loss
、accuracy
),则在图上显示。 -
图结构可视化:如果
h
不是张量,则假设其为NetworkX图结构,并使用nx.draw_networkx
进行可视化,采用spring布局来定位节点。 -
显示图像:最后,调用
plt.show()
来显示图像。
1.4 GCN
这里,我们首先在 __init__
中初始化所有的构建模块,并在 forward
中定义网络的计算流程。
我们首先定义并堆叠三层图卷积层。每一层对应于从每个节点的1跳邻居(直接邻居)聚合信息,但当我们将这些层组合在一起时,我们能够从每个节点的3跳邻居(所有距离3跳以内的节点)聚合信息。
此外,GCNConv
层将节点特征的维度减少到 2,即从
34
→
4
→
4
→
2
34 \rightarrow 4 \rightarrow 4 \rightarrow 2
34→4→4→2。每个 GCNConv
层都通过 tanh 非线性激活函数进行增强。
之后,我们应用一个线性变换(torch.nn.Linear
)作为分类器,将我们的节点映射到4个类别/社区中的一个。
我们返回最终分类器的输出以及我们的GNN生成的最终节点嵌入。
接下来,通过 GCN()
初始化我们的最终模型,打印我们的模型会生成其使用的所有子模块的摘要。
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(
dataset.num_features,4)
self.conv2 = GCNConv(4,4)
self.conv3 = GCNConv(4,2)
self.classifier = Linear(2,dataset.num_classes)
def forward(self,x,edge_index):
h = self.conv1(x,edge_index)
h = h.tanh()
h = self.conv2(h,edge_index)
h = h.tanh()
h = self.conv3(h,edge_index)
h = h.tanh()
out = self.classifier(h)
return out, h
model = GCN()
print(model)
1.5 训练
下面是对代码的详细解释和翻译,展示了如何在训练过程中使用定义的 GCN
模型,并每 10 轮可视化一次节点嵌入:
-
导入必要的库和模块:
import time from IPython.display import Javascript # 限制输出单元格的高度 display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})'''))
-
初始化模型、损失函数和优化器:
model = GCN(dataset) # 初始化 GCN 模型 criterion = torch.nn.CrossEntropyLoss() # 定义损失函数 optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 定义优化器
-
定义训练函数:
def train(data): optimizer.zero_grad() # 清除梯度 out, h = model(data.x, data.edge_index) # 执行一次前向传播 loss = criterion(out[data.train_mask], data.y[data.train_mask]) # 仅基于训练节点计算损失 loss.backward() # 反向传播计算梯度 optimizer.step() # 基于梯度更新参数 accuracy = {} # 计算训练集上的准确率 predicted_classes = torch.argmax(out[data.train_mask], axis=1) target_classes = data.y[data.train_mask] accuracy['train'] = torch.mean( torch.where(predicted_classes == target_classes, 1, 0).float()) # 计算整个图上的验证准确率 predicted_classes = torch.argmax(out, axis=1) target_classes = data.y accuracy['val'] = torch.mean( torch.where(predicted_classes == target_classes, 1, 0).float()) return loss, h, accuracy
-
训练模型并每 10 轮可视化一次节点嵌入:
for epoch in range(500): loss, h, accuracy = train(data) # 每10轮可视化一次节点嵌入 if epoch % 10 == 0: visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy) time.sleep(0.3) # 暂停0.3秒
详细解释:
-
初始化模型、损失函数和优化器:
GCN(dataset)
:初始化 GCN 模型。torch.nn.CrossEntropyLoss()
:定义交叉熵损失函数,用于分类任务。torch.optim.Adam(model.parameters(), lr=0.01)
:定义 Adam 优化器,学习率为 0.01。
-
定义训练函数
train
:optimizer.zero_grad()
:清除梯度。model(data.x, data.edge_index)
:执行前向传播,得到输出和节点嵌入。criterion(out[data.train_mask], data.y[data.train_mask])
:计算训练集上的损失。loss.backward()
:反向传播计算梯度。optimizer.step()
:基于梯度更新模型参数。predicted_classes = torch.argmax(out[data.train_mask], axis=1)
:计算训练集上的预测类别。accuracy['train']
:计算训练集上的准确率。predicted_classes = torch.argmax(out, axis=1)
:计算整个图上的预测类别。accuracy['val']
:计算验证集上的准确率。
-
训练模型并每 10 轮可视化一次:
for epoch in range(500)
:循环 500 轮。if epoch % 10 == 0
:每 10 轮执行一次可视化。visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy)
:可视化节点嵌入,并显示训练轮数、损失和准确率。time.sleep(0.3)
:暂停 0.3 秒,以便观察可视化结果。
运行代码示例:
假设我们有一个数据集 KarateClub
,可以使用如下代码:
from torch_geometric.datasets import KarateClub
# 加载KarateClub数据集
dataset = KarateClub()
data = dataset[0] # 获取第一个图对象
# 初始化GCN模型
model = GCN(dataset)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 定义训练函数
def train(data):
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
accuracy = {}
predicted_classes = torch.argmax(out[data.train_mask], axis=1)
target_classes = data.y[data.train_mask]
accuracy['train'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
predicted_classes = torch.argmax(out, axis=1)
target_classes = data.y
accuracy['val'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())
return loss, h, accuracy
# 训练模型并可视化
for epoch in range(500):
loss, h, accuracy = train(data)
if epoch % 10 == 0:
visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy)
time.sleep(0.3)
通过上述代码,我们可以在训练过程中每 10 轮可视化一次节点嵌入,直观地展示模型的训练效果和节点的嵌入分布情况。
以下是对 loss = criterion(out[data.train_mask], data.y[data.train_mask])
这一行代码的详细解释:
解释:
-
out
:out
是模型的输出,形状为[num_nodes, num_classes]
,表示每个节点的类别得分(logits)。
-
data.train_mask
:data.train_mask
是一个布尔掩码,形状为[num_nodes]
,指示哪些节点用于训练。- 例如,如果
data.train_mask[i]
为True
,则节点i
用于训练。
-
out[data.train_mask]
:- 这部分代码使用布尔掩码
data.train_mask
从out
中选择训练节点的输出。 - 结果是一个形状为
[num_train_nodes, num_classes]
的张量,表示训练节点的类别得分。
- 这部分代码使用布尔掩码
-
data.y
:data.y
是节点的真实标签,形状为[num_nodes]
,每个节点一个标签。
-
data.y[data.train_mask]
:- 使用布尔掩码
data.train_mask
从data.y
中选择训练节点的真实标签。 - 结果是一个形状为
[num_train_nodes]
的张量,表示训练节点的真实标签。
- 使用布尔掩码
-
criterion
:criterion
是交叉熵损失函数 (torch.nn.CrossEntropyLoss
),用于计算预测类别分布与真实类别分布之间的差异。
-
loss = criterion(out[data.train_mask], data.y[data.train_mask])
:- 这行代码计算模型在训练节点上的损失。
- 将训练节点的输出
out[data.train_mask]
作为预测值,将训练节点的标签data.y[data.train_mask]
作为真实值,传递给交叉熵损失函数criterion
,计算损失loss
。
示例代码和上下文:
下面是完整的上下文代码,用于训练图卷积网络模型并计算损失:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
import matplotlib.pyplot as plt
import networkx as nx
import time
# 定义 GCN 模型
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh() # 最后的 GNN 嵌入空间
# 应用最终的线性分类器
out = self.classifier(h)
return out, h
# 使用 GPU 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 加载数据集
dataset = KarateClub()
data = dataset[0].to(device)
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
# 定义模型并移至 GPU
model = GCN().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 训练函数
def train(data):
model.train()
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
accuracy = {}
# 计算训练准确率
predicted_classes = torch.argmax(out[data.train_mask], axis=1)
target_classes = data.y[data.train_mask]
accuracy['train'] = torch.mean((predicted_classes == target_classes).float()).item()
# 计算验证准确率
predicted_classes = torch.argmax(out, axis=1)
target_classes = data.y
accuracy['val'] = torch.mean((predicted_classes == target_classes).float()).item()
return loss.item(), h, accuracy
# 设置训练和验证掩码
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[:int(0.6 * data.num_nodes)] = True
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[int(0.6 * data.num_nodes):int(0.8 * data.num_nodes)] = True
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[int(0.8 * data.num_nodes):] = True
# 记录训练过程中的损失和准确率
epochs = []
losses = []
train_accuracies = []
val_accuracies = []
# 训练模型
for epoch in range(500):
loss, h, accuracy = train(data)
epochs.append(epoch)
losses.append(loss)
train_accuracies.append(accuracy['train'])
val_accuracies.append(accuracy['val'])
if epoch % 10 == 0:
# 将嵌入移回 CPU 以进行可视化
h_cpu = h.detach().cpu()
visualize(h_cpu, color=data.y.cpu(), epoch=epoch, loss=loss, accuracy=accuracy)
time.sleep(0.3)
# 打印最终模型
print(model)
# 绘制训练过程中的损失和准确率
plot_training_curves(epochs, losses, train_accuracies, val_accuracies)
总结:
在训练过程中,通过选择训练掩码中的节点,我们仅基于这些节点的预测值和真实值计算损失,从而优化模型参数。这种方法有助于确保模型能够更好地泛化到未见过的数据。
这段代码的作用是为训练、验证和测试集创建掩码,以便在图卷积网络(GCN)训练过程中指定哪些节点用于训练、验证和测试。具体地,这段代码设置了训练和验证集的掩码。下面是详细解释:
代码解释:
# 设置训练集掩码
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[:int(0.6 * data.num_nodes)] = True
# 设置验证集掩码
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[int(0.6 * data.num_nodes):int(0.8 * data.num_nodes)] = True
# 设置测试集掩码
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[int(0.8 * data.num_nodes):] = True
详细解释:
-
设置训练集掩码:
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) data.train_mask[:int(0.6 * data.num_nodes)] = True
torch.zeros(data.num_nodes, dtype=torch.bool)
:创建一个大小为data.num_nodes
的布尔张量,所有值初始化为False
。data.train_mask[:int(0.6 * data.num_nodes)] = True
:将前 60% 节点的掩码值设置为True
,表示这些节点用于训练。- 这样,训练集包含了前 60% 的节点。
-
设置验证集掩码:
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) data.val_mask[int(0.6 * data.num_nodes):int(0.8 * data.num_nodes)] = True
torch.zeros(data.num_nodes, dtype=torch.bool)
:创建一个大小为data.num_nodes
的布尔张量,所有值初始化为False
。data.val_mask[int(0.6 * data.num_nodes):int(0.8 * data.num_nodes)] = True
:将节点索引从 60% 到 80% 的掩码值设置为True
,表示这些节点用于验证。- 这样,验证集包含了从 60% 到 80% 的节点。
-
设置测试集掩码:
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) data.test_mask[int(0.8 * data.num_nodes):] = True
torch.zeros(data.num_nodes, dtype=torch.bool)
:创建一个大小为data.num_nodes
的布尔张量,所有值初始化为False
。data.test_mask[int(0.8 * data.num_nodes):] = True
:将节点索引从 80% 到 100% 的掩码值设置为True
,表示这些节点用于测试。- 这样,测试集包含了最后 20% 的节点。
示例代码:
为了更好地理解,我们可以将这些掩码应用到一个具体的数据集上,例如 KarateClub
数据集:
from torch_geometric.datasets import KarateClub
import torch
# 加载KarateClub数据集
dataset = KarateClub()
data = dataset[0]
# 设置训练、验证和测试集掩码
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[:int(0.6 * data.num_nodes)] = True
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[int(0.6 * data.num_nodes):int(0.8 * data.num_nodes)] = True
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[int(0.8 * data.num_nodes):] = True
# 打印掩码信息
print(f'Train mask: {data.train_mask}')
print(f'Validation mask: {data.val_mask}')
print(f'Test mask: {data.test_mask}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Number of validation nodes: {data.val_mask.sum()}')
print(f'Number of test nodes: {data.test_mask.sum()}')
运行结果:
Train mask: tensor([ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False], dtype=torch.bool)
Validation mask: tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, False, False, False, False, False, False, False, False], dtype=torch.bool)
Test mask: tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, True, True, True, True, True, True, True, True], dtype=torch.bool)
Number of training nodes: 34
Number of validation nodes: 20
Number of test nodes: 10
总结
# Read the uploaded file and extract the functions and libraries used
file_path = '/mnt/data/colab0代码详解.md'
# Read the file
with open(file_path, 'r', encoding='utf-8') as file:
content = file.readlines()
# Extract functions and libraries
functions = []
libraries = []
for line in content:
line = line.strip()
if line.startswith('import ') or line.startswith('from '):
libraries.append(line)
elif line.startswith('def '):
function_name = line.split('(')[0].replace('def ', '')
functions.append(function_name)
# Create a table with the extracted information
import pandas as pd
data = {
"Functions": functions,
"Libraries": libraries
}
df = pd.DataFrame(data)
df
# There was an error due to different lengths of functions and libraries.
# Let's handle the lengths and then create a DataFrame.
# Ensure both lists have the same length by filling the shorter list with empty strings
max_length = max(len(functions), len(libraries))
# Extend the shorter list to match the max length
functions.extend([''] * (max_length - len(functions)))
libraries.extend([''] * (max_length - len(libraries)))
# Create the DataFrame again
data = {
"Functions": functions,
"Libraries": libraries
}
df = pd.DataFrame(data)
import ace_tools as tools; tools.display_dataframe_to_user(name="Extracted Functions and Libraries", dataframe=df)
df
Functions | Libraries |
---|---|
visualize | from torch_geometric.datasets import KarateClub |
init | import networkx as nx |
forward | from torch_geometric.utils import to_networkx |
train | import matplotlib.pyplot as plt |
train | import torch |
init | import networkx as nx |
forward | import matplotlib.pyplot as plt |
train | import time |
from IPython.display import Javascript # 限制输出单元格的高度 | |
from torch_geometric.datasets import KarateClub | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import Linear | |
from torch_geometric.nn import GCNConv | |
from torch_geometric.datasets import KarateClub | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import time | |
from torch_geometric.datasets import KarateClub | |
import torch |