【图卷积网络】GCN基础原理简单python实现

news2024/11/19 18:30:38

基础原理讲解

应用路径

卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。

基础原理

其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~1/2A~D~1/2HlWl)
其中:

  • σ \sigma σ 是非线性激活函数
  • D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
  • A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I A A A是原始邻接矩阵, I I I是单位矩阵
  • H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
  • W l W^l Wl是第 l l l层的学习权重矩阵

步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。

加自环的邻接矩阵

A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。

图卷积操作

图像卷积和图卷积
图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X

度矩阵求解

D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
度矩阵的求解

标准化

在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~1/2A~D~1/2X

简单python实现

基于cora数据集实现节点分类

  • cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)

raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)

features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)

# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):
    lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]

# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)

# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))

# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):
    x = map[i] ; y = map[j]
    matrix[x][y] = matrix[y][x] = 1   # 无向图:有引用关系的样本点之间取1

# 查看邻接矩阵的元素
print(matrix.shape)
  • GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adj):
        rowsum = torch.sum(adj,dim=1)
        d_inv_sqrt = torch.pow(rowsum,-0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
        adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)
        
        out = torch.mm(adj_normalized,x)
        out = self.linear(out)
        return out
class GCN(nn.Module):
    def __init__(self, n_features, n_hidden, n_classes):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(n_features, n_hidden)
        self.gcn2 = GCNLayer(n_hidden, n_classes)

    def forward(self, x, adj):
        x = self.gcn1(x, adj)
        x = F.relu(x)
        x = self.gcn2(x, adj)
        return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)

features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)

# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))

# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):
    model.train()
    features, labels = features.cuda(), labels.cuda()
    adj = adj.cuda()
    optimizer.zero_grad()
    output = model(features, adj)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 20 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")

参考

cora数据集及简介
图卷积网络详细介绍
GCN讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1896336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java集成openai,ollama,千帆,通义千问实现文本聊天

Java集成openai,ollama,千帆,通义千问实现文本聊天 本文所使用的环境如下 Win10 JDK17 SpringBoot3.2.4 集成Open AI pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmln…

Emacs有什么优点,用Emacs写程序真的比IDE更方便吗?

Emacs 是一个功能强大的文本编辑器&#xff0c;它在开发者和程序员中非常受欢迎&#xff0c;主要优点包括&#xff1a; 可定制性&#xff1a;Emacs 允许用户通过 Lisp 编程语言来自定义编辑器的行为和界面&#xff0c;几乎可以修改任何方面。扩展性&#xff1a;拥有大量的扩展…

【博士每天一篇文献-综述】Threats, Attacks, and Defenses in Machine Unlearning A Survey

1 介绍 年份&#xff1a;2024 作者&#xff1a;刘子耀&#xff0c;陈晨&#xff0c;南洋理工大学 期刊&#xff1a; 未发表 引用量&#xff1a;6 Liu Z, Ye H, Chen C, et al. Threats, attacks, and defenses in machine unlearning: A survey[J]. arXiv preprint arXiv:2403…

idea导入Maven项目

导入Maven项目 方式1&#xff1a;使用Maven面板&#xff0c;快速导入项目 打开IDEA&#xff0c;选择右侧Maven面板&#xff0c;点击 号&#xff0c;选中对应项目的pom.xml文件&#xff0c;双击即可 说明&#xff1a;如果没有Maven面板&#xff0c;选择 View > Appearance…

马拉松报名小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;赛事信息管理&#xff0c;赛事报名管理&#xff0c;活动商城管理&#xff0c;留言板管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;赛事信息&…

AI Agent框架(LLM Agent):LLM驱动的智能体如何引领行业变革,应用探索与未来展望

AI Agent框架&#xff08;LLM Agent&#xff09;&#xff1a;LLM驱动的智能实体如何引领行业变革&#xff0c;应用探索与未来展望 1. AI Agent&#xff08;LLM Agent&#xff09;介绍 1.1. 术语 Agent&#xff1a;“代理” 通常是指有意行动的表现。在哲学领域&#xff0c;Ag…

通证经济重塑经济格局

在数字化转型的全球浪潮中&#xff0c;通证经济模式犹如一股新兴力量&#xff0c;以其独特的价值传递与共享机制&#xff0c;重塑着经济格局&#xff0c;引领我们步入数字经济的新纪元。 通证&#xff0c;作为这一模式的核心&#xff0c;不仅是权利与权益的数字化凭证&#xf…

ETCD 基本介绍与常见命令的使用

转载请标明出处&#xff1a;https://blog.csdn.net/donkor_/article/details/140171610 文章目录 一、基本介绍1.1 参考1.2 什么是ETCD1.3 ETCD的特点1.4 ETCD的主要功能1.5 ETCD的整体架构1.6 什么时候用ETCD&#xff0c;什么时候用redis 二、安装三、使用3.1 etcdctl3.2 常用…

【动态规划】动态规划一

动态规划一 1.第 N 个泰波那契数2.面试题 08.01. 三步问题3.使用最小花费爬楼梯4.解码方法 点赞&#x1f44d;&#x1f44d;收藏&#x1f31f;&#x1f31f;关注&#x1f496;&#x1f496; 你的支持是对我最大的鼓励&#xff0c;我们一起努力吧!&#x1f603;&#x1f603; 1.…

快手矩阵管理系统:开启短视频营销的智能时代

在短视频内容营销的浪潮中&#xff0c;快手矩阵管理系统以其独特的优势和功能&#xff0c;成为品牌和个人创作者不可或缺的工具。本文将详细解析快手矩阵管理系统的核心功能&#xff0c;探讨它如何帮助用户高效管理多平台、多账号的内容发布和互动。 快手矩阵管理系统概述 快…

14. Revit API: Selection(选择器)

前言 这篇写选择器&#xff0c;经过前面好些篇的讲解&#xff0c;总算把前置内容都写完了。 我们来回忆下都在哪里提到过… 算了&#xff0c;直接进入正文。 一、Selection 命名空间 选择器位于Autodesk.Revit.UI.Selection命名空间下&#xff0c;关系到交互嘛&#xff0c;所…

PostMan Error:Maximum response size reached

一、问题描述 用postman本地测试&#xff0c;restful api接口导出文件&#xff0c;文件大小为190M&#xff0c;服务没问题&#xff0c;总是在导出时&#xff0c;抛出&#xff1a;Error:Maximum response size reached。开始以为是服务相应文件过大或者相应时间超时导致的。其实…

数字流的秩

题目链接 数字流的秩 题目描述 注意点 x < 50000 解答思路 可以使用二叉搜索树存储出现的次数以及数字的出现次数&#xff0c;方便后续统计数字x的秩关键在于构建树的过程&#xff0c;如果树中已经有值为x的节点&#xff0c;需要将该节点对应的数字出现次数加1&#xf…

14-8 小型语言模型的兴起

过去几年&#xff0c;我们看到人工智能能力呈爆炸式增长&#xff0c;其中很大一部分是由大型语言模型 (LLM) 的进步推动的。GPT-3 等模型包含 1750 亿个参数&#xff0c;已经展示了生成类似人类的文本、回答问题、总结文档等能力。然而&#xff0c;虽然 LLM 的能力令人印象深刻…

Redis深度解析:核心数据类型与键操作全攻略

文章目录 前言redis数据类型string1. 设置单个字符串数据2.设置多个字符串类型的数据3.字符串拼接值4.根据键获取字符串的值5.根据多个键获取多个值6.自增自减7.获取字符串的长度8.比特流操作key操作a.查找键b.设置键值的过期时间c.查看键的有效期d.设置key的有效期e.判断键是否…

免杀笔记 ---> PE

本来是想先把Shellcode Loader给更新了的&#xff0c;但是涉及到一些PE相关的知识&#xff0c;所以就先把PE给更了&#xff0c;后面再把Shellcode Loader 给补上。 声明&#xff1a;本文章内容来自于B站小甲鱼 1.PE的结构 首先我们要讲一个PE文件&#xff0c;就得知道它的结构…

Appium+python自动化(四十二)- 寿终正寝完结篇 - 结尾有惊喜,过时不候(超详解)

1.简介 按照上一篇的计划&#xff0c;今天给小伙伴们分享执行测试用例&#xff0c;生成测试报告&#xff0c;以及自动化平台。今天这篇分享讲解完。Appium自动化测试框架就要告一段落了。 2.执行测试用例&报告生成 测试报告&#xff0c;宏哥已经讲解了testng、HTMLTestRun…

springboot整合Camunda实现业务

1.bean实现 业务 1.画流程图 系统任务&#xff0c;实现方式 2.定义bean package com.jmj.camunda7test.process.config;import lombok.extern.slf4j.Slf4j; import org.camunda.bpm.engine.TaskService; import org.camunda.bpm.engine.delegate.JavaDelegate; import org.…

开源大模型和闭源大模型,打法有何区别?

现阶段&#xff0c;各个公司都有自己的大模型产品&#xff0c;有的甚至不止一个。除了小部分开源外&#xff0c;大部分都选择了闭源。那么&#xff0c;头部开源模型厂商选择开源是出于怎样的初衷和考虑&#xff1f;未来大模型将如何发展&#xff1f;我们来看看本文的分享。 在对…

Hi3861 OpenHarmony嵌入式应用入门--SNTP

sntp&#xff08;Simple Network Time Protocol&#xff09;是一种网络时间协议&#xff0c;它是NTP&#xff08;Network Time Protocol&#xff09;的一个简化版本。 本项目是从LwIP中抽取的SNTP代码&#xff1b; Hi3861 SDK中已经包含了一份预编译的lwip&#xff0c;但没有…