文章目录
- 前言
- 1. 理论部分
- 1.1 为什么会出现图卷积网络?
- 1.2 图卷积网络的推导过程
- 1.3 图卷积网络的公式
- 2. 代码实现
- 参考资料
前言
本文从使用图卷积网络的目的出发,先对图卷积网络的来源与公式做简要介绍,之后通过一个例子来代码实现图卷积网络。
1. 理论部分
1.1 为什么会出现图卷积网络?
无论是CNN还是RNN,面对的都是规则的数据,面对图这种不规则的数据,原有网络无法对齐进行特征提取,而图这种数据在社会中广泛存在,需要设计一种方法对图数据进行提取,图卷积网络(Graph Convolutional Networks)的出现刚好解决了这一问题。
1.2 图卷积网络的推导过程
推导部分涉及通信相关知识,其主要核心是时域卷积等价于频域相乘,将时域卷积运算等价到频域进行相乘运算,再将相乘结果转化到时域。GCN的强悍之处在于,即使不训练,完全使用随机初始化的参数W,GCN提取出来的特征就以及十分优秀了。
1.3 图卷积网络的公式
公式由来请参考文献 图卷积网络(Graph Convolutional Networks, GCN)详细介绍,其网络的简易结构如下图所示。
图卷积的层与层之间的计算公式为:
H
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
D
~
−
1
2
H
(
l
)
W
(
l
)
)
\pmb{H^{(l+1)}=\sigma ( \tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)} )}
H(l+1)=σ(D~−21A~D~−21H(l)W(l))
式中:
A ~ \tilde{A} A~: A ~ = A + I \tilde{A}=A+I A~=A+I,A为图的邻接矩阵,I为单位矩阵;
D ~ \tilde{D} D~: D ~ \tilde{D} D~为 A ~ \tilde{A} A~的度矩阵(degree matrix),表示每个结点度的数量, D i i = ∑ j = 1 i A i j D_{ii}=\sum_{j=1}^iA_{ij} Dii=∑j=1iAij;
H:每一层的特征,对于输入层,其是X;
σ \sigma σ:非线性激活函数;
W:连接层的权重参数;
2. 代码实现
在ASGCN中卷积层的计算公式为:
h
i
l
=
R
e
l
U
(
∑
j
=
1
n
A
i
j
W
l
g
j
l
)
d
i
+
1
+
b
l
)
\pmb{h_i^{l}=RelU(\frac{\sum_{j=1 }^{n} A_{ij} W^lg_{j}^{l})}{d_i+1}+b^l)}
hil=RelU(di+1∑j=1nAijWlgjl)+bl)
依据计算公式构建代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.bias = nn.Parameter(torch.FloatTensor(out_features))
def forward(self, text, adj):
hidden = torch.matmul(text, self.weight) # 权重self.weight随机产生
denom = torch.sum(adj, dim=1, keepdim=True) + 1 # 加一保证做除法时分母不为零
output = torch.matmul(adj, hidden) / denom
output = F.relu(output + self.bias)
print(output)
return output
def main():
# 假设该句子经过构建依赖树后的邻接矩阵为adj
adj =torch.tensor([
[1., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],
])
# 假设一个句子中有10个单词,从前向后单词对应的索引为[0, 1, 2, 3, 3, 4, 6,0, 1, 2]
input = torch.tensor([0, 1, 2, 3, 3, 4, 6,0, 1, 2], dtype=torch.long)
embedding = torch.nn.Embedding(10, 50)
x = embedding(input) # 生成每个单词对应的词嵌入,维度为50
gc1 = GraphConvolution(50, 10)
gc1(x, adj)
if __name__ == '__main__':
main()
输出:
tensor([[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21]],
grad_fn=)
参考资料
- 图卷积网络 GCN Graph Convolutional Network(谱域GCN)的理解和详细推导
- 图卷积网络(Graph Convolutional Networks, GCN)详细介绍