llama3 implemented from scratch 笔记

news2024/10/10 0:52:02

github地址:https://github.com/naklecha/llama3-from-scratch?tab=readme-ov-file

分词器的实现

from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json
import matplotlib.pyplot as plt

tokenizer_path = "Meta-Llama-3-8B/tokenizer.model"
special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
    name=Path(tokenizer_path).name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)

tokenizer.decode(tokenizer.encode("hello world!"))

读取模型文件

model = torch.load("Meta-Llama-3-8B/consolidated.00.pth")
print(json.dumps(list(model.keys())[:20], indent=4))
[
    "tok_embeddings.weight",
    "layers.0.attention.wq.weight",
    "layers.0.attention.wk.weight",
    "layers.0.attention.wv.weight",
    "layers.0.attention.wo.weight",
    "layers.0.feed_forward.w1.weight",
    "layers.0.feed_forward.w3.weight",
    "layers.0.feed_forward.w2.weight",
    "layers.0.attention_norm.weight",
    "layers.0.ffn_norm.weight",
    "layers.1.attention.wq.weight",
    "layers.1.attention.wk.weight",
    "layers.1.attention.wv.weight",
    "layers.1.attention.wo.weight",
    "layers.1.feed_forward.w1.weight",
    "layers.1.feed_forward.w3.weight",
    "layers.1.feed_forward.w2.weight",
    "layers.1.attention_norm.weight",
    "layers.1.ffn_norm.weight",
    "layers.2.attention.wq.weight"
]
with open("Meta-Llama-3-8B/params.json", "r") as f:
    config = json.load(f)
config
{'dim': 4096,
 'n_layers': 32,
 'n_heads': 32,
 'n_kv_heads': 8,
 'vocab_size': 128256,
 'multiple_of': 1024,
 'ffn_dim_multiplier': 1.3,
 'norm_eps': 1e-05,
 'rope_theta': 500000.0}
dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = torch.tensor(config["rope_theta"])

将文本转换为 tokens(这里没有手动实现分词器)

这里用 tiktoken 作为 tokenizer

prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = [128000] + tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)
[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]
['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']

将令牌嵌入(这里用的内置的神经网络模块,也没有手动实现)

总之,[17, 1]的 tokens 现在变成了 [17, 4096]的嵌入向量
在这里插入图片描述

embedding_layer = torch.nn.Embedding(vocab_size, dim)
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape
torch.Size([17, 4096])

使用均方根 RMS 对嵌入进行归一化

这里并不会进行形状的改变,值只是进行了归一化,为了防止除以零的情况,会设置一个 norm_eps
在这里插入图片描述

# def rms_norm(tensor, norm_weights):
#     rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5
#     return tensor * (norm_weights / rms)
def rms_norm(tensor, norm_weights):
    return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights

tensor.pow(2):

这一步将输入张量 tensor 中的每个元素进行平方操作。假设 tensor 的形状为 (batch_size, seq_len, hidden_dim),那么 tensor.pow(2) 的结果形状仍然是 (batch_size, seq_len, hidden_dim),但每个元素都被平方了。

tensor.pow(2).mean(-1, keepdim=True):

这一步计算张量在最后一个维度(即 hidden_dim 维度)上的均值。mean(-1, keepdim=True) 表示在最后一个维度上求均值,并且保持该维度的形状(即 keepdim=True)。结果的形状为 (batch_size, seq_len, 1)

tensor.pow(2).mean(-1, keepdim=True) + norm_eps:

这一步在均值的基础上加上一个小的常数 norm_eps,以避免除零错误。norm_eps 通常是一个非常小的正数,例如 1e-8。

torch.rsqrt(...):

torch.rsqrt 是平方根的倒数(即 1 / sqrt(x))。这一步计算的是 1 / sqrt(mean + norm_eps),即 RMS 值的倒数。

tensor * torch.rsqrt(...):

这一步将输入张量 tensor 乘以 RMS 值的倒数,从而实现归一化。归一化后的张量在最后一个维度上的 RMS 值为1。

* norm_weights:

最后,将归一化后的张量乘以 norm_weightsnorm_weights 是一个可学习的权重张量,形状为 (hidden_dim,),用于对归一化后的特征进行缩放。

通常,归一化操作会将特征缩放到一个固定的范围,然而,不同的特征可能需要不同的缩放因子来更好地适应模型的需求。通过引入可学习的权重,模型可以根据数据的特点和任务的需求,自动调整每个特征的缩放因子。

构建 transformer 的第一层

在这里插入图片描述

归一化

# 这里是attention之前的normalization
token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"])
token_embeddings.shape
torch.Size([17, 4096])

手动实现注意力

在这里插入图片描述
从模型中加载查询(query)、键(key)、值(value)和输出(output)向量时,我们注意到它们的形状分别是 [4096x4096]、[1024x4096]、[1024x4096]、[4096x4096]。

假设我们有以下形状的矩阵:

query_matrix: [4096x4096]

key_matrix: [1024x4096]

value_matrix: [1024x4096]

output_matrix: [4096x4096]

我们可以通过以下方式解开它们:

解开查询

q_layer0 = model["layers.0.attention.wq.weight"]
head_dim = q_layer0.shape[0] // n_heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)
q_layer0.shape
torch.Size([32, 128, 4096])

32 是 llama3 的注意力头的数量,128 是查询向量的大小,4096 是令牌嵌入的大小。

实现第一层的第一个头

查询权重矩阵的大小是 [128, 4096]

q_layer0_head0 = q_layer0[0]
q_layer0_head0.shape
torch.Size([128, 4096])

现在将查询权重矩阵和令牌嵌入相乘,以接收对令牌的查询

在这里插入图片描述
最终的形状是 [17, 128],这是因为有 17 个令牌,和 128 长度的查询。

q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)
q_per_token.shape
torch.Size([17, 128])

位置编码

当前阶段是,我们为提示(prompt)中的每个令牌都有一个查询向量,但是单独的查询向量并不知道它在提示中的位置,在例子中,使用了三次 “the” 标记的查询向量([1, 128])。使用 RoPE 旋转位置编码来执行这些旋转。
在这里插入图片描述

q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])

这一步将查询向量分成对,并对每对应用旋转角度偏移。
在这里插入图片描述

用复数的点积来旋转向量

# 生成一个从0到1的等间隔序列,分成64个部分。这个序列表示每个部分的归一化位置
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
# 计算频率freqs,这里的rope_theta是llama3给的500000.0
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)

# 生成一个 [17, 64] 的矩阵,其中每一行对应一个标记的频率。torch.outer函数计算两个向量的外积,生成一个矩阵
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
# 将频率转换为复数形式,其中实部为1,虚部为频率。torch.polar函数生成复数形式的向量,其中模为1,相位为频率
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
freqs_cis.shape

等间隔序列 zero_to_one_split_into_64_parts:

tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])

频率:

tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
        2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
        8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
        2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
        7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
        2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
        6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
        1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
        5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
        1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
        4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])

z = r ⋅ e i θ z=r \cdot e^{i \theta} z=reiθ表示一个旋转角度为 θ \theta θ的复数
旋转矩阵中的每一个元素freqs_cis[i,j]可以表示为 e i ⋅ f r e q s _ f o r _ e a c h _ t o k e n [ i , j ] e ^{i⋅{freqs\_for\_each\_token[i,j]}} eifreqs_for_each_token[i,j],其中 i i i是标记的索引, j j j是频率的索引。
这就是所有 token 对应的旋转矩阵,下面进行相乘得到旋转后的所有 token
的查询

现在我们有了每个 token 查询的复数(角度变化向量)

我们可以将我们的查询转换为复数然后进行点积以根据位置旋转查询。

q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_as_complex_numbers.shape
torch.Size([17, 64])
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
q_per_token_as_complex_numbers_rotated.shape
torch.Size([17, 64])

这样就是旋转后的查询。

得到旋转后的向量之后

通过将查询再次从复数看成实数(从[a+bj]的存储形式变成[a, b]),可以得到

q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)
q_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])

旋转后的对现在已经合并,我们现在有了一个新的查询向量,其形状是[17, 128]

q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
q_per_token_rotated.shape
torch.Size([17, 128])

键,几乎和查询的处理是一样的

键也生成维度为 128 的键向量。键的权重数量只有查询(queries)的 1/4,这是因为键的权重在 4 个注意力头之间共享,以减少所需的计算量。键也像查询一样旋转以添加位置信息,因为同样的原因。

k_layer0 = model["layers.0.attention.wk.weight"]
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)
k_layer0.shape
torch.Size([8, 128, 4096])
k_layer0_head0 = k_layer0[0]
k_layer0_head0.shape
torch.Size([128, 4096])
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)
k_per_token.shape
torch.Size([17, 128])
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_split_into_pairs.shape
torch.Size([17, 64, 2])
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_as_complex_numbers.shape
torch.Size([17, 64])
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
k_per_token_split_into_pairs_rotated.shape
torch.Size([17, 64, 2])
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
k_per_token_rotated.shape
torch.Size([17, 128])

在这个阶段,现在有每个令牌的查询和键的旋转值

在这里插入图片描述

下一步,把查询和键相乘

这样做会给我们一个分数,将每个标记与其他标记进行映射。这个分数描述了每个标记的查询与每个标记的键之间的关系。这就是自注意力机制(Self-Attention)😃

注意力分数矩阵(qk_per_token)的形状为 [17x17],其中 17 是提示中的标记数量。

详细解释
在自注意力机制中,我们通过计算查询(queries)和键(keys)之间的点积来生成注意力分数。注意力分数矩阵描述了每个标记与其他标记之间的关系。
在这里插入图片描述

qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5
qk_per_token.shape
torch.Size([17, 17])

现在我们必须 mask 查询键分数

在训练过程中,Llama3 的未来标记的 qk 分数被掩码。
为什么?因为在训练过程中,我们只使用过去的标记来预测未来的标记。
因此,在推理过程中,我们将未来的标记设置为零。
在这里插入图片描述

# 显示注意力分数矩阵的热力图
def display_qk_heatmap(qk_per_token):
    _, ax = plt.subplots()
    # 生成热力图,使用 `viridis` 颜色映射
    im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis')
    ax.set_xticks(range(len(prompt_split_as_tokens)))
    ax.set_yticks(range(len(prompt_split_as_tokens)))
    ax.set_xticklabels(prompt_split_as_tokens)
    ax.set_yticklabels(prompt_split_as_tokens)
    ax.figure.colorbar(im, ax=ax)
    
display_qk_heatmap(qk_per_token)

在这里插入图片描述

# 生成一个掩码矩阵, 初始都为-inf
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
# 将掩码矩阵转换为上三角矩阵,diagonal=1保留对角线下一个元素及其以上的元素,其余为0
mask = torch.triu(mask, diagonal=1)
mask
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
qk_per_token_after_masking = qk_per_token + mask
display_qk_heatmap(qk_per_token_after_masking)

在这里插入图片描述

qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
display_qk_heatmap(qk_per_token_after_masking_after_softmax)

Values

在这里插入图片描述
值权重在每 4 个注意力头(所以总共 8 个注意力头)之间共享,以节省计算量。这意味着每个注意力头使用相同的值权重矩阵。

v_layer0 = model["layers.0.attention.wv.weight"]
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)
v_layer0.shape
torch.Size([8, 128, 4096])

第一层, 第一个权重矩阵为:

v_layer0_head0 = v_layer0[0]
v_layer0_head0.shape
torch.Size([128, 4096])

值向量

在这里插入图片描述
我们现在使用值权重来获取每个标记的注意力值,其大小为 [17x128],其中 17 是提示中的标记数量,128 是每个标记的值向量的维度。

v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)
v_per_token.shape
torch.Size([17, 128])

注意力

在自注意力机制中,我们将注意力分数矩阵与值矩阵相乘,生成最终的注意力输出。注意力输出的形状为 [17x128]

qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention.shape
torch.Size([17, 128])

多头注意力

在这里插入图片描述
我们现在有了第一层和第一个注意力头的注意力值。现在,我将运行一个循环,对第一层的每个注意力头执行与上述单元格相同的数学运算。

qkv_attention_store = []

for head in range(n_heads):
    q_layer0_head = q_layer0[head]
    k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads
    v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads
    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)

    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)

    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)

    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
    mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
    mask = torch.triu(mask, diagonal=1)
    qk_per_token_after_masking = qk_per_token + mask
    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    qkv_attention_store.append(qkv_attention)

len(qkv_attention_store)
32

在这里插入图片描述
我们现在有了第一层所有 32 个注意力头的 qkv_attention 矩阵。接下来,把所有注意力分数合并成一个大小为 [17x4096] 的大矩阵。

stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
stacked_qkv_attention.shape
torch.Size([17, 4096])

权矩阵,最后一个步骤

在这里插入图片描述

w_layer0 = model["layers.0.attention.wo.weight"]
w_layer0.shape

在完成第 0 层注意力机制的最后一步是,将注意力输出与权重矩阵相乘。具体来说,我们将最终的注意力输出矩阵与权重矩阵相乘,生成最终的注意力输出。

torch.Size([4096, 4096])

这是一个简单的线性层,所以我们只需要进行矩阵乘法(matmul)。

embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T)
embedding_delta.shape
torch.Size([17, 4096])

在这里插入图片描述
我们现在有了注意力机制之后的嵌入值变化,这应该加到原始的标记嵌入值上。

embedding_after_edit = token_embeddings_unnormalized + embedding_delta
embedding_after_edit.shape
torch.Size([17, 4096])

我们将其归一化然后运行一个前馈神经网络通过嵌入 δ \delta δ

在这里插入图片描述

embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])
embedding_after_edit_normalized.shape
torch.Size([17, 4096])

在加载前馈网络(Feed-Forward Network, FFN)的权重并实现前馈网络时,我们需要执行以下步骤:

在这里插入图片描述
在 Llama3 中,他们使用了 SwiGLU 前馈网络。这种网络架构在模型需要时能够很好地添加非线性。如今,在大型语言模型(LLMs)中使用这种前馈网络架构是非常标准的。

w1 = model["layers.0.feed_forward.w1.weight"]
w2 = model["layers.0.feed_forward.w2.weight"]
w3 = model["layers.0.feed_forward.w3.weight"]
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
output_after_feedforward.shape
torch.Size([17, 4096])

在 Llama3 中,前馈网络使用了 SwiGLU 架构。具体来说,前馈网络由三个线性层组成,其中第一个线性层的输出通过 Swish 激活函数,然后与第三个线性层的输出相乘,最后通过第二个线性层生成新的嵌入值。

Swish 激活函数:
Swish 激活函数是一种平滑的非线性函数,定义为:
S w i s h ( x ) = x ⋅ σ ( β x ) Swish(x)=x\cdot \sigma(\beta x) Swish(x)=xσ(βx)
其中, σ \sigma σ 是 sigmoid 函数, β \beta β 是一个可学习的参数(通常设为 1). Swish 激活函数在许多情况下表现优于 ReLU 和其他常见的激活函数.

GLU (Gated Linear Unit):
GLU 是一种门控机制,用于控制信息的流动。GLU 的定义为:
G L U ( a , b ) = a ⋅ σ ( b ) GLU(a, b)=a\cdot \sigma(b) GLU(a,b)=aσ(b)
其中, a a a b b b 是两个线性变换的输出, σ \sigma σ是 sigmoid 函数. GLU 通过门控信号 σ ( b ) \sigma(b) σ(b)来控制 a a a 的信息流动.

SwiGLU 结构:
SwiGLU 结合了 Swish 激活函数和 GLU 结构,定义为:
S w i G L U ( x , W 1 , W 2 , W 3 ) = S w i s h ( x W 1 ) ⋅ σ ( x W 3 ) W 2 SwiGLU(x, W_1, W_2, W_3)=Swish(xW_1)\cdot \sigma(xW_3)W_2 SwiGLU(x,W1,W2,W3)=Swish(xW1)σ(xW3)W2

我们终于在第一层之后为每个标记生成了新的编辑后的嵌入值。

在自注意力机制和前馈网络之后,我们为每个标记生成了新的嵌入值。这些新的嵌入值包含了更多的上下文信息,从而提高了模型的性能和理解能力。

在完成之前,我们还有 31 层要处理(只需要一个循环)。

你可以想象这个编辑后的嵌入值包含了第一层中所有查询的信息。现在,每一层都会对提出的问题进行越来越复杂的编码,直到我们有一个嵌入值,它包含了我们需要了解的关于下一个标记的所有信息。

layer_0_embedding = embedding_after_edit+output_after_feedforward
layer_0_embedding.shape
torch.Size([17, 4096])

总和

final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):
    qkv_attention_store = []
    layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
    q_layer = model[f"layers.{layer}.attention.wq.weight"]
    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
    k_layer = model[f"layers.{layer}.attention.wk.weight"]
    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = model[f"layers.{layer}.attention.wv.weight"]
    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    for head in range(n_heads):
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]
        v_layer_head = v_layer[head//4]
        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        qk_per_token_after_masking = qk_per_token + mask
        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
        qkv_attention_store.append(qkv_attention)

    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    embedding_after_edit = final_embedding + embedding_delta
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
    w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
    w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
    w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
    final_embedding = embedding_after_edit+output_after_feedforward

现在我们有了最终的嵌入, 这是模型对下一个令牌的最好猜测

嵌入的形状和常规令牌的形状相同 [ 17 , 4096 ] [17, 4096] [17,4096].
在这里插入图片描述

final_embedding = rms_norm(final_embedding, model["norm.weight"])
final_embedding.shape
torch.Size([17, 4096])

最后, 将嵌入解码成令牌值

在这里插入图片描述
我们将使用输出解码器将最终嵌入解码成令牌

model["output.weight"].shape
torch.Size([128256, 4096])

我们使用最后一个标记的嵌入值来预测下一个值。

根据《银河系漫游指南》这本书,42 是“生命、宇宙以及一切的终极问题的答案”。所以大多数 LLMs 在这里都会回答 42.

# 通过线性层生成 logits 向量, 训练过程中隐式调用了 softmax
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)
logits.shape
torch.Size([128256])

预测的 token number 是 2983, 解码后是 42

next_token = torch.argmax(logits, dim=-1)
next_token
tensor(2983)
tokenizer.decode([next_token.item()])
'42'

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2200568.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据新视界 --大数据大厂之 GraphQL 在大数据查询中的创新应用:优化数据获取效率

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

12.2 Linux_进程间通信_共享内存

概述 什么是共享内存&#xff1a; 共享内存又叫内存映射&#xff0c;可以通过mmap()映射普通文件。 实际上就是将磁盘中的一个文件映射到内存的一个缓冲区中去&#xff0c;这样进程就可以直接将这块空间当作普通内存来访问&#xff0c;不需要再使用I/O中的read/write去访问这…

霍普菲尔德(Hopfield)神经网络求解旅行商问题TSP,提供完整MATLAB代码,复制粘贴即可运行

Hopfield神经网络是以美国物理学家约翰霍普菲尔德&#xff08;John Hopfield&#xff09;的名字命名的。他在1982年提出了这种类型的神经网络模型&#xff0c;因此通常被称为Hopfield网络。旅行商问题&#xff08;Traveling Salesman Problem&#xff0c;TSP&#xff09;是一个…

IEDA创建文件模板

1、点击设置-编辑器-文件与代码模板 2、输入对应的名称、扩展名、文件名 3、复制模板代码-点击应用、确定即可 4、新建配置项目&#xff0c;右键点击新建选择SpringMVC即可&#xff08;刚刚模板中的名称&#xff09;

D32【python 接口自动化学习】- python基础之输入输出与文件操作

day32 文件编码 学习日期&#xff1a;20241009 学习目标&#xff1a;输入输出与文件操作&#xfe63;-44 文件编码&#xff1a; 如何解决不同操作系统的文件乱码问题&#xff1f; 学习笔记&#xff1a; 为什么产生乱码 常见操作系统的文件编码 以不同的编码打开文件 # 以gb…

Linux学习网络编程学习(TCP和UDP)

文章目录 网络编程主要函数介绍1、socket函数2、bind函数转换端口和IP形式的函数 3、listen函数4、accept函数网络模式&#xff08;TCP&UDP&#xff09;1、面向连接的TCP流模式2、UDP用户数据包模式 编写一个简单服务端编程5、connect函数编写一个简单客户端编程 超级客户端…

如何实现不同VLAN间互通?

问题描述 客户要求不同VLAN的PC机互通&#xff0c;如下图拓扑所示。 此外&#xff0c;仅允许在设备 LSW3 上进行配置修改。 分析 由于所有的PC都在同一个网段&#xff0c;当任何一个设备想要和另一个设备通信时&#xff0c;它会首先根据数据交互的流程广播一个ARP请求报文来获…

1. Keepalived概念和作用

1.keepalived概念 (1)解决单点故障(组件免费) (2)可以实现高可用HA机制 (3)基于VRR协议(虚拟路由沉余协议) 2.keepalived双机主备原理

枚举+二分,CF 325B - Stadium and Games

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 325B - Stadium and Games 二、解题报告 1、思路分析 考虑 一个可能的初…

QD1-P8 HTML格式化标签

本节学习&#xff1a;HTML 格式化标签。 本节视频 www.bilibili.com/video/BV1n64y1U7oj?p8 ‍ 一、font 标签 用途&#xff1a;定义文本的字体大小、颜色和 face&#xff08;字体类型&#xff09;。 示例 <!DOCTYPE html> <html><head><meta cha…

10.9QT对话框以及QT的事件机制处理

MouseMoveEvent(鼠标移动事件) widget.cpp #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);// 设置窗口为无边框&#xff0c;去掉标题栏等装饰this->setWi…

如何使用ArcGIS Pro设置一个图层不同标注

在有些时候&#xff0c;需要对某个要素进行突出显示&#xff08;比如省会城市&#xff09;&#xff0c;那就需要标注不同的样式&#xff0c;这里为大家介绍一下一个图层不同标注的方法。 分类标注 现在有一张广东省的行政区划图&#xff0c;想要突出标注广州市&#xff0c;虽…

超详解C++类与对象(中)

目录 1. 构造函数 1.1. 定义 1.2. 注意 2.析构函数 2.1定义 2.2注意 3.拷贝构造函数 3..1. 定义 3.2. 注意 4.运算符重载 4.1. 定义 5. 赋值运算符重载 5.1. 定义 5.2. 注意 ​​​​​​​ &#x1f493; 博客主页&#xff1a;C-SDN花园GGbond ⏩ 文章专…

大模型学习----什么是RAG

大模型快速定制的 RAG&#xff08;Retrieval-Augmented Generation&#xff09;方法 一、什么是 RAG RAG&#xff08;Retrieval-Augmented Generation&#xff09;即检索增强生成&#xff0c;它是一种结合了检索和语言生成的技术&#xff0c;旨在利用外部知识源来增强大型语言…

YOLO11改进|注意力机制篇|引入全局上下文注意力机制GCA

目录 一、【】注意力机制1.1【GCA】注意力介绍1.2【GCA】核心代码 二、添加【GCA】注意力机制2.1STEP12.2STEP22.3STEP32.4STEP4 三、yaml文件与运行3.1yaml文件3.2运行成功截图 一、【】注意力机制 1.1【GCA】注意力介绍 下图是【GCA】的结构图&#xff0c;让我们简单分析一下…

SQL优化 where谓词条件is null优化

1.创建测试表及谓词条件中包含is null模拟语句 create table t641 as select * from dba_objects; set autot trace select SUBOBJECT_NAME,OBJECT_NAME from t641 where OBJECT_NAMEWRI$_OPTSTAT_SYNOPSIS$ and SUBOBJECT_NAME is null; 2.全表扫描逻辑读1237 3.创建等值谓词条…

PE结构之导出表

导出表结构中各种值的意义 ​​​​​​ 根据函数地址表遍历函数名称RVA表,和上面的图是逆过程 //函数地址表 和当前内存中的位置DWORD AddressOfFunctionsFOA RVAToFOA(LPdosHeader, LPexprotDir->AddressOfFunctions);PDWORD LPFunctionsAddressInMemary (PDWORD)((cha…

flask发送邮件

开通邮件IMAP/SMTP服务 以网易邮箱为例 点击开启发送验证后会收到一个密钥&#xff0c;记得保存好 编写代码 安装flask-mail pip install flask-mail在config.py文件中配置邮件信息 MAIL_SERVER&#xff1a;邮件服务器 MAIL_USE_SSL&#xff1a;使用SSL MAIL_PORT&#…

【计算机网络】网络相关技术介绍

文章目录 NAT概述NAT的基本概念NAT的工作原理1. **基本NAT&#xff08;静态NAT&#xff09;**2. **动态NAT**3. **NAPT&#xff08;网络地址端口转换&#xff0c;也称为PAT&#xff09;** 底层实现原理1. **数据包处理**2. **转换表**3. **超时机制** NAT的优点NAT的缺点总结 P…

Linux:多线程中的生产消费模型

多线程 生产消费模型三种关系两个角色一个交易场所交易场所的实现&#xff08;阻塞队列&#xff09;pthread_cond_wait 接口判断阻塞队列的空或满时&#xff0c;需要使用while测试一&#xff1a;单消费单生产案例测试二&#xff1a;多生产多消费案例 生产消费模型 消费者与生产…