文章目录
- 一、前言
- 二、说明
- 2.1 代码结构
- 2.2 依赖包版本
- 三、启动对话演示
- 3.1 命令行交互 cli_demo.py
- 3.2 网页交互 web_demo.py
- 四、微调模型
- 4.1 基于 P-Tuning v2 微调模型
- 4.1.1 软件依赖
- 4.1.2 下载数据集
- 4.1.3 下载模型文件
- 4.1.4 操作步骤
- 4.2 基于 Full Parameter 微调模型
- 4.3 基于LoRA微调模型
- 参考资料
一、前言
ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
-
智谱AI官方内测网站:ChatGLM
-
项目地址:https://github.com/THUDM/ChatGLM-6B/tree/main
-
模型文件:https://huggingface.co/THUDM/chatglm-6b/tree/main
-
B站讲解视频:【官方教程】ChatGLM-6B 微调:P-Tuning,LoRA,Full parameter
-
七月在线博客:ChatGLM两代的部署/微调/实现:从基座GLM、ChatGLM的LoRA/P-Tuning微调、6B源码解读到ChatGLM2的微调与实现
二、说明
2.1 代码结构
2.2 依赖包版本
由于大模型相关的各种依赖包版本更新较快,会导致各种报错,如:
'ChatGLMTokenizer' object has no attribute 'sp_tokenizer'
这里主要是由transformers
版本问题导致的,解决方案可以参考博客:https://blog.csdn.net/Tink_bell/article/details/137942170
三、启动对话演示
ChatGLM-6B下提供了cli_demo.py和web_demo.py两个文件来启动模型:
- cli_demo.py:使用命令行进行交互。
- web_demo.py:使用gradio库使用本机服务器进行网页交互。
这里,依赖包的版本会影响到代码的运行。经过多次报错与尝试,我这里使用的依赖包版本为:
transformers==4.33.0
gradio==3.39.0
3.1 命令行交互 cli_demo.py
[待补充]
3.2 网页交互 web_demo.py
web_demo.py中基于gradio库使用本机服务器进行网页交互,具体运行步骤如下:
(1)模型路径测试。
首先需要将模型地址配置为本地模型路径
原代码:
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
修改为本地模型路径后的代码:
model_path = './model/chatglm-6b'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()
(2)模型量化。
这里,由于我的GPU内存不够,所以对模型做量化操作:
model_path = './model/chatglm-6b'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, quantization_config=bnb_config, device_map={"":0}, trust_remote_code=True)
model = model.eval()
(3)运行 web_demo.py
可以看到控制台输出
说明网页交互服务已在本地 http://127.0.0.1:7860
运行起来。
查看GPU显存占用情况,可以看到使用模型量化,最后只占用了不到5GB的显存:
(4)端口映射。
这里,由于我们是在远程服务器上运行的服务,如果想在本地浏览器访问,需要做一个端口映射,具体命令如下:
ssh -L 1234:localhost:7860 root@172.xxx.yyy.zzz
基于上述命令,将远程服务器的 7860 端口映射至本地 1234 端口。
(5)本地访问服务。
然后我们在本地浏览器打开 http://localhost:1234/
即可访问该页面,如下所示:
在输入窗口输入文本信息并提交即可实现调用ChatGLM的对话功能。
四、微调模型
ChatGLM模型的fine-tune有多种模式:
- P-Tuning v2
- LoRA
- Full parameter
其中ChatGLM官方代码仓库中给出了基于 P-Tuning v2
及 Full parameter
的方法,具体微调模型的方式可以参考B站视频:【官方教程】ChatGLM-6B 微调:P-Tuning,LoRA,Full parameter
LoRA的微调方法可以参考:https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/chatglm_v2_6b_lora
4.1 基于 P-Tuning v2 微调模型
官方给出的微调教程:https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md
本仓库实现了对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
下面以 ADGEN (广告生成) 数据集为例介绍代码的使用方法。
4.1.1 软件依赖
这里,基于P-Tuning v2 微调模型对于transformers
的版本有限制,需要4.27.1版本的transformers
:
pip install transformers==4.27.1
此外,还需要安装以下依赖
pip install rouge_chinese nltk jieba datasets
4.1.2 下载数据集
ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
{
"content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
"summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
ADGEN 数据集可以从 Google Drive 或者 Tsinghua Cloud 来下载。
4.1.3 下载模型文件
模型文件:https://huggingface.co/THUDM/chatglm-6b/tree/main
可以选择 git clone
或者 手动
的方式来下载模型文件。
4.1.4 操作步骤
(1)文件及数据准备
使用 ptuning
文件夹下的代码进行微调,这里我们在当前目录下创建:
model
目录存放下载的模型文件data
目录存放ADGEN 数据文件
(2)修改 train.sh
代码
根据本地模型及数据文件目录,修改train.sh
中的相应参数:
PRE_SEQ_LEN=128
LR=2e-2
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_train \
--train_file './data/AdvertiseGen/train.json' \
--validation_file './data/AdvertiseGen/dev.json' \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path './model/chatglm-6b/' \
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
【参数说明】:
train.sh 中的 PRE_SEQ_LEN
和 LR
分别是 soft prompt
长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16
下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
(3)运行脚本
bash train.sh
这里,我们对模型做了量化处理,可以看到占用的显存占用情况:
迭代3000次:
可以通过wandb
查看运行中的参数变化情况:
wandb中可以看到 train_loss 的变化情况:
4.2 基于 Full Parameter 微调模型
如果需要进行全参数的 Finetune,需要安装 Deepspeed,然后运行以下指令:
bash ds_train_finetune.sh
4.3 基于LoRA微调模型
ChatGLM官方仓库中的部分微调使用的是基于 P Tuning v2的微调方式,并未给出基于LoRA的微调。
LoRA的微调方法可以参考:https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/chatglm_v2_6b_lora
参考资料
- ChatGLM两代的部署/微调/实现:从基座GLM、ChatGLM的LoRA/P-Tuning微调、6B源码解读到ChatGLM2的微调与实现