Code Lab - 34

news2025/1/11 23:02:34

        GAT里面有一些地方看的不是太懂(GAT里Multi Attention的具体做法),暂时找了参考代码,留一个疑问


1. 一个通用的GNN Stack

import torch_geometric
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args.model_type)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        #assert(断言)  用于判断一个表达式,在表达式条件为 false 的时候触发异常
        assert (args.num_layers >= 1), 'Number of layers is not >=1'
        for l in range(args.num_layers-1):
            self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), 
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.emb = emb

    def build_conv_model(self, model_type):
        if model_type == 'GraphSage':
            return GraphSage
        elif model_type == 'GAT':
            # When applying GAT with num heads > 1, you need to modify the 
            # input and output dimension of the conv layers (self.convs),
            # to ensure that the input dim of the next layer is num heads
            # multiplied by the output dim of the previous layer.
            # HINT: In case you want to play with multiheads, you need to change the for-loop that builds up self.convs to be
            # self.convs.append(conv_model(hidden_dim * num_heads, hidden_dim)), 
            # and also the first nn.Linear(hidden_dim * num_heads, hidden_dim) in post-message-passing.
            return GAT

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout,training=self.training)

        x = self.post_mp(x)

        if self.emb == True:
            return x

        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)

2. 实现GraphSage和GAT

2.1 GraphSage

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        # self.lin_l is the linear transformation that you apply to embedding for central node.
        self.lin_l=Linear(in_channels,out_channels)  #Wl
        # self.lin_r is the linear transformation that you apply to aggregated message from neighbors.
        self.lin_r=Linear(in_channels,out_channels)  #Wr

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        # 调用propagation函数进行消息传递:propagate(edge_index, x=(x_i, x_j), extra=(extra_i, extra_j), size=size)
        # 我们将只使用邻居节点(x_j)的表示,因此默认情况下我们为中心节点和邻居节点传递与x=(x,x)相同的表示
        out1 = self.lin_l(x)
        out2 = self.propagate(edge_index,x = (x,x),size = size)
        out2 = self.lin_r(out2)
        out = out1 + out2
        if self.normalize:
            out = F.normalize(out)
        return out
    
    # 供propagate调用,对于所有(i,j)边,构造从邻点j到中心点i的信息
    # x_j表示 所有邻点的特征嵌入矩阵  
    def message(self, x_j):
        out = x_j
        return out
    
    # 聚合邻居信息
    def aggregate(self, inputs, index, dim_size = None):
        # The axis along which to index number of nodes.
        node_dim = self.node_dim
        out = torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')
        return out

2.2 GAT

class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        
        # self.lin_l is the linear transformation that you apply to embeddings 
        # Pay attention to dimensions of the linear layers, since we're using multi-head attention.
        self.lin_l = Linear(in_channels,heads*out_channels)  #W_l  这里的in_channels就是已经乘过heads的数字
        self.lin_r = self.lin_l  #W_r
        # Define the attention parameters \overrightarrow{a_l/r}^T in the above intro.
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)

    def forward(self, x, edge_index, size = None):
        H, C = self.heads, self.out_channels
        x_l = self.lin_l(x)
        x_r = self.lin_r(x)
        x_l = x_l.view(-1,H,C)
        x_r = x_r.view(-1,H,C)
        alpha_l = (x_l * self.att_l).sum(axis=1)  #*是逐元素相乘(每个特征对应的所有节点一样处理?)。sum的维度是H(聚合)。
        alpha_r = (x_r * self.att_r).sum(axis=1)
        out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)
        out = out.view(-1, H * C)
        return out

    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        #alpha:[E, C]
        alpha = alpha_i + alpha_j  #leakyrelu的对象
        alpha = F.leaky_relu(alpha,self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1)  #[E,1,C]
        out = x_j * alpha  #通过计算得到的alpha来计算节点信息聚合值(得到h_i^')  #[E,H,C]
        return out

    def aggregate(self, inputs, index, dim_size = None):
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')
        return out

3. 训练

3.1 优化器

import torch.optim as optim

def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    return scheduler, optimizer

3.2 训练

import time

import networkx as nx
import numpy as np
import torch
import torch.optim as optim
from tqdm import trange
import pandas as pd
import copy

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.nn as pyg_nn

import matplotlib.pyplot as plt


def train(dataset, args):
    
    print("Node task. test set size:", np.sum(dataset[0]['test_mask'].numpy()))
    print()
    test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # build model
    model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, 
                            args)
    scheduler, opt = build_optimizer(args, model.parameters())

    # train
    losses = []
    test_accs = []
    best_acc = 0
    best_model = None
    for epoch in trange(args.epochs, desc="Training", unit="Epochs"):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        losses.append(total_loss)

        if epoch % 10 == 0:
          test_acc = test(test_loader, model)
          test_accs.append(test_acc)
          if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(model)
        else:
          test_accs.append(test_accs[-1])
    
    return test_accs, losses, best_model, best_acc, test_loader

def test(loader, test_model, is_validation=False, save_model_preds=False, model_type=None):
    test_model.eval()

    correct = 0
    # Note that Cora is only one graph!
    for data in loader:
        with torch.no_grad():
            # max(dim=1) returns values, indices tuple; only need indices
            pred = test_model(data).max(dim=1)[1]
            label = data.y

        mask = data.val_mask if is_validation else data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = label[mask]

        if save_model_preds:
          print ("Saving Model Predictions for Model Type", model_type)

          data = {}
          data['pred'] = pred.view(-1).cpu().detach().numpy()
          data['label'] = label.view(-1).cpu().detach().numpy()

          df = pd.DataFrame(data=data)
          # Save locally as csv
          df.to_csv('CORA-Node-' + model_type + '.csv', sep=',', index=False)
            
        correct += pred.eq(label).sum().item()

    total = 0
    for data in loader.dataset:
        total += torch.sum(data.val_mask if is_validation else data.test_mask).item()

    return correct / total
  
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d
for args in [
    {'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},
]:
    args = objectview(args)
    for model in ['GraphSage']:
        args.model_type = model

        # Match the dimension.
        if model == 'GAT':
          args.heads = 2
        else:
          args.heads = 1

        if args.dataset == 'cora':
            dataset = Planetoid(root='/tmp/cora', name='Cora')
        else:
            raise NotImplementedError("Unknown dataset") 
        test_accs, losses, best_model, best_acc, test_loader = train(dataset, args) 

        print("Maximum test set accuracy: {0}".format(max(test_accs)))
        print("Minimum loss: {0}".format(min(losses)))

        # Run test for our best model to save the predictions!
        test(test_loader, best_model, is_validation=False, save_model_preds=True, model_type=model)
        print()

        plt.title(dataset.name)
        plt.plot(losses, label="training loss" + " - " + args.model_type)
        plt.plot(test_accs, label="test accuracy" + " - " + args.model_type)
    plt.legend()
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/945469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

package.json 详解

文章目录 package.json1. name2. version3. description4. homepage5. bugs6. license7. author, contributors8. funding9. files10. main11. module12. browser13. bin14. man15. directories15.1 directories.bin15.2 directories.man 16. repository17. scripts18. config1…

Unity引擎修改模型顶点色的工具

大家好,我是阿赵。   之前分享过怎样通过MaxScript在3DsMax里面修改模型的顶点色。不过由于很多时候顶点色的编辑需要根据在游戏引擎里面的实际情况和shader的情况来动态调整,所以如果能在引擎里面直接修改模型的顶点色,将会方便很多。于是…

指向任意节点的带环链表

🌈图示指向任意节点的带环链表 如图: 🌈快慢指针法判断链表是否带环 🌟思路:快指针fast一次走2步,慢指针slow一次走1步,fast先进环在换中运动,随后slow进入环。两指针每同时移动…

复原20世纪复古修仙游戏

前言 在本教程中,我突发奇想,想做一个复古的修仙游戏,考虑到以前的情怀决定做个古老的躺平修仙游戏 📝个人主页→数据挖掘博主ZTLJQ的主页 个人推荐python学习系列: ☄️爬虫JS逆向系列专栏 - 爬虫逆向教学 ☄️python…

英特尔Raptor Lake Refresh第14代CPU:传闻发布日期、价格、规格等

英特尔预计将在今年秋天推出第14代Raptor Lake-S Refresh CPU。虽然即将推出的系列芯片沿用了当前的第13代英特尔核心系列,但它们实际上是相同CPU的更新版本。 Raptor Lake-s Refresh芯片没有任何官方消息,但几次所谓的泄露让我们了解了我们可能会期待什…

发光太阳聚光器的蒙特卡洛光线追踪研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

C++字符串详解

C 大大增强了对字符串的支持,除了可以使用C风格的字符串,还可以使用内置的 string 类。string 类处理起字符串来会方便很多,完全可以代替C语言中的字符数组或字符串指针。 string 是 C 中常用的一个类,它非常重要,我们…

uniapp 项目实践总结(一)uniapp 框架知识总结

导语:最近开发了一个基于 uniapp 框架的项目,有一些感触和体会,所以想记录以下一些技术和经验,在这里做一个系列总结,算是对自己做一个交代吧。 目录 简介全局文件全局组件常用 API条件编译插件开发 简介 uniapp 是…

Matlab(变量与文本读取)

目录 1.变量(数据)类型转换 1.1 字符 1.2 字符串 1.3 逻辑操作与赋值 2.Struct结构体数组 2.1函数的详细介绍: 2.1.1 cell2struct 2.1.1.1 垂直维度转换 2.1.1.2 水平维度转换 2.1.1.3 部分进行转换 2.1.2 rmfield 2.1.3 fieldnames(查…

Git分布式版本控制系统与github

第四阶段提升 时 间:2023年8月29日 参加人:全班人员 内 容: Git分布式版本控制系统与github 目录 一、案例概述 二、版本控制系统 (一) 本地版本控制 (二)集中化的版本控制系统 &…

C++day7(auto关键字、lambda表达式、C++中的数据类型转换、C++标准模板库(STL)、list、文件操作)

一、Xmind整理&#xff1a; 关键词总结&#xff1a; 二、上课笔记整理&#xff1a; 1.auto关键字 #include <iostream>using namespace std;int fun(int a, int b, float *c, char d, double *e,int f) {return 12; }int main() {//定义一个函数指针&#xff0c;指向fu…

【USRP】集成化仪器系列1 :信号源,基于labview实现

USRP 信号源 1、设备IP地址&#xff1a;默认为192.168.10.2&#xff0c;请勿 修改&#xff0c;运行阶段无法修改。 2、天线输出端口是TX1&#xff0c;请勿修改。 3、通道&#xff1a;0 对应RF A、1 对应 RF B&#xff0c;运行 阶段无法修改。 4、中心频率&#xff1a;当需要…

自然语言处理-NLP

目录 自然语言处理-NLP 致命密码&#xff1a;一场关于语言的较量 自然语言处理的发展历程 兴起时期 符号主义时期 连接主义时期 深度学习时期 自然语言处理技术面临的挑战 语言学角度 同义词问题 情感倾向问题 歧义性问题 对话/篇章等长文本处理问题 探索自然语言…

腾讯云学生免费服务器如何申请?

腾讯云学生免费服务器如何申请?学生机申请流程&#xff0c;腾讯云学生服务器优惠活动&#xff1a;轻量应用服务器2核2G学生价30元3个月、58元6个月、112元一年&#xff0c;轻量应用服务器4核8G配置191.1元3个月、352.8元6个月、646.8元一年&#xff0c;CVM云服务器2核4G配置84…

老年人跌倒智能识别算法 opencv

老年人跌倒智能识别算法通过opencvpython深度学习算法框架模型&#xff0c;老年人跌倒智能识别算法能够及时发现老年人跌倒情况&#xff0c;提供快速的援助和救援措施&#xff0c;保障老年人的安全。Python是一种由Guido van Rossum开发的通用编程语言&#xff0c;它很快就变得…

读书笔记——《万物有灵》

前言 上一本书是《走出荒野》&#xff0c;太平洋步道女王提到了这本书《万物有灵》&#xff0c;她同样是看一点撕一点的阅读。我想&#xff0c;在她穿越山河森林&#xff0c;听见鸟鸣溪流的旅行过程中&#xff0c;是不是看这本描写动物有如何聪明的书——《万物有灵》&#xf…

完善区域企业监测预警机制,助推区域产业可持续发展

“五度易链”产业大数据解决方案由产业经济、智慧招商、企业服务、数据服务四大应用解决方案组成&#xff0c;囊括了产业经济监测、产业诊断分析、企业监测预警、企业综合评估、大数据精准招商、招商智能管理、企业管理、企业培育、企业市场服务、企业金融服务、产业数据开放服…

流程解决方案公司:用低代码技术平台实现流程化办公!

很多粉丝朋友会询问道可以用什么样的软件平台实现流程化办公。作为提供流程解决方案公司&#xff0c;流辰信息专业研发低代码技术平台&#xff0c;并且一直保持自主研发的奋斗心态&#xff0c;针对不同行业的特性&#xff0c;提供专属的框架定制服务&#xff0c;为客户朋友实现…

微信小程序左上角home图标的解决方法之一 层级混乱导致的home图标显示的问题 自定义左上角左侧图标的返回路径

这个项目的编辑页在tabbar上 导致跳到tabbar得使用wx.switchTab 保存后返回原来的页面就出现了左上角的home图标 本来想通过自定义home图标的跳转路径来解决这个问题 没想到居然找不到相关内容 有清楚的朋友麻烦给我留个言不胜感激 那我写一下我的骚操作 app.js globalData: {…

移动端和PC端对比【组件库+调试vconsole +单位postcss-pxtorem+构建vite/webpack+可视化echarts/antv】

目录 组件库 移动端 vue vant PC端 react antd vue element 调试&#xff1a;vconsole vs dev tools中的控制台&#xff08;Console&#xff09; ​​​​​​​vconsole&#xff1a;在真机上调试 postcss-pxtorem&#xff1a;移动端不同的像素密度 构建工具 web…