SAMformer:通过锐度感知最小化和通道注意力解锁变换器在时间序列预测中的潜力

news2024/11/20 15:34:51

目录

  • 摘要
  • 1. 引言
    • 当前方法的局限性
    • 变换器的可训练性
    • 我们贡献的总结
  • 2. 提出的方法
    • 符号说明
    • 2.1 问题设置
    • 2.2 激励示例
      • 命题2.1(最优解的存在性)
    • 2.3 变换器的损失景观
      • 现有的解决方案
    • 2.4. SAMformer:集成所有方法
  • 3. 实验
    • 3.1 主要收获

摘要

基于变换器的架构在自然语言处理和计算机视觉中取得了突破性的表现,然而在多变量长期预测中,它们仍然不如更简单的线性基线。为了更好地理解这一现象,我们从研究一个玩具线性预测问题开始,展示了尽管变换器具有高表达能力,但它们无法收敛到其真实解。我们进一步确定了变换器的注意力机制是导致这种低泛化能力的原因。在此见解的基础上,我们提出了一种轻量级的浅层变换器模型,该模型通过锐度感知优化成功地逃避了坏的局部极小值。我们通过实验证明,这一结果适用于所有常用的真实世界多变量时间序列数据集。特别地,SAMformer超越了当前的最新方法,并与最大的基础模型MOIRAI相当,同时具有显著更少的参数。代码可以在以下网址获得:https://github.com/romilbert/samformer。

1. 引言

多变量时间序列预测是一个经典的学习问题,包含了分析时间序列以根据历史信息预测未来趋势。特别是,由于特征相关性和时间序列中的长期时间依赖性,长期预测是公认的具有挑战性的任务。这一学习问题在那些观测值按顺序收集的实际应用中十分普遍,例如医疗数据、电力消耗、温度或股票价格。为了解决这个问题,已经开发了大量的方法,从经典的数学工具到统计方法(如ARIMA)再到更现代的深度学习方法,包括递归和卷积神经网络。

最近,变换器架构在自然语言处理和计算机视觉中变得无处不在,并在这两个领域取得了突破性的表现。变换器在处理序列数据方面特别高效,这一特性自然要求它们应用于时间序列。毫不奇怪,许多工作尝试提出特定的时间序列变换器架构,以利用它们捕捉时间相互作用的能力。然而,目前在多变量时间序列预测中的最新成果是通过一个更简单的基于MLP的模型实现的,这显著优于基于变换器的方法。此外,最近的研究发现线性网络在预测任务中可以与变换器相媲美甚至更好,这对它们的实际效用提出了质疑。这一有趣的发现成为我们工作的起点。

当前方法的局限性

最近将变换器应用于时间序列数据的研究主要集中在以下两个方面:(i) 提高实现效率,减少注意力的二次成本;(ii) 将时间序列分解以更好地捕捉其中的底层模式。令人惊讶的是,这些工作中没有一个特别解决了变换器在大规模数据缺失情况下的训练不稳定性这一众所周知的问题。

变换器的可训练性

在计算机视觉和自然语言处理中,已发现注意力矩阵可能因熵或秩崩溃而受到影响。随后,提出了几种方法来克服这些问题。然而,在时间序列预测的情况下,关于如何有效地训练变换器架构而不过拟合的开放问题仍然存在。我们旨在通过消除训练的不稳定性来展示变换器在多变量长期预测中的优越性,这与先前对其局限性的看法相反。

我们贡献的总结

我们的提议提出了以下贡献:

  1. 我们展示了即使变换器架构仅限于解决一个简单的玩具线性预测问题,它仍然具有良好的泛化能力并收敛到锋利的局部极小值。我们进一步确定注意力主要负责这种现象;
  2. 我们提出了一种浅层变换器模型,称为SAMformer,结合了研究社区提出的最佳实践,包括最近在计算机视觉领域引入的可逆实例归一化和通道注意力。我们展示了使用锐度感知最小化(SAM)优化这样一个简单的变换器,可以更好地泛化并收敛到局部极小值;
  3. 我们通过实验证明了我们的方法在常见多变量长期预测数据集上的优越性。SAMformer超越了当前最先进的方法,并且在参数显著减少的情况下与最大的基础模型MOIRAI相当。

2. 提出的方法

符号说明

我们用常规字母表示标量值(例如,参数 λ \lambda λ),用粗体小写字母表示向量(例如,向量 x \mathbf{x} x),用粗体大写字母表示矩阵(例如,矩阵 M \mathbf{M} M)。我们用 M ⊤ \mathbf{M}^\top M表示矩阵 M \mathbf{M} M的转置,同样适用于向量。矩阵 M \mathbf{M} M的秩表示为 rank ( M ) \text{rank}(\mathbf{M}) rank(M),其 Frobenius 范数表示为 ∥ M ∥ F \|\mathbf{M}\|_F MF。我们令 n ~ = min ⁡ { n , m } \tilde{n} = \min\{n, m\} n~=min{n,m},并用 ∥ M ∥ ∗ = ∑ i = 1 n ~ σ i ( M ) \|\mathbf{M}\|_* = \sum_{i=1}^{\tilde{n}} \sigma_i(\mathbf{M}) M=i=1n~σi(M)表示矩阵 M \mathbf{M} M的核范数,其中 σ i ( M ) \sigma_i(\mathbf{M}) σi(M)是其奇异值,用 ∥ M ∥ 2 = σ max ⁡ ( M ) \|\mathbf{M}\|_2 = \sigma_{\max}(\mathbf{M}) M2=σmax(M)表示其谱范数。大小为 n × n n \times n n×n的单位矩阵表示为 I n \mathbf{I}_n In。符号 M ⪰ 0 \mathbf{M} \succeq 0 M0表示 M \mathbf{M} M是半正定的。

2.1 问题设置

我们考虑多变量长期预测框架:给定长度为 L L L D D D维时间序列(回看窗口),排列成 R D × L \mathbb{R}^{D \times L} RD×L矩阵 X \mathbf{X} X以便于通道注意,我们的目标是预测其接下来的 H H H个值(预测时间),表示为 Y ∈ R D × H \mathbf{Y} \in \mathbb{R}^{D \times H} YRD×H。我们假设可以访问包含 N N N个观测值的训练集 { ( X ( i ) , Y ( i ) ) } i = 0 N \{(\mathbf{X}^{(i)}, \mathbf{Y}^{(i)})\}_{i=0}^N {(X(i),Y(i))}i=0N,并将第 i i i个输入(分别为目标)时间序列的第 d d d个特征表示为 X d ( i ) ∈ R 1 × L \mathbf{X}_d^{(i)} \in \mathbb{R}^{1 \times L} Xd(i)R1×L(分别为 Y d ( i ) ∈ R 1 × H \mathbf{Y}_d^{(i)} \in \mathbb{R}^{1 \times H} Yd(i)R1×H)。我们训练一个预测器 f ω : R D × L → R D × H f_\omega : \mathbb{R}^{D \times L} \to \mathbb{R}^{D \times H} fω:RD×LRD×H,其参数为 ω \omega ω,使得在训练集上均方误差(MSE)最小化:

L train ( ω ) = 1 N D ∑ i = 0 N ∥ Y ( i ) − f ω ( X ( i ) ) ∥ F 2 L_{\text{train}}(\omega) = \frac{1}{ND} \sum_{i=0}^{N} \|\mathbf{Y}^{(i)} - f_\omega(\mathbf{X}^{(i)})\|_F^2 Ltrain(ω)=ND1i=0NY(i)fω(X(i))F2.

详细解释

  1. 多变量时间序列:我们考虑的是一个多变量时间序列,即每个时间步都有 D D D个不同的特征。例如,在气象预测中,每个时间步可能有温度、湿度、风速等多个特征。

  2. 回看窗口 L L L:回看窗口表示我们用来做预测的过去时间步的数量。比如,如果我们想根据过去的一周的数据来预测未来的天气,那么 L L L可能就是7天的长度。

  3. 预测时间 H H H:预测时间表示我们希望预测的未来时间步的数量。例如,我们可能希望根据过去的一周的数据来预测未来一天的天气,这里 H H H就是1天的长度。

  4. 数据格式:输入数据 X \mathbf{X} X是一个 R D × L \mathbb{R}^{D \times L} RD×L的矩阵,表示 D D D维特征在 L L L个时间步上的值。目标数据 Y \mathbf{Y} Y是一个 R D × H \mathbb{R}^{D \times H} RD×H的矩阵,表示 D D D维特征在 H H H个时间步上的值。

  5. 训练集:我们有 N N N个样本的训练集,每个样本由一对 ( X ( i ) , Y ( i ) ) (\mathbf{X}^{(i)}, \mathbf{Y}^{(i)}) (X(i),Y(i))组成,其中 X ( i ) \mathbf{X}^{(i)} X(i)是输入, Y ( i ) \mathbf{Y}^{(i)} Y(i)是对应的目标值。

  6. 损失函数:我们使用均方误差(MSE)作为损失函数来评估模型的预测误差。损失函数的形式为:

L train ( ω ) = 1 N D ∑ i = 0 N ∥ Y ( i ) − f ω ( X ( i ) ) ∥ F 2 L_{\text{train}}(\omega) = \frac{1}{ND} \sum_{i=0}^{N} \|\mathbf{Y}^{(i)} - f_\omega(\mathbf{X}^{(i)})\|_F^2 Ltrain(ω)=ND1i=0NY(i)fω(X(i))F2

其中, ∥ ⋅ ∥ F \|\cdot\|_F F表示 Frobenius 范数,用于计算矩阵的误差。Frobenius 范数的平方等于矩阵元素平方和,即:

∥ A ∥ F 2 = ∑ i , j ∣ a i j ∣ 2 \|\mathbf{A}\|_F^2 = \sum_{i,j} |a_{ij}|^2 AF2=i,jaij2

在这个损失函数中,我们计算每个样本 X ( i ) \mathbf{X}^{(i)} X(i)的预测值与真实值之间的误差,并对所有样本的误差进行平均,以获得整体的训练损失。

通过最小化这个损失函数,我们可以训练模型 f ω f_\omega fω,使其在训练集上具有良好的预测性能,并期望其在测试集上也能有较好的泛化能力。

2.2 激励示例

最近,Zeng 等人(2023)展示了变换器在直接将输入投射到输出的简单线性神经网络训练中表现与之相当,甚至更差。我们将这一观察作为起点,考虑以下生成模型用于我们的玩具回归问题,模拟之前考虑的时间序列预测设置:

Y = X W toy + ϵ \mathbf{Y} = \mathbf{X} \mathbf{W}_{\text{toy}} + \epsilon Y=XWtoy+ϵ(2)

我们设 L = 512 L = 512 L=512 H = 96 H = 96 H=96 D = 7 D = 7 D=7 W toy ∈ R L × H \mathbf{W}_{\text{toy}} \in \mathbb{R}^{L \times H} WtoyRL×H ϵ ∈ R D × H \epsilon \in \mathbb{R}^{D \times H} ϵRD×H具有随机正态条目,并生成 15000 个输入-目标对 ( X , Y ) (\mathbf{X}, \mathbf{Y}) (X,Y)(10000 个用于训练,5000 个用于验证),其中 X ∈ R D × L \mathbf{X} \in \mathbb{R}^{D \times L} XRD×L具有随机正态条目。

鉴于此生成模型,我们希望开发一种变换器架构,可以有效地解决等式(2)中的问题而没有不必要的复杂性。为此,我们提出简化通常的变换器编码器,通过将注意力应用于 X \mathbf{X} X并结合残差连接,将 X \mathbf{X} X添加到注意力的输出中。我们直接在这种残差连接之上添加一个线性层用于输出预测,而不是添加前馈块。正式地,我们的模型定义如下:

f ( X ) = [ X + A ( X ) X W V W O ] W f(\mathbf{X}) = [\mathbf{X} + \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_{V} \mathbf{W}_{O}] \mathbf{W} f(X)=[X+A(X)XWVWO]W(3)

其中 W ∈ R L × H \mathbf{W} \in \mathbb{R}^{L \times H} WRL×H W V ∈ R L × d m \mathbf{W}_{V} \in \mathbb{R}^{L \times dm} WVRL×dm W O ∈ R d m × L \mathbf{W}_{O} \in \mathbb{R}^{dm \times L} WORdm×L A ( X ) \mathbf{A}(\mathbf{X}) A(X)是输入序列 X ∈ R D × L \mathbf{X} \in \mathbb{R}^{D \times L} XRD×L的注意力矩阵,定义为:

A ( X ) = softmax ( X W Q W K ⊤ X ⊤ d m ) ∈ R D × D \mathbf{A}(\mathbf{X}) = \text{softmax} \left( \frac{\mathbf{X} \mathbf{W}_{Q} \mathbf{W}_{K}^\top \mathbf{X}^\top}{\sqrt{d_{m}}} \right) \in \mathbb{R}^{D \times D} A(X)=softmax(dm XWQWKX)RD×D(4)

其中 softmax 是行归一化, W Q ∈ R L × d m \mathbf{W}_{Q} \in \mathbb{R}^{L \times dm} WQRL×dm W K ∈ R L × d m \mathbf{W}_{K} \in \mathbb{R}^{L \times dm} WKRL×dm是模型的维度。softmax 使 A ( X ) \mathbf{A}(\mathbf{X}) A(X)相当随机,每行描述一个概率分布。为了简化符号,在上下文明确的情况下,我们将注意力矩阵简写为 A \mathbf{A} A,省略 X \mathbf{X} X
我们将此架构称为Transformer,并简要评论如下。首先,注意力矩阵是按通道应用的,这简化了问题并减少了过度参数化的风险,因为矩阵 W \mathbf{W} W的形状与等式(2)相同,并且由于 L > D L > D L>D,注意力矩阵变得更小。此外,在这种情况下,通道注意比时间注意更相关,因为数据生成遵循独立同分布过程。我们正式建立了模型对 W toy \mathbf{W}_{\text{toy}} Wtoy的可识别性。证明详见附录E.2。

详细解释

我们将通过以下几个步骤详细解释这个激励示例,包括模型设置、生成模型和提出的变换器架构。

1. 模型设置

我们设置了以下参数:

  • L = 512 L = 512 L=512:回看窗口的长度。
  • H = 96 H = 96 H=96:预测时间的长度。
  • D = 7 D = 7 D=7:时间序列的维度,即每个时间步有7个特征。
  • W toy ∈ R L × H \mathbf{W}_{\text{toy}} \in \mathbb{R}^{L \times H} WtoyRL×H:一个随机正态分布的矩阵。
  • ϵ ∈ R D × H \epsilon \in \mathbb{R}^{D \times H} ϵRD×H:一个具有随机正态条目的矩阵。

生成了 15000 个输入-目标对 ( X , Y ) (\mathbf{X}, \mathbf{Y}) (X,Y),其中 10000 个用于训练,5000 个用于验证。输入数据 X ∈ R D × L \mathbf{X} \in \mathbb{R}^{D \times L} XRD×L具有随机正态条目。

2. 生成模型

生成模型定义如下:

Y = X W toy + ϵ \mathbf{Y} = \mathbf{X} \mathbf{W}_{\text{toy}} + \epsilon Y=XWtoy+ϵ

这个生成模型的目的是创建一个可以通过线性变换和噪声生成目标值的简单回归问题。

3. 提出的变换器架构

为了简化通常的变换器编码器,我们提出了一个简单的变换器架构。其主要思想是将注意力应用于输入 X \mathbf{X} X,并结合残差连接,将 X \mathbf{X} X添加到注意力的输出中。我们直接在这种残差连接之上添加一个线性层用于输出预测,而不是添加前馈块。模型定义如下:

f ( X ) = [ X + A ( X ) X W V W O ] W f(\mathbf{X}) = [\mathbf{X} + \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_{V} \mathbf{W}_{O}] \mathbf{W} f(X)=[X+A(X)XWVWO]W

其中:

  • W ∈ R L × H \mathbf{W} \in \mathbb{R}^{L \times H} WRL×H
  • W V ∈ R L × d m \mathbf{W}_{V} \in \mathbb{R}^{L \times d_m} WVRL×dm
  • W O ∈ R d m × L \mathbf{W}_{O} \in \mathbb{R}^{d_m \times L} WORdm×L
  • A ( X ) \mathbf{A}(\mathbf{X}) A(X)是输入序列 X ∈ R D × L \mathbf{X} \in \mathbb{R}^{D \times L} XRD×L的注意力矩阵,定义为:

A ( X ) = softmax ( X W Q W K ⊤ X ⊤ d m ) \mathbf{A}(\mathbf{X}) = \text{softmax} \left( \frac{\mathbf{X} \mathbf{W}_{Q} \mathbf{W}_{K}^\top \mathbf{X}^\top}{\sqrt{d_{m}}} \right) A(X)=softmax(dm XWQWKX)

其中:

  • softmax 是行归一化
  • W Q ∈ R L × d m \mathbf{W}_{Q} \in \mathbb{R}^{L \times d_m} WQRL×dm
  • W K ∈ R L × d m \mathbf{W}_{K} \in \mathbb{R}^{L \times d_m} WKRL×dm
  • d m d_m dm是模型的维度

注意力矩阵 A ( X ) \mathbf{A}(\mathbf{X}) A(X)相当随机,每行描述一个概率分布。为简化符号,在上下文明确的情况下,我们将注意力矩阵简写为 A \mathbf{A} A,省略 X \mathbf{X} X

好的,让我们详细解释公式:

P = X + A ( X ) X W V W O P = \mathbf{X} + \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_V \mathbf{W}_O P=X+A(X)XWVWO

以及后续相关内容。

公式解释

符号含义

  • X \mathbf{X} X:输入矩阵,维度为 R D × L \mathbb{R}^{D \times L} RD×L,表示一个长度为 L L L D D D维时间序列。
  • A ( X ) \mathbf{A}(\mathbf{X}) A(X):输入 X \mathbf{X} X的注意力矩阵,维度为 R D × D \mathbb{R}^{D \times D} RD×D
  • W V \mathbf{W}_V WV:变换器模型中的注意力权重矩阵,维度为 R L × d m \mathbb{R}^{L \times d_m} RL×dm
  • W O \mathbf{W}_O WO:变换器模型中的输出权重矩阵,维度为 R d m × L \mathbb{R}^{d_m \times L} Rdm×L
  • P P P:通过公式计算得到的中间结果矩阵,维度为 R D × L \mathbb{R}^{D \times L} RD×L

公式拆解

  1. 输入矩阵 X \mathbf{X} X
    X \mathbf{X} X是一个 D × L D \times L D×L的矩阵,表示 D D D维特征在 L L L个时间步上的值。

  2. 注意力矩阵 A ( X ) \mathbf{A}(\mathbf{X}) A(X)
    A ( X ) \mathbf{A}(\mathbf{X}) A(X)是一个 D × D D \times D D×D的矩阵,通过对输入 X \mathbf{X} X计算得到。注意力矩阵的作用是计算输入特征之间的相关性,用于加权和组合输入特征。

  3. 注意力加权 A ( X ) X W V \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_V A(X)XWV
    A ( X ) \mathbf{A}(\mathbf{X}) A(X) X \mathbf{X} X相乘,再与权重矩阵 W V \mathbf{W}_V WV相乘,得到一个新的矩阵,这个过程可以理解为通过注意力机制对输入特征进行加权。

  4. 输出权重 W O \mathbf{W}_O WO
    上一步得到的矩阵再与输出权重矩阵 W O \mathbf{W}_O WO相乘,进一步调整特征的表示,最后得到的维度还是 R D × L \mathbb{R}^{D \times L} RD×L

  5. 残差连接
    将原始输入 X \mathbf{X} X与注意力加权和输出权重调整后的结果相加,得到最终的中间结果矩阵 P P P。残差连接(ResNet)是一种常用技术,能够缓解深度模型中的梯度消失问题,提高训练效率。

公式的意义
这个公式旨在通过变换器模型对输入特征进行加权和调整,同时保留输入特征的原始信息。通过残差连接,模型能够更好地学习到输入特征之间的复杂关系,并提高预测性能。

4. 识别性证明

为了证明我们提出的模型对 W toy \mathbf{W}_{\text{toy}} Wtoy的识别性,我们提供了如下论证:

我们希望开发一种变换器架构,可以有效地解决等式(2)中的问题而没有不必要的复杂性。为此,我们提出简化通常的变换器编码器,通过将注意力应用于 X \mathbf{X} X并结合残差连接,将 X \mathbf{X} X添加到注意力的输出中。我们直接在这种残差连接之上添加一个线性层用于输出预测,而不是添加前馈块。正式地,我们的模型定义如下:

f ( X ) = [ X + A ( X ) X W V W O ] W f(\mathbf{X}) = [\mathbf{X} + \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_{V} \mathbf{W}_{O}] \mathbf{W} f(X)=[X+A(X)XWVWO]W

注意力矩阵 A ( X ) \mathbf{A}(\mathbf{X}) A(X)是输入序列 X ∈ R D × L \mathbf{X} \in \mathbb{R}^{D \times L} XRD×L的注意力矩阵,定义为:

A ( X ) = softmax ( X W Q W K ⊤ X ⊤ d m ) \mathbf{A}(\mathbf{X}) = \text{softmax} \left( \frac{\mathbf{X} \mathbf{W}_{Q} \mathbf{W}_{K}^\top \mathbf{X}^\top}{\sqrt{d_{m}}} \right) A(X)=softmax(dm XWQWKX)

其中 softmax 是行归一化, W Q ∈ R L × d m \mathbf{W}_{Q} \in \mathbb{R}^{L \times d_m} WQRL×dm W K ∈ R L × d m \mathbf{W}_{K} \in \mathbb{R}^{L \times d_m} WKRL×dm是模型的维度。softmax 使 A ( X ) \mathbf{A}(\mathbf{X}) A(X)相当随机,每行描述一个概率分布。为简化符号,在上下文明确的情况下,我们将注意力矩阵简写为 A \mathbf{A} A,省略 X \mathbf{X} X

首先,注意力矩阵是按通道应用的,这简化了问题并减少了过度参数化的风险,因为矩阵 W \mathbf{W} W的形状与等式(2)相同,并且由于 L > D L > D L>D,注意力矩阵变得更小。此外,在这种情况下,通过注意力的注意更相关,因为数据根据 IID 分布独立同分布。我们正式建立了模型对 W toy \mathbf{W}_{\text{toy}} Wtoy的识别性。证明详见附录E.2。

命题2.1(最优解的存在性)

假设 W Q \mathbf{W}_Q WQ W K \mathbf{W}_K WK W V \mathbf{W}_V WV W O \mathbf{W}_O WO是固定的,令 P = X + A ( X ) X W V W O ∈ R D × L P = \mathbf{X} + \mathbf{A}(\mathbf{X})\mathbf{X}\mathbf{W}_V\mathbf{W}_O \in \mathbb{R}^{D \times L} P=X+A(X)XWVWORD×L。那么,存在一个矩阵 W ∈ R L × H \mathbf{W} \in \mathbb{R}^{L \times H} WRL×H,使得 P W = X W toy P\mathbf{W} = \mathbf{X}\mathbf{W}_{\text{toy}} PW=XWtoy,当且仅当 rank ( [ P X W toy ] ) = rank ( P ) \text{rank}([P \mathbf{X}\mathbf{W}_{\text{toy}}]) = \text{rank}(P) rank([PXWtoy])=rank(P),其中 [ P X W toy ] ∈ R D × ( L + H ) [P \mathbf{X}\mathbf{W}_{\text{toy}}] \in \mathbb{R}^{D \times (L+H)} [PXWtoy]RD×(L+H)是一个块矩阵。

上述假设在 P P P为满秩且 D < H D < H D<H的情况下成立,这在此玩具实验中是成立的。

因此,在用等式(2)生成的数据上拟合变换器的优化问题理论上允许无限多个最优分类器 W \mathbf{W} W

我们现在想确定注意力在解决等式(3)问题中的作用。为此,我们考虑一个模型,称为随机变换器(Random Transformer),其中仅优化 W \mathbf{W} W,而自注意权重 W Q \mathbf{W}_Q WQ W K \mathbf{W}_K WK W V \mathbf{W}_V WV W O \mathbf{W}_O WO在训练过程中固定,并按照Glorot & Bengio (2010)进行初始化。这实际上使所考虑的变换器像一个线性模型。最后,我们比较了这两个模型使用Adam优化后的局部最小值,并与Oracle模型的最小二乘解进行比较。

我们在图2中展示了两个模型的验证损失。第一个令人惊讶的发现是,两个变换器都未能恢复 W toy \mathbf{W}_{\text{toy}} Wtoy,这表明即使是优化这样一个具有良好设计的简单架构也表现出强烈的泛化缺乏。当固定自注意矩阵时,问题在某种程度上得到了缓解,尽管随机变换器仍然是次优的。这一观察在各种优化器和学习率值中保持一致,表明这种现象并不是由于次优的优化器超参数或优化器的具体选择所致。由于随机变换器和变换器之间的参数数量仅增加了2%,因此这也不是由于过拟合。因此,我们从图1中得出结论,变换器较差的泛化能力主要是由于注意力模块的可训练性问题。

2.3 变换器的损失景观

直觉 :在前一节中,我们得出结论,注意力是导致Transformer泛化能力差的原因。为了进一步理解这一现象,我们在图3a中绘制了不同训练时期的注意力矩阵。我们可以看到,注意力矩阵在第一次迭代后非常接近于单位矩阵,并且之后几乎没有变化,特别是在softmax放大了矩阵值之间的差异后。这表明全秩注意力矩阵在训练过程中出现了熵崩溃(entropy collapse),这也是Zhai等人(2023)指出的导致训练变换器困难的原因之一。

这项工作还建立了熵崩溃与变换器损失景观锐度之间的关系,我们在图3b中确认了这一点(在图5a中的真实数据上也观察到了类似行为)。Transformer在训练过程中收敛到比随机变换器更锐利的最小值,同时具有显著较低的熵(对后者来说,注意力权重在初始化时被固定,其熵在整个训练过程中保持不变)。这些病理模式表明,Transformer的失败是由于熵崩溃和损失景观的锐度。在接下来的段落中,我们探讨了文献中现有的解决方案,以缓解这些问题。

现有的解决方案

最近的研究表明,与其他残差架构相比,变换器的损失景观更为锐利(Chen et al., 2022; Zhai et al., 2023)。这可能解释了变换器在小规模数据集上训练不稳定且表现不佳的原因。变换器的锐度已经被观察和定量化:Chen等(2022)计算了损失函数Hessian的最大特征值,Zhai等(2023)则计算了注意力矩阵的熵以展示其在高锐度下的崩溃。我们在图3b中展示了这两个指标的结果,验证了我们的假设,同时揭示了两个有害现象:一方面,固定注意力的变换器的锐度比收敛到单位注意力矩阵的变换器高出几个数量级;另一方面,变换器的注意力矩阵的熵在训练过程中急剧下降。

为了找到一种既能提高泛化性能又能改善训练稳定性的解决方案,我们探讨了Chen et al.(2022)和Zhai et al.(2023)提出的两种补救措施。第一种方法是利用最近提出的锐度感知最小化(SAM)框架(Foret et al., 2021),该框架将等式(1)中的训练目标替换为:

L train SAM ( ω ) = max ⁡ ∥ ϵ ∥ ≤ ρ L train ( ω + ϵ ) L_{\text{train}}^{\text{SAM}}(\omega) = \max_{\|\epsilon\| \leq \rho} L_{\text{train}}(\omega + \epsilon) LtrainSAM(ω)=ϵρmaxLtrain(ω+ϵ)

其中 ρ > 0 \rho > 0 ρ>0是超参数(见附录D.1), ω \omega ω是模型参数。更多关于SAM的详细信息见附录D.2。第二种方法引入了光谱归一化和一个附加的可学习标量,这种技术被Zhai et al.(2023)称为 σ \sigma σReparam。更正式地,我们将每个权重矩阵 W \mathbf{W} W替换为:

W ^ = γ ∥ W ∥ 2 W \widehat{\mathbf{W}} = \frac{\gamma}{\|\mathbf{W}\|_2}\mathbf{W} W =W2γW

其中 γ ∈ R \gamma \in \mathbb{R} γR是一个可学习的参数,初始化为1。

图1中展示的结果强调了我们的变换器成功地收敛到所需的解。令人惊讶的是,这只有在使用SAM时才能实现,而 σ \sigma σReparam尽管最大化了注意力矩阵的熵,但无法接近最优表现。此外,可以在图3b中看到,使用SAM时变换器的锐度比使用 σ \sigma σReparam时低几个数量级,而获得的注意力熵与基础变换器相似。

为了更好地理解 σ \sigma σReparam的失败,可以回顾等式(5)的推导过程。Zhai等(2023)从注意力熵的一个紧下界出发,证明其随着 ∥ W Q W K ⊤ ∥ 2 2 \|\mathbf{W}_Q \mathbf{W}_K^\top\|_2^2 WQWK22的最小化而指数级增加。等式(5)作为一种简单的方法被提出以最小化这一量。然而,在通道注意的情况下,这对注意力矩阵秩产生了有害影响,从而导致注意力机制崩溃。我们在附录E.3中对此进行了形式化证明。

需要注意的是,当 W Q = W K \mathbf{W}_Q = \mathbf{W}_K WQ=WK时,上述假设成立,并且这一点已经在之前的研究中被确认。该定理证实,通过减少 ∥ W Q W K ⊤ ∥ 2 2 \|\mathbf{W}_Q \mathbf{W}_K^\top\|_2^2 WQWK22 σ \sigma σReparam减少了等式(4)中定义的注意力矩阵的核范数。尽管矩阵秩与核范数之间的直接联系并不总是存在,核范数正则化常用于鼓励压缩感知中的低秩结构。

尽管命题2.2不能直接应用于注意力矩阵 A ( X ) \mathbf{A}(\mathbf{X}) A(X),我们指出在 σ \sigma σReparam导致注意力分数 X W Q W K ⊤ X ⊤ \mathbf{X} \mathbf{W}_Q \mathbf{W}_K^\top \mathbf{X}^\top XWQWKX为秩1并具有相同行的极端情况下,注意力矩阵在行softmax应用后仍保持秩1。因此, σ \sigma σReparam可能导致注意力秩的崩溃,我们在图7中的核范数实验证实了这一点。基于这些发现,我们提出了一种新的简单变换器模型,在多变量时间序列预测中具有高性能和训练稳定性。

2.4. SAMformer:集成所有方法

提出的SAMformer基于等式(3)进行两项重要修改。首先,我们为 X \mathbf{X} X配备了可逆实例归一化(RevIN, Kim et al. 2021b),因为这种技术被证明在处理训练数据和测试数据之间的转换时非常有效。其次,如上文所述,我们使用SAM对模型进行优化,使其收敛到更平滑的局部最小值。总体而言,这为浅层变换器模型提供了一个编码器,如图4所示。

我们强调,SAMformer保持了由 D × D D \times D D×D矩阵表示的通道注意,如等式(3)所示,这不同于其他模型中使用的由 L × L L \times L L×L矩阵表示的空间(或时间)注意。这带来了两个重要好处:(i) 确保了特征排列不变性,消除了对注意层之前常见的位置信息编码的需求;(ii) 由于在大多数实际数据集中 D ≤ L D \leq L DL,因此减少了时间和内存复杂度。我们的通道注意检查了每个特征在所有时间步长上的平均影响。

在附录C.4中详细描述的消融研究验证了此实现的有效性。我们现在准备在常见的多变量时间序列预测基准上评估SAMformer,展示其优越性。

3. 实验

在本节中,我们通过实验证明了SAMformer在多变量长期时间序列预测中的定量和定性优越性。我们展示了SAMformer在常见基准上的表现超越了当前最先进的TSMixer(Chen et al., 2023)14.33%,同时参数量减少了约4倍。所有实现细节见附录A.1。

数据集
我们在8个公开可用的真实世界多变量时间序列数据集上进行了实验,这些数据集通常用于长期预测(Wu et al., 2021; Chen et al., 2023; Nie et al., 2023; Zeng et al., 2023)。具体包括四个电力变换器温度数据集ETTh1、ETTh2、ETTm1和ETTm2(Zhou et al., 2021),电力(UCI, 2015),外汇(Lai et al., 2018b),交通(California Department of Transportation, 2021),和天气(Max Planck Institute, 2021)。所有时间序列被分段,输入长度 L = 512 L = 512 L=512,预测区间 H ∈ { 96 , 192 , 336 , 720 } H \in \{96, 192, 336, 720\} H{96,192,336,720},步长为1,即每个后续窗口移动一步。更多关于数据集和时间序列准备的详细描述见附录A.2。

基线
我们将SAMformer与之前介绍的Transformer和TSMixer(Chen et al., 2023)进行比较,后者是一个基于MLP的最先进多变量基线模型。需要注意的是,Chen et al.(2023)报告了TSMixer在一个固定种子上的表现,为了公平比较,我们在不同种子上多次运行,报告平均性能。

我们还比较了iTransformer(Liu et al., 2024)、PatchTST(Nie et al., 2023)、FEDformer(Zhou et al., 2022)、Informer(Zhou et al., 2021)和Autoformer(Wu et al., 2021)等最近的SOTA多变量变换器模型的结果,所有结果均使用RevIN(Kim et al., 2021b)获得,以便于更公平地比较SAMformer及其竞争对手。关于这些基线的详细信息见附录A.3。

评估
所有模型都被训练以最小化等式(1)中的MSE损失。报告测试集上的平均MSE及其标准差(通过不同种子的5次运行)。更多的详细数据和结果,包括平均绝对误差(MAE),见附录B.1中的表6。除非另有说明,否则所有结果均通过不同种子的5次运行获得。

3.1 主要收获

SAMformer 优于最先进技术。实验结果详见表1,附录表7中提供了学生t检验分析。SAMformer 在8个数据集中的7个上显著优于其竞争对手。具体而言,它比其最佳竞争对手 TSMixer+SAM 提高了5.25%,比独立的 TSMixer 提高了14.33%,比最佳的多变量基于变换器模型 FEDformer 提高了12.36%。此外,它比 Transformer 提高了16.96%。SAMformer 还优于最近的 iTransformer,这是一种基于变换器的方法,使用了时间和空间注意,PatchTST 则针对单变量时间序列预测进行了优化。我们注意到 iTransformer 在总体表现上混合,除外汇数据集外,SAMformer 在所有数据集上都优于它。在没有 RevIN 的情况下,SAMformer 总体上提高了3.94%,但在有 RevIN 的情况下提高了多达8.38%。

SAMformer 比 PatchTST 提高了11.13%。对于每个预测区间和数据集(除外汇外),SAMformer 都排名第一或第二。显著的是,SAM的集成提高了TSMixer的泛化能力,平均增强了9.58%。表6中的 MAE 研究得出了相同的结论。由于 TSMixer 与 SAM 的训练几乎总是排名第二,它作为进一步讨论的主要基准。值得注意的是,SAMformer 的参数量比 TSMixer 少4倍,比基于变换器的方法少几个数量级。

更平滑的损失景观。在 SAMformer 的训练中引入 SAM 使其损失比 Transformer 更平滑。我们在图5a中通过比较 Transformer 和 SAMformer 在 ETTh1 和外汇数据集上的 λ max ⁡ \lambda_{\max} λmax值来说明这一点。我们的观察显示 Transformer 表现出显著更高的锐度,而 SAMformer 具有期望的行为,其损失景观的锐度小一个数量级。

提高的鲁棒性。SAMformer 在随机初始化方面表现出鲁棒性。图5b展示了 SAMformer 和 Transformer 在 ETTh1 和外汇数据集上预测区间为 H = 96 H=96 H=96的5个不同种子的测试 MSE 分布。SAMformer 在不同种子选择上持续保持性能稳定,而 Transformer 表现出显著的方差,因此对权重初始化高度依赖。此观察结果在所有数据集和预测区间上都适用,如附录B.4所示。

在这里插入图片描述
图1:我们的方法在合成数据上的说明。Oracle 是最优解,Transformer 是基本变换器, σ \sigma σReparam 是权重重缩放的变换器(Zhai 等人,2023),而 Transformer + SAM 是使用锐度感知最小化训练的变换器。Transformer 过拟合, σ \sigma σReparam 略有改善但未能达到 Oracle 的水平,而 Transformer + SAM 完美泛化。这激励了 SAMformer 的提出,一个结合了 SAM 和时间序列预测最佳实践的浅层变换器。
在这里插入图片描述
图2:泛化能力差。尽管其设计简单,Transformer仍然存在严重的过拟合问题。在随机变换器(Random Transformer)中固定注意力权重可以改善泛化能力,这暗示了注意力在防止收敛到最优局部极小值中的作用。

在这里插入图片描述
图3:变换器在线性回归中的损失景观分析。(a) 变换器的注意力矩阵从第一个时期起就陷入单位矩阵。(b, 左) 变换器比变换器+SAM收敛到更锐利的最小值,具有更大的 λ max ⁡ \lambda_{\max} λmax(约 1 0 4 10^4 104),而随机变换器具有平滑的损失景观。(b, 右) 变换器在训练期间遭遇熵崩溃,证实了其损失景观的高度锐度。

在这里插入图片描述
在这里插入图片描述
图5:(a)SAMformer比Transformer具有更平滑的损失景观。(b)SAMformer在每次初始化时都能一致地很好泛化,而Transformer则不稳定且严重依赖于种子。

在这里插入图片描述
图6:在天气数据集上的注意力矩阵。SAMformer 保持了特征之间的自相关性,而 (\sigma)Reparam 降低了秩,阻碍了信息的传播。
在这里插入图片描述
图7:不同模型的注意力矩阵的核范数: σ \sigma σReparam 根据命题2.2诱导较低的核范数,而SAMformer保持了比Transformer更高的注意力表达能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1881683.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

目标检测算法的研究现状

一、引言 目标检测是计算机视觉领域中的一项重要任务&#xff0c;旨在从图像或视频中识别并定位出感兴趣的目标物体。随着深度学习技术的快速发展&#xff0c;目标检测算法取得了显著的进步&#xff0c;并在自动驾驶、智能安防、人脸识别等领域得到了广泛应用。本文将对目标检…

2024上海初中生古诗文大会倒计时4个月:单选题真题和独家解析

现在距离2024年初中生古诗文大会初选还有4个多月时间&#xff08;11月3日正式开赛&#xff09;&#xff0c;我们继续来看10道选择题真题和详细解析&#xff0c;以下题目截取自我独家制作的在线真题集&#xff0c;都是来自于历届真题&#xff0c;去重、合并后&#xff0c;每道题…

电商价格监测:品牌控价维权的关键利器

品牌在进行控价时&#xff0c;所面对的是线上成千上万条的商品链接&#xff0c;如果仅依靠人工&#xff0c;根本无法做到准确且全面地完成电商价格监测工作。因此&#xff0c;一套准确率高的电商价格监测系统对于品牌的控价维权而言&#xff0c;其重要性不言而喻。 在形形色色的…

UE4_材质_材质节点_视差偏移BumpOffset

一、定义 凹凸贴图偏移&#xff08;BumpOffset&#xff09; 是虚幻引擎4术语&#xff0c;就是通常所谓的"视差贴图"。BumpOffset表达式可以使材质产生深度错觉&#xff0c;而不需要额外的几何体。BumpOffset材质使用灰阶_高度贴图_来提供深度信息。高度贴图中的值越…

沃创云获客系统如何帮助企业找到意向客户群体

沃创云是一家做外呼系统起家&#xff0c;越来越多客户有打电话的需求。我们突然意识到&#xff0c;大量的数据积累是电销的基础&#xff0c;那么如何找到客户以及联系方式也非常关键 通过爬虫技术&#xff0c;通过帮助企业精准地找到意向客户群体。以下是该系统如何帮助企业实…

PHP转Go系列 | GET 和 POST 请求的使用姿势

大家好&#xff0c;我是码农先森。 说到 HTTP 请求工具想必对我们做 Web 开发的程序员都不陌生&#xff0c;只要涉及到网络请求都必须使用。对于我们 PHP 程序员来说&#xff0c;最熟悉不过的就是 CURL 扩展&#xff0c;只要安装的这个扩展便可随意发起 HTTP 请求。 但在 PHP …

Transformer常见面试题

目录 1.Transformer为何使用多头注意力机制&#xff1f;&#xff08;为什么不使用一个头&#xff09; 2.Transformer为什么Q和K使用不同的权重矩阵生成&#xff0c;为何不能使用同一个值进行自身的点乘&#xff1f; &#xff08;注意和第一个问题的区别&#xff09; 3.Transf…

React+TS前台项目实战(二十二)-- 全局常用导出组件Export封装

文章目录 前言Export组件1. 功能分析2. 代码详细注释3. 使用方式4. 效果展示 总结 前言 今天我们来封装一个带导出图标的导出组件。 Export组件 1. 功能分析 通过传入链接地址&#xff0c;规定要跳转的导出页面&#xff0c;或是直接通过链接导出数据 2. 代码详细注释 // /c…

产品中心|高效能双处理器Xilinx FPGA 4通道射频收发板卡

1、产品概述 基于Xilinx XC7K325T芯片的4通道射频收发板卡&#xff0c;搭载高能效Cortex-A8内核处理器、1组16bit/2GB DDR3及1组4GB DDR3、 1组2GB Nand Flash、1路USB接口、4路高速ADC、4路高速DAC&#xff0c;支持外触发&#xff0c;外时钟。用于FPGA程序加载板卡工作温度范…

前端开发中实用小技巧

运行javascript小技巧 // 1.直接在浏览器的地址栏中输入一下代码&#xff1a;javascript:alert(hello world) // 2.注意事项ie和chrom回自动去掉开头的 【javascript:】需要手动添加火狐浏览器不支持这个技巧 // 3.场景快速测试一段js代码运行HTML代码的神奇技巧 // 1.直接在…

TensorFlow安装CPU版本和GPU版本

文章目录 前言一、TensorFlow安装CPU版本1.新建虚拟环境2.激活虚拟环境3.下载tensorflow4.验证是否下载成功 二、TensorFlow安装GPU版本1.新建虚拟环境2.激活虚拟环境3.安装tensorflow-gpu4.验证是否下载成功 前言 下载的Anaconda是Anaconda3-2024.02-1-Windows-x86_64版本 一…

firewalld防火墙转发流量到其他端口forward port rules

假设云主机eth0: 47.93.27.106 tun0: inet 10.8.0.1 netmask 255.255.255.0 Show rules for a specific zone (public) sudo firewall-cmd --zonepublic --list-all Add the tun0 interface to the public zone: sudo firewall-cmd --zonepublic --add-interfacetun0 --…

MTK7621交换芯片配置

MTK7621上自带的交换芯片为mt7530 admin@OpenWrt:~# /sbin/swconfig list Found: switch0 - mt7530 交换芯片的配置工具为swconfig程序。MTK7621采用内部的MDIO(Management Data Input/Output)接口管理MT7530的switch芯片。 MT7530共有7个物理口,通过/sbin/swconfig dev …

isupper()方法——判断字符串是否全由大写字母组成

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 语法参考 isupper()方法用于判断字符串中所有的字母是否都是大写。isupper()方法的语法格式如下&#xff1a; str.isupper() 如果字符串中包含至少…

linux虚拟机部署的MySQL如何使用外网访问?教你轻松使用cpolar在centos搭建内网穿透

文章目录 写在前面实现Linux的内网穿透1、官网账号注册2、在Linux部署我们自己的项目3、一键自动下载安装cpolar4、设置自己的token5、启动cpolar服务6、MySQL穿透测试 卸载方法 写在前面 相信很多小伙伴在本地搭建了一个MySQL数据库&#xff0c;想让其他同事或者合作者一起使…

【AI大模型】跌倒监控与健康:技术实践及如何改变未来

文章目录 1. **背景与意义**2. **关键技术与方法**2.1 传感器数据融合2.2 深度学习模型2.3 行为模式识别2.4 预测与预防 3. **应用场景**3.1 老年人跌倒预警3.2 康复患者监测3.3 高风险职业防护 4. **实践案例**案例1&#xff1a;某老年社区的跌倒预警系统案例2&#xff1a;康复…

【FreeRTOS】空闲任务

目录 空闲任务及其钩子函数介绍使用钩子函数的前提 实际操作任务如何退出&#xff1f;IDLE函数 空闲任务及其钩子函数 介绍 空闲任务(Idle任务)的作用之一&#xff1a;释放被删除的任务的内存。 除了上述目的之外&#xff0c;为什么必须要有空闲任务? 这是一个良好的程序&…

使vim创建.sh文件时自动添加头部描述信息

目录 需求解决方案vimrc配置文件常见选项 修改vimrc功能解释 效果 需求 在编写shell脚本时&#xff0c;为了便于后续阅读或修改或SOP需求&#xff0c;我们常常会在shell脚本前添加一些描述信息&#xff0c;用于标注其作用和shell版本&#xff0c;例如&#xff1a; #!/bin/bas…

mongodb在windows环境安装部署

一、mongodb 1.释义 MongoDB 是一种开源的文档型 NoSQL 数据库管理系统&#xff0c;使用 C 编写&#xff0c;旨在实现高性能、高可靠性和易扩展性。MongoDB 采用了面向文档的数据模型&#xff0c;数据以 JSON 风格的 BSON&#xff08;Binary JSON&#xff09;文档存储&#x…

Android经典面试题之Glide的缓存大揭秘

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 Glide缓存 关联类&#xff1a;Engine、LruResourceCache、LruCache、ActiveResources ActiveResources&#xff1a;弱引用缓存池 VisibleForTe…