Informer 论文学习笔记

news2024/9/22 11:32:19

论文:《Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting》
代码:https://github.com/zhouhaoyi/Informer2020
地址:https://arxiv.org/abs/2012.07436v3
特点

  1. 实现时间与空间复杂度为 O ( L ln ⁡ L ) \mathcal{O}(L\ln L) O(LlnL) 的自注意力;
  2. 使用自注意力提纯(Distilling)的方法,降低了特征的冗余;
  3. 以生成式的风格一次性输出长序列预测结果,杜绝了 One-by-One 方式中存在的误差积累;
  4. 基于上面的内容,创建新的 LSTF 模型 Informer。

核心贡献

  1. 用新的自注意力模块 ProbSparse Self-Attention 降低了原始 Self-Attention 的时间与空间复杂度;
  2. 提出 Self-Attention 净化(Distilling) 方法,进一步降低模型整体的复杂度;

Informer 模型的整体结构

在这里插入图片描述

ProbSparse Self-Attention

先介绍一下算法的整体流程,后面再介绍具体含义和原因。

Require:Tensor Q ∈ R m × d , K ∈ R n × d , V ∈ R n × d \pmb{Q}\in\mathbb{R}^{m\times d},\pmb{K}\in\mathbb{R}^{n\times d},\pmb{V}\in\mathbb{R}^{n\times d} QRm×d,KRn×d,VRn×d

  1. print set hyperparameter c c c, u = c ln ⁡ m u=c\ln m u=clnm and U = m ln ⁡ n U=m\ln n U=mlnn
  2. randomly select U U U dot-product pairs from K \pmb{K} K to K ˉ \bar{\pmb{K}} Kˉ
  3. set the sample score S ˉ = Q K ˉ T \bar{\pmb{S}}=\pmb{Q}\bar{\pmb{K}}^T Sˉ=QKˉT
  4. compute the measurement M = max ⁡ ( S ˉ ) − mean ( S ˉ ) M=\max(\bar{\pmb{S}})-\text{mean}(\bar{\pmb{S}}) M=max(Sˉ)mean(Sˉ) by row
  5. set Top- u \text{Top-}u Top-u queries under M M M as Q ˉ \bar{\pmb{Q}} Qˉ
  6. set S 1 = softmax ( Q ˉ K T / d ) ⋅ V \pmb{S}_1=\text{softmax}(\bar{\pmb{Q}}\pmb{K}^T/\sqrt{d})\cdot \pmb{V} S1=softmax(QˉKT/d )V
  7. set S 0 = mean ( V ) \pmb{S}_0=\text{mean}(\pmb{V}) S0=mean(V)
  8. set S = { S 1 , S 0 } \pmb{S}=\{\pmb{S}_1,\pmb{S}_0\} S={S1,S0} by their original rows accordingly

Ensure:self-attention feature map S \pmb{S} S

ProbSparse Self-Attention 的基本思想

利用原始 Self-Attention 中的稀疏性,降低算法的时间与空间复杂度。
核心方法利用下式选出对 value 更有价值的 query

M ˉ ( q i , K ) = max ⁡ j { q i k j T d } − 1 L K Σ j = 1 L K q i k j T d \bar{M}(\pmb{q}_i,\pmb{K})=\max_{j}\{\frac{\pmb{q}_i\pmb{k}_j^T}{\sqrt{d}}\}-\frac{1}{L_K}\Sigma^{L_K}_{j=1}\frac{\pmb{q}_i\pmb{k}_j^T}{\sqrt{d}} Mˉ(qi,K)=jmax{d qikjT}LK1Σj=1LKd qikjT

即算法中的 3 与 4。

为什么用这种方法?
原始 Self-Attention softmax ( Q K T / d ) ⋅ V \text{softmax}(\pmb{Q}\pmb{K}^T/\sqrt{d})\cdot \pmb{V} softmax(QKT/d )V 可改写为下面的概率形式:
A ( q i , K , V ) = Σ j k ( q i , k j ) Σ l k ( q i , k l ) v j = E p ( k j ∣ q i ) [ v j ] \mathcal{A}(\pmb{q}_i,\pmb{K},\pmb{V})=\Sigma_j\frac{k(\pmb{q}_i,\pmb{k}_j)}{\Sigma_l k(\pmb{q}_i,\pmb{k}_l)}\pmb{v}_j=\mathbb{E}_{p(\pmb{k}_j|\pmb{q}_i)}[\pmb{v}_j] A(qi,K,V)=ΣjΣlk(qi,kl)k(qi,kj)vj=Ep(kjqi)[vj]

k ( ⋅ , ⋅ ) k(\cdot,\cdot) k(,) 的含义不再赘述。

为度量 query 的稀疏性,可以考虑 p ( k j ∣ q i ) p(\pmb{k}_j|\pmb{q}_i) p(kjqi) 与均匀分布 q ( k j ∣ q i ) = 1 / L K q(\pmb{k}_j|\pmb{q}_i)=1/L_K q(kjqi)=1/LK`之间的 KL 散度 K L ( q ∣ ∣ p ) = − Σ 1 L K ln ⁡ ( k ( q i , k j ) Σ l k ( q i , k l ) L K ) KL(q||p)=-\Sigma\frac{1}{L_K}\ln(\frac{k(\pmb{q}_i,\pmb{k}_j)}{\Sigma_l k(\pmb{q}_i,\pmb{k}_l)}L_K) KL(q∣∣p)=ΣLK1ln(Σlk(qi,kl)k(qi,kj)LK),展开并舍弃常数项之后可得第 i 个 query 的稀疏性度量为:
M ( q i , K ) = ln ⁡ Σ j = 1 L K e q i k j T d − 1 L K Σ j = 1 L K q i k j T d M(\pmb{q}_i,\pmb{K})=\ln\Sigma^{L_K}_{j=1}e^{\frac{\pmb{q}_i\pmb{k}^T_j}{\sqrt{d}}}-\frac{1}{L_K}\Sigma^{L_K}_{j=1}\frac{\pmb{q}_i\pmb{k}^T_j}{\sqrt{d}} M(qi,K)=lnΣj=1LKed qikjTLK1Σj=1LKd qikjT

基于 M,可以选用 Top-u 的 queries 构成的 Q ˉ \bar{\pmb{Q}} Qˉ 代替 Q 计算自注意力(文中设置 u = c ln ⁡ L Q u=c\ln L_Q u=clnLQ,其中 c 是超参数)。

为什么要使用这两个分布的 KL 散度?为什么M可以度量注意力的稀疏性?:Self-Attention 涉及到了点积运算,该运算表明 p ( k j ∣ q i ) p(\pmb{k}_j|\pmb{q}_i) p(kjqi) 与均匀分布 q ( k j ∣ q i ) = 1 / L K q(\pmb{k}_j|\pmb{q}_i)=1/L_K q(kjqi)=1/LK 之间的差别越大越好,这启发我们使用 M 作为稀疏性的度量。
新问题:M 中的第一项实际计算时的复杂度仍旧是 O ( L 2 ) \mathcal{O}(L^2) O(L2) 的。
解决方式:基于 Lemma 1 与 Proposition 1,先随机采样 U = L K ln ⁡ L Q U=L_K\ln L_Q U=LKlnLQ 个 k-q 对,然后在这 U 个 k-q 对上计算 M ˉ = max ⁡ j { q i k j T d } − mean j { q i k j T d } \bar{M}=\max_{j}\{\frac{\pmb{q}_i\pmb{k}^T_j}{\sqrt{d}}\}-\text{mean}_{j}\{\frac{\pmb{q}_i\pmb{k}^T_j}{\sqrt{d}}\} Mˉ=maxj{d qikjT}meanj{d qikjT} 作为 M 的近似值,最后选定 top-u 个 query 用作 Self-Attention 计算。(即算法中的 1、2、5 和 6,这里两次降低计算量)

补充

  • Lemma 1For each query q i ∈ R d \pmb{q}_i\in\mathbb{R}^d qiRd and k j ∈ R d \pmb{k}_j\in\mathbb{R}^d kjRd in the keys set K \pmb{K} K, we have the bound as ln ⁡ L K ≤ M ( q i , K ) ≤ ln ⁡ L K + M ˉ ( q i , K ) \ln L_K\leq M(\pmb{q}_i,\pmb{K})\leq\ln L_K +\bar{M}(\pmb{q}_i,\pmb{K}) lnLKM(qi,K)lnLK+Mˉ(qi,K). When q i ∈ K \pmb{q}_i\in\pmb{K} qiK, it also holds.(它说明可以用 M ˉ \bar{M} Mˉ 做近似计算。利用凸函数证明)
  • Proposition 1: Assuming k j ∼ N ( μ , Σ ) \pmb{k}_j\sim\mathcal{N}(\mu,\Sigma) kjN(μ,Σ) and we let q k i \pmb{q}\pmb{k}_i qki denote set { ( q i k j T ) / d ∣ j = 1 , ⋯   , L K } \{(\pmb{q}_i\pmb{k}_j^T)/\sqrt{d}|j=1,\cdots,L_K\} {(qikjT)/d j=1,,LK}, then ∀ M m = max ⁡ i M ( q i , K ) \forall M_m=\max_i M(\pmb{q}_i,\pmb{K}) Mm=maxiM(qi,K) there exist κ > 0 \kappa>0 κ>0 such that: in the interval ∀ q 1 , q 2 ∈ { q ∣ M ( q , K ) ∈ [ M m , M m − κ ) } \forall\pmb{q}_1,\pmb{q}_2\in\{\pmb{q}|M(\pmb{q},\pmb{K})\in[M_m,M_m-\kappa)\} q1,q2{qM(q,K)[Mm,Mmκ)}, if M ˉ ( q 1 , K ) > M ˉ ( q 2 , K ) \bar{M}(\pmb{q}_1,\pmb{K})>\bar{M}(\pmb{q}_2,\pmb{K}) Mˉ(q1,K)>Mˉ(q2,K) and Var ( q k 1 ) > Var ( q k 2 ) \text{Var}(\pmb{q}\pmb{k}_1)>\text{Var}(\pmb{q}\pmb{k}_2) Var(qk1)>Var(qk2), we have high probability that M ( q 1 , K ) > M ( q 2 , K ) M(\pmb{q}_1,\pmb{K})>M(\pmb{q}_2,\pmb{K}) M(q1,K)>M(q2,K).(采样后不影响排序,这说明采样之后仍旧可以保证 Top-u 的可靠性。利用对数正态分布及数值化样例定性式证明)

Self-Attention Distilling

目的:在自注意力模块之后,过滤掉 value 中的冗余信息。
方式:使用 CNN、MaxPooling 进行下采样:

\pmb{X}^t_{j+1}=\text{MaxPool}(\text{ELU}(\text{Conv1d}([\pmb{X}^t_j]_{AB})))

其中,CNN 的 kernel-size=3,pooling 的 stride=2,整体的空间复杂度为: O ( ( 2 − ϵ ) L log ⁡ L ) \mathcal{O}((2-\epsilon)L\log L) O((2ϵ)LlogL) ϵ \epsilon ϵ 是一个小量(原因是: 1 + 1 2 + 1 4 + 1 8 + ⋯ 1+\frac{1}{2}+\frac{1}{4}+\frac{1}{8}+\cdots 1+21+41+81+)。


其他

  1. Decoder:与原始 Transformer 的一致;
  2. 生成式推断(Generative Inference):一次性输出长序列预测结果,而非迭代地逐个输出结果。
  3. Loss Function:MSE
  4. 位置嵌入(Position Embedding):局部时间戳的位置嵌入(PE,使用sin函数)、全局时间戳的位置嵌入(SE,用于日月周节日等特殊时间点) PE ( L x × ( t − 1 ) + i , ) + Σ [ SE ( L x × ( t − 1 ) + i ) ] p \text{PE}_{(L_x\times(t-1)+i,)}+\Sigma[\text{SE}_{(L_x\times(t-1)+i)}]_p PE(Lx×(t1)+i,)+Σ[SE(Lx×(t1)+i)]p
    # PE
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    # SE
    minute_x  = nn.Embedding( 4, d_model)(x[:,:,4])
    hour_x    = nn.Embedding(24, d_model)(x[:,:,3])
    weekday_x = nn.Embedding( 7, d_model)(x[:,:,2])
    day_x     = nn.Embedding(32, d_model)(x[:,:,1])
    month_x   = nn.Embedding(13, d_model)(x[:,:,0])
    se = hour_x + weekday_x + day_x + month_x + minute_x
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/811680.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LaTex4【下载模板、引入文献】

下载latex模板:(模板官网一般都有,去找) 我这随便找了一个: 下载得到一个压缩包,然后用overleaf打开👇: (然后改里面的内容就好啦) 另外,有很多在线的数学公式编辑器&am…

1 Python的前世今生

为什么要学Python 这个问题,仁者见仁,智者见智。编程界有一句名言:“人生苦短,我用Python”,这句话似乎道出了一些原因。Python是一门简单直观的语言,更是一门注重可读性和效率的语言。解决同一个问题&…

数据结构空间复杂度

数据结构空间复杂度 空间复杂度常见的复杂度对比 空间复杂度 空间复杂度也是一个数学表达式,是对一个算法在运行过程中临时额外占用存储空间大小的量度 。 空间复杂度不是程序占用了多少bytes的空间,因为这个也没太大意义,所以空间复杂度算的…

jmeter常用的性能测试监听器

jmeter中提供了很多性能数据的监听器,我们通过监听器可以来分析性能瓶颈 本文以500线程的阶梯加压测试结果来描述图表。 常用监听器 1:Transactions per Second 监听动态TPS,用来分析吞吐量。其中横坐标是运行时间,纵坐标是TPS…

【后端面经】微服务构架 (1-6) | 隔离:如何确保心悦会员体验无忧?唱响隔离的鸣奏曲!

文章目录 一、前置知识1、什么是隔离?2、为什么要隔离?3、怎么进行隔离?A) 机房隔离B) 实例隔离C) 分组隔离D) 连接池隔离 与 线程池隔离E) 信号量隔离F) 第三方依赖隔离二、面试环节1、面试准备2、基本思路3、亮点方案A) 慢任务隔离B) 制作库与线上库分离三、章节总结 …

【Linux】TCP协议

​🌠 作者:阿亮joy. 🎆专栏:《学会Linux》 🎇 座右铭:每个优秀的人都有一段沉默的时光,那段时光是付出了很多努力却得不到结果的日子,我们把它叫做扎根 目录 👉TCP协议&…

java设计模式-建造者(Builder)设计模式

介绍 Java的建造者(Builder)设计模式可以将产品的内部表现和产品的构建过程分离开来,这样使用同一个构建过程来构建不同内部表现的产品。 建造者设计模式涉及如下角色: 产品(Product)角色:被…

通过clone的方式,下载huggingface中的大模型(git lfs install)

1、如图:可以手动一个个文件下载,但是那样太慢了,此时,可以点击下图圈起来的地方。 2、点击【Clone repository】,在命令行中,输入【git lfs install】(安装了这个,才会下载大文件&a…

【Git】git企业开发命令整理,以及注意点

1.git企业开发过程 业务的分支大概有以下几个: master:代码随时可能上线 develop:代码最新 feature/xxx:实际业务开发分支 release/xxx:预发布分支 fix:修复bug分支 过程大概是这样的: 首…

机器学习知识经验分享之六:决策树

python语言用于深度学习较为广泛,R语言用于机器学习领域中的数据预测和数据处理算法较多,后续将更多分享机器学习数据预测相关知识的分享,有需要的朋友可持续关注,有疑问可以关注后私信留言。 目录 一、R语言介绍 二、R语言安装…

【1.3】Java微服务:Spring Cloud版本说明

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。 🍎个人主页:Meteors.的博客 💞当前专栏: 微服务 ✨特色专栏: 知识分享 &#x…

python实现递推算法解决分鱼问题

一、问题描述 A、B、C、D、E5个人合伙夜间捕鱼,凌晨时都已经疲惫不堪,于是各自在河边的树丛中找地方睡着了。第二天日上三竿时,A第一个醒来,他将鱼平分为5份,把多余的一条扔回河中,然后拿着自己的一份回家…

如何快速同步第三方平台数据?

前言 最近知识星球中有位小伙伴问了我一个问题:如何快速同步第三方平台数据? 他们有个业务需求是:需要同步全国34个省市,多个系统的8种业务数据,到他们公司的系统当中。 他们需求同步全量的数据和增量的数据。 全量…

Ray

public Ray(Vector3 origin, Vector3 direction); 射线:origin为起始点,direction为射线方向 public static bool Raycast(Ray ray); 物理射线监测:返回值为bool型,可以确定射线有无碰撞到碰撞体 public static bool Raycast(R…

isp调试工具环境搭建及其介绍!

一、isp调试环境搭建: 后期调试isp,是在rv1126提供的RKISP2.x Tuner工具上进行调试,所以我们大前提必须要把这个环境和一些操作先搞熟悉来,后面有一些专用术语,我们遇到了再去看,现在专门看一些专用术语&am…

Spring Cloud简单记录

1. Spring Cloud是什么 工作这么多年,哈哈。。。没深入理解spring,spring cloud也是没有用过。趁着周末,搞一搞概念,先搞清楚是什么,虽然是什么只有用过之后才能理解的更具体,但是还是需要先整体的熟悉一下…

联想拯救者如何开启独显直连

不同机型有不同的切换方式,下面就分别给大家讲一下: 显卡模式切换方式一: 打开联想电脑管家,选择游戏模式,在左侧菜单栏选择显卡模式,然后就能看到显卡的输出模式了,默认是混合模式&#xff0c…

React之组件的生命周期

React之组件的生命周期 一、概述二、整体说明三、挂载阶段四、更新阶段五、卸载阶段 一、概述 生命周期:一个事务从创建到最后消亡经历的整个过程组件的生命周期:组件从被创建到挂载到页面中运行,再到组件不用时卸载的过程意义:理解组件的生…

RT1052的定时器

文章目录 1 通用定时器1.1 定时器框图1.2 实现周期性中断 2 相关寄存器3 定时器配置3.1 时钟使能3.2 初始化GPT1定时器3.2.1 base3.2.2 initConfig3.2.2.1 clockSorce3.2.2.2 divider3.2.2.3 enablexxxxx 3.3 设置 GPT1 比较值3.3.1 base3.3.2 channel3.3.3 value 3.4 设置 GPT…

10-矩阵(matrix)_方阵_对称阵_单位阵_对角阵

矩阵及其运算 [ a 11 ⋯ a 1 n ⋯ ⋯ ⋯ a m 1 ⋯ a m n ] \begin{bmatrix} a_{11} & \cdots & a_{1n} \\ \cdots & \cdots & \cdots \\ a_{m1} & \cdots & a_{mn} \\ \end{bmatrix} ​a11​⋯am1​​⋯⋯⋯​a1n​⋯amn​​ ​ 矩阵就是二维数组&…