GPT(Generative Pre-Training)论文解读及源码实现(二)

news2024/11/15 23:25:40

本篇为gpt2的pytorch实现,参考 nanoGPT

nanoGPT如何使用见后面第5节

1 数据准备及预处理

data/shakespeare/prepare.py 文件源码分析

1.1 数据划分

下载数据后90%作为训练集,10%作为验证集

with open(input_file_path, 'r') as f:
    data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

1.2 数据编码

使用tiktoken包进行gpt2编码,gpt2默认编码方式为 bpe

enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data[:100])
val_ids = enc.encode_ordinary(val_data[:100])
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

>>>
	train has 31 tokens
	val has 40 tokens

如上取了train_data100个字符,编码后为31个tokens, 100个val 编码后为40个token. 可以通过enc.decode(train_ids) 还原为原始文本数据

train_ids 输出形式为:[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13, 198, 198, 5962, 22307, 25, 198, 1639]

2 训练数据加工

构造训练数据X,Y,其中target数据Y为X平移一位生成,每次取batch_size个数据

    data = train_data if split == 'train' else val_data # data: 301966
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

3 模型训练

3.1 GPT模型结构

3.1.1 embedding层

token embedding 和位置embedding
(batch_size 取4,句子长度取8,则输入x shape =[4,8])
embedding后维度,如下图所示

  • token embedding shape=[4,8,128]
  • 位置embeddign shape=[8,128]
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),

其中位置信息根据句子长度生成

pos = torch.arange(0, x.size(1), dtype=torch.long, device=device)

在这里插入图片描述

3.1.2 attention 层(带因果推断的attention,即需要上三角maske)

输入x shape=[4,8,128],
通过线性层后 shape: q=k=v=[4,8,128]
将embedding维度进行多头划分后,shape =[4,4,8,32]
(torch2 支持因果attention )

在这里插入图片描述

重点:attention 中mask实现,
即给上三角矩阵填充负无穷大数(负无穷在softmax时,值为0,即权重为0)

 L, S = query.size(-2), key.size(-2)
 scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
 attn_bias = torch.zeros(L, S, dtype=query.dtype)
 if is_causal:
     assert attn_mask is None
     temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
     attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
     attn_bias.to(query.dtype)

在这里插入图片描述

3.1.3 block层

  • gpt2 会有n_layer个block层,每个block层由layer normal层,attention层,mlp层构成(具体可以参考transformer)

  • block层,由attention层和全连接层组成,
    输入x shape为 [4,8,128]
    输出attention shape为 [4,8,128]
    输出MLP shape为 [4,8,128]

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x)) # shape [4,8,128]
        x = x + self.mlp(self.ln_2(x)) # shape [4,8,128]
        return x

3.2 损失函数

输入x shape [4,8,128]
输出 logits shape [4,8, 50304],即词典中每个单词的得分
loss为交叉熵损失,为一个标量,

 logits = self.lm_head(x) #; logits shape: [4,8, 50304]
 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) # targets 即为之前训练数据的 Y数据

4 模型推理

4.1 模型加载

加载训练时保存的模型
在这里插入图片描述

4.2 定义数据处理的编解码器

数据编解码器与训练时一致

    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)

4.3 数据生成(重点)

  • 第一次输入只有一个字符
    idx_cond=idx: shape =[1,1]
    logits shape =[1,50304]
    从topK个中安概率随机取一个
    和上面的idx拼接,作为第二次的输入
  • 第二次输入
    idx_cond=idx: shape =[1,2]
    logits shape =[1,50304]
    然后再从topk中安概率随机取一个进行拼接

    -直到达到最大输出活着终止字符
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

5 GPT2 使用

5.1 下载git源码

  • git clone https://github.com/karpathy/nanoGPT.git
  • 安装依赖包(建议安装torch2以上版本,其他包不限制版本)
pip install torch numpy transformers datasets tiktoken wandb tqdm

(mac pytorch 安装: pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

5.1 数据下载

测试下载数据及,并编码成数字格式

python data/shakespeare/prepare.py

5.2 模型训练

参数解释:

  • device :使用的GPU 类型,可以是cuda ,cpu , mps
  • compile 是否使用编译优化,torch2版本支持(mac mps 不支持)
  • eval_iters 迭代次数
  • block_size:训练句子长度(演示最大句子长度只取了8)
  • batch_size : batch size
  • n_layer: 使用多少个transformer block
  • n_head: attention 头数
  • n_embd: embedding 维度
  • dropout:dropout 比例
config/train_shakespeare_char.py --device=mps --compile=False --eval_iters=20 --log_interval=1 --block_size=8 --batch_size=4 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0

(备注: 我使用的是shakespeare数据集,因此将配置文件train_shakespeare_char.py 进行了修改 wandb_project = ‘shakespeare’ ;dataset = ‘shakespeare’)

5.3 模型推理

python sample.py --out_dir=out-shakespeare

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1365015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

yolo 分割label格式标注信息图片显示可视化查看

参考: https://github.com/ultralytics/ultralytics/issues/3137 https://blog.csdn.net/weixin_42357472/article/details/135218349?spm=1001.2014.3001.5501 需要把坐标信息在图片上显示 代码 1)只画出了坐标边缘 import cv2 import numpy as np from random impor…

html 原生网页使用ElementPlus 日期控件el-date-picker换成中文

项目&#xff1a; 原生的html,加jQuery使用不习惯&#xff0c;新html页面导入vue3,element plus做界面&#xff0c;现在需要把日历上英文切成中文。 最终效果&#xff1a; 导入能让element plus日历变成中文脚本&#xff1a; elementplus, vue3对应的js都可以通过创建一个vu…

idea 以文本形式输出 SpringBoot项目 目录结构

第1步&#xff1a;AltF12 打开 Terminal 终端 第2步&#xff1a;cd 到 项目路径下 第3步&#xff1a;使用 tree 命令 结果 D:. ├─.mvn │ └─wrapper ├─applog │ └─logs ├─src │ ├─main │ │ ├─java │ │ │ └─com │ │ │ └─zhangziwa …

【软件测试】学习笔记-如何做好单元测试

什么是单元测试&#xff1f; 在正式开始今天的话题之前&#xff0c;我先给你分享一个工厂生产电视机的例子。 工厂首先会将各种电子元器件按照图纸组装在一起构成各个功能电路板&#xff0c;比如供电板、音视频解码板、射频接收板等&#xff0c;然后再将这些电路板组装起来构…

【计算机网络】网络编程套接字socket--UDP/TCP简单服务器实现/TCP协议通信流程

文章目录 一、预备知识1.IP和端口号2.TCP协议和UDP协议3.网络字节序 二、socket编程接口1.socket 常见API2.sockaddr结构 三、UDP服务器相关重要接口介绍sendtorecvfrompopen 1.udpServer.hpp2.udpServer.cc3.udpClient.hpp4.udpClient.cc5.onlineUser.hpp 四、TCP服务器socket…

高性能、可扩展、分布式对象存储系统MinIO的介绍、部署步骤以及代码示例

详细介绍 MinIO 是一款流行的开源对象存储系统&#xff0c;设计上兼容 Amazon S3 API&#xff0c;主要用于私有云和边缘计算场景。它提供了高性能、高可用性以及易于管理的对象存储服务。以下是 MinIO 的详细介绍及优缺点&#xff1a; 架构与特性&#xff1a; 开源与跨平台&am…

HTML---JavaScript操作DOM对象

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 本章目标 了解DOM的分类和节点间的关系熟练使用JavaScript操作DOM节点 访问DOM节点 能够熟练的进行节点的创建、添加、删除、替换等 能够熟练的设置元素的样式 能够灵活运用JavaScript获取元素…

SpringBoot学习(八)-SpringBoot + Dubbo + zookeeper

分布式DubboZookeeper 1、分布式理论 1&#xff09;什么是分布式系统&#xff1f; 在《分布式系统原理与范型》一书中有如下定义&#xff1a;“分布式系统是若干独立计算机的集合&#xff0c;这些计算机对于用户来说就像单个相关系统”&#xff1b; 分布式系统是由一组通过…

【AI视野·今日CV 计算机视觉论文速览 第284期】Fri, 5 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Fri, 5 Jan 2024 Totally 62 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Learning to Prompt with Text Only Supervision for Vision-Language Models Authors Muhammad Uzair Khattak, Muhammad F…

2024.1.7-实战-docker方式给自己网站部署prometheus监控ecs资源使用情况-2024.1.7(测试成功)

实战-docker方式给自己网站部署prometheus监控ecs资源使用情况-2024.1.7(测试成功) 目录 最终效果 原文链接 https://onedayxyy.cn/docs/prometheus-grafana-ecs 参考模板 https://i4t.com/ https://grafana.frps.cn &#x1f530; 额&#xff0c;注意哦: 他这个是通过frp来…

【Flutter 开发实战】Dart 基础篇:从了解背景开始

想要学会用 Flutter 开发 App&#xff0c;就不可避免的要学习另一门很有意思的编程语言 —— Dart。很多小伙伴可能在学习 Flutter 之前可能都没听说过这门编程语言&#xff0c;我也是一样&#xff0c;还以为 Dart 是为了 Flutter 而诞生的&#xff1b;然而&#xff0c;当我们去…

网页设计与制作web前端设计html+css+js成品。电脑网站制作代开发。vscodeDrea 【企业公司宣传网站(HTML静态网页项目实战)附源码】

网页设计与制作web前端设计htmlcssjs成品。电脑网站制作代开发。vscodeDrea 【企业公司宣传网站&#xff08;HTML静态网页项目实战&#xff09;附源码】 https://www.bilibili.com/video/BV1Hp4y1o7RY/?share_sourcecopy_web&vd_sourced43766e8ddfffd1f1a1165a3e72d7605

分布式系统架构设计之分布式消息队列基础知识

随着微服务、大数据和云计算的普及&#xff0c;分布式系统已经成为现代软件架构的核心。在分布式系统中&#xff0c;各个组件间的通信和数据交换尤其重要&#xff0c;而消息队列正是实现这一目标的关键技术之一。 在分布式架构设计过程中&#xff0c;架构师们需要对消息队列有…

爬虫瑞数4案例:网上房地产

声明&#xff1a; 该文章为学习使用&#xff0c;严禁用于商业用途和非法用途&#xff0c;违者后果自负&#xff0c;由此产生的一切后果均与作者无关 一、瑞数简介 瑞数动态安全 Botgate&#xff08;机器人防火墙&#xff09;以“动态安全”技术为核心&#xff0c;通过动态封装…

深度解析基于模糊数学的C均值聚类算法

深度解析基于模糊数学的C均值聚类算法 模糊C均值聚类 (FCM)聚类步骤&#xff1a;FCM Python代码&#xff1a; 模糊C均值聚类 (FCM) 在数据挖掘和聚类分析领域&#xff0c;C均值聚类是一种广泛应用的方法。模糊C均值聚类&#xff08;FCM&#xff09;是C均值聚类的自然升级版。相…

创建Vue3项目

介绍 使用命令创建vue3项目 示例 第一步&#xff1a;执行创建项目命令 npm create vuelatest第二步&#xff1a;填写输入项 第三步&#xff1a;进入study-front-vue3文件夹 cd study-front-vue3第四步&#xff1a;执行npm命令安装依赖 npm install第五步&#xff1a;运行…

Vue中Vuex的环境搭建和原理分析及使用

Vuex的环境搭建 Vuex是Vue实现集中式数据管理的Vue的一个插件&#xff0c;集中式可以理解为一个老师给多个学生讲课。 Vue2.0版本的安装&#xff1a; npm i vuex3 使用Vuex需要在store中的index.js引入Vuex和main.js中引入store,目的是让vm和vc都能看到$store。实现多个组件…

2024--Django平台开发-Django知识点(三)

day03 django知识点 项目相关路由相关 urls.py视图相关 views.py模版相关 templates资源相关 static/media 1.项目相关 新项目 开发时&#xff0c;可能遇到使用其他的版本。虚拟环境 老项目 打开项目虚拟环境 1.1 关于新项目 1.系统解释器命令行【学习】 C:/python38- p…

大模型LLM训练的数据集

引言 2021年以来&#xff0c;大预言模型的开发和生产使用呈现出爆炸式增长。除了李开复、王慧文、王小川等“退休”再创业的互联网老兵&#xff0c;在阿里巴巴、腾讯、快手等互联网大厂的中高层也大胆辞职&#xff0c;加入这波创业浪潮。 通用大模型初创企业MiniMax完成了新一…

目标检测-One Stage-YOLOv4

文章目录 前言一、目标检测网络组成二、BoF&#xff08;Bag of Freebies&#xff09;1. 数据增强2.语义分布偏差问题3.损失函数IoUGIoUDIoUCIoU 三、BoS(Bag of Specials)增强感受野注意力机制特征融合激活函数后处理 四、YOLO v4的网络结构和创新点1.缓解过拟合&#xff08;Bo…