简单的二元语言模型bigram实现

news2025/3/10 12:24:06
  • 内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 
  • 项目:https://github.com/karpathy/ng-video-lecture

Bigram模型是基于当前Token预测下一个Token的模型。例如,如果输入序列是`[A, B, C]`,那么模型会根据`A`预测`B`,根据`B`预测`C`,依此类推,实现自回归生成。在生成新Token时,通常只需要最后一个Token的信息,因为每个预测仅依赖于当前Token。

1. 训练batch数据形式

训练数据是:

 训练目标是:

2. 定义词嵌入层

nn.Embedding 层输出的是可学习的浮点数,将token索引 (B,T) 直接映射为logits,即输入(4,8),输出 (4,8,65),其中输入每个数字,被映射成logit向量(这些值通过 F.cross_entropy 内部自动进行 softmax 转换为概率分布),比如上面输入tokens有个24被映射成如下。

logits = [1.0, 0.5, -2.0, ..., 3.2]  # 共65个浮点数

softmax后得到。

probs = [0.15, 0.12, 0.01, ..., 0.20]  # 和为1的概率分布

这样输出的是每个位置的概率分布。

交叉熵函数会自动计算每个位置的概率分布与真实标签之间的损失,并取平均。

简单的大语言模型,基于Bigram的结构,即每个token仅根据前一个token来预测下一个token。具体实现如下。

from torch.nn import functional as F  # 导入PyTorch函数模块
torch.manual_seed(1337)  # 固定随机种子保证结果可复现

class BigramLanguageModel(nn.Module):  # 定义Bigram语言模型类

    def __init__(self, vocab_size):
        super().__init__()  # 继承父类初始化方法
        # 定义词嵌入层:将token索引直接映射为logits
        # 输入输出维度均为vocab_size(词汇表大小)
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        # 前向传播函数
        # idx: inputs, 输入序列 (B, T),B=批次数,T=序列长度
        # targets: 目标序列 (B, T)
        
        # (4,8) -> (4,8,65)
        logits = self.token_embedding_table(idx)  # (B, T, C)
        # 通过嵌入层获得每个位置的概率分布,C=词汇表大小
        
        # (4*8,C), (4*8,) -> (1,)
        B, T, C = logits.shape  # 解包维度:批次数、序列长度、词表大小
        logits = logits.view(B*T, C)  # 展平为二维张量 (B*T, C)
        targets = targets.view(B*T)    # 目标展平为一维张量 (B*T)
        loss = F.cross_entropy(logits, targets)  # 计算交叉熵损失
        
        return logits, loss  # 返回logits(未归一化概率)和损失值

# 假设 vocab_size=65(例如52字母+标点)
vocab_size = 65
m = BigramLanguageModel(vocab_size)  # 实例化模型

# 假设输入数据(代码中未定义):
# xb: 输入批次 (B=4, T=8),例如 tensor([[1,2,3,...], ...])
# yb: 目标批次 (B=4, T=8)
logits, loss = m(xb, yb)  # 执行前向传播

print(logits.shape)  # 输出logits形状:torch.Size([32, 65])
# 解释:32 = B*T = 4*8,65=词表大小(每个位置65种可能)

print(loss)  # 输出损失值:tensor(4.8786, grad_fn=<NllLossBackward>)
# 解释:初始随机参数下,损失值约为-ln(1/65)=4.17,实际值因参数初始化略有波动

3. 代码逻辑分步解释

# 假设输入和目标的形状均为 (B=4, T=8)
# 输入示例(第一个样本):
inputs[0] = [24, 43, 58, 5, 57, 1, 46, 43]
targets[0] = [43, 58, 5, 57, 1, 46, 43, 39]

3.1 Softmax后的概率分布意义 

当模型处理输入序列时,每个位置会输出一个长度为vocab_size的logits向量。即

输入: (4,8);

输出:(4,8,65). 65维度向量是每个输入token的下一个token的概率分布。

例如,当输入序列为 [24, 43, 58, 5, 57, 1, 46, 43] 时:
- 在第1个位置(token=24),模型预测下一个token(对应target=43)的概率分布p[0].shape=(65,),下一个输出是43的概率为p[0][target[0]]=p[0][43];
- 在第2个位置(token=43),模型预测下一个token(对应target=58)的概率分布p[1].shape=(65),下一个输出是58的概率为p[1][target[1]]=p[1][58];
- 以此类推,每个位置的logits经过softmax后得到一个概率分布,即每个输入位置,都会预测下一个token概率分布
具体来说:
   logits.shape = (4, 8, 65) → softmax后形状不变->p.shape=(4,8,65),但每行的65个值变为概率(和为1)
这些概率表示模型认为「当前token的下一个token」是词汇表中各token的可能性。

3.2 交叉熵计算步骤

假设logits初始形状为 (4, 8, 65)
B, T, C = logits.shape  # B=4, T=8, C=65

# 展平logits和targets:
logits_flat = logits.view(B*T, C)  # 形状 (32, 65)
targets_flat = targets.view(B*T)    # 形状 (32,)

# 交叉熵计算(PyTorch内部过程):
# 对logits_flat的每一行(共32行)做softmax,得到概率分布probs (32, 65)
# 对每个样本i,取probs[i][targets_flat[i]],即真实标签对应的预测概率(此概率是下一个token是targets_flat[i]的概率
# 计算负对数损失:loss = -mean(log(probs[i][targets_flat[i]]))(pytorch实现是将targets_flat所谓索引
loss = F.cross_entropy(logits_flat, targets_flat)  # 输出标量值

3.3 示例计算

# 以第一个样本的第一个位置为例:
# 输入token=24,目标token=43
# 模型输出的logits[0,0]是一个65维向量(这里logits.shape=[4,8,65]),例如:
logits_example = logits[0,0]  # 形状 (65,)
probs_example = F.softmax(logits_example, dim=-1)  # 形状 (65,)
# 假设probs_example[43] = 0.15(模型预测下一个token=43的概率为15%)
# 则此位置的损失为 -log(0.15) ≈ 1.897 (注意-log(p)是一个x范围在[0,1]之间单调递减函数)

# 最终损失是所有32个位置类似计算的均值。
# 初始损失约为4.87(接近均匀分布的理论值 -ln(1/65)≈4.17)

4. 测试生成文本

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx) # 没有输入target时,返回的logits未被展平。  
            # focus only on the last time step
            logits = logits[:, -1, :] # (B,T,C) -> (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

以下是Bigram模型生成过程的逐步详解,以输入序列[24, 43, 58, 5, 57, 1, 46, 43]为例,说明模型如何从初始输入[24]开始逐步预测下一个词: 

4.1 初始输入:[24]

  • 输入形状idx = [[24]]B=1批次,T=1序列长度)。

  • 前向传播

    • 通过嵌入层,模型输出logits形状为(1, 1, 65),表示对当前词24的下一个词的预测分数。

    • 假设logits[0, 0, 43] = 5.0(词43的logit较高),其他位置logits较低(如logits[0, 0, :] = [..., 5.0, ...])。

  • 概率分布

    • 对logits应用softmax,得到概率分布probs。例如:

      probs = [0.01, ..., 0.8(对应43), 0.01, ...]  # 总和为1
  • 采样

    • 根据probs,使用torch.multinomial采样,选中词43的概率最大。

  • 更新输入

    • 43拼接到序列末尾,新输入为idx = [[24, 43]](形状(1, 2))。


4.2 输入:[24, 43]

  • 前向传播

    • 模型处理整个序列,输出logits形状为(1, 2, 65),对应两个位置的预测:

      • 第1个位置(词24)预测下一个词(已生成43)。

      • 第2个位置(词43)预测下一个词。

    • 提取最后一个位置的logits:logits[:, -1, :](形状(1, 65))。

    • 假设logits[0, -1, 58] = 6.0(词58的logit较高)。

  • 概率分布

    • probs = [0.01, ..., 0.85(对应58), 0.01, ...]

  • 采样

    • 选中词58

  • 更新输入

    • 新输入为idx = [[24, 43, 58]](形状(1, 3))。


4.3 输入:[24, 43, 58]

  • 前向传播

    • logits形状为(1, 3, 65)

    • 提取最后一个位置(词58)的logits,假设logits[0, -1, 5] = 4.5

  • 概率分布

    • probs = [0.01, ..., 0.7(对应5), ...]

  • 采样

    • 选中词5

  • 更新输入

    • 新输入为idx = [[24, 43, 58, 5]](形状(1, 4))。


4.4 重复生成直到序列完成

  • 后续步骤

    • 输入[24, 43, 58, 5] → 预测词57

    • 输入[24, 43, 58, 5, 57] → 预测词1

    • 输入[24, 43, 58, 5, 57, 1] → 预测词46

    • 输入[24, 43, 58, 5, 57, 1, 46] → 预测词43

  • 最终序列

    • idx = [[24, 43, 58, 5, 57, 1, 46, 43]]

注意:上面输入序列是越来越长的,为何说预测下一个词只跟上一个词有关?如果只跟一个词有关,为何不每次只输入一个词,然后预测下一个词?

虽然理论上可以仅传递最后一个词,但实际实现中传递完整序列的原因(视频作者说的,固定generate函数形式,我这里理解的是代码简洁):

  • 代码简洁性:无需在每次生成时截取最后一个词,直接复用统一的前向传播逻辑;

实验验证

若修改代码,每次仅传递最后一个词:

def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
        last_token = idx[:, -1:]          # 仅取最后一个词 (B, 1)
        logits, _ = self(last_token)       # 输出形状 (B, 1, C)
        probs = F.softmax(logits[:, -1, :], dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

4.5 完整代码

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel(vocab_size)
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2311864.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机视觉之dlib人脸关键点绘制及微笑测试

dlib人脸关键点绘制及微笑测试 目录 dlib人脸关键点绘制及微笑测试1 dlib人脸关键点1.1 dlib1.2 人脸关键点检测1.3 检测模型1.4 凸包1.5 笑容检测1.6 函数 2 人脸检测代码2.1 关键点绘制2.2 关键点连线2.3 微笑检测 1 dlib人脸关键点 1.1 dlib dlib 是一个强大的机器学习库&a…

Windows11下玩转 Docker

一、前提准备 WSL2&#xff1a;Windows 提供的一种轻量级 Linux 运行环境&#xff0c;具备完整的 Linux 内核&#xff0c;并支持更好的文件系统性能和兼容性。它允许用户在 Windows 系统中运行 Linux 命令行工具和应用程序&#xff0c;而无需安装虚拟机或双系统。Ubuntu 1.1 安…

Android 平台架构系统启动流程详解

目录 一、平台架构模块 1.1 Linux 内核 1.2 硬件抽象层 (HAL) 1.3 Android 运行时 1.4 原生 C/C 库 1.5 Java API 框架 1.6 系统应用 二、系统启动流程 2.1 Bootloader阶段 2.2 内核启动 2.3 Init进程&#xff08;PID 1&#xff09; 2.4 Zygote与System Serv…

【C++设计模式】第四篇:建造者模式(Builder)

注意&#xff1a;复现代码时&#xff0c;确保 VS2022 使用 C17/20 标准以支持现代特性。 分步骤构造复杂对象&#xff0c;实现灵活装配 1. 模式定义与用途 核心目标&#xff1a;将复杂对象的构建过程分离&#xff0c;使得同样的构建步骤可以创建不同的表示形式。 常见场景&am…

使用GitLink个人建站服务部署Allure在线测试报告

更多技术文章&#xff0c;访问软件测试社区 文章目录 &#x1f680;前言&#x1f511;开通GitLink个人建站服务1. 前提条件2. 登录GitLink平台&#xff08;https://www.gitlink.org.cn/login&#xff09;3. 进入设置>个人建站>我的站点4. 新建站点5. 去仓部进行部署6. 安…

专业学习|多线程、多进程、多协程加速程序运行

学习资料来源&#xff1a;【2021最新版】Python 并发编程实战&#xff0c;用多线程、多进程、多协程加速程序运行_哔哩哔哩_bilibili 若有侵权&#xff0c;联系删除。 一、程序的提速方法——多线程、多进程、多协程 在现代编程中&#xff0c;多线程、多进程和多协程是三种常见…

C/C++蓝桥杯算法真题打卡(Day3)

一、P8598 [蓝桥杯 2013 省 AB] 错误票据 - 洛谷 算法代码&#xff1a; #include<bits/stdc.h> using namespace std;int main() {int N;cin >> N; // 读取数据行数unordered_map<int, int> idCount; // 用于统计每个ID出现的次数vector<int> ids; …

烟花燃放安全管控:智能分析网关V4烟火检测技术保障安全

一、方案背景 在中国诸多传统节日的缤纷画卷中&#xff0c;烟花盛放、烧纸祭祀承载着人们的深厚情感。一方面&#xff0c;烟花璀璨&#xff0c;是对节日欢庆氛围的热烈烘托&#xff0c;寄托着大家对美好生活的向往与期许&#xff1b;另一方面&#xff0c;袅袅青烟、点点烛光&a…

【Bert系列模型】

目录 一、BERT模型介绍 1.1 BERT简介 1.2 BERT的架构 1.2.1 Embedding模块 1.2.2 双向Transformer模块 1.2.3 预微调模块 1.3 BERT的预训练任务 1.3.1 Masked Language Model (MLM) 1.3.2 Next Sentence Prediction (NSP) 1.4 预训练与微调的关系 1.5 小结 二、BERT…

android接入rocketmq

一 前言 RocketMQ 作为一个功能强大的消息队列系统&#xff0c;不仅支持基本的消息发布与订阅&#xff0c;还提供了顺序消息、延时消息、事务消息等高级功能&#xff0c;适应了复杂的分布式系统需求。其高可用性架构、多副本机制、完善的运维管理工具&#xff0c;以及安全控制…

前端数据模拟 Mock.js 学习笔记

mock.js介绍 Mock.js是一款前端开发中拦截Ajax请求再生成随机数据响应的工具&#xff0c;可以用来模拟服务器响应 优点是&#xff1a;非常方便简单&#xff0c;无侵入性&#xff0c;基本覆盖常用的接口数据类型支持生成随机的文本、数字、布尔值、日期、邮箱、链接、图片、颜…

用DeepSeek-R1-Distill-data-110k蒸馏中文数据集 微调Qwen2.5-7B-Instruct!

下载模型与数据 模型下载&#xff1a; huggingface&#xff1a; Qwen/Qwen2.5-7B-Instruct HF MirrorWe’re on a journey to advance and democratize artificial intelligence through open source and open science.https://hf-mirror.com/Qwen/Qwen2.5-7B-Instruct 魔搭&a…

DeepSeek大模型 —— 全维度技术解析

DeepSeek大模型 —— 全维度技术解析 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;可以分享一下给大家。点击跳转到网站。 https://www.captainbed.cn/ccc 文章目录 DeepSeek大模型 —— 全维度技术解析一、模型架构全景解析1…

EasyRTC嵌入式音视频通话SDK:基于ICE与STUN/TURN的实时音视频通信解决方案

在当今数字化时代&#xff0c;实时音视频通信技术已成为人们生活和工作中不可或缺的一部分。无论是家庭中的远程看护、办公场景中的远程协作&#xff0c;还是工业领域的远程巡检和智能设备的互联互通&#xff0c;高效、稳定的通信技术都是实现这些功能的核心。 EasyRTC嵌入式音…

qt open3dAlpha重建

qt open3dAlpha重建 效果展示二、流程三、代码效果展示 二、流程 创建动作,链接到槽函数,并把动作放置菜单栏 参照前文 三、代码 1、槽函数实现 void on_actionAlpha_triggered();//alpha重建 void MainWindow::

《深入浅出数据索引》- 公司内部培训课程笔记

深入浅出数据索引 内容&#xff1a;索引理论&#xff0c;索引常见问题&#xff0c;索引最佳实践&#xff0c;sql优化实战&#xff0c;问答 哈希不支持范围查询 4层 几个亿 5层 几十亿上百亿 B树的分裂&#xff0c;50-50分裂 都是往上插一个元素&#xff08;红黑树是左右旋转&a…

PPT 技能:巧用 “节” 功能,让演示文稿更有序

在制作PPT时&#xff0c;你是否遇到过这样的情况&#xff1a;幻灯片越来越多&#xff0c;内容越来越杂&#xff0c;找某一页内容时翻得眼花缭乱&#xff1f;尤其是在处理大型PPT文件时&#xff0c;如果没有合理的结构&#xff0c;编辑和调整都会变得非常麻烦。这时候&#xff0…

Xss漏洞问题

https://bu1.github.io/2021/01/12/%E7%AC%AC%E5%8D%81%E4%BA%8C%E5%91%A8%EF%BC%9AXSS%E6%BC%8F%E6%B4%9E%E5%AD%A6%E4%B9%A0%E5%AE%9E%E6%88%98/ 后端绕开了前端&#xff0c;直接调用接口入库&#xff1a; <select οnchange“alert(1)”>12 前端拿到这个文本后&…

Docker概念与架构

文章目录 概念docker与虚拟机的差异docker的作用docker容器虚拟化 与 传统虚拟机比较 Docker 架构 概念 Docker 是一个开源的应用容器引擎。诞生于 2013 年初&#xff0c;基于 Go 语言实现。Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中&#xf…

3.使用ElementUI搭建侧边栏及顶部栏

1. 安装ElementUI ElementUI是基于 Vue 2.0 的桌面端组件库。使用之前&#xff0c;需要在项目文件夹中安装ElementUI&#xff0c;在终端中输入以下命令&#xff0c;进行安装。 npm i element-ui -S并在main.js中引入ElementUI 2. 使用elmentUI组件进行页面布局 2.1 清空原…