【大模型理论篇】LLaMA3结构关键模块分析

news2024/11/15 4:14:50

1. 背景介绍     

        在文章《关于LLaMA 3.1 405B以及小模型的崛起》中,我们提到,LLaMA 3.1 的模型架构基本上已经成为当前LLM 模型的标准结构,和《Transformer原理分析》中提到的结构,也类似。但相比较,其中的一些关键模块做了一些优化和改进。本篇文章主要就是针对其中的几个改进模块,做一定的分析和记录。 

        通过图1 对比【1】,来看下LLaMA结构的改进点,主要有四点:Layer Norm换成了RMSNorm、  Self-Attension采用了Group Query Attention、Positional Embedding采用了ROPE、FFN结构采用了SwiGLU。

图1. Transformer结构与LLaMA结构对比

2. 各模块介绍

2.1 RMS Norm

        RMSNorm(Root Mean Square Layer Normalization)【2】是一种用于深度神经网络的归一化技术,旨在对网络层的输出进行归一化,以加速模型训练并提高模型的稳定性。它是Layer Normalization的一种变体,但其实现方式与LayerNorm略有不同。

        RMSNorm不依赖于均值,而是通过计算特征向量的均方根(Root Mean Square, RMS)来进行归一化。具体来说,对于输入向量 x(假设其维度为 d),RMSNorm的计算方式如下:

  1. 计算均方根 (RMS):

    \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2}

    这里,x_i表示输入向量 x 的第 i 个元素。

  2. 归一化输出:

    \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot g

    其中,g 是一个可学习的标量参数,用于重新缩放归一化后的输出。

RMSNorm的特点

  • 不依赖均值: 与LayerNorm不同,RMSNorm仅使用RMS而不依赖于均值来进行归一化,这使得其计算更简单、代价更低。

  • 计算效率: 由于不计算均值,RMSNorm的计算复杂度较低,尤其是在高维度输入的情况下,RMSNorm相比LayerNorm更为高效。

  • 适用于各种网络结构: RMSNorm可以应用于各种神经网络结构,如Transformer等,能够提高模型的训练稳定性。

理解RMSNorm

        RMSNorm的核心思想是通过控制每个特征向量的长度,使网络层的输出保持稳定。通过去除均值计算,RMSNorm减少了归一化过程中的计算量,但仍然能够有效抑制梯度爆炸或消失问题,提高训练的稳定性。

        由于RMSNorm仅依赖于输入向量的RMS值,因此它在计算上比LayerNorm更简洁,并且在某些情况下可以提供类似甚至更好的性能。特别是在大规模神经网络如Transformer的训练中,RMSNorm已经成为一种常见的归一化选择。

代码示例【3】

class RMSNormLayer(lasagne.layers.Layer):
    def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.),
                 W=lasagne.init.Normal(0.05), nonlinearity=relu, **kwargs):
        super(RMSNormLayer, self).__init__(incoming, **kwargs)
        self.nonlinearity = nonlinearity
        k = self.input_shape[1]
        if b is not None:
            self.b = self.add_param(b, (k,), name="b", regularizable=False)
        if g is not None:
            self.g = self.add_param(g, (k,), name="g")

        if len(self.input_shape)==4:
            self.axes_to_sum = (2,3)
            self.dimshuffle_args = ['x',0,'x','x']
        else:
            self.axes_to_sum = 1
            self.dimshuffle_args = ['x',0]

    def get_output_for(self, input, **kwargs):
        meanS = T.mean(input ** 2,axis=self.axes_to_sum,keepdims=True)

        norm_input = input / T.sqrt(meanS + 1e-6)

        if hasattr(self, 'g'):
            activation = norm_input*self.g.dimshuffle(*self.dimshuffle_args)
        else:
            activation = norm_input
        if hasattr(self, 'b'):
            activation += self.b.dimshuffle(*self.dimshuffle_args)

        return self.nonlinearity(activation)

def rms_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs):
    nonlinearity = getattr(layer, 'nonlinearity', None)
    if nonlinearity is not None:
        layer.nonlinearity = lasagne.nonlinearities.identity
    if hasattr(layer, 'b'):
        del layer.params[layer.b]
        layer.b = None
    return RMSNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs)

2.2 Group Query Attention

        分组查询注意力(GQA, Grouped Query Attention)【4】是一种改进的注意力机制,旨在提高多头自注意力机制的计算效率和内存利用率。GQA的主要思想是将查询(query)分组,以减少计算复杂度,同时保持注意力机制的有效性。

2.2.1 GQA的关键点

  1. 分组查询: 在GQA中,查询(query)被分成多个组,每个组独立计算注意力权重。这种分组可以减少计算量,因为每组的查询数量减少了。

  2. 共享键和值: 尽管查询被分组,GQA会在所有组中共享键(key)和值(value),从而减少内存消耗,并确保模型的注意力机制仍然能够有效捕捉输入之间的关系。

  3. 计算效率: 通过减少查询的数量,GQA大幅降低了计算成本,使得模型可以更高效地处理大规模数据,尤其是在资源有限的环境中。

  4. 保留性能: 尽管进行了分组,GQA仍然能够保持与标准多头自注意力机制相似的性能。实验表明,在许多任务中,GQA可以在减少计算和内存使用的同时,仍然取得与传统方法相当的结果。

图2. GQA形式

        多头注意力拥有 H 个查询、键和数值头部。多查询注意力在所有查询头部之间共享单个键和数值头部。分组查询注意力为每组查询头部共享单个键和数值头部,从而在多头和多查询注意力之间插值。

2.2.2 标准多头自注意力机制(MHSA)公式回顾

        标准的多头自注意力机制的计算可以表示为:

\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • Q 是查询矩阵(Query)
  • K 是键矩阵(Key)
  • V 是值矩阵(Value)
  • d_k 是键向量的维度

        在多头自注意力机制中,多个注意力头通过不同的线性变换进行计算,然后将结果拼接起来:

\text{MHSA}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h)W^O

其中,每个注意力头的计算为:

\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

2.2.3 分组查询注意力(GQA)公式

        在GQA中,查询矩阵 Q 被分为 G 组,每组的查询向量数量为\frac{N}{G},其中 N 是查询向量的总数。键矩阵 K 和值矩阵 V 则通常保持不变,并在所有组之间共享。

        假设 Q_g​ 表示第 g 组的查询矩阵(g = 1, 2, \dots, G),那么GQA的计算可以表示为:

分组注意力计算:

\text{Attention}_g(Q_g, K, V) = \text{softmax}\left(\frac{Q_g K^T}{\sqrt{d_k}}\right)V

拼接分组输出:

\text{GQA}(Q, K, V) = \text{Concat}(\text{Attention}_1, \text{Attention}_2, \dots, \text{Attention}_G)W^O

其中:

  • \text{Attention}_g​ 表示第 g 组的注意力计算结果。
  • W^O 是输出的线性变换矩阵。

        通过这种分组方法,GQA减少了计算复杂度,同时仍然能够有效捕捉查询与键、值之间的相关性。

2.2.4 特殊情形

        在分组查询注意力(GQA)中,传统多头模型中的查询头(Q)被分成 G 个组。每个组分配一个键(K)和一个值(V)头。这种配置被表示为 GQA-G,其中 G 表示组的数量。

GQA的特殊情况【5】:

  • GQA-1 = MQA:当只有一个组(G = 1)时,GQA等同于MQA(单头查询注意力),因为所有查询头共享一个键和一个值头。
  • GQA-H = MHA:当组的数量等于头的数量(G = H)时,GQA的表现与传统的多头注意力机制(MHA)相同,每个查询头都有其独立的键和值头。

2.2.5 代码示例【6】

from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from torch import Tensor, nn


def scaled_dot_product_gqa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    dropout: float = 0.0,
    scale: Optional[float] = None,
    mask: Optional[Tensor] = None,
    is_causal: Optional[bool] = None,
    need_weights: bool = False,
    average_attn_weights: bool = False,
    force_grouped: bool = False,
):
    """Scaled dot product attention with support for grouped queries.

    Einstein notation:
    - b: batch size
    - n / s: sequence length
    - h: number of heads
    - g: number of groups
    - d: dimension of query/key/value

    Args:
        query: Query tensor of shape (b, n, h, d)
        key: Key tensor of shape (b, s, h, d)
        value: Value tensor of shape (b, s, h, d)
        dropout: Dropout probability (default: 0.0)
        scale: Scale factor for query (default: d_query ** 0.5)
        mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is
            applied to all 'n' rows of the attention matrix. (default: None)
        force_grouped: If True, apply grouped-query attention even if the number of
            heads is equal for query, key, and value. (default: False)

    Returns:
        2-tuple of:
        - Attention output with shape (b, n, h, d)
        - (Optional) Attention weights with shape (b, h, n, s). Only returned if
          'need_weights' is True.
    """
    if (mask is not None) and (is_causal is not None):
        raise ValueError(
            "Only one of 'mask' and 'is_causal' should be provided, but got both."
        )
    elif not query.ndim == key.ndim == value.ndim == 4:
        raise ValueError(
            f"Expected query, key, and value to be 4-dimensional, but got shapes "
            f"{query.shape}, {key.shape}, and {value.shape}."
        )

    # Move sequence length dimension to axis 2.
    # This makes the attention operations below *much* faster.
    query = rearrange(query, "b n h d -> b h n d")
    key = rearrange(key, "b s h d -> b h s d")
    value = rearrange(value, "b s h d -> b h s d")

    bq, hq, nq, dq = query.shape
    bk, hk, nk, dk = key.shape
    bv, hv, nv, dv = value.shape
    if not (bq == bk == bv and dq == dk == dv):
        raise ValueError(
            "Expected query, key, and value to have the same batch size (dim=0) and "
            f"embedding dimension (dim=3), but got query: {query.shape}, "
            f"key: {key.shape}, and value: {value.shape}."
        )
    elif (hk != hv) or (nk != nv):
        raise ValueError(
            "Expected key and value to have the same size in dimensions 1 and 2, but "
            f"got key: {key.shape} and value: {value.shape}."
        )
    elif hq % hk != 0:
        raise ValueError(
            "Expected query heads to be a multiple of key/value heads, but got "
            f"query: {query.shape} and key/value: {key.shape}."
        )

    if scale is None:
        scale = query.size(-1) ** 0.5
    query = query / scale

    num_head_groups = hq // hk
    query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
    similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s")

    if is_causal:
        # Mask out the upper triangular portion of the attention matrix. This prevents
        # the model from attending to tokens in the future.
        mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_()

    if mask is not None:
        # Expand mask to match the shape of the attention matrix.
        # If mask is 2D, assume that it is applied to the key/value sequence dimension.
        # Else if mask is 3D, assume that it is applied to the query/key/value sequence
        # dimension for all attention heads.
        #
        # Users could also provide a 4D mask, which is applied to the query/key/value
        # sequence dimension for each attention head (though I don't have a particular
        # use case in mind for that).
        if mask.ndim == 2:
            mask = rearrange(mask, "b s -> b () () () s")
        elif mask.ndim == 3:
            mask = rearrange(mask, "b n s -> b () () n s")
        # Mask similarity values by setting them to negative infinity.  This guarantees
        # that they will not contribute to the softmax computation below.
        similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)

    attention = F.softmax(similarity, dim=-1)
    if dropout > 0.0:
        attention = F.dropout(attention, p=dropout)

    # Apply attention matrix to the value Tensor.
    out = einsum(attention, value, "b g h n s, b h s d -> b g h n d")
    # Move head dimension back to axis 2
    out = rearrange(out, "b g h n d -> b n (h g) d")

    attn_weights: Optional[Tensor] = None
    if need_weights:
        # Move the sequence dimensions back to positions 1, 2.  Move the head dimension
        # to position 3.  This more closely matches the return shape of the attention
        # output: (b, n, h, d).
        attn_weights = rearrange(attention, "b g h n s -> b n s (h g)")
        if average_attn_weights:
            attn_weights = attn_weights.mean(dim=1)

    return out, attn_weights


class MultiheadGQA(nn.Module):
    """Multi-head grouped query attention (GQA) layer.

    Reference:
        "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
        https://arxiv.org/pdf/2305.13245v1.pdf

    GQA is a variant of multihead attention (MHA) that uses fewer write heads
    (key / value) than query heads.  GQA can be viewed as a generalization of
    multi-query attention (MQA), which uses a single write head. GQA and MQA give
    significant speedups over standard MHA in decoder layers, with minimal loss in
    accuracy. In the paper, GQA is shown to be more accurate than MQA, while still
    having a significant speedup over MHA.

    NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model
    from MHA to GQA.  As a result, they do not mention parameter initialization or
    layer normalization strategies.  I follow the best practices laid out in the
    MAGNETO paper, which improves Transformer performance through better parameter
    initialization and layer norm placement.  See:
        https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
    """

    def __init__(
        self,
        embed_dim: int,
        query_heads: int,
        kv_heads: int,
        dropout: float = 0.0,
        bias: bool = True,
        layer_norm: bool = True,
        layer_norm_eps: float = 1e-5,
        gamma_init: float = 1.0,
        device: Optional[Union[torch.device, str]] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.query_heads = query_heads
        self.kv_heads = kv_heads
        self.dropout = dropout
        self.layer_norm = layer_norm
        self.gamma_init = gamma_init

        if self.query_heads % self.kv_heads != 0:
            raise ValueError(
                f"query_heads ({query_heads}) must be divisible by "
                f"kv_heads ({kv_heads})"
            )
        elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0):
            raise ValueError(
                f"embed_dim ({embed_dim}) must be divisible by "
                f"query_heads ({query_heads}) and kv_heads ({kv_heads})"
            )

        head_dim = embed_dim // query_heads
        if not head_dim % 8 == 0:
            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
            )
        if not head_dim <= 128:
            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
            )

        # Query projection layer is the same as in vanilla MHA.
        self.q_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        # Key/value projection layers have a smaller output dimension, so that
        # the we have fewer key/value attention heads after reshaping.
        kv_embed_dim = embed_dim // query_heads * kv_heads
        self.k_proj = nn.Linear(
            embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.v_proj = nn.Linear(
            embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.norm: Optional[nn.LayerNorm] = None
        if layer_norm:
            self.norm = nn.LayerNorm(
                embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
            )
        # Grouped attention output will have the same embedding dimension as the
        # key/value Tensors.  So the output projection layer needs to accept the
        # same dimension (kv_embed_dim).
        self.out_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_normal_(self.q_proj.weight)
        if self.q_proj.bias is not None:
            nn.init.constant_(self.q_proj.bias, 0)
        nn.init.xavier_normal_(self.k_proj.weight)
        if self.k_proj.bias is not None:
            nn.init.constant_(self.k_proj.bias, 0)

        # NOTE: We follow the initialization strategy from MAGNETO.  See:
        # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
        # Gain (self.gamma_init) should be provided as a keyword argument when
        # initializing the larger Transformer model, since it requires knowledge
        # of the number of encoder/decoder layers in the model.

        nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
        if self.v_proj.bias is not None:
            nn.init.constant_(self.v_proj.bias, 0)
        nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        need_weights: bool = False,
        # TODO
        # attn_mask: Optional[Tensor] = None,
        is_causal: bool = False,
        average_attn_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        # Notation:
        #   b - batch size
        #   n - sequence length
        #   h - number of heads
        #   d - embedding dimension
        #
        # Input shape: (b, n, d)
        q: Tensor = self.q_proj(query)
        k: Tensor = self.k_proj(key)
        v: Tensor = self.v_proj(value)

        # Unfold 'd' dimension into 'h' separate attention heads.
        q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads)
        k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads)
        v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads)
        # Apply attention, then fold 'h' attention heads back into 'd'.
        x, attn = scaled_dot_product_gqa(
            query=q,
            key=k,
            value=v,
            # TODO
            # mask=attn_mask,
            is_causal=is_causal,
            need_weights=need_weights,
            average_attn_weights=average_attn_weights,
            force_grouped=False,
        )
        x = rearrange(x, "b n h d -> b n (h d)")

        # NOTE: This is different from 'nn.MultiheadAttention'!  We follow the MAGNETO
        # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra
        # layer norm before the linear output projection.  The cross-attention layer in
        # the MAGNETO decoder does not include this layer norm, so users have the
        # option to disable it (layer_norm=False).
        if self.layer_norm:
            assert self.norm is not None
            x = self.norm(x)
        # Linear projection on attention outputs.
        x = self.out_proj(x)

        return x, attn

2.3 RoPE(Rotary Positional Encoding)

        Rotary Positional Encoding (RoPE) 是一种用于自注意力模型的位置信息编码方法,用于增强模型对序列中位置信息的理解。与传统的正弦和余弦位置编码不同,RoPE通过旋转操作将位置信息直接嵌入到查询(query)和键(key)向量中,以更好地捕捉序列中不同位置的相对关系。

2.3.1 RoPE 的关键特点

  1. 旋转编码:RoPE通过在查询和键向量的不同维度之间应用旋转操作,将位置信息嵌入到向量表示中。具体来说,RoPE将每个维度的向量分成两部分,分别应用二维旋转矩阵进行处理。

  2. 相对位置信息:RoPE自然地编码了相对位置信息,而不是绝对位置信息,这使得模型能够更好地捕捉序列中元素之间的相对关系,尤其在处理长序列时表现更优。

  3. 无额外参数:RoPE不引入额外的模型参数,只是在已有的查询和键向量上应用旋转操作,因此不增加模型的复杂性。

2.3.2 RoPE 的数学表达

        对于一个维度 d 的查询或键向量 x,假设\theta 是与位置 p 相关的旋转角度,则 RoPE 的旋转操作可以表示为:

\text{RoPE}(x, p) = \begin{pmatrix} \cos(p \cdot \theta) & -\sin(p \cdot \theta) \\ \sin(p \cdot \theta) & \cos(p \cdot \theta) \end{pmatrix} \cdot \begin{pmatrix} x_1 \\ x_2 \end{pmatrix}

        在这个操作中,x_1x_2 分别是向量 x 的一部分,旋转操作会对这些部分应用一个与位置 p 相关的旋转角度,使得向量在不同位置上具有相应的变化。

2.3.3 旋转位置编码取代绝对位置编码的原因【8】

2.3.3.1 绝对位置编码的主要缺点

        无法包括相对位置信息,虽然绝对位置编码捕捉到了单词的位置信息,但它并未捕捉整个句子(或序列)的位置信息。

示例:

一种常见的创建长度为3的绝对位置编码的方法是进行随机初始化。
假设我们得到如下结果:[0.1, 0.01, 0.5],这种绝对位置编码将确保相同的单词在不同位置将产生不同的注意力输出。

(1) 位置之间没有关系。

较大的位置索引上的位置编码可能大于或小于较小的索引位置编码
→ 位置1的编码为0.1,可能大于位置2的编码为0.01,即 [0.1 > 0.01]
→ 位置1的编码为0.1,可能也小于位置3的编码为0.5,即 [0.1 < 0.5]

如果绝对位置编码:[0.1, 0.01, 0.05]:

(2) 相对距离不一致。
位置编码之间的差异无法告诉我们单词之间的距离。
→ 位置1到位置2的距离为 abs(0.1-0.01) = 0.09。
→ 位置1到位置3的距离为 abs(0.1-0.05) = 0.05。
(理想情况下,位置1到位置3的距离应大于位置1到位置2的距离)
这意味着绝对位置编码并没有捕捉到整个句子(或序列)的位置信息!

2.3.3.2 相对位置编码的主要缺点

(1) 计算效率低下
需要在自注意力机制后增加额外的步骤。
必须创建成对位置编码矩阵,然后进行大量的张量操作,以在每个时间步获取相对位置编码。

(2)不适合推理
在推理过程中,使用一种称为 KV 缓存的方法,这有助于减少推理速度。
使用 KV 缓存的一个要求是已生成单词的编码在生成新单词时不会改变(这是绝对位置编码提供的)。
因此,相对位置编码不适合推理,因为每个标记的嵌入在每个新时间步都会发生变化。

当序列长度为2时,单词的相对位置将是 [-1, 0, 1]。
当输入序列长度为3时,单词的相对位置将是 [-2,-1,0,1,2]。
由于使用这些相对位置来获取每个单词的相对位置编码,当相对位置集发生变化时,每个单词的编码也会改变。
这意味着给定句子[‘this’, ‘is’, ‘awesome’],在生成推理标记时,单词‘this’的编码会在每个时间步改变。
这就是为什么相对位置编码不常用于推理的原因

2.3.3.3 旋转位置编码

        旋转位置编码是使用旋转矩阵对绝对位置信息进行编码并在自注意力机制中自然地结合显式相对位置依赖关系的位置编码方法。旋转矩阵是通过某个角度将一个向量旋转到另一个向量的矩阵。

        旋转位置编码通过结合绝对位置编码和相对位置编码来克服它们的缺点。

首先,看看它是如何结合这两者的:

旋转位置编码中的绝对位置编码

        在旋转位置编码示例中,对于词语“this”和“is”,旋转矩阵仅取决于 m \times \Theta_i,其中 \Theta_i 在所有词语中是共享的,而 m 只是词语的位置(而不考虑其他词语的位置)。

        类似于绝对位置编码,为每个位置都有一个位置编码(不依赖于其他词语的位置)。

旋转位置编码中的相对位置编码

        由于旋转角度取决于m \times \Theta_i,句子开头的词语会有较小的旋转角度,而句子末尾的词语会有较大的旋转角度,从而使得这些距离具有相对性。

         注意:\Theta_i 对所有词语都是相同的,唯一改变的是词语的位置 m。较大的 m 意味着 m \times \Theta_i,即较大的旋转角度。

        类似于相对位置编码,不同位置的编码之间存在一定的关系。

        为了解释为什么旋转位置编码能够克服这些缺点,回顾一下这两种方法的缺点:

  • 绝对位置编码: (1) 不包含相对位置信息
  • 相对位置编码: (2) 计算效率低,(3) 不适用于推理

因此,我们可以知道,旋转位置编码必须具备以下特性:

  1. 包含相对位置信息
  2. 计算效率高
  3. 适用于推理

旋转位置编码包含相对位置信息

        旋转位置编码通过使旋转矩阵中的角度依赖于当前词语的位置,并通过旋转矩阵的性质,告诉我们两个词语之间的距离,从而包含了相对位置信息。

        例如,即使‘pig’和‘dog’(1)出现在两个长度不同的句子中,并且(2)‘pig’和‘dog’出现在句子的不同位置,它们之间的角度仍然是相同的。 角度相同是因为‘pig’和‘dog’之间始终间隔2个词语(‘chased the’在‘pig’和‘dog’之间)。 词语的旋转角度是 m \times \Theta_i,而‘pig’和‘dog’之间相隔两个词语的角度是m_{dog} - m_{pig} = 2 \times \Theta_i

(另一种观点理解)

  1. 我们仅讨论了“旋转角度”,但编码相对距离的部分在哪里呢?答案在于以下点积的定义:“向量A和B的点积等于A的长度乘以B的长度,再乘以它们之间角度的余弦值,A\cdot B = |A||B|\cos\theta

  2. 如果计算位置1的旋转位置编码与其他位置的点积,得到的将是\cos(\theta ),它只依赖于旋转角度(更具体地说,是从位置1到其他词语位置的旋转角度差异)。

  3. 位置1的旋转角度小于位置>1的旋转角度。这意味着,其他词离位置1越远,(1)旋转角度的差异越大,(2)点积越小(因为余弦值从0到π减小),(3)其他词离位置1越远。

  4. 因此可以得出结论,旋转位置编码包含了相对位置信息。

如何使用旋转位置编码:

  • 因为我们需要点积包含相对位置编码,将旋转位置编码分别应用于查询和键,这样当我们将它们矩阵相乘时,注意力矩阵就包含了相对位置信息。

        另外,旋转位置编码类似于绝对位置编码,其编码仅依赖于当前词的位置,已经生成的词的位置信息不会改变,因此在推断过程中可以再次使用KV缓存。

进一步分析可以参考【6】。

代码示例【9】:

import torch
from torch import nn
from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention

class RotaryPositionalEmbeddings(nn.Module):
	def __init__(self, d: int, base: int = 10_000):
		super().__init__()
        self.base = base
        self.d = d
        self.cos_cached = None
        self.sin_cached = None


	def _build_cache(self, x: torch.Tensor):
	 	if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
	        return

	    seq_len = x.shape[0]
	    theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
	    seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
	    idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
	    idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
	    self.cos_cached = idx_theta2.cos()[:, None, None, :]
		self.sin_cached = idx_theta2.sin()[:, None, None, :]

	def _neg_half(self, x: torch.Tensor):
		d_2 = self.d // 2
		return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

	def forward(self, x: torch.Tensor):
		self._build_cache(x)
		x_rope, x_pass = x[..., :self.d], x[..., self.d:]
		neg_half_x = self._neg_half(x_rope)
		x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
		return torch.cat((x_rope, x_pass), dim=-1)


class RotaryPEMultiHeadAttention(MultiHeadAttention):
	def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
		super().__init__(heads, d_model, dropout_prob)
		d_rope = int(self.d_k * rope_percentage)
		self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
		self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)

	def get_scores(self, query: torch.Tensor, key: torch.Tensor):
		return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

	def _test_rotary():
		x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
		x = x[:, None, None, :]
		inspect(x)

		rotary_pe = RotaryPositionalEmbeddings(4)
		inspect(rotary_pe(x))


if __name__ == '__main__':
	_test_rotary()

2.4 SwiGLU

        SwiGLU(Swish-Gated Linear Unit)是一种激活函数,旨在提高深度学习模型的性能。作为一种改进的激活函数,用于深度神经网络中的非线性变换。SwiGLU 结合了 Swish 激活函数和 Gated Linear Unit (GLU) 的特点。

SwiGLU 的激活函数定义如下:

\text{SwiGLU}(x) = \text{Swish}(x) \times \text{GLU}(x)

其中:

  • Swish 激活函数是一个具有自适应特性的激活函数,公式为:\text{Swish}(x) = x \cdot \sigma(x)其中,\sigma(x)是 Sigmoid 函数。

  • GLU 激活函数是 Gated Linear Unit 的一种变体,公式为:           \text{GLU}(x) = \text{Linear}(x) \times \sigma(\text{Linear}(x))

        SwiGLU 结合了这两种激活函数的优点,旨在提高模型的表达能力和训练效率。它通过非线性激活和门控机制来改善模型的性能,尤其是在处理复杂数据和任务时【11】。

SwiGLU(x) = Swish(W1x+b)⊗(Vx+c)

FFNSwiGLU(x) = (Swish1(xW)⊗xV)W2

具体来说,SwiGLU 的计算可以分为以下几个步骤:

  1. 线性变换:输入 x 分别通过两个线性变换 W_1W_2,得到两个中间结果 x \cdot W_1​ 和 x \cdot W_2
  2. Swish 激活:对 x \cdot W_1​ 进行 Swish 激活,得到激活后的输出 \text{Swish}(x \cdot W_1)
  3. 门控机制:将 Swish 激活的输出与另一个线性变换的输出 x \cdot W_2相乘,实现门控机制。

        Swish 激活部分提供了一个平滑的非线性函数,确保了梯度的良好流动,尤其是在靠近零的输入区域。  门控部分 通过x \cdot W_2 和 Sigmoid 门控机制引入了额外的控制,可以更灵活地选择哪些信息通过,这有助于模型捕捉更复杂的模式。最终的输出是两个部分的逐元素乘积,既保留了 Swish 激活的平滑性和良好的梯度流动特性,也引入了 GLU 的门控机制,以增强非线性表达能力。

代码示例【12】:

class SwiGLU(nn.Module):

    def __init__(self, w1, w2, w3) -> None:
        super().__init__()
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3

    def forward(self, x):
        x1 = F.linear(x, self.w1.weight)
        x2 = F.linear(x, self.w2.weight)
        hidden = F.silu(x1) * x2
        return F.linear(hidden, self.w3.weight)

F.silu函数与ß=1时的swish相同。

3. 参考材料

【1】 LLaMA Pro: Progressive LLaMA with Block Expansion

【2】Root Mean Square Layer Normalization

【3】RMSNorm GITHUB

【4】GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

【5】Demystifying GQA — Grouped Query Attention for Efficient LLM Pre-training

【6】grouped-query-attention-pytorch GITHUB

【7】RoFormer: Enhanced Transformer with Rotary Position Embedding

【8】Understanding Rotary Positional Encoding

【9】Rotary Positional Embeddings (RoPE)

【10】GLU Variants Improve Transformer

【11】SwiGLU: GLU Variants Improve Transformer

【12】What Is SwiGLU? How to Implement It? And Why Does it Work?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2057384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day38——动态库与静态库

一、linux系统中的库 库在linux系统中是一个二进制文件&#xff0c;它是由XXX.c&#xff08;不包含main函数&#xff09;文件编译而来的&#xff0c;分为静态库和动态库。 库在系统中的内容是不可见的&#xff0c;是一个二进制乱码 当程序需要使用库中的相关函数时&#xff…

PyTorch升级之旅——安装与基本知识

目录 一、安装 二、张量 创建tensor 张量的操作 广播机制 三、自动求导 四、并行计算 &#xff08;一&#xff09;网络结构分布到不同的设备中(Network partitioning) &#xff08;二&#xff09;同一层的任务分布到不同数据中(Layer-wise partitioning) &#xff08;…

Linux下编译安装redis-哨兵模式

Linux下编译安装-哨兵模式 哨兵sentinel模式 sentinel相当于是一个投票者或者哨兵&#xff0c;它时刻监视着redis集群的各个服务器&#xff0c;当主master挂了之后&#xff0c;它将进行投票进行新master的选举。下图 Sentinel的工作方式 每个Sentinel以每秒钟一次的频率向…

python从入门到精通:数据容器

数据容器介绍 一种可以容纳多份数据的数据类型&#xff0c;容纳的每一份数据称之为一个元素&#xff0c;可以是任意类型的数据&#xff0c;如字符串、数字、布尔等。 数据容器根据特点的不同&#xff0c;如&#xff1a; 是否支持重复元素 是否可以修改 是否有序&#xff0…

Web3链上聚合器声呐已全球上线,开启区块链数据洞察新时代

在全球区块链技术高速发展的浪潮中&#xff0c;在创新发展理念的驱动下&#xff0c;区块链领域的工具类应用备受资本青睐。 2024年8月20日&#xff0c;由生纳&#xff08;香港&#xff09;国际集团倾力打造的一款链上应用工具——“声呐链上聚合器”&#xff0c;即“声呐链上数…

数据分析实操案例分享:如何对人事数据进行BI分析?

在数据驱动时代&#xff0c;数据分析已经成为企业和个人获取竞争优势的关键技能。特别是在人力资源管理领域&#xff0c;数据分析的应用正变得越来越重要。通过对在职和离职数据的深入分析&#xff0c;企业不仅能够洞察员工的动态&#xff0c;揭示员工流动的模式、预测人才需求…

快速体验微软TTS服务

微软的语音合成服务&#xff08;TTS&#xff09;拥有500多种高品质的音色&#xff0c;并且在全球都有节点可以接入&#xff0c;在国内访问延迟可以控制在毫秒级。下面介绍在不需要编码的情况下&#xff0c;如何快速体验微软TTS的效果。 方式一、微软语音库UI界面 语音库地址&…

网安加·百家讲坛 | 裴伟伟:蓝牙音箱和耳机安全测评报告

作者简介&#xff1a;裴伟伟&#xff0c;洞源实验室创始人&#xff0c;国家网安基地网络安全行业专家&#xff0c;网安加社区特聘专家&#xff0c;持有CISSP、PMP证书&#xff0c;曾在HITCON、可信云大会、开源产业大会等安全论坛发表演讲。曾任国内某安全实验室负责人、某互金…

Oracle SQL - 合并重叠的期间

数据和目标 有如下数据存储了各组件的有效期间&#xff08;此处起止日期用数字代替以便查阅&#xff09;&#xff0c;目标为将有重叠的期间合并到一起。 SQL> SELECT * FROM demo_eff_periods;COMPONENT_ITEM_ID EFFECTIVITY_DATE DISABLE_DATE ----------------- -------…

Spring GateWay自定义断言工厂

文章目录 概要整体架构流程最终的处理效果小结 概要 我们在线上系统部署了&#xff0c;灰度环境和生产环境时&#xff0c;可以通过自定义断言工厂去将请求需要路由到灰度环境的用户调用灰度的的服务集群&#xff0c;将正常的用户调用正常集群。 这样&#xff0c;我们可以在上线…

【UCB CS61C】Lecture 1 - Number Representation 数制

目录 进制的定义常用的进制与换算十进制到二进制的转换二进制到十六进制、十六进制到二进制的转换二进制向 n 进制的转换 有符号数处理&#xff08;Signed Representation&#xff09;无符号整数&#xff08;Unsigned Integers&#xff09;有符号整数&#xff08;Signed Interg…

亚德诺(ADI)超静音步进电机驱动芯片——TMC2209

芯品快报:德州仪器(TI)的高性能、集成式的双全桥电机驱动器——DRV8412 芯品快报:亚德诺(ADI)超静音步进电机驱动芯片——TMC2209 原创 IPBrain平台君 集成电路大数据平台 2024年08月16日 19:18 北京 平台君今天给大家介绍一款亚德诺(ADI)公司的用于两相步进电机的超…

Elasticsearch 使用误区之四——不合理的使用 track_total_hits

0、企业级实战问题 在使用 Elasticsearch 进行搜索时&#xff0c;我们常常关心匹配查询的文档总数而将 track_total_hits 设置为 true&#xff0c;如下截图所示&#xff0c;在数据量非常大的情况下这种检索导致的问题是&#xff1a;查询特别慢&#xff0c;聚合会更慢&#xff0…

RKNN在转换过程中的均值和方差设置问题

为什么ONNX转RKNN要匹配均值和方差&#xff1f; 因为不匹配精度会下降&#xff01;&#xff01;&#xff01; 一般的类似于YOLO模型 YOLO模型在ONNX转RKNN时rknn.config设置为 一些其他模型将数据送入模型时会进行前处理&#xff0c;前处理会设置均值和方差&#xff0c;则在转…

【Nginx】实现 FastCGI

为什么会有 FastCGI &#xff1f; CGI 协议虽然解决了语言解析器和 Web Server 之间通讯的问题&#xff0c;但是它的效率很低&#xff0c;因为 Web Server每收到一个请求都会创建一个CGI 进程&#xff0c; PHP 解析器都会解析 php.ini 文件&#xff0c;初始化环境&#xff0c…

LCP142 环形链表[leetcode-7]

LCP142 环形链表 先上结果 前排提醒&#xff0c;本文有两种解法&#xff0c;和原理分析 给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 null。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次…

数据结构与算法 - 设计

1. LRU缓存 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;则返回关键字的值&#xff0…

0819、0820梳理及一些面试题梳理

一、抓包分析 二、HTTP服务器 三、动态库与静态库 四、一些面试题 指针数组和数组指针的区别&#xff1a;指针数组本质是一个数组&#xff0c;只是数组中存储的是指针变量。数组指针存储的是该数组的起始地址&#xff0c;对该指针来说每偏移一个单位就是偏移了一整个数组的地…

如何寻找专业精密机械零件代加工工厂

在现代工业生产中&#xff0c;精密机械零件的加工质量直接关系到产品的性能和可靠性。因此&#xff0c;寻找一家专业的精密机械零件代加工工厂至关重要。以下时利和整理分享的一些关于寻找专业精密机械零件代加工工厂的关键步骤和要点&#xff0c;帮助你找到合适的合作伙伴。 首…

想投资现货黄金?在TMGM开户需要多少钱?

最近&#xff0c;越来越多的人开始关注黄金投资&#xff0c;希望通过黄金来对冲风险、保值增值。而选择一家可靠的交易平台是进行黄金投资的第一步。TMGM作为全球知名的外汇交易商&#xff0c;也为投资者提供了黄金交易服务。那么&#xff0c;在TMGM开户投资黄金&#xff0c;需…