(pytorch进阶之路)Informer

news2024/9/29 0:32:22

论文:Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting (AAAI’21 Best Paper)

看了一下以前的论文学习学习,我也是重应用吧,所以代码部分会比较多,理论部分就一笔带过吧

论文作者也很良心的给出了colab,就大大加快了看源码是怎么实现的速度:https://colab.research.google.com/drive/1_X7O2BkFLvqyCdZzDZvV2MB0aAvYALLC

那么源码主要看什么呢,首先是issue,github的issue里面如果压根就跑不了,那就不用花时间了,如果没太大的错误说明代码没有致命的错误

第二步是看数据,源数据是什么,数据如何预处理

第三步看模型实现,一般就在model文件夹下面,这一步比较简单,重点看创新点部分如何实现的

第四步pth,看看复现结果


文章目录

  • 模型框架
  • 代码地址

模型框架

在这里插入图片描述
创新点:ProbSparse Attention
主要思想就是用top-k选择最有用的信息


代码地址

https://github.com/zhouhaoyi/Informer2020

下载好代码和数据,仔细阅读Data的说明,我们得知得把数据放到data/ETT文件夹下面

parser部分大致看看什么意思,model,data,root_path,data_path,单卡多卡和num_workers设置一下,结合上下文推测大致的意思,同时github里面提供了数据字典,我们至少需要修改data和data_path参数

由于我是windows上debug的,所以args如果是required=True的话参数需要我们手动填就很麻烦,个人习惯就都改成False先

右键运行成功,那么就可以逐步debug了


main_informer.py运行,逐渐运行到
exp.train(setting)
进入train函数

		train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')

首先_get_data取数据,进入函数看看,data_dict里面看到了Dataset_Custom,就知道它是可以自定义数据的,后面实例化dataset,实例化dataset再实例化dataloader,数据集做好了

dataset中看看怎么预处理数据的,dataset里面有__read_data__和__getitem__函数,上下文分析__read_data__就是预处理的步骤,因为看到了StandardScaler,里面做了一个标准化

time_features函数对时间维度做特征编码,思想很简单,但是代码写特别复杂

最后构造dataloader


往下走到epoch开始迭代训练数据,到_process_one_batch函数

pred, true = self._process_one_batch(
    train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)

_process_one_batch进一步处理数据和输入进model,dec_input先全0或者全1进行初始化

然后enc_inputh后面48个和dec_input按dim=1维度进行拼接

dec_input前面的48个就是时序的观测值,我们要预测后面的24个

model输入是96,12的enc_input,enc_mark是96,4时间编码特征
dec_input是72,12,dec_mark是72,4


model 部分

主要是attention模块(其他都比较简单),在model/attn.py,看ProbAttention class,直接看forward函数

首先划分QKV,96个seqlen中选25个(U_part)

重点来了,_prob_QK函数

scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)

进入_prob_QK

首先K扩充了-3的维度,K_expand=(32,8,96,96,64)

index_sample随机采样出0~96的96×25的矩阵,K_sample取出(32,8,96,25,64)

Q和K_sample计算内积的到Q_K_sample(32,8,96,25)
Q_K_sample上计算max,选出M_top个max波峰最大的Q,得到Q_reduce(25个Q)

Q_reduce再和96个K做内积

    def _prob_QK(self, Q, K, sample_k, n_top):  # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k))  # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)

        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]

        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   M_top, :]  # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k

        return Q_K, M_top

_get_initial_context函数显示了如果没有选择到的Q,说明比较平庸,直接用平均V来表示

	 V_sum = V.mean(dim=-2)

_update_context
只更新25个Q

context_in[
            torch.arange(B)[:, None, None],
            torch.arange(H)[None, :, None],
            index, :]\
            = torch.matmul(attn, V).type_as(context_in)

attention做完
回到forward,做了一个蒸馏操作,MaxPool1d,stride=2,做个下采样
96len变成48len

ConvLayer(
  (downConv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), padding_mode=circular)
  (norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (activation): ELU(alpha=1.0)
  (maxPool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)

encoder做完
做decoder,用的模块和encoder一致,还有一个cross attention,都老生常谈,跳过…

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微服务的Feign到底是什么

Feign是什么 分区是一种数据库优化技术,它可以将大表按照一定的规则分成多个小表,从而提高查询和维护的效率。在分区的过程中,数据库会将数据按照分区规则分配到不同的分区中,并且可以在分区中使用索引和其他优化技术来提高查询效…

目标检测论文阅读:CBNet算法笔记

标题:CBNet: A Composite Backbone Network Architecture for Object Detection 期刊:TIP2022 论文地址:https://ieeexplore.ieee.org/document/9932281/ 官方代码:https://github.com/VDIGPKU/CBNetV2 作者单位:北京大…

【正点原子FPGA连载】第二十章AXI4接口之DDR读写实验 摘自【正点原子】DFZU2EG_4EV MPSoC之嵌入式Vitis开发指南

1)实验平台:正点原子MPSoC开发板 2)平台购买地址:https://detail.tmall.com/item.htm?id692450874670 3)全套实验源码手册视频下载地址: http://www.openedv.com/thread-340252-1-1.html 第二十章AXI4接口…

如何查看Spring Boot各版本的变化

目录 1.版本 2.基础特性和使用 3.新增特性和Bug修复 1.版本 打开Spring官网,点进Spring Boot项目我们会发现在不同版本后面会跟着不同的标签: 这些标签对应不同的版本,其意思如下: GA正式版本,通常意味着该版本已…

VsCode安装PlatformIO 开发ESP arduino,买的板子或者随便ESP,PlatformIO添加Board(不是自定义Board)

这次主要记录怎么给新建选板子的时候没有的板子下程序 我这里是一块 WiFi Kit 32 (V3) PlatformIO里面只有到V2 先从头开始,安装PlatformIO 安装PlatformIO 直接搜索安装 安装有时候会比较慢,左侧出现蚂蚁图标之后点击会显示 右下角会提示正在安…

【神经网络】Transformer基础问答

1.Transforme与LSTM的区别 transformer和LSTM最大的区别就是LSTM的训练是迭代的,无法并行训练,LSTM单元计算完T时刻信息后,才会处理T1时刻的信息,T 1时刻的计算依赖 T-时刻的隐层计算结果。而transformer的训练是并行了&#xff0…

AndroidStudio打包HBuilderX的H5+项目为安卓App【一次过,无任何异常报错】

目录 1.查看HBuilderX的版本号 2.下载Dcloud上对应的安卓SDK 3.下载完安卓SDK后,我们解压它,注意不要放在任何有中文组成的文件夹中【是否有中文决定于你鼠标单击上面路径后,第一张图还没鼠标单击,第二张已鼠标单击&#xff0c…

【前端工程化】01-Node.js基础

Node.js基础认识NodeNode的定义Node的应用场景Node的安装和版本管理Node的基本操作Node.js执行文件Node的参数传递Node的REPL认识Node Node的定义 Node.js是一个基于V8 JavaScript引擎的JavaScript运行时环境 Node.js为JavaScript提供了一些服务器级别的操作API 文件读写网…

背靠“湘潭系”的谭新乔,能带领湖南裕能再上一个台阶吗?

文丨熔财经作者|kinki近日,磷酸铁锂正极材料龙头湖南裕能正式登陆A股,上市当天市值超过了400亿元,投资者中一签可赚1.49万元,可谓近年低迷的资本市场中一支“大肉签”。不过在 “开门红”之后,湖南裕能的股价便一路下挫…

ETL工具(kettle) 与 ETL产品(BeeloadBeeDI) 差之毫厘,谬以千里

E T L——是英文Extract-Transform-Load的缩写,用来描述将数据从来源端经过抽取(extract)、转换(transform)、加载(load)至目的端的过程。工具——原指工作时所需用的器具,后引申为达…

Clickhouse学习(一):MergeTree概述

MergeTree一、Clickhouse表引擎概述二、MergeTree表引擎<一>、ReplacingMergeTree引擎<二>、SummingMergeTree引擎<三>、AggregatingMergeTree引擎三、MergeTree分区一、Clickhouse表引擎概述 MergeTree表引擎:允许根据日期和主键创建索引 1、ReplacingMerge…

实践IC-GVINS: 以惯导为核心的GNSS-Visual-INS组合导航系统

视觉导航系统对环境比较敏感&#xff0c;受到光照变化、重复纹理、动态物体等影响&#xff1b;惯性导航系统(INS)则完全自主工作&#xff0c;不受外部环境影响&#xff0c;能够实现连续、高频的自主导航&#xff0c;但其误差发散较快。两者组合能够取长补短&#xff0c;形成视觉…

毕业设计 基于STM32单片机生理监控心率脉搏TFT彩屏波形曲线设计

基于STM32单片机生理监控心率脉搏TFT彩屏波形曲线设计1、项目简介1.1 系统构成1.2 系统功能2、部分电路设计2.1 STM32F103C8T6核心系统电路设计2.2心率检测电路设计2.3 TFT2.4寸彩屏电路设计3、部分代码展示3.1 ADC初始化3.2 获取ADC采样值3.3 LCD引脚初始化3.3 在LCD指定位置显…

15 Nacos客户端实例注册源码分析

Nacos客户端实例注册源码分析 实例客户端注册入口 流程图&#xff1a; 实际上我们在真实的生产环境中&#xff0c;我们要让某一个服务注册到Nacos中&#xff0c;我们首先要引入一个依赖&#xff1a; <dependency><groupId>com.alibaba.cloud</groupId>&l…

Android与flutter混合开发

这里我使用的android studio版本是2020.3.1&#xff1b;flutter版本2.5.3。此前在网上搜索的很多教教程版本都不一样&#xff0c;新版的IDE和SDK让我遇到了很多坑故这里整理一下。一、创建项目1.在Android项目中点击File->New->New Flutter Project。File->New->Ne…

认识STM32和如何构建STM32工程

STM32介绍什么是单片机单片机(Single-Chip Microcomputer)是一种集成电路芯片&#xff0c;把具有数据处理能力的中央处理器CPU、随机存储器RAM、只读存储器ROM、多种/0口和中断系统、定时器/计数器等功能(可能还包括显示驱动电路、脉宽调制电路、模拟多路转换器、A/D转换器等电…

快速找到外贸客户的9种方法(建议收藏)

所有外贸企业想要做好外贸出口的头等大事&#xff0c;就是要快速的找到优质的外贸客户和订单&#xff0c;没有订单的达成&#xff0c;所有的努力都是图劳&#xff0c;还有可能会陷入一种虚假的繁荣&#xff0c;每天都很忙&#xff0c;但是没有结果。今天&#xff0c;小编就来分…

在VScode中添加Linux中的Docker容器中的Python解释器

VScode编辑器在安装好Python插件之后会自动选择环境变量中排序最高的那一个解释器作为默认解释器&#xff0c;而想要额外添加新的Python解释器就需要自己设置。 VScode编辑器安装在本地电脑 支持Python的docker安装在远程服务器 第一步&#xff0c;在/usr/local/下新建pytho…

Telnet 基础实验1: Telnet 实验

Telnet 基础实验1&#xff1a; Telnet 实验 拓扑图 配置命令 R1 的配置 undo ter mo sys sys R1 interface g0/0/0 ip address 192.168.1.1 255.255.255.0 qR2 的配置 undo ter mo system-view sysname R2 interface g0/0/0 ip address 192.168.1.2 255.255.255.0 q两台设…

微信小程序和webview使用postMessage交互

小程序和webview能交互&#xff0c;但是没有你想的那个完美小程序向webview传递参数只能使用url携带参数webview向小程序传递参数可以使用postMessage, 但是注意了&#xff0c;postMessage只会在特定的时机执行&#xff0c;请看官方文档由此可见&#xff0c;如果你想点击webvie…