关于 ogbg-molhi数据集的个人解析

news2025/1/23 10:45:05

 cs224w_colab2.py这个图属性预测到底咋预测的

dataset.meta_info.T
Out[2]:
num tasks                                                                1 
eval metric                                                         rocauc
download_name                                                          hiv
version                                                                  1
url                      http://snap.stanford.edu/ogb/data/graphproppre...
add_inverse_edge                                                      True
data type                                                              mol
has_node_attr                                                         True
has_edge_attr                                                         True
task type                                            binary classification
num classes                                                              2
split                                                             scaffold
additional node files                                                 None
additional edge files                                                 None
binary                                                               False
Name: ogbg-molhiv, dtype: object

参照上面 这里的num tasks  仅适用于图属性预测? num tasks = 1

model = GCN_Graph(args['hidden_dim'],
            dataset.num_tasks, args['num_layers'],
            args['dropout']).to(device)


train_loader.dataset.data.edge_index.shape
Out[10]: torch.Size([2, 2259376])

train_loader.dataset.data.edge_attr.shape
Out[12]: torch.Size([2259376, 3])



type(train_loader.dataset.data.node_stores)
Out[26]: list

 

train_loader.dataset.data.node_stores[0]['y'].shape
Out[46]: torch.Size([41127, 1])
train_loader.dataset.data.node_stores[0]['y'].sum()
Out[47]: tensor(1443) y 中的数值求和值

torch.unique(train_loader.dataset.data.node_stores[0]['y'],return_counts=True)
Out[58]: (tensor([0, 1]), tensor([39684,  1443]))  仅0,1两类


self.node_encoder.atom_embedding_list
Out[62]: 
ModuleList(
  (0): Embedding(119, 256)
  (1): Embedding(5, 256)
  (2): Embedding(12, 256)
  (3): Embedding(12, 256)
  (4): Embedding(10, 256)
  (5): Embedding(6, 256)
  (6): Embedding(6, 256)
  (7): Embedding(2, 256)
  (8): Embedding(2, 256)
)



list(enumerate(data_loader))
Out[82]: 
[(0,
  DataBatch(edge_index=[2, 1734], edge_attr=[1734, 3], x=[807, 9], y=[32, 1], num_nodes=807, batch=[807], ptr=[33])),
  若干组 很多
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch
        embed = self.node_encoder(x)  #使用编码器 将原先9维的编码为256维 self.node_encoder = AtomEncoder(hidden_dim)

        out = self.gnn_node(embed, edge_index) #使用gcn得到节点嵌入 embed=X edge_index 连边/节点对 
        out = self.pool(out, batch)


batch.unique(return_counts = True)
Out[94]: 
(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
        device='cuda:0'), 这里说明有31个待训练子图(化学分子) 下图api说明了聚合过程
 tensor([30, 18, 21, 26, 12, 20, 17, 18, 36, 11, 31, 22, 21, 26, 22, 21, 21, 63,
         15, 18, 18, 29, 18, 40, 41, 19, 19, 30, 12, 21, 19, 23], 每个分子中所含有的节点(原子)数量
        device='cuda:0'))
batch.shape
Out[95]: torch.Size([758])

def global_mean_pool(x: Tensor, batch: Optional[Tensor],
                     size: Optional[int] = None) -> Tensor:

    dim = -1 if x.dim() == 1 else -2 #这里的x.dim() = 2 
# dim() → int Returns the number of dimensions of self tensor.

    if batch is None:
        return x.mean(dim=dim, keepdim=x.dim() <= 2) #keepdim=x.dim() <= 2 ??啥玩意<=
    size = int(batch.max().item() + 1) if size is None else size
    return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')

This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in PyTorch, which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.该包由一个小型扩展库组成,该库包含用于PyTorch的高度优化的稀疏更新(分散和分段)操作,这些操作在主包中丢失。散射和分段运算可以粗略地描述为基于给定“群索引”张量的归约运算。分段运算需要对“组索引”张量进行排序,而分散运算则不受这些要求的约束。

由此(scatter)由多个节点的嵌入值最终得到这部分节点所在的子图嵌入(化学分子)。

    def forward(self, batched_data):
        # TODO: Implement a function that takes as input a
        # mini-batch of graphs (torch_geometric.data.Batch) and
        # returns the predicted graph property for each graph.
        #
        # NOTE: Since we are predicting graph level properties,
        # your output will be a tensor with dimension equaling
        # the number of graphs in the mini-batch
       
         x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch
        embed = self.node_encoder(x)  #使用编码器 将原先9维的编码为256维 self.node_encoder = AtomEncoder(hidden_dim)

        out = self.gnn_node(embed, edge_index) #使用gcn得到节点嵌入 embed=X edge_index 连边/节点对
        out = self.pool(out, batch)
        out = self.linear(out)

        ############# Your code here ############
        ## Note:
        ## 1. Construct node embeddings using existing GCN model
        ## 2. Use the global pooling layer to aggregate features for each individual graph
        ## For more information please refer to the documentation:
        ## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#global-pooling-layers
        ## 3. Use a linear layer to predict each graph's property
        ## (~3 lines of code)

        #########################################

        return out

out.shape
Out[122]: torch.Size([32, 1])
out
Out[121]: 
tensor([[-0.4690],
        [-1.0285],
        [-0.4614],
最后经过线性层 返回得到所属类别概率 运行到如下部分结束反向传播forward() (op = model(batch)# 先进入model函数 然后运行 反向传播)



def train(model, device, data_loader, optimizer, loss_fn):
    # TODO: Implement a function that trains your model by
    # using the given optimizer and loss_fn.
    model.train()  #Sets the module in training mode. data_loader.dataset.data Data(num_nodes=1049163, edge_index=[2, 2259376], edge_attr=[2259376, 3], x=[1049163, 9], y=[41127, 1])
    loss = 0

    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")): #,total= data_loader.batch_sampler
    # for step, batch in tqdm(enumerate(data_loader), desc="Iteration"): #,total= data_loader.batch_sampler
      batch = batch.to(device)

      if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
          pass
      else:
        ## ignore nan targets (unlabeled) when computing training loss.
        is_labeled = batch.y == batch.y # 0/1转化为Ture/False

        ############# Your code here ############
        ## Note:
        ## 1. Zero grad the optimizer
        ## 2. Feed the data into the model
        ## 3. Use `is_labeled` mask to filter output and labels
        ## 4. You may need to change the type of label to torch.float32
        ## 5. Feed the output and label to the loss_fn
        ## (~3 lines of code)

        optimizer.zero_grad()
        # print('optimizer.zero_grad()')
        op = model(batch)# 先进入model函数 然后运行 反向传播



。。。。。。。。。。。。。。。后面计算损失 更新梯度等等

存在错误等欢迎指正! 附件为整个作业的.py文件

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/996323.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

阿里云和腾讯云2核2G服务器价格和性能对比

2核2G云服务器可以选择阿里云服务器或腾讯云服务器&#xff0c;腾讯云轻量2核2G3M带宽服务器95元一年&#xff0c;阿里云轻量2核2G3M带宽优惠价108元一年&#xff0c;不只是轻量应用服务器&#xff0c;阿里云还可以选择ECS云服务器u1&#xff0c;腾讯云也可以选择CVM标准型S5云…

SpringBoot整合Redis 并 展示使用方法

步骤 引入依赖配置数据库参数编写配置类构造RedisTemplate创建测试类测试 1.引入依赖 不写版本号&#xff0c;也是可以的 在pom中引入 <!--redis配置--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-…

Kafka消费者组重平衡(一)

文章目录 概述消费者组特点什么是 Coordinator重平衡影响 概述 重平衡&#xff0c;也就是Rebalance, 就是让一个 Consumer Group 下所有的 Consumer 实例就如何消费订阅主题的所有分区达成共识的过程。在 Rebalance 过程中&#xff0c;所有 Consumer 实例共同参与&#xff0c;…

ShardingSphere分库分表(一):高性能架构模式

互联网业务兴起之后&#xff0c;海量用户加上海量数据的特点&#xff0c;单个数据库服务器已经难以满足业务需要&#xff0c;必须考虑数据库集群的方式来提升性能。高性能数据库集群的第一种方式是“读写分离”&#xff0c;第二种方式是“数据库分片”。 文章目录 1、读写分离架…

地理测绘基础知识(6) 照射距离等值线计算

上一篇文章中&#xff0c;我们采用HPR坐标系里的向量旋转&#xff0c;在地表绘制了这样的螺旋线&#xff1a; 在复杂多样的现实应用需求中&#xff0c;还有一种更为普遍的计算需求&#xff0c;就是求取地表到全向光源的距离为D的所有点的集合&#xff08;用多边形组成的近似椭…

MongoDB简介以及安装

文章目录 1. MongoDB简介2. NoSQL简介3. MongoDB安装 1. MongoDB简介 MongoDB是一种NoSQL数据库&#xff0c;采用了文档数据库模型。它以BSON&#xff08;Binary JSON&#xff09;格式存储数据&#xff0c;支持动态模式和灵活的查询语言。MongoDB具有以下特点&#xff1a; 文…

虚拟机 + Ubuntu22.04 + ros2 (humble) colcon build turtlebot3_node失败的解决方案

一、问题描述 在虚拟机Ubuntu22.04中安装了ROS2&#xff08;humble&#xff09;,下载turtlebot3。在colcon build --symlink-install 编译的过程中turtlebot3_Fake_node一直失败&#xff0c;无法正常运行&#xff0c;影响后面的仿真测试。 二、解决方案 查阅相关资料后发现问…

JAVA 从入门到起飞 面向对象 day08 P2

老师的知识点1 在JAVA中&#xff0c;必须先设计类&#xff0c;才能获得对象。 我的理解&#xff1a; 疑问&#xff1a;为什么是这样的呢&#xff1f; 答案&#xff1a; 在 JAVA 或其他面向对象的编程语言中&#xff0c;类是对象的蓝图或模板。这意味着在你创建对象之前&am…

【已解决】在Win11上离线安装 .NET Framework 3.5的方法【含网盘离线文件】

随 Windows 11提供的是.NET Framework 4.8&#xff0c;该环境可以运行任何 .NET Framework 4.x 应用。 而.NET Framework 3.5 支持为 .NET Framework 2.0 到 3.5 生成的应用&#xff0c;需要自行安装。 当Win11的应用软件需要.net framework3.5的运行环境时&#xff0c;就会提…

领域驱动设计:微服务架构模型

文章目录 整洁架构六边形架构DDD 分层架构三种微服务架构模型的对比和分析从三种架构模型看中台和微服务设计 整洁架构 整洁架构又名“洋葱架构”。为什么叫它洋葱架构&#xff1f;整洁架构的层就像洋葱片一样&#xff0c;它体现了分层的设计思想。在整洁架构里&#xff0c;同…

跨站请求伪造

CSRF是什么&#xff1f; 跨站请求伪造(Cross Site Request Forgery&#xff0c;CSRF)是一种攻击&#xff0c;它强制浏览器客户端用户在当前对其进行身份验证后的Web 应用程序上执行非本意操作的攻击&#xff0c;攻击的重点在于更改状态的请求&#xff0c;而不是盗取数据&#x…

西部是真的地广人稀啊,常用地市东西分布差异明显

背景 最近在使用folium处理一些工作上的事情&#xff0c;这过程中发现一些GPS坐标数据的获取和置换不是太方便&#xff0c;尤其是坐标置换&#xff0c;做了一些工作进行了GPS坐标数据秘坐标置换方向的封装。 GPS坐标类封装的过程中&#xff0c;发现一些常用的GPS坐标的查取比…

安装程序报错“E: Sub-process /usr/bin/dpkg returned an error code (1)”的解决办法

今天在终端使用命令安装程序时出现了如下的报错信息。 E: Sub-process /usr/bin/dpkg returned an error code (1) 这种情况下安装什么程序最终都会报这个错&#xff0c;具体的报错截图如下图所示。 要解决这个问题&#xff0c;首先使用下面的命令进到相应的目录下。 cd /var/…

项目02—基于keepalived+mysqlrouter+gtid半同步复制的MySQL集群

文章目录 一.项目介绍1.拓扑图2.详细介绍 二.前期准备1.项目环境2.IP划分 三. 项目步骤1.ansible部署软件环境1.1 安装ansible环境1.2 建立免密通道1.3 ansible批量部署软件1.4 统一5台mysql服务器的数据 2.配置基于GTID的半同步主从复制2.1 在master上安装配置半同步的插件,再…

蓝桥杯官网练习题(玩具蛇)

题目描述 本题为填空题&#xff0c;只需要算出结果后&#xff0c;在代码中使用输出语句将所填结果输出即可。 小蓝有一条玩具蛇&#xff0c;一共有 16 节&#xff0c;上面标着数字 1 至 16。每一节都是一个正方形的形状。相邻的两节可以成直线或者成 90 度角。 小蓝还有一个…

ROS学习笔记(五)---话题发布

1. 话题通信是什么 在ROS&#xff08;机器人操作系统&#xff09;中&#xff0c;话题通信是一种常用的通信机制&#xff0c;用于在不同的ROS节点之间传递消息。话题通信基于发布者-订阅者模式&#xff0c;其中一个节点&#xff08;发布者&#xff09;发布消息到一个特定的话题…

使用最新android sdk 将jar文件编译成dex

最近需要一些比较骚的操作&#xff0c;所以需要将gson编译成dex。 因为手上有jar包&#xff0c;所以就拿出了android sdk准备一把入魂&#xff0c;结果报错不断&#xff0c;让人无奈。只好根据报错来调整编译步骤&#xff0c;不得不为安卓环境更新Debug。 1、dx变d8 并不确定…

postgresql-通用表达式

postgresql-通用表达式 入门案例简单CTE递归 CTE案例1案例2 入门案例 -- 通用表达式 with t(n) as (select 2) select * from t;简单CTE WITH cte_name (col1, col2, ...) AS (cte_query_definition ) sql_statement;WITH 表示定义 CTE&#xff0c;因此 CTE 也称为 WITH 查询…

Pandas中at、iat函数详解

前言 嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 at 函数&#xff1a;通过行名和列名来取值&#xff08;取行名为a, 列名为A的值&#xff09; iat 函数&#xff1a;通过行号和列号来取值&#xff08;取第1行&#xff0c;第1列的值&#xff09; 本文给出at、iat常见的…

Mybatis-Plus-入门简介(2)

Mybatis-Plus-入门简介 1.简介 Mybatis-Plus官网&#xff1a;https://baomidou.com/ Mybatis-Plus仓库地址&#xff1a;https://mvnrepository.com/artifact/com.baomidou/mybatis-plus-boot-starter 仓库地址&#xff1a;仓库地址&#xff1a;https://gitee.com/long-xiaozhe…