论文《Graph Neural Networks with convolutional ARMA filters》笔记

news2024/12/27 13:22:08

【ARMA 2021 PAMI】本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。

本文发表在2021年PAMI期刊上,第一作者学校:UiT the Arctic University of Norway,引用量:494。

PAMI期刊简介:全称IEEE Transactions on Pattern Analysis and Machine Intelligence(IEEE模式分析与机器学习智能汇刊),影响因子很高,计算机视觉顶刊,CCF A。

查询会议:

  • 会伴:https://www.myhuiban.com/

  • CCF deadline:https://ccfddl.github.io/

原文和开源代码链接:

  • paper原文:https://arxiv.org/abs/1901.01343
  • 开源代码:https://github.com/xnuohz/ARMA-dgl

0、核心内容

本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。

研究背景与动机

  • 传统的图神经网络在图上实施卷积操作,通常基于多项式谱滤波器。
  • 作者指出多项式滤波器的局限性,如对噪声的敏感性、不能很好地捕捉全局图结构等。

ARMA滤波器的卷积层

  • 提出了一种基于ARMA滤波器的新型图卷积层,与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。
  • 介绍了ARMA滤波器的图神经网络实现,包括递归和分布式公式,使得卷积层训练高效、节点空间局部化,并能在测试时迁移到新图上。

理论分析

  • 通过谱分析研究了所提出的ARMA层的滤波效果。
  • 论文还讨论了ARMA滤波器与多项式滤波器的数学性质和性能对比。

实验验证

  • 在四个下游任务上进行了实验:半监督节点分类、图信号分类、图分类和图回归。
  • 实验结果显示,基于ARMA层的图神经网络在各项任务重相较于基于多项式滤波器的网络有显著的性能提升。

方法细节

  • 论文详细介绍了ARMA滤波器的理论基础,包括其在图信号处理中的应用。
  • 讨论了ARMA滤波器的非线性和可训练性,以及如何通过优化任务相关的损失函数来学习参数。

实验设置与结果

  • 论文描述了实验的设置,包括数据集、模型架构、超参数选择等。
  • 提供了详细的实验结果,证明了ARMA层在不同任务上的有效性。

结论

  • 论文总结了ARMA层的主要贡献,并强调了其在图机器学习任务中的优越性能。

1、先验知识
① 什么是加权移动平均(weighted moving average)?

在图信号处理和GNN的上下文中,加权移动平均是一种对图信号进行平滑处理的方法。

移动平均(Moving Average,MA)

  • 移动平均是一种常用的信号处理技术,用于减少数据的波动性,突出其趋势。在图信号的背景下,它涉及对图中每个节点的特征值进行局部平均,以获得更平滑的信号表示。
  • 在最简单的形式中,每个节点的新特征值是其自身和其邻居节点特征值的算术平均。

加权移动平均(Weighted Moving Average,WMA)

  • 加权移动平均是移动平均的一种扩展,其中每个邻居节点的贡献被一个权重系数所调整。这些权重反映了节点间连接的重要性或节点特征的相关性。
  • 在图信号处理中,这意味着每个节点的新特征值使其自身和邻居节点特征值的加权和。权重通常基于图的拓扑结构(例如,边的强度或节点的相似性)或其他学习到的参数。

2、从谱滤波器到多项式滤波器发展脉络

① 谱滤波器:在谱域内与非线性可训练滤波器实现卷积的GNN。

参考资料:

  • 《Spectral networks and locally connected networks on graphs》Brunna 2013
  • 《Deep convolutional networks on graph-structured data》Henaff 2015

这种滤波器选择性地缩小或放大图信号的傅里叶系数(节点特征的一个实例),然后将节点特征映射到一个新的空间。

在这里插入图片描述

谱滤波器存在的问题:在实现时需要进行特征分解,计算代价很大。

② 多项式滤波器:为了避免频域昂贵的频谱分解和投影,最先进的GNN将图滤波器作为低阶多项式,直接在节点域学习。

参考资料:

  • 《Convolutional neural networks on graphs with fast localized spectral filtering》Deferrard 2016
  • 《GCN》Kipf 2016
  • 《Variational graph auto-encoders》Kipf 2016

多项式滤波器有一个有限的脉冲响应和执行加权移动平均过滤图信号在局部节点社区,允许快速分布式实现等基于切比雪夫多项式和Lanczos迭代。

多项式滤波器的频率响应(frequency response):

在这里插入图片描述

多项式滤波器实现:

在这里插入图片描述

切比雪夫多项式滤波器实现:

在这里插入图片描述

GCN滤波器实现:

在这里插入图片描述

多项式滤波器存在的问题:建模能力有限,由于其平滑性,不能模拟频率响应的急剧变化。至关重要的是,高阶多项式对于达到高阶邻域是必要的,但它们的计算成本往往更高,过拟合训练数据,使模型对图信号或底层图结构的变化敏感。(总结:低阶多项式不能模拟频率响应的急剧变化,导致过平滑问题;高阶多项式导致过拟合问题。)

③ ARMA(自回归移动平均滤波器):与具有相同参数数量的多项式滤波器相比,ARMA滤波器提供了更多的频率响应,可以解释高阶邻域。

参考资料:

  • 《Design of graph filters and filterbanks》Tremblay 2018
  • 《Signal processing techniques for interpolation in graph structured data》Narang 2013

ARMA滤波器的频率响应:

在这里插入图片描述

ARMA滤波器的实现:

在这里插入图片描述

为了避免求逆,ARMA滤波器的近似实现(迭代方法):

在这里插入图片描述

1阶ARMA滤波器的近似实现:

在这里插入图片描述

1阶ARMA滤波器的频率响应:

k阶ARMA滤波器的近似实现:

3、ARMA层及滤波器实现

1阶ARMA滤波器的实现(本文最核心的公式,而且比较好理解):

在这里插入图片描述

定理1证明了公式14可收敛的条件(可以先忽略proof部分):

在这里插入图片描述

在实际实现ARMA算法过程中存在一些问题和挑战:

  1. 每个GCS堆栈k可能需要不同的、可能是大量的迭代次数 T k T_k Tk来收敛。这使得神经网络的实现变得麻烦,因为计算图是动态的,在训练过程中每次权重矩阵通过梯度下降更新都会变化。
  2. 为了训练参数的反向传播,如果 T k T_k Tk较大,神经网络必须多次展开,引入较高的计算代价和消失梯度问题。

解决方法:

  1. 参数 W k W_k Wk V k V_k Vk随机权重初始化。
  2. 固定收敛次数为常数 T T T

k阶ARMA滤波器的实现:

在这里插入图片描述

ARMA卷积层算法实现:

在这里插入图片描述

4、实验部分

节点分类任务的数据集和超参数:

在这里插入图片描述

节点分类任务实验结果:

5、核心卷积层源码

ARMA4NC类:

class ARMA4NC(nn.Module):
    def __init__(self,
                 in_dim,
                 hid_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0):
        super(ARMA4NC, self).__init__()

        self.conv1 = ARMAConv(in_dim=in_dim,
                              out_dim=hid_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)

        self.conv2 = ARMAConv(in_dim=hid_dim,
                              out_dim=out_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)
        
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, feats):
        feats = F.relu(self.conv1(g, feats))
        feats = self.dropout(feats)
        feats = self.conv2(g, feats)
        return feats

ARMA卷积层:

class ARMAConv(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0,
                 bias=True):
        super(ARMAConv, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.K = num_stacks
        self.T = num_layers
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)

        # init weight
        self.w_0 = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # deeper weight
        self.w = nn.ModuleDict({
            str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)
        })
        # v
        self.v = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # bias
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        for k in range(self.K):
            glorot(self.w_0[str(k)].weight)
            glorot(self.w[str(k)].weight)
            glorot(self.v[str(k)].weight)
        zeros(self.bias)

    def forward(self, g, feats):
        with g.local_scope():
            init_feats = feats
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
            output = None

            for k in range(self.K):
                feats = init_feats
                for t in range(self.T):
                    feats = feats * norm
                    g.ndata['h'] = feats
                    g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                    feats = g.ndata.pop('h')
                    feats = feats * norm

                    if t == 0:
                        feats = self.w_0[str(k)](feats)
                    else:
                        feats = self.w[str(k)](feats)
                    
                    feats += self.dropout(self.v[str(k)](init_feats))
                    feats += self.v[str(k)](self.dropout(init_feats))

                    if self.bias is not None:
                        feats += self.bias[k][t]
                    
                    if self.activation is not None:
                        feats = self.activation(feats)
                    
                if output is None:
                    output = feats
                else:
                    output += feats
                
            return output / self.K 

6、心得&代码复现

从本文中可以学到很多东西,收获很大;本文的方法也比较有意思,值得认真研读,虽然理论部分读起来非常硬核,尤其是算法implementation部分。

实验部分我自己复现的结果:

CoraCiteseerPubmedFilmSquirrelChameleonTexasCornellWisconsin
GCN81.90%71.80%79.10%23.36%34.49%51.97%54.05%54.05%64.71%
H2GCN74.20%59.80%76.60%25.72%38.04%52.63%72.97%48.65%74.51%
ARMA(dgl版)80.80%71.20%78.50%33.30%31.50%53.40%70.27%59.46%79.80%

原论文中没有对异配图进行实验,在这里补充了6个异配图的实验结果,ARMA算法的效果提升主要在异配图上。

7、参考资料
  • kimi:https://kimi.moonshot.cn/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2120159.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【每日一题】LeetCode 104.二叉树的最大深度(树、深度优先搜索、广度优先搜索、二叉树)

【每日一题】LeetCode 104.二叉树的最大深度(树、深度优先搜索、广度优先搜索、二叉树) 题目描述 给定一个二叉树 root,我们需要计算并返回该二叉树的最大深度。二叉树的最大深度是指从根节点到最远叶子节点的最长路径上的节点数。 思路分…

Uni-app 开发鸿蒙 App 全攻略

一、开发前的准备工作 开发鸿蒙 App 之前,我们需要做好充分的准备工作。首先是工具的安装与配置。 Node.js 的安装:推荐使用 LTS 版本的 Node.js。可以前往 Node.js 的官方网站下载适合自己操作系统的安装包,如 Windows 用户根据自己的系统版…

OpenHarmony鸿蒙开发( Beta5.0)智能风扇设备开发实践

样例简介 智能风扇设备不仅可以接收数字管家应用下发的指令来控制风扇开启的时间,调节风扇挡位,更改风扇定时时间,而且还可以加入到数字管家的日程管理中。通过日程可以设定风扇相关的任务,使其在特定的时间段内,风扇…

【MySQL】MySQL表的操作

目录 创建表的语法创建表的示例查看表的结构进入数据库查看自己在哪个数据库查看自己所在数据库都有哪些表查看表的详细信息查看创建表时的详细信息 修改表修改表名修改表的内容插入几个数据增加一列修改一列的所有属性删除某一列修改列的名称 删除表 创建表的语法 CREATE TAB…

DFS算法专题(二)——穷举vs暴搜vs深搜vs回溯vs剪枝【OF】决策树

目录 1、决策树 2、算法实战应用【leetcode】 2.1 题一:全排列 2.2.1 算法原理 2.2.2 算法代码 2.2 题二:子集 2.2.1 算法原理【策略一】 2.2.2 算法代码【策略一】 2.2.3 算法原理【策略二,推荐】 2.2.4 算法代码【策略二&#x…

图像去噪技术:传统中值滤波与改进中值滤波算法的比较

在数字图像处理中,去噪是一个至关重要的步骤,尤其是在图像受到椒盐噪声影响时。本文将介绍一种改进的中值滤波算法,并与传统的中值滤波算法进行比较,以展示其在去除椒盐噪声方面的有效性。 实验环境 软件:MATLAB图像…

Centos如何配置阿里云的yum仓库作为yum源?

背景 Centos在国内访问官方yum源慢,可以用国内的yum源,本文以阿里云yum源为例说明。 快速命令 sudo mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.a…

宏观学习笔记:GDP分析(二)

GDP分析(一)主要是介绍GDP相关的定义以及核算逻辑,本节主要介绍GDP的分析思路。GDP分析主要是2种方法:总量分析和结构分析。 1. 总量分析 1.1 数值选择 一般情况下,分析的对象都是 官方公布的GDP当季值。 1.2 趋势规…

利用发电量和气象数据分析来判断光伏仿真系统的准确性

随着光伏产业的迅速发展,光伏仿真系统通过集成气象数据分析、发电量分析、投融资分析及损耗估算等功能,为光伏项目的全生命周期管理提供了科学依据。 光伏仿真系统集成了气象数据分析、发电量预测、投融资分析、损耗估算及光伏设计等功能。其中&#xf…

qmt量化交易策略小白学习笔记第60期【qmt编程之期权数据--基于BS模型计算欧式期权隐含波动率--内置Python】

qmt编程之获取期权数据 qmt更加详细的教程方法,会持续慢慢梳理。 也可找寻博主的历史文章,搜索关键词查看解决方案 ! 基于BS模型计算欧式期权隐含波动率 基于Black-Scholes-Merton模型,输入期权标的价格、期权行权价、期权现价、无风险利…

【880高数】高等数学一刷错题整理

第一章 函数、极限、连续 2024.8.11日 1. 2. 3. 4. 5. 2024.8.12日 1. 2. 3. 4. 5. 6. 7. 8. 2024.8.13日 1. 2. 3. 4. 2024.8.14日 1. 2. 3. 4. 5. 第二章 一元函数微分学及其应用 2024.8.15日 1. 2. 3. 4. 5. 6. 2024.8.16日 1. 2. 3. 4. 5. 2024.8.17日 1. 2. 3. 4…

哈希表 和 算法

1.哈希表的作用:将我们要存储的数据,通过关键字与位置的关系函数,来确定具体的位置。 2.写哈希表时常出现的问题:哈希冲突/矛盾:当多个数据满足哈希函数的映射时出现 解决的方法为: 1)开放地址…

MVC设计模式与delegate

一、MVC MVC就是Model(模型)、View(视图)、Controller(控制器) 例如上面的 excel表, 数据、数据结构就是模型Model 根据数据形成的直观的、用户能直接看见的柱形图是视图View 数据构成的表格…

ABAP JSON处理应用

1. json 转换成内表 通过上传URL获取json数据并转换为内表 json to itab关键字 METHOD get_itab_for_json.DATA : lr_client TYPE REF TO if_http_client,lv_url TYPE string,lv_content_type TYPE string VALUE application/x-www-form-urlencoded,ev_xstrin…

【python报错】ModuleNotFoundError: No module named ‘utils‘

问题 想要用python语言将A文件夹的a.py脚本引用utils文件夹b.py脚本,直接引用:from utils import XXX 导致在vscode编译器报错:ModuleNotFoundError: No module named utils 这里文件夹A和utils是同级目录【其他情况,修改后面代码…

背钻设计时要优先保证哪一项,STUB长度真的是越短越好吗

高速先生成员--王辉东 人道是: 八月十八潮,壮观天下无。 鲲鹏水击三千里,组练长驱十万夫。 红旗青盖互明末,黑沙白浪相吞屠。 人生会合古难必,此情此景那两得。 小蝶托着腮望着窗外,思绪飞到千里之外…

【鸿蒙开发从0到1 day08】

鸿蒙开发基础 一.联合类型二.枚举类型三.组件和样式1. ArkUI基本语法 四.尺寸五.字体1.字体颜色2.字体样式3.LineHeight() 设置行高 上间距文字下间距4.下划线:5.对齐方式(1)水平对齐方式(2)垂直对齐方式 6.文本缩进和文本省略号设置 六.图片1.图片的等比例缩放2.占位符3.图片填…

2024腾讯互联网AI应用专场

2024腾讯互联网AI应用专场 灵魂提问: 1、AI应用场景: 智能客服智能数据分析BI 通过AI生成的内容的点击率是人工生产的103%。 2、AI时代已经来临, 依然是这些互联网巨头领导。 现在股价低迷,是不是投资的好机会。 3、agent …

矩阵怪 - 2024全新矩阵产品,一键分发抖音,快手,视频号,B站,小红书!

1. 本方案面向谁,解决了什么问题 本方案主要面向C端客户,特别是那些在各大短视频平台(如小红书、抖音、视频号、快手、B站等)上进行内容创作和分发的个人用户、自由职业者、小型团队或企业。这些用户通常面临着在多个平台上同时发…

Python爬虫如何通过滑块验证

一:定位元素的坐标 当 Selenium 定位到元素后,如果想获取元素在页面中的具体坐标位置,可以通过 element.location 的方式来得到元素的起始坐标字典(元素的左上顶点)。然后再通过 element.size 的方式来获取该元素的宽…