torch_geometric实现GCN和LightGCN
- 题记
- demo示意图
- GCN代码
- LightGCN代码
- 参考博文及感谢
题记
使用torch_geometric实现GCN和LightGCN,以后可能要用,做一下备份
demo示意图
GCN代码
X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X Θ \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta} X′=D^−1/2A^D^−1/2XΘ
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones
torch.manual_seed(2023)
"""
默认 \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},加自连接,按权重传递
传递完成后归一化
"""
class BaseModel(MessagePassing):
def __init__(self, in_channels, out_channels, normalize=True, self_loops=True, bias=True, aggr='add', **kwargs):
super(BaseModel, self).__init__(aggr=aggr, **kwargs)
self.aggr = aggr
self.in_channels = in_channels
self.out_channels = out_channels
self.self_loops = self_loops
self.normalize = normalize
self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
uniform(self.in_channels, self.weight)
uniform(self.in_channels, self.bias)
def forward(self, x, edge_index, edge_weight=None):
if self.self_loops:
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
x = torch.matmul(x, self.weight) # 表示乘以一个可学习参数矩阵
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
# propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)
def message(self, x_j, edge_index, size, edge_weight):
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
return norm.view(-1, 1) * x_j if norm is not None else x_j
# norm = edge_weight #将上面全部注释,即没有对邻接矩阵的归一化
# return norm.view(-1, 1) * x_j if norm is not None else x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
if self.normalize:
aggr_out = F.normalize(aggr_out, p=2, dim=-1) # 按行进行归一化
return aggr_out
def __repr(self):
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)
x = torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
GCN = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]]) # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2
h = F.leaky_relu(GCN(x, edge_index, edge_weight=edge_weight))
print(h)
LightGCN代码
X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} X′=D^−1/2A^D^−1/2X
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones
torch.manual_seed(2023)
"""
默认 \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \mathbf{X},不加自连接,按权重传递
传递完成后不进行归一化
"""
class BaseModel(MessagePassing):
def __init__(self, in_channels, out_channels, normalize=False, self_loops=False, aggr='add', **kwargs):
super(BaseModel, self).__init__(aggr=aggr, **kwargs)
self.aggr = aggr
self.in_channels = in_channels
self.out_channels = out_channels
self.self_loops = self_loops
self.normalize = normalize
def forward(self, x, edge_index, edge_weight=None):
if self.self_loops:
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
# propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)
def message(self, x_j, edge_index, size, edge_weight):
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
return norm.view(-1, 1) * x_j if norm is not None else x_j
# norm = edge_weight #将上面全部注释,即没有对邻接矩阵的归一化
# return norm.view(-1, 1) * x_j if norm is not None else x_j
def update(self, aggr_out):
if self.normalize:
aggr_out = F.normalize(aggr_out, p=2, dim=-1) # 按行进行归一化
return aggr_out
def __repr(self):
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)
x = torch.tensor(
[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
Lightgcn = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]]) # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2
h = Lightgcn(x, edge_index, edge_weight=edge_weight)
print(h)
参考博文及感谢
部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 MMGCN论文开源代码
https://github.com/weiyinwei/MMGCN