xLSTM: Extended Long Short-Term Memory
引用: Beck M, Pöppel K, Spanring M, et al. xLSTM: Extended Long Short-Term Memory[J]. arXiv preprint arXiv:2405.04517, 2024.
论文链接: [2405.04517] xLSTM: Extended Long Short-Term Memory (arxiv.org)
作者: Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, Sepp Hochreiter
机构: ELLIS Unit, LIT AI Lab, Institute for Machine Learning, JKU Linz, Austria; NXAI Lab, Linz, Austria; NXAI GmbH, Linz, Austria
文章目录
- xLSTM: Extended Long Short-Term Memory
- 摘要
- 引言
- xLSTM架构
- 1. **sLSTM(Scalar LSTM)**
- 2. **mLSTM(Matrix LSTM)**
- 3. **xLSTM块(xLSTM Blocks)**
- 4. **xLSTM架构(xLSTM Architecture)**
- 实验
- 结论
摘要
- 论文提出了xLSTM,一种扩展的长短期记忆网络,旨在解决传统LSTM的局限性,并在大规模参数下进行语言建模。
- xLSTM引入了指数门控和适当的归一化与稳定技术,修改了LSTM记忆结构,包括标量记忆的sLSTM和完全可并行化的具有矩阵记忆和协方差更新规则的mLSTM。
- 通过将这些LSTM变体集成到残差块中,构建了xLSTM架构,这些架构在性能和扩展性方面与最先进的Transformers和状态空间模型相媲美。
引言
- LSTM自1990年代引入以来,在多个领域取得了成功,特别是在大型语言模型(LLMs)中。
- 引入Transformer技术后,其并行化的自注意力机制使得LSTM在大规模应用中的性能受到挑战。
- 论文提出了一个问题:当LSTM扩展到数十亿参数,并结合现代LLMs的最新技术,同时克服LSTM的已知限制时,我们能在语言建模中走多远?
xLSTM架构
1. sLSTM(Scalar LSTM)
指数门控是sLSTM中的一个创新点,它允许模型更有效地更新其记忆状态。在传统的LSTM中,门控机制通常涉及sigmoid函数,但在xLSTM中,输入门( i t i_t it)和遗忘门( f t f_t ft)可以具有指数激活函数:
c t = f t c t − 1 + i t z t c _ { t } = f _ { t } c _ { t - 1 } + i _ { t } z _ { t } ct=ftct−1+itzt
n t = f t n t − 1 + i t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } nt=ftnt−1+it
h t = o t h t ~ , h t ~ = o t / n t h _ { t } = o _ { t } \tilde{h _ { t }}, \quad \tilde{h _ { t }} = o _ { t } / n _ { t } ht=otht~,ht~=ot/nt
z t = φ ( z ~ t ) , z ~ t = w z T x t + r z h t − 1 + b z z _ { t } = \varphi ( \tilde { z } _ { t } ), \quad \tilde { z } _ { t } = w _ { z } ^ { T } x _ { t } + r _ { z } h _ { t - 1 } + b _ { z } zt=φ(z~t),z~t=wzTxt+rzht−1+bz
i t = e x p ( i ~ t ) , i ~ t = w i T x t + r i h t − 1 + b i i _ { t } = exp ( \tilde { i } _ { t } ), \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + r _ { i } h _ { t - 1 } + b _ { i } it=exp(i~t),i~t=wiTxt+riht−1+bi
f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + r f h t − 1 + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad e x p ( \tilde { f } _ { t } ), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + r _ { f } h _ { t - 1 } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+rfht−1+bf
o t = e x p ( o ~ t ) , o ~ t = w o T x t + r o h t − 1 + b o o _ { t } = exp ( \tilde { o } _ { t } ), \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + r _ { o } h _ { t - 1 } + b _ { o } ot=exp(o~t),o~t=woTxt+roht−1+bo
指数激活函数可能导致较大的值,从而导致溢出。因此,用一个额外的状态 m t m_t mt来稳定门:
m t = max ( log ( f t ) + m t − 1 , log ( i t ) ) m _ { t } = \max ( \log ( f _ { t } ) + m _ { t - 1 } , \log ( i _ { t } ) ) mt=max(log(ft)+mt−1,log(it))
i t ′ = e x p ( log ( i t ) − m t ) = e x p ( i ~ t − m t ) i _ { t } ^ { \prime } = e x p ( \log ( i _ { t } ) - m _ { t } ) = e x p ( \tilde { i } _ { t } - m _ { t } ) it′=exp(log(it)−mt)=exp(i~t−mt)
f t ′ = e x p ( log ( f t ) + m t − 1 − m t ) f _ { t } ^ { \prime } = e x p ( \log ( f _ { t } ) + m _ { t - 1 } - m _ { t } ) ft′=exp(log(ft)+mt−1−mt)
其中,$m_t $是稳定状态,用于防止梯度爆炸。
同时,sLSTM引入了新的记忆混合技术,允许在多个内存单元之间进行更复杂的交互。多个存储器单元使得能够分别经由从隐藏状态向量 h h h到存储器单元输入 z z z和门 i i i、 f f f、 o o o的循环连接 R z R_z Rz、 R i R_i Ri、 R f R_f Rf、 R o R_o Ro进行存储器混合。sLSTM可以有多个头,每个头内混合内存,但不能跨头混合。
2. mLSTM(Matrix LSTM)
mLSTM使用矩阵记忆来增强存储容量,并通过协方差更新规则来存储关键值对。
C t = f t C t − 1 + i t v t k t T C _ { t } = f _ { t } C _ { t - 1 } + i _ { t } v _ { t } k _ { t } ^ { T } Ct=ftCt−1+itvtktT
n t = f t n t − 1 + i t k t n _ { t } = f _ { t } n _ { t - 1 } + i _ { t } k _ { t } nt=ftnt−1+itkt
h t = o t ⊙ h ~ t , h ~ t = C t q t / max { ∣ n t T q t ∣ , 1 } h _ { t } = o _ { t } \odot \tilde { h } _ { t } , \quad \tilde { h } _ { t } = C _ { t } q _ { t } / \max \left\{ | n _ { t } ^ { T } q _ { t } | , 1 \right\} ht=ot⊙h~t,h~t=Ctqt/max{∣ntTqt∣,1}
q t = W q x t + b q q _ { t } = W _ { q } x _ { t } + b _ { q } qt=Wqxt+bq
k t = 1 d W k x t + b k k _ { t } = \frac { 1 } { \sqrt { d } } W _ { k } x _ { t } + b _ { k } kt=d1Wkxt+bk
v t = W v x t + b v v _ { t } = W _ { v } x _ { t } + b _ { v } vt=Wvxt+bv
i t = e x p ( i ~ t ) , i ~ t = w i T x t + b i i _ { t } = e x p ( \tilde { i } _ { t } ) , \quad \tilde { i } _ { t } = w _ { i } ^ { T } x _ { t } + b _ { i } it=exp(i~t),i~t=wiTxt+bi
f t = σ ( f ~ t ) O R e x p ( f ~ t ) , f ~ t = w f T x t + b f f _ { t } = \sigma ( \tilde { f } _ { t } ) \quad OR \quad exp(\tilde { f } _ { t }), \quad \tilde { f } _ { t } = w _ { f } ^ { T } x _ { t } + b _ { f } ft=σ(f~t)ORexp(f~t),f~t=wfTxt+bf
o t = σ ( o ~ t ) , o ~ t = w o T x t + b o o _ { t } = \sigma ( \tilde { o } _ { t } ) , \quad \tilde { o } _ { t } = w _ { o } ^ { T } x _ { t } + b _ { o } ot=σ(o~t),o~t=woTxt+bo
3. xLSTM块(xLSTM Blocks)
xLSTM块结合了sLSTM和mLSTM的特性,并通过残差连接来进一步提高性能。对于残差sLSTM块,输入首先进入sLSTM,然后是一个门控的多层感知机(MLP)。对于残差mLSTM块,输入首先通过两个MLP,然后是mLSTM,通过卷积、可学习的跳跃连接和输出门。
4. xLSTM架构(xLSTM Architecture)
xLSTM架构通过残差堆叠xLSTM块来构建,利用了预层归一化(preLayerNorm)残差骨干网络。
实验
- 论文在合成任务和长距离竞技场(Long Range Arena)上测试了xLSTM,并与其他方法进行了比较。
- 在SlimPajama数据集上进行了语言建模实验,比较了不同方法的性能。
- 进行了扩展实验,训练了更大的模型,并在更多的训练数据上评估了它们的扩展行为。
结论
- xLSTM在语言建模方面的表现至少与当前的Transformer或状态空间模型相当。
- xLSTM有潜力在强化学习、时间序列预测或物理系统建模等深度学习领域产生重大影响。