2024.6.16周报

news2025/1/11 4:11:06

目录

摘要

ABSTRACT

一、文献阅读

一、题目

二、摘要

三、创新点

四、模型架构

五、文章解读

1、Introduction

2、实验

3、结论

二、代码复现

1、模型代码

2、实验结果 

三、总结


摘要

本周我阅读了一篇题目为《Contaminant Transport Modeling and Source Attribution With Attention‐Based Graph Neural Network》的论文,这篇论文引入了一种新的基于注意力的图神经网络(aGNN),专门用于在有限监测数据下模拟污染物迁移并量化污染源及其传播之间的因果关系。此外,aGNN的解释性分析能有效量化每个污染源的影响,证实了其在地下污染物运移研究中的高效性和减少计算成本的能力,为地下水管理提供了一个有力的工具。通过复现其代码,对模型的架构有了更深刻的理解。

ABSTRACT

This week, I rear a paper titled "Contaminant Transport Modeling and Source Attribution With Attention-Based Graph Neural Network" . In the paper, a new attention-based graph neural network (aGNN) was introduced, which is specifically designed to simulate contaminant migration under limited monitoring data and to quantify the causal relationships between pollutant sources and their propagation. Moreover, the interpretative analysis of aGNN was shown to effectively quantify the impact of each pollution source, confirming its efficiency in studies of subsurface contaminant migration and its ability to reduce computational costs, providing a powerful tool for groundwater management. By reproducing its code, a deeper understanding of the model's architecture was gained.

一、文献阅读

一、题目

题目:Contaminant Transport Modeling and Source Attribution With Attention‐Based Graph Neural Network

期刊:Water Resources Research

链接:https://doi.org/10.1029/2023WR035278

二、摘要

文章引用了一种称为基于注意力的图神经网络(aGNN)的新型机器学习模型,该模型旨在使用稀疏的监测数据对污染物传输进行建模,并分析污染物源与特定位置观测到的浓度之间的因果关系。文章在具有不同检测设置的不同含水层系统中进行了五个综合案例研究,其中aGNN表现最佳;此外,aGNN的解释性分析有效地量化了每个污染物源的影响,总结来说,这篇论文将aGNN确立为一种高效而稳健的地下污染物迁移复杂时空学习方法,它也成为地下水管理和污染源识别的一个重要工具。

The article employs a novel machine learning model known as Attention-based Graph Neural Networks (aGNN), which is designed to model the transport of pollutants using sparse monitoring data and to analyze the causal relationships between pollutant sources and the concentrations observed at specific locations. Five comprehensive case studies were conducted in various aquifer systems with different detection setups, where the aGNN demonstrated superior performance. Furthermore, the interpretability analysis of aGNN effectively quantified the impact of each pollutant source. In summary, this paper establishes aGNN as an efficient and robust method for complex spatiotemporal learning of subsurface pollutant migration, making it a significant tool for groundwater management and pollutant source identification.

三、创新点

(1)该文提出一种基于图的深度学习方法,用于模拟受监测数据约束的污染物迁移;

(2)所提出的模型量化了每个潜在污染源对任意位置观测浓度的贡献;

(3)与基于物理的污染物传输模型相比,深度学习方法大大降低了计算成本;

四、模型架构

使用深度学习和基于物理的模型(MODFLOW和MT3DMS)两种方法进行污染物传输建模的工作流程和数据概述。这些模型在三个任务中进行评估:转导学习、归纳学习和模型解释。

图1展示了使用深度学习(DL)方法和基于物理的模型来模拟地下水质量对多源污染排放的时空响应。深度学习模型,如aGNN、CNN和RNN,不需详细的水文地质信息,而物理模型如MODFLOW和MT3DMS则依赖这些数据。DL模型通过端到端学习,整合MODFLOW和MT3DMS的功能,处理水排放、污染物释放及其浓度和地下水位的数据。文章还评估了这些模型在转导学习、归纳学习和模型解释方面的效果,特别是通过Shapley值方法来分析和量化多点污染源的影响,以提供地下水管理和污染源识别的见解。

图2展示了aGNN的体系结构,这是一个基于编码器-解码器框架的系统。该体系由五个主要模块组成:

1、输入模块

(1)编码器输入和解码器输入:这两个模块负责构建节点的特征向量(包括监测点的污染物浓度、流量动态等)、空间信息以及邻接矩阵。编码器输入通常设计过去的时间步骤;解码器输入则关注未来的时间步骤。

(2)图嵌入模块:空间嵌入:通过对节点的地理位置或其他空间属性进行编码,捕捉节点间的空间关系。

                                时间嵌入:将时间信息转换为嵌入表示,使模型能够捕捉时间变化的模式和趋势。

时间嵌入可以使用时间顺序信息,给定一个时间序列S =(s0,s1,…,sT),时间嵌入层形成一个有限维表示来表示si在序列S中的位置。研究中的时间嵌入是将正弦变换串联到时间顺序,形成矩阵TE\in\mathbb{R}^{T\times d_{emb}},其中T和d_{emb}分别是时间长度和向量维数。TE设计为式2和式3,
其中2d和2d + 1分别表示偶数维和奇数维,t为时间序列中的时间位置。时间嵌入的维数为demb
3所示,时间嵌入中的每个元素都结合了时间顺序位置和特征空间的信息。

(3)编码器模块

 查询(Q),键(K)和值(V)。其思想是将Q和一组K‐V对映射到输出,使输出表示V的加权和。权重由相应的K和Q决定,然后应用Softmax函数对权重值进行归一化。

多头自注意力机制(MSA):允许模型在处理每个节点的特征时,同时考虑多种不同的解释和侧重点,从而更好地理解数据中的复杂模式。

 Q与解码器输入相关,K和V与编码器生成的隐藏特征相关。MSA特别关注自我注意机制,该机制适用于与自身交互的输入,在数学上,Q、K和V采用相同的原始输入(如公式6中的Xq=Xk=Xv)。MSA允许模型捕获输入序列中的不同方面和依赖关系,从而对特征元素之间的关系提供更全面的理解。

图卷积网络(GCN):通过在图结构中传播和更新节点信息,学习节点的特征表示。GCN通过使用节点及其邻居的信息,增强了模型对整个网络结构的理解。

在MSA阻塞后,GCN通过图结构在节点之间交换信息来提取中间表示,从而对空间依赖关系进行建模。GCNs使用图卷积过滤器,设计用于建模节点依赖关系。GCN的主要思想是构建一个消息传递网络,其中信息沿着图内相邻节点传播。

多头注意(MTA):MAT将信息从编码器传输到解码器。MAT作为编码器和解码器之间的链接。编码器的堆叠输出作为V和K传递给MAT,并将注意力分数分配给解码器输入的表示(即Q)。解码器中的MSA和GCN进行类似于机器翻译任务的学习过程,其中,解码器输入表示需要翻译成另一种“语言”的“一种语言中的句子”。

(4)解码器模块

与编码器结构相似,解码器同样包括多头自注意力机制和GCN层。不同的是,解码器更侧重于使用编码器的输出(隐藏状态)来生成对未来状态的预测。

(5)输出模块

最终生成的是目标序列预测,如污染物在未来某一时间点在地下水中的预期移动。

五、文章解读

1、Introduction

文章提出了aGNN,一种新的基于注意力的图神经建模框架,它结合了(a)图卷积网络(GCN)、(b)注意力机制和(c)嵌入层来模拟地下水系统中的污染物输送过程。GCN通过消息通过节点和边缘提取图信息,有效学习空间模式。注意机制是变压器网络中擅长序列分析的关键组成部分。嵌入层是潜在空间学习机制,代表了时空过程中的高维性。对交通和行人轨迹的研究表明,基于注意力的图神经网络在单过程时空预测任务中表现出竞争性的表现。在本研究中,作者将其应用扩展到学习地下水流动和溶质输送问题中的多个过程。此外,在尚未研究的未监测污染位置,采用了新的坐标嵌入方法进行归纳学习。本研究的目标有三个方面。首先研究了aGNN在涉及污染物迁移建模的多过程中的性能。基于GNN、CNN和LSTM的方法适用于多步空间预测的端到端学习任务,以深入了解每个模型的执行情况。其次,根据数据的可用性和含水层的非均质性,评估了aGNN将从监测数据中学习到的知识通过归纳学习转移到未监测站点的能力。第三,采用了一种可解释的人工智能技术,即沙普利值,它起源于合作博弈论的概念。

2、实验

1、研究区域

本研究设计了两个采用非承压含水层的合成研究场地,用于方法的开发和验证。第一个研究场地面积为497,500平方米,通过MODFLOW划分为30列和15行的网格,每个网格50米x50米。研究场地设置了两侧无通量边界和两侧恒定水头边界(分别为100米和95米)。为了研究水力传导率异质性对污染物传输模型的影响,考虑了两种水力传导率场景:一种是五个不同区域的水力传导率从15到35米/天变化;另一种是水力传导率从0到50米/天变化。污染物传输在MT3DMS中以30米的均匀纵向分散性进行模拟,并设置了三个间歇性排放污染水的注水井。第二个研究场地(场景C)覆盖面积180平方公里,是第一个场地的约360倍,划分为120列和150行的网格,每个网格100米x100米,并设有四个区域的水力传导率从30到350米/天变化。两个场地都设置了监测系统,包括水位下降和污染物浓度的日常数据记录。本研究还考察了三个水力传导率场的三个监测网络,观察它们对污染物移动反应的学习过程如何受到数据大小的影响。

2、实验准备

使用MODFLOW和MT3DMS模拟生成污染物运输数据集,并用于训练和评估不同的深度学习模型。数据集中80%用于训练,20%用于性能评估。所有DL模型均通过批量优化进行训练,批次大小为16,迭代400个周期,模型输出观测位置的地下水位降低(GD)和污染物浓度(CC)的预测,预测时域为50时间步。

上表为三种检测网络中不同算法的输入维度和参数数量

3、实验结果

四种DL模型:DCRNN、aGNN、aGNN-noE(无嵌入模块的aGNN变体)和ConvLSTM。这些模型都使用编解码器框架,但在输入设计上有所不同。输入特征包括静态特征(S)、历史行为(H)和计划特征(F)。静态特征代表坐标信息,历史行为详细记录了地下水排放和污染物释放的两个计划及监测的GD和CC,计划特征包含预测期内的地下水排放和污染物释放计划。

上图(a)、(b)、(c)显示的是传导学习中,三个检测网络中,污染源及其邻居具有较大的节点强度。

上图(d)、(e)、(f)显示的是归纳学习M1、M2、M3的预测区域。 

表2总结了整个数据集的统计特征,并按80/20的比例划分为训练和测试集。结果显示aGNN在五种不同情况下的测试性能。CC的变异范围是GD的五倍,表现出更高的分散性。本研究将多目标任务(涉及GD和CC)转化为单目标,使用加权和方法,CC权重为5,GD权重为1。此外,含水层非均质性对GD的影响较小,与CC相比,水头对电导率的非均质性敏感度较低。场景C中,由于场地更大且监测井更少,所有模型的精度均有所下降。在所有算法中,aGNN在几乎所有五种情况下均获得最低RMSE和最高R^{2}(表2),表明其在模拟非均匀分布监测系统中污染物迁移方面的性能优于其他算法。

 图7展示了四种模型的预测误差。ConvLSTM在空间上的RMSE较高,通常超过1 mg/L,而DCRNN的RMSE普遍低于0.3 mg/L,尤其在A-M1、B-M1、A-M2和B-M2区域。aGNN和aGNN-noE的性能优于DCRNN,显示更小的RMSE波动,这证明了基于注意力的图卷积网络的优势。aGNN在所有模型中展示了最小的RMSE变化,突显出其在捕捉空间变化,尤其是在污染源下游区域的高效性。此外,研究还使用了相对绝对误差(RAE)来测量预测值与真实值之间的差异,发现使用aGNN时RAE降低。

3、结论

本研究开发了一种新型数据驱动模型aGNN,用于模拟非均质地下水含水层中的污染物传输,特别强调数据有限且分布不均的情况。aGNN模型结合了注意力机制、时空嵌入层和图卷积网络层(GCN),优化了污染物传输的时空学习精度,通过动态权重分配、特征转换和信息传递提高模型效率。实验结果显示,aGNN在预测精度上达到了99%的R^{2}值,证明了其高效的预测能力。此外,aGNN能够利用图学习从监测地点提供的数据推断未监测地点的观测,即使在监测井有限的大型场地或高度非均质的含水层中也能有效捕捉污染物的时空变化。aGNN还通过SHAP方法分析污染源归因,展示了其作为数值模拟模型MODFLOW和MT3DMS的有效替代品。此方法大幅减轻了基于物理模型的计算负担,特别是在需要处理大量注入井和长期管理的情景中,显著提高了计算效率。

二、代码复现

1、模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import numpy as np
from utils_dpl3_contam import norm_Adj


class RBF(nn.Module):
    """
    Transforms incoming data using a given radial basis function:
    u_{i} = rbf(||x - c_{i}|| / s_{i})
    Arguments:
        in_features: size of each input sample
        out_features: size of each output sample
    Shape:
        - Input: (N, in_features) where N is an arbitrary batch size
        - Output: (N, out_features) where N is an arbitrary batch size
    Attributes:
        centres: the learnable centres of shape (out_features, in_features).
            The values are initialised from a standard normal distribution.
            Normalising inputs to have mean 0 and standard deviation 1 is
            recommended.

        log_sigmas: logarithm of the learnable scaling factors of shape (out_features).

        basis_func: the radial basis function used to transform the scaled
            distances.
    """

    def __init__(self, in_features, out_features, num_vertice,basis_func):
        super(RBF, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres1 = nn.Parameter(torch.Tensor(num_vertice, self.in_features))  # (out_features, in_features)
        self.alpha = nn.Parameter(torch.Tensor(num_vertice,out_features))
        self.log_sigmas = nn.Parameter(torch.Tensor(out_features))
        self.basis_func = basis_func
        self.reset_parameters()


        # self.alpha1 = nn.Parameter(torch.Tensor(num_vertice, self.out_features))
    def reset_parameters(self):
        nn.init.normal_(self.centres1, 0, 1)
        nn.init.constant_(self.log_sigmas, 0)

    def forward(self, input):

        size1= (input.size(0), input.size(0), self.in_features)
        x1 = input.unsqueeze(1).expand(size1)
        c1 = self.centres1.unsqueeze(0).expand(size1)
        distances1 = torch.matmul((x1 - c1).pow(2).sum(-1).pow(0.5),self.alpha) / torch.exp(self.log_sigmas).unsqueeze(0)
        return self.basis_func(distances1) #distances1


# RBFs

def gaussian(alpha):
    phi = torch.exp(-1 * alpha.pow(2))
    return phi


def linear(alpha):
    phi = alpha
    return phi


def quadratic(alpha):
    phi = alpha.pow(2)
    return phi


def inverse_quadratic(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2))
    return phi


def multiquadric(alpha):
    phi = (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi


def inverse_multiquadric(alpha):
    phi = torch.ones_like(alpha) / (torch.ones_like(alpha) + alpha.pow(2)).pow(0.5)
    return phi


def spline(alpha):
    phi = (alpha.pow(2) * torch.log(alpha + torch.ones_like(alpha)))
    return phi


def poisson_one(alpha):
    phi = (alpha - torch.ones_like(alpha)) * torch.exp(-alpha)
    return phi


def poisson_two(alpha):
    phi = ((alpha - 2 * torch.ones_like(alpha)) / 2 * torch.ones_like(alpha)) \
          * alpha * torch.exp(-alpha)
    return phi


def matern32(alpha):
    phi = (torch.ones_like(alpha) + 3 ** 0.5 * alpha) * torch.exp(-3 ** 0.5 * alpha)
    return phi


def matern52(alpha):
    phi = (torch.ones_like(alpha) + 5 ** 0.5 * alpha + (5 / 3) \
           * alpha.pow(2)) * torch.exp(-5 ** 0.5 * alpha)
    return phi


def basis_func_dict():
    """
    A helper function that returns a dictionary containing each RBF
    """

    bases = {'gaussian': gaussian,
             'linear': linear,
             'quadratic': quadratic,
             'inverse quadratic': inverse_quadratic,
             'multiquadric': multiquadric,
             'inverse multiquadric': inverse_multiquadric,
             'spline': spline,
             'poisson one': poisson_one,
             'poisson two': poisson_two,
             'matern32': matern32,
             'matern52': matern52}
    return bases
###############################################################################################################

def clones(module, N):
    '''
    Produce N identical layers.
    :param module: nn.Module
    :param N: int
    :return: torch.nn.ModuleList
    '''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def subsequent_mask(size):
    '''
    mask out subsequent positions.
    :param size: int
    :return: (1, size, size)
    '''
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0   # 1 means reachable; 0 means unreachable


class spatialGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels):
        super(spatialGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix, x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))


class GCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels):
        super(GCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, F_in)
        :return: (batch_size, N, F_out)
        '''
        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix, x)))  # (N,N)(b,N,in)->(b,N,in)->(b,N,out)


class Spatial_Attention_layer(nn.Module):
    '''
    compute spatial attention scores
    '''
    def __init__(self, dropout=.0):
        super(Spatial_Attention_layer, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, T, N, N)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        score = torch.matmul(x, x.transpose(1, 2)) / math.sqrt(in_channels)  # (b*t, N, F_in)(b*t, F_in, N)=(b*t, N, N)

        score = self.dropout(F.softmax(score, dim=-1))  # the sum of each row is 1; (b*t, N, N)

        return score.reshape((batch_size, num_of_timesteps, num_of_vertices, num_of_vertices))


class spatialAttentionGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels, dropout=.0):
        super(spatialAttentionGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)
        self.SAt = Spatial_Attention_layer(dropout=dropout)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''

        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        spatial_attention = self.SAt(x)  # (batch, T, N, N)

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))  # (b*t,n,f_in)

        spatial_attention = spatial_attention.reshape((-1, num_of_vertices, num_of_vertices))  # (b*T, n, n)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix.mul(spatial_attention), x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))
        # (b*t, n, f_in)->(b*t, n, f_out)->(b,t,n,f_out)->(b,n,t,f_out)


class spatialAttentionScaledGCN(nn.Module):
    def __init__(self, sym_norm_Adj_matrix, in_channels, out_channels, dropout=.0):
        super(spatialAttentionScaledGCN, self).__init__()
        self.sym_norm_Adj_matrix = sym_norm_Adj_matrix  # (N, N)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.Theta = nn.Linear(in_channels, out_channels, bias=False)
        self.SAt = Spatial_Attention_layer(dropout=dropout)

    def forward(self, x):
        '''
        spatial graph convolution operation
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        batch_size, num_of_vertices, num_of_timesteps, in_channels = x.shape

        spatial_attention = self.SAt(x) / math.sqrt(in_channels)  # scaled self attention: (batch, T, N, N)

        x = x.permute(0, 2, 1, 3).reshape((-1, num_of_vertices, in_channels))
        # (b, n, t, f)-permute->(b, t, n, f)->(b*t,n,f_in)

        spatial_attention = spatial_attention.reshape((-1, num_of_vertices, num_of_vertices))  # (b*T, n, n)

        return F.relu(self.Theta(torch.matmul(self.sym_norm_Adj_matrix.mul(spatial_attention), x)).reshape((batch_size, num_of_timesteps, num_of_vertices, self.out_channels)).transpose(1, 2))
        # (b*t, n, f_in)->(b*t, n, f_out)->(b,t,n,f_out)->(b,n,t,f_out)



class SpatialPositionalEncoding_RBF(nn.Module):
    def __init__(self, d_model, logitudelatitudes,num_of_vertices, dropout, gcn=None, smooth_layer_num=0):
        super(SpatialPositionalEncoding_RBF, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # self.embedding = torch.nn.Embedding(num_of_vertices, d_model)
        self.embedding = RBF(2, d_model, num_of_vertices,quadratic) # gaussin nn.Linear(4, d_model-4)
        self.logitudelatitudes = logitudelatitudes
        self.gcn_smooth_layers = None
        if (gcn is not None) and (smooth_layer_num > 0):
            self.gcn_smooth_layers = nn.ModuleList([gcn for _ in range(smooth_layer_num)])

    def forward(self, x,log1,lat1):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        # x,log,lat,t= x[0],x[1],x[2],x[3]
        batch, num_of_vertices, timestamps, _ = x.shape
        x_indexs = torch.concat((torch.unsqueeze(log1.mean(0).mean(-1),-1),torch.unsqueeze(lat1.mean(0).mean(-1),-1)),-1)# (N,)

        x_ind = torch.concat((
                              x_indexs[:, 0:1] ,
                              x_indexs[:, 1:] )
                             , axis=1)

        embed = self.embedding(x_ind.float()).unsqueeze(0)
        if self.gcn_smooth_layers is not None:
            for _, l in enumerate(self.gcn_smooth_layers):
                embed = l(embed)  # (1,N,d_model) -> (1,N,d_model)
        x = x + embed.unsqueeze(2)  # (B, N, T, d_model)+(1, N, 1, d_model)

        return self.dropout(x)


class TemporalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len, lookup_index=None):
        super(TemporalPositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)
        self.lookup_index = lookup_index
        self.max_len = max_len
        # computing the positional encodings once in log space
        pe = torch.zeros(max_len, d_model)
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        pe = pe.unsqueeze(0).unsqueeze(0)  # (1, 1, T_max, d_model)
        self.register_buffer('pe', pe)
        # register_buffer:
        # Adds a persistent buffer to the module.
        # This is typically used to register a buffer that should not to be considered a model parameter.

    def forward(self, x,t):
        '''
        :param x: (batch_size, N, T, F_in)
        :return: (batch_size, N, T, F_out)
        '''
        if self.lookup_index is not None:
            x = x + self.pe[:, :, self.lookup_index, :]  # (batch_size, N, T, F_in) + (1,1,T,d_model)
        else:
            x = x + self.pe[:, :, :x.size(2), :]

        return self.dropout(x.detach())


class SublayerConnection(nn.Module):
    '''
    A residual connection followed by a layer norm
    '''
    def __init__(self, size, dropout, residual_connection, use_LayerNorm):
        super(SublayerConnection, self).__init__()
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        self.dropout = nn.Dropout(dropout)
        if self.use_LayerNorm:
            self.norm = nn.LayerNorm(size)

    def forward(self, x, sublayer):
        '''
        :param x: (batch, N, T, d_model)
        :param sublayer: nn.Module
        :return: (batch, N, T, d_model)
        '''
        if self.residual_connection and self.use_LayerNorm:
            return x + self.dropout(sublayer(self.norm(x)))
        if self.residual_connection and (not self.use_LayerNorm):
            return x + self.dropout(sublayer(x))
        if (not self.residual_connection) and self.use_LayerNorm:
            return self.dropout(sublayer(self.norm(x)))


class PositionWiseGCNFeedForward(nn.Module):
    def __init__(self, gcn, dropout=.0):
        super(PositionWiseGCNFeedForward, self).__init__()
        self.gcn = gcn
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        '''
        :param x:  (B, N_nodes, T, F_in)
        :return: (B, N, T, F_out)
        '''
        return self.dropout(F.relu(self.gcn(x)))


def attention(query, key, value, mask=None, dropout=None):
    '''
    :param query:  (batch, N, h, T1, d_k)
    :param key: (batch, N, h, T2, d_k)
    :param value: (batch, N, h, T2, d_k)
    :param mask: (batch, 1, 1, T2, T2)
    :param dropout:
    :return: (batch, N, h, T1, d_k), (batch, N, h, T1, T2)
    '''
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # scores: (batch, N, h, T1, T2)

    if mask is not None:
        scores = scores.masked_fill_(mask == 0, -1e9)  # -1e9 means attention scores=0
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # p_attn: (batch, N, h, T1, T2)

    return torch.matmul(p_attn, value), p_attn  # (batch, N, h, T1, d_k), (batch, N, h, T1, T2)


class MultiHeadAttention(nn.Module):
    def __init__(self, nb_head, d_model, dropout=.0):
        super(MultiHeadAttention, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask: (batch, T, T)
        :return: x: (batch, N, T, d_model)
        '''
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        query, key, value = [l(x).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3) for l, x in
                             zip(self.linears, (query, key, value))]

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_qc_kc(nn.Module):  # key causal; query causal;
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag, kernel_size=3, dropout=.0):
        '''
        :param nb_head:
        :param d_model:
        :param num_of_weeks:
        :param num_of_days:
        :param num_of_hours:
        :param points_per_hour:
        :param kernel_size:
        :param dropout:
        '''
        super(MultiHeadAttentionAwareTemporalContex_qc_kc, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.padding = kernel_size - 1
        self.conv1Ds_aware_temporal_context = clones(nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding)), 2)  # # 2 causal conv: 1  for query, 1 for key
        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h, key_h = [l(x.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query[:, :, self.w_length + self.d_length:self.w_length + self.d_length + self.h_length, :], key[:, :, self.w_length + self.d_length:self.w_length + self.d_length + self.h_length, :]))]
                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query, key = [l(x.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query, key))]

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.conv1Ds_aware_temporal_context[0](query.permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.conv1Ds_aware_temporal_context[1](key[:, :,0:self.n_length, :].permute(0, 3, 1, 2))[:, :, :, :-self.padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_q1d_k1d(nn.Module):  # 1d conv on query, 1d conv on key
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag,  kernel_size=3, dropout=.0): #num_of_weeks, num_of_days, num_of_hours

        super(MultiHeadAttentionAwareTemporalContex_q1d_k1d, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.padding = (kernel_size - 1)//2

        self.conv1Ds_aware_temporal_context = clones(
            nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding)),
            2)  # # 2 causal conv: 1  for query, 1 for key

        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag  #num_of_hours * points_per_hour


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h, key_h = [l(x.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query[:, :,0: self.n_length, :], key[:, :, 0: self.n_length, :]))]
                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query, key = [l(x.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2) for l, x in zip(self.conv1Ds_aware_temporal_context, (query, key))]

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.conv1Ds_aware_temporal_context[0](query.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.conv1Ds_aware_temporal_context[1](key[:, :, 0:self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class MultiHeadAttentionAwareTemporalContex_qc_k1d(nn.Module):  # query: causal conv; key 1d conv
    def __init__(self, nb_head, d_model, num_of_lags, points_per_lag,  kernel_size=3, dropout=.0):
        super(MultiHeadAttentionAwareTemporalContex_qc_k1d, self).__init__()
        assert d_model % nb_head == 0
        self.d_k = d_model // nb_head
        self.h = nb_head
        self.linears = clones(nn.Linear(d_model, d_model), 2)  # 2 linear layers: 1  for W^V, 1 for W^O
        self.causal_padding = kernel_size - 1
        self.padding_1D = (kernel_size - 1)//2
        self.query_conv1Ds_aware_temporal_context = nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.causal_padding))
        self.key_conv1Ds_aware_temporal_context = nn.Conv2d(d_model, d_model, (1, kernel_size), padding=(0, self.padding_1D))
        self.dropout = nn.Dropout(p=dropout)
        self.n_length = num_of_lags * points_per_lag


    def forward(self, query, key, value, mask=None, query_multi_segment=False, key_multi_segment=False):
        '''
        :param query: (batch, N, T, d_model)
        :param key: (batch, N, T, d_model)
        :param value: (batch, N, T, d_model)
        :param mask:  (batch, T, T)
        :param query_multi_segment: whether query has mutiple time segments
        :param key_multi_segment: whether key has mutiple time segments
        if query/key has multiple time segments, causal convolution should be applied separately for each time segment.
        :return: (batch, N, T, d_model)
        '''

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, T, T), same mask applied to all h heads.

        nbatches = query.size(0)

        N = query.size(1)

        # deal with key and query: temporal conv
        # (batch, N, T, d_model)->permute(0, 3, 1, 2)->(batch, d_model, N, T) -conv->(batch, d_model, N, T)-view->(batch, h, d_k, N, T)-permute(0,3,1,4,2)->(batch, N, h, T, d_k)

        if query_multi_segment and key_multi_segment:
            query_list = []
            key_list = []
            if self.n_length > 0:
                query_h = self.query_conv1Ds_aware_temporal_context(query[:, :, 0: self.n_length, :].permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1,
                                                                                                                4, 2)
                key_h = self.key_conv1Ds_aware_temporal_context(key[:, :,0: self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

                query_list.append(query_h)
                key_list.append(key_h)

            query = torch.cat(query_list, dim=3)
            key = torch.cat(key_list, dim=3)

        elif (not query_multi_segment) and (not key_multi_segment):

            query = self.query_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
            key = self.key_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2)).contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

        elif (not query_multi_segment) and (key_multi_segment):

            query = self.query_conv1Ds_aware_temporal_context(query.permute(0, 3, 1, 2))[:, :, :, :-self.causal_padding].contiguous().view(nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)

            key_list = []

            if self.n_length > 0:
                key_h = self.key_conv1Ds_aware_temporal_context(key[:, :, 0: self.n_length, :].permute(0, 3, 1, 2)).contiguous().view(
                    nbatches, self.h, self.d_k, N, -1).permute(0, 3, 1, 4, 2)
                key_list.append(key_h)

            key = torch.cat(key_list, dim=3)

        else:
            import sys
            print('error')
            sys.out

        # deal with value:
        # (batch, N, T, d_model) -linear-> (batch, N, T, d_model) -view-> (batch, N, T, h, d_k) -permute(2,3)-> (batch, N, h, T, d_k)
        value = self.linears[0](value).view(nbatches, N, -1, self.h, self.d_k).transpose(2, 3)

        # apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # x:(batch, N, h, T1, d_k)
        # attn:(batch, N, h, T1, T2)

        x = x.transpose(2, 3).contiguous()  # (batch, N, T1, h, d_k)
        x = x.view(nbatches, N, -1, self.h * self.d_k)  # (batch, N, T1, d_model)
        return self.linears[-1](x)


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, trg_dim,decoder1, src_dense, encode_temporal_position,decode_temporal_position, generator1, DEVICE,spatial_position): #generator2,
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder1 = decoder1
        # self.decoder2 = decoder2
        self.src_embed = src_dense
        # self.trg_embed = trg_dense
        self.encode_temporal_position = encode_temporal_position
        self.decode_temporal_position = decode_temporal_position
        self.prediction_generator1 = generator1
        # self.prediction_generator2 = generator2
        self.spatial_position = spatial_position
        self.trg_dim = trg_dim
        self.to(DEVICE)

    def forward(self, src, trg,x,y,te,td):
        '''
        src:  (batch_size, N, T_in, F_in)
        trg: (batch, N, T_out, F_out)
        '''
        encoder_output = self.encode(src,x,y,te)  # (batch_size, N, T_in, d_model)

        trg_shape = self.trg_dim#int(trg.shape[-1]/2)
        return self.decode1(trg[:, :, :, -trg_shape:], encoder_output, trg[:, :, :, :trg_shape],x,y,td)#trg[:, :, :, :trg_shape],x,y,td)  # src[:,:,-1:,:2])#

    def encode(self, src,x,y,t):
        '''
        src: (batch_size, N, T_in, F_in)
        '''
        src_emb = self.src_embed(src)
        if self.encode_temporal_position ==False:
            src_tmpo_emb = src_emb
        else:
            src_tmpo_emb = self.encode_temporal_position(src_emb,t)
        if self.spatial_position == False:
            h = src_tmpo_emb
        else:
            h = self.spatial_position(src_tmpo_emb, x,y)

        return self.encoder(h)


    def decode1(self, trg, encoder_output,encoder_input,x,y,t):
        trg_embed = self.src_embed
        trg_emb_shape = self.trg_dim
        trg_emb = torch.matmul(trg, list(trg_embed.parameters())[0][:, trg_emb_shape:].T)
        if self.encode_temporal_position ==False:
            trg_tempo_emb = trg_emb
        else:
            trg_tempo_emb = self.decode_temporal_position(trg_emb, t)

        if self.spatial_position ==False:
            a =  self.prediction_generator1(self.decoder1(trg_tempo_emb, encoder_output))+encoder_input#[:,:,:,0:2]
            return a
        else:
            a =  self.prediction_generator1(self.decoder1(self.spatial_position(trg_tempo_emb,x,y), encoder_output))+encoder_input#[:,:,:,0:2]
            return a




class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, gcn, dropout, residual_connection=True, use_LayerNorm=True):
        super(EncoderLayer, self).__init__()
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        self.self_attn = self_attn
        self.feed_forward_gcn = gcn
        if residual_connection or use_LayerNorm:
            self.sublayer = clones(SublayerConnection(size, dropout, residual_connection, use_LayerNorm), 2)
        self.size = size

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        if self.residual_connection or self.use_LayerNorm:
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, query_multi_segment=True, key_multi_segment=True))
            return self.sublayer[1](x, self.feed_forward_gcn)
        else:
            x = self.self_attn(x, x, x, query_multi_segment=True, key_multi_segment=True)
            return self.feed_forward_gcn(x)


class Encoder(nn.Module):
    def __init__(self, layer, N):
        '''
        :param layer:  EncoderLayer
        :param N:  int, number of EncoderLayers
        '''
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, gcn, dropout, residual_connection=True, use_LayerNorm=True):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward_gcn = gcn
        self.residual_connection = residual_connection
        self.use_LayerNorm = use_LayerNorm
        if residual_connection or use_LayerNorm:
            self.sublayer = clones(SublayerConnection(size, dropout, residual_connection, use_LayerNorm), 3)

    def forward(self, x, memory):
        '''
        :param x: (batch_size, N, T', F_in)
        :param memory: (batch_size, N, T, F_in)
        :return: (batch_size, N, T', F_in)
        '''
        m = memory
        tgt_mask = subsequent_mask(x.size(-2)).to(m.device)  # (1, T', T')
        if self.residual_connection or self.use_LayerNorm:
            x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, query_multi_segment=False, key_multi_segment=False))  # output: (batch, N, T', d_model)
            x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, query_multi_segment=False, key_multi_segment=True))  # output: (batch, N, T', d_model)
            return self.sublayer[2](x, self.feed_forward_gcn)  # output:  (batch, N, T', d_model)
        else:
            x = self.self_attn(x, x, x, tgt_mask, query_multi_segment=False, key_multi_segment=False)  # output: (batch, N, T', d_model)
            x = self.src_attn(x, m, m, query_multi_segment=False, key_multi_segment=True)  # output: (batch, N, T', d_model)
            return self.feed_forward_gcn(x)  # output:  (batch, N, T', d_model)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, memory):
        '''
        :param x: (batch, N, T', d_model)
        :param memory: (batch, N, T, d_model)
        :return:(batch, N, T', d_model)
        '''
        for layer in self.layers:
            x = layer(x, memory)
        return self.norm(x)

class EmbedLinear(nn.Module):
    def __init__(self, encoder_input_size, d_model,bias=False):
        '''
        :param layer:  EncoderLayer
        :param N:  int, number of EncoderLayers
        '''
        super(EmbedLinear, self).__init__()
        self.layers = nn.Linear(encoder_input_size, d_model, bias=bias)

    def forward(self, x):
        '''
        :param x: src: (batch_size, N, T_in, F_in)
        :return: (batch_size, N, T_in, F_in)
        '''
        #for layer in self.layers:
        y = self.layers(x)
        return y

def search_index(max_len, num_of_depend, num_for_predict,points_per_hour, units):
    '''
    Parameters
    ----------
    max_len: int, length of all encoder input
    num_of_depend: int,
    num_for_predict: int, the number of points will be predicted for each sample
    units: int, week: 7 * 24, day: 24, recent(hour): 1
    points_per_hour: int, number of points per hour, depends on data
    Returns
    ----------
    list[(start_idx, end_idx)]
    '''
    x_idx = []
    for i in range(1, num_of_depend + 1):
        start_idx = max_len - points_per_hour * units * i
        for j in range(num_for_predict):
            end_idx = start_idx + j
            x_idx.append(end_idx)
    return x_idx



def make_model(DEVICE,logitudelatitudes, num_layers, encoder_input_size,decoder_input_size, decoder_output_size, d_model, adj_mx, nb_head, num_of_lags,points_per_lag,
                 num_for_predict, dropout=.0, aware_temporal_context=True,
               ScaledSAt=True, SE=True, TE=True, kernel_size=3, smooth_layer_num=0, residual_connection=True, use_LayerNorm=True):

    # LR rate means: graph Laplacian Regularization

    c = copy.deepcopy

    norm_Adj_matrix = torch.from_numpy(norm_Adj(adj_mx)).type(torch.FloatTensor).to(DEVICE)  # 通过邻接矩阵,构造归一化的拉普拉斯矩阵

    num_of_vertices = norm_Adj_matrix.shape[0]

    src_dense = EmbedLinear(encoder_input_size, d_model, bias=False)#nn.Linear(encoder_input_size, d_model, bias=False)

    if ScaledSAt:  # employ spatial self attention
        position_wise_gcn = PositionWiseGCNFeedForward(spatialAttentionScaledGCN(norm_Adj_matrix, d_model, d_model), dropout=dropout)
    else:  #
        position_wise_gcn = PositionWiseGCNFeedForward(spatialGCN(norm_Adj_matrix, d_model, d_model), dropout=dropout)

    # encoder temporal position embedding
    max_len = num_of_lags

    if aware_temporal_context:  # employ temporal trend-aware attention
        attn_ss = MultiHeadAttentionAwareTemporalContex_q1d_k1d(nb_head, d_model, num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
        attn_st = MultiHeadAttentionAwareTemporalContex_qc_k1d(nb_head, d_model,num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
        att_tt = MultiHeadAttentionAwareTemporalContex_qc_kc(nb_head, d_model, num_of_lags, points_per_lag,  kernel_size, dropout=dropout)
    else:  # employ traditional self attention
        attn_ss = MultiHeadAttention(nb_head,d_model, dropout=dropout) #d_model, dropout=dropout)
        attn_st = MultiHeadAttention(nb_head,d_model, dropout=dropout)# d_model, dropout=dropout)
        att_tt = MultiHeadAttention(nb_head,d_model, dropout=dropout) #d_model, dropout=dropout)

    encode_temporal_position = TemporalPositionalEncoding(d_model, dropout, max_len)  #   en_lookup_index   decoder temporal position embedding
    decode_temporal_position = TemporalPositionalEncoding(d_model, dropout, num_for_predict)
    spatial_position = SpatialPositionalEncoding_RBF(d_model, logitudelatitudes,num_of_vertices, dropout, GCN(norm_Adj_matrix, d_model, d_model), smooth_layer_num=smooth_layer_num) #logitudelatitudes,


    encoderLayer = EncoderLayer(d_model, attn_ss, c(position_wise_gcn), dropout, residual_connection=residual_connection, use_LayerNorm=use_LayerNorm)

    encoder = Encoder(encoderLayer, num_layers)

    decoderLayer1 = DecoderLayer(d_model, att_tt, attn_st, c(position_wise_gcn), dropout, residual_connection=residual_connection, use_LayerNorm=use_LayerNorm)

    decoder1 = Decoder(decoderLayer1, num_layers)

    generator1 = nn.Linear(d_model, decoder_output_size)#



    model = EncoderDecoder(encoder,decoder_output_size,
                       decoder1,
                           src_dense,
                       encode_temporal_position,
                       decode_temporal_position,
                       generator1,
                       DEVICE,
                       spatial_position) #,generator2

    # param init
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

2、实验结果 

模型经过399个epoch训练后,在验证阶段,损失为0.1143,其他性能指标包括c-r为0.0185和L-hr为0.0215,验证阶段耗时约3.655秒,模型在第308个epoch达到最佳性能。第二张图的训练和验证损失曲线显示,训练损失从高到低逐渐稳定,验证损失经过初始波动后也趋于平稳,这表明模型随着训练逐渐适应数据,达到了较好的泛化能力。

三、总结

本周阅读的这篇论文,受益颇多,回顾了很多知识,比如说GCN、多头自注意力等,文中提到的方法aGNN大幅减轻了基于物理模型的计算负担,特别是在需要处理大量注入井和长期管理的情景中,显著提高了计算效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839893.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第九届信也科技杯全球AI算法大赛——语音深度鉴伪识别参赛A榜 0.968961分

遗憾没有进复赛,只是第41名。先贴个A榜的成绩。A榜的前三十名晋级,个个都是99分的大佬,但是B榜的成绩就有点低了,应该是数据不同源的问题,第一名0.78分。官网链接:语音深度鉴伪识别 官方baselin:https://g…

XHS-Downloader是一款小红书图片视频下载工具

这款软件可以提取账号发布、收藏、点赞作品链接;提取搜索结果作品链接、用户链接;下载小红书作品信息;提取小红书作品下载地址;下载小红书无水印作品文件! 📑 功能清单 ✅ 采集小红书图文 / 视频作品信息…

项目五 OpenStack镜像管理与制作

任务一 理解OpenStack镜像服务 1.1 •什么是镜像 • 镜像通常 是指一系列文件或一个磁盘驱动器的精确副本 。 • 虚拟机 所使用的虚拟磁盘, 实际上是 一种特殊格式的镜像文件 。 • 云 环境下尤其需要 镜像。 • 镜像 就是一个模板,类似于 VMware 的虚拟…

每日复盘-202406019

今日关注: 20240619 六日涨幅最大: ------1--------300868--------- 杰美特 五日涨幅最大: ------1--------300462--------- 华铭智能 四日涨幅最大: ------1--------300462--------- 华铭智能 三日涨幅最大: ------1--------300462--------- 华铭智能 二日涨幅最大…

使用MAT定位线上OOM问题

目录 1.什么是OOM? 2.发生的可能原因 3.常见类型的OOM 4.如何定位问题? 4.1 获取dump文件 4.2 MAT分析 「Leak Suspects」泄露嫌疑 「Histogram」直方图 「dominator tree」支配树 「thread overview」线程视图 目录 1.什么是OOM? 2.发生的可能原因 …

MPLS TE简介

定义 MPLS TE(MPLS Traffic Engineering),即MPLS流量工程。MPLS流量工程通过建立基于一定约束条件的LSP隧道,并将流量引入到这些隧道中进行转发,使网络流量按照指定的路径进行传输,达到流量工程的目的。 …

vue3+element ui +ts 封装周范围选择器

vue3element ui ts 封装周范围选择器 在业务场景中,产品需要在页面中使用周范围选择器,我们在使用ant-design的时候里面是有自带的,但是在emement中只有指定周的范围选择器: 这个是ant-design的周范围选择器 这个是element ui 的…

C# WinForm —— 36 布局控件 GroupBox 和 Panel

1. 简介 两个可以盛放其他控件的容器,可以用于把不同的控件分组,一般不会注册事件 GroupBox:为其他控件提供可识别的分组。可通过Text属性设置标题;有边框;没有滚动条,一般用于按功能分组 Panel&#xff…

鸿蒙小案例-短视频

参加泡泡玛特写了个小demo,然后给它稍微完善了一下 基于API11 演示效果 hfvideo演示视频 主要功能集中在4个tab页内 1.首页-视频播放页 2.朋友-关注、朋友、粉丝聚合 3.消息-聊天列表 4.我的-当前用户信息展示 主页页面 1.用户主页 2.聊天页面 3.朋友页面 4.视频播放页 因为不…

【嵌入式】嵌入式Linux开发实战指南:从交叉编译到触摸屏交互

文章目录 前言:1.简介1.1. 交叉编译工具1.2. 项目开发流程:1.3. ARM开发板的连接方法 2. 开发板连接3. 系统文件 IO4. 设置共享文件夹3.1. 读文件3.2. 写文件3.2. 设置文件偏移量 4. LCD显示屏显示4.1. LCD 显示颜色4.2. 将文件下载到开发板4.2.1. 在CRT…

shadertoy-安装和使用

一、安装vscode 安装vscode流程 二、安装插件 1.安装glsl编辑插件 2.安装shader toy插件 三、创建glsl文件 test.glsl文件 float Grid(float size, vec2 fragCoord) {vec2 r fragCoord / size;vec2 grid abs(fract(r - 0.5) - 0.5) / fwidth(r);float line min(grid…

免费域名第二弹:手把手教你获取个性化免费域名并托管至Cloudflare

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 免费申请域名的方法 📒📝 注册账号📝 创建免费域名📝 将域名添加到 Cloudflare⚓️ 相关链接 ⚓️📖 介绍 📖 在如今的数字时代,拥有一个个性化的域名已经成为越来越多人的需求。无论是建立个人博客、项目展示,还…

JWT整合Gateway实现鉴权(RSA与公私密钥工具类)

一.业务流程 1.使用RSA生成公钥和私钥。私钥保存在授权中心,公钥保存在网关(gateway)和各个信任微服务中。 2.用户请求登录。 3.授权中心进行校验,通过后使用私钥对JWT进行签名加密。并将JWT返回给用户 4.用户携带JWT访问 5.gateway直接通过公钥解密JWT进…

前端页面实现【矩阵表格与列表】

实现页面&#xff1a; 1.动态表绘制&#xff08;可用于矩阵构建&#xff09; <template><div><h4><b>基于层次分析法的权重计算</b></h4><table table-layout"fixed"><thead><tr><th v-for"(_, colI…

王思聪隐形女儿曝光

王思聪"隐形"女儿曝光&#xff01;黄一鸣独自面对怀孕风波&#xff0c;坚持生下爱情结晶近日&#xff0c;娱乐圈掀起了一场惊天波澜&#xff01;前王思聪绯闻女友黄一鸣在接受专访时&#xff0c;大胆揭露了她与王思聪之间的爱恨纠葛&#xff0c;并首度公开承认&#…

VBA技术资料MF161:按需要显示特定工作表

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。“VBA语言専攻”提供的教程一共九套&#xff0c;分为初级、中级、高级三大部分&#xff0c;教程是对VBA的系统讲解&#…

重生之 SpringBoot3 入门保姆级学习(24、场景整合 kafka 消息发送服务)

重生之 SpringBoot3 入门保姆级学习&#xff08;24、场景整合 kafka 消息发送服务&#xff09; 6.4 消息发送服务 6.4 消息发送服务 访问 kafka-ui &#xff08;注意这里需要换成你自己的服务器或者虚拟机的 IP 地址&#xff0c;虚拟机可以用局域网 192.168.xxx.xxx 的地址&…

05-对混合app应用中的元素进行定位

本文介绍对于混合app应用中的元素如何进行定位。 一、app的类型 1&#xff09;Native App&#xff08;原生应用&#xff09; 原生应用是指利用Android、IOS平台官方的开发语言、开发类库、工具等进行开发的app应用&#xff0c;在应用性能和交互体验上应该是最好的。 通俗点来…

富唯智能复合机器人

复合机器人&产品概述 富唯智能复合机器人集协作机器人、移动机器人和视觉引导技术于一体&#xff0c;搭载ICD系列核心控制器&#xff0c;一体化控制整个复合机器人系统&#xff0c;并且可以对接产线系统&#xff0c;搭配我司自研的2D/3D视觉平台&#xff0c;能够轻松实现工…