图神经网络实战——利用节点回归预测网络流量

news2024/12/26 23:26:32

图神经网络实战——利用节点回归预测网络流量

    • 0. 前言
    • 1. 数据集分析
    • 2. 实现 GCN 模型执行节点回归
    • 3. 模型测试
    • 相关链接

0. 前言

在机器学习中,回归指的是对连续值的预测。通常与分类形成鲜明对比,分类的目标是找到正确的类别(即离散值,而非连续值)。在图数据中,分类和回归分别对应于节点分类和节点回归。在本节中,我们将尝试预测每个节点的连续值,而非分类变量。

1. 数据集分析

为了利用节点回归预测网络流量,在本节中,我们将使用 Wikipedia Network 数据集,Wikipedia Network 数据集由 Rozemberckzi 等人于 2019 年引入。它由三个页面网络组成:chameleons (包含 2277 个节点和 31421 条边)、crocodiles (包含 11631 个节点和 170918 条边)和 squirrels (包含 5201 个节点和 198493 条边)。在这些数据集中,节点代表文章,边代表文章之间的相互链接,节点特征反映了文章中包含的特定词语,我们的目标是预测 201812 月的平均流量的对数。
在本节中,我们将在 chameleon 数据集上应用图卷积网络 (Graph Convolutional Network, GCN) 来预测网络流量。

(1) 导入 WikipediaNetwork 并下载 chameleon 数据集,应用转换函数 RandomNodeSplit() 随机创建一个评估掩码和一个测试掩码:

from torch_geometric.datasets import WikipediaNetwork
import torch_geometric.transforms as T

dataset = WikipediaNetwork(root=".", name="chameleon", transform = T.RandomNodeSplit(num_val=200, num_test=500))
data = dataset[0]

(2) 打印该数据集的相关信息:

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of unique features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

输出结果如下所示:

Dataset: WikipediaNetwork()
-------------------
Number of graphs: 1
Number of nodes: 2277
Number of unique features: 2325
Number of classes: 5

Graph:
------
Edges are directed: True
Graph has isolated nodes: False
Graph has loops: True

(3) 在以上输出中,可以看出数据集中有五个类别。但是,我们需要执行的是节点回归任务,而不是分类。实际上,这五个类别正是我们想要预测的连续值的分段函数,但这些标签并不满足我们的需求,因此必须手动进行修改。首先,下载 wikipedia.zip 文件,解压缩后导入 pandas 并使用它加载目标值:

import pandas as pd

df = pd.read_csv('wikipedia/chameleon/musae_chameleon_target.csv')

(4) 使用 np.log10() 对目标值应用对数函数,因为我们的目标是预测月平均流量的对数:

values = np.log10(df['target'])

(5)data.y 重新定义为上一步中获取的连续值张量。需要注意的是,为了便于演示,我们在本例中未对这些值进行归一化处理(通常是获取优秀模型的标准预处理步骤):

data.y = torch.tensor(values)
print(data.y)

tensor([2.2330, 3.9079, 3.9329,  ..., 1.9956, 4.3598, 2.4409],
       dtype=torch.float64)

(6) 将节点度的分布进行可视化:

from torch_geometric.utils import degree
from collections import Counter

# Get list of degrees for each node
degrees = degree(data.edge_index[0]).numpy()

# Count the number of nodes for each degree
numbers = Counter(degrees)

# Bar plot
fig, ax = plt.subplots()
ax.set_xlabel('Node degree')
ax.set_ylabel('Number of nodes')
plt.bar(numbers.keys(), numbers.values())
plt.show()

节点度分布

与 Cora 和 Facebook Page-Page 数据集相比,该分布的尾部较短,但形状相似,大多数节点只有一个或几个邻居,但其中一些节点作为 "枢纽"节点可以连接 80 多个节点。

(7) 在节点回归的任务中,节点度的分布并不是唯一需要检查的分布类型,目标值的分布同样重要。事实上,非正态分布(如节点度)往往更难预测,可以使用 Seaborn 库绘制目标值,并将其与 scipy.stats.norm 提供的正态分布进行比较:

import seaborn as sns
from scipy.stats import norm

df['target'] = values
fig = sns.distplot(df['target'], fit=norm)
plt.show()

目标值的分布

可以看到该分布不完全是正态分布,也不像节点度分布那样近似指数分布,因此模型有机会能够很好地预测这些值。

2. 实现 GCN 模型执行节点回归

接下来,使用 PyTorch Geometric 实现图卷积网络 (Graph Convolutional Network, GCN) 架构用于执行节点回归任务。

(1) 定义 GCN 类和 __init__() 初始化方法,使用三个神经元数量递减的 GCNConv 层。这种编码器架构能够迫使模型选择最相关的特征来预测目标值,并添加了一个线性层,令预测输出不局限于 -1 和 `1 之间:

class GCN(torch.nn.Module):
    """Graph Convolutional Network"""
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.gcn1 = GCNConv(dim_in, dim_h*4)
        self.gcn2 = GCNConv(dim_h*4, dim_h*2)
        self.gcn3 = GCNConv(dim_h*2, dim_h)
        self.linear = torch.nn.Linear(dim_h, dim_out)

(2)forward() 方法中使用 GCNConv 层和 nn.Linear 层,但不再需要使用 log_softmax 函数,因为模型的目标并不是预测类别:

    def forward(self, x, edge_index):
        h = self.gcn1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.gcn2(h, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.gcn3(h, edge_index)
        h = torch.relu(h)
        h = self.linear(h)
        return h

(3)fit() 方法中使用 F.mse_loss() 函数替代分类任务中使用的交叉熵损失,使用均方差 (Mean Squared Error, MSE) 作为模型性能的评价指标,MSE 定义如下:
M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 MSE=\frac 1N\sum_{i=1}^N(y_i-\hat y_i)^2 MSE=N1i=1N(yiy^i)2
完整的 fit() 方法代码如下:

    def fit(self, data, epochs): 
        optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.02,
                                      weight_decay=5e-4)

        self.train()
        for epoch in range(epochs+1):
            optimizer.zero_grad()
            out = self(data.x, data.edge_index)
            loss = F.mse_loss(out.squeeze()[data.train_mask], data.y[data.train_mask].float())
            loss.backward()
            optimizer.step()
            if epoch % 20 == 0:
                val_loss = F.mse_loss(out.squeeze()[data.val_mask], data.y[data.val_mask])
                print(f"Epoch {epoch:>3} | Train Loss: {loss:.5f} | Val Loss: {val_loss:.5f}")

(4)test() 方法中同样包含 MSE

    def test(self, data):
        self.eval()
        out = self(data.x, data.edge_index)
        return F.mse_loss(out.squeeze()[data.test_mask], data.y[data.test_mask].float())

(5) 实例化 GCN 模型,模型包含 128 个隐藏维度、 1 个输出维度(目标值),并训练 200epoch

# Create the Vanilla GNN model
gcn = GCN(dataset.num_features, 128, 1)
print(gcn) 

# Train
gcn.fit(data, epochs=200)

模型训练过程

3. 模型测试

(1) 模型训练完成后进行测试,获得在测试集上的 MSE

# Test
loss = gcn.test(data)
print(f'\nGCN test loss: {loss:.5f}\n')

# GCN test loss: 0.80326

MSE 损失本身并不是最合适指标,可以使用以下两个指标得到更有意义的结果:

  • RMSE:衡量误差的平均值:
    R M S E = M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 RMSE=\sqrt {MSE}=\sqrt {\frac 1N\sum_{i=1}^N(y_i-\hat y_i)^2} RMSE=MSE =N1i=1N(yiy^i)2
  • 平均绝对误差 (Mean Absolute Error, MAE):预测值和实际值之间的平均绝对差值:
    M A E = 1 N ∑ i = 1 N ∣ y i − y ^ i ∣ MAE=\frac 1N\sum_{i=1}^N |y_i-\hat y_i| MAE=N1i=1Nyiy^i

接下来,使用 Python 实现上述两种模型评价指标。

(2) 可以直接从 scikit-learn 库中导入 MSEMAE

from sklearn.metrics import mean_squared_error, mean_absolute_error

(3) 使用 .detach().numpy() 将模型预测值的 PyTorch 张量转换为 NumPy 数组:

out = gcn(data.x, data.edge_index)
y_pred = out.squeeze()[data.test_mask].detach().numpy()
mse = mean_squared_error(data.y[data.test_mask], y_pred)
mae = mean_absolute_error(data.y[data.test_mask], y_pred)

(4) 使用 scikit-learn 库函数计算 MSEMAE,使用 np.sqrt() 计算 MSE 的平方根得到 RMSE

print('=' * 43)
print(f'MSE = {mse:.4f} | RMSE = {np.sqrt(mse):.4f} | MAE = {mae:.4f}')
print('=' * 43)
'''
===========================================
MSE = 0.8033 | RMSE = 0.8962 | MAE = 0.7409
===========================================
'''

不同指标可以用于比较不同的模型。为了直观可视化模型性能,可使用散点图,其中横轴代表预测值,纵轴代表实际值,在 Seaborn 库中可以使用函数 regplot() 实现这种可视化:

fig = sns.regplot(x=data.y[data.test_mask].numpy(), y=y_pred)
fig.set(xlabel='Ground truth', ylabel='Predicted values')
plt.show()

模型性能

虽然我们没有使用基线模型,但仍然可以看出模型可以得到不错的预测结果,因为离群值很少。尽管数据集很小,但足以说明 GCN 在多种应用中都能发挥作用。如果我们想改进模型性能,可以调整超参数并进行误差分析,以了解异常值的来源。

相关链接

图神经网络实战(1)——图神经网络(Graph Neural Networks, GNN)基础
图神经网络实战(2)——图论基础
图神经网络实战(3)——基于DeepWalk创建节点表示
图神经网络实战(4)——基于Node2Vec改进嵌入质量
图神经网络实战(5)——常用图数据集
图神经网络实战(6)——使用PyTorch构建图神经网络
图神经网络实战(7)——图卷积网络(Graph Convolutional Network, GCN)详解与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1605647.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++_智能指针

文章目录 前言一、智能指针原理二、库支持的智能指针类型1.std::auto_ptr2.std::unique_ptr3.std::shared_ptr4.std::weak_ptr 三、删除器总结 前言 智能指针是一种采用RAII思想来保护申请内存不被泄露的方式来管理我们申请的内存,对于RAII,我们之前也已…

这是刚发布的人形机器人?不,分明是《午夜凶铃》现实版

波士顿动力公司大名鼎鼎的人形机器人Atlas,你一定见识过吧。 Atlas可以像人一样行走、奔跑和攀爬 | 波士顿动力公司 这款用液压系统打造的机器人产品,经过十多年的调试升级,才终于拥有了人类一般灵活的身手。在波士顿动力公司历年来放出的视频…

OpenHarmony UI开发-ohos-svg

简介 ohos-svg是一个SVG图片的解析器和渲染器,解析SVG图片并渲染到页面上。它支持大部分 SVG 1.1 规范,包括基本形状、路径、文本、样式和渐变,它能够渲染大多数标准的 SVG 图像。ohos-svg的优点是性能好、内存占用低。 效果展示 SVG图片解析并绘制: …

第七周学习笔记DAY.4-方法重写与多态

学完本次课程后,你能够: 实现方法重写 深入理解继承相关概念 了解Object类 会使用重写实现多态机制 会使用instanceof运算符 会使用向上转型 会使用向下转型 什么是方法重写 方法的重写或方法的覆盖(overriding) 1.子类根据…

【STM32CubeIDE 1.15.0】汉化包带路径配置过程

一、IDE软件下载 二、汉化版包路径 三、IDE软件板载汉化包 一、IDE软件下载 ST官网IDE下载链接 二、汉化版包路径 https://mirrors.ustc.edu.cn/eclipse/technology/babel/update-site/ 找不到就到.cn后面一级一级进 三、IDE软件板载汉化包 https://mirrors.ustc.edu…

Jmeter 压测-Jprofiler定位接口相应时间长

1、环境准备 执行压测脚本,分析该接口tps很低,响应时间很长 高频接口在100ms以内,普通接口在200ms以内 2、JProfiler分析响应时间长的方法 ①JProfiler录制数据 压测脚本,执行1-3分钟即可 ②分析接口相应时间长的方法 通过Me…

Django之rest_framework(四)

扩展的视图类介绍 rest_framework提供了几种后端视图(对数据资源进行增删改查)处理流程的实现,如果需要编写的视图属于这几种,则视图可以通过继承相应的扩展类来复用代码,减少自己编写的代码量 官网:3 - Class based views - Django REST framework rest_framework.mixi…

cobaltstrike 流量隐藏

云函数 新建一个云函数,在代码位置进行修改 首先导入 yisiwei.zip 的云函数包 PYTHON # -*- coding: utf8 -*- import json, requests, base64def main_handler(event, context):C2 https://49.xx.xx.xx # 这里可以使用 HTTP、HTTPS~下角标~ path event[path]h…

在Windows 11/10/8中打开计算机管理的几种方法,总有一种适合你

序言 计算机管理是Windows中一个功能强大的工具,允许你管理和监视计算机系统的各个方面。使用“计算机管理”,你可以快速访问“设备管理器”、“磁盘管理”、“本地用户管理”等。本文将向你展示如何在Windows 11/10/8中打开“计算机管理器”。 网上有很多方法可以打开计算…

Spring Security详细学习第一篇

Spring Security 前言Spring Security入门编辑Spring Security底层原理UserDetailsService接口PasswordEncoder接口 认证登录校验密码加密存储退出登录 前言 本文是作者学习三更老师的Spring Security课程所记录的学习心得和笔记知识,希望能帮助到大家 Spring Sec…

单分支:if语句

示例&#xff1a; /*** brief how about if? show you here.* author wenxuanpei* email 15873152445163.com(query for any question here)*/ #define _CRT_SECURE_NO_WARNINGS//support c-library in Microsoft-Visual-Studio #include <stdio.h>#define if_state…

C语言学习/复习23---

一、数据的存储 二、数据类型的介绍 三、整型在内存中的存储 将原码转换为补码。如果数是正数&#xff0c;则补码与原码相同&#xff1b;如果数是负数&#xff0c;则先将原码按位取反&#xff0c;然后加1。将补码转换原补码。如果数是正数&#xff0c;则补码与原码相同&#x…

【笔试强训】双指针的思想!

1.数组中字符串的最小距离 题目链接 解题思路&#xff1a; 小技巧 ✌&#xff1a;标记两个字符串是否被找到&#xff0c;每次找到一个字符串就更新一次答案来保证找到的是最小距离。 实现代码&#xff1a; #include <iostream> using namespace std;int main() {in…

快手本地生活服务商系统怎么操作?

当下&#xff0c;抖音和快手两大短视频巨头都已开始布局本地生活服务&#xff0c;想要在这一板块争得一席之地。而这也很多普通人看到了机遇&#xff0c;选择成为抖音和快手的本地生活服务商&#xff0c;通过将商家引进平台&#xff0c;并向其提供代运营服务&#xff0c;而成功…

截图快捷键失效的解决方法 _ 统信UOS _ 麒麟KOS _ 中科方德NFS

原文链接&#xff1a;截图快捷键失效的解决方法 | 统信UOS | 麒麟KOS | 中科方德NFS Hello&#xff0c;大家好啊&#xff01;在日常使用计算机时&#xff0c;截图功能是我们经常需要用到的一个实用工具&#xff0c;它可以帮助我们快速保存屏幕上的信息&#xff0c;用于报告错误…

恭喜上岸的准研究生们,入学后还有这些奖学金

很多学校都开设了研究生的新生奖学金&#xff0c;有些学校是不分学校等级的全覆盖&#xff0c;比如北京科技大学前两年给研一新生每人发1万。 一般来说&#xff0c;新生奖学金的等级划分就是按考研成绩&#xff0c;所以大家一定要尽可能的考高的分数&#xff0c;不仅仅对评奖学…

云HIS医院管理系统源码 SaaS模式 B/S架构 基于云计算技术

一、系统概述 云HIS系统源码是一款满足基层医院各类业务需要的健康云产品。该系统能帮助基层医院完成日常各类业务&#xff0c;提供病患预约挂号支持、收费管理、病患问诊、电子病历、开药发药、住院检查、会员管理、财务管理、统计查询、医生工作站和护士工作站等一系列常规功…

累积分布函数图(CDF)的介绍、matlab的CDF图绘制方法(附源代码)

在对比如下两个误差的时候&#xff0c;怎么直观地分辨出来谁的误差更低一点&#xff1f;&#xff1a; 通过这种误差时序图往往不容易看出来。 但是如果使用CDF图像&#xff0c;以误差绝对值作为横轴&#xff0c;以横轴所示误差对应的累积概率为纵轴&#xff0c;绘制曲线图&am…

gitlab(docker)安装及使用

GitLab GitLab 是一个用于仓库管理系统的开源项目&#xff0c;使用Git作为代码管理工具&#xff0c;并在此基础上搭建起来的Web服务。 下载(docker) 查询docker镜像gitlab-ce gitlab-ce是它的社区版 [rootlocalhost ~]# docker search gitlab-ce NAME …

Xshell和XFtp下载和使用

Xshell和XFtp下载和使用 最好是官网直接下载。 链接: Xshell官网 Xshell官网最近出了免费个人使用版&#xff0c;但是我直接下载的话感觉非常非常慢&#xff0c;或许挂个梯子会好的多。看到图片的红色字没&#xff0c;可能被骗的人比较多。运行之前的Xshell会显示需要最新版的软…