Chapter4.1 Coding an LLM architecture

news2025/1/8 21:16:10

文章目录

  • 4 Implementing a GPT model from Scratch To Generate Text
    • 4.1 Coding an LLM architecture

4 Implementing a GPT model from Scratch To Generate Text

  • 本章节包含

    1. 编写一个类似于GPT的大型语言模型(LLM),这个模型可以被训练来生成类似人类的文本。
    2. Normalizing layer activations to stabilize neural network training
    3. 在深度神经网络中添加shortcut connections,以更有效地训练模型
    4. 实现 Transformer 模块以创建不同规模的 GPT 模型
    5. 计算 GPT 模型的参数数量及其存储需求

    在上一章中,学习了多头注意力机制并对其进行了编码,它是LLMs的核心组件之一。在本章中,将编写 LLM 的其他构建块,并将它们组装成类似 GPT 的模型


4.1 Coding an LLM architecture

  • 诸如GPT和Llama等模型,基于原始Transformer架构中的decoder部分,因此,这些LLM通常被称为"decoder-like" LLMs,与传统的深度学习模型相比,LLM规模更大,这主要归因于它们庞大的参数数量,而非代码量。因为它的许多组件都是重复的,下图提供了类似 GPT LLM 的自上而下视图

本章将详细构建一个最小规模的GPT-2模型(1.24亿参数),并展示如何加载预训练权重以兼容更大规模的模型。

  • 1.24亿参数GPT-2模型的配置细节包括:

    GPT_CONFIG_124M = {
        "vocab_size": 50257,    # Vocabulary size
        "context_length": 1024, # Context length
        "emb_dim": 768,         # Embedding dimension
        "n_heads": 12,          # Number of attention heads
        "n_layers": 12,         # Number of layers
        "drop_rate": 0.1,       # Dropout rate
        "qkv_bias": False       # Query-Key-Value bias
    }
    

    我们使用简短的变量名以避免后续代码行过长

    1. "vocab_size" 词汇表大小,由 BPE tokenizer 支持,值为 50,257。
    2. "context_length" 模型的最大输入标记数量,通过 positional embeddings 实现。
    3. "emb_dim" token输入的嵌入大小,将每个token转换为 768 维向量。
    4. "n_heads" 多头注意力机制中的注意力头数量。
    5. "n_layers" 是模型中 transformer 块的数量
    6. "drop_rate" 是 dropout 机制的强度,第 3 章讨论过;0.1 表示在训练期间丢弃 10% 的隐藏单元以缓解过拟合
    7. "qkv_bias" 决定多头注意力机制(第 3 章)中的 Linear 层在计算查询(Q)、键(K)和值(V)张量时是否包含偏置向量;我们将禁用此选项,这是现代 LLMs 的标准做法;然而,我们将在第 5 章将 OpenAI 的预训练 GPT-2 权重加载到我们的重新实现时重新讨论这一点。
  • 下图中的方框展示了我们为实现最终 GPT 架构所需处理的各个概念的顺序。我们将从第一步开始,即一个我们称为 DummyGPTModel 的 GPT 骨架占位符:

    import torch
    import torch.nn as nn
    
    
    class DummyGPTModel(nn.Module):
        def __init__(self, cfg):
            super().__init__()
            self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
            self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
            self.drop_emb = nn.Dropout(cfg["drop_rate"])
            
            # Use a placeholder for TransformerBlock
            self.trf_blocks = nn.Sequential(
                *[DummyTransformerBlock(cfg) for _ in range(cfg["n_layers"])])
            
            # Use a placeholder for LayerNorm
            self.final_norm = DummyLayerNorm(cfg["emb_dim"])
            self.out_head = nn.Linear(
                cfg["emb_dim"], cfg["vocab_size"], bias=False
            )
    
        def forward(self, in_idx):
            batch_size, seq_len = in_idx.shape
            tok_embeds = self.tok_emb(in_idx)
            pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
            x = tok_embeds + pos_embeds
            x = self.drop_emb(x)
            x = self.trf_blocks(x)
            x = self.final_norm(x)
            logits = self.out_head(x)
            return logits
    
    
    class DummyTransformerBlock(nn.Module):
        def __init__(self, cfg):
            super().__init__()
            # A simple placeholder
    
        def forward(self, x):
            # This block does nothing and just returns its input.
            return x
    
    
    class DummyLayerNorm(nn.Module):
        def __init__(self, normalized_shape, eps=1e-5):
            super().__init__()
            # The parameters here are just to mimic the LayerNorm interface.
    
        def forward(self, x):
            # This layer does nothing and just returns its input.
            return x
    
    1. DummyGPTModel:简化版的 GPT 类模型,使用 PyTorch 的 nn.Module 实现。
    2. 模型组件:包括标记嵌入、位置嵌入、丢弃层、变换器块、层归一化和线性输出层。
    3. 配置字典:配置通过 Python 字典传入,如 GPT_CONFIG_124M,用于传递模型配置。
    4. forward 方法:描述数据从输入到输出的完整流程。计算嵌入 → 应用 dropout → 通过 transformer blocks 处理 → 应用归一化 → 生成 logits。
    5. 占位符DummyLayerNormDummyTransformerBlock 是待实现的组件。
  • 数据流动:下图提供了 GPT 模型中数据流动的高层次概述。

    使用 tiktoken 分词器对由 GPT 模型的两个文本输入组成的批次进行分词:

    import tiktoken
    
    tokenizer = tiktoken.get_encoding("gpt2")
    
    batch = []
    
    txt1 = "Every effort moves you"
    txt2 = "Every day holds a"
    
    batch.append(torch.tensor(tokenizer.encode(txt1)))
    batch.append(torch.tensor(tokenizer.encode(txt2)))
    batch = torch.stack(batch, dim=0)
    print(batch)
    
    """输出"""
    tensor([[6109, 3626, 6100,  345],
            [6109, 1110, 6622,  257]]
          )
    

    接下来,我们初始化一个包含 1.24 亿参数的 DummyGPTModel 实例,并将 tokenized batch 输入其中。

    torch.manual_seed(123)
    model = DummyGPTModel(GPT_CONFIG_124M)
    
    logits = model(batch)
    print("Output shape:", logits.shape)
    print(logits)
    
    """输出"""
    Output shape: torch.Size([2, 4, 50257])
    tensor([[[-0.9289,  0.2748, -0.7557,  ..., -1.6070,  0.2702, -0.5888],
             [-0.4476,  0.1726,  0.5354,  ..., -0.3932,  1.5285,  0.8557],
             [ 0.5680,  1.6053, -0.2155,  ...,  1.1624,  0.1380,  0.7425],
             [ 0.0447,  2.4787, -0.8843,  ...,  1.3219, -0.0864, -0.5856]],
    
            [[-1.5474, -0.0542, -1.0571,  ..., -1.8061, -0.4494, -0.6747],
             [-0.8422,  0.8243, -0.1098,  ..., -0.1434,  0.2079,  1.2046],
             [ 0.1355,  1.1858, -0.1453,  ...,  0.0869, -0.1590,  0.1552],
             [ 0.1666, -0.8138,  0.2307,  ...,  2.5035, -0.3055, -0.3083]]],
           grad_fn=<UnsafeViewBackward0>)
    
    1. 输出张量:输出张量有两行,分别对应两个文本样本。每个文本样本由 4 个标记组成;每个标记是一个 50,257 维的向量,这与标记器的词汇表大小一致。
    2. 嵌入维度:50,257 维对应词汇表中的唯一标记,后处理阶段将其转换回 token IDs 并解码为单词。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2273428.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git revert回滚

回退中间的某次提交&#xff08;此操作在预生产分支上比较常见&#xff09;&#xff0c;建议此方式使用命令进行操作&#xff08;做好注释&#xff0c;方便后续上线可以找到这个操作&#xff09; Git操作&#xff1a; 命令&#xff1a;revert -n 版本号 1&#xff1a;git re…

1. 使用springboot做一个音乐播放器软件项目【前期规划】

背景&#xff1a; 现在大部分音乐软件都是要冲会员才可以无限常听的。对于喜欢听音乐的小伙伴&#xff0c;资金又比较紧张&#xff0c;是那么的不友好。作为程序员的我&#xff0c;也是喜欢听着歌&#xff0c;敲着代码。 最近就想做一个音乐播放器的软件&#xff0c;在内网中使…

Flutter项目开发模版,开箱即用(Plus版本)

前言 当前案例 Flutter SDK版本&#xff1a;3.22.2 本文&#xff0c;是由这两篇文章 结合产出&#xff0c;所以非常建议大家&#xff0c;先看完这两篇&#xff1a; Flutter项目开发模版&#xff1a; 主要内容&#xff1a;MVVM设计模式及内存泄漏处理&#xff0c;涉及Model、…

配置数据的抗辐照加固方法

SRAM 型FPGA 的配置存储器可以看成是由0 和1 组成的二维阵列&#xff0c;帧的高度为矩阵阵列的高度&#xff0c;相同结构的配置帧组成配置列&#xff0c;如CLB 列、IOB 列、输入输出互联(Input Output Interconnect,IOI)列、全局时钟(Global Clock, GCLK)列、BRAM 列和BRAM 互联…

【学习路线】Python 算法(人工智能)详细知识点学习路径(附学习资源)

学习本路线内容之前&#xff0c;请先学习Python的基础知识 其他路线&#xff1a; Python基础 >> Python进阶 >> Python爬虫 >> Python数据分析&#xff08;数据科学&#xff09; >> Python 算法&#xff08;人工智能&#xff09; >> Pyth…

【Vue3项目实战系列一】—— 全局样式处理,导入view-ui-plus组件库,定制个性主题

&#x1f609; 你好呀&#xff0c;我是爱编程的Sherry&#xff0c;很高兴在这里遇见你&#xff01;我是一名拥有十多年开发经验的前端工程师。这一路走来&#xff0c;面对困难时也曾感到迷茫&#xff0c;凭借不懈的努力和坚持&#xff0c;重新找到了前进的方向。我的人生格言是…

SCAU期末笔记 - 数据库系统概念往年试卷解析

数据库搞得人一头雾水&#xff0c;题型太多太杂&#xff0c;已经准备摆烂了。就刷刷往年试卷&#xff0c;挂不挂听天由命。 2019年 Question 1 选择题 1. R ∩ S R∩S R∩S等于一下哪个选项&#xff1f; 画个文氏图秒了 所以选A. R ∩ S R − ( R − S ) R∩SR-(R-S) R∩…

oxml中创建CT_Document类

概述 本文基于python-docx源码&#xff0c;详细记录CT_Document类创建的过程&#xff0c;以此来加深对Python中元类、以及CT_Document元素类的认识。 元类简介 元类&#xff08;MetaClass&#xff09;是Python中的高级特性。元类是什么呢&#xff1f;Python是面向对象编程…

Fabric环境部署

官方下载文档&#xff1a;A Blockchain Platform for the Enterprise — Hyperledger Fabric Docs main documentation 1.1 创建工作目录 将Fabric代码按照GO语言的推荐方式进行存放&#xff0c;创建目录结构并切换到该目录下。具体命令如下&#xff1a; mkdir -p ~/go/src/g…

TCP与DNS的报文分析

场景拓扑&#xff1a; 核心路由配置&#xff1a; 上&#xff08;DNS&#xff09;&#xff1a;10.1.1.1/24 下(WEB)&#xff1a;20.1.1.1/24 左&#xff08;client&#xff09;&#xff1a;192.168.0.1/24 右(PC3)&#xff1a;192.168.1.1/24Clint2配置&a…

怎么管理电脑usb接口,分享四种USB端口管理方法

怎么管理电脑usb接口&#xff0c;分享四种USB端口管理方法 USB接口作为电脑重要的外部接口&#xff0c;方便了数据传输和设备连接。 然而&#xff0c;不加管理的USB接口也可能带来安全隐患&#xff0c;例如数据泄露、病毒传播等。 因此&#xff0c;有效管理电脑USB接口至关重…

教育咨询系统架构与功能分析

一、系统架构 服务端 服务端&#xff1a;Java&#xff08;最低JDK1.8&#xff0c;支持JDK11以及JDK17&#xff09;数据库&#xff1a;MySQL数据库&#xff08;标配5.7版本&#xff0c;支持MySQL8&#xff09;ORM框架&#xff1a;Mybatis&#xff08;集成通用tk-mapper&#x…

nginx http反向代理

系统&#xff1a;Ubuntu_24.0.4 1、安装nginx sudo apt-get update sudo apt-get install nginx sudo systemctl start nginx 2、配置nginx.conf文件 /etc/nginx/nginx.conf&#xff0c;但可以在 /etc/nginx/sites-available/ 目录下创建一个新的配置文件&#xff0c;并在…

代码管理助手-Git

前言 Git 是一个版本控制系统&#xff0c;可以帮助你记录文件的每一次修改。这样&#xff0c;如果你在编程时不小心把代码写错了&#xff0c;可以很容易地回退到之前的版本。最重要的是&#xff0c;Git 是完全免费的&#xff0c;用户可以在自己的计算机上安装和使用 Git&#x…

注册中心如何选型?Eureka、Zookeeper、Nacos怎么选

这是小卷对分布式系统架构学习的第9篇文章&#xff0c;第8篇时只回答了注册中心的工作原理的内容&#xff0c;面试官的第二个问题还没回答&#xff0c;今天再来讲讲各个注册中心的原理&#xff0c;以及区别&#xff0c;最后如何进行选型 上一篇文章&#xff1a;如何设计一个注册…

WebRtc02: WebRtc架构、目录结构、运行机制

整体架构 WebRtc主要分为三层&#xff1a; CAPI层&#xff1a;外层调用Session管理核心层&#xff1a;包括视频引擎、音频引擎、网络传输 可由使用者重写视频引擎&#xff1a;编解码器、视频缓存、视频增强音频引擎&#xff1a;编解码器、音频缓存、回音消除、降噪传输&#x…

【Java】JVM内存相关笔记

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域有各自的用途&#xff0c;以及创建和销毁的时间&#xff0c;有的区域随着虚拟机进程的启动而一直存在&#xff0c;有些区域则是依赖用户线程的启动和结束而建立和销毁。 程序计数器&am…

基于springboot的网上商城购物系统

作者&#xff1a;学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等 文末获取“源码数据库万字文档PPT”&#xff0c;支持远程部署调试、运行安装。 目录 项目包含&#xff1a; 开发说明&#xff1a; 系统功能&#xff1a; 项目截图…

STM32 拓展 低功耗案例3:待机模式 (hal)

配置PA0的两种方式&#xff1a; 第一种 第二种 复制寄存器代码然后对其进行修改 mian.c /* USER CODE BEGIN Header */ /********************************************************************************* file : main.c* brief : Main program body…

VLMs之Agent之CogAgent:《CogAgent: A Visual Language Model for GUI Agents》翻译与解读

VLMs之Agent之CogAgent&#xff1a;《CogAgent: A Visual Language Model for GUI Agents》翻译与解读 导读&#xff1a;这篇论文介绍了CogAgent&#xff0c;一个专注于图形用户界面 (GUI) 理解和导航的视觉语言模型 (VLM)。这篇论文提出了一种新的视觉语言模型 CogAgent&#…