欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141462669
LLaMA3 是 Meta 的最新大语言模型,在整体网络设计进行多项升级,显著提升了模型的性能和效率,重要的改进,如下:
- 词汇量增加至 128k 个。
- 使用 RMS Normalization,即 根均方正则化。
- 使用 旋转位置编码 RoPE。
- 使用 Grouped Query Attention,即 分组查询注意力,head 数量是 32,4组,即 8 个 KV head。
- 使用 SwiGLU Feedforward Network,即 SwiGLU 前馈网络。
详细的推理流程如下:
添加依赖的 Python 库:
pip install tiktoken
pip install matplotlib
pip install blobfile
- tiktoken:快速的 BPE (字节对编码) 分词器,用于与 类 OpenAI 的模型一起使用。
- blobfile:用于处理云存储 (如 Amazon S3、Google Cloud Storage 等) 的 Python 库。
1. 加载 Tokenizer (分词器)
Tokenizer (分词器),将单词切分成 Token,常见的是 BPE (Byte Pair Encoding),字节对编码。
Byte Pair Encoding,即 BPE,也称为 digram coding 是一种算法,用于将文本字符串编码成表格形式,在下游模型中使用,BPE 是一种数据压缩技术,被 OpenAI 用于预训练 GPT 模型时的分词,被许多 Transformer 模型广泛使用。
Tokenizer 的源码:
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json
import os
import matplotlib.pyplot as plt
os.environ['CUDA_VISIBLE_DEVICES'] = "0" #(代表仅使用第0,1号GPU)
# Meta-Llama-3-8B 来自于 HuggingFace
model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
name=Path(tokenizer_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)
tokenizer.decode(tokenizer.encode("hello world!"))
关于 tiktoken.Encoding
的参数说明:
name=Path(tokenizer_path).name
:指定编码对象的名称。pat_str
:正则表达式,用于定义文本的分词规则,英文的默认参数。mergeable_ranks=load_tiktoken_bpe(tokenizer_path)
,加载 BPE 编码的规则,属于模型的一部分。special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
,定了特殊 token 的映射关系。将特殊 token 映射到一个整数值,这个整数值是基于已加载的 BPE 编码规则的长度计算得出的,顺次增加。
测试 hello world!
,编码和解码,输出一致,输出:
hello world!
2. 加载 Model (模型)
加载 Llama3 的模型参数,直接使用 pth
模型
model = torch.load(model_path)
print(json.dumps(list(model.keys())[:20], indent=4))
路径:
Meta-Llama-3-8B/original/consolidated.00.pth
加载 Llama3 的模型参数:
with open("Meta-Llama-3-8B/original/params.json", "r") as f:
config = json.load(f)
参数如下:
{
'dim': 4096,
'n_layers': 32,
'n_heads': 32,
'n_kv_heads': 8,
'vocab_size': 128256,
'multiple_of': 1024,
'ffn_dim_multiplier': 1.3,
'norm_eps': 1e-05,
'rope_theta': 500000.0
}
dim
:隐藏状态维度,即 4096,隐藏状态是模型在处理输入序列时的内部表示。n_layers
:模型的层数,即 32,层数决定模型的复杂度和表示能力。n_heads
:自注意力机制中的头数,即 32,自注意力机制用于捕捉输入序列中的不同关系。n_kv_heads
:键值注意力机制中的头数,即 8,键值对注意力机制用于处理不同类型的信息。vocab_size
:词汇表的大小,即 128256,词汇表包含模型可以处理的所有单词和标记。multiple_of
:模型的维度必须是这个值的倍数,即 1024,有助于优化模型的计算效率。ffn_dim_multiplier
:前馈神经网络 (FFN) 的维度相对于隐藏状态维度的倍数,这里设置为 1.3,FFN 用于对隐藏状态进行非线性变换。norm_eps
:归一化层的 epsilon 值, 1e-05,归一化层用于规范化模型的中间表示。rope_theta
:RoPE (Rotary Position Embedding) 的参数,即 500000.0,ROPE 是编码方式,用于处理序列中的相对顺序信息。
缓存参数:
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])
3. 文本 - Token - Embedding
文本 转换 Token
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = [128000] + tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)
输出:
[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']
Token 转换 Embedding
embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape
输出:
torch.Size([17, 4096])
4. 实现 RMS Normalization (根均方正则化)
具体RMS Normalization 的函数如下:
# def rms_norm(tensor, norm_weights):
# rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5
# return tensor * (norm_weights / rms)
def rms_norm(tensor, norm_weights):
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
RMS Normalization,即 Root Mean Square Normalization,根均方正则化,对于网络中层的输入进行规范化的方法,目的是使模型具有重新缩放不变性,并且具备隐式学习率自适应能力,相比于 Layer Normalization,RMS Normalization 更加高效,因为计算复杂度较低,RMS Normalization 公式如下:
x
‾
=
x
i
R
M
S
(
x
)
+
ϵ
∗
g
i
,
w
h
e
r
e
R
M
S
(
x
)
=
1
n
∑
i
=
1
n
x
i
2
\overline{x} = \frac{x_{i}}{RMS(x)+\epsilon}*g_{i},\ where \ RMS(x)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}}
x=RMS(x)+ϵxi∗gi, where RMS(x)=n1i=1∑nxi2
相对于 Layer Normalization 和 RMS Normalization,Layer Normalization 包含缩放和平移两个部分,RMS Normalization 去除了平移部分,只保留了缩放部分。研究表明 LayerNorm 取得成功的关键是缩放部分的缩放不变性,而不是平移部分的平移不变性。RMS Normalization 相比于 Layer Normalization,减少计算均值和平移系数的部分,训练速度更快,效果基本相当,甚至有所提升。
μ
=
∑
i
=
1
n
x
i
σ
=
1
n
∑
i
=
1
n
(
x
i
−
μ
i
)
2
y
=
x
−
μ
σ
+
ϵ
γ
+
β
\mu=\sum_{i=1}^{n}x_{i} \\ \sigma=\sqrt{\frac{1}{n}\sum_{i=1}^{n}(x_{i}-\mu_{i})^{2}} \\ y=\frac{x-\mu}{\sigma+\epsilon}\gamma+\beta
μ=i=1∑nxiσ=n1i=1∑n(xi−μi)2y=σ+ϵx−μγ+β
运行 RMS Normalization
token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"])
print(token_embeddings.shape)
输出:
torch.Size([17, 4096])
5. 计算 Self Attention (自注意力) Query
Llama3 的 QKV,以及 O 的相关权重:
print(
model["layers.0.attention.wq.weight"].shape,
model["layers.0.attention.wk.weight"].shape,
model["layers.0.attention.wv.weight"].shape,
model["layers.0.attention.wo.weight"].shape
)
输出:
torch.Size([4096, 4096]) # wq
torch.Size([1024, 4096]) # wk
torch.Size([1024, 4096]) # wv
torch.Size([4096, 4096]) # vo
Multi-Head Attention,Head (头) 的数量是 n_heads
= 32,即:
q_layer0 = model["layers.0.attention.wq.weight"]
head_dim = q_layer0.shape[0] // n_heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)
print(q_layer0.shape)
Query 的权重:
torch.Size([32, 128, 4096])
每个 Query Head 的维度是 [128, 4096]
,即:
q_layer0_head0 = q_layer0[0]
print(q_layer0_head0.shape) # torch.Size([128, 4096])
Token 的维度是 torch.Size([17, 4096])
,则与权重相乘,输出是 [17, 128]
,即:
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)
print(q_per_token.shape) # torch.Size([17, 128])
6. 实现 旋转位置编码 RoPE (Rotary Positional Encoding)
使用 RoPE,旋转位置编码,转换 Token 维度:
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
print(q_per_token_split_into_pairs.shape) # torch.Size([17, 64, 2])
构建旋转位置编码:
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
print(zero_to_one_split_into_64_parts)
输出:
tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
0.9844])
频率,rope_theta = 500000.0
来自于模型参数:
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
print(freqs)
输出:
tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])
转换成极坐标:
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
print(freqs_cis.shape) # torch.Size([17, 64])
# viewing tjhe third row of freqs_cis
value = freqs_cis[3]
plt.figure()
for i, element in enumerate(value[:17]):
plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}")
plt.annotate(f"{i}", xy=(element.real, element.imag), color='red')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.title('Plot of one row of freqs_cis')
plt.show()
极坐标的图像:
将 Query 的编码转换成 复数(complex) 形式,即 torch.Size([17, 64, 2])
转换成 torch.Size([17, 64])
,再与 freqs_cis
相乘,即:
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
print(q_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
print(q_per_token_as_complex_numbers_rotated.shape) # torch.Size([17, 64])
再将 复数(complex) 转换成实数,维度增加 2 维,则由 torch.Size([17, 64])
转换成 torch.Size([17, 64, 2])
,则:
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
print(q_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])
再转换成 Query 的维度,即 torch.Size([17, 64, 2])
转换成 torch.Size([17, 128])
,这样,Query 向量与 RoPE 位置编码相乘,即:
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
print(q_per_token_rotated.shape) # torch.Size([17, 128])
7. 计算 Self Attention (自注意力) Key 与 QK Mask
将 Key 的 权重 weight,拆分成 n_kv_heads
,输出维度 [1028, 4096]
转换 [8, 128, 4096]
,即:
k_layer0 = model["layers.0.attention.wk.weight"]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
print(k_layer0.shape) # torch.Size([8, 128, 4096])
Key 添加 RoPE 的位置编码,与 Query 类似,即:
k_layer0_head0 = k_layer0[0]
print(k_layer0_head0.shape) # torch.Size([128, 4096])
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) # [17, 4096] x [4096, 128] = [17, 128]
print(k_per_token.shape) # torch.Size([17, 128])
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
print(k_per_token_split_into_pairs.shape) # torch.Size([17, 64, 2])
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
print(k_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
print(k_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
print(k_per_token_rotated.shape) # torch.Size([17, 128])
Query 矩阵与 Key 矩阵,计算:
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
print(qk_per_token.shape)
输出:
torch.Size([17, 17])
自注意力机制的矩阵公式:
A
=
s
o
f
t
m
a
x
(
Q
K
⊤
d
k
)
V
A=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V
A=softmax(dkQK⊤)V
显示这个注意力矩阵:
def display_qk_heatmap(qk_per_token):
_, ax = plt.subplots()
im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
ax.set_xticks(range(len(prompt_split_as_tokens)))
ax.set_yticks(range(len(prompt_split_as_tokens)))
ax.set_xticklabels(prompt_split_as_tokens)
ax.set_yticklabels(prompt_split_as_tokens)
ax.figure.colorbar(im, ax=ax)
display_qk_heatmap(qk_per_token)
矩阵显示:
Decoder 的 Mask 矩阵:
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) # 最大负数
mask = torch.triu(mask, diagonal=1)
print(mask)
qk_per_token_after_masking = qk_per_token + mask
display_qk_heatmap(qk_per_token_after_masking)
矩阵显示:
Softmax 之后的概率矩阵:
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
display_qk_heatmap(qk_per_token_after_masking_after_softmax)
矩阵显示:
8. 计算 Self-Attention (自注意力) Value 与 QKV
Value 矩阵,与 Key 矩阵维度相同,都是 n_kv_heads=8
, 但是 Value 不需要使用位置编码,最终计算 qkv_attention
矩阵:
v_layer0 = model["layers.0.attention.wv.weight"]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
print(v_layer0.shape) # torch.Size([8, 128, 4096])
v_layer0_head0 = v_layer0[0]
print(v_layer0_head0.shape) # torch.Size([128, 4096])
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
print(v_per_token.shape) # torch.Size([17, 128])
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
print(qkv_attention.shape) # torch.Size([17, 128])
9. 计算 多头注意力 Grouped Query Attention
在 Llama3 中,头 n_heads=32
的数量是 32 个,KV头 n_kv_heads=8
的数量是 8 个,即每 4 个 Query 共享 1 组 KV。也就是说,多头自注意力 n_heads=32
与 KV 头 n_kv_heads=8
,融合到在一起,相当于 Query 的 1~4,使用相同的 Key 和 Value,输出 32 个 Head, [32, 17, 128]
维度,32x128 = 4096
,与 Embedding 的维度相同,即:
qkv_attention_store = []
scale = n_heads // n_kv_heads # 32 / 8 = 4
for head in range(n_heads):
q_layer0_head = q_layer0[head]
k_layer0_head = k_layer0[head//scale] # key weights are shared across 4 heads
v_layer0_head = v_layer0[head//scale] # value weights are shared across 4 heads
q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
qk_per_token_after_masking = qk_per_token + mask
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention_store.append(qkv_attention)
len(qkv_attention_store) # 32 个 head
GQA (Grouped Query Attention),与 MHA (Multi-Head Attention) 、MQA (Multi-Query Attention) 的区别如下:
- MHA 是一种基础的注意力机制,通过将输入分割成多个头 (heads) 来并行计算注意力,每个头学习输入的不同部分,最终将结果合并,以捕获序列的不同方面信息。
- MQA 是优化的注意力机制,通过让所有头共享相同的键 (keys) 和值 (values),减少了参数量和计算量,从而加快了推理速度,但可能会牺牲一些精度。
- GQA 是 MHA 和 MQA 的折中方案,将查询头 (Query Heads) 分组,每组共享一个键和值,而不是所有头都共享,能够在减少计算量的同时,保持更多的多样性,从而在推理速度和模型精度之间取得平衡。
再融合 32 个 Head,即完成数据转换:
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
print(stacked_qkv_attention.shape) # torch.Size([17, 4096])
输出权重的维度,即 [4096, 4096]
,再进行线性变换,输入经过自注意力的 Embedding,即:
w_layer0 = model["layers.0.attention.wo.weight"]
print(w_layer0.shape) # torch.Size([4096, 4096])
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
print(embedding_delta.shape) # torch.Size([17, 4096])
再进行残差连接,偏移 embedding_delta
相加的是非正则的矩阵(token_embeddings_unnormalized
),即:
embedding_after_edit = token_embeddings_unnormalized + embedding_delta
print(embedding_after_edit.shape) # torch.Size([17, 4096])
embedding_after_edit_normalized
进行 RMS 正则化,即:
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
print(embedding_after_edit_normalized.shape) # torch.Size([17, 4096])
10. 实现 SwiGLU 前馈网络 (SwiGLU Feedforward Network)
SwiGLU 前馈网络 的具体操作如下:
w1 = model["layers.0.feed_forward.w1.weight"] # torch.Size([14336, 4096])
w2 = model["layers.0.feed_forward.w2.weight"] # torch.Size([4096, 14336]), 扩展3.5倍
w3 = model["layers.0.feed_forward.w3.weight"] # torch.Size([14336, 4096])
tmp1 = torch.matmul(embedding_after_edit_normalized, w1.T) # 提升维度, [17, 14336]
tmp2 = torch.matmul(embedding_after_edit_normalized, w3.T) # 提升维度, [17, 14336]
output_after_feedforward = torch.matmul(torch.functional.F.silu(tmp1) * tmp2, w2.T)
print(output_after_feedforward.shape) # torch.Size([17, 4096])
# 残差连接
layer_0_embedding = embedding_after_edit+output_after_feedforward
print(layer_0_embedding.shape) # torch.Size([17, 4096])
SwiGLU FFN 的网络如下:
S
w
i
G
L
U
F
F
N
=
(
S
i
L
U
(
x
W
1
⊤
)
⊙
x
W
3
⊤
)
x
W
2
⊤
SwiGLU\ FFN=(SiLU(xW_{1}^{\top}) \odot xW_{3}^{\top})xW_{2}^{\top}
SwiGLU FFN=(SiLU(xW1⊤)⊙xW3⊤)xW2⊤
SwiGLU 相比于 ReLU 函数的优势:
- Swish 对于负值的响应相对较小,克服 ReLU 的输出始终为零
- GLU 具有门控特性,根据输入决定信息是否通过或过滤,使网络更有效地学习到有用的表示,有助于提高模型的泛化能力。
关于 F.silu
是 swish 激活函数:
s
i
l
u
(
x
)
=
x
∗
σ
(
x
)
σ
(
x
)
=
1
1
+
e
−
x
silu(x)=x*\sigma(x) \\ \sigma(x) = \frac{1}{1+e^{-x}}
silu(x)=x∗σ(x)σ(x)=1+e−x1
绘制函数:
def draw(func):
x = np.arange(-10, 10, 0.1)
y = []
x_torch = torch.from_numpy(x)
for t in x_torch:
y_1 = func(t)
y_1 = y_1.numpy()
y.append(y_1)
plt.plot(x, y, label="silu")
plt.xlabel("x")
plt.ylabel("y")
plt.xlim(-7, 7)
plt.ylim(-7, 7)
plt.grid()
plt.legend()
plt.show()
ReLU 图像:
draw(torch.functional.F.silu)
def my_silu(x):
t = 1 / (1 + torch.exp(-x))
return x*t
参考:torch.nn.SiLU
11. 实现网络的循环多层计算
每一层都运行相同模块,层数 n_layers=32
,即:
final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):
qkv_attention_store = []
layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
q_layer = model[f"layers.{layer}.attention.wq.weight"]
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
k_layer = model[f"layers.{layer}.attention.wk.weight"]
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
v_layer = model[f"layers.{layer}.attention.wv.weight"]
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
w_layer = model[f"layers.{layer}.attention.wo.weight"]
for head in range(n_heads):
q_layer_head = q_layer[head]
k_layer_head = k_layer[head//4]
v_layer_head = v_layer[head//4]
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
mask = torch.triu(mask, diagonal=1)
qk_per_token_after_masking = qk_per_token + mask
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention_store.append(qkv_attention)
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
w_layer = model[f"layers.{layer}.attention.wo.weight"]
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
embedding_after_edit = final_embedding + embedding_delta
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
final_embedding = embedding_after_edit+output_after_feedforward
再执行 RMS Norm:
final_embedding = rms_norm(final_embedding, model["norm.weight"])
print(final_embedding.shape) # torch.Size([17, 4096])
12. 解码输出字符
解码最后一维特征,即 final_embedding[-1]
,输出的权重是 [128256, 4096]
,即输出向量是 4096 维,包括 128256 个 Token:
print(model["output.weight"].shape) # torch.Size([128256, 4096])
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
print(logits.shape) # torch.Size([128256])
next_token = torch.argmax(logits, dim=-1)
print(next_token) # tensor(2983)
output_v = tokenizer.decode([next_token.item()])
print(output_v) # '42'
完整的输入和输出:
"the answer to the ultimate question of life, the universe, and everything is "
"42"
13. 完整源码
参考: llama3-from-scratch,完整源码如下:
import json
import os
from pathlib import Path
import tiktoken
import torch
from tiktoken.load import load_tiktoken_bpe
from tqdm import tqdm
os.environ['CUDA_VISIBLE_DEVICES'] = "4" #(代表仅使用第0,1号GPU)
def infer_for_llama3(prompt, model_path, tokenizer_path, config_path):
# model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
# tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"
# config_path = "Meta-Llama-3-8B/original/params.json"
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
name=Path(tokenizer_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)
model = torch.load(model_path)
with open(config_path, "r") as f:
config = json.load(f)
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])
tokens = [128000] + tokenizer.encode(prompt)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
def rms_norm(tensor, norm_weights):
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
final_embedding = token_embeddings_unnormalized
n_tokens = len(tokens)
zero_to_one_split_into_64_parts = torch.tensor(range(64)) / 64
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
freqs_for_each_token = torch.outer(torch.arange(n_tokens), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
kv_scale = n_heads // n_kv_heads
for layer in tqdm(range(n_layers), "layers"):
qkv_attention_store = []
layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
q_layer = model[f"layers.{layer}.attention.wq.weight"]
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
k_layer = model[f"layers.{layer}.attention.wk.weight"]
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
v_layer = model[f"layers.{layer}.attention.wv.weight"]
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
for head in range(n_heads):
q_layer_head = q_layer[head]
k_layer_head = k_layer[head // kv_scale]
v_layer_head = v_layer[head // kv_scale]
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
mask = torch.triu(mask, diagonal=1)
qk_per_token_after_masking = qk_per_token + mask
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking,
dim=1).to(torch.bfloat16)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention_store.append(qkv_attention)
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
w_layer = model[f"layers.{layer}.attention.wo.weight"]
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
embedding_after_edit = final_embedding + embedding_delta
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
output_after_feedforward = torch.matmul(
torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(
embedding_after_edit_normalized, w3.T), w2.T)
final_embedding = embedding_after_edit + output_after_feedforward
final_embedding = rms_norm(final_embedding, model["norm.weight"])
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
next_token = torch.argmax(logits, dim=-1)
print(f"[Info] next_token: {next_token}")
word = tokenizer.decode([next_token.item()])
return word
That’s all!