【论文阅读】Deep Graph Contrastive Representation Learning

news2024/10/2 3:30:19

目录

  • 0、基本信息
  • 1、研究动机
  • 2、创新点
  • 3、方法论
    • 3.1、整体框架及算法流程
    • 3.2、Corruption函数的具体实现
      • 3.2.1、删除边(RE)
      • 3.2.2、特征掩盖(MF)
    • 3.3、[编码器](https://blog.csdn.net/qq_44426403/article/details/135443921)的设计
      • 3.3.1、直推式学习
    • 3.4、损失函数的定义
    • 3.5、评估
    • 3.6、理论动机
      • 3.6.1、最大化目标函数等价于最大化互信息的下界
      • 3.6.2、三重损失
    • 3.7、实验参数设置
  • 4、代码实现
    • 4.1、RE and MF
    • 4.2、encorder
    • 4.3、GRACE
    • 4.4、loss

0、基本信息

  • 作者:Yanqiao Zhu Yichen Xu
  • 文章链接:Deep Graph Contrastive Representation Learning
  • 代码链接:Deep Graph Contrastive Representation Learning

1、研究动机

  • 现实世界中,图的标签数量较少,尽管GNNs蓬勃发展,但是训练模型时标签的可用性问题也越来越受到关心。

  • 传统的无监督图表征学习方法,例如DeepWalk和node2vec,以牺牲结构信息为代价过度强调邻近信息

  • 基于局部-全局互信息最大化框架的[[DGI]]模型,要求readout函数是单射的具有局限性,并且对节点特征随机排列,当特征矩阵稀疏时,不足以生成不同的上下文信息,导致难以学习对比目标

 本文提出的GRACE模型:首先,通过移除边和掩盖特征生成两个视图,然后最大化两个视图中结点嵌入的一致性。

2、创新点

  • 结点级图对比学习框架
  • 提出新的Corruption Function:删除边和特征掩盖

3、方法论

3.1、整体框架及算法流程

  • 首先,通过Corruption函数在原始图 G G G的基础上生成两个视图 G ~ 1 \tilde{G}_1 G~1 G ~ 2 \tilde{G}_2 G~2
  • 其次,通过编码器函数 f f f,生成两个视图的结点嵌入表征, U = f ( G ~ 1 ) U=f(\tilde{G}_1) U=f(G~1) V = f ( G ~ 2 ) V=f(\tilde{G}_2) V=f(G~2)
  • 计算对比目标函数 J \mathcal{J} J
  • 通过随机梯度下降更新参数;

GRACE的整体框架如下图所示:
在这里插入图片描述

3.2、Corruption函数的具体实现

 视图的生成是对比学习方法的关键组成部分,不同视图为每个节点提供不同的上下文,本文依赖不同视图中结点嵌入之间对比的对比方法,作者在结构和属性两个层次上破坏原始图,这为模型构建了不同的节点上下文,分别是删除边和掩蔽结点特征。

3.2.1、删除边(RE)

 随机删除原图中的部分边。
 首先,采样一个随机掩盖矩阵 R ~ ∈ { 0 , 1 } N × N \tilde{R}\in \{0,1\}^{N \times N} R~{0,1}N×N,矩阵中的每个元素服从伯努利分布,即 R ~ ∼ B ( 1 − p r ) \tilde{R}\sim \mathcal{B}(1-p_r) R~B(1pr) p r p_r pr是每条边被移除的概率;其次,用得到地掩盖矩阵与原始邻接矩阵做Hadamard积,最终得到的邻接矩阵为:
A ~ = A ∘ R ~ \tilde{A}=A\circ \tilde{R} A~=AR~
注意,上式为Hadamard积。

3.2.2、特征掩盖(MF)

 再结点特征中用零随机地掩盖部分特征。
 首先,采样一个随机向量 m ~ ∈ { 0 , 1 } F \tilde{m}\in\{0,1\}^F m~{0,1}F,向量的每个元素来自于伯努利分布,即 m ~ ∼ B ( 1 − p m ) \tilde{m}\sim \mathcal{B}(1-p_m) m~B(1pm) p r p_r pr是元素被掩盖的概率;其次,用得到地掩盖向量与原始特征做Hadamard积,最终得到的特征矩阵为:
X ~ = [ x 1 ∘ m ~ ; x 2 ∘ m ~ ; . . . ; x N ∘ m ~ ; ] \tilde{X}=[x_1 \circ\tilde{m};x_2 \circ\tilde{m};...;x_N \circ\tilde{m};] X~=[x1m~;x2m~;...;xNm~;]
注意, [ . ; . ] [.;.] [.;.]是连接运算符。

3.3、编码器的设计

 针对不同任务,transductive learning、inductive learning on large graphs和inductive learning on multiple graphs,设计不同的编码器。这里仅仅列出transductive learning的编码器设计,其他任务编码器的设计请阅读原文4.2节实验设置

3.3.1、直推式学习

 直推式学习采用了一个两层的GCN作为编码器。编码器 f f f的形式如下:
G C i ( X , A ) = σ ( D ^ 1 2 A ^ D ^ 1 2 X W i ) GC_i(X,A)=\sigma(\hat{D}^{\frac{1}{2}}\hat{A}\hat{D}^{\frac{1}{2}}XW_i) GCi(X,A)=σ(D^21A^D^21XWi)
f ( X , A ) = G C 2 ( G C 1 ( X , A ) , A ) f(X,A)=GC_2(GC_1(X,A),A) f(X,A)=GC2(GC1(X,A),A)
其中, A ^ = A + I \hat{A}=A+I A^=A+I D ^ \hat{D} D^ A ^ \hat{A} A^的度矩阵, σ ( . ) \sigma(.) σ(.)为激活函数,例如 R e L U ( . ) = m a x ( 0 , . ) \mathrm{ReLU}(.)=max(0,.) ReLU(.)=max(0,.) W i W_i Wi为可训练的权重矩阵。

3.4、损失函数的定义

 对比目标,即判别器,是将两个来自不同视图相同结点的嵌入与其他结点区分开来,最大化嵌入之间的结点级的一致性。

 对于任意一个结点 v i v_i vi,在第一个视图中的嵌入为 u i \mathbf{u}_i ui,被视作锚;在另外一个视图中的嵌入为 v i \mathbf{v}_i vi,形成正样本,两个视图中出 v i v_i vi之外的结点嵌入被视为负样本。

 简单而言,正样本:同一结点在不同视图的嵌入被视作正样本对;负样本包含两类:(1)intra-view:同一视图中的不同结点对(2)inter-view:不同视图中的不同结点对。

 判别函数定义为 θ ( u , v ) = s ( g ( u ) , g ( v ) ) \theta(u,v)=s(g(u),g(v)) θ(u,v)=s(g(u),g(v)) s s s为cosine相似度,g为非线性映射,例如两层的MLP。

综上所述,目标函数定义为:

ℓ ( u i , v i ) = log ⁡ e θ ( u i , v i ) / τ e θ ( u i , v i ) / τ ⏟ the positive pair + ∑ k = 1 N 1 [ k ≠ i ] e θ ( u i , v k ) / τ ⏟ inter-view negaive pairs + ∑ k = 1 N 1 [ k ≠ i ] e θ ( u i , u k ) / τ ⏟ intra-view negative pairs \ell(\boldsymbol{u}_i,\boldsymbol{v}_i)=\log\frac{e^{\theta(\boldsymbol{u}_i,\boldsymbol{v}_i)/\tau}}{\underbrace{e^{\theta(\boldsymbol{u}_i,\boldsymbol{v}_i)/\tau}}_{\text{the positive pair}}+\underbrace{\sum _ { k = 1 }^N\mathbb{1}_{[k\neq i]}e^{\theta(\boldsymbol{u}_i,\boldsymbol{v}_k)/\tau}}_{\text{inter-view negaive pairs}}+\underbrace{\sum _ { k = 1 }^N\mathbb{1}_{[k\neq i]}e^{\theta(\boldsymbol{u}_i,\boldsymbol{u}_k)/\tau}}_{\text{intra-view negative pairs}}} (ui,vi)=logthe positive pair eθ(ui,vi)/τ+inter-view negaive pairs k=1N1[k=i]eθ(ui,vk)/τ+intra-view negative pairs k=1N1[k=i]eθ(ui,uk)/τeθ(ui,vi)/τ

其中, 1 [ k ≠ i ] ∈ { 0 , 1 } \mathbb{1}_{[k\neq i]}\in\{0,1\} 1[k=i]{0,1}是一个指示函数,当且仅当 k ≠ i k \neq i k=i时定于1。两个视图是对称的,另一个视图定义类似 ℓ ( v i , u i ) \ell(\boldsymbol{v}_i,\boldsymbol{u}_i) (vi,ui),最后,要最大化的总体目标被定义为:

J = 1 2 N ∑ i = 1 N [ ℓ ( u i , v i ) + ℓ ( v i , u i ) ] \mathcal{J}=\dfrac{1}{2N}\sum_{i=1}^N\left[\ell(\boldsymbol{u}_i,\boldsymbol{v}_i)+\ell(\boldsymbol{v}_i,\boldsymbol{u}_i)\right] J=2N1i=1N[(ui,vi)+(vi,ui)]

3.5、评估

 类似于DGI中的线性评估方案,模型首先以无监督的方式训练,得到的嵌入被用来训练逻辑回归分类器并做测试。

3.6、理论动机

3.6.1、最大化目标函数等价于最大化互信息的下界

 定理1说明了目标函数 J \mathcal{J} J是InfoNCE目标函数的一个下界,而InfoNCE评估器是MI(即互信息)的下界,所以 J ≤ I ( X ; U , V ) \mathcal{J} \le I(X;U,V) JI(X;U,V)
所以,最大化目标函数 J \mathcal{J} J等价于最大化输入节点特征和学习节点表示之间的互信息 I ( X ; U , V ) I(X;U,V) I(X;U,V)的下界

3.6.2、三重损失

 定理2说明了最小化目标函数与最大化三重损失一致。更详细的证明请看原文。

triplet Loss是深度学习中的一种损失函数,用于训练差异性较小的样本,如人脸等。在人脸识别领域,triplet loss常被用来提取人脸的embedding。 输入数据是一个三元组,包括锚(Anchor)例、正(Positive)例、负(Negative)例,通过优化锚示例与正示例的距离小于锚示例与负示例的距离,实现样本的相似性计算。

3.7、实验参数设置

Dataset p m , 1 p_{m,1} pm,1 p m , 2 p_{m,2} pm,2 p r , 1 p_{r,1} pr,1 p r , 2 p_{r,2} pr,2lrwdepochhidfeatactivation
Cora0.30.40.20.40.0051e-5200128ReLU
Citeseer0.30.20.20.00.0011e-5200256PReLU
Pubmed0.00.20.40.10.0011e-51500256ReLU

4、代码实现

完整代码见
链接:https://pan.baidu.com/s/1g9Rhe1EjxBZ0dFgOfy3CSg
提取码:6666

4.1、RE and MF

from dgl.transforms import DropEdge
#RE
#随机删除边——使用dgl内建库DropEdge
#MF
#随机掩盖特征
def drop_feature(x, drop_prob):
    drop_masks=[]
    for i in range(x.shape[0]):
        drop_mask = torch.empty(
            size= (x.size(1),) ,
            dtype=torch.float32,
            device=x.device).uniform_(0, 1) < drop_prob
        drop_masks.append(drop_mask)
    x = x.clone()
    for i,e in enumerate(drop_masks):
        x[i,e] = 0
    return x

4.2、encorder

import dgl
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
from model.GCNLayer import GCNLayer

class Encoder(nn.Module):
    def __init__(self, infeat: int, outfeat: int, act_func,base_model=GraphConv, k: int = 2):

        super(Encoder, self).__init__()

        self.base_model = base_model
        assert k >= 2
        self.k = k
        self.convs = nn.ModuleList()
        self.convs.append(base_model(infeat, 2 * outfeat))
        for _ in range(1, k-1):
            self.convs.append(base_model(2 * outfeat, 2 * outfeat))
        self.convs.append(base_model(2 * outfeat, outfeat))
        self.act_func = act_func
    def forward(self, g, x ):
        #g = dgl.add_self_loop(g)
        for i in range(self.k):
            x = self.act_func(self.convs[i](g,x))
        return x

4.3、GRACE

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from dgl.nn.pytorch import GraphConv
from model.encoder import Encoder
class GRACE(nn.Module):
   def __init__(self,infeat,hidfeat,act_func,k=2) -> None:
      super(GRACE,self).__init__()
      self.encoder = Encoder(infeat,hidfeat,act_func,base_model=GraphConv,k=k)
   def forward(self,g,x):
      z =self.encoder(g,x)
      return z

4.4、loss

import torch
import torch.nn as nn
import torch.nn.functional as F
class LossFunc(nn.Module):
    def __init__(self, infeat,hidfeat,outfeat,tau) -> None:
        super(LossFunc,self).__init__()
        self.tau = tau
        self.layer1 = nn.Linear(infeat,hidfeat)
        self.layer2 = nn.Linear(hidfeat,outfeat)
    def projection(self,x):
        x = F.elu(self.layer1(x))
        x = self.layer2(x)
        return x
    def sim(self,x,y):
        x = F.normalize(x)
        y = F.normalize(y)
        return torch.mm(x, y.t())
    def sim_loss(self,h1,h2):
        f = lambda x : torch.exp(x/self.tau)
        #exp(\theta(u_i,u_j)/tau)
        intra_sim = f(self.sim(h1,h1))
        #exp(\theta(u_i,v_j)/tau)
        inter_sim = f(self.sim(h1,h2))
        return -torch.log(
            inter_sim.diag() / (intra_sim.sum(1) + inter_sim.sum(1) - intra_sim.diag())
            )
    def forward(self,u,v):
        h1 = self.projection(u)
        h2 = self.projection(v)
        loss1 = self.sim_loss(h1,h2)
        loss2 = self.sim_loss(h2,h1)
        loss_sum = (loss1 + loss2) * 0.5
        res = loss_sum.mean()
        return res

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1395253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[python]裁剪文件夹中所有pdf文档并按名称保存到指定的文件夹

最近在写论文的实验部分&#xff0c;由于latex需要pdf格式的文档&#xff0c;审稿专家需要对pdf图片进行裁剪放大&#xff0c;以保证图片质量。 原图&#xff1a; 裁剪后的图像&#xff1a; 代码粘贴如下。将input_folder和output_folder替换即可。(x1, y1)&#xff0c; (x2…

【Go面试向】rune和byte类型的认识与使用

【Go】rune和byte类型的认识与使用 大家好 我是寸铁&#x1f44a; 总结了一篇rune和byte类型的认识与使用的文章✨ 喜欢的小伙伴可以点点关注 &#x1f49d; byte和rune类型定义 byte,占用1个字节&#xff0c;共8个比特位&#xff0c;所以它实际上和uint8没什么本质区别,它表示…

一文了解【完全合作关系】下的【多智能体强化学习】

处于完全合作关系的多智能体的利益一致&#xff0c;获得的奖励相同&#xff0c;有共同的目标。比如多个工业机器人协同装配汽车&#xff0c;他们的目标是相同的&#xff0c;都希望把汽车装好。 在多智能体系统中&#xff0c;一个智能体未必能观测到全局状态 S。设第 i 号智能体…

Nginx前后端分离部署springboot和vue项目

Nginx前后端分离部署springboot和vue项目&#xff0c;其实用的比较多&#xff0c;有的小伙伴对其原理和配置还一知半解&#xff0c;现在就科普一下&#xff1a; 1、准备后端项目 后端工程无论是微服务还是单体&#xff0c;一般最终都是jar启动&#xff0c;关键点就是把后端服…

vivado RTL运行方法检查、分析方法报告、报告DRC

运行方法检查 Vivado Design Suite提供基于超快设计的自动化方法检查使用“报告方法论”命令的FPGA和SoC&#xff08;UG949&#xff09;方法论指南。您可以生成关于打开、详细阐述、综合或实现的方法论报告设计对于详细设计&#xff0c;方法报告会检查XDC和RTL文件。对于有关使…

CVE重要通用漏洞复现java php

在进行漏洞复现之前我们需要在linux虚拟机上进行docker的安装 我不喜欢win上安因为不知道为什么总是和我的vmware冲突 然后我的kali内核版本太低 我需要重新安装一个新的linux 并且配置网络 我相信这会话费我不少时间 查看版本 uname -a 需要5.5或以上的版本 看错了浪…

滚动菜单+图片ListView

目录 Fruit.java FruitAdapter MainActivity activity_main.xml fruit.xml 整体结构 Fruit.java public class Fruit {private String name;private int imageId;public Fruit(String name, int imageId) {this.name name;this.imageId imageId;}public String getNam…

AR与AI融合加速,医疗护理更便捷

根据Reports and Data的AR市场发展报告&#xff0c;到2026年&#xff0c;预计医疗保健市场中的AR/VR行业规模将达到70.5亿美元。这一趋势主要受到对创新诊断技术、神经系统疾病和疾病意识不断增长的需求驱动。信息技术领域的进步&#xff0c;包括笔记本电脑、计算机、互联网连接…

用 Python 制作可视化 GUI 界面,一键实现自动分类管理文件!

经常杂乱无章的文件夹会让我们找不到所想要的文件&#xff0c;因此小编特意制作了一个可视化GUI界面&#xff0c;通过输入路径一键点击实现文件分门别类的归档。 不同的文件后缀归类为不同的类别 我们先罗列一下大致有几类文件&#xff0c;根据文件的后缀来设定&#xff0c;大…

Vue入门七(Vuex的使用|Vue-router|LocalStorage与SessionStorage和cookie的使用|路由的两种工作模式)

文章目录 一、Vuex1&#xff09;理解vuex2&#xff09;优点3&#xff09;何时使用&#xff1f;4&#xff09;使用步骤① 安装vuex② 创建vuex③ 导入vuex④ 创建仓库Store⑤ 基本使用 5&#xff09;五个模块介绍1.State2.mutations3.actions4.Getter5.Modules 6&#xff09;购物…

【vue】ant-col多列栅格式的表单排列方式布局异常:

文章目录 一、效果&#xff1a;二、解决&#xff1a;三、问题&#xff1a; 一、效果&#xff1a; 二、解决&#xff1a; 在row中添加布局类型&#xff1a;type“flex” 三、问题&#xff1a; 后期正式环境还是存在该问题 >>>.ant-form-item {max-height: 32px; }多…

【FastAPI】P1 简单实现 a+b

目录 准备工作代码运行 说明&#xff1a;本文通过 FastAPI 实现返回两个参数 ab 的值&#xff1b; 准备工作 默认读者已准备完善 Python IDE工具以及包管理工具。 首先&#xff0c;需要安装 fastapi 和 uvicorn 库&#xff0c;如果没有请使用 pip 进行安装&#xff1a; pip…

【Android】为什么在子线程中更新UI不会抛出异常

转载请注明来源&#xff1a;https://blog.csdn.net/devnn/article/details/135638486 前言 众所周知&#xff0c;Android App在子线程中是不允许更新UI的&#xff0c;否则会抛出异常&#xff1a; android.view.ViewRootImpl$CalledFromWrongThreadException: Only the origin…

智慧灌区解决方案:针对典型灌区水利管理需求

​随着国家对农业水利的重视,各地积极推进智慧灌区建设,以实现对水资源的精准调度和科学化管理。下面我们针对典型灌区水利管理需求,推荐智慧灌区解决方案。 一、方案构成智慧水利解决方案- 智慧水利信息化系统-智慧水利平台-智慧水利公司 - 星创智慧水利 一、方案构成 (一)水…

安全加速SCDN是什么

安全加速SCDN&#xff08;Secure Content Delivery Network&#xff0c;SCDN&#xff09; 是集分布式DDoS防护、CC防护、WAF防护、BOT行为分析为一体的安全加速解决方案。已使用内容分发网络&#xff08;CDN&#xff09;或全站加速网络&#xff08;ECDN&#xff09;的用户&…

Java CAS原子操作过程及ABA问题

目录 一.什么是CAS 二.流程 三.缺点 四.ABA 问题 五.解决ABA问题 一.什么是CAS CAS&#xff08;Compare And Swap&#xff0c;比较并交换&#xff09;&#xff0c;通常指的是这样一种原子操作&#xff1a;针对一个变量&#xff0c;首先比较它的内存值与某个期望值是否相同…

边缘计算AI智能分析网关V4客流统计算法的概述

客流量统计AI算法是一种基于人工智能技术的数据分析方法&#xff0c;通过机器学习、深度学习等算法&#xff0c;实现对客流量的实时监测和统计。该算法主要基于机器学习和计算机视觉技术&#xff0c;其基本流程包括图像采集、图像预处理、目标检测、目标跟踪和客流量统计等步骤…

EasyDarwin计划新增将各种流协议(RTSP、RTMP、HTTP、TCP、UDP)、文件转推RTMP到其他视频直播平台,支持转码H.264、文件直播推送

之前我们尝试做过EasyRTSPLive&#xff08;将RTSP流转推RTMP&#xff09;和EasyRTMPLive&#xff08;将各种RTSP/RTMP/HTTP/UDP流转推RTMP&#xff0c;这两个服务在市场上都得到了比较多的好评&#xff0c;其中&#xff1a; 1、EasyRTSPLive用的是EasyRTSPClient取流&#xff…

Presents-codeforces

题目链接&#xff1a;Problem - 136A - Codeforces 解题思路&#xff1a; 这题挺有意思&#xff0c;大致意思是&#xff0c;每个人都会互相送礼物&#xff0c;可能送给自己&#xff0c;可能送给别人&#xff0c;第i个数表示第i个人要把礼物送给第i个数的人比如1 3 2&#xff0…

C++系列-第1章顺序结构-9-字符类型char

在线练习&#xff1a; http://noi.openjudge.cn/ https://www.luogu.com.cn/ 总结 本文是C系列博客&#xff0c;主要讲述字符类型char 字符类型char 在C编程语言中&#xff0c;char是一种基本的数据类型&#xff0c;它用于存储单个字符。字符可以是字母、数字、标点符号或者…