自己动手做一个mini-chatgpt

news2024/11/18 10:24:33

开场

最近chatgpt已经火爆了,几乎是家喻户晓老少皆知啊,公测推出60天后就已经是UV人数过亿,日访问量号称也是过亿。投资chatgpt研发团队的微软也是2个月内迅速推出自己的chatgpt的bing搜索,股票下载量都是暴增啊。前面文章已经介绍过chatgpt技术可能会对整个人类组织分工带来的影响以及原因,这里就不在继续歪歪了。

chatgpt的一些思考

从这篇文章开始,我打算实现一个mini版本的chatgpt,把背后的原理算法、数据准备工作都会介绍到。这系列文章预计会有7-8篇,主要是讲实现,不会介绍transformer模型技术细节、ppo数学推理。

到最后大家可以收获一个问答式的文本生成工具,大家也可以根据自己需要定制训练自己的模型做自己想要做的事,比如一个跟懂自己智能助理、解读论文的神器、可以通过语音方式理解需求帮你控制智能家居、通过语音帮你画一幅你想要的画...

第一篇先介绍整个RLHF大训练框架,介绍SFT模型训练:数据、基本模型。先介绍单个模型大家先熟悉代码在自己机器上试跑训练下数据。

第二部分会对模型改造、代码封装,让代码能够在多卡多机上训练;更工业风。

第三部分把流程封装,三部分的代码做一个整合,到这边你就可以得到一个真正能够训练中文语料的链路框架,并且可以自己准备训练标注语料。

第四部分会给大家介绍基于这个小的chatgpt引擎做的各种应用探索。

宏观介绍

整个链路包括三块:

  1. 文本生成AGGENT,为了得到一个不错Agent我们需要用‘输入-输出’语料对训练一个不错基准模型,把这个过程叫做sft

  1. 评判文本生成好坏的Reward,为了得到Reward模型我们需要用‘输入-输出list’语料做一个排序打分模型,把这个过程叫做Reward

  1. 利用Reward反馈调试Agent模型PPO调控器

fig1.sft训练过程

fig2.reward训练过程

Rank数据打标

SFT实现

先训练一个基本的有文本生成能力的模型,可以选用GPT或者T5框架模型来做训练。

from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-lyric")
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-lyric")
text_generator = TextGenerationPipeline(model, tokenizer)   
text_generator("最美的不是下雨天,是曾与你躲过雨的屋檐", max_length=100, do_sample=True)

GPT2

数据样式:

{"id": 0, "article": [12, 43, 27912, 12, 8100, 532, 21095, 33, 12, 1377, 7214, 4621, 286, 262, 890, 5041, 351, 257, 474, 5978, 284, 534, 17627, 764, 775, 1965, 1312, 6207, 3816, 284, 2648, 5205, 286, 511, 4004, 7505, 3952, 5636, 2171, 764], "abstract": [9787, 503, 8100, 13, 785, 7183, 705, 7505, 3952, 5205, 764, 1471, 19550, 287, 319, 262, 995, 705, 82, 27627, 6386, 1660, 19392, 764]}
#这部分代码拷贝命名'dataset.py'
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset

from utils import add_special_tokens


class GPT21024Dataset(Dataset):

    def __init__(self, root_dir, ids_file, mode='train',length=None):
        self.root_dir = root_dir
        self.tokenizer = add_special_tokens()

        # with open(ids_file,'r') as f:
            # if mode=='train':
            #     self.idxs = np.array(json.load(f)['train_ids'])
            # elif mode=='valid':
            #     self.idxs = np.array(json.load(f)['valid_ids'])
            # elif mode=='test':
            #     self.idxs = np.array(json.load(f)['test_ids'])

            # self.idxs = self.idxs -min(self.idxs)
        
        self.idxs = os.listdir(root_dir)
        self.mode = mode
        if len == None:
            self.len = len(self.idxs)
        else:
            self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self,idx):

        if self.mode=='valid':
            idx = self.idxs[-idx]
        elif self.mode=='test':
            idx = self.idxs[-idx-self.len]   # assuming valid and test set of same sizes
        else:
            idx = self.idxs[idx]
        # file_name = os.path.join(self.root_dir,str(idx)+".json")
        file_name = os.path.join(self.root_dir,str(idx))
        with open(file_name,'r') as f:
              data = json.load(f)
        text = self.tokenizer.encode(self.tokenizer.pad_token)*1024
        content = data['article'] + self.tokenizer.encode(self.tokenizer.sep_token) + data['abstract']
        text[:len(content)] = content
        text = torch.tensor(text)
        sample = {'article': text, 'sum_idx': len(data['article'])}
        return sample
#训练部分代码
import argparse
from datetime import datetime
import os
import time

import numpy as np
from transformers import GPT2LMHeadModel,AdamW, WarmupLinearSchedule
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tnrange, tqdm_notebook

from dataset import GPT21024Dataset 
from utils import add_special_tokens, generate_sample, set_seed

#please change default arguments if needed

parser = argparse.ArgumentParser()
parser.add_argument("--lr",default=5e-5, type=float, help="learning rate")
parser.add_argument("--seed",default=42, type=int,  help="seed to replicate results")
parser.add_argument("--n_gpu",default=1, type=int,  help="no of gpu available")
parser.add_argument("--gradient_accumulation_steps",default=2, type=int, help="gradient_accumulation_steps")
parser.add_argument("--batch_size",default=1, type=int,  help="batch_size")
parser.add_argument("--num_workers",default=4, type=int,  help="num of cpus available")
parser.add_argument("--device",default=torch.device('cpu'), help="torch.device object")
parser.add_argument("--num_train_epochs",default=1, type=int,  help="no of epochs of training")
parser.add_argument("--output_dir",default='./output', type=str,  help="path to save evaluation results")
parser.add_argument("--model_dir",default='./weights', type=str,  help="path to save trained model")
parser.add_argument("--max_grad_norm",default=1.0, type=float, help="max gradient norm.")
parser.add_argument("--root_dir",default='./CNN/gpt2_1024_data', type=str, help="location of json dataset.")
parser.add_argument("--ids_file",default='./CNN/ids.json', type=str, help="location of train, valid and test file indexes")
args = parser.parse_args([])
print(args)

def train(args, model, tokenizer, train_dataset, valid_dataset, ignore_index):
    """ Trains GPT2 model and logs necessary details.
        Args:
            args: dict that contains all the necessary information passed by user while training
            model: finetuned gpt/gpt2 model
            tokenizer: GPT/GPT2 tokenizer
            train_dataset: GPT21024Dataset object for training data
            ignore_index: token not considered in loss calculation
    """
    writer = SummaryWriter('./output/logs')
    train_sampler = RandomSampler(train_dataset)
    train_dl = DataLoader(train_dataset,sampler=train_sampler,batch_size=args.batch_size,num_workers=args.num_workers)
    loss_fct = CrossEntropyLoss(ignore_index=ignore_index) #ignores padding token for loss calculation
    optimizer = AdamW(model.parameters(),lr=args.lr)
    scheduler = WarmupLinearSchedule(optimizer,100,80000)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = tnrange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)
    for _ in train_iterator:
        epoch_iterator = tqdm_notebook(train_dl, desc="Training")
        for step, batch in enumerate(epoch_iterator):
            inputs, labels = batch['article'].to(args.device), batch['article'].to(args.device)
            model.train()
            logits = model(inputs)[0]
            # only consider loss on reference summary just like seq2seq models
            shift_logits = logits[..., batch['sum_idx']:-1, :].contiguous()
            shift_labels = labels[..., batch['sum_idx']+1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss/args.gradient_accumulation_steps
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                writer.add_scalar('loss', (tr_loss - logging_loss)/args.gradient_accumulation_steps, global_step)
                logging_loss = tr_loss
                print("loss:", loss.item(), end='\n\n')
                if (step + 1)/args.gradient_accumulation_steps == 1.0:
                    print('After 1st update: ', end='\n\n')
                    generate_sample(valid_dataset, tokenizer, model, num=2, eval_step=False,device=args.device)
                
                
            if (step + 1) % (10*args.gradient_accumulation_steps) == 0:
                results = evaluate(args, model, valid_dataset, ignore_index, global_step)
                for key, value in results.items():
                    writer.add_scalar('eval_{}'.format(key), value, global_step)
                print('After', global_step+1,'updates: ', end='\n\n')
                generate_sample(valid_dataset, tokenizer, num=2, eval_step=True,device=args.device)

# creating training and validation dataset object

train_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='train',length=3000) #training on only 3000 datasets
valid_data = GPT21024Dataset(args.root_dir,args.ids_file,mode='valid',length=500)  #validation on only 500 datasets

# load pretrained GPT2
tokenizer = add_special_tokens()
ignore_idx = tokenizer.pad_token_id
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.to(args.device)

#training the model

start = time.time()
train(args, model, tokenizer, train_data, valid_data, ignore_idx)
print('total time: ', (time.time()-start)/60, " minutes", end='\n\n')

print('Saving trained model...')
model_file = os.path.join(args.model_dir, 'model_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.bin'.format(len(train_data),args.num_train_epochs))
config_file = os.path.join(args.model_dir, 'config_data{}_trained_after_{}_epochs_only_sum_loss_ignr_pad.json'.format(len(train_data),args.num_train_epochs))
torch.save(model.state_dict(), model_file)
model.config.to_json_file(config_file)

这部分代码,我同步会整理到我的github整理完会把链接发上来。

T5

下次迭代更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/338642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络自定向下 -- 浅谈可靠性之rdt协议

可靠性数据传输原理 可靠指数据在传输过程中不错,不丢,不乱 运输层要为应用层提供一种服务:数据可以通过一条可靠的信道进行传输,在该信道中传输的数据不会受到损坏或者丢失, 实现这种服务的是可靠数据传输协议。 要实现这种服…

python的opencv操作记录11——阈值分割

文章目录传统图像处理分割阈值分割一个应用场景opencv库中的阈值分割固定阈值THRESH_OTSU 大津法阈值自适应阈值传统图像处理分割 现在提到图像分割,很多人会直接想到当前火爆的深度学习的各种分割网络,比如实例分割,语义分割等。其实在传统…

Java爬虫Selenium+Java+ChromeDriver

一、爬虫工具 selenium 是一个模拟浏览器操作的工具,背后有google 维护源代码,支持全部主流浏览器,支持主流的编程语言,包括:java,Python,C#,PHP,Ruby,等,在本项目上使用的Java语言。 官网:https://www.sel…

【Vue】参数传递:如何同时传递 DTO 和 file 文件

日常开发时,如果遇到较为复杂的 DTO 对象的参数传递时,通常前端使用的请求头为:application/json(JSON 格式的数据);而后端可以使用 RequestBody 接收一个 DTO 对象。但当需要在一个界面上同时传递 DTO 和附…

Java集合中的Map

MapMap接口键 值 对存储键不能重复&#xff0c;值可以重复Map三个实现类的存储结构HashMap&#xff1a;Hash表链表红黑树结构 线程不安全TreeMap&#xff1a; 底层红黑树实现HashTable&#xff1a;hash表链表红黑树 线程安全HashMapHashMap常用方法HashMap<String,String>…

[测开篇]设计测试用例的方法如何正确描述Bug

​ 文章目录为什么测试人员要写测试用例&#xff1f;怎样设计测试用例&#xff1f;&#xff08;总的方面&#xff09;1.基于需求设计测试用例&#xff08;总的方面&#xff09; 2.页面&#xff08;总的方面&#xff09; 3.非功能性测试&#xff08;具体方面&#xff09; 4.1 等…

「Python|环境安装|Windows」如何在Windows上安装Python环境?

本文主要介绍如何在Windows上安装Python&#xff0c;帮助初学者或者非程序员伙伴快速搭建可以运行python代码的环境。 文章目录安装python做一点小配置验证python如何安装指定版本的python编程语言的环境搭建一直是学习编程的第一道门槛。 对于如何在Linux系统上安装指定版本的…

谷歌蜘蛛池怎么搭建?Google蜘蛛池可以帮助谷歌排名吗?

本文主要分享关于谷歌蜘蛛池的搭建疑问&#xff0c;以及Google对谷歌排名的影响到底有多大。 本文由光算创作&#xff0c;有可能会被剽窃和修改&#xff0c;我们佛系对待这种行为吧。 谷歌蜘蛛池怎么搭建&#xff1f; 答案是&#xff1a;需要一个内链外链体系复杂的站群系统…

154、【动态规划】leetcode ——494. 目标和:回溯法+动态规划(C++版本)

题目描述 原题链接&#xff1a;494. 目标和 解题思路 &#xff08;1&#xff09;回溯法 本题的特点是nums中每个元素只能使用一次&#xff0c;分别试探加上nums[index]和减去nums[index]&#xff0c;然后递归的遍历下一个元素index 1。 class Solution { public:int res …

java中flatMap用法

java中map是把集合每个元素重新映射&#xff0c;元素个数不变&#xff0c;但是元素值发生了变化。而flatMap从字面上来说是压平这个映射&#xff0c;实际作用就是将每个元素进行一个一对多的拆分&#xff0c;细分成更小的单元&#xff0c;返回一个新的Stream流&#xff0c;新的…

1629_MIT_6.828_xv6_chapter1操作系统的组织

全部学习汇总&#xff1a;GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 这一次整理一下操作系统组织相关的知识&#xff0c;主要还是xv6教学操作系统相关的知识。当然&#xff0c;很多知识在这类技术领域是通用的。 1. 操作系统的主要功能…

SAP ABAP Odata

GetEntity和GetEntitys GetEntitys 创建Odata Project 导入结构 选择需要的字段 设定Key 勾选字段的creatable、updatable、sortable、nullable、filterable属性值。 再依上述步骤创建ZPOITEM结构和实体集 3. 创建ZPOHEADER和ZPOITEM的Association 两个实体集的关联字段&…

RocketMQ-消息消费模式 顺序消费

RocketMQ-消息消费模式 顺序消费RocketMQ-消息消费模式集群模式集群模式的演示(本身就默认)Rocketmq存储队列广播模式顺序消费如何改实现顺序消费RocketMQ-消息消费模式 集群模式 在消费模式为集群的情况下,如果机器是集群的,消息只会给集群中的其中一台机器消费到 集群模…

【数据结构】双向链表的模拟实现(无头)

目录 前言&#xff1a; 1、认识双向链表中的结点 2、认识并创建无头双向链表 3、实现双向链表当中的一些方法 3.1、遍历输出方法&#xff08;display&#xff09; 3.2、得到链表的长度&#xff08;size&#xff09; 3.3、查找关键字key是否包含在双链表中(contains) 3.…

基于I2S通讯MAX98357模块的JetsonNano声音外放

前言有很多方法可以为 Jetson 设备添加音频功能。USB 扬声器和USB 麦克风是一种简单的解决方案&#xff0c;但它们确实占用了宝贵的 USB 插槽&#xff0c;这些插槽可能更适合用于键盘、蓝牙功能、Internet Keys 和其他配件。在 Jetson 设备上&#xff0c;NVIDIA 通过 40 针 GPI…

微软发布会精华回顾:“台式电脑”抢了风头

Lightbot北京时间2016年10月26日晚10点&#xff0c;微软在纽约发布了名为 Surface Studio 的一体机、名为 Surface Dial 的配件以及外观未变的顶配版 Surface Book。同时&#xff0c;微软宣布了 Windows 10 下一个重要版本——“Creators Update”的数项新功能&#xff0c;包括…

【Linux】冯诺依曼体系结构和操作系统概念

文章目录&#x1f3aa; 冯诺依曼体系结构&#x1f680;1.体系概述&#x1f680;2.CPU和内存的数据交换&#x1f680;3.体系结构中数据的流动&#x1f3aa; 操作系统概念理解&#x1f680;1.简述&#x1f680;2.设计目的&#x1f680;3.定位&#x1f680;4.理解&#x1f680;5.管…

AOP面向切面编程思想。

目录 一、AOP工作流程 1、基本概念 2、AOP工作流程 二、AOP核心配置 1、AOP切入点表达式 2、AOP通知类型 三、AOP通知获取数据 1、获取参数 2、获取返回值 3、获取异常 四、AOP事务管理 1、Spring事务简介 2、Spring事务角色 3、事务属性 一、AOP工作流程 1、…

Linux内核启动(理论,0.11版本)分段与分页

为什么要虚拟内存 我们知道&#xff0c;在之前上微机原理时&#xff0c;我们的程序是可以直接访问内存的&#xff0c;而且访问的是直接的物理内存&#xff0c;在实模式下&#xff0c;寄存器是16位的&#xff0c;数组总线&#xff08;data bus&#xff09;是16位的&#xff0c;…

设计模式-值类型与引用类型、深拷贝与浅拷贝、原型模式详解

一. 值类型和引用类型 1. 前言 (1). 分类 值类型包括&#xff1a;布尔类型、浮点类型(float、double、decimal、byte)、字符类型(char)、整型&#xff08;int、long、short等&#xff09;、枚举(entum)、结构体(struct)。 引用类型&#xff1a;数组、字符串(string)、类、接口…