DGL的图数据处理管道

news2024/11/16 7:34:35

DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道。 DGL推荐用户将图数据处理为 dgl.data.DGLDataset 的子类。该类为导入、处理和保存图数据提供了简单而干净的解决方案。

DGL中的Dataset类

它是处理、导入以及保存dgl.data提供的图数据集的基类,实现了处理图的基本模板。同时,它还提供了一种标准且方便的方式来导入图数据,同时用户可以存储有关数据集的图,也可以存储图中的节点特征、标签、掩码以及诸如类别数标签数等基本信息。该模块的工作方式如下:
在这里插入图片描述
由上图可见,这种加载方式与torch的其他组件导入数据集方式几乎相同。同时Dataset也是学习Pytorch过程中重点关注的内容,这里是一篇Hugging-Face实现中文句子情感分析的完整流程:HuggingFace-利用BERT预训练模型实现中文情感分类(下游任务)
可以看到该部分在定义数据集时直接使用的是Dataset类,而它进一步调用了datasets包中的load_dataset

搭建自己的Dataset

下面定义了一个名为MyDataset的类,它继承于dgl.data.DGLDataset,我们将它作为处理位于远程服务器或本地磁盘上图数据集的接口。

from dgl.data import DGLDataset

class MyDataset(DGLDataset):
    """ 用于在DGL中自定义图数据集的模板:

    Parameters
    ----------
    url : str
        下载原始数据集的url。
    raw_dir : str
        指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/
    save_dir : str
        处理完成的数据集的保存目录。默认:raw_dir指定的值
    force_reload : bool
        是否重新导入数据集。默认:False
    verbose : bool
        是否打印进度信息。
    """
    def __init__(self,
                 url=None,
                 raw_dir=None,
                 save_dir=None,
                 force_reload=False,
                 verbose=False):
        super(MyDataset, self).__init__(name='dataset_name',
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)

    def download(self):
        # 将原始数据下载到本地磁盘
        pass

    def process(self):
        # 将原始数据处理为图、标签和数据集划分的掩码
        pass

    def __getitem__(self, idx):
        # 通过idx得到与之对应的一个样本
        pass

    def __len__(self):
        # 数据样本的数量
        pass

    def save(self):
        # 将处理后的数据保存至 `self.save_path`
        pass

    def load(self):
        # 从 `self.save_path` 导入处理后的数据
        pass

    def has_cache(self):
        # 检查在 `self.save_path` 中是否存有处理后的数据
        pass

原始的DGLDataset类中有抽象函数process()__ getitem __(idx)以及__len __()。DGL要求子类必须实现这些函数,同时DGL建议实现保存和导入函数,这样做的目的是对于处理的大型数据集,可以节省大量的时间开销。
DGLDataset 的目的是提供一种标准且方便的方式来导入图数据。 用户可以存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息。 诸如采样、划分或特征归一化等操作建议在 DGLDataset 子类之外完成。

下载原始数据

如果用户的数据集已经在本地磁盘中,请确保它被存放在目录 raw_dir 中。 如果用户想在任何地方运行代码而又不想自己下载数据并将其移动到正确的目录中,则可以通过实现函数 download() 来自动完成。

如果数据集是一个zip文件,可以直接继承 dgl.data.DGLBuiltinDataset 类。后者支持解压缩zip文件。 否则用户需要自己实现 download(),其实现的步骤可以参考QM7Dataset类,其内容如下:

import os
from dgl.data.utils import download

def download(self):
    # 存储文件的路径
    file_path = os.path.join(self.raw_dir, self.name + '.mat')
    # 下载文件
    download(self.url, path=file_path)

请注意上面的代码是将.mat的文件下载到用户指定的self.raw_dir中,但是如果是.gz.tar.tar.gz.tgz文件则需要参考下面的BitcoinOTCDataset 类下载.gz文件的方法:

from dgl.data.utils import download, check_sha1

def download(self):
    # 存储文件的路径,请确保使用与原始文件名相同的后缀
    gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
    # 下载文件
    download(self.url, path=gz_file_path)
    # 检查 SHA-1
    if not check_sha1(gz_file_path, self._sha1_str):
        raise UserWarning('File {} is downloaded but the content hash does not match.'
                          'The repo may be outdated or download may be incomplete. '
                          'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
    # 将文件解压缩到目录self.raw_dir下的self.name目录中
    self._extract_gz(gz_file_path, self.raw_path)

上面的代码会将文件解压缩到 self.raw_dir 下的目录 self.name 中。 如果该类继承自 dgl.data.DGLBuiltinDataset 来处理zip文件, 则它也会将文件解压缩到目录 self.name 中。一个可选项是用户可以按照上面的示例检查下载后文件的SHA-1字符串,以防作者在远程服务器上更改了文件

数据处理

本小节主要是用于实现MyDataset中的process()函数,它主要是根据不同图任务进行区分的。

图上的机器学习任务主要包括三个大类:

  1. 整张图的分类:也就是将看到的图看做一个整体,预测该图属于什么类型或什么领域,一般这种图的数据规模较小
  2. 节点分类:主要是对于图中的节点进行预测
  3. 链路预测:这种问题主要是针对于节点分类的回归问题,也就是不需要知道明确的节点分类,只需要预测出边或者节点的特征值

这里着重说一下第三点:这个有点类似于之前完成的Hugging-Face的例子,由于我不想对PLM进行训练,而是直接使用它已经训练好的冻结后的梯度和参数,所以我可以直接使用它作为一个特征提取器,这种做的效果如果不是追求SOTA的话往往也是比较好的。具体的还是参考这篇:
HuggingFace-利用BERT预训练模型实现中文情感分类(下游任务)

本节重点介绍了处理图、特征和划分掩码的标准方法。

处理整图分类数据集

Dataset部分

整图分类数据集与用小批次训练的典型机器学习任务中的大多数数据集类似。 因此,需要将原始数据处理为 dgl.DGLGraph 对象的列表和标签张量的列表。 此外,如果原始数据已被拆分为多个文件,则可以添加参数 split导入数据的特定部分
下面是 QM7bDataset 的示例:

from dgl.data import DGLDataset

class QM7bDataset(DGLDataset):
    _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
           'datasets/qm7b.mat'
    _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'

    def __init__(self, raw_dir=None, force_reload=False, verbose=False):
        super(QM7bDataset, self).__init__(name='qm7b',
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        mat_path = self.raw_path + '.mat'
        # 将数据处理为图列表和标签列表
        self.graphs, self.label = self._load_graph(mat_path)

    def __getitem__(self, idx):
        """ 通过idx获取对应的图和标签

        Parameters
        ----------
        idx : int
            Item index

        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return self.graphs[idx], self.label[idx]

    def __len__(self):
        """数据集中图的数量"""
        return len(self.graphs)

加载图数据集

可以看到这里的process()函数并没有直接对原图数据集进行处理,而是使用的_load_graph()函数,对于这个函数我们可以通过QM7bDataset的源码看到:

def _load_graph(self, filename):
    data = io.loadmat(filename)
    labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])
    feats = data['X']
    num_graphs = labels.shape[0]
    graphs = []
    for i in range(num_graphs):
        edge_list = feats[i].nonzero()
        g = dgl_graph(edge_list)
        g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
                                dtype=F.data_type_dict['float32'])
        graphs.append(g)
    return graphs, labels

_load_graph最终返回了处理后的图特征列表与标签特征列表。

增加额外属性

用户还可以向类添加属性以指示一些有用的数据集信息。在 QM7bDataset 中, 用户可以添加属性 num_tasks 来指示此多任务数据集中的预测任务总数

@property
def num_tasks(self):
    """每个图的标签数,即预测任务数。"""
    return 14

使用定义的Dataset

编写完以上代码后,就可以使用下面的方式使用定义的QM7bDataset

import dgl
import torch

from dgl.dataloading import GraphDataLoader

# 数据导入
dataset = QM7bDataset()
num_tasks = dataset.num_tasks

# 创建 dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)

# 训练
for epoch in range(100):
    for g, labels in dataloader:
        # 用户自己的训练代码
        pass

处理节点分类数据集

与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码来指定数据集的划分。
此外,DGL推荐重新排列图的节点/边,使得相邻节点/边的ID位于邻近区间内。这个过程 可以提高节点/边的邻居的局部性,为后续在图上进行的计算与分析的性能改善提供可能。 DGL提供了名为 dgl.reorder_graph() 的API用于此优化。

Dataset部分

以内置数据集 CitationGraphDataset 为例:

from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url

class CitationGraphDataset(DGLBuiltinDataset):
    _urls = {
        'cora_v2' : 'dataset/cora_v2.zip',
        'citeseer' : 'dataset/citeseer.zip',
        'pubmed' : 'dataset/pubmed.zip',
    }

    def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
        assert name.lower() in ['cora', 'citeseer', 'pubmed']
        if name.lower() == 'cora':
            name = 'cora_v2'
        url = _get_dgl_url(self._urls[name])
        super(CitationGraphDataset, self).__init__(name,
                                                   url=url,
                                                   raw_dir=raw_dir,
                                                   force_reload=force_reload,
                                                   verbose=verbose)

    def process(self):
        # 跳过一些处理的代码
        # === 跳过数据处理 ===

        # 构建图
        g = dgl.graph(graph)

        # 划分掩码
        g.ndata['train_mask'] = train_mask
        g.ndata['val_mask'] = val_mask
        g.ndata['test_mask'] = test_mask

        # 节点的标签
        g.ndata['label'] = torch.tensor(labels)

        # 节点的特征
        g.ndata['feat'] = torch.tensor(_preprocess_features(features),
                                       dtype=F.data_type_dict['float32'])
        self._num_tasks = onehot_labels.shape[1]
        self._labels = labels
        # 重排图以获得更优的局部性
        self._g = dgl.reorder_graph(g)

    def __getitem__(self, idx):
        assert idx == 0, "这个数据集里只有一个图"
        return self._g

    def __len__(self):
        return 1

为简便起见,这里省略了 process() 中的一些代码,以突出展示用于处理节点分类数据集的关键部分:划分掩码。 节点特征和节点的标签被存储在 g.ndata 中。
这里 getitem(idx) 和 len() 的实现也发生了变化, 这是因为节点分类任务通常只用一个图。掩码在PyTorchTensorFlow中是bool张量,在MXNet中是float张量。

使用定义的Dataset

下面中使用 dgl.data.CitationGraphDataset 的子类 dgl.data.CiteseerGraphDataset 来演示如何使用用于节点分类的数据集:

# 导入数据
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]

# 获取划分的掩码
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']

# 获取节点特征
feats = graph.ndata['feat']

# 获取标签
labels = graph.ndata['label']

处理链接预测数据集

链接预测数据集的处理与节点分类相似,数据集中通常只有一个图。
以内置的数据集 KnowledgeGraphDataset 为例,其中省略了详细的数据处理代码以突出展示处理链接预测数据集的关键部分:

Dataset部分

# 创建链接预测数据集示例
class KnowledgeGraphDataset(DGLBuiltinDataset):
    def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
        self._name = name
        self.reverse = reverse
        url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
        super(KnowledgeGraphDataset, self).__init__(name,
                                                    url=url,
                                                    raw_dir=raw_dir,
                                                    force_reload=force_reload,
                                                    verbose=verbose)

    def process(self):
        # 跳过一些处理的代码
        # === 跳过数据处理 ===

        # 划分掩码
        g.edata['train_mask'] = train_mask
        g.edata['val_mask'] = val_mask
        g.edata['test_mask'] = test_mask

        # 边类型
        g.edata['etype'] = etype

        # 节点类型
        g.ndata['ntype'] = ntype
        self._g = g

    def __getitem__(self, idx):
        assert idx == 0, "这个数据集只有一个图"
        return self._g

    def __len__(self):
        return 1

在上面代码中,图的edata存储了划分掩码。

使用定义的Dataset

from dgl.data import FB15k237Dataset

# 导入数据
dataset = FB15k237Dataset()
graph = dataset[0]

# 获取训练集掩码
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
src, dst = graph.edges(train_idx)

# 获取训练集中的边类型
rel = graph.edata['etype'][train_idx]

保存和加载数据

DGL建议用户实现保存和加载数据的函数,将处理后的数据缓存在本地磁盘中。 这样在多数情况下可以帮用户节省大量的数据处理时间。DGL提供了4个函数让任务变得简单。

  • dgl.save_graphs()dgl.load_graphs(): 保存DGLGraph对象和标签到本地磁盘和从本地磁盘读取它们。
  • dgl.data.utils.save_info()dgl.data.utils.load_info(): 将数据集的有用信息(python dict对象)保存到本地磁盘和从本地磁盘读取它们。

下面的示例显示了如何保存和读取图和数据集信息的列表。

import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info

def save(self):
    # 保存图和标签
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    save_graphs(graph_path, self.graphs, {'labels': self.labels})
    # 在Python字典里保存其他信息
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    save_info(info_path, {'num_classes': self.num_classes})

def load(self):
    # 从目录 `self.save_path` 里读取处理过的数据
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    self.graphs, label_dict = load_graphs(graph_path)
    self.labels = label_dict['labels']
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    self.num_classes = load_info(info_path)['num_classes']

def has_cache(self):
    # 检查在 `self.save_path` 里是否有处理过的数据文件
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    return os.path.exists(graph_path) and os.path.exists(info_path)

注意:有些情况下不适合保存处理过的数据;
例如,在内置数据集 GDELTDataset 中, 处理过的数据比较大。所以这个时候,在 __getitem__(idx) 中处理每个数据实例是更高效的方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1251934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BTP-K710自定义纸张大小

进入控制面板 输入 control 打开打印机 点击查看设备和打印机 打开打印机属性 鼠标右击 -> 点击打开打印机属性 如果上面没有你想要的尺寸 点标签库 -> 选中Custom Paper-BPLZ -> 然后修改宽度和高度 -> 点击应用 选择Custom Paper-BPLZ 最后点击确定 这样…

计数问题+约瑟夫问题(map)

目录 一、计数问题 二、约瑟夫问题 一、计数问题 #include<iostream> #include<map> using namespace std; int main() {int n,x;cin>>n>>x;map<int,int>m;for(int i1;i<n;i){if(i>1 && i<10){m[i];}else{int temp i;while (…

Task.Run为什么会存在内存泄漏的风险?

由于值类型是拷贝方式的赋值&#xff0c;捕获的本地变量和类成员是指向各自的值&#xff0c;对本地变量的捕获不会影响到整个类。但如果把类中的值类型改为引用类型&#xff0c;那这两者最终指向的是同一个对象值&#xff0c;这是否意味着使用本地变量还是无法避免内存泄漏&…

1.如何修改seruat对象的行名 2.FeaturePlot如何把所有阳性表达的spot放到图的前面

本有解决标题中的两个问题 1.答案是修改不了&#xff0c;不如重新制作一个seurat对象。 试图使用rownames&#xff08;obj&#xff09;featurenames是不成功的 记录 客户需求遇到一个问题&#xff1a;作者提供的rds文件行名为ensemble id&#xff0c;如何改成gene symbol。 …

解决hbuilder使用android studio模拟器不能热更新

hbuilder使用android studio模拟器编&#xff0c;在编写代码时&#xff0c;不能热更新&#xff0c;总是需要重启虚拟机中的程序&#xff0c;hbuilderx的版本是3.1.22&#xff0c;android studio的版本是4.2.2 同时在hbuilderx中出现如下报错信息&#xff1a; 报错信息&#x…

ASP产品通过网络安全专用产品安全认证

什么是网络安全专用产品安全检测&#xff1f; 网络安全专用产品安全检测是指对网络关键设备和网络安全专用产品进行安全性评估和检测&#xff0c;以确保其符合相关标准和法规的要求&#xff0c;能够有效地抵御网络攻击和威胁。该检测由具备资格的机构进行&#xff0c;采用认证…

禁止指定电脑程序运行的2种方法

你可能要问了&#xff0c;为什么要禁止电脑程序运行呢&#xff0c;因为有的公司要净化公司的工作环境&#xff0c;防止某些刺头员工在公司电脑上瞎搞。也有部分家长&#xff0c;是为了防止自己家的孩子利用电脑乱下载东西。 今天就分享2种禁止指定电脑程序运行的方法&#xff1…

C语言SO EASY(ZZULIOJ1220: SO EASY)

题目描述 Superbin最近在研究初等数论&#xff0c;初等数论 是研究数的规律&#xff0c;特别是整数性质的数学分支。它是数论的一个最古老的分支。它以算术方法为主要研究方法&#xff0c;主要内容有整数的整除理论、同余理论、连分数理论和某些特殊不定方程。 是定义在正整数…

【DP】mobiusp正在创作乐曲

输入样例1&#xff1a; 5 2 1 7 7 1 3 输出样例1&#xff1a; 2 输入样例2&#xff1a; 10 3 2 5 6 4 4 5 7 3 5 6 输出样例2&#xff1a; 1 #include<iostream> #include<cstring> #include<algorithm> #include<vector> using namespace std; typede…

遥遥领先!TinyEngine 低代码引擎更新升级!AI 已成功部署!

官网更新 1. 协议规范优化更新 协议规范交互优化&#xff0c;能够在一页里面自由切换&#xff0c;提高用户体验&#xff0c;后端SDK文档字段描述补齐&#xff0c;助力开发者自助进行服务端开发。 2.使用手册全新改版 去掉之前的学院课程&#xff0c;新设计新分类让结构一目了…

【云备份】文件操作实用工具类设计

文章目录 为什么要单独设计文件工具类&#xff1f;整体实现Filesize ——文件大小stat接口 LastMTime ——最后一次修改时间LastATime —— 最后一次访问时间FileName —— 文件名称GetPostLen ——获取文件指定位置 指定长度的数据GetContnet —— 读取文件数据SetContent ——…

文献速递:人工智能在新生儿重症监护室:现在是时候了

人工智能在新生儿重症监护室&#xff1a;现在是时候了 01 文献速递介绍 文章介绍了AI的多学科特性&#xff0c;包括计算机科学、数学、神经科学和哲学。AI的目标是通过各种计算和算法技术创建智能机器。尽管早期对人类水平AI的预测未能实现&#xff0c;但对AI的期待仍然强烈…

图书馆管理系统源码(Java)

Book包->内含Book类与BookList类 Book类 用于初始化图书并对其进行操作 BookList类 用于存放多本图书 Book类 package Book;public class Book {private String name;private String author;private int price;private String type;private boolean isBorrow;//写一个…

10.打印楼梯图案和笑脸【2023.11.25】

1.问题描述 打印楼梯图案和笑脸 2.解决思路 3.代码实现 #include<stdio.h> int main(){printf("11\n");for(int i0;i<10;i){for(int j0;j<i;j){printf("FF"); }printf("\n");} return 0; }4.运行结果

Java研学-异常

一 异常 异于正常状态的显示,即控制台显示结果与预期的不一致 // 抛出异常java.lang.ArrayIndexOutOfBoundsException int[] nums new int[2]; System.out.print(nums[2]);二 异常分类 1 检测性异常:有称为非运行时异常,一般是在编写代码的过程中编程软件直接报错 2 非检测…

用Python写的自动答题脚本,正确率99%,刷课必备

咱们上学的时候&#xff0c;有些课程&#xff0c;不是很重要所以没去看&#xff0c;但是又要刷题&#xff0c;这时候&#xff0c;自动答题脚本就很关键&#xff0c;全程全自动&#xff0c;正确率还嘎嘎高&#xff01; 所以咱们今天来用Python写一个自动答题脚本&#xff0c;正…

图的建立基本操作

#include <stdio.h> #include <stdlib.h> #include <stdbool.h> // 添加头文件#define MAX_VERTEX_NUM 100 //图中最大顶点数//struct ArcNode* nextarc; //ArcNode* firstarc; //这两个是很有必要的&#xff0c;如果你没有这两个指针&#xff0c;你就无法判…

2.前端--HTML标签1【2023.11.25】

1.基本语法规范 HTML 标签是由尖括号包围的关键词&#xff0c;例如 <html>。HTML 标签通常是成对出现的&#xff0c;例如 和 &#xff0c;我们称为双标签。有些特殊的标签必须是单个标签&#xff08;极少情况&#xff09;&#xff0c;例如 <br />我们称为单标签。 …

2022年12月 Scratch(三级)真题解析#中国电子学会#全国青少年软件编程等级考试

Scratch等级考试(1~4级)全部真题・点这里 一、单选题(共25题,每题2分,共50分) 第1题 默认小猫角色和气球角色都是显示状态,小猫程序如下图所示,气球没有程序,点击绿旗,舞台上最终显示的效果是?( ) A:可能出现6个不同位置的小猫和6个小球 B:可能出现6个不同位…

Python基于jieba+wordcloud实现文本分词、词频统计、条形图绘制及不同主题的词云图绘制

目录 序言&#xff1a;第三方库及所需材料函数模块介绍分词词频统计条形图绘制词云绘制主函数 效果预览全部代码 序言&#xff1a;第三方库及所需材料 编程语言&#xff1a;Python3.9。 编程环境&#xff1a;Anaconda3&#xff0c;Spyder5。 使用到的主要第三方库&#xff1a;…