精进语言模型:探索LLM Training微调与奖励模型技术的新途径

news2024/10/1 19:34:31

大语言模型训练(LLM Training)

LLMs Trainer 是一个旨在帮助人们从零开始训练大模型的仓库,该仓库最早参考自 Open-Llama,并在其基础上进行扩充。

有关 LLM 训练流程的更多细节可以参考 【LLM】从零开始训练大模型。

使用仓库之前,请先安装所有需要的依赖:

pip install -r requirements.txt

1. 继续预训练(Continue Pretraining)

继续预训练是指,在一个已有的模型上继续进行预训练增强,通常用于 英文模型的中文增强 或是 领域数据增强

我们这里以英文模型 OpenLlama 在中文数据集 MNBVC 中的 少量数据 为例来演示整个流程。

1.1 数据压缩

由于预训练数据集通常比较庞大,因此先将训练数据进行压缩并流氏读取。

首先,进入到 data 目录:

cd data

找到目录下的 compress_data.py, 在该文件中修改需要压缩的数据路径:

SHARD_SIZE = 10      # 单个文件存放样本的数量, 示例中使用很小,真实训练可以酌情增大
...

def batch_compress_preatrain_data():
    """
    批量压缩预训练数据。
    """
    source_path = 'shuffled_data/pretrain'                  # 源数据文件
    target_path = 'pretrain_data'                           # 压缩后存放地址

    files = [                                               # 这三个文件是示例数据
        'MNBVC_news',
        'MNBVC_qa',
        'MNBVC_wiki'
    ]
    ...

if __name__ == '__main__':
    batch_compress_preatrain_data()
    # batch_compress_sft_data()

Notes: 上述的 files 可以在 shuffled_data/pretrain/ 中找到,是我们准备的少量示例数据,真实训练中请替换为完整数据。

data 路径中执行 python compress_data.py, 终端将显示:

processed shuffled_data/pretrain/MNBVC_news.jsonl...
total line: 100
total files: 10
processed shuffled_data/pretrain/MNBVC_qa.jsonl...
total line: 50
total files: 5
processed shuffled_data/pretrain/MNBVC_wiki.jsonl...
total line: 100
total files: 10

随后可在 pretrain_data 中找到对应的 .jsonl.zst 压缩文件(该路径将在之后的训练中使用)。

1.2 数据源采样比例(可选)

为了更好的进行不同数据源的采样,我们提供了按照预设比例进行数据采样的功能。

我们提供了一个可视化工具用于调整不同数据源之间的分布,在 根目录 下使用以下命令启动:

streamlit run utils/sampler_viewer/web.py --server.port 8001

随后在浏览器中访问 机器IP:8001 即可打开平台。

我们查看 data/shuffled_data/pretrain 下各数据的原始文件大小:

-rw-r--r--@ 1 xx  staff   253K Aug  2 16:38 MNBVC_news.jsonl
-rw-r--r--@ 1 xx  staff   121K Aug  2 16:38 MNBVC_qa.jsonl
-rw-r--r--@ 1 xx  staff   130K Aug  2 16:37 MNBVC_wiki.jsonl

并将文件大小按照格式贴到平台中:

调整完毕后,复制上图右下角的最终比例,便于后续训练使用。

1.3 词表扩充(可选)

由于原始 Llama 的中文 token 很少,因此我们可以选择对原有的 tokenizer 进行词表扩充。

进入到 utils 目录:

cd utils

修改文件 train_tokenizer.py 中的训练数据(我们使用正式预训练训练数据集作为训练词表的数据集):

...
dataset = {
    "MNBVC_news": "../data/pretrain_data/MNBVC_news/*.jsonl.zst",
    "MNBVC_qa": "../data/pretrain_data/MNBVC_qa/*.jsonl.zst",
    "MNBVC_wiki": "../data/pretrain_data/MNBVC_wiki/*.jsonl.zst",
}

执行完 train_tokenizer.py 后,路径下会出现训练好的模型 test_tokenizer.model

随后,我们将训练好的 model 和原本的 llama model 做融合:

python merge_tokenizer.py

你可以使用 这个工具 很方便的对合并好后的 tokenizer 进行可视化。

1.4 平均初始化 extend token embedding(可选)

为了减小扩展的 token embedding 随机初始化带来模型性能的影响,我们提供使用将新 token 在原 tokenizer 中的 sub-token embedding 的平均值做为初始化 embedding 的方法。

具体使用方法在 utils/extend_model_token_embeddings.py

1.5 正式训练

当完成上述步骤后就可以开始正式进行训练,使用以下命令启动训练:

sh train_llms.sh configs/accelerate_configs/ds_stage1.yaml \
    configs/pretrain_configs/llama.yaml \
    openlm-research/open_llama_7b_v2

多机多卡则启动:

sh train_multi_node_reward_model.sh configs/accelerate_configs/ds_stage1.yaml \
    configs/pretrain_configs/llama.yaml \
    openlm-research/open_llama_7b_v2

注意,所有的训练配置都放在了第 2 个参数 configs/pretrain_configs/llama.yaml 中,我们挑几个重要的参数介绍。

  • tokenizer_path (str):tokenizer 加载路径。

  • ckpt (str):初始 model 加载路径。

  • sample_policy_file (str):数据源采样配置文件,若不包含这一项则不进行数据源采样。

  • train_and_eval (bool):该参数决定了是否在训练中执行评估函数。

  • img_log_dir (str):训练过程中的 log 图存放目录。

  • eval_methods (list):使用哪些评估函数,包括:

    • single_choice_eval: 单选题正确率测试(如: C-Eval),评估数据格式参考 eval_data/knowledge/knowledge_and_reasoning.jsonl

    • generation_eval: 生成测试,给定 prompt,测试模型生成能力,评估数据格式参考 eval_data/pretrain/generation_test.jsonl

  • work_dir (str):训练模型存放路径。

  • save_total_limit (int):最多保存的模型个数(超过数目则删除旧的模型)

2. 指令微调(Instruction Tuning)

我们准备了部分 ShareGPT 的数据作为示例数据,我们仍旧使用 OpenLlama 作为训练的基座模型。

2.1 数据压缩

同预训练一样,我们先进入到 data 目录:

cd data

找到目录下的 compress_data.py, 在该文件中修改需要压缩的数据路径:

SHARD_SIZE = 10      # 单个文件存放样本的数量, 示例中使用很小,真实训练可以酌情增大
...

def batch_compress_sft_data():
    """
    批量压缩SFT数据。
    """
    source_path = 'shuffled_data/sft'
    target_path = 'sft_data'

    files = [
        'sharegpt'
    ]
    ...

if __name__ == '__main__':
    # batch_compress_preatrain_data()
    batch_compress_sft_data()

Notes: 上述的 files 可以在 shuffled_data/sft/ 中找到,是我们准备的少量示例数据,真实训练中请替换为完整数据。

data 路径中执行 python compress_data.py, 终端将显示:

processed shuffled_data/sft/sharegpt.jsonl...
total line: 9637
total files: 964

随后可在 sft_data 中找到对应的 .jsonl.zst 压缩文件(该路径将在之后的训练中使用)。

2.2 特殊 token 扩充

受到 ChatML 的启发,我们需要在原有的 tokenizer 中添加一些 special token 用于对话系统。

一种最简单的方式是在 tokenizer 路径中找到 special_tokens_map.json 文件,并添加以下内容:

{
    ...                                         # 需要添加的特殊 token
    "system_token": "<|system|>",               # system prompt
    "user_token": "<|user|>",                   # user token
    "assistant_token": "<|assistant|>",         # chat-bot token
    "chat_end_token": "<|endofchat|>"           # chat end token
}

2.3 微调训练

当完成上述步骤后就可以开始正式进行训练,使用以下命令启动训练:

sh train_llms.sh configs/accelerate_configs/ds_stage1.yaml \
    configs/sft_configs/llama.yaml \
    openlm-research/open_llama_7b_v2

多机多卡则启动:

sh train_multi_node_reward_model.sh configs/accelerate_configs/ds_stage1.yaml \
    configs/sft_configs/llama.yaml \
    openlm-research/open_llama_7b_v2

注意,所有的训练配置都放在了第 2 个参数 configs/sft_configs/llama.yaml 中,我们挑几个重要的参数介绍。

  • tokenizer_path (str):tokenizer 加载路径。

  • ckpt (str):初始 model 加载路径。

  • train_and_eval (bool):该参数决定了是否在训练中执行评估函数。

  • img_log_dir (str):训练过程中的 log 图存放目录。

  • eval_methods (list):使用哪些评估函数,包括:

    • generation_eval: 生成测试,给定 prompt,测试模型生成能力,评估数据格式参考 eval_data/sft/share_gpt_test.jsonl

    • 暂无。

  • work_dir (str):训练模型存放路径。

  • save_total_limit (int):最多保存的模型个数(超过数目则删除旧的模型)

3. 奖励模型(Reward Model)

3.1 数据集准备

我们准备 1000 条偏序对作为示例训练数据,其中 selected 为优势样本,rejected 为劣势样本:

{
    "prompt": "下面是一条正面的评论:",
    "selected": "很好用,一瓶都用完了才来评价。",
    "rejected": "找了很久大小包装都没找到生产日期。上当了。"
}

这个步骤不再需要数据压缩,因此准备好上述结构的 .jsonl 文件即可。

3.2 RM 训练

当完成上述步骤后就可以开始正式进行训练,使用以下命令启动训练:

sh train_multi_node_reward_model.sh \
    configs/accelerate_configs/ds_stage1.yaml \
    configs/reward_model_configs/llama7b.yaml

注意,所有的训练配置都放在了第 2 个参数 configs/reward_model_configs/llama.yaml 中,我们挑几个重要的参数介绍。

  • tokenizer_path (str):tokenizer 加载路径。

  • ckpt (str):初始 model 加载路径。

  • train_and_eval (bool):该参数决定了是否在训练中执行评估函数。

  • img_log_dir (str):训练过程中的 log 图存放目录。

  • test_reward_model_acc_files (list):acc 测试文件列表。

  • work_dir (str):训练模型存放路径。

  • save_total_limit (int):最多保存的模型个数(超过数目则删除旧的模型)

项目链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/LLM/LLMsTrainer/readme.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/919500.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

36k字从Attention解读Transformer及其在Vision中的应用(pytorch版)

文章目录 0.卷积操作1.注意力1.1 注意力概述(Attention)1.1.1 Encoder-Decoder1.1.2 查询、键和值1.1.3 注意力汇聚: Nadaraya-Watson 核回归1.2 注意力评分函数1.2.1 加性注意力1.2.2 缩放点积注意力1.3 自注意力(Self-Attention)1.3.1 自注意力的定义和计算1.3.2 自注意…

DataFrame.query()--Pandas

1. 函数功能 Pandas 中的一个函数&#xff0c;用于在 DataFrame 中执行查询操作。这个方法会返回一个新的 DataFrame&#xff0c;其中包含符合查询条件的数据行。请注意&#xff0c;query 方法只能用于筛选行&#xff0c;而不能用于筛选列。 2. 函数语法 DataFrame.query(ex…

【OJ比赛日历】快周末了,不来一场比赛吗? #08.26-09.01 #16场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…&#xff09;比赛。本账号会推送最新的比赛消息&#xff0c;欢迎关注&#xff01; 以下信息仅供参考&#xff0c;以比赛官网为准 目录 2023-08-26&#xff08;周六&#xff09; #8场比赛2023-08-27…

redis在linux和windows上的安装配置(解决问题:没有可用软件包 redis)

linux系统 安装 yum install redis安装 在终端输入yum install redis安装。 报错&#xff1a;没有可用软件包 redis。 解决&#xff1a; 运行以下命令更新软件包信息&#xff1a; sudo yum clean all sudo yum update 然后继续尝试yum install redis。 如果还不成功&a…

探索最短路径问题:寻找优化路线的算法解决方案

1. 前言&#xff1a;最短路径问题的背景与重要性 在现实生活中&#xff0c;我们常常面临需要找到最短路径的情况&#xff0c;如地图导航、网络路由等。最短路径问题是一个关键的优化问题&#xff0c;涉及在图中寻找两个顶点之间的最短路径&#xff0c;以便在有限时间或资源内找…

最新AI系统ChatGPT程序源码+搭建部署教程/支持GPT4/支持ai绘画/H5端/完整知识库

一、AI系统 如何搭建部署AI创作ChatGPT系统呢&#xff1f;小编这里写一个详细图文教程吧&#xff01; SparkAi使用Nestjs和Vue3框架技术&#xff0c;持续集成AI能力到AIGC系统&#xff01; 程序核心功能&#xff1a; 程序已支持ChatGPT3.5/4.0提问、AI绘画、Midjourney绘画&…

【Axure原型分享】能统计中英文字数的多行输入框

今天和大家分享能统计中英文字数的多行输入框的原型模板&#xff0c;在输入框里输入内容后&#xff0c;能够动态根据输入框的内容&#xff0c;统计出字符数量&#xff0c;包括总字数、中文字数、英文字数、数字字数、其他标点符号的字数&#xff0c;具体效果可以观看下方视频或…

微服务架构2.0--云原生时代

云原生 云原生&#xff08;Cloud Native&#xff09;是一种关注于在云环境中构建、部署和管理应用程序的方法和理念。云原生应用能够最大程度地利用云计算基础设施的优势&#xff0c;如弹性、自动化、可伸缩性和高可用性。这个概念涵盖了许多方面&#xff0c;包括架构、开发、…

DataLoader

机器学习的五个步骤&#xff1a; 数据模块——模型——损失函数——优化器——训练 在实际项目中&#xff0c;如果数据量很大&#xff0c;考虑到内存有限、I/O 速度等问题&#xff0c;在训练过程中不可能一次性的将所有数据全部加载到内存中&#xff0c;也不能只用一个进程去加…

mmdetection基于 PyTorch 的目标检测开源工具箱 入门教程

安装环境 MMDetection 支持在 Linux&#xff0c;Windows 和 macOS 上运行。它需要 Python 3.7 以上&#xff0c;CUDA 9.2 以上和 PyTorch 1.8 及其以上。 1、安装依赖 步骤 0. 从官方网站下载并安装 Miniconda。 步骤 1. 创建并激活一个 conda 环境。 conda create --name…

厦门逗客传媒:抖音本地团购怎么入驻

随着社交媒体的不断发展&#xff0c;短视频平台已经成为了商家推广和营销的热门渠道之一。在这其中&#xff0c;抖音作为全球知名的短视频平台&#xff0c;以其巨大的用户基数和精准的推荐算法吸引了大量商家的关注。而在抖音上&#xff0c;本地团购也成为了一个备受关注的领域…

控制Unity发布的PC包的窗体

大家好&#xff0c;我是阿赵。   用Unity发布PC包接入某些渠道时&#xff0c;有时候会收到一些特殊的需求&#xff0c;比如控制窗口最大化(比如某些情况强制显示窗体)、最小化(比如老板键)、强制规定窗体置顶等。虽然我一直认为这些需求都是流氓软件行为&#xff0c;但作为一…

【每日易题】七夕限定——单身狗问题以及进阶问题位运算法的深入探讨

君兮_的个人主页 勤时当勉励 岁月不待人 C/C 游戏开发 Hello,米娜桑们&#xff0c;这里是君兮_&#xff0c;在写这篇博客的前一天是七夕&#xff0c;也是中国传统的“情人节”&#xff0c;不知道各位脱单了吗&#xff1f;碰巧最近刷题时遇到了经典的单身狗问题想带大家深入探…

消息队列前世今生 字节跳动 Kafka #创作活动

消息队列前世今生 1.1 案例一&#xff1a; 系统崩溃 首先大家跟着我想象一下下面的这个的场景&#xff0c; 看到新出的游戏机&#xff0c;太贵了买不起&#xff0c;这个时候你突然想到&#xff0c;今天抖音直播搞活动&#xff0c;打开抖音搜索&#xff0c;找到直播间以后&am…

JVM——类加载与字节码技术—编译期处理+类加载阶段

3.编译期处理 编译期优化称为语法糖 3.1 默认构造器 3.2 自动拆装箱 java基本类型和包装类型之间的自动转换。 3.3泛型集合取值 在字节码中可以看见&#xff0c;泛型擦除就是字节码中的执行代码不区分是String还是Integer了&#xff0c;统一用Object. 对于取出的Object&…

【ARM】Day9 cortex-A7核I2C实验(采集温湿度)

1. 2、编写IIC协议&#xff0c;采集温湿度值 iic.h #ifndef __IIC_H__ #define __IIC_H__ #include "stm32mp1xx_gpio.h" #include "stm32mp1xx_rcc.h" #include "led.h" /* 通过程序模拟实现I2C总线的时序和协议* GPIOF ---> AHB4* I2C1_S…

IoT DC3 是一个基于 Spring Cloud 的开源的、分布式的物联网(IoT)平台本地部署步骤

dc3 windows 本地搭建步骤&#xff1a; ​​ 必要软件环境 进入原网页# 务必保证至少需要给 docker 分配&#xff1a;1 核 CPU 以及 4G 以上的运行内存&#xff01; JDK : 推荐使用 Oracle JDK 1.8 或者 OpenJDK8&#xff0c;理论来说其他版本也行&#xff1b; Maven : 推荐…

记录《现有docker中安装spark3.4.1》

基础docker环境中存储hadoop3--方便后续查看 参考&#xff1a; 实践&#xff1a; export JAVA_HOME/opt/apache/jdk1.8.0_333 export SPARK_MASTER_IP192.168.0.220 export SPARK_WORKER_MEMORY4g export SPARK_WORKER_CORES2 export SPARK_EXECUTOR_MEMORY4g export HADOOP_H…

『SEQ日志』在 .NET中快速集成轻量级的分布式日志平台

&#x1f4e3;读完这篇文章里你能收获到 如何在Docker中部署 SEQ&#xff1a;介绍了如何创建和运行 SEQ 容器&#xff0c;给出了详细的执行操作如何使用 NLog 接入 .NET Core 应用程序的日志&#xff1a;详细介绍了 NLog 和 NLog.Seq 来配置和记录日志的步骤日志记录示例&…

微服务中间件--MQ

MQ MQa.安装RabbitMQb.消息模型c.SpringAMQP发送和接收d.WorkQueue模型e.发布订阅模型1) FanoutExchange2) DirectExchange3) TopicExchange f.消息转换器 MQ 同步调用的问题 微服务间基于Feign的调用就属于同步方式&#xff0c;存在一些问题。 耦合度高&#xff1a;每次加入…