图模型训练

news2024/9/22 9:37:09

一、依赖安装

网址:pyg-team/pytorch_geometric: Graph Neural Network Library for PyTorch (github.com)

找到此处,点击here进入依赖安装界面

找到自己安装的torch版本并点击,,进入安装依赖

二、用库自带的数据集

代码:

import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx


def visualiize_graph(G,color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G,seed = 42), node_color=color, cmap = "Set2", with_labels=False)

    plt.show()

def visualize_embedding(h,color,epoch = None,loss = None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], c=color, cmap = "Set2", s = 140)
    if epoch is not None and loss is not None:
        plt.xlabel("Epoch: {}, Loss: {:.4f}".format(epoch, loss))
    plt.show()


from torch_geometric.datasets import KarateClub

dataset = KarateClub()

data = dataset[0]
print(data)
#展示点之间的关系
edge_index = data.edge_index
print(edge_index.t())

#可视化
G = to_networkx(data, to_undirected=True)
visualiize_graph(G,data.y)

查看数据集格式

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

数据可视化

三、模型网络搭建

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub

from test import visualize_embedding

dataset = KarateClub()
data = dataset[0]

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self,x,edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()

        #分类层
        out = self.classifier(h)
        return out,h

model = GCN()
# print(model)

_,h = model(data.x, data.edge_index)
print(f'Hidden state size: {h.size()}')

visualize_embedding(h,color = data.y)

四、模型训练

代码:

import time
import torch.nn
from model import GCN
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub

model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def visualiize_graph(G,color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G,seed = 42), node_color=color, cmap = "Set2", with_labels=False)

    plt.show()

def visualize_embedding(h,color,epoch = None,loss = None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], c=color, cmap = "Set2", s = 140)
    if epoch is not None and loss is not None:
        plt.xlabel("Epoch: {}, Loss: {:.4f}".format(epoch, loss))
    plt.show()

def train(data):
    optimizer.zero_grad()
    out,h = model(data.x,data.edge_index)
    loss = criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss,h

dataset = KarateClub()
data = dataset[0]

for epoch in range(101):
    loss,h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h,color = data.y, epoch = epoch, loss = loss)
        time.sleep(0.3)

结果:

 

 

最后会发现不同颜色的点逐渐分散开,loss越来越小 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2051134.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WCT系列(二):SyncTransactionQueue类详解

SyncTransactionQueue类: 接上一回的WindowContainerTransaction类讲解,上一篇博客根据TaskView.java中的updateTaskVisibility()方法分析了WindowContainerTransaction的功能及使用。本次继续上一篇的思路,主要拆解syncTransactionQueue类。…

【JUC】06-可重入锁

可重入锁:又称递归锁。在外层使用锁后,内层仍然可以使用,并不发生死锁,这样的锁就叫可重入锁。synchronized默认是一个可重入锁。 public class Demo01 {public synchronized void m1() {System.out.println(Thread.currentThrea…

软件函数过期-软件开发故障处理-开发语言升级-全栈软件架构师-软件修仙界掌握几十门开发语言

一、软件界通用关键字 obsolete,deprecated,deprecation 二、多语言全栈,所有语言混合开发是什么?十几门开发语言 组合1、php/java/aspJSCandroid 平台物联网设备,智能音箱 组合2:C#PHPPYTHON 组合3&am…

云计算的三大服务模式:IaaS、PaaS、SaaS的深入解析

在数字化转型的浪潮中,云计算以其独特的灵活性、可扩展性和成本效益,正逐渐成为企业IT架构的核心。云计算提供了三种主要的服务模式,分别是基础设施即服务(IaaS)、平台即服务(PaaS)和软件即服务…

【算法/学习】双指针

✨ 少年要迎着朝阳,活得肆无忌惮 🌏 📃个人主页:island1314 🔥个人专栏:算法学习 🚀 欢迎关注:👍点赞 &a…

挑战1G内存!如何在千万记录中找到最热TOP10查询串?

我是小米,一个喜欢分享技术的29岁程序员。如果你喜欢我的文章,欢迎关注我的微信公众号“软件求生”,获取更多技术干货! 哈喽大家好!我是你们的技术小伙伴小米,今天又来和大家分享一个非常实用的算法题!假设我们现在有1000w个查询记录,这些记录中有很多重复的内容,但去…

内存碎片问题—容器启动状态卡在ContainerCreating

线上发现部分容器处于ContainerCreating状态: 查看kubelet日志: [rootdc07-prod-k8s-node /root] journalctl -u kubelet Jul 01 00:45:30 prod-k8s-node kubelet[12227]: I0701 00:45:30.491326 12227 kubelet.go:1908] SyncLoop (ADD, "api"): &quo…

RK3568笔记五十五:yolov10训练部署测试

若该文为原创文章,转载请注明原文出处。 yolov8还没熟悉,yolov10就出来了,本篇记录使用yolov10训练自己的数据,并部署到rk3568上。 参考大佬的博客yolov10 瑞芯微RKNN、地平线Horizon芯片部署、TensorRT部署,部署工程难度小、模型推理速度快_yolov10 rknn-CSDN博客 一、…

【网络编程】基于UDP的TFTP文件传输

1)tftp协议概述 简单文件传输协议,适用于在网络上进行文件传输的一套标准协议,使用UDP传输 特点: 是应用层协议 基于UDP协议实现 数据传输模式 octet:二进制模式(常用) mail:已经不再…

深度学习入门:卷积神经网络 | CNN概述,图像基础知识,卷积层,池化层(还在等什么!!!超详解!!!)

目录 🍔 前言 🍔 图像基础知识 1. 像素和通道的理解 2. 小节 🍔 卷积层 1. 卷积计算 2. Padding 3. Stride 4. 多通道卷积计算 5. 多卷积核卷积计算 6. 特征图大小 7. PyTorch 卷积层 API 7. 小节 🍔 池化层 1. 池…

WEB之文件上传

一:思维导图 二:相关问题解答 1,什么是文件上传漏洞? 文件上传漏洞是一种常见的网络安全问题,它发生在网络应用程序允许用户上传文件到服务器的功能中。如果这一功能没有得到适当的安全控制和验证,攻击者就可以利用…

web开发,过滤器,前后端交互

目录 web开发概述 web开发环境搭建 Servlet概述 Servlet的作用: Servlet创建和使用 Servlet生命周期 http请求 过滤器 过滤器的使用场景: 通过Filter接口来实现: 前后端项目之间的交互: 1、同步请求 2、异步请求 优化…

利用telnet发送QQ邮箱的电子邮件时遇到的问题(2024最新)

问题1:即使在控制面板启用telnet客户端也无法使用telnet 解决:使用管理员权限打开cmd,执行命令:dism /online /Enable-Feature /FeatureName:TelnetClient,之后根据弹出信息键入Y重启即可 参考链接:https:…

开源新宠:RAG2SQL工具,超越Text2SQL的7K Star之作

查询数据库离不开SQL,那如何快速构建符合自己期望的SQL呢?AI发展带来了Text2SQL的能力,众多产品纷纷提供了很好的支持。 今天我们分享一个开源项目,它在Text2SQL的基础上还要继续提高,通过加入RAG的能力进一步增强&am…

虹软科技25届校招笔试算法 A卷

目录 1. 第一题2. 第二题3. 论述题 ⏰ 时间:2024/08/18 🔄 输入输出:ACM格式 ⏳ 时长:2h 本试卷分为不定项选择,编程题,必做论述题和选做论述题,这里只展示编程题和必做论述题,一共三…

代码随想录算法训练营_day17

题目信息 654. 最大二叉树 题目链接: https://leetcode.cn/problems/maximum-binary-tree/题目描述: 给定一个不重复的整数数组 nums 。 最大二叉树 可以用下面的算法从 nums 递归地构建: 创建一个根节点,其值为 nums 中的最大值。递归地在最大值 左边 的 子数组前…

AVI-Talking——能通过语音生成很自然的 3D 说话面孔

概述 论文地址:https://arxiv.org/pdf/2402.16124v1.pdf 逼真的人脸三维动画在娱乐业中至关重要,包括数字人物动画、电影视觉配音和虚拟化身的创建。以往的研究曾试图建立动态头部姿势与音频节奏之间的关联模型,或使用情感标签或视频剪辑作…

【数据结构与算法】如何构建最小堆

最小堆的定义 最小堆,作为一种独特且重要的数据结构,它是一种特殊的二叉树。在这种二叉树中,有一个关键的规则:每一个父节点所存储的值,都必然小于或者等于其对应的子节点的值。这一规则确保了根节点总是承载着整个堆…

机器学习(3)-- 一元线性回归

文章目录 线性回归训练模型测试模型线性回归方程测试实用性 总结 线性回归 线性回归算法是一种用于预测一个或多个自变量(解释变量)与因变量(响应变量)之间关系的统计方法。这种方法基于线性假设,即因变量是自变量的线…

【学习笔记】Day 16-17

一、进度概述 1、ddnet_main 相关代码学习(预计 3-4 天) 二、详情 1、顶层结构 关于代码顶层结构的一些思考和总结,其中下图为师兄代码的文件结构 总结: 对于一个优秀的代码,其文件结构一定也是清晰的&#…