文章目录
- JumpingKnowledge
- 思想:
- 举例
贯彻费曼学习法,用最简单的语句阐述复杂的理论
PYG JumpingKnowledge
JumpingKnowledge
GNN里也有JK了吗
思想:
非常简单。
假设有3层GNN,将经过GNN的每一层都保存下来,即中间层的每一层的嵌入保存,然后最后对所有层的嵌入列表做操作:
这里的操作有:拼接(cat)、最大池化(max pooling)、加权:
使用参数:
class JumpingKnowledge(mode: str, channels: Optional[int] = None, num_layers: Optional[int] = None)
mode (str) – The aggregation scheme to use (“cat”, “max” or “lstm”).
channels (int, optional) – The number of channels per representation. Needs to be only set for LSTM-style aggregation. (default: None)
num_layers (int, optional) – The number of layers to aggregate. Needs to be only set for LSTM-style aggregation. (default: None)
举例
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, JumpingKnowledge
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, jk_mode):
super(GNN, self).__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.jk = JumpingKnowledge(jk_mode, channels=hidden_channels, num_layers=num_layers)
self.lin = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
xs = []
for conv in self.convs:
x = F.relu(conv(x, edge_index))
xs.append(x)
x = self.jk(xs)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# 示例用法
model = GNN(in_channels=16, hidden_channels=32, out_channels=7, num_layers=3, jk_mode='max')
print(model)
Jumping Knowledge 机制通过合并不同层的节点表示,解决了图神经网络中的一些关键问题,提高了模型的表达能力和性能。其多种合并模式提供了灵活性