稀疏注意力:时间序列预测的局部性和Transformer的存储瓶颈

news2025/1/14 1:16:12

        时间序列预测是许多领域的重要问题,包括对太阳能发电厂发电量、电力消耗和交通拥堵情况的预测。在本文中,提出用Transformer来解决这类预测问题。虽然在我们的初步研究中对其性能印象深刻,但发现了它的两个主要缺点:(1)位置不可知性:规范Transformer架构中的点积自关注对局部上下文不敏感,这可能使模型在时间序列中容易出现异常;(2)内存瓶颈:正则Transformer的空间复杂度随序列长度L呈二次增长,使得直接建模长时间序列变得不可行。

        为了解决这两个问题,首先提出了卷积自注意力,通过使用因果卷积产生查询和键,以便更好地将本地上下文纳入注意机制。然后,提出了仅O(L(log L)^2)内存开销的LogSparse Transformer,在内存预算受限的情况下,提高了对细粒度、长期依赖性强的时间序列的预测精度。在合成数据和真实世界数据集上的实验表明,它比最先进的技术更有优势。        

1. 引言

        深度神经网络被提出作为另一种解决方案,其中递归神经网络(RNN)已被用于以自回归的方式对时间序列建模。然而,众所周知,RNN很难训练。由于梯度消失和爆炸问题。尽管出现了各种变体,包括LSTM和GRU(门控循环单元),问题仍然没有解决。如何对长期依赖关系进行建模成为实现良好性能的关键步骤。

        规范Transformer的空间复杂度随输入长度L呈二次增长,这对直接建模细粒度长时间序列造成了内存瓶颈。 主要贡献:①成功地将Transformer架构应用于时间序列预测,并在合成数据集和真实数据集上进行了广泛的实验,以验证Transformer在处理长期依赖关系方面比基于RNN的模型更好的潜在价值。

        ②提出卷积自注意,通过使用因果卷积在自注意层产生查询和键。感知局部上下文(如形状)的查询键匹配可以帮助模型降低训练损失,进一步提高预测精度。

        ③提出LogSparse Transformer,只有O(L(log L)^2)空间复杂度来打破内存瓶颈,不仅使细粒度的长时间序列建模可行,而且与规范Transformer相比,使用更少的内存可以产生相当甚至更好的结果。

2. 相关工作

        时间序列预测领域中不同方法的发展和挑战,并强调了几种主要的模型。首先,文章提到ARIMA模型,这是时间序列预测中非常著名的一种方法。ARIMA模型因其统计性质和Box-Jenkins方法论而备受推崇,后者是一种在模型选择过程中广泛使用的方法。因此,ARIMA模型通常是实践者在时间序列预测中首先尝试的工具。然而,ARIMA模型有一些局限性。它假设时间序列是线性的,这在处理更复杂的、非线性的时间序列时可能表现不佳。此外,ARIMA模型的扩展性有限,难以应用于大规模的预测任务,并且每个时间序列都必须独立拟合,这意味着无法在相似的时间序列之间共享信息。

        相反,有些方法尝试通过矩阵分解的方法处理相关时间序列数据,把预测问题看作矩阵分解问题。另外,还有研究提出了分层贝叶斯方法,从图模型的角度来学习多个相关的计数时间序列。接着,文章介绍了深度神经网络在时间序列预测中的应用。这些模型可以捕捉相关时间序列之间的共享信息,从而提高预测的准确性。比如,有研究将传统的自回归(AR)模型与递归神经网络(RNN)结合起来,采用编码器-解码器的方式对概率分布进行建模。另一种方法使用RNN作为编码器,使用多层感知机(MLP)作为解码器,以解决误差累积问题,并且能够进行多步并行预测。此外,还有模型使用全局RNN来直接输出线性状态空间模型(SSM)的参数,目的是用局部线性片段来近似非线性动态。也有研究通过使用局部高斯过程来处理每个时间序列中的噪声,同时使用全局RNN来建模共享模式。另一些方法则试图结合AR模型和SSM的优势,保持复杂的潜在过程以进行多步并行预测。

        Transformer在序列建模中取得了很大成功,并且已被广泛应用于翻译、语音、音乐和图像生成等领域。然而,当处理极长的序列时,注意力机制的计算复杂度会随着序列长度的增加而呈二次方增长,这在处理高粒度且具有强长期依赖性的时间序列时,成为了一个严重的问题。

3. 背景

3.1 问题定义

        首先定义了一个时间序列预测的问题。在这个问题中,我们有一个包含N个相关单变量时间序列的集合,每个时间序列记为 z_{i,1:t_0},表示从时间1到时间 t_0​ 的观测值。目标是预测这些时间序列未来的 \tau 个时间步的值,即 z_{i,t_0+1:t_0+\tau}​。此外,假设有一个与时间相关的协变量集合 x_{i,1:t_0+\tau},其维度为d,这些协变量可能包括诸如星期几、一天中的小时等已知信息。我们需要建模条件分布 p(z_{i,t_0+1:t_0+\tau}|z_{i,1:t_0},x_{i,1:t_0+\tau};\omega),其中 \omega 是所有时间序列共享的可学习参数。

        接着,问题被简化为学习一个一步预测模型,即 p(z_t | z_{1:t-1}, x_{1:t}; \omega),其中 \omega 表示模型的可学习参数。为了充分利用观测值和协变量,作者将它们连接起来,形成一个扩展矩阵 y_t = [z_{t-1}, x_t],然后通过 Y_t = [y_1, \cdots, y_t]^T 表示所有的观测数据和协变量的集合。接下来,研究探索了一个合适的模型 z_t \sim f(Y_t),用于预测给定 Y_t​ 时 z_t​ 的分布。

        然后,文章介绍了Transformer模型,并提出将其作为函数 f 的实例,因为Transformer通过多头自注意力机制能够捕捉到时间序列中的长短期依赖性。不同的注意力头可以专注于不同的时间模式,这使得Transformer在时间序列预测中成为一个很有潜力的候选模型。

        在自注意力层中,多头自注意力子层同时将 Y 转换为H个不同的查询矩阵 Q_h、键矩阵 K_h​ 和值矩阵 V_h​,其中 h = 1, \cdots, H。这些矩阵通过线性投影获得,它们的学习参数分别为 W_Q^h​、 W_K^h​ 和 W_V^h​。在这些线性投影之后,缩放点积注意力机制计算出一系列的向量输出 O_h​,这些输出是通过公式 O_h=\mathrm{Attention}(Q_h,K_h,V_h)=\mathrm{softmax}\left(\frac{Q_hK_h^T}{\sqrt{d_k}}\cdot M\right)V_h 计算得到的。这里,掩码矩阵 MMM 被应用于过滤右侧的注意力,以避免未来信息泄露。然后,所有的 O_h 被连接起来并再次进行线性投影。最后,在注意力输出上叠加了一个由两层全连接网络和中间ReLU激活层组成的位置前馈子层。

4. 方法论

4.1 增强Transformer的局部性

        时间序列中的模式可能由于各种事件(如假期和极端天气)随时间发生显著变化的现象。因此,判断一个观测点是异常点、变更点还是模式的一部分,很大程度上依赖于其周围的上下文。然而,在经典Transformer的自注意力层中,查询和键之间的相似性是基于它们逐点值来计算的,未能充分利用局部上下文信息(如形状)。这种对局部上下文不敏感的查询-键匹配可能会导致自注意力模块混淆观测值的性质,从而引发潜在的优化问题。

        为了解决这个问题,提出了卷积自注意力机制。图1(c)和(d)展示了这种卷积自注意力的架构。不同于使用核大小为1且步幅为1的卷积(即矩阵乘法),采用核大小为k且步幅为1的因果卷积,将输入(经过适当的填充)转换为查询和键。因果卷积确保当前位置不会访问未来信息。通过使用因果卷积,生成的查询和键能够更加感知局部上下文,从而基于局部上下文信息(如局部形状)来计算相似性,而不是简单的逐点取值,这有助于提高预测的准确性。值得注意的是,当k=1时,卷积自注意力将退化为经典自注意力,因此它可以看作经典自注意力的一种广义形式。

        因果卷积:具体来说,假设输入序列的长度为 T,卷积核的大小为 k。在因果卷积中,输入会在前面添加 k-1个零填充,这样卷积运算就只会考虑当前和之前的时间步,而不会涉及未来的时间步。这种方式保证了模型在训练和推理时遵循时间顺序,从而保持因果性。 

        图1展示了经典自注意力层和卷积自注意力层的比较。图1(a)显示了经典自注意力可能错误地逐点匹配输入的情况,图1(b)则展示了经典自注意力在Transformer中的应用。而图1(c)和(d)则展示了卷积自注意力如何通过形状匹配来正确匹配最相关的特征。

4.2 突破Transformer的内存瓶颈

        首先对经典Transformer在traffic-f数据集上学习到的注意力模式进行了定性评估。traffic-f数据集包含旧金山湾区963条车道的占用率数据,每20分钟记录一次。在traffic-f数据集上训练了一个10层的经典Transformer,并对学习到的注意力模式进行了可视化。引入某种形式的稀疏性而不会显著影响性能。更重要的是,对于长度为 L 的序列,计算每对单元之间的注意力分数会导致 O(L^2) 的内存使用量,使得对具有精细粒度和强长期依赖的长时间序列进行建模变得非常困难。

        为了解决这个问题,提出了LogSparse Transformer,这种方法只需要计算每个单元在每层中的 O(\log L) 个点积。此外,只需要堆叠最多O(\log L)层,模型就能够访问每个单元的信息。因此,总的内存使用成本仅为 O(L(\log L)^2)。我们将 I_k^l​ 定义为在第 k 层到第 k+1 层计算过程中,单元 l 表示可以访问的单元索引集。在标准的Transformer自注意力中, I_k^l = \{j : j \leq l\},这意味着每个单元都可以访问其所有过去的单元及其自身,如图3(a)所示。

        然而,这种算法在输入长度增加时会导致空间复杂度的二次增长。为了解决这个问题,提出选择 I_k^l 的一个子集 I_k^l \subseteq \{j : j \leq l\},使得 \|I_k^l\| 不会随着 l 的增加而增长得太快。选择索引的一个有效方法是 |I_k^l| \propto \log L

        需要注意的是,单元 l 是在第 k 层中通过加权组合索引为 I_k^l​ 的单元生成的,并且可以将这些信息传递给下一层的后续单元。令 S_k^l 为包含所有到第 k 层为止传递给单元 l 的单元索引的集合。为了确保每个单元接收到所有之前的单元及其自身的信息,堆叠的层数 \tilde{k}_l 应满足 S_{\tilde{k}_l}^l = \{j : j \leq l\},即对于每个 lj \leq l,存在一个具有 \tilde{k}_l 条边的路径 P_{jl} = (j, p_1, p_2, \dots, l),其中 j \in I_1^{p_1}, p_1 \in I_2^{p_2}, \dots, p_{\tilde{k}_l-1} \in I_{\tilde{k}_l}^l

        通过允许每个单元仅以指数步长访问其之前的单元及其自身来提出LogSparse自注意力。即对于所有的 k 和 lI_k^l = \{l \% 2^{\lfloor \log_2 l \rfloor}, l \% 2^{\lfloor \log_2 l \rfloor - 1}, \dots, l \% 2^0, l\},其中 \lfloor \cdot \rfloor 表示向下取整运算,如图3(b)所示。

        定理1表明,尽管每层的内存使用量从 O(L^2) 减少到 O(L\log^2 L),但信息仍然可以从任意单元流向另一个单元,只需稍微“加深”模型——将层数设为 \lfloor \log_2 L \rfloor+1。这意味着总体内存使用量为 O(L(\log^2 L)^2),解决了Transformer在GPU内存限制下的扩展性瓶颈。此外,随着两个单元之间的距离增大,路径的数量会以 log_2(l - j) 的超指数速率增加,这表明LogSparse Transformer在建模精细的长期依赖关系时能够实现丰富的信息流动。

4.3 Logparase注意力

        LogSparse注意力是一种针对长序列时间复杂度和内存使用量进行优化的自注意力机制。它是对经典Transformer模型中自注意力机制的一种改进,旨在解决处理长序列时计算资源消耗过大的问题。

        在传统的Transformer中,自注意力机制的计算复杂度是二次方的,即对于长度为 L 的序列,每个元素需要与其他 L-1个元素进行相互计算,导致整个序列的计算量和内存使用量为 O(L^2)。这种复杂度在处理长序列时会变得非常昂贵,尤其是在需要处理大量数据的情况下。

LogSparse注意力 通过以下方式优化了这一过程:

  1. 选择性注意力(Selective Attention):LogSparse注意力并不计算每个序列元素与所有其他元素之间的注意力得分,而是引入了一种稀疏化策略。具体来说,它允许每个元素只关注一小部分与其相关的元素,而不是所有元素。这些相关元素的选择是基于指数步长的,即每个元素只与其之前的少量元素进行注意力计算,这些元素之间的距离按对数规律增长。这意味着,如果当前元素是第 l 个,那么它只会与之前的一些元素进行计算,而这些元素的索引为 l,l-2^1, l-2^2,\dots,l-2^{\left \lfloor log_2l \right \rfloor} 等。

  2. 对数级别的复杂度(Logarithmic Complexity):这种选择性注意力策略将原本的 O(L^2) 复杂度降低到了 O(L \log^2 L)。因为每个元素只需计算 O(\log L) 个注意力分数,而整个序列需要堆叠 O(\log L) 层,以确保所有元素都能互相通信。

        通过这种方法,LogSparse注意力在处理长序列时能够显著减少内存使用和计算时间,同时保留Transformer模型的强大建模能力,特别是对于长时间依赖关系的建模非常有效。

4.3.1 增强模型性能 

如何在LogSparse自注意力机制的基础上进一步增强模型的性能,同时保持计算复杂度的控制?

  1. 局部注意力(Local Attention):在LogSparse Transformer中,虽然每个单元只需要访问先前的一些关键单元,但为了更好地捕捉局部信息(如趋势),可以让每个单元密集地关注其左侧邻近的单元,窗口大小为 O(\log^2 L)。这样,每个单元可以利用更多的局部信息来进行当前步的预测。在这种局部窗口之外,仍然可以继续采用LogSparse注意力策略,如图3(c)所示。

  2. 重启注意力(Restart Attention):这个策略是将整个输入序列(长度为 L)分成多个子序列,每个子序列的长度为 L_{sub},其中 L_{sub} \approx L。对每个子序列分别应用LogSparse注意力策略,类似于重新开始注意力计算的过程。这样可以减少模型处理每个子序列时的复杂性,并且每个子序列可以独立地进行信息处理,如图3(d)所示。

  3. 结合局部注意力和重启注意力:使用局部注意力和重启注意力不会改变LogSparse自注意力策略的计算复杂度,但会增加更多的信息路径,并减少路径中所需的边数。这意味着可以在不增加计算成本的情况下,提高模型对局部信息的捕捉能力和对长序列的处理能力。通过结合局部注意力和重启注意力,模型能够更有效地捕捉序列中的各种模式和趋势。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2044997.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++_2_ inline内联函数 宏函数(2/3)

C推出了inline关键字,其目的是为了替代C语言中的宏函数。 我们先来回顾宏函数: 宏函数 现有个需求:要求你写一个Add(x,y)的宏函数。 正确的写法有一种,错误的写法倒是五花八门,我们先来“见不贤而自省也。” // …

windows下部署redis3.2

一、下载redis3.2的包 6.2.6的包也有,但无法安装为Windows服务,暂时舍弃。 直接运行: redis-server redis.windows.conf 修改密码, 对应 redis.windows.conf 中的 requirepass 节点,注意去掉前面的# 修改端口,对应…

缺陷检测AI 重要参数解释

一、参数介绍 基本参数 True Positives (TP) True Positives (TP) 是一个用于评估模型性能的术语。它指的是模型正确预测为正例(Positive)的样本数量,即实际为正例且被正确分类为正例的样本数量。 False Positives (FP) FP (False Posit…

Python 文件目录操作,以及json.dump() 和 json.load()

import os 是用来引入 Python 标准库中的 os 模块的,这个模块提供了与操作系统交互的功能。这个模块常用于文件和目录操作,比如获取文件的目录路径、创建目录等。 如果你在代码中需要使用与操作系统相关的功能(例如获取目录名、检查文件是否…

qt-11基本对话框(消息框)

基本对话框--消息框 msgboxdlg.hmsgboxdlg.cppmain.cpp运行图QustionMsgInFormationMsgWarningMsgCriticalMsgAboutMsgAboutAtMsg自定义 msgboxdlg.h #ifndef MSGBOXDLG_H #define MSGBOXDLG_H#include <QDialog> #include <QLabel> #include <QPushButton>…

Cesium模型制作,解决Cesium加载glb/GLTF显示太黑不在中心等问题

Cesium模型制作&#xff0c;解决Cesium加载glb/GLTF显示太黑不在中心等问题 QQ可以联系这里&#xff0c;谢谢

电商搜索新纪元:大模型引领购物体验革新

随着电商行业的蓬勃发展&#xff0c;搜索技术作为连接用户与商品的桥梁&#xff0c;其重要性日益凸显。在技术不断革新的今天&#xff0c;电商搜索技术经历了哪些阶段&#xff1f;面对大模型的飞速发展&#xff0c;企业又将如何把握趋势&#xff0c;应对挑战&#xff1f;为了深…

openinstall支持抖音游戏小手柄监测,助力游戏联运生态高效增长

近来&#xff0c;抖音“小手柄”功能风靡游戏广告生态&#xff0c;通过新颖的联运形式成功将游戏广告触达到抖音整个流量池&#xff0c;由于受众较广&#xff0c;小手柄也是目前直播场数、点赞数最高的形式。 为了帮助广告主快速捕捉流量红利&#xff0c;打通抖音小手柄的数据…

【选型指南】大流量停车场和高端停车场如何选择停车场管理系统?

在当今快节奏的城市生活中&#xff0c;大型停车场和高端车场的运营者面临着一系列挑战&#xff0c;尤其是在车辆流量大和客户期望高的情况下。选择一个合适的停车场管理系统&#xff0c;不仅关系到日常运营的效率&#xff0c;更关系到客户的满意度和车场的整体形象。 捷顺科技认…

螺纹钢生产线中测径仪对基圆和负公差的测量和影响

螺纹钢生产线中测径仪的作用 在螺纹钢生产线中&#xff0c;测径仪是一种关键的检测设备&#xff0c;它负责对螺纹钢的基圆直径、横肋和纵肋等尺寸进行实时测量。测径仪的数据对于监控和控制螺纹钢的生产质量至关重要&#xff0c;尤其是在进行负公差轧制时&#xff0c;它能够提供…

K8S中使用英伟达GPU —— 筑梦之路

前提条件 根据不同的操作系统&#xff0c;安装好显卡驱动&#xff0c;并能正常识别出来显卡&#xff0c;比如如下截图&#xff1a; GPU容器创建流程 containerd --> containerd-shim--> nvidia-container-runtime --> nvidia-container-runtime-hook --> libnvid…

【Spring Boot - 注解】@ResponseBody 注解:处理 JSON 响应

文章目录 一、ResponseBody 注解概述1. 注解的功能2. 主要功能 二、ResponseBody 的工作原理1. 接口定义2. 消息转换器3. 自动配置与默认行为 三、ResponseBody 的应用场景1. RESTful API 的实现2. 返回复杂数据结构3. 错误处理和异常处理 四、ResponseBody 的配置和自定义1. 自…

Rabbit的学习——从安装到集群

一、MQ概念 1.1、异步通讯和同步通讯 1.2、同步调用和异步调用 1.2.1、同步调用 1.2.2、异步调用 1.3、消息队列的作用 1.3.1、流量削峰/限流 1.3.2、 应用解耦 1.3.3、异步处理 1.4、消息队列的两种模式 1.4.1、点对点模式 1.4.2、发布/订阅模式 二、RabbitMQ 2.1…

MyBatis Plus批量写入慢?

1. 数据库连接配置 在使用 MyBatis Plus 进行批量插入之前&#xff0c;首先需要配置数据库连接。在连接 URL 中添加 &rewriteBatchedStatementstrue&#xff0c;以提高批量插入的性能。以下是一个示例&#xff1a; spring.datasource.urljdbc:mysql://localhost:3306/your…

路径规划 | 基于改进蝙蝠算法的多无人机路径规划(Matlab)

目录 效果一览基本介绍程序设计参考文献 效果一览 基本介绍 路径规划 | 基于改进蝙蝠算法的多无人机路径规划&#xff08;Matlab&#xff09; 蝙蝠算法&#xff08;Bat Algorithm&#xff09;是一种基于自然界蝙蝠群体行为的启发式优化算法。该算法模拟了蝙蝠在寻找食物时的行为…

Linux 内核源码分析---内核ICMP协议分析

因特网控制报文协议ICMP&#xff08;Internet Control Message Protocol&#xff09; 是一个差错报告机制&#xff0c;是TCP/IP协议簇中的一个重要子协议&#xff0c;通常被IP层或更高层协议&#xff08;TCP或UDP&#xff09;使用&#xff0c;属于网络层协议&#xff0c;主要用…

论文阅读-Transformer Layers as Painters

1. 摘要 尽管大语言模型现在已经被广泛的应用于各种任务&#xff0c;但是目前对其并没有一个很好的认知。为了弄清楚删除和重组预训练模型不同层的影响&#xff0c;本文设计了一系列的实验。通过实验表明&#xff0c;预训练语言模型中的lower和final layers与中间层分布不一致…

四路一体行车记录仪,语音提示注意行人,保障车辆行驶安全

在叉车、货车、客车等行业中&#xff0c;随着运输业务量的不断增加&#xff0c;行车安全问题已经成为了一大难题。经常会发生车祸、司乘人身安全无保障、货物损失等意外情况&#xff0c;这些事件不仅会给企业带来经济损失&#xff0c;也会影响对应行业的整体形象。 如何提高运…

服装行业的利器:RFID智能吊挂分拣系统

服装行业的利器&#xff1a;RFID智能吊挂分拣系统 服装业继续走粗放型老路的利润空间越来越小&#xff0c;行业内过度竞争利润降低&#xff0c;原料价格上涨导致成本上升。企业内部生产技术创新不足、工厂生产效率低&#xff0c;导致产出不够、货期竞争乏力。企业为了盈利生存…

C++中STL的sring类常用接口及其源码解析

1. 为什么会有string类&#xff1f; C语言中的字符串 C语言中&#xff0c;字符串是以\0结尾的一些字符的集合&#xff0c;为了操作方便&#xff0c;C标准库中提供了一些str系列的库函数&#xff0c; 但是这些库函数与字符串是分离开的&#xff0c;不太符合OOP的思想&#xff0…