CMU 10423 Generative AI:HW1(理论部分)

news2024/12/22 19:50:05

备注:S24版GitHub上有某CMU学生分享了自己的全套理论+编程作业,以下内容的整理结合了我自己的理解查阅、GPT4的解答、以及CMU学生的答案。

文章目录

  • 0 作业概述
  • 1 RNN语言模型
    • 1.1 问题1:Elman(即RNN) 网络模型条件输出问题
      • 讲义中问题的原文
      • Elman 网络介绍
      • 问题的通俗解释
      • 解答
        • 解题思路
        • 参数选择
        • 验证
    • 1.2 问题2:双向RNN与自回归语言模型的可实现性探讨
      • 讲义中问题的原文
      • 解答
        • 问题背景:
        • 问题1.2(a) 的解释:
        • 问题1.2(b) 的解释:
        • 1.2(b)概率分布的定义:
        • 1.2(a)和1.2(b)的总结:
  • 2 transformer语言模型
    • 2.1(a) 问题1:乘法注意力(multiplicative attention)
      • **解答**
        • **理由**:
        • **关于函数类的解释**:
        • 我(GPT4)不赞同CMU学生答案的理由:
        • **总结**:
    • 2.1(b) 问题2:连接注意力(concatenated attention)
      • 解答
        • **理由:**
        • **总结:**
    • 2.1(c) 问题3:加性注意力(Additive Attention)
      • 解答
        • **理由:**
        • **总结:**
    • 问题2.2:自注意力与多头注意力的对称性与多头/单头注意力的等价性问题
      • 解答
  • 3 Sliding Window Attention(滑动窗口注意力)
    • 解答
      • 3.1(a) 公式版的时间复杂度
      • 3.1(b) 公式版的空间复杂度
      • 3.1(c, d, e) 代码版实现、空间复杂度、时间复杂度
      • 代码流程总结:
      • 时间复杂度分析:
      • 空间复杂度分析:

0 作业概述

这次作业主要围绕生成式文本模型,具体包括以下几个部分:

  1. RNN语言模型 (7分)
  • 构建递归神经网络(RNN)用于语言建模,解答与RNN递归方程相关的数值问题,讨论双向RNN能否用于自回归语言模型。

  • Transformer语言模型 (15分)

    • 比较不同类型的注意力机制:缩放点积注意力、乘法注意力、连接注意力和加法注意力,并分析它们是否能够学习不同类别的函数。
  • 滑动窗口注意力 (8分)

    • 探讨滑动窗口注意力的时间和空间复杂度,并编写伪代码以优化滑动窗口注意力的计算效率。
  • 编程题:RoPE与GQA (24分)

    • 实现旋转位置嵌入(RoPE)和分组查询注意力(GQA),并对比它们在训练时间、内存消耗和模型性能上的表现。

1 RNN语言模型

1.1 问题1:Elman(即RNN) 网络模型条件输出问题

讲义中问题的原文

题目大意如下:

1.1 (3 分) 数值问题:考虑一个 RNN (Elman 网络),其输入为 x t ∈ { 0 , 1 } 2 x_t \in \{0, 1\}^2 xt{0,1}2,隐藏向量为 h t ∈ R 2 h_t \in \mathbb{R}^2 htR2,输出单元为 y t ∈ R y_t \in \mathbb{R} ytR,其中 t ∈ { 1 , . . . , T } t \in \{1, ..., T\} t{1,...,T}。假设递归公式如下所示:

h t = slide ( W h h h t − 1 + W h x x t + b h ) h_t = \text{slide}(W_{hh} h_{t-1} + W_{hx} x_t + b_h) ht=slide(Whhht1+Whxxt+bh)

y t = slide ( W y h h t + b y ) y_t = \text{slide}(W_{yh} h_t + b_y) yt=slide(Wyhht+by)

其中,激活函数 slide(a) 定义为: slide ( a ) = min ⁡ ( 1 , max ⁡ ( 0 , a ) ) \text{slide}(a) = \min(1, \max(0, a)) slide(a)=min(1,max(0,a))。请定义参数 W h h ∈ R 2 × 2 , W h x ∈ R 2 × 2 , W y h ∈ R 1 × 2 , b h ∈ R 2 , b y ∈ R W_{hh} \in \mathbb{R}^{2 \times 2}, W_{hx} \in \mathbb{R}^{2 \times 2}, W_{yh} \in \mathbb{R}^{1 \times 2}, b_h \in \mathbb{R}^2, b_y \in \mathbb{R} WhhR2×2,WhxR2×2,WyhR1×2,bhR2,byR使得满足以下条件:如果存在 r , s ≤ t r, s \leq t r,st,使得 x r , 0 = 1 x_{r,0} = 1 xr,0=1 x s , 1 = 1 x_{s,1} = 1 xs,1=1,那么 y t = 1 y_t = 1 yt=1,否则 y t = 0 y_t = 0 yt=0。假设 h 0 = [ 0 , 0 ] T h_0 = [0, 0]^T h0=[0,0]T

Elman 网络介绍

突然冒出个elman网络,查阅了些资料,其实就是我们熟知的最基本的RNN,在此介绍一下Elman网络早年的架构图。只看下图三层网络结构,不要看公式:

在这里插入图片描述

Elman网络是一种简单的递归神经网络(RNN),由Jeffrey Elman在1990年提出。它是RNN早期的基本形式之一,具有一个“隐含层”来存储前一时刻的状态信息,从而使网络能够处理序列数据(如时间序列、文本等),并捕捉输入数据中的时序依赖性。

输入(紫色):代表每个时间步的输入信息。

隐藏层(黄色):相当于一个全连接

输出(粉色):最终输出的结果

承接层(白色):这个就相当于常规RNN中的h_t,即缓存的是上一个时间步(即上一次的输入)的隐藏层的输出结果。

贴下当前RNN架构图的绘制方法:

在这里插入图片描述

问题的通俗解释

这个问题的核心是在构建一个简单的递归神经网络(RNN),该网络需要根据输入序列中的特定模式来输出0或1。具体来说,问题要求我们设计网络的参数,使得当输入序列的某些位置满足特定条件时,网络的输出为1,否则为0。

通俗来说:

  1. 输入是什么?
  • 每个时刻 t t t,你会输入一个二维向量 x t ∈ { 0 , 1 } 2 x_t \in \{0, 1\}^2 xt{0,1}2,也就是说,这个向量有两个元素,每个元素都是0或1。

  • 网络要做什么?

    • 网络要通过不断处理这些输入,并根据这些输入计算出输出 y t y_t yt。输出的规则是:如果在某些时刻:

      • 输入的第一个元素 x t , 0 x_{t,0} xt,0曾经为1,并且
      • 输入的第二个元素 x t , 1 x_{t,1} xt,1也曾经为1,

      那么输出 y t y_t yt应该是1,否则输出应该是0。

  • 如何实现?

    • 网络有一个“记忆”(隐藏状态 h t h_t ht),它会记住过去发生的事情。我们需要让网络“记住”是否见过第一个元素为1和第二个元素为1的情况。
    • 比如,如果在之前的某个时刻 t t t,你输入的 x t , 0 = 1 x_{t,0} = 1 xt,0=1,那么网络的“记忆”应该更新,记住这一点。同样,如果在某个时刻 x t , 1 = 1 x_{t,1} = 1 xt,1=1,网络也要记住。
    • 一旦网络“记住”了这两个条件(即第一个元素和第二个元素都见过1),它就输出1;如果这两个条件都没发生或只发生了一个,它就输出0。

解答

CMU学生的解答:

在这里插入图片描述

以下是GPT的解答:

问题给出的RNN的结构为:

  1. h t = slide ( W h h h t − 1 + W h x x t + b h ) h_t = \text{slide}(W_{hh} h_{t-1} + W_{hx} x_t + b_h) ht=slide(Whhht1+Whxxt+bh)
  2. y t = slide ( W y h h t + b y ) y_t = \text{slide}(W_{yh} h_t + b_y) yt=slide(Wyhht+by)

其中,slide函数的作用是将值限制在 [0, 1] 之间,即:

slide ( a ) = min ⁡ ( 1 , max ⁡ ( 0 , a ) ) \text{slide}(a) = \min(1, \max(0, a)) slide(a)=min(1,max(0,a))

需要我们定义参数 W h h , W h x , W y h , b h , b y W_{hh}, W_{hx}, W_{yh}, b_h, b_y Whh,Whx,Wyh,bh,by,使得:

  • y t = 1 y_t = 1 yt=1,当存在 r , s ≤ t r, s \leq t r,st使得 x r , 0 = 1 x_{r,0} = 1 xr,0=1并且 x s , 1 = 1 x_{s,1} = 1 xs,1=1
  • 否则 y t = 0 y_t = 0 yt=0

假设初始状态 h 0 = [ 0 , 0 ] T h_0 = [0, 0]^T h0=[0,0]T

解题思路
  1. 目标条件分析:
  • 我们希望当存在一个时刻 r r r 使得 x r , 0 = 1 x_{r,0} = 1 xr,0=1(即输入的第一个元素为1)并且存在另一个时刻 s s s 使得 x s , 1 = 1 x_{s,1} = 1 xs,1=1(即输入的第二个元素为1)时,输出 y t = 1 y_t = 1 yt=1

  • 因此我们可以利用隐藏状态 h t h_t ht 来存储这些信息,即记录是否见过 x t , 0 = 1 x_{t,0} = 1 xt,0=1 x t , 1 = 1 x_{t,1} = 1 xt,1=1

  • 隐藏状态 ​ h t h_t ht 的设计:

    • 将隐藏状态设计为 h t = [ h t ( 0 ) , h t ( 1 ) ] h_t = [h_t^{(0)}, h_t^{(1)}] ht=[ht(0),ht(1)],其中:
      • h t ( 0 ) h_t^{(0)} ht(0) 用来记录是否见过 x t , 0 = 1 x_{t,0} = 1 xt,0=1
      • h t ( 1 ) h_t^{(1)} ht(1) 用来记录是否见过 x t , 1 = 1 x_{t,1} = 1 xt,1=1
    • h t ( 0 ) = 1 h_t^{(0)} = 1 ht(0)=1 时,说明在当前时刻之前至少有一个 x r , 0 = 1 x_{r,0} = 1 xr,0=1
    • h t ( 1 ) = 1 h_t^{(1)} = 1 ht(1)=1 时,说明在当前时刻之前至少有一个 x s , 1 = 1 x_{s,1} = 1 xs,1=1
  • 递归更新 ​ h t h_t ht

    • 对于 W h x W_{hx} Whx,当 x t , 0 = 1 x_{t,0} = 1 xt,0=1 时,需要更新 h t ( 0 ) = 1 h_t^{(0)} = 1 ht(0)=1
    • 对于 W h x W_{hx} Whx,当 x t , 1 = 1 x_{t,1} = 1 xt,1=1 时,需要更新 h t ( 1 ) = 1 h_t^{(1)} = 1 ht(1)=1
    • 这些信息将通过递归公式传递,因此设计 W h h W_{hh} Whh 来使之前的状态能够保持。
  • 输出 ​ y t y_t yt 的设计:

    • h t ( 0 ) = 1 h_t^{(0)} = 1 ht(0)=1 h t ( 1 ) = 1 h_t^{(1)} = 1 ht(1)=1 时,输出 y t = 1 y_t = 1 yt=1,否则 y t = 0 y_t = 0 yt=0
参数选择

我们可以选择以下参数来实现这个目标:

  1. W h x = ( 1 0 0 1 ) W_{hx} = \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} Whx=(1001)

这样当 x t , 0 = 1 x_{t,0} = 1 xt,0=1时,隐藏状态 h t ( 0 ) h_t^{(0)} ht(0)更新为1;当 x t , 1 = 1 x_{t,1} = 1 xt,1=1时,隐藏状态 h t ( 1 ) h_t^{(1)} ht(1)更新为1。

  1. W h h = ( 1 0 0 1 ) W_{hh} = \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} Whh=(1001)

这可以确保上一时刻的隐藏状态会传递到当前时刻。

  1. b h = ( 0 0 ) b_h = \begin{pmatrix} 0 \\ 0 \end{pmatrix} bh=(00)

无需额外的偏置项。

  1. W y h = ( 1 1 ) W_{yh} = \begin{pmatrix} 1 & 1 \end{pmatrix} Wyh=(11)

输出 y t y_t yt取决于隐藏状态的两个分量的和,当且仅当两个分量都是1时,才能满足 y t = 1 y_t = 1 yt=1

  1. b y = − 1 b_y = -1 by=1

这样只有当两个分量的和为2时,才能使得输出 y t y_t yt通过 slide 函数达到1,否则输出为0。

验证

我们需要验证:当且仅当 x r , 0 = 1 x_{r,0} = 1 xr,0=1 x s , 1 = 1 x_{s,1} = 1 xs,1=1存在时,输出 y t = 1 y_t = 1 yt=1,否则 y t = 0 y_t = 0 yt=0

我们从初始状态 h 0 = ( 0 0 ) h_0 = \begin{pmatrix} 0 \\ 0 \end{pmatrix} h0=(00)开始,逐步推导隐藏状态 h t h_t ht和输出 y t y_t yt的计算过程。

步骤 1:当 x t = ( 1 0 ) x_t = \begin{pmatrix} 1 \\ 0 \end{pmatrix} xt=(10)

  • 输入 x t x_t xt 的第一个元素为 1,第二个元素为 0。
  • 递归更新 h t h_t ht

h t = slide ( W h h h t − 1 + W h x x t + b h ) h_t = \text{slide}(W_{hh} h_{t-1} + W_{hx} x_t + b_h) ht=slide(Whhht1+Whxxt+bh)

h t = slide ( ( 1 0 0 1 ) ( h t − 1 ( 0 ) h t − 1 ( 1 ) ) + ( 1 0 0 1 ) ( 1 0 ) ) h_t = \text{slide} \left( \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} h_{t-1}^{(0)} \\ h_{t-1}^{(1)} \end{pmatrix} + \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} 1 \\ 0 \end{pmatrix} \right) ht=slide((1001)(ht1(0)ht1(1))+(1001)(10))

h t = slide ( ( h t − 1 ( 0 ) h t − 1 ( 1 ) ) + ( 1 0 ) ) = slide ( ( h t − 1 ( 0 ) + 1 h t − 1 ( 1 ) ) ) h_t = \text{slide} \left( \begin{pmatrix} h_{t-1}^{(0)} \\ h_{t-1}^{(1)} \end{pmatrix} + \begin{pmatrix} 1 \\ 0 \end{pmatrix} \right) = \text{slide} \left( \begin{pmatrix} h_{t-1}^{(0)} + 1 \\ h_{t-1}^{(1)} \end{pmatrix} \right) ht=slide((ht1(0)ht1(1))+(10))=slide((ht1(0)+1ht1(1)))

h t = ( 1 h t − 1 ( 1 ) ) h_t = \begin{pmatrix} 1 \\ h_{t-1}^{(1)} \end{pmatrix} ht=(1ht1(1))

因此,此时 h t ( 0 ) = 1 h_t^{(0)} = 1 ht(0)=1,表示见到了 x t , 0 = 1 x_{t,0} = 1 xt,0=1

步骤 2:当 x t = ( 0 1 ) x_t = \begin{pmatrix} 0 \\ 1 \end{pmatrix} xt=(01)

  • 输入 x t x_t xt 的第一个元素为 0,第二个元素为 1。
  • 递归更新 h t h_t ht

h t = slide ( W h h h t − 1 + W h x x t + b h ) h_t = \text{slide}(W_{hh} h_{t-1} + W_{hx} x_t + b_h) ht=slide(Whhht1+Whxxt+bh)

h t = slide ( ( 1 0 0 1 ) ( h t − 1 ( 0 ) h t − 1 ( 1 ) ) + ( 1 0 0 1 ) ( 0 1 ) ) h_t = \text{slide} \left( \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} h_{t-1}^{(0)} \\ h_{t-1}^{(1)} \end{pmatrix} + \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} 0 \\ 1 \end{pmatrix} \right) ht=slide((1001)(ht1(0)ht1(1))+(1001)(01))

h t = slide ( ( h t − 1 ( 0 ) h t − 1 ( 1 ) + 1 ) ) h_t = \text{slide} \left( \begin{pmatrix} h_{t-1}^{(0)} \\ h_{t-1}^{(1)} + 1 \end{pmatrix} \right) ht=slide((ht1(0)ht1(1)+1))

h t = ( h t − 1 ( 0 ) 1 ) h_t = \begin{pmatrix} h_{t-1}^{(0)} \\ 1 \end{pmatrix} ht=(ht1(0)1)

此时 h t ( 1 ) = 1 h_t^{(1)} = 1 ht(1)=1,表示见到了 x t , 1 = 1 x_{t,1} = 1 xt,1=1

步骤 3:输出 y t y_t yt的计算

  • 输出层的公式为:

y t = slide ( W y h h t + b y ) y_t = \text{slide}(W_{yh} h_t + b_y) yt=slide(Wyhht+by)

其中 W y h = ( 1 1 ) W_{yh} = \begin{pmatrix} 1 & 1 \end{pmatrix} Wyh=(11) b y = − 1 b_y = -1 by=1

所以当 h t = ( 1 1 ) h_t = \begin{pmatrix} 1 \\ 1 \end{pmatrix} ht=(11) 时:

y t = slide ( 1 × 1 + 1 × 1 − 1 ) = slide ( 1 ) = 1 y_t = \text{slide}(1 \times 1 + 1 \times 1 - 1) = \text{slide}(1) = 1 yt=slide(1×1+1×11)=slide(1)=1

即输出为 1。

h t = ( 1 0 ) h_t = \begin{pmatrix} 1 \\ 0 \end{pmatrix} ht=(10) h t = ( 0 1 ) h_t = \begin{pmatrix} 0 \\ 1 \end{pmatrix} ht=(01) 时:

y t = slide ( 1 × 1 + 1 × 0 − 1 ) = slide ( 0 ) = 0 y_t = \text{slide}(1 \times 1 + 1 \times 0 - 1) = \text{slide}(0) = 0 yt=slide(1×1+1×01)=slide(0)=0

y t = slide ( 1 × 0 + 1 × 1 − 1 ) = slide ( 0 ) = 0 y_t = \text{slide}(1 \times 0 + 1 \times 1 - 1) = \text{slide}(0) = 0 yt=slide(1×0+1×11)=slide(0)=0

即输出为 0。

验证总结:

通过上面的推导,我们可以验证以下结果:

  • h t ( 0 ) = 1 h_t^{(0)} = 1 ht(0)=1 h t ( 1 ) = 1 h_t^{(1)} = 1 ht(1)=1,即同时见到了 x t , 0 = 1 x_{t,0} = 1 xt,0=1 x t , 1 = 1 x_{t,1} = 1 xt,1=1时,输出 y t = 1 y_t = 1 yt=1
  • 如果仅有一个分量为1,或者两个分量都为0,输出 y t = 0 y_t = 0 yt=0

因此,所选的参数符合题目要求,成功实现了题目中要求的条件。

1.2 问题2:双向RNN与自回归语言模型的可实现性探讨

讲义中问题的原文

自回归语言模型定义了一个序列 x 1 : T x_{1:T} x1:T 上的概率分布形式如下:

p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ x 1 , … , x t − 1 ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | x_1, \dots, x_{t-1}) p(x1:T)=t=1Tp(xtx1,,xt1)

(a) (2 分) 简答:假设我们给定一个输入 x 1 : T x_{1:T} x1:T,并且定义一个双向 RNN,如下所示:

f t = σ ( W f f t − 1 + W f x x t + b f ) , ∀ t ∈ { 1 , … , T } f_t = \sigma(W_f f_{t-1} + W_{fx} x_t + b_f), \quad \forall t \in \{1, \dots, T\} ft=σ(Wfft1+Wfxxt+bf),t{1,,T}

g t = σ ( W g g t + 1 + W g x x t + b g ) , ∀ t ∈ { 1 , … , T } g_t = \sigma(W_g g_{t+1} + W_{gx} x_t + b_g), \quad \forall t \in \{1, \dots, T\} gt=σ(Wggt+1+Wgxxt+bg),t{1,,T}

h t = σ ( W h f t + W h g t + b g ) , ∀ t ∈ { 1 , … , T } h_t = \sigma(W_h f_t + W_h g_t + b_g), \quad \forall t \in \{1, \dots, T\} ht=σ(Whft+Whgt+bg),t{1,,T}

注意 f t f_t ft从左侧建立上下文, g t g_t gt从右侧建立上下文, h t h_t ht结合了两者。我们是否可以定义形式为 p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ h t − 1 ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | h_{t-1}) p(x1:T)=t=1Tp(xtht1)的自回归语言模型?如果可以,定义这个概率分布;如果不可以,请解释原因。

(b) (2 分) 简答:假设 BiRNN( x 1 : t − 1 x_{1:t-1} x1:t1) 计算子序列 x 1 : t − 1 x_{1:t-1} x1:t1 上的双向 RNN,并返回 h t − 1 h_{t-1} ht1。我们是否可以定义形式为 p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ BiRNN ( x 1 : t − 1 ) ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | \text{BiRNN}(x_{1:t-1})) p(x1:T)=t=1Tp(xtBiRNN(x1:t1)) 的自回归语言模型?如果可以,定义这个概率分布;如果不可以,请解释原因。

解答

CMU学生的解答:

1.2(a):

在这里插入图片描述

1.2(b):

在这里插入图片描述

以下是GPT解答:

1.2问题涉及到自回归语言模型以及双向RNN的结构设计。问题的核心是探讨如何在双向RNN的基础上定义自回归语言模型,或者说能否通过某些改动来实现这种模型。下面我将对问题进行通俗解释。

问题背景:

首先,什么是自回归语言模型?它是通过对前面已经看到的序列来预测下一个单词或字符。具体来说,自回归模型的核心是使用已经看到的序列来计算当前时刻的条件概率:

p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ x 1 , … , x t − 1 ) p(x_1:T) = \prod_{t=1}^{T} p(x_t | x_1, \dots, x_{t-1}) p(x1:T)=t=1Tp(xtx1,,xt1)

这意味着:我们在生成序列时,每个时间步 t t t的输出 x t x_t xt都依赖于之前的所有输入 x 1 , … , x t − 1 x_1, \dots, x_{t-1} x1,,xt1

问题1.2(a) 的解释:
  • 双向RNN(BiRNN) 是一种扩展的RNN,它同时考虑序列的前后文信息,即它的两个隐藏状态:

    • f t f_t ft:从左到右处理,构建从左侧传来的上下文。
    • g t g_t gt:从右到左处理,构建从右侧传来的上下文。
  • 现在问题问:能否使用这种双向RNN来定义一个自回归语言模型?

通俗地说,自回归语言模型要求在生成 x t x_t xt 时,不能利用未来的信息(即 t + 1 , t + 2 , … t+1, t+2, \dots t+1,t+2,)。然而,双向RNN的结构同时利用了前向和后向的隐藏状态,即它利用了未来的上下文信息。因此,这与自回归语言模型的要求不符。

结论:不可以直接使用双向RNN定义自回归语言模型,因为双向RNN的设计依赖未来的信息,而自回归模型只依赖过去的信息。

问题1.2(b) 的解释:
  • 在这个问题中,假设我们使用一个双向RNN,但只让它处理到时间步 t − 1 t-1 t1 的输入(即不考虑未来的输入)。
  • 现在问题问:如果我们只使用时间步 ​ t − 1 t-1 t1 的双向RNN隐藏状态来生成 x t x_t xt,是否可以定义自回归语言模型?

这个问题实际上是在问:是否可以利用双向RNN的隐藏状态(只处理到 ​ t − 1 t-1 t1)来进行自回归预测?

通俗地说,如果我们在生成当前的 x t x_t xt 时,只使用之前时间步 x 1 , … , x t − 1 x_1, \dots, x_{t-1} x1,,xt1 的信息,那么这是符合自回归模型要求的。因为它只使用了过去的信息,符合自回归模型不使用未来信息的条件。

结论:可以使用双向RNN处理到时间步 t − 1 t-1 t1 的隐藏状态来定义自回归语言模型,因为此时没有使用未来的输入序列。

1.2(b)概率分布的定义:
  • 自回归语言模型的基本公式为:

p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ x 1 , … , x t − 1 ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | x_1, \dots, x_{t-1}) p(x1:T)=t=1Tp(xtx1,,xt1)

这意味着每个时刻 x t x_t xt 的输出只依赖于之前时刻的输入。

  • 在这个问题中,双向RNN的隐藏状态 h t − 1 h_{t-1} ht1 (包括 f t − 1 f_{t-1} ft1 g t − 1 g_{t-1} gt1)能够捕捉之前的输入序列信息。因此,我们可以用隐藏状态 h t − 1 h_{t-1} ht1 来代替 x 1 , … , x t − 1 x_1, \dots, x_{t-1} x1,,xt1 作为当前时刻 x t x_t xt 的条件:

p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ h t − 1 ) = ∏ t = 1 T p ( x t ∣ f t − 1 , g t − 1 ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | h_{t-1}) = \prod_{t=1}^{T} p(x_t | f_{t-1}, g_{t-1}) p(x1:T)=t=1Tp(xtht1)=t=1Tp(xtft1,gt1)

其中, f t − 1 f_{t-1} ft1 是从左到右的隐藏状态, g t − 1 g_{t-1} gt1 是从右到左的隐藏状态。

最终结论

我们可以使用双向RNN处理到时间步 t − 1 t-1 t1的隐藏状态 h t − 1 h_{t-1} ht1来定义自回归语言模型,其概率分布为:

p ( x 1 : T ) = ∏ t = 1 T p ( x t ∣ h t − 1 ) = ∏ t = 1 T p ( x t ∣ f t − 1 , g t − 1 ) p(x_{1:T}) = \prod_{t=1}^{T} p(x_t | h_{t-1}) = \prod_{t=1}^{T} p(x_t | f_{t-1}, g_{t-1}) p(x1:T)=t=1Tp(xtht1)=t=1Tp(xtft1,gt1)

该公式表明,序列的生成是基于前一时刻的隐藏状态,符合自回归语言模型的要求。

1.2(a)和1.2(b)的总结:
  • 1.2(a):不能直接使用双向RNN定义自回归语言模型,因为它依赖未来的信息。
  • 1.2(b):如果我们只使用双向RNN处理到时间步 t − 1 t-1 t1 的信息来生成 x t x_t xt,则可以定义自回归语言模型,因为它只依赖过去的信息。

2 transformer语言模型

Transformer 使用缩放点积注意力(scaled-dot-product attention):

s t , j = k j T q t ∣ k ∣ , ∀ j , t s_{t,j} = \frac{k_j^T q_t}{\sqrt{|k|}}, \forall j, t st,j=k kjTqt,j,t

a t = softmax ( s t ) , ∀ t a_t = \text{softmax}(s_t), \forall t at=softmax(st),t

其中,值(values)、查询(queries)和键(keys)分别由以下公式给出: v j = W v T x j v_j = W_v^T x_j vj=WvTxj, q j = W q T x j q_j = W_q^T x_j qj=WqTxj, 以及 k j = W k T x j k_j = W_k^T x_j kj=WkTxj对于所有 j j j。这里, v j , q j , k j ∈ R d k v_j, q_j, k_j \in \mathbb{R}^{d_k} vj,qj,kjRdk

2.1(a) 问题1:乘法注意力(multiplicative attention)

乘法注意力将注意力权重定义为:

s ~ t , j = k j T W s q t ∣ k ∣ , ∀ j , t \tilde{s}_{t,j} = \frac{k_j^T W_s q_t}{\sqrt{|k|}}, \forall j, t s~t,j=k kjTWsqt,j,t

其中, W s ∈ R d k × d k W_s \in \mathbb{R}^{d_k \times d_k} WsRdk×dk是一个参数矩阵。一个使用乘法注意力的 Transformer 能否学习到与使用更简单的缩放点积注意力不同的函数类?简要说明你的理由。

解答

CMU学生的解答:(GPT4并不认同该答案)

在这里插入图片描述

GPT的解答:

结论是的,使用乘法注意力的 Transformer 能够学习到与缩放点积注意力不同的函数类

理由
  1. 乘法注意力的额外自由度
  • 在 **乘法注意力 **中,查询向量 q t q_t qt 经过了一个可训练的线性变换 W s W_s Ws,这使得相似性度量具有更高的灵活性。这个矩阵 W s W_s Ws 是额外的参数,能够对查询进行方向和尺度上的调整。

  • 相较之下,**缩放点积注意力 **中的查询 q t q_t qt 和键 k j k_j kj 直接进行点积运算,并通过 1 ∣ k ∣ \frac{1}{\sqrt{|k|}} k 1 进行缩放,但没有进一步的变换。这种相似性度量是固定的,仅依赖于查询和键向量的方向性。

  • 扩展的函数类

    • **缩放点积注意力 **的函数类是基于直接的点积计算,即:

    s t , j = k j T q t ∣ k ∣ s_{t,j} = \frac{k_j^T q_t}{\sqrt{|k|}} st,j=k kjTqt

    这是一个简单的相似性度量函数,计算的是查询和键之间的夹角或相对方向。

    • **乘法注意力 **的函数类更广泛,因为查询 q t q_t qt 在与键 k j k_j kj 点积之前,先经过了线性变换 W s W_s Ws。这意味着它不仅可以表示点积相似性,还可以通过学习 W s W_s Ws 来调整查询向量,从而学习到更复杂的相似性关系。

    s ~ t , j = k j T W s q t ∣ k ∣ \tilde{s}_{t,j} = \frac{k_j^T W_s q_t}{\sqrt{|k|}} s~t,j=k kjTWsqt

因此,乘法注意力能够学习到比缩放点积注意力更加灵活和复杂的相似性函数,能够捕捉到不同的模式。

关于函数类的解释
  • **函数类(function class) **是指一组具有相似结构或行为的函数的集合。它描述了从输入到输出的映射关系。在机器学习中,函数类是由模型的结构、参数以及计算方法所决定的。
    • 在注意力机制中,函数类表示的是用于计算查询 q t q_t qt 和键 k j k_j kj 之间相似性得分的函数集合。例如,缩放点积注意力的函数类包括直接通过点积计算相似性的函数,而乘法注意力的函数类由于引入了线性变换,涵盖了更多的相似性度量方式。
我(GPT4)不赞同CMU学生答案的理由:

某人在截图中的观点认为:“乘法注意力无法学习到与缩放点积注意力不同的函数类”,因为它仅仅是对缩放点积注意力的线性变换版本。这个观点过于局限,忽视了线性变换 ​ W s W_s Ws 的可训练性和灵活性。

  • 为什么不认同?
    • 可训练的 W s W_s Ws:乘法注意力通过引入矩阵 W s W_s Ws对查询 q t q_t qt进行线性变换,使得相似性度量不再局限于简单的点积,而是能够学习和表达更多复杂的关系。这就意味着,乘法注意力可以表示出缩放点积注意力无法表达的相似性度量。
    • 不同的函数类:虽然在表面形式上,乘法注意力的点积结构与缩放点积注意力相似,但由于矩阵 W s W_s Ws的存在,乘法注意力的函数类要比缩放点积注意力的函数类更广。因此,乘法注意力能够学习到与缩放点积注意力不同的函数类
总结
  • 乘法注意力引入了额外的线性变换 W s W_s Ws,使得它能够表示比缩放点积注意力更复杂的相似性函数,因此它们学习的函数类是不同的。
  • 某人的观点忽略了乘法注意力中的 W s W_s Ws 的可训练性,这也是我不认同其结论的原因。

2.1(b) 问题2:连接注意力(concatenated attention)

连接注意力将注意力权重定义为:

s ^ t , j = w s T [ k j ; q t ] , ∀ j , t \hat{s}_{t,j} = w_s^T [k_j; q_t], \forall j, t s^t,j=wsT[kj;qt],j,t

其中, w s ∈ R 2 d k w_s \in \mathbb{R}^{2d_k} wsR2dk是一个参数向量,且 [ a ; b ] [a; b] [a;b]表示向量 a a a b b b的连接。是否存在参数 w s w_s ws,使得 s ^ t , j \hat{s}_{t,j} s^t,j能近似等于两个向量 k j k_j kj q t q_t qt之间的角度 θ \theta θ,或 cos ⁡ ( θ ) \cos(\theta) cos(θ)?简要说明理由(不需要正式的证明)。

解答

CMU学生:

在这里插入图片描述

GPT4解答:

结论不存在这样的固定参数 w s w_s ws,使得 s ^ t , j \hat{s}_{t,j} s^t,j 能够在查询 q t q_t qt 和键 k j k_j kj 变化的情况下,始终近似表示它们之间的角度 θ \theta θ 或余弦相似度 cos ⁡ ( θ ) \cos(\theta) cos(θ)

理由:
  1. 连接操作与固定参数的局限性
  • 在连接注意力中, [ k j ; q t ] [k_j; q_t] [kj;qt]是通过将两个向量 k j ∈ R d k k_j \in \mathbb{R}^{d_k} kjRdk q t ∈ R d k q_t \in \mathbb{R}^{d_k} qtRdk拼接起来形成一个 2 d k 2d_k 2dk维的向量:

[ k j ; q t ] = ( k j q t ) ∈ R 2 d k [k_j; q_t] = \begin{pmatrix} k_j \\ q_t \end{pmatrix} \in \mathbb{R}^{2d_k} [kj;qt]=(kjqt)R2dk

然后使用固定的权重向量 w s w_s ws对该拼接向量进行线性变换:

s ^ t , j = w s T [ k j ; q t ] \hat{s}_{t,j} = w_s^T [k_j; q_t] s^t,j=wsT[kj;qt]

这个线性变换只能计算出一个固定的线性组合,它无法直接表示两个向量之间的非线性关系,例如角度 θ \theta θ或余弦相似度 cos ⁡ ( θ ) \cos(\theta) cos(θ),因为它们依赖于向量的相对方向和长度。

  • 余弦相似度的非线性特性

    • 余弦相似度 cos ⁡ ( θ ) \cos(\theta) cos(θ)的公式为:

    cos ⁡ ( θ ) = k j T q t ∥ k j ∥ ∥ q t ∥ \cos(\theta) = \frac{k_j^T q_t}{\|k_j\| \|q_t\|} cos(θ)=kj∥∥qtkjTqt

    这是一个非线性函数,它不仅涉及向量 k j k_j kj q t q_t qt的点积,还需要计算它们的长度 ∥ k j ∥ \|k_j\| kj ∥ q t ∥ \|q_t\| qt。而连接注意力中的线性变换 w s T [ k j ; q t ] w_s^T [k_j; q_t] wsT[kj;qt]并没有这种非线性处理能力,因此无法准确表示或近似余弦相似度。

  • 输入的变化与固定参数的关系

    • 在注意力机制中,查询 q t q_t qt和键 k j k_j kj是随每次输入变化的,而 w s w_s ws是一个固定的参数向量。由于 w s w_s ws不会随输入变化,因此它无法动态地调整来适应每个不同的 q t q_t qt k j k_j kj,从而计算出它们之间的角度 θ \theta θ或余弦相似度。
  • 点积与余弦相似度的不同

    • 点积( k j T q t k_j^T q_t kjTqt)与余弦相似度( cos ⁡ ( θ ) \cos(\theta) cos(θ))之间有一个关键区别:点积是一个线性度量,而余弦相似度是一个非线性度量,涉及向量的方向和大小。固定的线性变换 w s T [ k j ; q t ] w_s^T [k_j; q_t] wsT[kj;qt]无法表示这种非线性关系。
    • 因此,即便通过合适的权重组合, w s w_s ws也不能完全捕捉到 cos ⁡ ( θ ) \cos(\theta) cos(θ)这样复杂的非线性相似度
总结:
  • 不存在固定的参数 w s w_s ws,可以使得连接注意力的输出 s ^ t , j \hat{s}_{t,j} s^t,j 近似等于两个向量 k j k_j kj q t q_t qt 之间的角度 θ \theta θ 或余弦相似度 cos ⁡ ( θ ) \cos(\theta) cos(θ)
  • 连接注意力中的线性变换 w s T [ k j ; q t ] w_s^T [k_j; q_t] wsT[kj;qt] 只能表示线性关系,而无法捕捉两个向量间的非线性相似性,例如角度或余弦相似度。

2.1© 问题3:加性注意力(Additive Attention)

加性注意力将注意力权重定义为:

s ^ t , j = w s T tanh ⁡ ( W s [ k j ; q t ] ) , ∀ j , t \hat{s}_{t,j} = w_s^T \tanh(W_s [k_j; q_t]), \forall j, t s^t,j=wsTtanh(Ws[kj;qt]),j,t

其中, w s ∈ R d s w_s \in \mathbb{R}^{d_s} wsRds W s ∈ R d s × 2 d k W_s \in \mathbb{R}^{d_s \times 2d_k} WsRds×2dk是参数,且 d s d_s ds是超参数。是否存在参数 w s , W s w_s, W_s ws,Ws,使得 s ^ t , j \hat{s}_{t,j} s^t,j近似等于向量 k j k_j kj q t q_t qt之间的角度 θ \theta θ cos ⁡ ( θ ) \cos(\theta) cos(θ)?简要说明理由(不需要正式的证明)。

解答

CMU学生:

在这里插入图片描述

GPT解答:

结论存在这样的参数 w s w_s ws W s W_s Ws,可以使加性注意力中的输出 s ^ t , j \hat{s}_{t,j} s^t,j近似表示 k j k_j kj q t q_t qt之间的余弦相似度 cos ⁡ ( θ ) \cos(\theta) cos(θ),这是因为加性注意力引入了非线性激活函数 tanh ⁡ \tanh tanh和可学习的线性变换,具备足够的表示能力来近似复杂的非线性关系。

理由:
  1. 非线性激活函数的作用
  • 加性注意力中的公式使用了 tanh ⁡ \tanh tanh作为非线性激活函数。 tanh ⁡ \tanh tanh函数能够将线性组合的输入(在这里是 W s [ k j ; q t ] W_s [k_j; q_t] Ws[kj;qt])转化为非线性输出,具备近似任意非线性函数的能力,包括余弦相似度这样的非线性关系。

  • 神经网络的近似能力

    • 神经网络的一个基本性质是通用函数逼近器,即通过足够的参数和非线性激活函数,神经网络可以近似任意的连续函数。这里的 w s T tanh ⁡ ( W s [ k j ; q t ] ) w_s^T \tanh(W_s [k_j; q_t]) wsTtanh(Ws[kj;qt])可以看作是一个简单的神经网络结构,它能够通过训练来逼近 cos ⁡ ( θ ) \cos(\theta) cos(θ),即查询 q t q_t qt和键 k j k_j kj之间的余弦相似度。
  • 线性变换与非线性结合的灵活性

    • W s W_s Ws是一个线性变换矩阵,它可以对输入 [ k j ; q t ] [k_j; q_t] [kj;qt]进行初步的线性映射。结合非线性激活函数 tanh ⁡ \tanh tanh,这提供了足够的自由度来表示复杂的相似性度量(如余弦相似度)。
    • 参数 w s w_s ws的线性组合进一步将这些非线性输出组合在一起。因此,经过适当的训练,参数 w s w_s ws W s W_s Ws完全可以学习到近似 cos ⁡ ( θ ) \cos(\theta) cos(θ)的映射。
总结:
  • 存在这样的参数 w s w_s ws W s W_s Ws,使得加性注意力的输出 s ^ t , j \hat{s}_{t,j} s^t,j 可以近似表示查询 q t q_t qt 和键 k j k_j kj 之间的角度或余弦相似度。这是因为加性注意力结合了非线性激活函数 tanh ⁡ \tanh tanh 和线性变换,具备近似复杂非线性关系的能力。

问题2.2:自注意力与多头注意力的对称性与多头/单头注意力的等价性问题

自注意力(Self-attention)通常通过矩阵乘法来计算。这里我们考虑没有因果注意力掩码的多头注意力(multi-headed attention)。

设:

X = [ x 1 , . . . , x N ] T X = [x_1, ..., x_N]^T X=[x1,...,xN]T

V ( i ) = X W v ( i ) V^{(i)} = XW^{(i)}_v V(i)=XWv(i)

K ( i ) = X W k ( i ) K^{(i)} = XW^{(i)}_k K(i)=XWk(i)

Q ( i ) = X W q ( i ) Q^{(i)} = XW^{(i)}_q Q(i)=XWq(i)

S ( i ) = Q ( i ) ( K ( i ) ) T / d k S^{(i)} = Q^{(i)}(K^{(i)})^T / \sqrt{d_k} S(i)=Q(i)(K(i))T/dk

A ( i ) = softmax ( S ( i ) ) A^{(i)} = \text{softmax}(S^{(i)}) A(i)=softmax(S(i))

X ′ ( i ) = A ( i ) V ( i ) X'^{(i)} = A^{(i)}V^{(i)} X(i)=A(i)V(i)

X ′ = concat ( X ′ ( 1 ) , . . . , X ′ ( h ) ) X' = \text{concat}(X'^{(1)}, ..., X'^{(h)}) X=concat(X(1),...,X(h))

其中, N N N是序列长度, h h h是注意力头的数量,每一行与 i i i相关的定义适用于所有 i ∈ { 1 , . . . , h } i \in \{1, ..., h\} i{1,...,h}

(a) (3分) 简答:注意力矩阵 A ( i ) A^{(i)} A(i)是否总是对称的?如果是,请证明它是对称的。如果不是,请描述一个确保它对称的条件。

(b) (4分) 简答:假设我们有两个注意力头 h = 2 h = 2 h=2,令 d k = d m / h d_k = d_m / h dk=dm/h,并且我们有一个单一输入 X X X。令 X ′ X' X为在参数 W v ( 1 ) , W k ( 1 ) , W q ( 1 ) , W v ( 2 ) , W k ( 2 ) , W q ( 2 ) ∈ R d m × d k W^{(1)}_v, W^{(1)}_k, W^{(1)}_q, W^{(2)}_v, W^{(2)}_k, W^{(2)}_q \in R^{d_m \times d_k} Wv(1),Wk(1),Wq(1),Wv(2),Wk(2),Wq(2)Rdm×dk上执行的多头注意力的输出。现在,假设我们使用这些相同的参数,并沿行拼接得到新的参数:

W v ′ = concat ( W v ( 1 ) , W v ( 2 ) ) , W k ′ = concat ( W k ( 1 ) , W k ( 2 ) ) , W q ′ = concat ( W q ( 1 ) , W q ( 2 ) ) ∈ R d m × d m W'_v = \text{concat}(W^{(1)}_v, W^{(2)}_v), W'_k = \text{concat}(W^{(1)}_k, W^{(2)}_k), W'_q = \text{concat}(W^{(1)}_q, W^{(2)}_q) \in R^{d_m \times d_m} Wv=concat(Wv(1),Wv(2)),Wk=concat(Wk(1),Wk(2)),Wq=concat(Wq(1),Wq(2))Rdm×dm

X ′ ′ X'' X′′为在参数 W v ′ , W k ′ , W q ′ W'_v, W'_k, W'_q Wv,Wk,Wq上执行的单头注意力的输出。在这种情况下, X ′ ′ = X ′ X'' = X' X′′=X吗?请证明你的答案。

解答

CMU学生:

2.2(a):

在这里插入图片描述

2.2(b):

在这里插入图片描述

GPT解答:

问题2.2 (a) 的解答

注意力矩阵 A ( i ) A^{(i)} A(i)并不总是对称的。具体来说, A ( i ) A^{(i)} A(i)的计算方式是通过:

A ( i ) = softmax ( S ( i ) ) A^{(i)} = \text{softmax}(S^{(i)}) A(i)=softmax(S(i))

其中, S ( i ) = Q ( i ) ( K ( i ) ) T / d k S^{(i)} = Q^{(i)}(K^{(i)})^T / \sqrt{d_k} S(i)=Q(i)(K(i))T/dk

由于 Q ( i ) Q^{(i)} Q(i) K ( i ) K^{(i)} K(i)是通过不同的权重矩阵 W q ( i ) W_q^{(i)} Wq(i) W k ( i ) W_k^{(i)} Wk(i)从输入 X X X中生成的,因此 Q ( i ) ≠ K ( i ) Q^{(i)} \neq K^{(i)} Q(i)=K(i)是常见的情况。因此, S ( i ) S^{(i)} S(i)通常不会是对称的矩阵。通过 softmax 操作后, A ( i ) A^{(i)} A(i)也不会是对称的。

确保对称的条件

要确保注意力矩阵 A ( i ) A^{(i)} A(i)是对称的,需要使得 Q ( i ) = K ( i ) Q^{(i)} = K^{(i)} Q(i)=K(i)。换句话说,只有当查询矩阵 Q ( i ) Q^{(i)} Q(i)和键矩阵 K ( i ) K^{(i)} K(i)完全相同时,注意力矩阵 S ( i ) S^{(i)} S(i)才会对称,而这要求 W q ( i ) = W k ( i ) W_q^{(i)} = W_k^{(i)} Wq(i)=Wk(i)。在这种情况下, S ( i ) = Q ( i ) ( Q ( i ) ) T / d k S^{(i)} = Q^{(i)}(Q^{(i)})^T / \sqrt{d_k} S(i)=Q(i)(Q(i))T/dk 是对称的,从而 A ( i ) A^{(i)} A(i)也会对称。

总结:

  • A ( i ) A^{(i)} A(i)通常不是对称的。
  • W q ( i ) = W k ( i ) W_q^{(i)} = W_k^{(i)} Wq(i)=Wk(i)时, A ( i ) A^{(i)} A(i)才会是对称的。

问题2.2 (b) 的解答

问题的核心是在比较两个不同的注意力机制输出是否相等:一种是多头注意力(multi-headed attention),另一种是将注意力头的参数沿行拼接并作为单头注意力(single-headed attention)的输出。

给定的情况是:

  • 多头注意力使用的参数为 W v ( 1 ) , W k ( 1 ) , W q ( 1 ) W^{(1)}_v, W^{(1)}_k, W^{(1)}_q Wv(1),Wk(1),Wq(1) W v ( 2 ) , W k ( 2 ) , W q ( 2 ) W^{(2)}_v, W^{(2)}_k, W^{(2)}_q Wv(2),Wk(2),Wq(2)
  • 单头注意力则使用拼接后的参数 W v ′ = concat ( W v ( 1 ) , W v ( 2 ) ) W'_v = \text{concat}(W^{(1)}_v, W^{(2)}_v) Wv=concat(Wv(1),Wv(2)) W k ′ = concat ( W k ( 1 ) , W k ( 2 ) ) W'_k = \text{concat}(W^{(1)}_k, W^{(2)}_k) Wk=concat(Wk(1),Wk(2)),和 W q ′ = concat ( W q ( 1 ) , W q ( 2 ) ) W'_q = \text{concat}(W^{(1)}_q, W^{(2)}_q) Wq=concat(Wq(1),Wq(2))

问题: X ′ ′ = X ′ X'' = X' X′′=X吗?

答案:

X ′ ′ ≠ X ′ X'' \neq X' X′′=X

解释:

  1. 多头注意力的特性

在多头注意力机制中,输入 X X X被分别投影到不同的查询、键和值空间中(每个注意力头的维度为 d k d_k dk),计算出相应的注意力矩阵 A ( i ) A^{(i)} A(i)并作用在值向量上。这些不同的注意力头可以捕捉输入序列中的不同模式。最终的输出 X ′ X' X是将各个头的输出 X ′ ( i ) X'^{(i)} X(i)进行拼接后得到的。

  1. 单头注意力的特性

在单头注意力机制中,输入 X X X是通过一个拼接后的参数矩阵 W v ′ , W k ′ , W q ′ W'_v, W'_k, W'_q Wv,Wk,Wq投影到一个维度更大的查询、键和值空间中。这意味着单头注意力只执行一次注意力计算,无法像多头注意力那样对输入序列进行多角度的模式捕捉。

  1. 拼接的影响

虽然多头注意力的参数可以通过拼接形式转换为单头注意力的参数,但由于单头注意力只进行一次注意力计算,而多头注意力在每个头上进行独立的注意力计算并拼接结果,因此单头注意力的输出 X ′ ′ X'' X′′与多头注意力的输出 X ′ X' X不同。

换句话说,单头注意力将所有信息合并到一个注意力头中,而多头注意力可以独立地处理多个注意力头,捕捉不同的上下文关系。因此,拼接后的单头注意力无法完全重现多头注意力的效果。

总结

  • X ′ ′ ≠ X ′ X'' \neq X' X′′=X,因为多头注意力和单头注意力的计算方式不同,拼接参数并不能保证输出相等。

3 Sliding Window Attention(滑动窗口注意力)

滑动窗口注意力的最简单定义方法是将因果掩码 M M M 设置为仅包含 1 2 w + 1 \frac{1}{2}w + 1 21w+1 个标记,其中最右侧的窗口元素为当前标记(即在对角线上)。然后我们的注意力计算为:

X ′ = softmax ( Q K T d k + M ) V X' = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V X=softmax(dk QKT+M)V

例如,如果序列长度为 N = 6 N = 6 N=6,窗口大小为 w = 4 w = 4 w=4,则掩码矩阵 M M M为:

M = [ 0 − ∞ − ∞ − ∞ − ∞ − ∞ 0 0 − ∞ − ∞ − ∞ − ∞ 0 0 0 − ∞ − ∞ − ∞ − ∞ 0 0 0 − ∞ − ∞ − ∞ − ∞ 0 0 0 − ∞ − ∞ − ∞ − ∞ 0 0 0 ] M = \begin{bmatrix} 0 & -\infty & -\infty & -\infty & -\infty & -\infty \\ 0 & 0 & -\infty & -\infty & -\infty & -\infty \\ 0 & 0 & 0 & -\infty & -\infty & -\infty \\ -\infty & 0 & 0 & 0 & -\infty & -\infty \\ -\infty & -\infty & 0 & 0 & 0 & -\infty \\ -\infty & -\infty & -\infty & 0 & 0 & 0 \end{bmatrix} M= 000000000000000

(a) (1 分) 简答:如果我们按照公式中的矩阵乘法实现滑动窗口,时间复杂度在 N N N w w w的情况下是什么?(对于此问题及后续问题,假设两个矩阵 X ∈ R m × n X \in \mathbb{R}^{m \times n} XRm×n Y ∈ R n × p Y \in \mathbb{R}^{n \times p} YRn×p相乘的计算成本为 O ( m n p ) O(mnp) O(mnp))。

(b) (1 分) 简答:如果我们按照公式中的矩阵乘法实现滑动窗口,空间复杂度在 N N N w w w的情况下是什么?

© (4 分) 伪代码:编写伪代码或数学表达式,用于一个函数,它接收查询、键、值和窗口大小 w w w,并计算 X ′ X' X

SCALEDDOTPRODUCTATTENTION ( Q , K , V , d k , w ) \text{SCALEDDOTPRODUCTATTENTION}(Q, K, V, d_k, w) SCALEDDOTPRODUCTATTENTION(Q,K,V,dk,w)

你的伪代码/数学表达式必须具有比上述天真矩阵乘法方法更低的渐进计算复杂度。你的解决方案可以并且应该包括 for 循环。假设可以访问一个函数 softmax(x),该函数对向量 x x x应用 softmax,以及一个函数 tensor(·),用于构造任意形状的向量、矩阵、张量。

(d) (1 分) 简答:你的伪代码的空间复杂度是多少(以 N N N w w w表示)?

(e) (1 分) 简答:你的伪代码的时间复杂度是多少(以 N N N w w w表示)?


通过这部分问题,要求理解滑动窗口注意力的实现及其计算复杂度,重点在于如何在序列上应用局部注意力窗口,并分析其效率。

解答

CMU学生的a~e答案:

在这里插入图片描述

备注:关于严格按照公式计算的空间复杂度,这里CMU学生应该错了,空间复杂度应该是N^2,
因为要存储Q*K矩阵的结果维度也就是N^2。

在这里插入图片描述

GPT解答:


3.1(a) 公式版的时间复杂度

问题:如果我们严格按照公式 X ′ = softmax ( Q K T d k + M ) V X' = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V X=softmax(dk QKT+M)V实现滑动窗口注意力,时间复杂度是多少?

解答

  1. 计算 Q K T / d k QK^T / \sqrt{d_k} QKT/dk
  • Q ∈ R N × d k Q \in \mathbb{R}^{N \times d_k} QRN×dk K ∈ R N × d k K \in \mathbb{R}^{N \times d_k} KRN×dk,计算 Q K T QK^T QKT需要 N 2 N^2 N2次点积,每个点积的复杂度是 O ( d k ) O(d_k) O(dk)

  • 时间复杂度: O ( N 2 d k ) O(N^2 d_k) O(N2dk)

  • 加上掩码矩阵 M M M

    • M M M N × N N \times N N×N矩阵,将它与 Q K T QK^T QKT相加的时间复杂度是 O ( N 2 ) O(N^2) O(N2)
  • 计算 softmax

    • Softmax 应用在 N × N N \times N N×N矩阵上,因此时间复杂度是 O ( N 2 ) O(N^2) O(N2)
  • 计算 A V AV AV

    • A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N V ∈ R N × d k V \in \mathbb{R}^{N \times d_k} VRN×dk,矩阵乘法的时间复杂度是 O ( N 2 d k ) O(N^2 d_k) O(N2dk)

总的时间复杂度

O ( N 2 d k ) + O ( N 2 ) = O ( N 2 d k ) O(N^2 d_k) + O(N^2) = O(N^2 d_k) O(N2dk)+O(N2)=O(N2dk)

如果 d k d_k dk是常量,则时间复杂度简化为:

O ( N 2 ) O(N^2) O(N2)


3.1(b) 公式版的空间复杂度

问题:如果我们严格按照公式实现滑动窗口注意力,空间复杂度是多少?

解答

  1. 存储 Q Q Q K K K V V V
  • Q Q Q K K K V V V的大小分别是 N × d k N \times d_k N×dk,所以需要的存储空间是 O ( N d k ) O(N d_k) O(Ndk)

  • 存储 Q K T QK^T QKT

    • Q K T QK^T QKT是一个 N × N N \times N N×N的矩阵,因此需要 O ( N 2 ) O(N^2) O(N2)的空间。
  • 存储 softmax 的中间结果和 A A A

    • softmax 操作需要存储 A ∈ R N × N A \in \mathbb{R}^{N \times N} ARN×N,即 O ( N 2 ) O(N^2) O(N2)
  • 存储 X ′ X' X

    • X ′ ∈ R N × d k X' \in \mathbb{R}^{N \times d_k} XRN×dk,因此存储 X ′ X' X需要 O ( N d k ) O(N d_k) O(Ndk)

总的空间复杂度

O ( N 2 ) + O ( N d k ) = O ( N 2 ) O(N^2) + O(N d_k) = O(N^2) O(N2)+O(Ndk)=O(N2)

d k d_k dk是常量的情况下,空间复杂度简化为:

O ( N 2 ) O(N^2) O(N2)


3.1(c, d, e) 代码版实现、空间复杂度、时间复杂度

问题:编写伪代码,优化滑动窗口注意力计算,使其具有比天真矩阵乘法方法更低的渐进计算复杂度。

解答

优化滑动窗口注意力的关键在于减少每个位置与序列中每个元素的交互,而仅与滑动窗口内的元素进行交互,从而将复杂度从 O ( N 2 ) O(N^2) O(N2)降到 O ( N w ) O(N w) O(Nw)

代码如下:

import torch
import torch.nn.functional as F

def sliding_window_causal_attention(Q, K, V, d_k, w):
    """
    实现滑动窗口的因果注意力机制
    Q, K, V: 输入的查询、键和值张量,形状为 (N, d_k)
    d_k: 每个向量的维度
    w: 滑动窗口大小
    返回: 经过滑动窗口因果注意力机制处理后的输出 X_prime
    """
    N = Q.size(0)  # 序列长度
    X_prime = torch.zeros_like(Q)  # 初始化输出矩阵

    for i in range(N):
        # 只向前看,因此窗口从 max(0, i - w // 2) 到 i
        window_start = max(0, i - w // 2)
        window_end = i  # 只看当前位置及之前的位置

        # 提取键和值的窗口部分
        K_window = K[window_start:window_end+1]  # 窗口内的键
        V_window = V[window_start:window_end+1]  # 窗口内的值

        # 计算 Q[i] 与 K_window 的点积,并除以 sqrt(d_k)
        scores = torch.matmul(Q[i], K_window.T) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

        # 创建因果掩码,确保当前位置只关注前面的元素
        mask = torch.full_like(scores, float('-inf'))  # 初始化为 -inf
        mask[i - window_start:] = 0  # 当前位置及其前面的部分设置为有效

        # 计算 softmax 并应用掩码
        attention_weights = F.softmax(scores + mask, dim=-1)

        # 计算加权的输出
        X_prime[i] = torch.matmul(attention_weights, V_window)

    return X_prime

# 测试
N = 6  # 序列长度
d_k = 4  # 向量维度
w = 4  # 滑动窗口大小

# 随机生成查询、键和值
Q = torch.rand(N, d_k)
K = torch.rand(N, d_k)
V = torch.rand(N, d_k)

output = sliding_window_causal_attention(Q, K, V, d_k, w)
print(output)

# 我测试的运行结果:
tensor([[0.0323, 0.9104, 0.7727, 0.9223],
        [0.3455, 0.5752, 0.7489, 0.8813],
        [0.3887, 0.2673, 0.6308, 0.3600],
        [0.8112, 0.2752, 0.5635, 0.5418],
        [0.2759, 0.4999, 0.2751, 0.8016],
        [0.9672, 0.7985, 0.7158, 0.2517]])

代码流程总结:

  1. 初始化:为输出矩阵 X_prime 分配与查询向量 Q Q Q 相同形状的零矩阵,用于存储最终的注意力结果。
  2. 滑动窗口定义:对每个序列位置 i i i,确定滑动窗口的范围,该窗口的起点为 max ⁡ ( 0 , i − w / / 2 ) \max(0, i - w//2) max(0,iw//2),终点为 i i i。该窗口确保当前元素 i i i 只能关注自己和前面 w / 2 w/2 w/2 个元素。
  3. 提取窗口内的键和值:从序列中提取滑动窗口范围内的键 K K K 和值 V V V,用于后续的计算。
  4. 计算点积注意力分数:对查询向量 Q [ i ] Q[i] Q[i] 与窗口内的键 K K K 进行点积计算,并且进行缩放操作(除以 d k \sqrt{d_k} dk )。
  5. 应用因果掩码:生成掩码矩阵,屏蔽窗口外的元素,确保当前元素 i i i 只关注自己及前面 w / 2 w/2 w/2 个元素。
  6. Softmax归一化:对注意力分数应用 softmax 操作,得到注意力权重。
  7. 加权求和:使用注意力权重对窗口内的值 V V V 进行加权求和,计算最终的注意力输出。
  8. 返回输出:最终返回经过滑动窗口注意力机制处理后的输出矩阵 X ′ X' X

时间复杂度分析:

对于每个序列位置 i i i,计算点积、掩码应用、Softmax以及加权求和的操作都限制在一个大小为 w / 2 w/2 w/2的窗口内。因此,对于序列长度 N N N和窗口大小 w w w,代码的整体时间复杂度为:

  • 每个位置的时间复杂度:在窗口 w w w内进行矩阵点积和加权求和,每次计算的时间复杂度为 O ( w d k ) O(w d_k) O(wdk)
  • 整体时间复杂度:对于整个序列,每个序列位置的计算都限制在窗口大小 w w w内,因此总的时间复杂度为:

O ( N w d k ) O(N w d_k) O(Nwdk)

其中 N N N是序列长度, w w w是窗口大小, d k d_k dk是查询和键的向量维度。

空间复杂度分析:

  • 存储键、查询和值:每个键、查询和值的大小为 N × d k N \times d_k N×dk,因此存储 Q Q Q K K K V V V的空间复杂度为 O ( N d k ) O(N d_k) O(Ndk)
  • 存储掩码矩阵:每个位置的掩码矩阵大小为 w / 2 w/2 w/2,但并不需要存储全局的 N × N N \times N N×N掩码矩阵,因此掩码的空间复杂度是 O ( w ) O(w) O(w)
  • 存储中间结果:每个位置的注意力分数和权重只涉及窗口内的元素,大小为 O ( w ) O(w) O(w)
  • 总空间复杂度:代码的空间复杂度主要由存储 Q Q Q K K K V V V以及中间计算的结果决定,最终为:

O ( N d k + w ) O(N d_k + w) O(Ndk+w)

其中 N N N是序列长度, d k d_k dk是向量维度, w w w是窗口大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2125762.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IDEAJ真正修改maven(.m2)在Windows环境下缓存路径

IDEAJ真正修改maven(.m2)缓存路径的方法 下面这种方式虽然当前项目生效了,IntelliJ IDEA修改默认.m2和.gradle缓存路径-CSDN博客文章浏览阅读251次,点赞4次,收藏8次。文章浏览阅读1.3k次。1,File -ProjectStructure - Artifacts &…

RISC-V (十二)系统调用

系统模式:用户态和内核态 当前的代码都是实现在machine模式下。 系统模式的切换 epc寄存器的值存放的是ecall指本身的地址 。 用ecall指令 系统调用的执行流程 mret这条指令会利用status的mpp值恢复到之前的特权级别。 蓝色的线表示涉及到权限切换。 系统调用的传…

【VUE】pinia持久化存储

前言:状态持久化存储的意义在于它能够确保用户在与应用程序交互时,其操作状态、用户偏好、应用数据等关键信息在页面刷新、浏览器关闭或重新启动后依然得以保留,从而提供连贯、无缝的用户体验,避免因状态丢失导致的不便和重复操作…

绑定域名解析怎么做?

在当今数字化时代,拥有一个网站已经成为许多个人和企业展示自己、提供服务或进行商业活动的重要方式。而要让网站能够在互联网上被访问到,绑定域名解析是一个关键的步骤。 绑定域名解析究竟该怎么做呢? 一、了解域名解析的基本概念 域名解…

【重要】MThings V0.7.1更新要点

下载地址http://gulink.cn/download 01. [新增]逻辑控制功能。 “逻辑控制”作为一项全新的商用版功能,通过使用直观的可视化积木搭建方法,为系统级数据处理提供了高度的可定制性。它能够简化复杂的数据计算、控制算法设计和仿真验证流程,同…

探索未来住宿体验:酒店触摸开关的科技魅力

在快节奏的现代生活中,人们对于旅行住宿的期待已远不止于基本的休息与安眠,而是更加注重个性化、便捷化与智能化。随着科技的飞速发展,酒店行业也迎来了前所未有的变革,其中,触摸开关作为智能客房的标志性元素&#xf…

【网络安全】-xss跨站脚本攻击-pikachu

文章目录 前言 什么是xss跨站脚本攻击? 1.xss的分类: 1.1 反射型xss 1.2 存储型xss 1.3 Dom型xss: 2.同源策略: 2.1同源策略的定义 2.2同源策略的绕过 2.3 绕过同源策略的绕过 前言 什么是xss跨站脚本攻击&#xff1…

两个月冲刺软考——重点理解传值方式与传引用方式的区别

1.总线的分类(按功能划分) 数据总线负责传输实际的数据。 地址总线用于指定数据的来源或目的地的内存地址。 控制总线传输控制信号,如读写指令和其他操作指令。 2.传值方式与传引用方式 传值方式:形参取的是实参的值,形参的改变不会导致调…

开源数据集 FreiHAND rgb 三维手势建模 手部关键点 >> DataBall

开源数据集 FreiHAND rgb 三维手势建模 手部关键点 mano hand 混合现实 深度学习 人工智能 FreiHAND是一个用于评估和训练深度神经网络以从单色图像中估计手部姿态和形状的数据集,这是在我们的论文中提出的。其当前版本包含32560个独特的训练样本和3960个独特的评估…

【JavaScript】LeetCode:26-30

文章目录 26 矩阵置零27 螺旋矩阵28 旋转图像29 搜索二维矩阵Ⅱ30 相交链表 26 矩阵置零 2次双重for循环。第1次:将matrix[i][j]为0时的i、j分别存放于数组res_i、res_j,记录有哪些行、列应该置为0。第2次:将记录中的行、列置为0。 /**- par…

揭秘蛇形机器人的主动SLAM算法和障碍物避让策略

更多优质内容,请关注公众号:智驾机器人技术前线 1.论文信息 论文标题:An active SLAM with multi-sensor fusion for snake robots based on deep reinforcement learning 作者:Xin Liu, Shuhuan Wen, Yaohua Hu, Fei Han, Hong…

如何利用免费工具轻松设计出专业Logo?

Logo 作为品牌的象征和视觉核心,承载了品牌的价值和理念。无论是创业公司还是个人品牌,拥有一个独特的 Logo 都显得尤为重要。然而,设计一个专业的 Logo 通常需要高昂的设计费用,许多人因此望而却步。幸运的是,随着互联…

视频合并实用教程分享,教你6个合并视频方法,不可错过!

多个视频怎么合成一个视频?如何把2个视频合成一个?怎么把多个视频合成一个视频?您是否也曾产生过这样的疑问呢?在如今互联网高速发展的时代,各行各业都难免需要涉及到视频制作领域,如果您正在考虑视频如何有…

HarmonyOS开发之模拟器地图点击无效问题

问题描述 当我们模拟器已经运行了,但是需求中需要引用地图,地图的点击事件无效,模拟器地图点击事件以及地拖拖拽无效,这个时候需要配置一下通过 chrome://inspect/#devices 通常用于远程调试移动设备上的 Chrome 浏览器 1、先配…

Cortex-A7:ARM官方推荐的嵌套中断实现机制

0 参考资料 ARM Cortex-A(armV7)编程手册V4.0.pdf ARM体系结构与编程第2版 1 前言 Cortex-M系列内核MCU中断硬件原生支持嵌套中断,开发者不需要为了实现嵌套中断而进行额外的工作。但在Cortex-A7中,硬件原生是不支持嵌套中断的,这从Cortex-A…

Sentinel实时监控不展示问题

问题 官方插件 Endpoint 支持,可以实时统计出SpringBoot的健康状况和请求的调用信息 在使用 Endpoint 特性之前需要在 Maven 中添加 spring-boot-starter-actuator 依赖,并在配置中允许 Endpoints 的访问。 Spring Boot 1.x 中添加配置 management.se…

硬件工程师笔试面试——MOS管

目录 8、MOS管 8.1 基础 MOS管原理图 MOS实物图 8.1.1 概念 8.1.2 特点 8.1.3 类型 7.2 相关问题 7.2.1 MOS管在不同应用中的阈值电压和最大漏极电流通常是多少? 7.2.2 如何根据电路设计选择合适的MOS管类型? 7.2.3 MOS管在高频应用中的优势是什么,它如何影响电路…

数据结构(4)栈和队列

一、栈 栈是一种线性的数据存储结构,由于栈只允许从一端进行数据的插入和删除。因此,栈是先进后出(FILO)的特性。 系统栈是由系统进行创建和维护的,数据结构中的栈是由自己创建和维护的。 栈的类型有:满增栈…

网络基础入门指南(三)

一、远程管理交换机 1.配置IP地址 远程管理需要通过IP地址访问网络设备交换机的接口,默认无法配置IP地址需要使用虚接口vlan1 2.配置远程登录密码 远程管理需要配置VTY接口VTY是虚拟终端,是一种网络设备远程连接的方式vty 0 4表示可同时打开5个会话 3…

数控纵切自动车床

数控纵切自动车床,作为现代机械加工领域的重要设备,集成了高精度、高效率与自动化等诸多优势。下面,我将从多个方面为您详细介绍这一先进设备。 ‌一、设备概述‌ 数控纵切自动车床,顾名思义,是一种结合了数控技术与纵…