花费7元训练自己的GPT 2模型

news2025/1/27 12:13:00

在上一篇博客中,我介绍了用Tensorflow来重现GPT 1的模型和训练的过程。这次我打算用Pytorch来重现GPT 2的模型并从头进行训练。

GPT 2的模型相比GPT 1的改进并不多,主要在以下方面:

1. GPT 2把layer normalization放在每个decoder block的前面。

2. 最终的decoder block之后额外添加了一个layer normalization。

3. 残差层的参数初始化根据网络深度进行调节

4. 训练集采用了webtext(45GB),而不是之前采用的bookcorpus(5GB)

5. 更深的网络结构,最大的模型拥有15亿的参数,对比GPT 1是1.2亿的参数

GPT 2有以下四种不同深度的模型架构,如图:

以下我将用pytorch代码来搭建一个GPT 2的模型,以最小的GPT 2为例,采用bookcorpus的数据,在AutoDL平台的一个40G显存的A100显卡上进行训练,看看效果如何。

模型结构

整个模型的结构和GPT 1是基本一致的。

定义一个多头注意力模块,如以下代码:

class MHA(nn.Module):
    def __init__(self, d_model, num_heads, attn_pdrop, resid_pdrop):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.attn_pdrop = attn_pdrop
        self.resid_dropout = nn.Dropout(resid_pdrop)
        self.ln = nn.Linear(d_model, d_model*3)
        self.c_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.size()
        x_qkv = self.ln(x)
        q, k, v = x_qkv.split(self.d_model, dim=2)
        q = q.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)
        k = k.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)
        v = v.view(B, T, self.num_heads, C//self.num_heads).transpose(1, 2)
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_pdrop if self.training else 0, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        y = self.resid_dropout(y)
        return y

这个模块接收一个输入数据,大小为(batch_size, seq_len, dimension),然后进行一个线性变换层,把数据映射为(batch_size, seq_len, dimension*3)的维度,这里的dimension*3表示的是qkv这三个值的拼接。接着就把这个数据切分为q,k,v三份,然后每份都把维度调整为(batch_size, seq_len, num_head, dimension/num_head),num_head表示这个自注意力模块包含多少个head。最后就可以调用scaled_dot_product_attention进行qk的相似度计算,进行缩放之后与v值相乘。Pytorch的这个函数提供了最新的flash attention的实现,可以大幅提升计算性能。最后就是对qkv的结果进行一个线性变换,映射为一个(batch_size, seq_len, dimension)的向量。

自注意力模块的输出结果,将通过一个Feed forward层进行计算,代码如下:

class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout):
        super().__init__()
        self.ln1 = nn.Linear(d_model, dff)
        self.ln2 = nn.Linear(dff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.layernorm = nn.LayerNorm(d_model)
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.ln1(x)
        x = self.gelu(x)
        x = self.ln2(x)
        x = self.dropout(x)
        return x

代码很简单,就是做了两次线性变换,第一次把维度扩充到dimension*4,第二次把维度恢复为dimension。

最后定义一个decoder block模块,把多头注意力模块和feed forward模块组合起来,代码如下:

class Block(nn.Module):
    def __init__(self, d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(d_model)
        self.attn = MHA(d_model, num_heads, attn_pdrop, resid_pdrop)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, dff, dropout)

    def forward(self, x):
        x = x + self.attn(self.layernorm1(x))
        x = x + self.ff(self.layernorm2(x))
        return x

有了decoder block之后,GPT 2的模型就是把这些block串起来,例如最小的GPT 2的模型结构是定义了12个decoder block。模型接收的是字符序列经过tokenizer之后的数字,然后把这些数字通过embedding层映射为向量表达,例如对每个token id,映射为784维度的一个向量。为了能在embedding的向量里面反映字符的位置信息,我们需要把字符的位置也做一个embedding,然后两个embedding相加。

输入数据经过embedding处理后,通过多个decoder block处理之后,数据的维度为(batch_size, seq_len, dimension), 我们需要通过一个权重维度为(dimension, vocab_size)的线性变换,把数据映射为(batch_size, seq_len, vocab_size)的维度。这里vocab_size表示tokenizer的单词表的长度,例如对于GPT 2所用的tokenizer,有50257个单词。对于输出数据进行softmax计算之后,我们就可以得到每个token的预测概率,从而可以和label数据,即真实的下一个token id进行比较,计算loss值。

GPT 2模型的代码如下:

class GPT2(nn.Module):
    def __init__(self, vocab_size, d_model, block_size, embed_pdrop, num_heads, dff, attn_pdrop, resid_pdrop, dropout, num_layer):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model, sparse=False)
        self.pos_embed = nn.Embedding(block_size, d_model, sparse=False)
        self.dropout_embed = nn.Dropout(embed_pdrop)
        #self.blocks = [Block(d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout) for _ in range(num_layer)]
        self.blocks = nn.ModuleList([Block(d_model, num_heads, dff, attn_pdrop, resid_pdrop, dropout) for _ in range(num_layer)])
        self.num_layer = num_layer
        self.block_size = block_size
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.token_embed.weight = self.lm_head.weight
        self.layernorm = nn.LayerNorm(d_model)

        self.apply(self._init_weights)

        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        device = x.device
        b, t = x.size()
        pos = torch.arange(0, t, dtype=torch.long, device=device) 
        x = self.token_embed(x) + self.pos_embed(pos)
        x = self.dropout_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.layernorm(x)

        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, -1, :])
            loss = None

        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, block_size=512):
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

模型训练

定义好模型之后,我们就可以开始训练了。

首先我们需要准备训练数据集。GPT 2采用的是webtext,网上的一些公开网页数据来进行训练。在Huggingface上面有对应的一个公开数据集。不过考虑到我们的资源有限,我这次还是采用GPT 1所用的bookcorpus数据集来训练。

以下代码是下载huggingface的数据集,并用GPT 2的tokenizer来进行编码:

from datasets import load_dataset
from transformers import GPT2Tokenizer

dataset = load_dataset("bookcorpusopen", split="train")

block_size=513
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def tokenize_function(examples):
    token_ids = [tokenizer(text) for text in examples["text"]]
    total_length = [len(t["input_ids"]) for t in token_ids]
    total_length = [(l//(block_size+1))*(block_size+1) for l in total_length]
    result = []
    label = []
 
    for i in range(len(total_length)):
        result.extend([token_ids[i]["input_ids"][j:j+block_size+1] for j in range(0, total_length[i], block_size+1)])
    return {"token_ids": result}
 
ds_test = ds['train'].select(range(10000))
 
tokenized_datasets = ds_test.map(
    tokenize_function, batched=True, num_proc=8, remove_columns=["title", "text"], batch_size=100
)
 
tokenized_datasets.save_to_disk("data/boocorpusopen_10000_512tokens")

GPT1采用的bookcorpus有7000多本书,huggingface的bookcorpusopen数据集有14000多本,这里我只采用了10000本书来构建数据集,对于每本书进行tokenizer转化后,每513个token写入为1条记录。这样我们在训练时,每条记录我们采用前1-512个token作为训练,取2-513个token作为label。

以下代码将读取我们处理好的数据集,并转化为pytorch的dataloader

from datasets import load_from_disk

dataset = load_from_disk("data/boocorpusopen_10000_512tokens")
dataset = dataset.with_format("torch")
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

然后我们就可以实例化一个GPT 2的model并开始训练,具体的代码可以见repo https://github.com/gzroy/gpt2_torch.git 里面的train.py文件。

如果在本地显卡上训练,对应12层的网络结构需要30多G的显存,我的显卡是2080Ti,只有11G显存,因此只能指定6层decoder。我们可以在autodl上面租用一个40G显存的A100显卡,价格是3.45元每小时,在这个显卡上开启半精度进行训练,大约1个小时可以跑10000个迭代,batch大小为64。我总共训练了2小时,最终在训练集上的Loss值为3.5左右,准确度为35%,花费为7元。

生成文本

最后我们可以基于这个训练了1个小时的GPT 2模型来测试一下,看生成文本的效果如何,如以下代码:

from transformers import GPT2Tokenizer
from model import GPT2
import torch
from torch.nn import functional as F
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='gpt2 predict')
    parser.add_argument('--checkpoint_path', type=str, default='checkpoints/')
    parser.add_argument('--checkpoint_name', type=str, default='')
    parser.add_argument('--d_model', type=int, default=768)
    parser.add_argument('--block_size', type=int, default=512)
    parser.add_argument('--dff', type=int, default=768*4)
    parser.add_argument('--heads', type=int, default=12)
    parser.add_argument('--decoder_layers', type=int, default=6)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--input', type=str)
    parser.add_argument('--generate_len', type=int, default=100)
    parser.add_argument('--topk', type=int, default=5)
    args = parser.parse_args()

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    vocab_size = len(tokenizer.get_vocab())
    model = GPT2(vocab_size, args.d_model, args.block_size, 0, args.heads, args.dff, 0, 0, 0, args.decoder_layers)
    model.to(args.device)
    model = torch.compile(model)
    checkpoint = torch.load(args.checkpoint_path+args.checkpoint_name)
    model.load_state_dict(checkpoint['model_state_dict'])

    token_id = tokenizer.encode(args.input)
    input_data = torch.reshape(torch.tensor(token_id, device=args.device), [1,-1])
    predicted = model.generate(input_data, args.generate_len, 1.0, args.topk, args.block_size)
    print("Generated text:\n-------------------")
    print(tokenizer.decode(predicted.cpu().numpy()[0]))

运行以下命令,给定一个文本的开头,然后让模型生成200字看看:

python predict.py --checkpoint_name model_1.pt --input 'it was saturday night, the street' --generate_len 200 --topk 10

生成的文本如下:

it was saturday night, the street lights blared and the street lights flickered on. A few more houses were visible.

The front door opened, and a large man stepped in and handed him one. He handed the man the keys and a small smile. It looked familiar, and then a little too familiar. The door was closed.

"Hey! You guys out there?" he said, his eyes wide.

"What are you up to?" the man asked.

"I'm just asking for you out in my office."

The man was about thirty feet away from them.

"I'm in a serious situation, but it's just the way you are."

He looked around at the man, the man looked up and down, and then his eyes met hers. He was a little older than he was, but his eyes were blue with red blood. He looked like a giant. His eyes were blue and red, and his jaw looked like a giant

可见生成的文本语法没有问题,内容上也比较连贯,上下文的逻辑也有关联。如果模型继续训练更长时间,相信生成文本的内容会更加好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/820298.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

森林生物量(蓄积量)数据处理到随机森科估算全流程

python森林生物量&#xff08;蓄积量&#xff09;估算全流程 一.哨兵2号获取/处理/提取数据1.1 影像处理与下载采用云概率影像去云采用6S模型对1C级产品进行大气校正geemap下载数据到本地NDVI 1.2 各种参数计算&#xff08;生物物理变量、植被指数等&#xff09;LAI&#xff1a…

程序员面试金典17.*

文章目录 17.01 不用加号的加法17.04 消失的数字17.05字母与数字17.06 2出现的次数17.07 婴儿名字17.08 马戏团人塔17.09 第k个数17.10 主要元素17.11 单词距离17.12 BiNode17.13 恢复空格&#xff08;未做&#xff0c;字典树dp&#xff09;17.14 最小K个数17.15 最长单词17.16…

TIA Portal(博途)V15.0 安装教程

哈喽&#xff0c;大家好&#xff0c;我是雷工。 最近项目上用到博图15.0软件&#xff0c;在虚拟机安装博图软件。下面记录安装过程。 一、安装环境 虚拟机内的Win10系统专业版64位。 二、注意事项 1、安装文件的存放路径不能含中文字符&#xff0c;软件需安装在C盘。 2、操…

uniapp实现地图点聚合

点聚合的最重要的一个地方是在 markers 中添加 joinCluster true 这个重要的属性&#xff0c;否则将无法开启点聚合功能。 其实在uniapp的官方文档里体现的不是那么清楚&#xff0c;但是在小程序文档提示的就相当清楚。 实现效果如下&#xff1a; 重点&#xff1a;需要编译在小…

PySpark介绍与安装

Spark是什么 定义&#xff1a;Apache Spark是用于大规模数据&#xff08;large-scala data&#xff09;处理的统一&#xff08;unified&#xff09;分析引擎。 简单来说&#xff0c;Spark是一款分布式的计算框架&#xff0c;用于调度成百上千的服务器集群&#xff0c;计算TB、…

免费商城搭建之java版直播商城平台规划及常见的营销模式+电商源码+小程序+三级分销+二次开发

&#xfeff; 1. 涉及平台 平台管理、商家端&#xff08;PC端、手机端&#xff09;、买家平台&#xff08;H5/公众号、小程序、APP端&#xff08;IOS/Android&#xff09;、微服务平台&#xff08;业务服务&#xff09; 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、R…

Linux 入侵痕迹清理技巧(仅限学习安全知识)

vim ~/.bash_history 查看历史操作命令&#xff1a;history history记录文件&#xff1a;more ~/.bash_history history -c #使用vim打开一个文件 vi test.txt # 设置vim不记录命令&#xff0c;Vim会将命令历史记录&#xff0c;保存在viminfo文件中。 :set history0 # 用vim的…

Qt之qml和widget混合编程调用

首先是创建一个widget项目 然后需要添加qml和quick的插件使用 QT quickwidgets qml 接着要在界面上创建一个quickwidget和按钮 创建一个c对象类 QObjectQml #ifndef QOBJECTQML_H #define QOBJECTQML_H#include <QObject> #include <QDebug> class QObjectQml …

如何去推动自己团队所提出的需求

自己团队所提出的需求是指性能优化、技术栈升级、架构调整等需求&#xff0c;偏向于技术范畴。 要推动这类需求&#xff0c;除了自己团队的努力之外&#xff0c;还需要一些外在的辅助因素。 一、时机 对于我们自己团队内部就能消化的需求&#xff0c;主要的问题就是人员&#x…

upload-labs详解------持续更新

目录 注&#xff1a; 搭建&#xff1a; pass-01&#xff08;前端绕过&#xff09; pass-02&#xff08;后缀绕过&#xff09; pass-03&#xff08;黑名单绕过&#xff09; pass-04&#xff08;Apache解析漏洞\.htaccess文件绕过&#xff09; 注&#xff1a; 本项目提供的…

稍微深度踩坑haystack + whoosh + jieba

说到django的全文检索&#xff0c;网上基本推荐的都是 haystack whoosh jieba 的方案。 由于我的需求对搜索时间敏感度较低&#xff0c;但是要求不能有数据的错漏。 但是没有调试的情况下&#xff0c;搜索质量真的很差&#xff0c;搞得我都想直接用Like搜索数据库算了。 但是…

item_search-ks-根据关键词取商品列表

一、接口参数说明&#xff1a; item_search-根据关键词取商品列表&#xff0c;点击更多API调试&#xff0c;请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/ks/item_search 名称类型必须描述keyString是调用key&#xff08;http:…

一文快速入门Byzer-python

目录 一、Byzer-Python介绍 二、Byzer-python工具语法糖 三、环境依赖 1. Python 环境搭建 2. Ray 环境搭建 3. Byzer-python 与 Ray 四、参数详解 五、数据处理 1. Byzer-python 处理数据 2. Byzer-python 代码说明 3. Byzer-python 读写 Excel 文件 4. Byzer-pytho…

FPGA及其应用

目录 1.什么是FPGA 2.FPGA的硬件结构 3.FPGA与单片机的区别 4.FPGA的具体应用场景 1.什么是FPGA FPGA&#xff08;Field-Programmable Gate Array&#xff09;是一种可编程逻辑器件&#xff0c;它由大量的可编程逻辑单元&#xff08;CLB&#xff09;和可编程连线&#xff08…

解决el-table打印时数据重复显示

1.表格数据比较多加了横向滚动和竖向滚动&#xff0c;导致打印出问题 主要原因是fixed导致&#xff0c;但是又必须得滚动和打印 方法如下&#xff1a; 1. 2. is_fixed: true,//data中定义初始值 3.打印时设置为false,记得要改回true if (key 2) { this.is_fixed false //打…

Image process ----butterworth high pass 滤波器

import numpy as np import matplotlib.pyplot as plt import cv2def Butterworth_Filter_Image():img cv2.imread(r/Users/PycharmProjects/ImageProcess/Butterworth Filter Image/Pasted Graphic 31.png,0)plt.imshow(img)# ———————————————————————…

Sublime操作技巧笔记

同时选中2个文件&#xff1a;自动切换成左右2个界面 格式化代码ctrlshifth&#xff1a; 使用快捷键ctrl shift p调出控制台&#xff0c;输入install package&#xff0c;然后输入html-css-js prettify&#xff0c;进行下载。具体的快捷键在preference > package setting &g…

P1542 包裹快递 (二分答案)(内附封面)

包裹快递 题目描述 小 K 成功地破解了密文。但是乘车到 X 国的时候&#xff0c;发现钱包被偷了&#xff0c;于是无奈之下只好作快递员来攒足路费去 Orz 教主…… 一个快递公司要将 n n n 个包裹分别送到 n n n 个地方&#xff0c;并分配给邮递员小 K 一个事先设定好的路线…

PoseiSwap:首个基于模块化设施构建的订单簿 DEX

在前不久&#xff0c;PoseiSwap 曾以1000万美元的估值&#xff0c;获得了来自于ZebecLabs基金会的150万美元的融资。此后 PoseiSwap 又以2500万美元的估值&#xff0c;从GateLabs、EmurgoVentures、Republic以及CipholioVentures等行业顶级投资机构中&#xff0c;获得了新一轮未…

QMessageBox类

QMessageBox类 静态方法例子 静态方法 调用这一些静态成员函数&#xff0c;就可以得到模态提示框 枚举值为&#xff1a; 例子 头文件&#xff1a; #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QMessageBox>QT_BEGIN_NAMESPACE…