6.如何用CSV文件生成异构图数据集

news2025/1/23 0:05:32

       我们将使用GroupLens研究小组收集的MovieLens数据集
       这个数据集描述了MovieLens的五星评级和标记活动。该数据集包含来自600多名用户的9000多部电影的约10万个评分。我们将使用该数据集生成两种节点类型,分别保存电影和用户的数据,以及一种连接用户和电影的边类型,表示用户对特定电影的评分关系。
       首先,我们将数据集下载到任意文件夹(在本例中为当前目录):

from torch_geometric.data import download_url, extract_zip

url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'
extract_zip(download_url(url, '.'), '.')

movies_path = './ml-latest-small/movies.csv'
ratings_path = './ml-latest-small/ratings.csv'

打开数据集,就可以看到以下一些文件:
在这里插入图片描述

import pandas as pd

print(pd.read_csv(movies_path).head()) # DataFrame对象,默认显示前5行
print(pd.read_csv(ratings_path).head())

在这里插入图片描述
       为了用PyG数据格式表示这些数据,我们首先定义了一个方法load_node_csv(),该方法读取*.csv文件并返回形状为[num_nodes,num_features]的节点级特征表示x:

import torch

def load_node_csv(path, index_col, encoders=None, **kwargs): # **kwargs用于在函数定义中接收任意数量的关键字参数,是一个字典
    df = pd.read_csv(path, index_col=index_col, **kwargs) # 读取*.csv
    mapping = {index: i for i, index in enumerate(df.index.unique())} # 将索引映射成连续值

    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    return x, mapping

在这里插入图片描述
       这里,load_node_csv()从路径读取*.csv文件,并创建一个字典映射,将其索引列映射到范围{0,…,num_rows-1}中的连续值。这是必要的,因为我们希望我们的最终数据表示尽可能紧凑,例如,第一行中的电影表示应该可以通过x[0]访问。

from sentence_transformers import SentenceTransformer
class SequenceEncoder:
    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(df.values, show_progress_bar=True,
                              convert_to_tensor=True, device=self.device)
        print(x.shape)
        return x.cpu()

       SequenceEncoder类加载一个由model_name给定的预先训练的NLP模型,并使用它将字符串列表编码为形状为[num_strings,embedding_dim]的PyTorch张量。我们可以使用此SequenceEncodermovies.csv文件的标题进行编码。

       以类似的方式,我们可以创建另一个编码器,将电影类型转换为分类标签。为此,我们首先需要找到数据中存在的所有电影类型,创建shape[num_movies,num_genres]的特征表示x,并在类型j存在于电影i中的情况下将1分配给x[i,j]:

class GenresEncoder:
    def __init__(self, sep='|'):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
                
        print(x.shape)
        return x

       有了这个,我们可以通过以下方式获得我们对电影的最终呈现:
在这里插入图片描述
       类似地,我们也可以使用load_node_csv()来获得从userId到连续值的用户映射。但是,此数据集中没有用户的其他特征信息。因此,我们没有定义任何编码器:
在这里插入图片描述

       这样,我们就可以初始化HeteroData对象,并将两种节点类型传递给它:

from torch_geometric.data import HeteroData

data = HeteroData()

data['user'].num_nodes = len(user_mapping)  # Users do not have any features.
data['movie'].x = movie_x

print(data)
print(movie_x.shape)

在这里插入图片描述
       由于用户没有任何节点级别的信息,我们只定义其节点数。因此,在异构图模型的训练过程中,我们可能需要通过torch.nn.Embedding以端到端的方式学习不同的用户嵌入。

       接下来,我们来看看根据用户的评分将他们与电影联系起来。为此,我们定义了一个方法load_edge_csv(),该方法从ratings.csv返回shape[2,num_ratings]的最终edge_index表示,以及原始*.csv文件中存在的任何其他功能:

def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,
                  encoders=None, **kwargs):
    df = pd.read_csv(path, **kwargs)

    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    #print(len(src))
    #print(len(dst))
    edge_index = torch.tensor([src, dst])

    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)
        
    #print(edge_attr.shape)

    return edge_index, edge_attr

       这里,src_index_coldst_index_col分别定义源节点和目标节点的索引列。我们进一步利用节点级映射src_mappingdst_mapping来确保原始索引在我们的最终表示中被映射到正确的连续索引。

       对于文件中定义的每条边,它会在src_mappingdst_mapping中查找正向索引,并适当地移动数据。

       load_node_csv()类似,编码器用于返回额外的边特征信息。例如,为了从ratings.csv中的rating列加载ratings,我们可以定义一个IdentityEncoder,它只需将浮点值列表转换为PyTorch张量:

class IdentityEncoder:
    def __init__(self, dtype=None):
        self.dtype = dtype

    def __call__(self, df):
        return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)

       这样,我们就可以完成我们的HeteroData对象了:

edge_index, edge_label = load_edge_csv(
    ratings_path,
    src_index_col='userId',
    src_mapping=user_mapping,
    dst_index_col='movieId',
    dst_mapping=movie_mapping,
    encoders={'rating': IdentityEncoder(dtype=torch.long)},
)

data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_label

print(data)

在这里插入图片描述
       该HeteroData对象是PyG中异构图的原生格式,可以用作异构图模型的输入。

本文内容参考:PyG官网
视频讲解:4.如何用CSV文件生成异构图数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/836701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【我们一起60天准备考研算法面试(大全)-第三十四天 34/60】【前缀和】【北邮】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

【Linux】Linux下git的使用

文章目录 一、什么是git二、git发展史三、Gitee仓库的创建1.新建仓库2.复制仓库链接3.在命令行克隆仓库3.1仓库里的.gitignore是什么3.2仓库里的git是什么 三、git的基本使用1.将克隆仓库的新增文件添加到暂存区(本地仓库)2.将暂存区的文件添加到.git仓库中3.将.git仓库中的变化…

4.msf辅助模块

目录 1 在虚拟机中设置与外部相同的网段 2 当前内网中的可用IP arp_sweep 3 搜索指定IP的TCP端口信息 portscan/tcp 4 扫描http服务的路由 http/dir_scanner 5 SSH密码爆破 ssh/ssh_login 1 在虚拟机中设置与外部相同的网段 我真实机的地址的网段是192.168.0 我虚拟…

【大模型】开源且可商用的大模型通义千问-7B(Qwen-7B)来了

【大模型】开源且可商用的大模型通义千问-7B(Qwen-7B)来了 新闻通义千问 - 7B 介绍评测表现快速使用环境要求安装相关的依赖库推荐安装flash-attention来提高你的运行效率以及降低显存占用使用 Transformers 运行模型使用 ModelScope 运行模型 量化长文本…

SAP标准搜索帮助(Search Help)改造之标准增强点

1. 搜索帮助加载前 包含程序:LWDTMO01 行:40 标准搜索帮助输出前的控制(影响标准Search Help CDS View Search Help(如果在标准Search Help搜索帮助出口函数上修改控制参数,则不会影响 CDS View Search Help&#xf…

【Kubernetes】Kubernetes之二进制部署

kubernetes 一、Kubernetes 的安装部署1. 常见的安装部署方式1.1 Minikube1.2 Kubeadm1.3 二进制安装部署2. K8S 部署 二进制与高可用的区别2.1 二进制部署2.2 kubeadm 部署二、Kubernetes 二进制部署过程1. 服务器相关设置以及架构2. 操作系统初始化配置3. 部署 etcd 集群4. 部…

Vue——formcreate表单设计器自定义组件实现(二)

前面我写过一个自定义电子签名的formcreate表单设计器组件,那时初识formcreate各种使用也颇为生疏,不过总算套出了一个组件不是。此次时隔半年又有机会接触formcreate,重新熟悉和领悟了一番各个方法和使用指南。趁热打铁将此次心得再次分享。…

python爬虫1:基础知识

python爬虫1:基础知识 前言 ​ python实现网络爬虫非常简单,只需要掌握一定的基础知识和一定的库使用技巧即可。本系列目标旨在梳理相关知识点,方便以后复习。 目录结构 文章目录 python爬虫1:基础知识1. 基础认知1.1 什么是爬虫&…

【2023】XXL-Job 具体通过docker 配置安装容器,再通过springboot执行注册实现完整流程

【2023】XXL-Job 具体通过docker 配置安装容器,再通过springboot执行注册实现 一、概述二、安装1、拉取镜像2、创建数据库3、创建容器并运行3、查看容器和日志4、打开网页 127.0.0.1:9051/xxl-job-admin/ 三、实现注册测试1、创建一个SpringBoot项目、添加依赖。2、…

steam搬砖项目拆解,长久稳定

steam搬砖指将"CS:GO"的游戏道具从国外游戏平台搬到国内的游戏平台(一般都是在网易BUFF)进行贩卖,从而赚取道具商品差价或者汇率的差价。 首先,Steam是全球最大的游戏平台,拥有上亿的玩家,同时在…

ISC 2023︱诚邀您参与赛宁“安全验证评估”论坛

​​8月9日-10日,第十一届互联网安全大会(简称ISC 2023)将在北京国家会议中心举办。本次大会以“安全即服务,开启人工智能时代数字安全新范式”为主题,打造全球首场AI数字安全峰会,赋予安全即服务新时代内涵…

数据驱动+自动化测试

自动化测试代码优化 setUp 在每个测试用例执行之前执行 tearDown 在每个测试用例执行完以后执行 所以,可以利用setUp,把测试用例中的通用代码提取出来,减少冗余 数据驱动测试:优化自动化测试 安装: pip install p…

JDK19 - synchronized关键字导致的虚拟线程PINNED

JDK19 - synchronized关键字导致的虚拟线程PINNED 前言一. PINNED是什么意思1.1 synchronized 绑定测试1.2 synchronized 关键字的替代 二. -Djdk.tracePinnedThreads的作用和坑2.1 死锁案例测试2.2 发生原因的推测2.3 总结 前言 在 虚拟线程详解 这篇文章里面,我们…

Protues 仿真报错Internal Exception: access violation in module ‘UNKNOWN‘[7ADEEEA9]

在使用STM32F103C8进行Protues仿真设计的时候,出现了这个报错,通过查找和定位问题,发现是我在配置供电网络的时候配置错误,要配置成如下: 至于为什么回这样,我猜想应该是和这个软件导入STM32芯片的时候&…

300个智商测试FLASH智商游戏ACCESS数据库

最近在找IQ测试方面的数据,网上大多只留传着33道题这种类型,其他的又因各种条件(比如图片含水印等)不能弄,这是从测智网下载的一些测试智商的游戏数据,游戏文件是FLASH的,扩展名是SWF。 数据包总…

FineReport主题组件使用

主题 添加主题 服务器-》模版主题管理,设置决策报表与普通报表的模版主题: 修改内置模版,打开,点击另存为设置自己的主题名称,保存主题 根据自己需求设置模版相关样式:模版背景、单元格样式、图表样式、…

模板方法设计模式(C++)

定义 定义一个操作中的算法的骨架(稳定),而将一些步骤延迟(变化)到子类中。Template Method使得子类可以不改变(复用)一个算法的结构即可重定义(override重写)该算法的某些特定步骤。 ——《设计模式》GoF Template Method模式是一种非常基…

转录组下游分析 | 懒人分析推荐

写在前面 今天在GitHub看到一个博主写的RNASeqTool的ShinApp,里面包含了PCA、DESeq2、volcano、NormEnrich、GSEA、Gene tred analysis和WGCNA分析。使用后还是很方便的,就此推荐给大家。感兴趣可以自己操作即可。 GitHub网址 https://github.com/Cha…

QT以管理员身份运行

以下配置后,QT在QT Creator调试时,或者生成的.exe程序,都将会默认以管理员身份运行。 一、MSVC编译器 1、在Pro文件中添加以下代码: QMAKE_LFLAGS /MANIFESTUAC:\"level\requireAdministrator\ uiAccess\false\\" …

Java 8 中使用 Stream 遍历树形结构

在实际开发中,我们经常会开发菜单,树形结构,数据库一般就使用父id来表示,为了降低数据库的查询压力,我们可以使用Java8中的Stream流一次性把数据查出来,然后通过流式处理,我们一起来看看&#x…