环境配置
Python 3.10.12
transformers 4.36.2
torch 2.0.1
下载demo代码
- 在官方网址https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo 下载demo代码
- cd 进入文件夹 pip install -r requirements.txt 安装一些包
基本知识
SFT 全量微调: 4张显卡平均分配,每张显卡占用 48346MiB 显存
P-TuningV2 微调: 1张显卡,占用 18426MiB 显存
LORA 微调: 1张显卡,占用 14082MiB 显存
微调数据文件格式因需求不同而有所差别:
微调模型的对话能力
微调模型的对话和工具能力
此处demo选用AdvertiseGen 数据集
新建convert.py 将AdvertiseGen 转为AdvertiseGen _fix
convert.py 代码官方给了,代码如下:
import json
from typing import Union
from pathlib import Path
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def _mkdir(dir_name: Union[str, Path]):
dir_name = _resolve_path(dir_name)
if not dir_name.is_dir():
dir_name.mkdir(parents=True, exist_ok=False)
def convert_adgen(data_dir: Union[str, Path], save_dir: Union[str, Path]):
def _convert(in_file: Path, out_file: Path):
_mkdir(out_file.parent)
with open(in_file, encoding='utf-8') as fin:
with open(out_file, 'wt', encoding='utf-8') as fout:
for line in fin:
dct = json.loads(line)
sample = {'conversations': [{'role': 'user', 'content': dct['content']},
{'role': 'assistant', 'content': dct['summary']}]}
fout.write(json.dumps(sample, ensure_ascii=False) + '\n')
data_dir = _resolve_path(data_dir)
save_dir = _resolve_path(save_dir)
train_file = data_dir / 'train.json'
if train_file.is_file():
out_file = save_dir / train_file.relative_to(data_dir)
_convert(train_file, out_file)
dev_file = data_dir / 'dev.json'
if dev_file.is_file():
out_file = save_dir / dev_file.relative_to(data_dir)
_convert(dev_file, out_file)
convert_adgen('data/AdvertiseGen', 'data/AdvertiseGen_fix')
最后运行完convert.py ,会出现如下图格式:
数据集eg. :
Old:{“content”: “类型#上衣材质#牛仔布颜色#白色风格#简约图案#刺绣衣样式#外套衣款式#破洞”, “summary”: “简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。”}
New:{“conversations”: [{“role”: “user”, “content”: “类型#裙*裙长#半身裙”}, {“role”: “assistant”, “content”: “这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。”}]}
下载ChatGLM3-6B模型
由于chatglm3-6b 是用modelscope 下载
姑新建python文件,内容如下:
from modelscope import snapshot_download
model_dir = snapshot_download("ZhipuAI/chatglm3-6b", revision = "v1.0.0")
print(model_dir)
所以/media/zr/Data/Models/LLM/chatglm3-6b == print(model_dir)
微调,执行:
新建sh文件,内容如下:
CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python finetune_hf.py data/AdvertiseGen_fix /media/zr/Data/Models/LLM/chatglm3-6b configs/lora.yaml
注意此处的/media/zr/Data/Models/LLM/chatglm3-6b 要换成 你自己下载的chatglm3-6b的路径
使用微调的数据集进行推理:执行
新建sh文件,内容如下:
CUDA_VISIBLE_DEVICES=0 NCCL_P2P_DISABLE="1" NCCL_IB_DISABLE="1" python inference_hf.py output/checkpoint-3000/ --prompt "类型#裙*版型#显瘦*材质#网纱*风格#性感*裙型#百褶*裙下摆#压褶*裙长#连衣裙*裙衣门襟#拉链*裙衣门襟#套头*裙款式#拼接*裙款式#拉链*裙款式#木耳边*裙款式#抽褶*裙款式#不规则"
输出:这款连衣裙采用了网纱拼接的压褶设计,视觉上很显瘦,搭配木耳边套头设计,更具有性感的气质。不规则的裙摆,更具有灵动性。而拉链设计,方便穿脱。百褶裙摆,优雅而灵动。
如图所示:
参考:
官方教程十分详细,值得一看,且里面还有许多参数说明。
https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/README.md
https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb