文章目录
- 1. Llama3 整体结构
- 2. 模块详解
- 2.1 模块1: Embeddings
- 2.2 模块2: RoPE
- 2.3 模块3: Transformer Block
- 2.4 模块4: RMSNorm
- 2.5 模块5: Attention
- 2.6 模块6: ADD
- 2.7 模块7: FFN
- 2.8 模块8: Linear
1. Llama3 整体结构
llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。
注: 下一节中会参照上图中 红色圆形序号 讲解各模块。
2. 模块详解
2.1 模块1: Embeddings
llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。
2.2 模块2: RoPE
这部分今天先不写,主要是写不完,公式太多了。。。
2.3 模块3: Transformer Block
Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。
2.4 模块4: RMSNorm
RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
layer normalization 是根据下面的公式对
x
x
x的分布进行调整。
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
x = a * \frac{x - \overline{x}}{std + eps} + b
x=a∗std+epsx−x+b
其中,
x
‾
\overline{x}
x是均值,
s
t
d
std
std是标准差,
e
p
s
eps
eps为一个很小的数,防止分母为零。
a
a
a、
b
b
b为参数,
b
b
b可以为零。
我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子
x
=
a
∗
x
−
x
‾
s
t
d
+
e
p
s
+
b
x = a * \frac{x - \overline{x}}{std + eps} + b
x=a∗std+epsx−x+b进行了简化。RMSNorm的计算公式如下:
a
‾
i
=
a
i
R
M
S
(
a
)
g
i
,
w
h
e
r
e
R
M
S
(
a
)
=
1
n
Σ
i
=
1
n
a
i
2
\overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}}
ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2
从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(ai−a))。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。
下面附上meta llama3中RMSNorm的源码,方便大家理解。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
2.5 模块5: Attention
llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
我们知道,在《Attention is all you need》一文中,作者为了提高计算效率,提出了MHA技术,思想是采用分而治之的策略,把K、Q、V 对应的切分为若干个短向量,然后使用Scaled Dot-Product Attention 计算出attention score后,再把结果拼接起来,从而避免了超大向量乘法的计算消耗,从而提高了计算效率。如下图所示。
然而,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
说了半天,其实在源码层次来就,就是在计算Scaled Dot-Product Attention之前对query进行个分组,组内共享一套Key和value。下面是meta llama3中的Attention类,方便大家理解。
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
.
.
.
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
# 以下是Scaled Dot-Product Attention的计算
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
2.6 模块6: ADD
此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。
2.7 模块7: FFN
由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s
i
l
u
(
x
)
=
x
∗
σ
(
x
)
,
w
h
e
r
e
σ
(
x
)
i
s
t
h
e
l
o
g
i
s
t
i
c
s
i
g
m
o
i
d
.
silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid.
silu(x)=x∗σ(x), where σ(x) is the logistic sigmoid.
函数的激活曲线如下图:
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。
下面给出主要源码。
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
.
.
.
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
2.8 模块8: Linear
此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)