图神经网络GNN(一)GraphEmbedding

news2024/11/29 2:42:45

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as plt


class MyGraph():
    def __init__(self,device):
        super(MyGraph, self).__init__()
        self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),
                                  nodetype=None,data=[('weight',int)])
        self.adj_matrix = nx.attr_matrix(self.G)
        self.edges = nx.edges(self.G)
        self.edges_emb = torch.eye(len(self.G.edges)).to(device)
        self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)

class GraphEmbedding(nn.Module):
    def __init__(self,nodes_num,edges_num,device,emb_dim = 10):
        super(GraphEmbedding, self).__init__()
        self.device = device
        self.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))
        self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))
        self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))
        self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))

    def forward(self,G:MyGraph):
        self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)
        self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)
        # Encoder
        edges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)
        nodes_emb = F.elu_(nodes_emb)
        edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)
        # Decoder
        policy = self.DeepWalk(G,gamma=5,window=2)
        outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)
        policy,outputs = policy.to(device),outputs.to(device)
        return policy,outputs

    def DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):
        # Calculate transpose matrix
        adj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)
        for i in range(adj_matrix.shape[0]):
            adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)

        adj_nodes = Graph.adj_matrix[1].copy()
        random.shuffle(adj_nodes)
        nodes_idx, route_result = [],[]
        for node in adj_nodes:
            node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()
            node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)
            route_result.append(node_list)
        return torch.tensor(route_result)

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):
        node_list = [node_idx]
        for i in range(window):
            pi = self.HMM_process(adj_matrix,node_idx)
            if torch.sum(pi) == 0:
                pi += 1 / pi.shape[0]
            node_idx = Categorical(pi).sample().item()
            node_list.append(node_idx)
        return node_list

    def HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):

        pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)
        pi[:,node_idx] = 1.0
        pi = torch.matmul(pi,adj_matrix)
        pi = pi.squeeze(0) / (torch.sum(pi) + eps)
        return pi


if __name__ == "__main__":
    epochs = 200
    device = torch.device("cuda:1")
    cross_entrophy_loss = CrossEntropyLoss().to(device)
    Graph = MyGraph(device)
    Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)
    optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)
    scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)
    loss_list = []
    epoch_list = [i for i in range(1,epochs+1)]
    for epoch in range(epochs):
        policy,outputs = Embedding(Graph)
        outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])
        optimizer.zero_grad()
        loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()
        scheduler.step()
        loss_list.append(loss.item())
        print(f"Loss : {loss.item()}")
    plt.plot(epoch_list,loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('CrossEntrophyLoss')
    plt.title('Loss-Epoch curve')
    plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):
        node_list = [node_idx]
        for i in range(window):
            pi = self.HMM_process(adj_matrix,node_idx)
            if torch.sum(pi) == 0:
                pi += 1 / pi.shape[0]
            if i > 0:
                v,t = node_list[-1],node_list[-2]
                x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)
                for x in x_list:
                    if t == x:  # 0
                        pi[x] *= 1/self.p
                    elif adj_matrix[t][x] == 1:  # 1
                        pi[x] *= 1
                    else:   # 2
                        pi[x] *= 1/self.q
            node_idx = Categorical(pi).sample().item()
            node_list.append(node_idx)
        return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1055376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【多线程初阶】多线程案例之线程池

文章目录 前言1. 什么是线程池1.1 线程池的优势 2. 标准库中的线程池2.1 聊聊工厂模式2.2 Executors 创建线程池的几种方式2.3 ThreadPoolExecutor 构造方法中的几个参数2.3.1 RejectedExecutionHandler handler 的几个拒绝策略 3. 自己实现一个线程池总结 前言 本文主要给大家…

【COMP304 LEC3】

LEC 3 1. Contingent Formulas: 定义:Truth or falsity of a propositional formula depends on the truth/falsity of the atoms in the formula 例子:p ∧ q is true if both p and q are true, false otherwise.这里p和q就是atoms&…

[React] 性能优化相关 (一)

文章目录 1.React.memo2.useMemo3.useCallback4.useTransition5.useDeferredValue 1.React.memo 当父组件被重新渲染的时候,也会触发子组件的重新渲染,这样就多出了无意义的性能开销。如果子组件的状态没有发生变化,则子组件是不需要被重新渲…

计算机网络笔记 第二章 物理层

2.1 物理层概述 物理层要实现的功能 物理层接口特性 机械特性 形状和尺寸引脚数目和排列固定和锁定装置 电气特性 信号电压的范围阻抗匹配的情况传输速率距离限制 功能特性 -规定接口电缆的各条信号线的作用 过程特性 规定在信号线上传输比特流的一组操作过程&#xff0…

论文研读 - share work - QPipe:一种并行流水线的查询执行引擎

QPipe:一种并行流水线的查询执行引擎 QPipe: A Simultaneously Pipelined Relational Query Engine 关系型数据库通常独立执行并发的查询,每个查询都需执行一系列相关算子。为了充分利用并发查询中的数据扫描与计算,现有研究提出了丰富的技术…

进程之间的通信方式(共享存储,消息传递,管道通信)

进程通信 进程间通信(Inter-Process Communication,IPC)是指两个进程之间产生数据交互。进程是分配系统资源的单位(包括内存地址空间),因此各进程拥有的内存地址空间相互独立。为了保证安全,一个进程不能直接访问另一…

计算机网络学习易错点

目录 概述 1.internet和Internet的区别 2.面向连接和无连接 3.不同的T 4.传输速率和传播速率 5.传播时延和传输时延(发送时延) 6.语法,语义和同步 一.物理层 1.传输媒体与物理层 2.同步通信和异步通信 3.位同步(比特同…

【算法分析与设计】贪心算法(下)

目录 一、单源最短路径1.1 算法基本思想1.2 算法设计思想1.3 算法的正确性和计算复杂性1.4 归纳证明思路1.5 归纳步骤证明 二、最小生成树2.1 最小生成树性质2.1.1 生成树的性质2.1.2 生成树性质的应用 2.2 Prim算法2.2.1 正确性证明2.2.2 归纳基础2.2.3 归纳步骤2.3 Kruskal算…

debian设置允许ssh连接

解决新debian系统安装后不能通过ssh连接的问题。 默认情况下,Debian系统不开启SSH远程登录,需要手动安装SSH软件包并设置开机启动。 > 设置允许root登录传送门:debian设置允许root登录 首先检查/etc/ssh/sshd_config文件是否存在。 注意…

TFT LCD刷新原理及LCD时序参数总结(LCD时序,写的挺好)

cd工作原理目前不了解,日后会在博客中添加这一部分的内容。 1.LCD工作原理[1] 我对LCD的工作原理也仅仅处在了解的地步,下面基于NXP公司对LCD工作原理介绍的ppt来学习一下。 LCD(liquid crystal display,液晶显示屏) 是由液晶段阵列组成,当…

【EasyPoi】SpringBoot使用EasyPoi自定义模版导出Excel

EasyPoi 官方文档&#xff1a;http://doc.wupaas.com/docs/easypoi Excel模版导出 引入依赖 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency…

newstarctf

wp web: 1.rce 可以发现这个变量名有下划线也有点。 $code$_POST[e_v.a.l]; 这时候如果直接按这个变量名来传参&#xff0c;php 是无法接收到这个值的&#xff0c;具体原因是 php 会自动把一些不合法的字符转化为下划线&#xff08;注&#xff1a;php8以下&#xff09;&am…

【论文极速读】IMAGEBIND —— 通过图片作为桥梁桥联多模态语义

【论文极速读】IMAGEBIND —— 通过图片作为桥梁桥联多模态语义 FesianXu 20230929 at Baidu Search Team 前言 当前大部分多模态工作都集中在图片-文本、视频-文本中&#xff0c;关于音频、深度图、热值图的工作则比较少&#xff0c;在IMAGEBIND中&#xff0c;作者提出了一种…

Neo4j最新安装教程(图文版)

目录 一、软件介绍 二、下载软件 1、官方下载 2、云盘下载 三、安装教程 1、首先配置Neo4j的环境变量 2、启动neo4j服务器 3、访问界面 一、软件介绍 官网地址&#xff1a;https://neo4j.com/ Neo4j是一个高性能、可扩展的图数据库管理系统。它专注于存储、查询和处理大…

Explain执行计划字段解释说明---ID字段说明

ID字段说明 1、select查询的序列号,包含一组数字&#xff0c;表示查询中执行select子句或操作表的顺序 2、ID的三种情况 &#xff08;1&#xff09;id相同&#xff0c;执行顺序由上至下。 &#xff08;2&#xff09;id不同&#xff0c;如果是子查询&#xff0c;id的序号会…

大数据Flink(九十四):DML:TopN 子句

文章目录 DML:TopN 子句 DML:TopN 子句 TopN 定义(支持 Batch\Streaming):TopN 其实就是对应到离线数仓中的 row_number(),可以使用 row_number() 对某一个分组的数据进行排序 应用场景

cesium 雷达扫描 (扫描线)

cesium 雷达扫描 (扫描线) 1、实现方法 图中的线使用polyline方法绘制,外面的圆圈是用ellipse方法绘制(当然也不指这一种方法),图中线的扫描转动效果是实时改变线的经纬度来实现(知道中心点经纬度、又已知方向和距离可以求出端点的经纬度)使用CallbackProperty方法来实…

数据结构与算法——18.avl树

这篇文章我们来看一下avl树 目录 1.概述 2.AVL树的实现 1.概述 我们前面讲了二叉搜索树&#xff0c;它是有一个key值&#xff0c;然后比父节点key值大的在左边&#xff0c;小的在右边。这样设计是为了便于查找。但是有一种极端的情况&#xff0c;就是所有的结点都在一边&am…

7.3 调用函数

前言&#xff1a; 思维导图&#xff1a; 7.3.1 函数调用的形式 我的笔记&#xff1a; 函数调用的形式 在C语言中&#xff0c;调用函数是一种常见的操作&#xff0c;主要有以下几种调用方式&#xff1a; 1. 函数调用语句 此时&#xff0c;函数调用独立存在&#xff0c;作为…

艺术表现形式

abstract expressionism 抽象表现主义 20世纪中期的一种艺术运动&#xff0c;包括多种风格和技巧&#xff0c;特别强调艺术家通过非传统和通常非具象的手段表达态度和情感的自由。 抽象表现主义用有力的笔触和滴落的颜料来表达情感和自发性。 简单地结合“abstract expression…