深度学习(36)—— 图神经网络GNN(1)

news2024/11/27 4:32:24

深度学习(36)—— 图神经网络GNN(1)

这个系列的所有代码我都会放在git上,欢迎造访

文章目录

  • 深度学习(36)—— 图神经网络GNN(1)
    • 1. 基础知识
    • 2.使用场景
    • 3. 图卷积神经网络GCN
      • (1)基本思想
    • 4. GNN基本框架——pytorch_geometric
      • (1)数据
      • (2)可视化
      • (3)网络定义
      • (4)训练模型(semi-supervised)

1. 基础知识

  • GNN考虑的事当前的点和周围点之间的关系

  • 邻接矩阵是对称的稀疏矩阵,表示图中各个点之间的关系

  • 图神经网络的输入是每个节点的特征和邻接矩阵

  • 文本数据可以用图的形式表示吗?文本数据也可以表示图的形式,邻接矩阵表示连接关系

  • 邻接矩阵中并不是一个N* N的矩阵,而是一个source,target的2* N的矩阵
    在这里插入图片描述

  • 信息传递神经网络:每个点的特征如何更新??——考虑他们的邻居,更新的方式可以自己设置:最大,最小,平均,求和等

  • GNN可以有多层,图的结构不发生改变,即当前点所连接的点不发生改变(邻接矩阵不发生变化)【卷积中存在感受野的概念,在GNN中同样存在,GNN的感受野也随着层数的增大变大】

  • GNN输出的特征可以干什么?

    • 各个节点的特征组合,对图分类【graph级别任务】
    • 对各个节点分类【node级别任务】
    • 对边分类【edge级别任务】
    • 利用图结构得到特征,最终做什么自定义!

2.使用场景

  • 为什么CV和NLP中不用GNN?
    因为图像和文本的数据格式很固定,传统神经网络格式是固定的,输入的东西格式是固定的
  • 化学、医疗
  • 分子、原子结构
  • 药物靶点
  • 道路交通,动态流量预测
  • 社交网络——研究人
    GNN输入格式比较随意,是不规则的数据结构, 主要用于输入数据不规则的时候

3. 图卷积神经网络GCN

  • 图卷积和卷积完全不同
  • GCN不是单纯的有监督学习,多数是半监督,有的点是没有标签的,在计算损失的时候只考虑有标签的点。针对数据量少的情况也可以训练

(1)基本思想

  • 网络层次:第一层对于每个点都要做更新,最后输出每个点对应的特征向量【一般不会做特别深层的】
  • 图中的基本组成:G(原图)A(邻接)D(度)F(特征)
  • 度矩阵的倒数* 邻接矩阵 *度矩阵的倒数——>得到新的邻接矩阵【左乘对行做归一化,右乘对列做归一化】
  • 两到三层即可,太多效果不佳

4. GNN基本框架——pytorch_geometric

它实现了各种GNN的方法
注意:安装过程中不要pip install,会失败!根据自己的device和python版本去下载scatter,pattern等四个依赖,先安装他们然后再pip install torch_geometric==2.0
这里记得是2.0版本否则会出现 TypeError: Expected ‘Iterator‘ as the return annotation for __iter__ of SMILESParser, but found ty
献上github地址:这里

下面是一个demo

(1)数据

这里使用的是和这个package提供的数据,具体参考:club
在这里插入图片描述

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

在torch_geometric中图用Data的格式,Data的对象:可以在文档中详细了解在这里插入图片描述
其中的属性

  • edge_index:表示图的连接关系(start,end两个序列)
  • node features:每个点的特征
  • node labels:每个点的标签
  • train_mask:有的节点没有标签(用来表示哪些节点要计算损失)

(2)可视化

from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

在这里插入图片描述

(3)网络定义

GCN layer的定义:在这里插入图片描述
可以在官网的文档做详细了解

在这里插入图片描述
卷积层就有很多了:
在这里插入图片描述

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4) # 只需定义好输入特征和输出特征即可
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index) # 输入特征与邻接矩阵(注意格式,上面那种)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  
        
        # 分类层
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)

_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')# 输出最后分类前的中间特征shape

visualize_embedding(h, color=data.y)

这时很分散
在这里插入图片描述

(4)训练模型(semi-supervised)

import time

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  
    out, h = model(data.x, data.edge_index) #h是两维向量,主要是为了画图方便 
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # semi-supervised
    loss.backward()  
    optimizer.step()  
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

然后就可以看到一系列图,看点的变化情况了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/867328.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于kubeasz部署高可用k8s集群

在部署高可用k8s之前,我们先来说一说单master架构和多master架构,以及多master架构中各组件工作逻辑 k8s单master架构 提示:这种单master节点的架构,通常只用于测试环境,生产环境绝对不允许;这是因为k8s集群…

C++学习笔记——从面试题出发学习C++

C学习笔记——从面试题出发学习C C学习笔记——从面试题出发学习C1. 成员函数的重写、重载和隐藏的区别?2. 构造函数可以是虚函数吗?内联函数可以是虚函数吗?析构函数为什么一定要是虚函数?3. 解释左值/右值、左值/右值引用、std:…

探索FSM (有限状态机)应用

有限状态机(FSM) 是计算机科学中的一种数学模型,可用于表示和控制系统的行为。它由一组状态以及定义在这些状态上的转换函数组成。FSM 被广泛用于计算机程序中的状态机制。 有限状态机(FSM)应用场景 在各种自动化系统…

LVGL学习笔记 29 - LED

目录 1. 设置颜色 2. 设置OFF颜色 3. 设置对比度 4. 改变状态 功能类似CheckBox,用一个方形或则圆形的控件显示开关状态。 lv_obj_t* led1 lv_led_create(lv_scr_act());lv_obj_t* led2 lv_led_create(lv_scr_act());lv_obj_align(led1, LV_ALIGN_CENTER, -80…

生产执行MES系统:提升企业灵活性和响应速度的关键利器

在竞争激烈的市场环境下,企业需要不断提高其灵活性和响应速度,以适应快速变化的需求和市场动态。生产执行MES(Manufacturing Execution System)系统作为信息技术的重要应用,为企业提供了强大的工具和平台,能…

redis 数据结构(一)

Redis 为什么那么快 redis是一种内存数据库,所有的操作都是在内存中进行的,还有一种重要原因是:它的数据结构的设计对数据进行增删查改操作很高效。 redis的数据结构是什么 redis数据结构是对redis键值对值的数据类型的底层的实现&#xff0c…

金蝶云星空与金蝶云星空对接集成采购入库查询连通采购入库新增(MW_写入测试)

金蝶云星空与金蝶云星空对接集成采购入库查询连通采购入库新增(MW_写入测试) 对接源平台:金蝶云星空 金蝶K/3Cloud在总结百万家客户管理最佳实践的基础上,提供了标准的管理模式;通过标准的业务架构:多会计准则、多币别、多地点、多组织、多税…

YOLOv5基础知识入门(3)— 目标检测相关知识点

前言:Hello大家好,我是小哥谈。YOLO算法发展历程和YOLOv5核心基础知识学习完成之后,接下来我们就需要学习目标检测相关知识了。为了让大家后面可以顺利地用YOLOv5进行目标检测实战,本节课就带领大家学习一下目标检测的基础知识点&…

运放电路笔记5-其它典型电路

电容的特性——“通交流,隔直流” 电容器可以用来对交流信号进行通路,同时隔离直流信号。这是因为电容器对交流信号具有低阻抗(通过)和对直流信号具有高阻抗(阻断)的特性。 在低频情况下,电容器…

Layui精简版,快速入门

目录 LayUI之入门 1.什么是layui 2.layui入门 3.自定义模块 4.用户登录 5.主页搭建 LayUI之动态树 main.jsp main.js LayUI之动态选项卡 1.选项卡 main.jsp main.js 2.用户登录 User.java UserDao.java UserAction.java R.java LayUI之用户管理 1.用户查询…

填补封闭社区一加ACE2V在151版本下安装KernelSU方式获取ROOT

背景需求,Android移动端软件太过流氓,随意驻留后台,其他root方案不满意 第一步,请将手机升级到你想稳定的版本 参考文档, 安装 | KernelSU https://kernelsu.org/zh_CN/guide/installation.html#patch-boot-image 免github下载地址 KernelSU https://mrzzoxo.lanzoub.com/b…

设计模式之工厂方法模式(FactoryMethod)

一、概述 定义一个用于创建对象的接口,让子类决定实例化哪一个类。FactoryMethod使一个类的实例化延迟到其子类。 二、适用性 1.当一个类不知道它所必须创建的对象的类的时候。 2.当一个类希望由它的子类来指定它所创建的对象的时候。 3.当类将创建对象的职责委…

[ubuntu]创建root权限的用户 该用户登录后自动切换为root用户

一、创建新用户 1、创建新用户 sudo useradd -r -m -s /bin/bash 用户名 # -r:建立系统账号 -m:自动建立用户的登入目录 -s:指定用户登入后所使用的shell2、手动为用户设置密码 passwd 用户名 二、为用户增加root权限 1、添加写权限 ch…

DIP: NAS(Neural Architecture Search)论文阅读与总结(双份快乐)

文章地址: NAS-DIP: Learning Deep Image Prior with Neural Architecture SearchNeural Architecture Search for Deep Image Prior 参考博客:https://zhuanlan.zhihu.com/p/599390720 文章目录 NAS-DIP: Learning Deep Image Prior with Neural Architecture Search1. 方法…

ChatGPT等人工智能编写文章的内容今后将成为常态

BuzzFeed股价上涨200%可能标志着“转向人工智能”媒体趋势的开始。 周四,一份内部备忘录被华尔街日报透露BuzzFeed正计划使用ChatGPT聊天机器人-风格文本合成技术来自OpenAI,用于创建个性化盘问和将来可能的其他内容。消息传出后,BuzzFeed的…

SpringBoot 3自带的 HTTP 客户端工具

原理 Spring的HTTP 服务接口是一个带有HttpExchange方法的 Java 接口,它支持的支持的注解类型有: HttpExchange:是用于指定 HTTP 端点的通用注释。在接口级别使用时,它适用于所有方法。GetExchange:为 HTTP GET请求指…

题解:ABC276E - Round Trip

题解:ABC276E - Round Trip 题目 链接:Atcoder。 链接:洛谷。 难度 算法难度:普及。 思维难度:提高。 调码难度:提高。 综合评价:困难。 算法 bfs。 思路 从起点周围四个点中任选两…

jenkins自动化构建保姆级教程(持续更新中)

1.安装 1.1版本说明 访问jenkins官网 https://www.jenkins.io/,进入到首页 点击【Download】按钮进入到jenkins下载界面 左侧显示的是最新的长期支持版本,右侧显示的是最新的可测试版本(可能不稳定),建议使用最新的…

代码随想录算法学习心得 51 | 503、下一个更大的元素II 42、接雨水...

一、下一个更大元素II 链接:力扣 描述如下:给定一个循环数组 nums ( nums[nums.length - 1] 的下一个元素是 nums[0] ),返回 nums 中每个元素的 下一个更大元素 。 数字 x 的 下一个更大的元素 是按数组遍历顺序&am…

2023年五款免费高效的在线客服系统大揭秘!

近年来,随着移动互联网的蓬勃发展,企业与消费者之间的互动方式正在迅速演变,从传统的PC端转向了更加便捷灵活的移动端。在这个变革的大背景下,为了满足日益增长的客户需求,企业对于提供优质客户服务的迫切需求也逐渐凸…