LLM - 从头开始实现 LLaMA3 的网络结构与推理流程 教程

news2024/9/29 15:19:27

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/141462669

llama3

LLaMA3 是 Meta 的最新大语言模型,在整体网络设计进行多项升级,显著提升了模型的性能和效率,重要的改进,如下:

  1. 词汇量增加至 128k 个。
  2. 使用 RMS Normalization,即 根均方正则化。
  3. 使用 旋转位置编码 RoPE。
  4. 使用 Grouped Query Attention,即 分组查询注意力,head 数量是 32,4组,即 8 个 KV head。
  5. 使用 SwiGLU Feedforward Network,即 SwiGLU 前馈网络。

详细的推理流程如下:

添加依赖的 Python 库:

pip install tiktoken
pip install matplotlib
pip install blobfile
  • tiktoken:快速的 BPE (字节对编码) 分词器,用于与 类 OpenAI 的模型一起使用。
  • blobfile:用于处理云存储 (如 Amazon S3、Google Cloud Storage 等) 的 Python 库。

1. 加载 Tokenizer (分词器)

Tokenizer (分词器),将单词切分成 Token,常见的是 BPE (Byte Pair Encoding),字节对编码。

Byte Pair Encoding,即 BPE,也称为 digram coding 是一种算法,用于将文本字符串编码成表格形式,在下游模型中使用,BPE 是一种数据压缩技术,被 OpenAI 用于预训练 GPT 模型时的分词,被许多 Transformer 模型广泛使用。

Tokenizer 的源码:

from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json
import os
import matplotlib.pyplot as plt
os.environ['CUDA_VISIBLE_DEVICES'] = "0"  #(代表仅使用第0,1号GPU)

# Meta-Llama-3-8B 来自于 HuggingFace
model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"

special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]

mergeable_ranks = load_tiktoken_bpe(tokenizer_path)

tokenizer = tiktoken.Encoding(
    name=Path(tokenizer_path).name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)

tokenizer.decode(tokenizer.encode("hello world!"))

关于 tiktoken.Encoding 的参数说明:

  • name=Path(tokenizer_path).name:指定编码对象的名称。
  • pat_str:正则表达式,用于定义文本的分词规则,英文的默认参数。
  • mergeable_ranks=load_tiktoken_bpe(tokenizer_path),加载 BPE 编码的规则,属于模型的一部分。
  • special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},定了特殊 token 的映射关系。将特殊 token 映射到一个整数值,这个整数值是基于已加载的 BPE 编码规则的长度计算得出的,顺次增加。

测试 hello world!,编码和解码,输出一致,输出:

hello world!

2. 加载 Model (模型)

加载 Llama3 的模型参数,直接使用 pth 模型

model = torch.load(model_path)
print(json.dumps(list(model.keys())[:20], indent=4))

路径:Meta-Llama-3-8B/original/consolidated.00.pth

加载 Llama3 的模型参数:

with open("Meta-Llama-3-8B/original/params.json", "r") as f:
    config = json.load(f)

参数如下:

{
	'dim': 4096,
	'n_layers': 32,
	'n_heads': 32,
	'n_kv_heads': 8,
	'vocab_size': 128256,
	'multiple_of': 1024,
	'ffn_dim_multiplier': 1.3,
	'norm_eps': 1e-05,
	'rope_theta': 500000.0
}
  • dim:隐藏状态维度,即 4096,隐藏状态是模型在处理输入序列时的内部表示。
  • n_layers:模型的层数,即 32,层数决定模型的复杂度和表示能力。
  • n_heads:自注意力机制中的头数,即 32,自注意力机制用于捕捉输入序列中的不同关系。
  • n_kv_heads:键值注意力机制中的头数,即 8,键值对注意力机制用于处理不同类型的信息。
  • vocab_size:词汇表的大小,即 128256,词汇表包含模型可以处理的所有单词和标记。
  • multiple_of:模型的维度必须是这个值的倍数,即 1024,有助于优化模型的计算效率。
  • ffn_dim_multiplier:前馈神经网络 (FFN) 的维度相对于隐藏状态维度的倍数,这里设置为 1.3,FFN 用于对隐藏状态进行非线性变换。
  • norm_eps:归一化层的 epsilon 值, 1e-05,归一化层用于规范化模型的中间表示。
  • rope_theta:RoPE (Rotary Position Embedding) 的参数,即 500000.0,ROPE 是编码方式,用于处理序列中的相对顺序信息。

缓存参数:

dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])

3. 文本 - Token - Embedding

文本 转换 Token

prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = [128000] + tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)

输出:

[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']

Token 转换 Embedding

embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape

输出:

torch.Size([17, 4096])

4. 实现 RMS Normalization (根均方正则化)

具体RMS Normalization 的函数如下:

# def rms_norm(tensor, norm_weights):
#     rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5
#     return tensor * (norm_weights / rms)
def rms_norm(tensor, norm_weights):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

RMS Normalization,即 Root Mean Square Normalization,根均方正则化,对于网络中层的输入进行规范化的方法,目的是使模型具有重新缩放不变性,并且具备隐式学习率自适应能力,相比于 Layer Normalization,RMS Normalization 更加高效,因为计算复杂度较低,RMS Normalization 公式如下:
x ‾ = x i R M S ( x ) + ϵ ∗ g i ,   w h e r e   R M S ( x ) = 1 n ∑ i = 1 n x i 2 \overline{x} = \frac{x_{i}}{RMS(x)+\epsilon}*g_{i},\ where \ RMS(x)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_{i}^{2}} x=RMS(x)+ϵxigi, where RMS(x)=n1i=1nxi2
相对于 Layer Normalization 和 RMS Normalization,Layer Normalization 包含缩放和平移两个部分,RMS Normalization 去除了平移部分,只保留了缩放部分。研究表明 LayerNorm 取得成功的关键是缩放部分的缩放不变性,而不是平移部分的平移不变性。RMS Normalization 相比于 Layer Normalization,减少计算均值和平移系数的部分,训练速度更快,效果基本相当,甚至有所提升。
μ = ∑ i = 1 n x i σ = 1 n ∑ i = 1 n ( x i − μ i ) 2 y = x − μ σ + ϵ γ + β \mu=\sum_{i=1}^{n}x_{i} \\ \sigma=\sqrt{\frac{1}{n}\sum_{i=1}^{n}(x_{i}-\mu_{i})^{2}} \\ y=\frac{x-\mu}{\sigma+\epsilon}\gamma+\beta μ=i=1nxiσ=n1i=1n(xiμi)2 y=σ+ϵxμγ+β
运行 RMS Normalization

token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"])
print(token_embeddings.shape)

输出:

torch.Size([17, 4096])

5. 计算 Self Attention (自注意力) Query

Llama3 的 QKV,以及 O 的相关权重:

print(
    model["layers.0.attention.wq.weight"].shape,
    model["layers.0.attention.wk.weight"].shape,
    model["layers.0.attention.wv.weight"].shape,
    model["layers.0.attention.wo.weight"].shape
)

输出:

torch.Size([4096, 4096])  # wq
torch.Size([1024, 4096])	# wk
torch.Size([1024, 4096])	# wv
torch.Size([4096, 4096])	# vo

Multi-Head Attention,Head (头) 的数量是 n_heads = 32,即:

q_layer0 = model["layers.0.attention.wq.weight"]
head_dim = q_layer0.shape[0] // n_heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)
print(q_layer0.shape)

Query 的权重:

torch.Size([32, 128, 4096])

每个 Query Head 的维度是 [128, 4096],即:

q_layer0_head0 = q_layer0[0]
print(q_layer0_head0.shape)  # torch.Size([128, 4096])

Token 的维度是 torch.Size([17, 4096]),则与权重相乘,输出是 [17, 128],即:

q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)
print(q_per_token.shape)	# torch.Size([17, 128])

6. 实现 旋转位置编码 RoPE (Rotary Positional Encoding)

使用 RoPE,旋转位置编码,转换 Token 维度:

q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
print(q_per_token_split_into_pairs.shape)  # torch.Size([17, 64, 2])

构建旋转位置编码:

zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
print(zero_to_one_split_into_64_parts)

输出:

tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])

频率,rope_theta = 500000.0 来自于模型参数:

freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
print(freqs)

输出:

tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
        2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
        8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
        2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
        7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
        2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
        6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
        1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
        5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
        1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
        4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])

转换成极坐标:

freqs_for_each_token = torch.outer(torch.arange(17), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
print(freqs_cis.shape)  # torch.Size([17, 64])

# viewing tjhe third row of freqs_cis
value = freqs_cis[3]
plt.figure()
for i, element in enumerate(value[:17]):
    plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}")
    plt.annotate(f"{i}", xy=(element.real, element.imag), color='red')
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.title('Plot of one row of freqs_cis')
plt.show()

极坐标的图像:

polar

将 Query 的编码转换成 复数(complex) 形式,即 torch.Size([17, 64, 2]) 转换成 torch.Size([17, 64]),再与 freqs_cis 相乘,即:

q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
print(q_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
print(q_per_token_as_complex_numbers_rotated.shape) # torch.Size([17, 64])

再将 复数(complex) 转换成实数,维度增加 2 维,则由 torch.Size([17, 64]) 转换成 torch.Size([17, 64, 2]),则:

q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
print(q_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])

再转换成 Query 的维度,即 torch.Size([17, 64, 2]) 转换成 torch.Size([17, 128]),这样,Query 向量与 RoPE 位置编码相乘,即:

q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
print(q_per_token_rotated.shape) # torch.Size([17, 128])

7. 计算 Self Attention (自注意力) Key 与 QK Mask

将 Key 的 权重 weight,拆分成 n_kv_heads,输出维度 [1028, 4096] 转换 [8, 128, 4096],即:

k_layer0 = model["layers.0.attention.wk.weight"]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
print(k_layer0.shape) # torch.Size([8, 128, 4096])

Key 添加 RoPE 的位置编码,与 Query 类似,即:

k_layer0_head0 = k_layer0[0]
print(k_layer0_head0.shape) # torch.Size([128, 4096])
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)  # [17, 4096] x [4096, 128] = [17, 128]
print(k_per_token.shape) # torch.Size([17, 128])
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
print(k_per_token_split_into_pairs.shape) # torch.Size([17, 64, 2])
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
print(k_per_token_as_complex_numbers.shape) # torch.Size([17, 64])
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
print(k_per_token_split_into_pairs_rotated.shape) # torch.Size([17, 64, 2])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
print(k_per_token_rotated.shape) # torch.Size([17, 128])

Query 矩阵与 Key 矩阵,计算:

qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
print(qk_per_token.shape)

输出:

torch.Size([17, 17])

自注意力机制的矩阵公式:
A = s o f t m a x ( Q K ⊤ d k ) V A=softmax(\frac{QK^{\top}}{\sqrt{d_{k}}})V A=softmax(dk QK)V

显示这个注意力矩阵:

def display_qk_heatmap(qk_per_token):
    _, ax = plt.subplots()
    im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
    ax.set_xticks(range(len(prompt_split_as_tokens)))
    ax.set_yticks(range(len(prompt_split_as_tokens)))
    ax.set_xticklabels(prompt_split_as_tokens)
    ax.set_yticklabels(prompt_split_as_tokens)
    ax.figure.colorbar(im, ax=ax)
    
display_qk_heatmap(qk_per_token)

矩阵显示:

map

Decoder 的 Mask 矩阵:

mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)  # 最大负数
mask = torch.triu(mask, diagonal=1)
print(mask)
qk_per_token_after_masking = qk_per_token + mask
display_qk_heatmap(qk_per_token_after_masking)

矩阵显示:

mask

Softmax 之后的概率矩阵:

qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
display_qk_heatmap(qk_per_token_after_masking_after_softmax)

矩阵显示:

softmax

8. 计算 Self-Attention (自注意力) Value 与 QKV

Value 矩阵,与 Key 矩阵维度相同,都是 n_kv_heads=8, 但是 Value 不需要使用位置编码,最终计算 qkv_attention 矩阵:

v_layer0 = model["layers.0.attention.wv.weight"]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
print(v_layer0.shape) # torch.Size([8, 128, 4096])
v_layer0_head0 = v_layer0[0]
print(v_layer0_head0.shape) # torch.Size([128, 4096])
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
print(v_per_token.shape) # torch.Size([17, 128])
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
print(qkv_attention.shape)  # torch.Size([17, 128])

9. 计算 多头注意力 Grouped Query Attention

在 Llama3 中,头 n_heads=32 的数量是 32 个,KV头 n_kv_heads=8 的数量是 8 个,即每 4 个 Query 共享 1 组 KV。也就是说,多头自注意力 n_heads=32 与 KV 头 n_kv_heads=8,融合到在一起,相当于 Query 的 1~4,使用相同的 Key 和 Value,输出 32 个 Head, [32, 17, 128] 维度,32x128 = 4096,与 Embedding 的维度相同,即:

qkv_attention_store = []
scale = n_heads // n_kv_heads  # 32 / 8 = 4

for head in range(n_heads):
    q_layer0_head = q_layer0[head]
    k_layer0_head = k_layer0[head//scale] # key weights are shared across 4 heads
    v_layer0_head = v_layer0[head//scale] # value weights are shared across 4 heads
    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)

    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)

    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)

    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
    mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
    mask = torch.triu(mask, diagonal=1)
    qk_per_token_after_masking = qk_per_token + mask
    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention_store.append(qkv_attention)

len(qkv_attention_store)  # 32 个 head

GQA (Grouped Query Attention),与 MHA (Multi-Head Attention) 、MQA (Multi-Query Attention) 的区别如下:

  • MHA 是一种基础的注意力机制,通过将输入分割成多个头 (heads) 来并行计算注意力,每个头学习输入的不同部分,最终将结果合并,以捕获序列的不同方面信息。
  • MQA 是优化的注意力机制,通过让所有头共享相同的键 (keys) 和值 (values),减少了参数量和计算量,从而加快了推理速度,但可能会牺牲一些精度。
  • GQA 是 MHA 和 MQA 的折中方案,将查询头 (Query Heads) 分组,每组共享一个键和值,而不是所有头都共享,能够在减少计算量的同时,保持更多的多样性,从而在推理速度和模型精度之间取得平衡。

再融合 32 个 Head,即完成数据转换:

stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
print(stacked_qkv_attention.shape) # torch.Size([17, 4096])

输出权重的维度,即 [4096, 4096],再进行线性变换,输入经过自注意力的 Embedding,即:

w_layer0 = model["layers.0.attention.wo.weight"]
print(w_layer0.shape) # torch.Size([4096, 4096])
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
print(embedding_delta.shape) # torch.Size([17, 4096])

再进行残差连接,偏移 embedding_delta 相加的是非正则的矩阵(token_embeddings_unnormalized),即:

embedding_after_edit = token_embeddings_unnormalized + embedding_delta
print(embedding_after_edit.shape)  # torch.Size([17, 4096])

embedding_after_edit_normalized 进行 RMS 正则化,即:

embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
print(embedding_after_edit_normalized.shape)  # torch.Size([17, 4096])

10. 实现 SwiGLU 前馈网络 (SwiGLU Feedforward Network)

SwiGLU 前馈网络 的具体操作如下:

w1 = model["layers.0.feed_forward.w1.weight"]  # torch.Size([14336, 4096])
w2 = model["layers.0.feed_forward.w2.weight"]  # torch.Size([4096, 14336]), 扩展3.5倍
w3 = model["layers.0.feed_forward.w3.weight"]  # torch.Size([14336, 4096])

tmp1 = torch.matmul(embedding_after_edit_normalized, w1.T)  # 提升维度, [17, 14336]
tmp2 = torch.matmul(embedding_after_edit_normalized, w3.T)  # 提升维度, [17, 14336]

output_after_feedforward = torch.matmul(torch.functional.F.silu(tmp1) * tmp2, w2.T)
print(output_after_feedforward.shape) # torch.Size([17, 4096])

# 残差连接
layer_0_embedding = embedding_after_edit+output_after_feedforward
print(layer_0_embedding.shape)  # torch.Size([17, 4096])

SwiGLU FFN 的网络如下:
S w i G L U   F F N = ( S i L U ( x W 1 ⊤ ) ⊙ x W 3 ⊤ ) x W 2 ⊤ SwiGLU\ FFN=(SiLU(xW_{1}^{\top}) \odot xW_{3}^{\top})xW_{2}^{\top} SwiGLU FFN=(SiLU(xW1)xW3)xW2
SwiGLU 相比于 ReLU 函数的优势:

  1. Swish 对于负值的响应相对较小,克服 ReLU 的输出始终为零
  2. GLU 具有门控特性,根据输入决定信息是否通过或过滤,使网络更有效地学习到有用的表示,有助于提高模型的泛化能力。

关于 F.silu 是 swish 激活函数:
s i l u ( x ) = x ∗ σ ( x ) σ ( x ) = 1 1 + e − x silu(x)=x*\sigma(x) \\ \sigma(x) = \frac{1}{1+e^{-x}} silu(x)=xσ(x)σ(x)=1+ex1
绘制函数:

def draw(func):
    x = np.arange(-10, 10, 0.1)
    y = []
    x_torch = torch.from_numpy(x)
    for t in x_torch:
        y_1 = func(t)
        y_1 = y_1.numpy()
        y.append(y_1)
    plt.plot(x, y, label="silu")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.xlim(-7, 7)
    plt.ylim(-7, 7)
    plt.grid()
    plt.legend()
    plt.show()

ReLU 图像:

draw(torch.functional.F.silu)
def my_silu(x):
    t = 1 / (1 + torch.exp(-x))
    return x*t

SiLU

参考:torch.nn.SiLU

11. 实现网络的循环多层计算

每一层都运行相同模块,层数 n_layers=32,即:

final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):
    qkv_attention_store = []
    layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
    q_layer = model[f"layers.{layer}.attention.wq.weight"]
    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
    k_layer = model[f"layers.{layer}.attention.wk.weight"]
    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = model[f"layers.{layer}.attention.wv.weight"]
    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    for head in range(n_heads):
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]
        v_layer_head = v_layer[head//4]
        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        qk_per_token_after_masking = qk_per_token + mask
        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
        qkv_attention_store.append(qkv_attention)

    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    embedding_after_edit = final_embedding + embedding_delta
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
    w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
    w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
    w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
    final_embedding = embedding_after_edit+output_after_feedforward

再执行 RMS Norm:

final_embedding = rms_norm(final_embedding, model["norm.weight"])
print(final_embedding.shape)  # torch.Size([17, 4096])

12. 解码输出字符

解码最后一维特征,即 final_embedding[-1],输出的权重是 [128256, 4096],即输出向量是 4096 维,包括 128256 个 Token:

print(model["output.weight"].shape) # torch.Size([128256, 4096])
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
print(logits.shape) # torch.Size([128256])
next_token = torch.argmax(logits, dim=-1)
print(next_token)  # tensor(2983)
output_v = tokenizer.decode([next_token.item()])
print(output_v)  # '42'

完整的输入和输出:

"the answer to the ultimate question of life, the universe, and everything is "
"42"

13. 完整源码

参考: llama3-from-scratch,完整源码如下:

import json
import os
from pathlib import Path

import tiktoken
import torch
from tiktoken.load import load_tiktoken_bpe
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = "4"  #(代表仅使用第0,1号GPU)


def infer_for_llama3(prompt, model_path, tokenizer_path, config_path):
    # model_path = "Meta-Llama-3-8B/original/consolidated.00.pth"
    # tokenizer_path = "Meta-Llama-3-8B/original/tokenizer.model"
    # config_path = "Meta-Llama-3-8B/original/params.json"
    special_tokens = [
                         "<|begin_of_text|>",
                         "<|end_of_text|>",
                         "<|reserved_special_token_0|>",
                         "<|reserved_special_token_1|>",
                         "<|reserved_special_token_2|>",
                         "<|reserved_special_token_3|>",
                         "<|start_header_id|>",
                         "<|end_header_id|>",
                         "<|reserved_special_token_4|>",
                         "<|eot_id|>",  # end of turn
                     ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
    mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
    tokenizer = tiktoken.Encoding(
        name=Path(tokenizer_path).name,
        pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
        mergeable_ranks=mergeable_ranks,
        special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
    )

    model = torch.load(model_path)
    with open(config_path, "r") as f:
        config = json.load(f)

    dim = config["dim"]
    n_layers = config["n_layers"]
    n_heads = config["n_heads"]
    n_kv_heads = config["n_kv_heads"]
    vocab_size = config["vocab_size"]
    multiple_of = config["multiple_of"]
    ffn_dim_multiplier = config["ffn_dim_multiplier"]
    norm_eps = config["norm_eps"]
    rope_theta = torch.tensor(config["rope_theta"])

    tokens = [128000] + tokenizer.encode(prompt)
    tokens = torch.tensor(tokens)
    prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]

    embedding_layer = torch.nn.Embedding(vocab_size, dim)
    embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
    token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)

    def rms_norm(tensor, norm_weights):
        return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

    final_embedding = token_embeddings_unnormalized

    n_tokens = len(tokens)
    zero_to_one_split_into_64_parts = torch.tensor(range(64)) / 64
    freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
    freqs_for_each_token = torch.outer(torch.arange(n_tokens), freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)

    kv_scale = n_heads // n_kv_heads
    for layer in tqdm(range(n_layers), "layers"):
        qkv_attention_store = []
        layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
        q_layer = model[f"layers.{layer}.attention.wq.weight"]
        q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
        k_layer = model[f"layers.{layer}.attention.wk.weight"]
        k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
        v_layer = model[f"layers.{layer}.attention.wv.weight"]
        v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
        for head in range(n_heads):
            q_layer_head = q_layer[head]
            k_layer_head = k_layer[head // kv_scale]
            v_layer_head = v_layer[head // kv_scale]
            q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
            k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
            v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
            q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
            q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
            q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
            q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
            k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
            k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
            k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
            k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
            qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
            mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            qk_per_token_after_masking = qk_per_token + mask
            qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking,
                                                                                   dim=1).to(torch.bfloat16)
            qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
            qkv_attention_store.append(qkv_attention)

        stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
        w_layer = model[f"layers.{layer}.attention.wo.weight"]
        embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
        embedding_after_edit = final_embedding + embedding_delta
        embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
        w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
        w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
        w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
        output_after_feedforward = torch.matmul(
            torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(
                embedding_after_edit_normalized, w3.T), w2.T)
        final_embedding = embedding_after_edit + output_after_feedforward

    final_embedding = rms_norm(final_embedding, model["norm.weight"])
    logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
    next_token = torch.argmax(logits, dim=-1)
    print(f"[Info] next_token: {next_token}")
    word = tokenizer.decode([next_token.item()])
    return word

That’s all!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2066556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AMEYA360:上海雷卯MOSFET器件参数:TJ、TA、TC到底讲啥?

近日&#xff0c;经常被问及MOSFET器件的参数计算问题。在本文中&#xff0c;AMEYA360将分享关于MOSFET中几个关键温度参数的计算方法&#xff1a;TJ(结温)、TA(环境温度)和TC(外壳温度)。 1. MOSFET温度参数的重要性 在电力电子应用中&#xff0c;温度是影响MOSFET性能和寿命的…

探索大型多模态智能代理的前沿进展

人工智能咨询培训老师叶梓 转载标明出处 在人工智能领域&#xff0c;代理被定义为能够感知环境并基于这些感知做出决策以实现特定目标的系统。尽管早期的代理在特定领域表现出了专业性&#xff0c;但它们通常缺乏适应性和泛化能力&#xff0c;现实世界的场景往往涉及超出文本的…

WinTune 系统基准测试:让你的电脑性能飞速提升

前言 你是否曾经为了等待电脑开机而焦急万分&#xff1f;是否因为系统卡顿而错过了重要的工作截止日期&#xff1f;是否渴望在繁忙的工作中找到一丝轻松&#xff0c;让加班成为过去式&#xff1f;如果你有这些烦恼&#xff0c;那么可以试试 WinTune 这款工具&#xff1b;它是一…

2024年电工(高级)证考试题库及电工(高级)试题解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年电工&#xff08;高级&#xff09;证考试题库及电工&#xff08;高级&#xff09;试题解析是安全生产模拟考试一点通结合&#xff08;安监局&#xff09;特种作业人员操作证考试大纲和&#xff08;质检局&#…

Python接口自动化测试详解

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、环境搭建 python unittest requests实现http请求的接口自动化Python的优势&#xff1a;语法简洁优美, 功能强大, 标准库跟第三方库灰常强大&#xff0c;建…

如何加密文档?电脑文件安全加密详细操作步骤(10种方法)

防患于未然&#xff0c;智者之举也。 文档与电脑文件的安全加密&#xff0c;正如古时城门深锁、密函暗藏&#xff0c;实为守护信息安全的智慧之举。 本文将引领您穿越古今&#xff0c;以十种详尽的方法&#xff0c;探讨如何在电脑上安全加密文档&#xff0c;确保您的信息固若金…

阿里云服务器的基本使用

1、购买云服务器 1. 注册阿里云账号&#xff0c;登录进去选择产品&#xff0c;阿里云目前有云服务器试用的政策&#xff0c;对于新手学习者&#xff0c;我们可以选择一个试用服务器 2. 选择服务器之后创建实例&#xff08;选择试用之后根据提示一步一步创建实例&#xff09;&…

关于shell输出颜色的事情

实例 # echo -e "\e[1;33;41m test content \e[0m"分析&#xff1a; 1、-e&#xff1a;转义起始符&#xff0c;等同于\033&#xff0c;表示定义一个转义序列 2、[&#xff1a;表示开始定义颜色 3、1;33;41&#xff1a;其中1表示高亮&#xff0c;33表示字体颜色为黄色…

ubuntu设置jupyter远程连接

一、配置远程连接 我是在unbuntu虚拟环境中操作的&#xff0c;&#xff08;要安装使用虚拟环境请看&#xff1a;ubuntu安装虚拟环境-CSDN博客&#xff09; step1&#xff1a;生成配置文件 jupyter notebook --generate-config 这样在~/.jupyter文件夹下就有 jupyter_noteboo…

AI绘画SD必学技能—从零开始训练你的专属Lora 模型!StableDiffusion模型训练保姆级教程建议收藏!

大家好&#xff0c;我是画画的小强 接触AI绘画的小伙伴&#xff0c;一定听过Lora。 Lora模型全称是&#xff1a;Low-Rank Adaptation of Large Language Models&#xff0c;可以理解为Stable-Diffusion中的一个插件&#xff0c;在生成图片时&#xff0c;Lora模型会与大模型结…

要做实施先做人

文/杨长春 作者简介&#xff1a;某IT公司项目总监&#xff0c;资深IT博主&#xff0c;专注于IT项目知识分享&#xff0c;著有《实战需求分析》、《软件需求分析实战》、《数字化管理软件实施》。 圣人曰&#xff0c;要做实施先做人。 作为一个软件项目的实施者&#xff0c;项目…

如何探索Sui DeFi生态

无论你是想进行tokens兑换、探索NFT世界&#xff0c;还是只是想借出资产以赚取奖励&#xff0c;Sui的DeFi生态都有适合你的内容。由于Sui原生的特性&#xff0c;这些apps能够应对DeFi中的常见挑战&#xff0c;例如通过DeepBook解决流动性问题。 一个健康的DeFi生态由几个关键应…

RFID光触发标签的特性、应用与传统RFID标签的差别

在当今数字化、智能化的时代浪潮中&#xff0c;RFID技术作为一种非接触式自动识别技术&#xff0c;已经在众多领域得到了广泛应用。而RFID光触发标签作为这一技术的创新发展&#xff0c;正以其独特的优势引领着行业的变革。 一、RFID光触发标签的特性 &#xff08;一&#xf…

GitHub配置SSH:一步步教你如何轻松连接远程仓库

GitHub配置SSH&#xff1a;一步步教你如何轻松连接远程仓库 优点&#xff1a;具体步骤1.检查本地SSH密钥2. 生成 ssh key3. 获取并保存公钥内容4.添加公钥到GitHub账户5.验证SSH设置是否成功 SSH工作原理 主页传送门&#xff1a;&#x1f4c0; 传送 优点&#xff1a; 提高安全…

Android13禁用Settings里面的Force Stop 強制停止按钮

总纲 android13 rom 开发总纲说明 目录 1.前言 2.问题分析 3.代码修改 4.编译 5.彩蛋 1.前言 禁用Settings里面的 強制停止按钮,禁用下面这个按钮 2.问题分析 根据文本找到对应的位置 搜索 Force stop 或者 強制停止,结果 ./packages/apps/Settings/res/values/s…

领夹麦克风哪个品牌好?揭秘选购无线麦克风时的五大隐藏风险

随着短视频行业的兴起&#xff0c;几乎人人都会拍些视频分享日常&#xff0c;更有一些人成为了专职的短视频内容的创作者。其实无论是专业的或是非专业的&#xff0c;我们在拍摄过程中&#xff0c;都会使用到一些辅助工具&#xff0c;比如摄影支架、补光灯、麦克风等&#xff0…

三种方法加密图纸!2024如何对CAD图纸进行加密?分享给你

“机事不密则害成&#xff0c;是以君子慎密而不出也。” 此言道出了保密的重要性&#xff0c;尤其是在今日数字化时代&#xff0c;图纸作为设计领域的核心资料&#xff0c;其安全性更是至关重要。 CAD图纸作为设计行业的基石&#xff0c;不仅承载着设计师的心血与智慧&#x…

超详细前端AI蔬菜水果生鲜识别应用优化之路

目录 背景原理技术选型技术栈 构造封装优化模型选择让模型加载更快张量释放 让indexddb更快将图片拆出单独建表特征向量降维 后续规划模型的下发更新模型的增强学习识别数据的上传和下发 背景 先定性&#xff0c;带AI识别的生鲜收银机早就上市了&#xff0c;目前学习的只能说是…

第48课 Scratch入门篇:仙女棒

仙女棒 故事背景: 夜空中,仙女棒划破天空,划过一条一条美丽的光线!! 程序原理: 自己动手设计一个动态变化的角色,实现美丽的仙女棒;仙女棒移动的轨迹,是通过背景的重叠而产生的,这里就是我们今天要学习的一个新知识:“图章”,图章跟我们平时盖章类似,就是复制一个…

C++概观:并发及实用工具(A Tour of C++: Concurrency and Utilities)

&#xff08;说明&#xff1a;本章内容讲的主要是 c11 标准相对于之前的标准新增加的内容。本书作者是 c 之父 Bjarne Stroustrup&#xff0c;这位作者的行文风格就是站在c的设计者角度进行讲解&#xff0c;内容极其丰富&#xff0c;但并没有像传统编程书籍那样事无具细地罗列知…