1. 背景介绍
在文章《关于LLaMA 3.1 405B以及小模型的崛起》中,我们提到,LLaMA 3.1 的模型架构基本上已经成为当前LLM 模型的标准结构,和《Transformer原理分析》中提到的结构,也类似。但相比较,其中的一些关键模块做了一些优化和改进。本篇文章主要就是针对其中的几个改进模块,做一定的分析和记录。
通过图1 对比【1】,来看下LLaMA结构的改进点,主要有四点:Layer Norm换成了RMSNorm、 Self-Attension采用了Group Query Attention、Positional Embedding采用了ROPE、FFN结构采用了SwiGLU。
图1. Transformer结构与LLaMA结构对比
2. 各模块介绍
2.1 RMS Norm
RMSNorm(Root Mean Square Layer Normalization)【2】是一种用于深度神经网络的归一化技术,旨在对网络层的输出进行归一化,以加速模型训练并提高模型的稳定性。它是Layer Normalization的一种变体,但其实现方式与LayerNorm略有不同。
RMSNorm不依赖于均值,而是通过计算特征向量的均方根(Root Mean Square, RMS)来进行归一化。具体来说,对于输入向量 x(假设其维度为 d),RMSNorm的计算方式如下:
-
计算均方根 (RMS):
这里,表示输入向量 x 的第 i 个元素。
-
归一化输出:
其中,g 是一个可学习的标量参数,用于重新缩放归一化后的输出。
RMSNorm的特点
不依赖均值: 与LayerNorm不同,RMSNorm仅使用RMS而不依赖于均值来进行归一化,这使得其计算更简单、代价更低。
计算效率: 由于不计算均值,RMSNorm的计算复杂度较低,尤其是在高维度输入的情况下,RMSNorm相比LayerNorm更为高效。
适用于各种网络结构: RMSNorm可以应用于各种神经网络结构,如Transformer等,能够提高模型的训练稳定性。
理解RMSNorm
RMSNorm的核心思想是通过控制每个特征向量的长度,使网络层的输出保持稳定。通过去除均值计算,RMSNorm减少了归一化过程中的计算量,但仍然能够有效抑制梯度爆炸或消失问题,提高训练的稳定性。
由于RMSNorm仅依赖于输入向量的RMS值,因此它在计算上比LayerNorm更简洁,并且在某些情况下可以提供类似甚至更好的性能。特别是在大规模神经网络如Transformer的训练中,RMSNorm已经成为一种常见的归一化选择。
代码示例【3】
class RMSNormLayer(lasagne.layers.Layer):
def __init__(self, incoming, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.),
W=lasagne.init.Normal(0.05), nonlinearity=relu, **kwargs):
super(RMSNormLayer, self).__init__(incoming, **kwargs)
self.nonlinearity = nonlinearity
k = self.input_shape[1]
if b is not None:
self.b = self.add_param(b, (k,), name="b", regularizable=False)
if g is not None:
self.g = self.add_param(g, (k,), name="g")
if len(self.input_shape)==4:
self.axes_to_sum = (2,3)
self.dimshuffle_args = ['x',0,'x','x']
else:
self.axes_to_sum = 1
self.dimshuffle_args = ['x',0]
def get_output_for(self, input, **kwargs):
meanS = T.mean(input ** 2,axis=self.axes_to_sum,keepdims=True)
norm_input = input / T.sqrt(meanS + 1e-6)
if hasattr(self, 'g'):
activation = norm_input*self.g.dimshuffle(*self.dimshuffle_args)
else:
activation = norm_input
if hasattr(self, 'b'):
activation += self.b.dimshuffle(*self.dimshuffle_args)
return self.nonlinearity(activation)
def rms_norm(layer, b=lasagne.init.Constant(0.), g=lasagne.init.Constant(1.), **kwargs):
nonlinearity = getattr(layer, 'nonlinearity', None)
if nonlinearity is not None:
layer.nonlinearity = lasagne.nonlinearities.identity
if hasattr(layer, 'b'):
del layer.params[layer.b]
layer.b = None
return RMSNormLayer(layer, b, g, nonlinearity=nonlinearity, **kwargs)
2.2 Group Query Attention
分组查询注意力(GQA, Grouped Query Attention)【4】是一种改进的注意力机制,旨在提高多头自注意力机制的计算效率和内存利用率。GQA的主要思想是将查询(query)分组,以减少计算复杂度,同时保持注意力机制的有效性。
2.2.1 GQA的关键点
-
分组查询: 在GQA中,查询(query)被分成多个组,每个组独立计算注意力权重。这种分组可以减少计算量,因为每组的查询数量减少了。
-
共享键和值: 尽管查询被分组,GQA会在所有组中共享键(key)和值(value),从而减少内存消耗,并确保模型的注意力机制仍然能够有效捕捉输入之间的关系。
-
计算效率: 通过减少查询的数量,GQA大幅降低了计算成本,使得模型可以更高效地处理大规模数据,尤其是在资源有限的环境中。
-
保留性能: 尽管进行了分组,GQA仍然能够保持与标准多头自注意力机制相似的性能。实验表明,在许多任务中,GQA可以在减少计算和内存使用的同时,仍然取得与传统方法相当的结果。
图2. GQA形式
多头注意力拥有 H 个查询、键和数值头部。多查询注意力在所有查询头部之间共享单个键和数值头部。分组查询注意力为每组查询头部共享单个键和数值头部,从而在多头和多查询注意力之间插值。
2.2.2 标准多头自注意力机制(MHSA)公式回顾
标准的多头自注意力机制的计算可以表示为:
其中:
- Q 是查询矩阵(Query)
- K 是键矩阵(Key)
- V 是值矩阵(Value)
- 是键向量的维度
在多头自注意力机制中,多个注意力头通过不同的线性变换进行计算,然后将结果拼接起来:
其中,每个注意力头的计算为:
2.2.3 分组查询注意力(GQA)公式
在GQA中,查询矩阵 Q 被分为 G 组,每组的查询向量数量为,其中 N 是查询向量的总数。键矩阵 K 和值矩阵 V 则通常保持不变,并在所有组之间共享。
假设 表示第 g 组的查询矩阵(),那么GQA的计算可以表示为:
分组注意力计算:
拼接分组输出:
其中:
- 表示第 g 组的注意力计算结果。
- 是输出的线性变换矩阵。
通过这种分组方法,GQA减少了计算复杂度,同时仍然能够有效捕捉查询与键、值之间的相关性。
2.2.4 特殊情形
在分组查询注意力(GQA)中,传统多头模型中的查询头(Q)被分成 G 个组。每个组分配一个键(K)和一个值(V)头。这种配置被表示为 GQA-G,其中 G 表示组的数量。
GQA的特殊情况【5】:
- GQA-1 = MQA:当只有一个组(G = 1)时,GQA等同于MQA(单头查询注意力),因为所有查询头共享一个键和一个值头。
- GQA-H = MHA:当组的数量等于头的数量(G = H)时,GQA的表现与传统的多头注意力机制(MHA)相同,每个查询头都有其独立的键和值头。
2.2.5 代码示例【6】
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import einsum, rearrange
from torch import Tensor, nn
def scaled_dot_product_gqa(
query: Tensor,
key: Tensor,
value: Tensor,
dropout: float = 0.0,
scale: Optional[float] = None,
mask: Optional[Tensor] = None,
is_causal: Optional[bool] = None,
need_weights: bool = False,
average_attn_weights: bool = False,
force_grouped: bool = False,
):
"""Scaled dot product attention with support for grouped queries.
Einstein notation:
- b: batch size
- n / s: sequence length
- h: number of heads
- g: number of groups
- d: dimension of query/key/value
Args:
query: Query tensor of shape (b, n, h, d)
key: Key tensor of shape (b, s, h, d)
value: Value tensor of shape (b, s, h, d)
dropout: Dropout probability (default: 0.0)
scale: Scale factor for query (default: d_query ** 0.5)
mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is
applied to all 'n' rows of the attention matrix. (default: None)
force_grouped: If True, apply grouped-query attention even if the number of
heads is equal for query, key, and value. (default: False)
Returns:
2-tuple of:
- Attention output with shape (b, n, h, d)
- (Optional) Attention weights with shape (b, h, n, s). Only returned if
'need_weights' is True.
"""
if (mask is not None) and (is_causal is not None):
raise ValueError(
"Only one of 'mask' and 'is_causal' should be provided, but got both."
)
elif not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(
f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}."
)
# Move sequence length dimension to axis 2.
# This makes the attention operations below *much* faster.
query = rearrange(query, "b n h d -> b h n d")
key = rearrange(key, "b s h d -> b h s d")
value = rearrange(value, "b s h d -> b h s d")
bq, hq, nq, dq = query.shape
bk, hk, nk, dk = key.shape
bv, hv, nv, dv = value.shape
if not (bq == bk == bv and dq == dk == dv):
raise ValueError(
"Expected query, key, and value to have the same batch size (dim=0) and "
f"embedding dimension (dim=3), but got query: {query.shape}, "
f"key: {key.shape}, and value: {value.shape}."
)
elif (hk != hv) or (nk != nv):
raise ValueError(
"Expected key and value to have the same size in dimensions 1 and 2, but "
f"got key: {key.shape} and value: {value.shape}."
)
elif hq % hk != 0:
raise ValueError(
"Expected query heads to be a multiple of key/value heads, but got "
f"query: {query.shape} and key/value: {key.shape}."
)
if scale is None:
scale = query.size(-1) ** 0.5
query = query / scale
num_head_groups = hq // hk
query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s")
if is_causal:
# Mask out the upper triangular portion of the attention matrix. This prevents
# the model from attending to tokens in the future.
mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_()
if mask is not None:
# Expand mask to match the shape of the attention matrix.
# If mask is 2D, assume that it is applied to the key/value sequence dimension.
# Else if mask is 3D, assume that it is applied to the query/key/value sequence
# dimension for all attention heads.
#
# Users could also provide a 4D mask, which is applied to the query/key/value
# sequence dimension for each attention head (though I don't have a particular
# use case in mind for that).
if mask.ndim == 2:
mask = rearrange(mask, "b s -> b () () () s")
elif mask.ndim == 3:
mask = rearrange(mask, "b n s -> b () () n s")
# Mask similarity values by setting them to negative infinity. This guarantees
# that they will not contribute to the softmax computation below.
similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)
attention = F.softmax(similarity, dim=-1)
if dropout > 0.0:
attention = F.dropout(attention, p=dropout)
# Apply attention matrix to the value Tensor.
out = einsum(attention, value, "b g h n s, b h s d -> b g h n d")
# Move head dimension back to axis 2
out = rearrange(out, "b g h n d -> b n (h g) d")
attn_weights: Optional[Tensor] = None
if need_weights:
# Move the sequence dimensions back to positions 1, 2. Move the head dimension
# to position 3. This more closely matches the return shape of the attention
# output: (b, n, h, d).
attn_weights = rearrange(attention, "b g h n s -> b n s (h g)")
if average_attn_weights:
attn_weights = attn_weights.mean(dim=1)
return out, attn_weights
class MultiheadGQA(nn.Module):
"""Multi-head grouped query attention (GQA) layer.
Reference:
"GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
https://arxiv.org/pdf/2305.13245v1.pdf
GQA is a variant of multihead attention (MHA) that uses fewer write heads
(key / value) than query heads. GQA can be viewed as a generalization of
multi-query attention (MQA), which uses a single write head. GQA and MQA give
significant speedups over standard MHA in decoder layers, with minimal loss in
accuracy. In the paper, GQA is shown to be more accurate than MQA, while still
having a significant speedup over MHA.
NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model
from MHA to GQA. As a result, they do not mention parameter initialization or
layer normalization strategies. I follow the best practices laid out in the
MAGNETO paper, which improves Transformer performance through better parameter
initialization and layer norm placement. See:
https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
"""
def __init__(
self,
embed_dim: int,
query_heads: int,
kv_heads: int,
dropout: float = 0.0,
bias: bool = True,
layer_norm: bool = True,
layer_norm_eps: float = 1e-5,
gamma_init: float = 1.0,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.query_heads = query_heads
self.kv_heads = kv_heads
self.dropout = dropout
self.layer_norm = layer_norm
self.gamma_init = gamma_init
if self.query_heads % self.kv_heads != 0:
raise ValueError(
f"query_heads ({query_heads}) must be divisible by "
f"kv_heads ({kv_heads})"
)
elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0):
raise ValueError(
f"embed_dim ({embed_dim}) must be divisible by "
f"query_heads ({query_heads}) and kv_heads ({kv_heads})"
)
head_dim = embed_dim // query_heads
if not head_dim % 8 == 0:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
)
if not head_dim <= 128:
raise ValueError(
f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
)
# Query projection layer is the same as in vanilla MHA.
self.q_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
# Key/value projection layers have a smaller output dimension, so that
# the we have fewer key/value attention heads after reshaping.
kv_embed_dim = embed_dim // query_heads * kv_heads
self.k_proj = nn.Linear(
embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
)
self.v_proj = nn.Linear(
embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
)
self.norm: Optional[nn.LayerNorm] = None
if layer_norm:
self.norm = nn.LayerNorm(
embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
)
# Grouped attention output will have the same embedding dimension as the
# key/value Tensors. So the output projection layer needs to accept the
# same dimension (kv_embed_dim).
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_normal_(self.q_proj.weight)
if self.q_proj.bias is not None:
nn.init.constant_(self.q_proj.bias, 0)
nn.init.xavier_normal_(self.k_proj.weight)
if self.k_proj.bias is not None:
nn.init.constant_(self.k_proj.bias, 0)
# NOTE: We follow the initialization strategy from MAGNETO. See:
# https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
# Gain (self.gamma_init) should be provided as a keyword argument when
# initializing the larger Transformer model, since it requires knowledge
# of the number of encoder/decoder layers in the model.
nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
if self.v_proj.bias is not None:
nn.init.constant_(self.v_proj.bias, 0)
nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
need_weights: bool = False,
# TODO
# attn_mask: Optional[Tensor] = None,
is_causal: bool = False,
average_attn_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
# Notation:
# b - batch size
# n - sequence length
# h - number of heads
# d - embedding dimension
#
# Input shape: (b, n, d)
q: Tensor = self.q_proj(query)
k: Tensor = self.k_proj(key)
v: Tensor = self.v_proj(value)
# Unfold 'd' dimension into 'h' separate attention heads.
q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads)
k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads)
v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads)
# Apply attention, then fold 'h' attention heads back into 'd'.
x, attn = scaled_dot_product_gqa(
query=q,
key=k,
value=v,
# TODO
# mask=attn_mask,
is_causal=is_causal,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
force_grouped=False,
)
x = rearrange(x, "b n h d -> b n (h d)")
# NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO
# architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra
# layer norm before the linear output projection. The cross-attention layer in
# the MAGNETO decoder does not include this layer norm, so users have the
# option to disable it (layer_norm=False).
if self.layer_norm:
assert self.norm is not None
x = self.norm(x)
# Linear projection on attention outputs.
x = self.out_proj(x)
return x, attn
2.3 RoPE(Rotary Positional Encoding)
Rotary Positional Encoding (RoPE) 是一种用于自注意力模型的位置信息编码方法,用于增强模型对序列中位置信息的理解。与传统的正弦和余弦位置编码不同,RoPE通过旋转操作将位置信息直接嵌入到查询(query)和键(key)向量中,以更好地捕捉序列中不同位置的相对关系。
2.3.1 RoPE 的关键特点
旋转编码:RoPE通过在查询和键向量的不同维度之间应用旋转操作,将位置信息嵌入到向量表示中。具体来说,RoPE将每个维度的向量分成两部分,分别应用二维旋转矩阵进行处理。
相对位置信息:RoPE自然地编码了相对位置信息,而不是绝对位置信息,这使得模型能够更好地捕捉序列中元素之间的相对关系,尤其在处理长序列时表现更优。
无额外参数:RoPE不引入额外的模型参数,只是在已有的查询和键向量上应用旋转操作,因此不增加模型的复杂性。
2.3.2 RoPE 的数学表达
对于一个维度 d 的查询或键向量 x,假设 是与位置 p 相关的旋转角度,则 RoPE 的旋转操作可以表示为:
在这个操作中, 和 分别是向量 x 的一部分,旋转操作会对这些部分应用一个与位置 p 相关的旋转角度,使得向量在不同位置上具有相应的变化。
2.3.3 旋转位置编码取代绝对位置编码的原因【8】
2.3.3.1 绝对位置编码的主要缺点
无法包括相对位置信息,虽然绝对位置编码捕捉到了单词的位置信息,但它并未捕捉整个句子(或序列)的位置信息。
示例:
一种常见的创建长度为3的绝对位置编码的方法是进行随机初始化。
假设我们得到如下结果:[0.1, 0.01, 0.5],这种绝对位置编码将确保相同的单词在不同位置将产生不同的注意力输出。(1) 位置之间没有关系。
较大的位置索引上的位置编码可能大于或小于较小的索引位置编码
→ 位置1的编码为0.1,可能大于位置2的编码为0.01,即 [0.1 > 0.01]
→ 位置1的编码为0.1,可能也小于位置3的编码为0.5,即 [0.1 < 0.5]
如果绝对位置编码:[0.1, 0.01, 0.05]:
(2) 相对距离不一致。
位置编码之间的差异无法告诉我们单词之间的距离。
→ 位置1到位置2的距离为 abs(0.1-0.01) = 0.09。
→ 位置1到位置3的距离为 abs(0.1-0.05) = 0.05。
(理想情况下,位置1到位置3的距离应大于位置1到位置2的距离)
这意味着绝对位置编码并没有捕捉到整个句子(或序列)的位置信息!
2.3.3.2 相对位置编码的主要缺点
(1) 计算效率低下
需要在自注意力机制后增加额外的步骤。
必须创建成对位置编码矩阵,然后进行大量的张量操作,以在每个时间步获取相对位置编码。(2)不适合推理
在推理过程中,使用一种称为 KV 缓存的方法,这有助于减少推理速度。
使用 KV 缓存的一个要求是已生成单词的编码在生成新单词时不会改变(这是绝对位置编码提供的)。
因此,相对位置编码不适合推理,因为每个标记的嵌入在每个新时间步都会发生变化。
当序列长度为2时,单词的相对位置将是 [-1, 0, 1]。
当输入序列长度为3时,单词的相对位置将是 [-2,-1,0,1,2]。
由于使用这些相对位置来获取每个单词的相对位置编码,当相对位置集发生变化时,每个单词的编码也会改变。
这意味着给定句子[‘this’, ‘is’, ‘awesome’],在生成推理标记时,单词‘this’的编码会在每个时间步改变。
这就是为什么相对位置编码不常用于推理的原因。
2.3.3.3 旋转位置编码
旋转位置编码是使用旋转矩阵对绝对位置信息进行编码并在自注意力机制中自然地结合显式相对位置依赖关系的位置编码方法。旋转矩阵是通过某个角度将一个向量旋转到另一个向量的矩阵。
旋转位置编码通过结合绝对位置编码和相对位置编码来克服它们的缺点。
首先,看看它是如何结合这两者的:
旋转位置编码中的绝对位置编码
在旋转位置编码示例中,对于词语“this”和“is”,旋转矩阵仅取决于 ,其中 在所有词语中是共享的,而 m 只是词语的位置(而不考虑其他词语的位置)。
类似于绝对位置编码,为每个位置都有一个位置编码(不依赖于其他词语的位置)。
旋转位置编码中的相对位置编码
由于旋转角度取决于,句子开头的词语会有较小的旋转角度,而句子末尾的词语会有较大的旋转角度,从而使得这些距离具有相对性。
注意: 对所有词语都是相同的,唯一改变的是词语的位置 m。较大的 m 意味着 ,即较大的旋转角度。
类似于相对位置编码,不同位置的编码之间存在一定的关系。
为了解释为什么旋转位置编码能够克服这些缺点,回顾一下这两种方法的缺点:
- 绝对位置编码: (1) 不包含相对位置信息
- 相对位置编码: (2) 计算效率低,(3) 不适用于推理
因此,我们可以知道,旋转位置编码必须具备以下特性:
- 包含相对位置信息
- 计算效率高
- 适用于推理
旋转位置编码包含相对位置信息
旋转位置编码通过使旋转矩阵中的角度依赖于当前词语的位置,并通过旋转矩阵的性质,告诉我们两个词语之间的距离,从而包含了相对位置信息。
例如,即使‘pig’和‘dog’(1)出现在两个长度不同的句子中,并且(2)‘pig’和‘dog’出现在句子的不同位置,它们之间的角度仍然是相同的。 角度相同是因为‘pig’和‘dog’之间始终间隔2个词语(‘chased the’在‘pig’和‘dog’之间)。 词语的旋转角度是 ,而‘pig’和‘dog’之间相隔两个词语的角度是。
(另一种观点理解)
我们仅讨论了“旋转角度”,但编码相对距离的部分在哪里呢?答案在于以下点积的定义:“向量A和B的点积等于A的长度乘以B的长度,再乘以它们之间角度的余弦值,”
如果计算位置1的旋转位置编码与其他位置的点积,得到的将是,它只依赖于旋转角度(更具体地说,是从位置1到其他词语位置的旋转角度差异)。
位置1的旋转角度小于位置>1的旋转角度。这意味着,其他词离位置1越远,(1)旋转角度的差异越大,(2)点积越小(因为余弦值从0到π减小),(3)其他词离位置1越远。
因此可以得出结论,旋转位置编码包含了相对位置信息。
如何使用旋转位置编码:
- 因为我们需要点积包含相对位置编码,将旋转位置编码分别应用于查询和键,这样当我们将它们矩阵相乘时,注意力矩阵就包含了相对位置信息。
另外,旋转位置编码类似于绝对位置编码,其编码仅依赖于当前词的位置,已经生成的词的位置信息不会改变,因此在推断过程中可以再次使用KV缓存。
进一步分析可以参考【6】。
代码示例【9】:
import torch
from torch import nn
from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention
class RotaryPositionalEmbeddings(nn.Module):
def __init__(self, d: int, base: int = 10_000):
super().__init__()
self.base = base
self.d = d
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
seq_len = x.shape[0]
theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
idx_theta = torch.einsum('n,d->nd', seq_idx, theta)
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
self.cos_cached = idx_theta2.cos()[:, None, None, :]
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
d_2 = self.d // 2
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
self._build_cache(x)
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
neg_half_x = self._neg_half(x_rope)
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
return torch.cat((x_rope, x_pass), dim=-1)
class RotaryPEMultiHeadAttention(MultiHeadAttention):
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
super().__init__(heads, d_model, dropout_prob)
d_rope = int(self.d_k * rope_percentage)
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
def _test_rotary():
x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
x = x[:, None, None, :]
inspect(x)
rotary_pe = RotaryPositionalEmbeddings(4)
inspect(rotary_pe(x))
if __name__ == '__main__':
_test_rotary()
2.4 SwiGLU
SwiGLU(Swish-Gated Linear Unit)是一种激活函数,旨在提高深度学习模型的性能。作为一种改进的激活函数,用于深度神经网络中的非线性变换。SwiGLU 结合了 Swish 激活函数和 Gated Linear Unit (GLU) 的特点。
SwiGLU 的激活函数定义如下:
其中:
-
Swish 激活函数是一个具有自适应特性的激活函数,公式为:其中,是 Sigmoid 函数。
-
GLU 激活函数是 Gated Linear Unit 的一种变体,公式为:
SwiGLU 结合了这两种激活函数的优点,旨在提高模型的表达能力和训练效率。它通过非线性激活和门控机制来改善模型的性能,尤其是在处理复杂数据和任务时【11】。
SwiGLU(x) = Swish(W1x+b)⊗(Vx+c)
FFNSwiGLU(x) = (Swish1(xW)⊗xV)W2
具体来说,SwiGLU 的计算可以分为以下几个步骤:
- 线性变换:输入 x 分别通过两个线性变换 和 ,得到两个中间结果 和 。
- Swish 激活:对 进行 Swish 激活,得到激活后的输出 。
- 门控机制:将 Swish 激活的输出与另一个线性变换的输出 相乘,实现门控机制。
Swish 激活部分提供了一个平滑的非线性函数,确保了梯度的良好流动,尤其是在靠近零的输入区域。 门控部分 通过 和 Sigmoid 门控机制引入了额外的控制,可以更灵活地选择哪些信息通过,这有助于模型捕捉更复杂的模式。最终的输出是两个部分的逐元素乘积,既保留了 Swish 激活的平滑性和良好的梯度流动特性,也引入了 GLU 的门控机制,以增强非线性表达能力。
代码示例【12】:
class SwiGLU(nn.Module):
def __init__(self, w1, w2, w3) -> None:
super().__init__()
self.w1 = w1
self.w2 = w2
self.w3 = w3
def forward(self, x):
x1 = F.linear(x, self.w1.weight)
x2 = F.linear(x, self.w2.weight)
hidden = F.silu(x1) * x2
return F.linear(hidden, self.w3.weight)
F.silu函数与ß=1时的swish相同。
3. 参考材料
【1】 LLaMA Pro: Progressive LLaMA with Block Expansion
【2】Root Mean Square Layer Normalization
【3】RMSNorm GITHUB
【4】GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
【5】Demystifying GQA — Grouped Query Attention for Efficient LLM Pre-training
【6】grouped-query-attention-pytorch GITHUB
【7】RoFormer: Enhanced Transformer with Rotary Position Embedding
【8】Understanding Rotary Positional Encoding
【9】Rotary Positional Embeddings (RoPE)
【10】GLU Variants Improve Transformer
【11】SwiGLU: GLU Variants Improve Transformer
【12】What Is SwiGLU? How to Implement It? And Why Does it Work?