ChatGLM Efficient Tuning效率调试PEFT

news2024/12/25 8:58:33

ChatGLM Efficient Tuning

   

基于 

PEFT 的高效 ChatGLM-6B 微调。

[ English | 中文 ]

更新日志

[23/06/05] 现在我们实现了 4 比特的 LoRA 训练(也称 QLoRA)。请尝试使用 --quantization_bit 4 参数进行 4 比特量化微调。(实验性功能)

[23/06/01] 我们开源了支持 LLaMA 和 BLOOM 系列模型的高效微调框架,如果您感兴趣请关注我们的 LLaMA-Efficient-Tuning 项目。

[23/06/01] 我们新增了一个使用监督微调和 RLHF 训练医疗问答模型的例子,请移步 covid_doctor.md 查阅。

[23/05/19] 现在我们支持了在模型训练时使用验证集评估性能。请尝试使用 --dev_ratio 参数指定验证集大小

[23/04/29] 现在我们实现了 RLHF(基于人类反馈的强化学习) 训练!我们提供了几个运行 RLHF 的例子,具体内容请移步 examples 文件夹。

[23/04/25] 我们新增了一个使用自定义数据集分布式训练的例子,请移步 ads_generation.md 查阅。

[23/04/20] 我们的项目在 12 天内获得了 100 个 Star!祝贺!

[23/04/20] 我们新增了一个修改模型自我认知的例子,请移步 alter_self_cognition.md 查阅。

[23/04/19] 现在我们实现了模型融合!请尝试使用 --checkpoint_dir checkpoint1,checkpoint2 参数训练融合 LoRA 权重后的模型。

[23/04/18] 现在可以微调量化模型了!请尝试使用 quantization_bit 参数进行 4 比特或 8 比特量化微调。

[23/04/12] 现在我们加入了断点训练支持!请尝试给定 --checkpoint_dir 参数加载指定的模型断点。

[23/04/11] 现在我们实现了数据集组合训练!请尝试使用 --dataset dataset1,dataset2 参数进行组合训练。

数据集

目前我们实现了针对以下数据集的支持:

  • Stanford Alpaca
  • Stanford Alpaca (Chinese)
  • GPT-4 Generated Data
  • BELLE 2M
  • BELLE 1M
  • BELLE 0.5M
  • BELLE Dialogue 0.4M
  • BELLE School Math 0.25M
  • BELLE Multiturn Chat 0.8M
  • Guanaco Dataset
  • Firefly 1.1M
  • CodeAlpaca 20k
  • Alpaca CoT
  • Web QA (Chinese)
  • UltraChat

使用方法请参考 data/README.md 文件。

部分数据集的使用需要确认,我们推荐使用下述命令登录您的 HuggingFace 账户。

pip install --upgrade huggingface_hub
huggingface-cli login

微调方法

目前我们实现了针对以下高效微调方法的支持:

  • LoRA
    • 仅微调低秩适应器。
  • P-Tuning V2
    • 仅微调前缀编码器。
  • Freeze
    • 仅微调后几层的全连接层。

软件依赖

  • Python 3.8+, PyTorch 1.13.1
  • Transformers, Datasets, Accelerate, PEFT, TRL
  • protobuf, cpm_kernels, sentencepiece
  • jieba, rouge_chinese, nltk(用于评估)
  • gradio, mdtex2html(用于网页端交互)

以及 强而有力的 GPU

如何使用

数据准备(可跳过)

关于数据集文件的格式,请参考 data/example_dataset 文件夹的内容。构建自定义数据集时,既可以使用单个 .json 文件,也可以使用一个数据加载脚本和多个文件。

注意:使用自定义数据集时,请更新 data/dataset_info.json 文件,该文件的格式请参考 data/README.md

环境搭建(可跳过)

git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
conda create -n chatglm_etuning python=3.10
conda activate chatglm_etuning
cd ChatGLM-Efficient-Tuning
pip install -r requirements.txt

对于 Windows 用户,若要启用 LoRA 或 Freeze 的量化微调,请下载预构建的 bitsandbytes 包,目前仅支持 CUDA 11.6 和 11.7。

pip install https://github.com/acpopescu/bitsandbytes/releases/download/v0.37.2-win.1/bitsandbytes-0.37.2-py3-none-any.whl

单 GPU 微调训练

CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
    --do_train \
    --dataset alpaca_gpt4_zh \
    --finetuning_type lora \
    --output_dir path_to_sft_checkpoint \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --fp16

关于参数信息,请查阅我们的维基。

多 GPU 分布式微调

accelerate config # 首先配置分布式环境
accelerate launch src/train_sft.py # 参数同上

注意:若您使用 LoRA 方法进行微调,请指定以下参数 --ddp_find_unused_parameters False 来避免报错。

奖励模型训练

CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
    --do_train \
    --dataset comparison_gpt4_zh \
    --finetuning_type lora \
    --output_dir path_to_rm_checkpoint \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-5 \
    --num_train_epochs 1.0 \
    --fp16

RLHF 训练

CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
    --do_train \
    --dataset alpaca_gpt4_zh \
    --finetuning_type lora \
    --checkpoint_dir path_to_sft_checkpoint \
    --reward_model path_to_rm_checkpoint \
    --output_dir path_to_ppo_checkpoint \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-5 \
    --num_train_epochs 1.0 \
    --fp16

指标评估(BLEU分数和汉语ROUGE分数

CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
    --do_eval \
    --dataset alpaca_gpt4_zh \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_eval_result \
    --per_device_eval_batch_size 8 \
    --max_samples 50 \
    --predict_with_generate

模型预测

CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
    --do_predict \
    --dataset alpaca_gpt4_zh \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_predict_result \
    --per_device_eval_batch_size 8 \
    --max_samples 50 \
    --predict_with_generate

命令行测试

python src/cli_demo.py \
    --checkpoint_dir path_to_checkpoint

浏览器测试

python src/web_demo.py \
    --checkpoint_dir path_to_checkpoint

导出微调模型

python src/export_model.py \
    --checkpoint_dir path_to_checkpoint \
    --output_dir path_to_export

硬件需求

微调方法批处理大小模式GPU显存速度
LoRA (r=8)16FP1628GB8ex/s
LoRA (r=8)8FP1624GB8ex/s
LoRA (r=8)4FP1620GB8ex/s
LoRA (r=8)4INT810GB8ex/s
LoRA (r=8)4INT48GB8ex/s
P-Tuning (p=16)4FP1620GB8ex/s
P-Tuning (p=16)4INT816GB8ex/s
P-Tuning (p=16)4INT412GB8ex/s
Freeze (l=3)4FP1624GB8ex/s
Freeze (l=3)4INT812GB8ex/s
奖励模型训练方法批处理大小模式GPU显存速度
LoRA (r=8) + rm4FP1622GB-
LoRA (r=8) + rm1INT811GB-
RLHF 训练方法批处理大小模式GPU显存速度
LoRA (r=8) + ppo4FP1623GB-
LoRA (r=8) + ppo1INT812GB-

注:r 为LoRA 维数大小,p 为前缀词表大小,l 为微调层数,ex/s 为每秒训练的样本数。gradient_accumulation_steps 参数设置为 1。上述结果均来自于单个 Tesla V100 GPU,仅供参考。

微调 ChatGLM 的例子

训练结果

我们使用整个 alpaca_gpt4_zh 数据集微调 ChatGLM 模型,使用秩为 8 的 LoRA 方法,使用默认超参数进行单轮训练。下图为训练损失变化曲线。

评估结果

我们选择 alpaca_gpt4_zh 数据集中的前一百条数据来评估微调后的 ChatGLM 模型,并计算 BLEU 和中文 ROUGE 分数。下表为评估结果。

分数原版模型FZ (l=2)PT (p=16)LoRA (r=8)
BLEU-415.7516.8516.0617.01 (+1.26)
Rouge-134.5136.6234.8036.77 (+2.26)
Rouge-215.1117.0415.3216.83 (+1.72)
Rouge-l26.1828.1726.3528.86 (+2.68)
训练参数/4.35%0.06%0.06%

FZ:Freeze 微调,PT:P-Tuning V2 微调(为了与 LoRA 公平比较,我们使用了 pre_seq_len=16),训练参数:可训练参数占全部参数的百分比。

和现有类似项目的比较

  • THUDM/ChatGLM-6B
    • ChatGLM 基于 P-Tuning v2 微调的官方实现,使用了 ADGEN 数据集。
    • 本仓库的代码实现绝大部分参考该项目。我们进一步实现了 LoRA 微调方法。此外,我们动态地将每个批处理数据中的序列进行填充,而非将其填充到模型的最大长度,此改进可以加速模型训练。
  • mymusise/ChatGLM-Tuning
    • ChatGLM 基于 LoRA 微调的非官方实现,使用了 Stanford Alpaca 数据集。
    • 我们借鉴了该项目的一些想法。我们的训练脚本将数据预处理部分集成至训练脚本中,以避免事先生成预处理后的数据。
  • ssbuild/chatglm_finetuning
    • ChatGLM 基于多种微调方法的非官方实现,使用了 Stanford Alpaca 数据集。
    • 我们的训练脚本全部基于 Huggingface transformers 框架实现,不依赖于额外的 deep_training 框架。
  • lich99/ChatGLM-finetune-LoRA
    • ChatGLM 基于 LoRA 微调的非官方实现,使用了 Stanford Alpaca 数据集。
    • 我们利用 Huggingface PEFT 框架来引入最先进的微调方法。
  • liucongg/ChatGLM-Finetuning
    • ChatGLM 基于参数冻结、LoRA 和 P-Tuning 微调的非官方实现,使用了汽车工业数据集。
    • 我们旨在引入更多指令遵循数据集用于微调 ChatGLM 模型。
  • yanqiangmiffy/InstructGLM
    • ChatGLM 微调的非官方实现,旨在探索 ChatGLM 在指令遵循数据集上的潜力。
    • 我们将数据预处理部分集成到训练脚本中。

TODO

  •  利用 LangChain 实现能够利用外部知识的基于 ChatGLM 微调模型应用的轻松构建。
  •  实现对齐算法使模型对齐人类意图。
    •  RLHF
    •  RRHF
    •  RAFT
  •  加入更多中文数据集。
    •  BELLE
    •  pCLUE
    •  CLUECorpus
    •  GuanacoDataset
    •  FireflyDataset
  •  加入基于 ChatGPT 和 GPT-4 产生的数据集。
    •  Baize
    •  GPT-4-LLM
  •  实现参数冻结和 P-Tuning 微调方法。
  •  支持多GPU训练。
  •  加入模型评估脚本。
  •  断点加载。
  •  量化微调。
  •  撰写基于该框架的 ChatGLM 模型微调指南手册。
  •  结合模型编辑技术。(例如:MEND)
  •  加入 OpenAssistant 对话数据集用于监督微调和意图对齐。
  •  加入高质量中文开源指令数据集 COIG。

协议

本仓库的代码依照 Apache-2.0 协议开源。ChatGLM-6B 模型的使用请遵循模型协议。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/671595.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Echarts实现流程图关系图拓扑图

实现如下&#xff0c;可以横着排竖着排都可以 1.先写个div做画布 ref值随意&#xff0c;但是一点要写 <div style"height: 400px;" ref"echartdom"></div> 2.下载echarts 我这边下载的是 "echarts": "^4.9.0",最新版应…

奢侈品回收APP系统开发功能有哪些?

奢侈品售卖回收APP系统开发功能有哪些&#xff1f; 1.回收品牌分类&#xff1a;奢侈品回收APP平台可以将支持回收鉴定的奢侈品品牌及商品进行分类展示&#xff0c;方便用户查看自己的想要出售的是不是平台支持的商品。 2.商品在线检索&#xff1a;客户可以直接按…

STM32F4_红外遥控

目录 1. 红外遥控简介 2. NEC协议 3. 硬件设计 4. 实验程序详解 4.1 main.c 4.2 Remote.c 4.3 Remote.h 1. 红外遥控简介 红外遥控是一种无线、非接触的控制技术。具有抗干扰能力强&#xff0c;信息传输可靠&#xff0c;功耗低&#xff0c;成本低&#xff0c;易实现等优…

深入理解Redis的AOF和RDB持久化机制

Redis的AOF&#xff08;Append-Only File&#xff09;和RDB&#xff08;Redis Database&#xff09;是两种常见的持久化机制&#xff0c;用于将内存中的数据保存到磁盘上&#xff0c;确保数据在Redis重新启动时的持久性。本文将深入介绍AOF和RDB的原理和使用&#xff0c;帮助读…

HQChart实战教程65-自定义手机端分时图tooltip显示数据

HQChart实战教程65-自定义手机端分时图tooltip显示数据 手机端分时图tooltip步骤1. 配置手机端tooltip2. 替换k线tooltip格式化输出函数2. 格式化输出函数说明HQChart插件源码地址完整的demo源码手机端分时图tooltip hqchart手机端内置一个tooltip,显示手势所在K线的信息。默认…

邮件打开率低?来看看这几招提高邮件打开率!

什么是邮件打开率&#xff1f; 邮件打开率&#xff1a;简单来讲就是收件人打开的邮件数占发送邮件总数的百分比。 我们要做的就是如何吸引收件人打开邮件&#xff0c;那可以从以下几个方面来考虑&#xff1a; 1、邮件标题 邮件标题直接向收件人表达了这封邮件是关于什么的&am…

CSS样式优先级怎样划分?【CSS优先级规则】

定义CSS样式时&#xff0c;经常出现两个或更多样式规则应用在同一元素上的情况。此时CSS就会根据样式规则的权重&#xff0c;优先显示权重最高的样式。CSS优先级指的就是CSS样式规则的权重。在网页制作中&#xff0c;CSS为每个基础选择器都指定了不同的权重&#xff0c;方便我们…

【内存问题真的很烦人】linux内存等资源管理 以及 linux内存不足解决办法

linux内存不足解决办法 ///这一部分存在疑问 查看目录下文件夹大小 du -h --max-depth1 看具体哪个文件夹占用内存过高&#xff0c;一般是日志&#xff0c;删除即可。 ///这一部分存在疑问&#xff0c;上面的文件夹可以代表内存吗&#xff1f; 内存不够 top 命令 看内存占用…

Python就业前景如何?三大就业岗位分享

Python是一门面向对象的编程语言&#xff0c;编译速度超快&#xff0c;从诞生到现在已经20来个年头了。Python的排名从去年开始就借助人工智能持续上升&#xff0c;Python的火热&#xff0c;也带动了工程师们的就业热。 据统计&#xff0c;现在初级Python工程师的起薪一般在10…

【 Lucas-Kanade光流法】

这里写目录标题 1.1 Lucas-Kanade光流法1.1 Lucas-Kanade光流法详细步骤&#xff1a; 1.1 Lucas-Kanade光流法 Lucas-Kanade光流法是一种密集光流估计方法&#xff0c;用于计算图像中每个像素的运动向量。它假设在相邻帧之间&#xff0c;像素的灰度值不会发生大的变化&#xf…

《网络安全0-100》双钥加密体制

双钥加密体制 怎么说 没找着公钥加密在哪&#xff0c;所以就接着写了。 公钥加密&#xff0c;也叫非对称(密钥)加密&#xff0c;属于通信科技下的网络安全二级学科&#xff0c;指的是由对应的一对唯一性密钥(即公开密钥和私有密钥)组成的加密方法。它解决了密钥的发布和管理…

【ArcGIS】使用ArcGIS进行坡度分析

使用ArcGIS进行坡度分析 1 数据来源2 操作步骤参考 坡度是指过地表面任意一点的切平面与水平地面之间的夹角。坡度用来计算任–单元和邻域单元间变化的最大比率&#xff0c;如单元下降最陡的坡面(单元和它相邻单元间的高程距离的最大变化率)。 坡度分析是计算两相邻像元间的数值…

STM32的时钟系统(嵌入式学习)

STM32的时钟系统 时钟的基本概念时钟系统的组成时钟源晶体振荡器和RC振荡器的区别晶体振荡器RC振荡器 STM32G030时钟源时钟树STM32CubeMX时钟树配置 时钟的基本概念 时钟是指用于计量和同步时间的装置或系统。时钟是嵌入式系统的脉搏&#xff0c;处理器内核在时钟驱动下完成指…

Goby 漏洞发布|PandoraFMS 软件 upload_head_image.php 任意文件上传漏洞

漏洞名称&#xff1a;PandoraFMS 软件 upload_head_image.php 任意文件上传漏洞 English Name&#xff1a;PandoraFMS upload_head_image.php Arbitrary File Upload Vulnerability CVSS core: 9.0 影响资产数&#xff1a;768 漏洞描述&#xff1a; PandoraFMS是美国Pando…

【Python】文件操作 ④ ( 文件操作 | 向文件写出数据 | 使用 write 函数向文件中写出数据 | 使用 flush 函数刷新文件数据 )

文章目录 一、向文件写出数据1、使用 write 函数向文件中写出数据2、使用 flush 函数刷新文件数据3、代码示例 - 使用 write / flush 函数向文件中写出数据 一、向文件写出数据 1、使用 write 函数向文件中写出数据 Python 中 通过 调用 write 函数 向文件中写入数据 ; 语法如下…

Fiddler抓包工具之fiddler的常用快捷键

一、常用三个快捷键 ctrlX :清空所有记录 CtrlF&#xff1a;查找 F12&#xff1a;启动或者停止抓包 使用 QuickExec Fiddler2 成了网页调试必备的工具&#xff0c;抓包看数据。Fiddler2自带命令行控制。 fiddler 命令行快捷键&#xff1a;ctrl q &#xff0c;然后 输入 help…

记一次杀猪盘网站渗透

1、首先访问杀猪盘主站。 2、通过扫描子域名找到后台管理系统。 3、对其后台的登录接口进行测试&#xff0c;发现接口的用户名参数存在sql注入&#xff0c;直接跑数据。 4、注入得到后台的账密如下,用户名和safecode是明文的&#xff0c;password使用自定义加密。 跑出来了账号…

PaaS2.0、Matter、AIGC、新能源…TUYA开发者大会亮点抢先看

6月29日&#xff0c;TUYA开发者大会&#xff08;深圳&#xff09;即将开幕。作为业内备受关注的盛会&#xff0c;大会的各种“路透”消息络绎不绝。那么TUYA开发者大会将呈现哪些精彩&#xff0c;我们带大家一探究竟。 亮点1&#xff1a;IoT行业风向标 积蓄2年的硬核分享 TUYA…

【Ribbon实现客户端负载均衡和故障转移】—— 每天一点小知识

&#x1f4a7; R i b b o n 实现客户端负载均衡和故障转移 \color{#FF1493}{Ribbon实现客户端负载均衡和故障转移} Ribbon实现客户端负载均衡和故障转移&#x1f4a7; &#x1f337; 仰望天空&#xff0c;妳我亦是行人.✨ &#x1f984; 个人主页——微风撞见云的博客…

掌握会议任务追踪技巧,提高会议效率!

跟踪会议任务是有效项目管理的重要组成部分。会议可以产生许多需要完成的行动项目和任务&#xff0c;如果没有适当的跟踪&#xff0c;这些任务很容易被遗漏。在本文中&#xff0c;我们将概述如何有效地跟踪会议任务。 1、在会议中分配任务 在会议期间&#xff0c;将任务分配给特…