深度学习算法informer(时序预测)(三)(Encoder)

news2024/12/23 14:19:57

一、EncoderLayer架构如图(不改变输入形状)

二、ConvLayer架构如图(输入形状中特征维度减半)

 三、Encoder整体

包括三部分

1. 多层EncoderLayer

2. 多层ConvLayer

3. 层归一化

代码如下

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, 
                 d_keys=None, d_values=None, mix=False):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model//n_heads)
        d_values = d_values or (d_model//n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
        self.mix = mix

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask
        )
        if self.mix:
            out = out.transpose(2,1).contiguous()
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        padding = 1 if torch.__version__>='1.5.0' else 2
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=padding,
                                  padding_mode='circular')
        # 批量归一化层的作用是在训练过程中对每个批次的数据进行归一化处理
        # 使其均值接近于 0,方差接近于 1,从而加速模型的训练和提高模型的稳定性
        # 不会改变形状
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x)
        x = x.transpose(1,2)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        # x = x + self.dropout(self.attention(
        #     x, x, x,
        #     attn_mask = attn_mask
        # ))
        new_x, attn = self.attention(
            x, x, x,
            attn_mask = attn_mask
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))

        return self.norm2(x+y), attn

class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1842417.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端安全——最新:lodash原型漏洞从发现到修复全过程

前端安全——最新:lodash原型漏洞从发现到修复全过程 1. 漏洞复现 现在很多系统的前端都是基于vue和react框架的,所以就肯定少不了引入各种依赖,而lodash作为一款非常流行的npm库,每月的下载量超过8000万次。可以说是使用的十分…

ASP.NET MVC企业级程序设计(增删,int类型转时间取余)

目录 题目: 实现过程 控制器代码 DAL BLL Index Jia 题目: 实现过程 控制器代码 using System; using System.Collections.Generic; using System.Linq; using System.Web; using System.Web.Mvc; using MvcApplication1.Models;namespace …

Palo Alto GlobalProtect App 6.3 (macOS, Linux, Windows, Andriod) - 端点网络安全客户端

Palo Alto GlobalProtect App 6.3 (macOS, Linux, Windows, Andriod) - 端点网络安全客户端 Palo Alto Networks 远程访问 VPN 客户端软件 请访问原文链接:https://sysin.org/blog/globalprotect-6/,查看最新版。原创作品,转载请保留出处。…

idea-Spring框架与ioc容器

Sping是轻量级的开源J2EE框架,可以解决企业应用开发的复杂性 Spring有两个核心部分为Ioc和AOP Ioc:控制反转,吧创建对象过程交给Sping进行管理 AOP:面向切面,不修改代码进行功能增强 创建Maven项目 IDEA-2024 就直接创建java项目即可 创…

为什么有人认为Linux不如macOS?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「Linux的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!!首先要明确你说的是哪个Lin…

Git分支的状态存储——stash命令的详细用法

文章目录 6.6 Git的分支状态存储6.6.1 git stash命令6.6.2 Git存储的基本使用6.6.3 Git存储的其他用法6.6.4 Git存储与暂存区6.6.5 Git存储的原理 6.6 Git的分支状态存储 有时,当我们在项目的一部分上已经工作一段时间后,所有东西都进入了混乱的状态&am…

前端网页开发学习(HTML+CSS+JS)有这一篇就够!

目录 HTML教程 ▐ 概述 ▐ 基础语法 ▐ 文本标签 ▐ 列表标签 ▐ 表格标签 ▐ 表单标签 CSS教程 ▐ 概述 ▐ 基础语法 ▐ 选择器 ▐ 修饰文本 ▐ 修饰背景 ▐ 透明度 ▐ 伪类 ▐ 盒子模型 ▐ 浮动 ▐ 定位 JavaScript教程 ▐ 概述 ▐ 基础语法 ▐ 函数 …

Si24R05—高度集成的低功耗 2.4G+125K SoC 芯片

Si24R05是一款高度集成的低功耗SoC芯片,具有低功耗、Low Pin Count、宽电压工作范围,集成了13/14/15/16位精度的ADC、LVD、UART、SPI、I2C、TIMER、WUP、IWDG、RTC、无线收发器、3D低频唤醒接收器等丰富的外设。内核采用RISC-V RV32IMAC(2.6 …

【BES2500x系列 -- RTX5操作系统】CMSIS-RTOS RTX -- 实时操作系统的核心,为嵌入式系统注入活力 --(一)

💌 所属专栏:【BES2500x系列】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &#x1f49…

本地部署应用测试01:ChatTTS

目的: 参考网上多种教程,学习开源大模型的部署与相关知识点,在此总结记录。 知识点: 1.大模型部署与应用大致步骤 首先需要下载项目的源码,并完成项目环境的搭建其次需要下载训练好的权重参数文件,以便…

微信公众号绑定开发者后端,报错“系统发生错误,请稍后重试”的坑

一、问题描述 在公众号后端填写完基本配置,点击保存,发现提示“系统发生错误,请稍后重试”。联系公众号客服回复,涉及开发内容不给支持-_-|| 二、经多次百度,结合实际尝试,总结解决方案如下:…

深入理解TCP协议:工作原理、报文结构及应用场景

TCP协议详解 TCP(Transmission Control Protocol,传输控制协议)是因特网协议套件中最重要的协议之一。它为应用程序提供了可靠、面向连接的通信服务。TCP协议确保数据包按顺序到达,并且没有丢失或重复。本文将详细介绍TCP协议的工…

OS复习笔记ch11-1

外围设备的管理和磁盘调度 外围设备 从CPU的角度来看,外设有几个比较重要的I/O接口(interfaces) 状态reg:向CPU报告设备的状态(忙碌/空闲)命令reg:接收CPU命令,存储 CPU 需要执行的…

Java变量:声明、作用域和命名约定

Java变量:声明、作用域和命名约定 什么是变量? 在Java中,变量是保存特定数据类型值的内存位置的名称。它是java编程中的一个基本概念,允许您在程序执行期间存储和操作数据。 Java中的变量可以保存各种类型的数据,包括…

【ai】tx2-nx:配置tritonserver2.17.0-jetpack4.6 环境并运行例子

2.17.0 for jetpack 4.6运行需要 如果在jetson上构建Triton : Note: When building Triton on Jetson, you will require a newer version of cmake. We recommend using cmake 3.21.0. Below is a script to upgrade your cmake version to 3.21.0. You can use cmake 3.18.4…

金融居间CRM系统赋能金服企业精细化管理客户

金融居间CRM系统可以帮助金融服务企业实现精细化管理客户。通过CRM系统,企业可以更好地了解和跟踪客户需求、行为和历史记录。以下是一些具体的赋能方式: 1. 客户数据集成与管理 将客户的个人信息、财务状况、交易历史等数据集成到一个统一的平台中&…

三步问题00

题目链接 三步问题 题目描述 注意点 n范围在[1, 1000000]之间结果可能很大,需要对结果模1000000007 解答思路 动态规划的思想根据dp[i - 1]、dp[i - 2]、dp[i - 3]推出dp[i]需要注意的是结果可能很大,在计算的过程中需要模1000000007防止越界 代码…

华为数通——OSPF

正掩码:/24 255.255.255.0 反掩码: 255.255.255.255 -255.-255.-255.0 0.0.0.255 例如掩码:255.255.252.0 反掩码:0.0.3.255 在反掩码里面,0 bit 表示精确匹配,1…

2024年有什么赚钱的副业推荐半年还清贷款,成功变现12.3w的全套玩法都放这里了!!!

要说推荐副业,我是最有发言权了。普通打工人一个,年轻不懂事,经常超前消费,欠了一屁股债,没得办法,就只能到处找能赚钱的门路。 尝试了30的副业,就发现能赚钱的不是太辛苦,就是需要…

目标检测——SCUT-HEAD:大规模人头检测数据集的深度剖析

引言 亲爱的读者们,您是否在寻找某个特定的数据集,用于研究或项目实践?欢迎您在评论区留言,或者通过公众号私信告诉我,您想要的数据集的类型主题。小编会竭尽全力为您寻找,并在找到后第一时间与您分享。 在…