WWW24因果论文(3/8) |通过因果干预实现图分布外泛化

news2024/9/22 17:25:28

【摘要】由于图神经网络 (GNN) 通常会随着分布变化而出现性能下降,因此分布外 (OOD) 泛化在图学习中引起了越来越多的关注。挑战在于,图上的分布变化涉及节点之间错综复杂的互连,并且数据中通常不存在环境标签。在本文中,我们采用自下而上的数据生成视角,并通过因果分析揭示了一个关键观察结果:GNN 在 OOD 泛化中失败的关键在于来自环境的潜在混杂偏差。后者误导模型利用自我图特征与目标节点标签之间的环境敏感相关性,导致在新的未见节点上出现不良的泛化。基于这一分析,我们引入了一种概念上简单但原则性的方法,用于在节点级分布变化下训练稳健的 GNN,而无需事先了解环境标签。我们的方法采用了一种源自因果推理的新学习目标,该目标协调了环境估计器和专家混合 GNN 预测器。新方法可以抵消训练数据中的混杂偏差,并促进学习可推广的预测关系。大量实验表明,我们的模型可以有效地增强各种分布偏移下的泛化能力,并且在图 OOD 泛化基准上比最先进的方法提高高达 27.4% 的准确率。

原文:Graph Out-of-Distribution Generalization via Causal Intervention
地址:https://arxiv.org/abs/2402.11494
代码:https://github.com/fannie1208/CaNet
出版:www 24
机构: 上海交通大学

写的这么辛苦,麻烦关注微信公众号“码农的科研笔记”!

1 研究问题

本文研究的核心问题是: 如何设计一个图神经网络模型,使其能够在结点属性分布发生变化时,仍然保持良好的泛化性能。

假设一个社交网络中,用户的爱好与其朋友的年龄分布密切相关。在大学生群体中,朋友都比较年轻的用户往往更喜欢篮球运动。但是这种相关性可能只在大学生群体中成立,对于职场社交网络LinkedIn,用户的年龄与其爱好的相关性可能就很弱。如果我们基于大学生的社交网络训练了一个用于预测用户爱好的图神经网络,那么将其直接应用于LinkedIn,可能会遇到泛化失败的问题。

本文研究问题的特点和现有方法面临的挑战主要体现在以下几个方面:

  • 图数据中的分布变化往往涉及到结点之间的复杂交互与关联,需要模型能够充分考虑不同结点的结构化特征。

  • 在图学习问题中,每个结点所处的环境信息通常是隐含的,难以直接获取。这为模型从观测数据中推断有用的环境信息,以指导学习过程,带来了障碍。

针对这些挑战,本文提出了一种基于因果干预的"因果网络(CaNet)"方法:

CaNet巧妙地借鉴了因果推理中的 do-calculus 思想,通过显式地对环境变量建模,消除了隐含的混淆偏差。具体来说,它引入了一个环境估计器,负责基于输入的局部子图推断可能的环境信息。同时,图神经网络的每一层都配备了一组 mixture-of-expert 的传播单元,可以动态地根据推断的环境选择不同的传播方式。通过环境估计器和图神经网络的协同优化,CaNet 可以自动地发现观测数据中的稳定关系,同时避免捕获那些容易受环境变化影响的虚假相关性。这一设计理念犹如为图神经网络装上了一副"透视镜",让它对不可见的因果机制具备了感知和适应的能力。

2 研究方法

2.1 因果分析

论文首先从因果的角度来分析GNN面临分布外泛化问题的根本原因。如图2(a)所示,论文使用有向无环图建模了节点的ego-graph特征、节点标签和环境因素之间的因果依赖关系。可以看到,作为未观测的混淆因素,会影响和的生成过程。当使用最大似然估计来训练GNN时,由于忽略了的影响,模型会错误地学习到由某些特定的引起的和之间的强相关性(如在大学生群体中,"朋友年轻"和"喜欢打篮球"往往同时成立)。然而,这类相关性是不稳定的,一旦测试环境发生变化(如职场人士的社交网络),先前学到的相关性便不再成立。这导致了GNN在分布外数据上的泛化性能显著下降。

2.2 因果干预

为了消除环境因素的混淆偏差,进而提升GNN的分布外泛化性能,论文提出了一种基于因果干预的方法。借助后门调整公式,论文指出,优化干预分布而非观测分布,可以有效避免环境因素的混淆。然而,的求解需要穷举所有可能的环境,这在实际中是不可行的。为此,论文进一步引入变分推断,得到的一个变分下界,如式(5)所示:

其中,是根据节点的ego-graph特征来推断环境因素的估计器,是给定ego-graph特征和推断的环境标签来预测节点标签的GNN。通过最大化该变分下界,可以得到协同优化的算法:环境估计器尽可能准确地推断环境因素,同时要求推断结果与ego-graph特征保持独立;而GNN预测器则根据ego-graph特征和推断的环境因素来预测节点标签。通过协同学习,模型可以学习到与具体环境无关的稳定预测模式,进而提升分布外泛化性能。

2.3 模型实例化

在CaNet中,环境估计器将环境因素表示为一系列伪环境标签向量,其中表示GNN的层数。如式(6-7)所示(太难打了,公式见原文),对于每个节点的第层表示,环境估计器首先计算该节点属于每个伪环境的概率,然后通过Gumbel-Softmax技巧对重参数化,得到。

GNN预测器的核心是一个层的混合专家传播网络,每一层包含个专家分支。如式(8-9)所示,每个分支采用独立的参数,并由推断的伪环境标签进行选择。不同分支学习ego-graph特征的不同组合模式,赋予模型更强的表征能力。GNN预测器的最后一层输出节点表示,并通过全连接层映射为节点标签的预测值。

算法1总结了CaNet的前向计算和训练优化流程。其中,环境估计器和GNN预测器通过梯度下降法交替优化,协同学习对分布外泛化有利的预测模式。

5 实验

5.1 实验场景介绍

该论文提出了一个处理图神经网络节点级别分布差异的因果干预方法CaNet。实验主要在节点属性预测任务中,验证CaNet相比其他模型在训练集和测试集节点分布不一致情况下的泛化优势。同时通过消融实验、超参数分析等进一步探究模型内部机制。

5.2 实验设置

  • Datasets:使用Cora、Citeseer、Pubmed、Twitch、Arxiv、Elliptic等6个不同规模和属性的节点预测数据集,通过时间属性、子图、动态快照等不同方式构建训练集和测试集的分布差异

  • Baseline:ERM, IRM, DeepCoral, DANN, GroupDRO, Mixup等通用OOD方法;SR-GNN, EERM等图数据OOD方法;均使用GCN和GAT作为编码器骨干

  • Implementation details:基于PyTorch 1.13和PyG 2.1, Adam优化器,训练500轮,网格搜索超参数

  • metric:Accuracy, ROC-AUC, macro F1

5.3 实验结果

5.3.1 实验一、不同数据集上的性能对比

目的:在多个数据集上验证CaNet相比其他模型处理分布差异节点的优势

涉及图表:表1,表3,图4

实验细节概述:在Cora、Citeseer、Pubmed上测试合成特征和结构导致的分布差异;在Arxiv上测试不同时间的论文节点;在Twitch上测试不同子图的节点;在Elliptic上测试不同时间快照的节点

结果:

  • CaNet在所有OOD测试集上显著优于对应的基线,在ID测试集上也有竞争力的表现

  • 在Cora和Citeseer上,CaNet在OOD数据的绝对性能接近ID数据

  • 在Arxiv的跨时间差异最大的测试集上,CaNet超出次优baseline 14.1%和27.4%

  • 在Elliptic的动态图快照测试集上,CaNet平均超出次优baseline 12.16%

5.3.2 实验二、消融实验

目的:验证正则化损失、层级环境推断等关键组件的有效性

涉及图表:图5

实验细节概述:去除正则化损失、使用复杂先验分布、采用全局环境表示、使用非参数环境估计器等简化变体

结果:

  • 正则化损失和层级环境推断能有效提升OOD性能

  • 简单先验分布优于复杂先验,更利于泛化

5.3.3 实验三、超参数分析

目的:探究伪环境数K和温度τ对模型性能的影响

涉及图表:图6

实验细节概述:在Arxiv和Twitch上分别评估不同K和τ下模型在各OOD测试集的表现

结果:

  • 性能对K不太敏感,过大或过小的K在Arxiv的OOD 2/3上可能降低性能

  • 适中的τ(如1)效果最佳,过大的τ会导致性能下降

5.3.4 实验四、可视化分析

目的:直观展现不同分支学习到的权重模式差异

涉及图表:图7,8,9,10

实验细节概述:可视化K=3时模型在Arxiv和Twitch上第一层和最后一层不同分支的权重矩阵

结果:不同分支权重有明显差异,说明mixture-of-expert结构能学习到区分不同伪环境的表达模式,利于泛化

4 总结后记

本论文针对图神经网络在面对分布偏移时泛化能力较差的问题,从因果分析的角度揭示了其根源在于未观测到的环境混淆因素。基于此分析,提出了一种通过因果干预改进图神经网络泛化性的方法CaNet。该方法引入了环境估计器和混合专家传播网络,可以在没有先验环境标签的情况下,通过优化一个新的学习目标来捕获对环境不敏感的预测关系,从而提高模型的分布外泛化能力。实验结果表明,该模型在多个具有不同类型分布偏移的数据集上,相比现有方法可以显著提升泛化性能,泛化准确率提升高达27.4%。

疑惑和想法:

  1. 除了节点层面的分布偏移,该方法是否可以推广到处理图层面或子图层面的分布偏移?

  2. 环境估计器推断出的伪环境标签对应着什么物理含义?能否赋予它们可解释性?

  3. 除了混合专家传播网络,是否可以设计其他形式的条件传播机制来建模环境因素对节点表示的影响?

可借鉴的方法点:

  1. 从因果分析角度揭示模型泛化不足的根源,为诊断和改进其他类型的图学习任务提供了新的思路。

  2. 通过优化包含对抗正则化项的目标函数来消除混淆偏差,可以推广到其他需要增强鲁棒性的机器学习场景。

  3. 环境估计器和条件传播网络的思想可以借鉴到图预训练等其他图表示学习任务中,以建模不同图之间的差异。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1706963.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

恭喜社区迎来新PMC成员!

恭喜Apache SeaTunnel社区又迎来一位PMC Memberliugddx!在社区持续活跃的两年间,大家经常看到这位开源爱好者出现在社区的各种活动中,为项目和社区发展添砖加瓦。如今成为项目PMC Member,意味着在社区中的责任更重了,他…

香橙派 AIpro的NPU随手记体验日记

昇腾AI 技术路线 8TOPS INT8(FP16)AI算力 LPDDR4X 8GB/16GB 📅 20240525 开放了原理图和源码,功能接口就不描述了手册都有描述,新手好好学习可以从底层覆盖到应用一个载板拿下 完成香橙派AIpro上手体验 镜像安装&am…

瓦罗兰特账号怎么注册 瓦罗兰特延迟高用什么加速器

《瓦罗兰特》(Valorant)是由拳头游戏(Riot Games)开发并发行的一款免费的多人在线第一人称射击游戏(FPS),它结合了传统的硬核射击机制与英雄角色的能力系统,为玩家提供了独特的竞技体…

【机器学习300问】103、简单的经典卷积神经网络结构设计成什么样?以LeNet-5为例说明。

一个简单的经典CNN网络结构由:输入层、卷积层、池化层、全连接层和输出层,这五种神经网络层结构组成。它最最经典的实例是LeNet-5,它最早被设计用于手写数字识别任务,包含两个卷积层、两个池化层、几个全连接层,以及最…

thingsboard接入臻识道闸

thingsboard 和tb-gateway 是通过源码idea启动测试开发 为了测试这里只是买了臻识道闸的摄像机模组方便调试,然后添加一个开关量开关模拟雷达 道闸品牌 臻识C3R3C5R5变焦500万车牌识别相机高速追逃费相机华厦V86像机 淘宝地址 https://item.taobao.com/item.htm?_us1thkikq4…

渗透攻击(思考题)

目录 1. windows登录的明文密码,存储过程是怎么样的,密文存在哪个文件下,该文件是否可以打开,并且查看到密文 2. 我们通过hashdump 抓取出 所有用户的密文,分为两个模块,为什么? 这两个模块分…

什么是GPT-4o,推荐GPT-4o的获取使用方法,使用GPT4o模型的最新方法教程(2024年5月16更新)

2024年5月最新GPT-4o模型使用教程和简介 2024年5月最新GPT-4o模型使用教程和简介 2024 年 5 月 13 日,openai 发布了最新的模型 GPT4o。 很多同学还不知道如何访问GPT-4、GPT-4 Turbo和GPT-4o等模型,这篇文章介绍如何在ChatGPT中访问GPT-4o&#xff0…

大模型时代下,数字员工演进全景图:RPA/IPA/Agent

从蒸汽机到电力,再到计算机,每一次技术的飞跃都极大地提升了企业效率。 如今,随着数字化转型的浪潮席卷全球,企业开始寻求新的解决方案来优化业务流程、打破数据屏障,达到提效降本的目的。在这一背景下,数字…

Python考试复习--day4

1.三角函数计算 import math aeval(input()) beval(input()) x(-bpow(2*a*math.sin(math.pi/3)*math.cos(math.pi/3),0.5))/(2*a) print(x) math库 2.分段函数B import math xeval(input()) if -6<x<0:yabs(x)5 elif 0<x<3:ymath.factorial(x) elif 3<x<6:y…

智慧园区:打造未来城市的新模式

随着城市化进程的加速和科技创新的推动&#xff0c;城市面临着诸多挑战和机遇。如何提升城市的竞争力和可持续性&#xff0c;是一个亟待解决的问题。在这个背景下&#xff0c;智慧园区作为一种新型的城市发展模式&#xff0c;引起了越来越多的关注和探索。 什么是智慧园区&…

黑马聚合的分类及实现

1、什么是聚合? 聚合是对文档数据的统计、分析、计算 聚合的常见种类有哪些? 桶(Bucket)聚合:用来对文档做分组 TermAggregation:按照文档字段值分组 Date Histogram:按照日期阶梯分组&#xff0c;例如一周为一组&#xff0c;或者一月为一组 度量(…

Docker 入门版

目录 1. 关于Docker 2. Dockr run命令中常见参数解读 3. Docker常见命令 4. Docker 数据卷 5. Docker本地目录挂载 6. 自定义镜像 Dockerfile 语法 自定义镜像模板 Demo 7. Docker网络 1. 关于Docker 在docker里面下载东西&#xff0c;就是相当于绿色面安装板&#x…

弘君资本:沪指跌0.46%,电力板块逆市爆发,半导体板块强势

28日&#xff0c;沪指早盘窄幅震动&#xff0c;午后回落走低&#xff1b;深证成指、创业板指大幅下探&#xff1b;两市成交额小幅萎缩。 截至收盘&#xff0c;沪指跌0.46%报3109.57点&#xff0c;深证成指跌1.23%报9391.05点&#xff0c;创业板指跌1.35%报1806.25点&#xff0c…

手搓顺序表(C语言)

目录 SeqList.h SeqList.c 头插尾插复用任意位置插入 头删尾删复用任意位置删除 SLtest.c 测试示例 顺序表优劣分析 SeqList.h //SeqList.h#pragma once#include <stdio.h> #include <assert.h> #include <stdlib.h> #define IN_CY 3typedef int S…

CyberDAO全国行第三站·西安圆满落幕

CyberDAO全国行第三站于2024年5月27日在西安顺利召开。以聚势启新&#xff0c;聚焦Web3新机遇&#xff0c;开启Web3财富密码为本次会议的思想路线&#xff0c;汇聚了大批Web3爱好者齐聚古城西安。CyberDAO致力于帮助更多Web3爱好者捕获行业价值。 以圆桌论坛《机遇拥抱Web3》拉…

matplotlib ---词云图

词云图是一种直观的方式来展示文本数据&#xff0c;可以体现出一个文本中词频的使用情况&#xff0c;有利于文本分析&#xff0c;通过词频可以抓住一篇文章的重点 本文通过处理一篇关于分析影响洋流流向的文章&#xff0c;分析影响洋流流向的主要因素都有哪些 文本在文末结尾 …

升级鸿蒙4.2新变化,新增 WLAN 网络自动连接开关!

手机已经成为现代人生活中不可或缺的一部分&#xff0c;手机里的功能可以满足大部分人的生活场景&#xff0c;但是最依赖的应该就是手机网络&#xff0c;手机网络突然变差怎么办——消息发不出去&#xff1f;刷新闻速度变慢&#xff1f;仔细检查后&#xff0c;发现其实不是手机…

Linux-CentOS7-解决vim修改不了主机名称(无法打开并写入文件)

Linux-CentOS7-修改主机名称 修改之后使用强制保存退出也不行。 解决办法&#xff1a; 使用hostnamectl命令进行修改 查看系统主机名和信息&#xff1a; hostnamectl这条命令会显示当前系统的主机名、操作系统信息、内核版本、架构信息等相关信息。 修改系统主机名&#xff1…

HQChart使用教程99-K线窗口设置上下间距

HQChart使用教程99-K线窗口设置上下预留间距 指标窗口布局说明设置预留间距数据结构通过Setoption设置通过ChangeIndex设置 HQChart代码地址 指标窗口布局说明 顶部预留间距(3)和底部预留间距(5) 这个部分是算在Y轴坐标上的 设置预留间距 数据结构 HorizontalReserved&#…

Hono 框架使用经验谈

Hono&#x1f525;是一个小型、快速并开源的 Serverless Web 框架&#xff0c;用 TypeScript 写就。它适用于任何JavaScript运行时&#xff1a;Cloudflare Workers&#xff0c;Fastly ComputeEdge&#xff0c;Deno&#xff0c;Bun&#xff0c;Vercel&#xff0c;Netlify&#x…