ASGCN之图卷积网络(GCN)

news2024/12/23 20:51:59

文章目录

  • 前言
  • 1. 理论部分
    • 1.1 为什么会出现图卷积网络?
    • 1.2 图卷积网络的推导过程
    • 1.3 图卷积网络的公式
  • 2. 代码实现
  • 参考资料


前言

本文从使用图卷积网络的目的出发,先对图卷积网络的来源与公式做简要介绍,之后通过一个例子来代码实现图卷积网络。


1. 理论部分

1.1 为什么会出现图卷积网络?

无论是CNN还是RNN,面对的都是规则的数据,面对图这种不规则的数据,原有网络无法对齐进行特征提取,而图这种数据在社会中广泛存在,需要设计一种方法对图数据进行提取,图卷积网络(Graph Convolutional Networks)的出现刚好解决了这一问题。
在这里插入图片描述

1.2 图卷积网络的推导过程

推导部分涉及通信相关知识,其主要核心是时域卷积等价于频域相乘,将时域卷积运算等价到频域进行相乘运算,再将相乘结果转化到时域。GCN的强悍之处在于,即使不训练,完全使用随机初始化的参数W,GCN提取出来的特征就以及十分优秀了。

1.3 图卷积网络的公式

公式由来请参考文献 图卷积网络(Graph Convolutional Networks, GCN)详细介绍,其网络的简易结构如下图所示。
在这里插入图片描述
图卷积的层与层之间的计算公式为:
H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) \pmb{H^{(l+1)}=\sigma ( \tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)} )} H(l+1)=σ(D~21A~D~21H(l)W(l))
式中:

A ~ \tilde{A} A~: A ~ = A + I \tilde{A}=A+I A~=A+I,A为图的邻接矩阵,I为单位矩阵;
D ~ \tilde{D} D~: D ~ \tilde{D} D~ A ~ \tilde{A} A~的度矩阵(degree matrix),表示每个结点度的数量, D i i = ∑ j = 1 i A i j D_{ii}=\sum_{j=1}^iA_{ij} Dii=j=1iAij;
H:每一层的特征,对于输入层,其是X;
σ \sigma σ:非线性激活函数;
W:连接层的权重参数;

2. 代码实现

在ASGCN中卷积层的计算公式为:
h i l = R e l U ( ∑ j = 1 n A i j W l g j l ) d i + 1 + b l ) \pmb{h_i^{l}=RelU(\frac{\sum_{j=1 }^{n} A_{ij} W^lg_{j}^{l})}{d_i+1}+b^l)} hil=RelU(di+1j=1nAijWlgjl)+bl)
依据计算公式构建代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = nn.Parameter(torch.FloatTensor(out_features))

    def forward(self, text, adj):
        hidden = torch.matmul(text, self.weight) # 权重self.weight随机产生
        denom = torch.sum(adj, dim=1, keepdim=True) + 1  # 加一保证做除法时分母不为零
        output = torch.matmul(adj, hidden) / denom
        output = F.relu(output + self.bias)
        print(output)
        return output
def main():
    # 假设该句子经过构建依赖树后的邻接矩阵为adj
    adj =torch.tensor([
        [1., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 1.],
    ])
    # 假设一个句子中有10个单词,从前向后单词对应的索引为[0, 1, 2, 3, 3, 4, 6,0, 1, 2]
    input = torch.tensor([0, 1, 2, 3, 3, 4, 6,0, 1, 2], dtype=torch.long)
    embedding = torch.nn.Embedding(10, 50)
    x = embedding(input)  # 生成每个单词对应的词嵌入,维度为50
    gc1 = GraphConvolution(50, 10)
    gc1(x, adj)
if __name__ == '__main__':
    main()

输出:
tensor([[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07, 3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21],
[1.1561e+19, 6.8794e+11, 2.7253e+20, 3.0866e+29, 1.1547e+19, 4.1988e+07,3.0357e+32, 1.1547e+19, 6.4069e+02, 4.3066e+21]],
grad_fn=)


参考资料

  1. 图卷积网络 GCN Graph Convolutional Network(谱域GCN)的理解和详细推导
  2. 图卷积网络(Graph Convolutional Networks, GCN)详细介绍

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/386619.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux命令行安装Oracle19c教程和踩坑经验

安装 下载 从 Oracle官方下载地址 需要的版本,本次安装是在Linux上使用yum安装,因此下载的是RPM。另外,需要说明的是,Oracle加了锁的下载需要登录用户才能安装,而用户是可以免费注册的,这里不做过多说明。 …

最新使用nvm控制node版本步骤

一、完全卸载已经安装的node、和环境变量 ①、打开控制面板的应用与功能,搜索node,点击卸载 ②、打开环境变量,将node相关的所有配置清除 ③、打开命令行工具,输入node-v,没有版本号则卸载成功 二、下载nvm安装包 ①…

SBUS的协议详解

SBUS 1.串口配置: 100k波特率, 8位数据位(在stm32中要选择9位), 偶校验(EVEN), 2位停止位, 无控流,25个字节, 2.协议格式: [startbyte] [data1][data2]……

单月涨粉超600w,直播销售额破5亿,2月的黑马都是谁?

2月的抖音,黑马多多,处处有爆点。有直播间热度不减,在带货领域持续位列前茅;有达人通过“抓马式”连麦直播,涨粉657w;有0.01元的低价商品,一天热卖超1000w个。那么,2月有哪些主播表现…

微服务实战03-注册数据服务

EurekaServer ,它扮演的角色是注册中心,用于注册各种微服务,以便于其他微服务找到和访问。有了EurekaServer,还需要一些微服务,注册到EurekaServer上去。 这一节,我们来写一个注册微服务。为了简单起见&am…

【同步工具类:Phaser】

同步工具类:Phaser介绍特性动态调整线程个数层次Phaser源码分析state 变量解析构造函数对state变量赋值阻塞方法arrive()awaitAdvance()业务场景实现CountDownLatch功能代码测试结果实现 CyclicBarrier功能代码展示测试结果总结介绍 一个可重复使用的同步屏障,功能…

26- AlexNet和VGG模型分析 (TensorFlow系列) (深度学习)

知识要点 AlexNet 是2012年ISLVRC 2012竞赛的冠军网络。 VGG 在2014年由牛津大学著名研究组 VGG 提出。 一 AlexNet详解 1.1 Alexnet简介 AlexNet 是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络&#xff0…

paddle推理部署(cpu)

我没按照官方文档去做,吐槽一下,官方文档有点混乱。。一、概述总结起来,就是用c示例代码,用一个模型做推理。二、示例代码下载https://www.paddlepaddle.org.cn/paddle/paddleinferencehttps://github.com/PaddlePaddle/Paddle-In…

Clion连接Docker,使用HElib库

文章目录需求Clion连接服务器内的DockerDockerCLionDocker内配置HElib库参考需求 HElib库是用C编写的同态加密开源库,一般在Linux下使用为了不混淆生产环境,使用Docker搭建HElib运行环境本地在Windows下开发,使用的IDE为Clion,本…

动态规划:leetcode 121. 买卖股票的最佳时机、122. 买卖股票的最佳时机II

leetcode 121. 买卖股票的最佳时机leetcode 122.买卖股票的最佳时机IIleetcode 121. 买卖股票的最佳时机给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日…

node版本管理工具nvm

1.标题卸载nvm和node.js 系统变量中删除nvm添加变量:NVM_HOME和NVM_SYMLINK环境变量中 path:删除nvm自动添加的变量 Path %NVM_HOME%;%NVM_SYMLINK%删除自身安装node环境,参考图一图二 图一 图二 2.安装nvm nvm-window下载------https:/…

ES window 系统环境下连接问题

环境问题:(我采用的版本是 elasticsearch-7.9.3)注意 开始修正之前的配置:前提:elasticsearch.yml增加或者修正一下配置:xpack.security.enabled: truexpack.license.self_generated.type: basicxpack.secu…

对象实例化【JVM】

JVM对象实例化简介/背景一、创建对象的方式1. new2. Class对象的newInstance方法3. Construstor对象的newInstance(xx)方法4. 使用clone方法二、创建对象的步骤1. 判断对象是否已经加载、链接、初始化2. 为对象分配内存3. 处理并发安全问题4. 初始化分配到的空间5. 设置对象的对…

Tech Lead如何引导团队成员解决问题?

作为一个开发团队的Tech Lead,当团队成员向你寻求帮助时,你有没有说过下面这些话? 你别管了,我来解决这个问题你只要。。。就行了你先做其他的吧,我研究一下,然后告诉你怎么做 当我们说这些话时&#xff…

腾讯免费企业邮箱迁移记录

本文记录在重新申请腾讯企业邮箱的过程。 背景 很多年前,将域名latelee.org 迁移到了阿里云,当时因政策原因无法实名,但能使用。去年3月,阿里云提示无法续费,紧急将其转到外面某服务,继续使用,…

IP地址的工作原理

如果您想了解特定设备为何未按预期方式进行连接,或者想要排查网络无法正常工作的可能原因,它可以帮助您了解 IP 地址的工作原理。互联网协议的工作原理与任何其他语言相同,即使用设定的准则进行通信以传递信息。所有设备都使用此协议与其他连…

jq获取同级或者下级的dom节点的操作

1.使用find找到对应的class或者其他 var class_dom1 obj.find(.class名称);或者 find(span .class名称)2.使用添加背景颜色来确定当前的查找位置 class_dom1.css(background,red);3.通过parent来找到它的上级的dom节点 var parent_li_dom1 class_dom1.parent(li.parent_li…

进阶指针——(2)

本次讲解重点: 6. 函数指针数组 7. 指向函数指针数组的指针 8. 回调函数 在前面我们已经讲解了进阶指针的一部分,我们回顾一下在进阶指针(1)我们学过的难点知识点: int my_strlen(const char* str) {return 0; }int main() {//指针数…

创宇盾重保经验分享,看政府、央企如何防护?

三月重保已经迫近,留给我们的准备时间越来越少,综合近两年三月重保经验及数据总结,知道创宇用实际案例的防护效果说话,深入解析为何创宇盾可以在历次重保中保持“零事故”成绩,受到众多部委、政府、央企/国企客户的青睐…

HACKTHEBOX——Irked

nmapnmap -sV -sC -Pn -T4 -oA nmap 10.10.10.117可能是因为网络原因,与目标链接并不稳定,因此添加了参数-Pn,也只扫描了常见的端口扫描可以看到只开启了3个端口,22,80和111。但是在访问web时,页面提示运行着irc因此再…