Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch)

news2024/11/26 16:58:46

Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch)

文章目录

  • Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch)
    • GLU(Gated Linear Unit)模块
    • GRN(Gated Residual Network)门控残差网络
    • Transformer经典模块
      • Add&Normalize模块
      • Scaled Dot-Product Attention 模块
      • InterpretableMultiHeadAttention可解释的多头注意力机制
    • Variable Selection Network 变量选择网络
    • QuantileLoss分位数损失函数
    • Temporal Fusion Transformer 模型构建

在这里插入图片描述

GLU(Gated Linear Unit)模块

该模块包括一个前向传播函数forward和一个参数初始化函数init_weights。
GLU是一种多层感知器(MLP)结构,由两个子层组成,分别为全连接层和sigmoid门。
全连接层有两个,一个作为输入门,另一个作为忘记门。sigmoid门负责控制两个门的权重,使得在两个门中仅有一个门起作用。
如果激活函数activation不是None,则在输出门之前添加一个激活函数。
在init_weights函数中,使用了Xavier均匀分布初始化权重和零初始化偏置。

class GatedLinearUnit(nn.Module):
    def __init__(self, input_size,
                 hidden_layer_size,
                 dropout_rate,
                 activation = None):
        
        super(GatedLinearUnit, self).__init__()
        
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.dropout_rate = dropout_rate
        self.activation_name = activation
        
        if self.dropout_rate:
            self.dropout = nn.Dropout(p=self.dropout_rate)
            
        self.W4 = torch.nn.Linear(self.input_size, self.hidden_layer_size)
        self.W5 = torch.nn.Linear(self.input_size, self.hidden_layer_size)
        
        if self.activation_name:
            self.activation = getattr(nn, self.activation_name)()
            
        self.sigmoid = nn.Sigmoid()
            
        self.init_weights()
            
    def init_weights(self):
        for n, p in self.named_parameters():
            if 'bias' not in n:
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            elif 'bias' in n:
                torch.nn.init.zeros_(p)
            
    def forward(self, x):
        
        if self.dropout_rate:
            x = self.dropout(x)
            
        if self.activation_name:
            output = self.sigmoid(self.W4(x)) * self.activation(self.W5(x))
        else:
            output = self.sigmoid(self.W4(x)) * self.W5(x)
            
        return output

GRN(Gated Residual Network)门控残差网络

这段代码实现了一个门控残差网络(Gated Residual Network),它是Temporal Fusion Transformer中的一个组成部分。

在这里插入图片描述

该类的初始化方法包含了多个参数:

hidden_layer_size:隐层的维度大小。
input_size:输入向量的维度大小。默认为None,若未指定则使用hidden_layer_size作为输入向量的维度大小。
output_size:输出向量的维度大小。默认为None,若未指定则输出向量的维度大小与隐层的维度大小相同。
dropout_rate:dropout概率。
additional_context:附加的上下文信息。默认为None。
return_gate:是否返回门控向量。默认为False。
在初始化方法中,首先定义了该类所需要用到的线性层,其中包括self.W1、self.W2和self.W3(若有附加的上下文信息)。接着,根据是否指定了输出向量的维度大小,定义了一个门控加和归一化网络(GateAddNormNetwork),其输入向量的维度大小取决于self.hidden_layer_size和self.output_size的大小关系。最后,对模型中的参数进行了初始化。

在该类的前向传播方法中,如果有附加的上下文信息,则将输入向量和上下文信息分别输入self.W2和self.W3中,并将它们的输出进行求和,并通过激活函数(这里使用了ELU)进行激活。接着,将该结果输入self.W1中,并将其输出通过门控加和归一化网络进行处理。如果指定了输出向量的维度大小,则将输入向量通过self.skip_linear进行线性转换,并将结果与门控加和归一化网络的输出相加后得到最终的输出向量;否则,直接将输入向量与门控加和归一化网络的输出相加后得到最终的输出向量。最后返回输出向量。

class GatedResidualNetwork(nn.Module):
    def __init__(self,
                 hidden_layer_size,
                 input_size = None,
                 output_size = None, 
                 dropout_rate = None, 
                 additional_context = None,
                 return_gate = False):
        
        super(GatedResidualNetwork, self).__init__()
        
        self.hidden_layer_size = hidden_layer_size
        self.input_size = input_size if input_size else self.hidden_layer_size
        self.output_size = output_size
        self.dropout_rate = dropout_rate
        self.additional_context = additional_context
        self.return_gate = return_gate
        
        self.W1 = torch.nn.Linear(self.hidden_layer_size, self.hidden_layer_size)
        self.W2 = torch.nn.Linear(self.input_size, self.hidden_layer_size)
        
        if self.additional_context:
            self.W3 = torch.nn.Linear(self.additional_context, self.hidden_layer_size, bias = False)
            

        if self.output_size:
            self.skip_linear = torch.nn.Linear(self.input_size, self.output_size)
            self.glu_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                   self.output_size,
                                                   self.dropout_rate)
        else:
            self.glu_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                   self.hidden_layer_size,
                                                   self.dropout_rate)
            
        self.init_weights()
            
    def init_weights(self):
        for name, p in self.named_parameters():
            if ('W2' in name or 'W3' in name) and 'bias' not in name:
                torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='leaky_relu')
            elif ('skip_linear' in name or 'W1' in name) and 'bias' not in name:
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            elif 'bias' in name:
                torch.nn.init.zeros_(p)
            
    def forward(self, x):
        
        if self.additional_context:
            x, context = x
            #x_forward = self.W2(x)
            #context_forward = self.W3(context)
            #print(self.W3(context).shape)
            n2 = F.elu(self.W2(x) + self.W3(context))
        else:
            n2 = F.elu(self.W2(x))
        
        #print('n2 shape {}'.format(n2.shape))
            
        n1 = self.W1(n2)
        
        #print('n1 shape {}'.format(n1.shape))
            
        if self.output_size:
            output = self.glu_add_norm(n1, self.skip_linear(x))
        else:
            output = self.glu_add_norm(n1, x)
            
        #print('output shape {}'.format(output.shape))
        
        return output

Transformer经典模块

Add&Normalize模块

在这里插入图片描述

transformer在进行self-Attention之后会进行一个layerNormalization 【将数据统一到固定区间内】
其中又分为batchNormalization和layerNormalization,batchNormalization 即按照batch维度化成均值为0标准差为1的数据
Layer则是纵向将每次的向量数据进行归一化
残差作用:加入未学习的原向量使得到的结果的效果至少不弱于原来的结果

在这里插入图片描述

蓝色圆圈标注的“add”是什么呢?好像初始注意力机制(如下图)中并没有add呀??

在这里插入图片描述

如图所示,在self-attention中求出 Z ( i ) = ∑ j = 1 T a i j v ( j ) Z^{\left( i \right)}=\sum_{j=1}^T{a_{ij}v^{\left( j \right)}} Z(i)=j=1Taijv(j)后直接将Z放入全连接层中,而transformer又经过了add&Norm步骤 ,这个add是什么呢???

其实add就是在Z的基础上加了一个残差块X,加入残差块X的目的是为了防止在深度神经网络训练中发生退化问题,退化的意思就是深度神经网络通过增加网络的层数,Loss逐渐减小,然后趋于稳定达到饱和,然后再继续增加网络层数,Loss反而增大。这个Add具体操作如下:其实就是将self-attention的output+input作为新的output,如下图:

在这里插入图片描述

关于首图所标记的2,Add&Norm这一过程指什么呢?
就是指对新的output做标准化,也就是上图的,对a+b做标准化

FeedForward:FeedForward又是什么呢?好像图二(注意力机制)中这里应该是全连接层(FC)呀?❓❓❓
先说一下FeedForward是什么?其实FeedForward是由全连接层(FC)与激活ReLu组成的结构,其实和bp神经网络结构差不多,输入:Add&Norm的输出(记作:Z’(i) ), FC:全连接层

在这里插入图片描述

接下来再说一下为什么要用FeedForward呢?不用单纯的FC呢?

其实主要还是想提取更深层次的特征,在Multi-Head Attention中,主要是进行矩阵乘法,即都是线性变换,而线性变换的学习能力不如非线性变换的学习能力强,我们希望通过引入ReLu激活函数,使模型增加非线性成分,强化学习能力。

回到这个模块来,它的输入是 x x x s k i p skip skip 两个张量,其中 x x x 是输入张量, s k i p skip skip 是跳跃连接(skip connection),用于保留之前层的信息,防止信息丢失。

在这个模块中,先通过 Gated Linear Unit(GLU)对 x x x 进行变换,然后将变换后的结果和 s k i p skip skip 进行加和操作,再通过 Layer Normalization 进行归一化,得到输出张量。

GLU 的作用是将 x x x 变换为与之相同维度的张量,并同时生成一个门控向量。门控向量的每个元素都是介于 0 和 1 之间的实数,用于控制该元素对最终结果的贡献程度。这里的 GLU 实现了类似于 Highway Network 的效果,即能够选择性地保留或抛弃某些信息。

Layer Normalization 的作用是对张量的每个特征维度进行归一化,使得每个特征维度上的均值为 0,方差为 1,从而加快收敛速度、提高泛化能力。

最终,这个门控加和归一化网络将 x x x s k i p skip skip 合并为一个输出张量,其中 s k i p skip skip 是经过跳跃连接保留下来的信息, x x x 是经过 GLU 变换后的信息,经过 Layer Normalization 归一化后得到的输出张量。

class GateAddNormNetwork(nn.Module):
    def __init__(self, input_size,
                 hidden_layer_size,
                 dropout_rate,
                 activation = None):
        
        super(GateAddNormNetwork, self).__init__()
        
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.dropout_rate = dropout_rate
        self.activation_name = activation
        
        self.GLU = GatedLinearUnit(self.input_size, 
                                   self.hidden_layer_size, 
                                   self.dropout_rate,
                                   activation = self.activation_name)
        
        self.LayerNorm = nn.LayerNorm(self.hidden_layer_size)
        
    def forward(self, x, skip):
        
        output = self.LayerNorm(self.GLU(x) + skip)
            
        return output

Scaled Dot-Product Attention 模块

在实际应用中,经常会用到 Attention 机制,其中最常用的是 Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.
在这里插入图片描述

在TFT里面套用原始Transformer架构,实现了一个Scaled Dot-Product Attention。
该类定义了一个基于nn.Module的自定义PyTorch模块,包含了初始化函数和前向传递函数。S

caledDotProductAttention中的主要计算过程是计算query和key之间的点积,然后进行softmax归一化,最后加权求和得到输出。具体地,该类的初始化函数接受两个参数:dropout和scale。
dropout参数指定了在softmax之前要应用多少dropout,
而scale参数指定是否应该缩放点积结果。
该类的前向传递函数接受四个参数:q,k,v和mask。
其中,q,k和v是查询、键和值的输入向量,mask是一个可选的屏蔽向量,用于屏蔽输入中的某些位置。

在前向函数中,首先计算 q 和 k 的点积得到注意力矩阵 attn。
如果模块被初始化为缩放模式,则在计算点积时会将结果除以 k 向量长度的平方根,以缩小点积结果的值域,避免梯度消失或梯度爆炸的问题。
如果 mask 不为空,则在 attn 中填充 -1e9,以便在 softmax 操作中产生 0 的概率。
然后,通过 softmax 函数将注意力矩阵归一化为概率分布,并应用 dropout 操作以防止过拟合。最后,通过点积计算得到输出向量 output。

在输出时,函数返回输出向量 output 和注意力矩阵 attn。

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout = 0, scale = True):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.softmax = nn.Softmax(dim = 2)
        self.scale = scale
            
    def forward(self, q, k, v, mask = None):
        #print('---Inputs----')
        #print('q: {}'.format(q[0]))
        #print('k: {}'.format(k[0]))
        #print('v: {}'.format(v[0]))
        
        attn = torch.bmm(q, k.permute(0,2,1))
        #print('first bmm')
        #print(attn.shape)
        #print('attn: {}'.format(attn[0]))
        
        if self.scale:
            dimention = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32))
            attn = attn / dimention
        #    print('attn_scaled: {}'.format(attn[0]))
            
        if mask is not None:
            #fill = torch.tensor(-1e9).to(DEVICE)
            #zero = torch.tensor(0).to(DEVICE)
            attn = attn.masked_fill(mask == 0, -1e9)
        #    print('attn_masked: {}'.format(attn[0]))
            
        attn = self.softmax(attn)
        #print('attn_softmax: {}'.format(attn[0]))
        attn = self.dropout(attn)
        
        output = torch.bmm(attn, v)
        
        return output, attn

InterpretableMultiHeadAttention可解释的多头注意力机制

在这里插入图片描述

TFT参照原始Transformer架构,实现了一个可解释的多头注意力机制,即Interpretable Multi-Head Attention。

在初始化函数中,定义了注意力头的个数n_head、输入特征的维度d_model和临时的q、k、v的维度d_k、d_q、d_v。同时,通过nn.Linear函数定义了q、k、v的线性层和v的线性层的组合,还定义了缩放点积注意力的ScaledDotProductAttention层以及输出的全连接层w_h,并初始化了权重。

在前向函数中,首先对输入的q、k、v分别通过线性层得到q_i、k_i、v_i,再利用ScaledDotProductAttention层计算出每个注意力头对应的权重矩阵和输出矩阵,并对输出矩阵进行dropout操作。然后,将多个头的输出按照最后一个维度进行拼接,再通过均值池化得到最终的输出矩阵,最后通过全连接层和dropout操作得到最终的输出和注意力权重矩阵,并将它们返回。

需要注意的是,在计算每个注意力头的权重矩阵时,利用了缩放点积注意力。该注意力的计算过程包括三个步骤:1)计算q和k的点积,即attn = torch.bmm(q, k.permute(0,2,1));2)如果需要进行缩放,则除以根号下d_k;3)将得到的结果进行softmax操作得到注意力矩阵。同时,在计算注意力矩阵时,还可以通过mask参数实现对无效位置的屏蔽。

class InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, dropout):
        super(InterpretableMultiHeadAttention, self).__init__()
            
        self.n_head = n_head
        self.d_model = d_model
        self.d_k = self.d_q = self.d_v = d_model // n_head
        self.dropout = nn.Dropout(p=dropout)
        
        self.v_layer = nn.Linear(self.d_model, self.d_v, bias = False)
        self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q, bias = False) 
                                       for _ in range(self.n_head)])
        self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k, bias = False) 
                                       for _ in range(self.n_head)])
        self.v_layers = nn.ModuleList([self.v_layer for _ in range(self.n_head)])
        self.attention = ScaledDotProductAttention()
        self.w_h = nn.Linear(self.d_v, self.d_model, bias = False)
        
        self.init_weights()
            
    def init_weights(self):
        for name, p in self.named_parameters():
            if 'bias' not in name:
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            else:
                torch.nn.init.zeros_(p)
        
    def forward(self, q, k, v, mask = None):
        
        heads = []
        attns = []
        for i in range(self.n_head):
            qs = self.q_layers[i](q)
            ks = self.k_layers[i](k)
            vs = self.v_layers[i](v)
            #print('qs layer: {}'.format(qs.shape))
            head, attn = self.attention(qs, ks, vs, mask)
            #print('head layer: {}'.format(head.shape))
            #print('attn layer: {}'.format(attn.shape))
            head_dropout = self.dropout(head)
            heads.append(head_dropout)
            attns.append(attn)
            
        head = torch.stack(heads, dim = 2) if self.n_head > 1 else heads[0]
        #print('concat heads: {}'.format(head.shape))
        #print('heads {}: {}'.format(0, head[0,0,Ellipsis]))
        attn = torch.stack(attns, dim = 2)
        #print('concat attn: {}'.format(attn.shape))
        
        outputs = torch.mean(head, dim = 2) if self.n_head > 1 else head
        #print('outputs mean: {}'.format(outputs.shape))
        #print('outputs mean {}: {}'.format(0, outputs[0,0,Ellipsis]))
        outputs = self.w_h(outputs)
        outputs = self.dropout(outputs)
        
        return outputs, attn

Variable Selection Network 变量选择网络

在这里插入图片描述

这段代码是 Temporal Fusion Transformer 中 VariableSelectionNetwork 的前向传递方法。

VariableSelectionNetwork 的目的是选择哪些变量对于当前时间步骤的预测是最重要的。这里有两种情况:使用静态上下文(只有输入嵌入)和使用非静态上下文(输入嵌入和静态上下文)。

在这个方法中,如果使用了非静态上下文,输入会分成嵌入和静态上下文两个部分。
首先,使用 GatedResidualNetwork 对嵌入进行非线性转换,
然后使用 Softmax 函数将其转换为稀疏权重。
接着,对于每个要预测的特征,使用 GatedResidualNetwork 对嵌入进行非线性转换,将其与稀疏权重相乘,以生成一组经过转换的特征,最后将这些特征相加以生成最终的时间上下文。

如果只有静态上下文,则直接对输入嵌入进行非线性转换,将其与稀疏权重相乘,以生成一组转换后的特征,最后将这些特征相加以生成最终的时间上下文。
在这两种情况下,该方法返回时间上下文和稀疏权重。

class VariableSelectionNetwork(nn.Module):
    def __init__(self, hidden_layer_size,
                 dropout_rate,
                 output_size,
                 input_size = None, 
                 additional_context = None):
        super(VariableSelectionNetwork, self).__init__()
        
        self.hidden_layer_size = hidden_layer_size
        self.input_size = input_size
        self.output_size = output_size
        self.dropout_rate = dropout_rate
        self.additional_context = additional_context
            
        self.flattened_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                   input_size = self.input_size,
                                                   output_size = self.output_size,
                                                   dropout_rate = self.dropout_rate,
                                                   additional_context=self.additional_context)
        
        self.per_feature_grn = nn.ModuleList([GatedResidualNetwork(self.hidden_layer_size,
                                                                   dropout_rate=self.dropout_rate)
                                                      for i in range(self.output_size)])
    def forward(self, x):
        # Non Static Inputs
        if self.additional_context:
            embedding, static_context = x
            #print('static_context')
            #print(static_context.shape)
            
            time_steps = embedding.shape[1]
            flatten = embedding.view(-1, time_steps, self.hidden_layer_size * self.output_size)
            #print('flatten')
            #print(flatten.shape)
            
            static_context = static_context.unsqueeze(1)
            #print('static_context')
            #print(static_context.shape)

            # Nonlinear transformation with gated residual network.
            mlp_outputs = self.flattened_grn((flatten, static_context))
            #print('mlp_outputs')
            #print(mlp_outputs.shape)
            
            sparse_weights = F.softmax(mlp_outputs, dim = -1)
            sparse_weights = sparse_weights.unsqueeze(2)
            #print('sparse_weights')
            #print(sparse_weights.shape)
            
            trans_emb_list = []
            for i in range(self.output_size):
                e = self.per_feature_grn[i](embedding[Ellipsis, i])
                trans_emb_list.append(e)
            transformed_embedding = torch.stack(trans_emb_list, axis=-1)
            #print('transformed_embedding')
            #print(transformed_embedding.shape)
            
            combined = sparse_weights * transformed_embedding
            #print('combined')
            #print(combined.shape)
            
            temporal_ctx = torch.sum(combined, dim = -1)
            #print('temporal_ctx')
            #print(temporal_ctx.shape)
            
        # Static Inputs
        else:
            embedding = x
            #print('embedding')
            #print(embedding.shape)
            
            flatten = torch.flatten(embedding, start_dim=1)
            #flatten = embedding.view(batch_size, -1)
            #print('flatten')
            #print(flatten.shape)
            
            # Nonlinear transformation with gated residual network.
            mlp_outputs = self.flattened_grn(flatten)
            #print('mlp_outputs')
            #print(mlp_outputs.shape)
            
            sparse_weights = F.softmax(mlp_outputs, dim = -1)
            sparse_weights = sparse_weights.unsqueeze(-1)
#             print('sparse_weights')
#             print(sparse_weights.shape)
            
            trans_emb_list = []
            for i in range(self.output_size):
                #print('embedding for the per feature static grn')
                #print(embedding[:, i:i + 1, :].shape)
                e = self.per_feature_grn[i](embedding[:, i:i + 1, :])
                trans_emb_list.append(e)
            transformed_embedding = torch.cat(trans_emb_list, axis=1)
#             print('transformed_embedding')
#             print(transformed_embedding.shape)
    
            combined = sparse_weights * transformed_embedding
#             print('combined')
#             print(combined.shape)
            
            temporal_ctx = torch.sum(combined, dim = 1)
#             print('temporal_ctx')
#             print(temporal_ctx.shape)
        
        return temporal_ctx, sparse_weights

QuantileLoss分位数损失函数

在这里插入图片描述

这段代码定义了两个类,QuantileLossCalculator 和 NormalizedQuantileLossCalculator,用于计算量化损失(quantile loss)和标准化量化损失(normalized quantile loss)。

QuantileLossCalculator 类中的 init 方法初始化了计算器的属性 quantiles 和 output_size。quantiles 是一个浮点数列表,代表要计算的分位数,例如 [0.1, 0.5, 0.9] 表示要计算 10%、50% 和 90% 的分位数。output_size 是输出序列的长度。

QuantileLossCalculator 类中的 quantile_loss 方法计算单个分位数的量化损失。该方法使用了标准的量化损失公式,计算方法可以参考 Temporal Fusion Transformer 论文中的 “Training Procedure” 章节。该方法接受三个参数:y 是目标序列(也称为标签),y_pred 是预测序列,quantile 是分位数。该方法首先检查分位数是否在 0 和 1 之间,然后计算预测序列与目标序列之间的差值,并根据分位数计算加权误差。最后,该方法返回一个张量,表示量化损失。

QuantileLossCalculator 类中的 apply 方法计算所有指定的分位数的量化损失之和。该方法接受两个参数:a 是目标序列,b 是预测序列。该方法首先创建一个空列表 loss,然后对于每个指定的分位数,使用 quantile_loss 方法计算量化损失,并将结果添加到 loss 列表中。最后,该方法使用 torch.mean 计算 loss 列表中所有张量的平均值,表示所有指定的分位数的量化损失之和。返回的值是一个标量张量,表示总的量化损失。

NormalizedQuantileLossCalculator 类中的 init 方法和 apply 方法与 QuantileLossCalculator 类相同,不同之处在于 NormalizedQuantileLossCalculator 类计算的是标准化量化损失。标准化量化损失的计算方法与量化损失的计算方法类似,但是在计算加权误差时,使用的权重是相对于整个序列绝对误差的比例。最后,标准化量化损失通过总的标准化误差进行归一化,使得它们可以与其他模型的损失进行比较。返回的值是一个标量张量,表示标准化量化损失。

import torch

class QuantileLossCalculator():
    """Computes the combined quantile loss for prespecified quantiles.
    Attributes:
      quantiles: Quantiles to compute losses
    """

    def __init__(self, quantiles, output_size):
        """Initializes computer with quantiles for loss calculations.
            Args:
            quantiles: Quantiles to use for computations.
        """
        self.quantiles = quantiles
        self.output_size = output_size
        
    # Loss functions.
    def quantile_loss(self, y, y_pred, quantile):
        """ Computes quantile loss for pytorch.
            Standard quantile loss as defined in the "Training Procedure" section of
            the main TFT paper
            Args:
            y: Targets
            y_pred: Predictions
            quantile: Quantile to use for loss calculations (between 0 & 1)
            Returns:
            Tensor for quantile loss.
        """

        # Checks quantile
        if quantile < 0 or quantile > 1:
            raise ValueError(
                'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile))

        prediction_underflow = y - y_pred
#         print('prediction_underflow')
#         print(prediction_underflow.shape)
        q_loss = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \
                (1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow))
        
#         print('q_loss')
#         print(q_loss.shape)
        
#         loss = torch.mean(q_loss, dim = 1)
#         print('loss')
#         print(loss.shape)
#         return loss
           
#         return torch.sum(q_loss, dim = -1)
        return q_loss.unsqueeze(1)

    def apply(self, b, a):
        """Returns quantile loss for specified quantiles.
            Args:
            a: Targets
            b: Predictions
        """
        quantiles_used = set(self.quantiles)

        loss = []
#         loss = 0.
        for i, quantile in enumerate(self.quantiles):
            if quantile in quantiles_used:
                #print(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)].shape)
#                 loss += self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
#                                            b[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 
#                                            quantile)
                #print(a[Ellipsis, self.output_size * i].shape)
                #loss += self.quantile_loss(a[Ellipsis, self.output_size * i],
                #                           b[Ellipsis, self.output_size * i], 
                #                           quantile)
                
#                 loss.append(self.quantile_loss(a[Ellipsis, self.output_size * i:self.output_size * (i + 1)],
#                                                b[Ellipsis, self.output_size * i:self.output_size * (i + 1)], 
#                                                quantile))

                loss.append(self.quantile_loss(a[Ellipsis, i],
                                               b[Ellipsis, i], 
                                               quantile))
                
#         loss_computed = torch.cat(loss, axis = -1)
#         loss_computed = torch.sum(loss_computed, axis = -1)
#         loss_computed = torch.sum(loss_computed, axis = 0)

        loss_computed = torch.mean(torch.sum(torch.cat(loss, axis = 1), axis = 1))
        
        return loss_computed
#         return loss

class NormalizedQuantileLossCalculator():
    """Computes the combined quantile loss for prespecified quantiles.
    Attributes:
      quantiles: Quantiles to compute losses
    """

    def __init__(self, quantiles, output_size):
        """Initializes computer with quantiles for loss calculations.
            Args:
            quantiles: Quantiles to use for computations.
        """
        self.quantiles = quantiles
        self.output_size = output_size
        
    # Loss functions.
    def apply(self, y, y_pred, quantile):
        """ Computes quantile loss for pytorch.
            Standard quantile loss as defined in the "Training Procedure" section of
            the main TFT paper
            Args:
            y: Targets
            y_pred: Predictions
            quantile: Quantile to use for loss calculations (between 0 & 1)
            Returns:
            Tensor for quantile loss.
        """

        # Checks quantile
        if quantile < 0 or quantile > 1:
            raise ValueError(
                'Illegal quantile value={}! Values should be between 0 and 1.'.format(quantile))

        prediction_underflow = y - y_pred
#         print('prediction_underflow')
#         print(prediction_underflow.shape)
        weighted_errors = quantile * torch.max(prediction_underflow, torch.zeros_like(prediction_underflow)) + \
                (1. - quantile) * torch.max(-prediction_underflow, torch.zeros_like(prediction_underflow))
        
        quantile_loss = torch.mean(weighted_errors)
        normaliser = torch.mean(torch.abs(quantile_loss))
        return 2 * quantile_loss / normaliser

Temporal Fusion Transformer 模型构建

在这里插入图片描述
利用上述组件构建模型

class TemporalFusionTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super(TemporalFusionTransformer, self).__init__()
        
        self.hparams = hparams
        
        self.name = self.__class__.__name__

        # Data parameters
        self.time_steps = int(hparams.total_time_steps)#int(params['total_time_steps'])
        self.input_size = int(hparams.input_size)#int(params['input_size'])
        self.output_size = int(hparams.output_size)#int(params['output_size'])
        self.category_counts = json.loads(str(hparams.category_counts))#json.loads(str(params['category_counts']))
        self.num_categorical_variables = len(self.category_counts)
        self.num_regular_variables = self.input_size - self.num_categorical_variables
        self.n_multiprocessing_workers = int(hparams.multiprocessing_workers) #int(params['multiprocessing_workers'])

        # Relevant indices for TFT
        self._input_obs_loc = json.loads(str(hparams.input_obs_loc))#json.loads(str(params['input_obs_loc']))
        self._static_input_loc = json.loads(str(hparams.static_input_loc))#json.loads(str(params['static_input_loc']))
        self._known_regular_input_idx = json.loads(str(hparams.known_regular_inputs))#json.loads(str(params['known_regular_inputs']))
        self._known_categorical_input_idx = json.loads(str(hparams.known_categorical_inputs))#json.loads(str(params['known_categorical_inputs']))
        
        self.num_non_static_historical_inputs = self.get_historical_num_inputs()
        self.num_non_static_future_inputs = self.get_future_num_inputs()
        
        self.column_definition = [
                                  ('id', DataTypes.REAL_VALUED, InputTypes.ID),
                                  ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME),
                                  ('power_usage', DataTypes.REAL_VALUED, InputTypes.TARGET),
                                  ('hour', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
                                  ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
                                ]

        # Network params
        self.quantiles = [0.1, 0.5, 0.9]
#         self.use_cudnn = use_cudnn  # Whether to use GPU optimised LSTM
        self.hidden_layer_size = int(hparams.hidden_layer_size)#int(params['hidden_layer_size'])
        self.dropout_rate = float(hparams.dropout_rate)#float(params['dropout_rate'])
        self.max_gradient_norm = float(hparams.max_gradient_norm)#float(params['max_gradient_norm'])
        self.learning_rate = float(hparams.learning_rate)#float(params['learning_rate'])
        self.minibatch_size = int(hparams.minibatch_size)#int(params['minibatch_size'])
        self.num_epochs = int(hparams.num_epochs)#int(params['num_epochs'])
        self.early_stopping_patience = int(hparams.early_stopping_patience)#int(params['early_stopping_patience'])

        self.num_encoder_steps = int(hparams.num_encoder_steps)#int(params['num_encoder_steps'])
        self.num_stacks = int(hparams.stack_size)#int(params['stack_size'])
        self.num_heads = int(hparams.num_heads)#int(params['num_heads'])

        # Serialisation options
#         self._temp_folder = os.path.join(params['model_folder'], 'tmp')
#         self.reset_temp_folder()

        # Extra components to store Tensorflow nodes for attention computations
        self._input_placeholder = None
        self._attention_components = None
        self._prediction_parts = None

        print('*** {} params ***'.format(self.name))
        for k in vars(hparams):
            print('# {} = {}'.format(k, vars(hparams)[k]))
            
        self.train_criterion = QuantileLossCalculator(self.quantiles, self.output_size)
        self.test_criterion = NormalizedQuantileLossCalculator(self.quantiles, self.output_size)

        # Build model
        ## Build embeddings
        self.build_embeddings()
        
        ## Build Static Contex Networks
        self.build_static_context_networks()
        
        ## Building Variable Selection Networks
        self.build_variable_selection_networks()
        
        ## Build Lstm
        self.build_lstm()
        
        ## Build GLU for after lstm encoder decoder and layernorm
        self.build_post_lstm_gate_add_norm()
        
        ## Build Static Enrichment Layer
        self.build_static_enrichment()
        
        ## Building decoder multihead attention
        self.build_temporal_self_attention()
        
        ## Building positionwise decoder
        self.build_position_wise_feed_forward()
        
        ## Build output feed forward
        self.build_output_feed_forward()
        
        ## Initializing remaining weights
        self.init_weights()
        
    def init_weights(self):
        for name, p in self.named_parameters():
            if ('lstm' in name and 'ih' in name) and 'bias' not in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            elif ('lstm' in name and 'hh' in name) and 'bias' not in name:
        
                 torch.nn.init.orthogonal_(p)
            
            elif 'lstm' in name and 'bias' in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.zeros_(p)
        
    def get_historical_num_inputs(self):
        
        obs_inputs = [i for i in self._input_obs_loc]
        
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]
        
        wired_embeddings = [i for i in range(self.num_categorical_variables)
                            if i not in self._known_categorical_input_idx 
                            and i not in self._input_obs_loc]

        unknown_inputs = [i for i in range(self.num_regular_variables)
                          if i not in self._known_regular_input_idx
                          and i not in self._input_obs_loc]

        return len(obs_inputs+known_regular_inputs+known_categorical_inputs+wired_embeddings+unknown_inputs)
    
    def get_future_num_inputs(self):
            
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]

        return len(known_regular_inputs + known_categorical_inputs)
    
    def build_embeddings(self):
        self.categorical_var_embeddings = nn.ModuleList([nn.Embedding(self.category_counts[i], 
                                                                      self.hidden_layer_size) 
                                                     for i in range(self.num_categorical_variables)])

        self.regular_var_embeddings = nn.ModuleList([nn.Linear(1, 
                                                              self.hidden_layer_size) 
                                                  for i in range(self.num_regular_variables)])

    def build_variable_selection_networks(self):
        
        self.static_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                   input_size = self.hidden_layer_size * len(self._static_input_loc),
                                                   output_size = len(self._static_input_loc),
                                                   dropout_rate = self.dropout_rate)
        
        self.temporal_historical_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                                input_size = self.hidden_layer_size *
                                                                        self.num_non_static_historical_inputs,
                                                                output_size = self.num_non_static_historical_inputs,
                                                                dropout_rate = self.dropout_rate,
                                                                additional_context=self.hidden_layer_size)
        
        self.temporal_future_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                            input_size = self.hidden_layer_size *
                                                                        self.num_non_static_future_inputs,
                                                            output_size = self.num_non_static_future_inputs,
                                                            dropout_rate = self.dropout_rate,
                                                            additional_context=self.hidden_layer_size)
        
    def build_static_context_networks(self):
        
        self.static_context_variable_selection_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                                          dropout_rate=self.dropout_rate)
        
        self.static_context_enrichment_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                              dropout_rate=self.dropout_rate)

        self.static_context_state_h_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
        self.static_context_state_c_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
    def build_lstm(self):
        self.historical_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                       hidden_size = self.hidden_layer_size,
                                       batch_first = True)
        self.future_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                   hidden_size = self.hidden_layer_size,
                                   batch_first = True)
        
    def build_post_lstm_gate_add_norm(self):
        self.post_seq_encoder_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                                 self.hidden_layer_size,
                                                                 self.dropout_rate,
                                                                 activation = None)
        
    def build_static_enrichment(self):
        self.static_enrichment = GatedResidualNetwork(self.hidden_layer_size,
                                                      dropout_rate = self.dropout_rate,
                                                      additional_context=self.hidden_layer_size)
        
    def build_temporal_self_attention(self):
        self.self_attn_layer = InterpretableMultiHeadAttention(n_head = self.num_heads, 
                                                               d_model = self.hidden_layer_size,
                                                               dropout = self.dropout_rate)
        
        self.post_attn_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                           self.hidden_layer_size,
                                                           self.dropout_rate,
                                                           activation = None)
        
    def build_position_wise_feed_forward(self):
        self.GRN_positionwise = GatedResidualNetwork(self.hidden_layer_size,
                                                     dropout_rate = self.dropout_rate)
        
        self.post_tfd_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                         self.hidden_layer_size,
                                                         self.dropout_rate,
                                                         activation = None)
        
    def build_output_feed_forward(self):
        self.output_feed_forward = torch.nn.Linear(self.hidden_layer_size, 
                                                   self.output_size * len(self.quantiles))
         
    def get_decoder_mask(self, self_attn_inputs):
        """Returns causal mask to apply for self-attention layer.
        Args:
        self_attn_inputs: Inputs to self attention layer to determine mask shape
        """
        len_s = self_attn_inputs.shape[1]
        bs = self_attn_inputs.shape[0]
        mask = torch.cumsum(torch.eye(len_s), 0)
        mask = mask.repeat(bs,1,1).to(torch.float32)

        return mask.to(DEVICE)
    
    def get_tft_embeddings(self, regular_inputs, categorical_inputs):
        # Static input
        if self._static_input_loc:
            static_regular_inputs = [self.regular_var_embeddings[i](regular_inputs[:, 0, i:i + 1]) 
                                    for i in range(self.num_regular_variables)
                                    if i in self._static_input_loc]
            #print('static_regular_inputs')
            #print([print(emb.shape) for emb in static_regular_inputs])
            
            static_categorical_inputs = [self.categorical_var_embeddings[i](categorical_inputs[Ellipsis, i])[:,0,:] 
                                         for i in range(self.num_categorical_variables)
                                         if i + self.num_regular_variables in self._static_input_loc]
            #print('static_categorical_inputs')
            #print([print(emb.shape) for emb in static_categorical_inputs])
            static_inputs = torch.stack(static_regular_inputs + static_categorical_inputs, axis = 1)
        else:
            static_inputs = None
            
        # Target input
        obs_inputs = torch.stack([self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                                     for i in self._input_obs_loc], axis=-1)
        
        # Observed (a prioir unknown) inputs
        wired_embeddings = []
        for i in range(self.num_categorical_variables):
            if i not in self._known_categorical_input_idx \
            and i not in self._input_obs_loc:
                e = self.categorical_var_embeddings[i](categorical_inputs[:, :, i])
                wired_embeddings.append(e)

        unknown_inputs = []
        for i in range(self.num_regular_variables):
            if i not in self._known_regular_input_idx \
            and i not in self._input_obs_loc:
                e = self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                unknown_inputs.append(e)
                
        if unknown_inputs + wired_embeddings:
            unknown_inputs = torch.stack(unknown_inputs + wired_embeddings, axis=-1)
        else:
            unknown_inputs = None
            
        # A priori known inputs
        known_regular_inputs = [self.regular_var_embeddings[i](regular_inputs[Ellipsis, i:i + 1])
                                for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
        #print('known_regular_inputs')
        #print([print(emb.shape) for emb in known_regular_inputs])
        
        known_categorical_inputs = [self.categorical_var_embeddings[i](categorical_inputs[Ellipsis, i])
                                    for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]
       #print('known_categorical_inputs')
       #print([print(emb.shape) for emb in known_categorical_inputs])

        known_combined_layer = torch.stack(known_regular_inputs + known_categorical_inputs, axis=-1)
        
        return unknown_inputs, known_combined_layer, obs_inputs, static_inputs
        
    def forward(self, all_inputs):

        regular_inputs = all_inputs[:, :, :self.num_regular_variables].to(torch.float)
        #print('regular_inputs')
        #print(regular_inputs.shape)
        categorical_inputs = all_inputs[:, :, self.num_regular_variables:].to(torch.long)
        #print('categorical_inputs')
        #print(categorical_inputs.shape)
        
        unknown_inputs, known_combined_layer, obs_inputs, static_inputs \
            = self.get_tft_embeddings(regular_inputs, categorical_inputs)
        
        # Isolate known and observed historical inputs.
        if unknown_inputs is not None:
              historical_inputs = torch.cat([
                  unknown_inputs[:, :self.num_encoder_steps, :],
                  known_combined_layer[:, :self.num_encoder_steps, :],
                  obs_inputs[:, :self.num_encoder_steps, :]
              ], axis=-1)
        else:
              historical_inputs = torch.cat([
                  known_combined_layer[:, :self.num_encoder_steps, :],
                  obs_inputs[:, :self.num_encoder_steps, :]
              ], axis=-1)
                
        #print('historical_inputs')
        #print(historical_inputs.shape)
        
        # Isolate only known future inputs.
        future_inputs = known_combined_layer[:, self.num_encoder_steps:, :]
        #print('future_inputs')
        #print(future_inputs.shape)
              
        #print('static_inputs')
        #print(static_inputs.shape)
        
        static_encoder, sparse_weights = self.static_vsn(static_inputs)
        
        #print('static_encoder')
        #print(static_encoder.shape)
        
        #print('sparse_weights')
        #print(sparse_weights.shape)
        
        static_context_variable_selection = self.static_context_variable_selection_grn(static_encoder)
        #print('static_context_variable_selection')
        #print(static_context_variable_selection.shape)
        static_context_enrichment = self.static_context_enrichment_grn(static_encoder)
        #print('static_context_enrichment')
        #print(static_context_enrichment.shape)
        static_context_state_h = self.static_context_state_h_grn(static_encoder)
        #print('static_context_state_h')
        #print(static_context_state_h.shape)
        static_context_state_c = self.static_context_state_c_grn(static_encoder)
        #print('static_context_state_c')
        #print(static_context_state_c.shape)
        
        historical_features, historical_flags \
        = self.temporal_historical_vsn((historical_inputs,
                                        static_context_variable_selection))
        #print('historical_features')
        #print(historical_features.shape)
        #print('historical_flags')
        #print(historical_flags.shape)
        
        future_features, future_flags \
        = self.temporal_future_vsn((future_inputs,
                                    static_context_variable_selection))
        #print('future_features')
        #print(future_features.shape)
        #print('future_flags')
        #print(future_flags.shape)
        
        history_lstm, (state_h, state_c) \
        = self.historical_lstm(historical_features,
                               (static_context_state_h.unsqueeze(0),
                                static_context_state_c.unsqueeze(0)))
        #print('history_lstm')
        #print(history_lstm.shape)
        #print('state_h')
        #print(state_h.shape)
        #print('state_c')
        #print(state_c.shape)
        
        future_lstm, _ = self.future_lstm(future_features,
                                          (state_h,
                                           state_c))
        #print('future_lstm')
        #print(future_lstm.shape)
        
        # Apply gated skip connection
        input_embeddings = torch.cat((historical_features, future_features), axis=1)
        #print('input_embeddings')
        #print(input_embeddings.shape) 
        
        lstm_layer = torch.cat((history_lstm, future_lstm), axis=1)
        #print('lstm_layer')
        #print(lstm_layer.shape) 
        
        temporal_feature_layer = self.post_seq_encoder_gate_add_norm(lstm_layer, input_embeddings)
        #print('temporal_feature_layer')
        #print(temporal_feature_layer.shape)  
        
        # Static enrichment layers
        expanded_static_context = static_context_enrichment.unsqueeze(1)
        
        enriched = self.static_enrichment((temporal_feature_layer, expanded_static_context))
        #print('enriched')
        #print(enriched.shape)    
        
        # Decoder self attention
        #self.mask = self.get_decoder_mask(enriched)
        #print('enriched')
        #print(enriched.shape)
        x, self_att = self.self_attn_layer(enriched, 
                                           enriched, 
                                           enriched,
                                           mask = self.get_decoder_mask(enriched))
        #print('x')
        #print(x.shape)
        #print('self_att')
        #print(self_att.shape)
        
        x = self.post_attn_gate_add_norm(x, enriched)
        #print('x')
        #print(x.shape)
        
        # Nonlinear processing on outputs
        decoder = self.GRN_positionwise(x)
        #print('decoder')
        #print(decoder.shape)
        
        # Final skip connection
        transformer_layer = self.post_tfd_gate_add_norm(decoder, temporal_feature_layer)
        #print('transformer_layer')
        #print(transformer_layer.shape)
        
        outputs = self.output_feed_forward(transformer_layer[Ellipsis, self.num_encoder_steps:, :])
        #print('outputs')
        #print(outputs.shape)
        
        #ipdb.set_trace()
        
        return outputs
    
    def loss(self, y_hat, y):
        return self.train_criterion.apply(y_hat, y)
    
    def test_loss(self, y_hat, y):
        return self.test_criterion.apply(y_hat, y, self.quantiles[1])
    
    def training_step(self, batch, batch_nb):
        x, y, _ = batch
        
        x = x.to(torch.float)
        y = y.to(torch.float)
#         print('y')
#         print(y.shape)
        y_hat = self.forward(x)
#         print('y_hat')
#         print(y_hat.shape)
        loss = self.loss(y_hat, torch.cat([y, y, y], dim = -1))
        #print(loss.shape)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_nb):
        x, y, _ = batch
        x = x.to(torch.float)
        y = y.to(torch.float)
        y_hat = self.forward(x)
        #print(y_hat.shape)
        #print(torch.cat([y, y, y], dim = -1).shape)
        loss = self.loss(y_hat, torch.cat([y, y, y], dim = -1))
        #print(loss)
        return {'val_loss': loss}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
    
    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y, _ = batch
        x = x.to(torch.float)
        y = y.to(torch.float)
        y_hat = self.forward(x)
        return {'test_loss': self.test_loss(y_hat[Ellipsis, 1], y[Ellipsis, 0])}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
    
    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return [torch.optim.Adam(self.parameters(), lr=self.learning_rate)]
    
    def plot_grad_flow(self, named_parameters):
        ave_grads = []
        layers = []
        for name, p in named_parameters:
            if p.grad is not None:
                if (p.requires_grad) and ("bias" not in name):
                    layers.append(name)
                    ave_grads.append(p.grad.abs().mean())
                    self.logger.experiment.add_histogram(tag=name, values=p.grad,
                                                         global_step=self.trainer.global_step)
            else:
                 print('{} - {}'.format(name, p.requires_grad))
            
        plt.plot(ave_grads, alpha=0.3, color="b")
        plt.hlines(0, 0, len(ave_grads), linewidth=1, color="k" )
        plt.xticks(list(range(0,len(ave_grads), 1)), layers, rotation='vertical')
        plt.xlim(left=0, right=len(ave_grads))
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.rcParams["figure.figsize"] = (20, 5)
    
    def on_after_backward(self):
        # example to inspect gradient information in tensorboard
        if self.trainer.global_step % 25 == 0:  
            self.plot_grad_flow(self.named_parameters())
    
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(train_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(valid_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)
    
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(test_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/384850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Leetcode】【简单】704. 二分查找

给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1。 示例 1: 输入: nums [-1,0,3,5,9,12], target 9 输出: 4 解释: 9 出现…

HNU工训中心:数模转换实验报告

D 级任务&#xff1a;实验 74194 的仿真验证 1. 实验任务 任务:实验 74194 的仿真验证&#xff0c;掌握 Quartus 仿真的基本原则和常规步骤&#xff0c;记录移 位寄存器的数据读写&#xff0c;并描述仿真波形&#xff0c;结果分析并记录到实验报告。 实验实作 最终的电路图…

“探索未来:VR全景直播技术引领新媒体时代”

随着虚拟现实技术的不断发展&#xff0c;VR全景直播已经成为了越来越受欢迎的直播形式。VR全景直播可以让观众通过虚拟现实设备亲临直播现场&#xff0c;享受身临其境的观看体验。VR全景直播是什么&#xff1f; VR全景直播是虚拟现实技术和直播的结合。相对于传统直播&#xff…

SOLIDWORKS提升 SAE赛车队成绩

SOLIDWORKS Education Edition实施结果&#xff1a; ●将车队成绩提升至史上最高水平 ●使用仿真工具验证赛车设计 ●促进车队成员之间的交流 ●增强设计可视化试用及报价新罕布什尔大学是美国顶级研究机构之一。UNH 拥有多个学术部门、跨学科机构和研究中心&#xff0c;屹立在…

APP发布----HBuilder

pc端&#xff1a;桌面应用&#xff1a;electronjs–vasode web应用&#xff1a; 纯前端开发—vue、react等框架 混合开发 — php、jsp、asp App开发类型 原生开发&#xff1a; Jave – AndroidObject – iosc – 鸿蒙运行效率最高、开发复杂、页面一致性 web-app&#xff08;…

新C++(12):位图与海量数据处理

"无关利弊或对错&#xff0c;勇气一直在"现在给你一道面试题:给40亿个不重复的无符号整数&#xff0c;没有任何排序。给出一个无符号整数&#xff0c;如何快速地判断&#xff0c;该数是否存在于40亿多个这样的无符号整数之中&#xff1f;唔……我想将这40亿个多数进行…

25- 卷积神经网络(CNN)原理 (TensorFlow系列) (深度学习)

知识要点 卷积神经网络的几个主要结构: 卷积层&#xff08;Convolutions&#xff09;: Valid :不填充&#xff0c;也就是最终大小为卷积后的大小. Same&#xff1a;输出大小与原图大小一致&#xff0c;那么N ​变成了​N2P. padding-零填充. 池化层&#xff08;Subsampli…

STM32HAL库1:认识时钟树(STM32F407)

目录 1.认识时钟树 1.1什么时钟 1.2认识时钟树 2.配置系统时钟 2.1系统适中的配置过程 2.2外设时钟使能和失能 2.3sys_stm32_clock_init函数&#xff08;F407&#xff09; 2.4SystemClock_Config函数&#xff08;F407&#xff09; [F1] [F4] 1.认识时钟树 1.1什么时钟 …

【禅道测试环境搭建及安装】Linux上的禅道安装教程,从环境搭建开始

目录 一、操作环境 二、安装VMware 三、安装FinalShell 四、在VMware里安装CentOS 1.前置&#xff1a;CentOS的下载 2.CentOS的安装 3.查看网络 五、用宿主机连接远程的虚拟机 六、禅道包的下载与CentOS配置的修改 七、上传项目包并安装禅道 一、操作环境 Windows 10…

Substrate 基础教程(Tutorials) -- 授权特定节点

五、授权特定节点 在添加可信节点中&#xff0c;您看到了如何使用一组已知的验证器节点构建一个简单的网络。该教程演示了一个简化版的许可网络&#xff08;permissioned network&#xff09;。在一个被许可的网络中&#xff0c;只有被授权的节点&#xff08;authorized nodes…

STM32CubexMX与FreeRTOS学习

FreeRTOS结合MX软件开发&#xff0c;基础配置直接生成&#xff0c;我们只需要会操作即可&#xff0c;操作一些API函数&#xff0c;注意事项就是&#xff1a;头文件、二值信号量一开始就有、定时器需要打开并且配置周期 1-LED-EXTI 上下拉模式和触发模式不要选择错误 void HAL_…

Elasticsearch:如何在 Elastic 中实现图片相似度搜索

作者&#xff1a;Radovan Ondas 在本文章&#xff0c;我们将了解如何通过几个步骤在 Elastic 中实施相似图像搜索。 开始设置应用程序环境&#xff0c;然后导入 NLP 模型&#xff0c;最后完成为你的图像集生成嵌入。 Elastic 图像相似性搜索概览 >> 如何设置环境 第一步…

【11】linux命令每日分享——useradd添加用户

大家好&#xff0c;这里是sdust-vrlab&#xff0c;Linux是一种免费使用和自由传播的类UNIX操作系统&#xff0c;Linux的基本思想有两点&#xff1a;一切都是文件&#xff1b;每个文件都有确定的用途&#xff1b;linux涉及到IT行业的方方面面&#xff0c;在我们日常的学习中&…

网页全屏截图、在线截图API

在我们的数字时代&#xff0c;网页截图已经成为了一个非常常见的需求&#xff0c;无论是为了记录网页信息、保存重要信息或是与他人分享自己看到的内容&#xff0c;截图工具都是必不可少的。今天&#xff0c;我想向大家推荐一个非常好用的在线网页全屏截图工具&#xff0c;它的…

获取html元素相对屏幕的位置

1. 写在前面 Javascript 由三部分构成&#xff0c;ECMAScript&#xff0c;DOM和BOM。 ECMAScript(核心)   描述了JS的语法和基本对象 Bom:(browser object model&#xff09;&#xff1b;浏览器对象模型&#xff0c;提供一些属性和方法可以操作浏览器&#xff0c;浏览器可…

【spring】Spring Data --Spring Data JPA

Spring Data 的委托是为数据访问提供熟悉且符合 Spring 的编程模型&#xff0c;同时仍保留着相关数据存储的特​​殊特征。 它使使用数据访问技术、关系和非关系数据库、map-reduce 框架和基于云的数据服务变得容易。这是一个伞形项目&#xff0c;其中包含许多特定于给定数据库…

详述java的设计模式(二)

1.工厂模式 简单工厂模式&#xff08;Simple Factory Pattern&#xff09;属于创建型模式&#xff0c;是指由一个工厂类根据传入的参数&#xff0c;动态决定应该创建哪一个产品类的对象&#xff0c;而被创建的对象通常都具有共同的父类或接口&#xff0c;这种模式又叫做静态工…

数据结构和算法学习

文章目录精通一个领域切题四件套算法算法的五个条件流程图数据结构数据与信息数据信息数据结构和算法数据结构算法时间复杂度空间复杂度数组 Array优点缺点数组和链表的区别时间复杂度链表 Linked List优点缺点时间复杂度单向链表双向链表循环链表双向循环链表堆栈 Stack队列 Q…

涂鸦开发资源中心硬核上线!支持产物包可视化、自由裁剪开发框架等功能

开发过程对于程序猿们来说&#xff0c;就像建碉堡&#xff0c;碉堡的建成速度、创造力&#xff0c;以及精致和牢固程度&#xff0c;取决于工具称不称手。不称手的开发工具&#xff0c;不仅能让人心力交瘁&#xff0c;还会多掉几根头发。 为了帮助开发者获得更好的开发体验&…

【搞透C语言指针】那年我双手插兜, 不知道指针是我的对手

☃️内容专栏&#xff1a;【C语言】进阶部分 ☃️本文概括&#xff1a; 征服C语言指针&#xff01;一篇文章搞清楚指针的全部要点。 ☃️本文作者&#xff1a;花香碟自来_ ☃️发布时间&#xff1a;2023.3.3 目录 一、字符指针 二、指针数组 三、数组指针 1.数组指针的定义…