RootNeighboursDataset(helpers.dataset_classes文件中的root_neighbours_dataset.py)

news2024/10/23 7:40:53

任务类型:回归
用途:在 `RootNeighboursDataset` 中,任务是给定一棵根树,预测根节点度数为6的邻居的特征平均值。因此,模型需要基于根节点的结构,找到度为6的邻居,并计算其特征的平均值。这属于回归问题,因为目标是预测连续值(特征的平均值)

from helpers.dataset_classes.root_neighbours_dataset import RootNeighboursDataset

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensor


class RootNeighboursDataset(object):

    def __init__(self, seed: int, print_flag: bool = False):
        super().__init__()
        self.seed = seed
        self.plot_flag = print_flag
        self.generator = torch.Generator().manual_seed(seed)
        self.constants_dict = self.initialize_constants()

        self._data = self.create_data()

    def get(self) -> Data:
        return self._data

    def create_data(self) -> Data:
        # train, val, test
        data_list = []
        for num in range(self.constants_dict['NUM_COMPONENTS']):
            data_list.append(self.generate_component())
        return Batch.from_data_list(data_list)

    def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:
        num_nodes = sum(num_nodes_per_fold)
        train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)
        val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)
        test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)

        train_mask[0] = True
        val_mask[num_nodes_per_fold[0]] = True
        test_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = True
        return train_mask, val_mask, test_mask

    def generate_component(self) -> Data:
        data_per_fold, num_nodes_per_fold = [], []
        for fold_idx in range(3):
            data = self.generate_fold(eval=(fold_idx != 0))
            num_nodes_per_fold.append(data.x.shape[0])
            data_per_fold.append(data)

        train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)

        batch = Batch.from_data_list(data_per_fold)
        return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask,
                    test_mask=test_mask)

    def initialize_constants(self) -> Dict[str, int]:
        return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5,
                'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}

    def generate_fold(self, eval: bool) -> Data:
        constant_dict = self.initialize_constants()
        MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\
            [constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS',
                                            'MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]

        assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORS
        add_hubs = ADD_HUBS if eval else 0
        num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubs
        num_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,),
                                           generator=self.generator).item()
        assert num_hubs <= num_1hop_neighbors

        list_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,),
                                                generator=self.generator).tolist()
        list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors

        # 2 hop edge index
        num_nodes = 1  # root node is 0
        idx_1hop_neighbors = []
        list_edge_index = []
        for num_2hop_neighbors in list_num_2hop_neighbors:
            idx_1hop_neighbors.append(num_nodes)
            if num_2hop_neighbors > 0:
                clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])
                # clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).T
                list_edge_index.append(clique_edge_index + num_nodes)

            num_nodes += num_2hop_neighbors + 1

        # 1 hop edge index
        idx_0hop = torch.tensor([0] * num_1hop_neighbors)
        idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)
        hubs = idx_1hop_neighbors[:num_hubs]
        list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))
        edge_index = torch.cat(list_edge_index, dim=1)

        # undirect
        edge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)
        edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)

        # features
        x = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2

        # labels
        y = torch.zeros_like(x)
        y[0] = torch.mean(x[hubs], dim=0)
        return Data(x=x, edge_index=edge_index, y=y)


if __name__ == '__main__':
    data = RootNeighboursDataset(seed=0, print_flag=True)

这个 RootNeighboursDataset通过随机生成的树状图数据来模拟一种节点关系,并基于图结构生成特征和标签。代码使用了 PyTorchPyTorch Geometric 的功能来处理图数据。下面逐块详细解释该代码实现:

1. RootNeighboursDataset 类构造器

import torch
from torch_geometric.data import Data, Batch
from typing import Dict, Tuple, List
from torch import Tensor


class RootNeighboursDataset(object):

    def __init__(self, seed: int, print_flag: bool = False):
        super().__init__()
        self.seed &#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2221426.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用自定义 ref 实现函数防抖

今天来简单介绍一个新的方法&#xff0c;使用自定义 ref 实现函数防抖。 1. 自定义 ref 的来源 自定义 ref 防抖函数来自于前端开发中的两个概念&#xff1a;Vue 的响应式系统 和 数防抖&#xff08;Debounce&#xff09;。 1、Vue 响应式系统&#xff1a;Vue 提供了 ref 和…

Python学习的自我理解和想法(20)

#1024程序员节|征文# 学的是b站的课程&#xff08;千锋教育&#xff09;&#xff0c;跟老师写程序&#xff0c;不是自创的代码&#xff01; 今天是学Python的第20天&#xff0c;学的内容是面向对象中的私有属性&#xff0c;私有方法&#xff0c;多态&#xff0c;单例计模式。开…

【ubuntu18.04】ubuntu18.04升级cmake-3.29.8及还原系统自带cmake操作说明

参考链接 cmake升级、更新&#xff08;ubuntu18.04&#xff09;-CSDN博客 升级cmake操作说明 下载链接 Download CMake 下载版本 下载软件包 cmake-3.30.3-linux-x86_64.tar.gz 拷贝软件包到虚拟机 cp /var/run/vmblock-fuse/blockdir/jrY8KS/cmake-3.29.8-linux-x86_64…

spring源码中的,函数式接口,注解@FunctionalInterface

调用方 /org/springframework/beans/factory/support/AbstractBeanFactory.java:333sharedInstance getSingleton(beanName, () -> {try {return createBean(beanName, mbd, args);}catch (BeansException ex) {// Explicitly remove instance from singleton cache: It mi…

高级的SQL查询技巧有哪些?

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///C爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于高级SQL查询技巧方面的相关内容&#xf…

MATLAB人脸考勤系统

MATLAB人脸考勤系统课题介绍 该课题为基于MATLAB平台的人脸识别系统。传统的人脸识别都是直接人头的比对&#xff0c;现实意义不大&#xff0c;没有一定的新意。该课题识别原理为&#xff1a;先采集待识别人员的人脸&#xff0c;进行训练&#xff0c;得到人脸特征值。测试的时…

HomeAssistant自定义组件学习-【一】

#环境准备# 按官方的步骤准备就可以&#xff0c;我是在Windows下使用VS Code开发的&#xff0c;安装了WSL&#xff08;使用模板创建组件需要在WSL环境下完成&#xff09; 官方链接&#xff1a;https://developers.home-assistant.io/docs/development_environment 环境准备好…

力扣困难题汇总(14道)

题4&#xff08;困难&#xff09;&#xff1a; 思路&#xff1a; 找两数组中位数&#xff0c;这个看起来简单&#xff0c;顺手反应就是数第(mn)/2个&#xff0c;这个难在要求时间复杂度为log(mn)&#xff0c;所以不能这样搞&#xff0c;我的思路是&#xff1a;每次切割长度为较…

【K8s】Kubernetes 词汇表

微思网络 厦门微思网络 K8S认证工程师&#xff08;CKA&#xff09;备考与学习指南https://mp.weixin.qq.com/s/XsEVpU7dKnJDBopynWW3GQ K8S-CKA课程试听:Container 概述 词汇表 此术语表旨在提供 Kubernetes 术语的完整、标准列表。其中包含特定于 Kubernetes 的技术术语以及…

uniapp修改input中placeholder样式

Uniapp官方提供了两种修改的属性方法&#xff0c;但经过测试&#xff0c;只有 placeholder-class 属性能够生效 <input placeholder"请输入手机验证码" placeholder-class"input-placeholder"/><!-- css --> <style lang"scss" s…

redis的zset实现下滑滚动分页查询思路

常规zset查询 我们redis的数据为 我们知道 我们常规查询的话 我们假如 zset 表中 有7个元素&#xff0c;然后我们进行分页查询的话&#xff0c;我们一次查3个元素&#xff0c;然后查出来元素 和元素的分数 我们redis的语法应该这样写 zrevrangebyscore wang 1000 0 withsc…

kotlin实现viewpager

说明:kotlin tablayout viewpager adapter实现滑动界面 效果图 step1: package com.example.flushfragmentdemoimport androidx.appcompat.app.AppCompatActivity import android.os.Bundle import androidx.fragment.app.Fragment import androidx.viewpager2.adapter.…

【uni-app学习-2】

一、跳转 方法&#xff1a;在methods中去定义方法&#xff1a; 上述为直接跳转&#xff0c;但是当你要跳转页面是由多个可切换页面组成比如&#xff1a; 这个页面其实是由两个页面组成&#xff0c;一个主页&#xff0c;一个我的&#xff0c;两个页面 路由配置需要用到toob…

java--多态(详解)

目录 一、概念二、多态实现的条件三、向上转型和向下转型3.1 向上转型3.2 向下转型 四、重写和重载五、理解多态5.1练习&#xff1a;5.2避免在构造方法中调用重写的方法&#xff1a; 欢迎来到权权的博客~欢迎大家对我的博客提出指导这是我的博客主页&#xff1a;点击 一、概念…

EasyExcel自定义下拉注解的三种实现方式

文章目录 一、简介二、关键组件1、ExcelSelected注解2、ExcelDynamicSelect接口&#xff08;仅用于方式二&#xff09;3、ExcelSelectedResolve类4、SelectedSheetWriteHandler类 三、实际应用总结 一、简介 在使用EasyExcel设置下拉数据时&#xff0c;每次都要创建一个SheetWr…

韩语干货topik韩语考级柯桥外语培训韩语中的惯用表达

表示递进的词尾或惯用表达 1 -을/ㄹ 뿐만 아니라 接在动词和形容词词干后面&#xff0c;表示“不仅...而且...”。该语法需要注意前后会有两个动词或形容词&#xff0c;此时两个动词或形容词的时态应保持一致。 例: 한번 파괴된 자연은 되돌리기기 쉽지 않을 뿐만 아니라 지역…

Java项目实战II基于微信小程序的原创音乐平台{UNIAPP+SSM+MySQL+Vue}(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在数字音乐…

《Order-Agnostic Data Augmentation for Few-Shot Named Entity Recognition》中文

文章汉化系列目录 文章目录 文章汉化系列目录摘要1 引言2 相关工作2.1 NER的数据增强2.2 少样本命名实体识别&#xff08;Few-Shot NER&#xff09; 3 无序数据增强3.1 公式化3.2 通过实体重排进行数据增强3.3 构建唯一的输入-输出对3.4 使用 OADA-XE 校准预测 4 实验4.1 不同D…

【ELK】初始阶段

一、logstash学习 安装的时候最好不要有中文的安装路径 使用相对路径 在 Windows PowerShell 中&#xff0c;如果 logstash 可执行文件位于当前目录下&#xff0c;你需要使用相对路径来运行它。尝试输入以下命令&#xff1a; .\logstash -e ‘input { stdin { } } output { s…

[软件工程]—嵌入式软件开发流程

嵌入式软件开发流程 1.工程文件夹目录 ├─00_Project_Management ├─00_Reference ├─01_Function_Map ├─02_Hardware ├─03_Firmware ├─04_Software ├─05_Mechanical ├─06_FCT └─07_Tools00_Project_Management 子文件夹如下所示&#xff1a; ├─00_需求导…