使用LlamaFactory进行模型微调
简介
论文地址:https://arxiv.org/pdf/2403.13372
仓库地址:https://github.com/hiyouga/LLaMA-Factory/tree/main
名词解释
1. 预训练 (Pre-training, PT)
预训练是指模型在大规模无监督数据集上进行初步训练的过程。数据通常由互联网上的文本构成,没有明确的标签。模型通过预测下一个词来学习语言的语法、语义和常识知识。
- 目标:通过大量的文本数据学习词与词之间的关系,获得广泛的语言理解能力。
- 训练方式:无监督学习。模型根据已知上下文预测下一个词,并通过这个预测结果不断调整自身的参数。
- 结果:预训练后的模型具备了基本的语言生成和理解能力,但缺乏完成特定任务的精细调整。
2. 指令微调 (Supervised Fine-tuning, SFT)
指令微调是在预训练的基础上,使用有监督数据对模型进行精细调整的过程。这里的数据是特定任务的数据集,通常包含了人类标注的输入-输出对,指导模型如何在特定指令下工作。
- 目标:提升模型在特定任务上的表现,让模型更好地理解和执行特定指令。
- 训练方式:有监督学习。模型根据给定的输入和预期的输出调整参数,以便能够在类似场景下更准确地执行任务。
- 结果:模型在特定任务(如回答问题、翻译、摘要等)上具有更好的性能和准确性。
3. 基于人工反馈的对齐 (Reinforcement Learning with Human Feedback, RLHF)
RLHF 是通过引入人工反馈,进一步调整模型行为,使其生成的内容更加符合人类期望的一个阶段。具体来说,人类会对模型的输出进行评估,模型根据这些反馈进行强化学习调整。
- 目标:让模型的输出更符合人类偏好或社会价值观,并减少不当或有害的输出。
- 训练方式:强化学习。模型会生成多个候选答案,并由人类评估这些答案的质量。基于这些评价,模型会通过强化学习算法(如Proximal Policy Optimization, PPO)调整生成策略。
- 结果:模型不仅能够理解指令,还能产生与人类偏好对齐的答案。RLHF 在模型的安全性和伦理方面也有很大的作用。
背景
开源大模型如LLaMA,Qwen,Baichuan等主要都是使用通用数据进行训练而来,其对于不同下游的使用场景和垂直领域的效果有待进一步提升,衍生出了微调训练相关的需求,包含预训练(pt),指令微调(sft),基于人工反馈的对齐(rlhf)等全链路。但大模型训练对于显存和算力的要求较高,同时也需要下游开发者对大模型本身的技术有一定了解,具有一定的门槛。
LLaMA-Factory整合主流的各种高效训练微调技术,适配市场主流开源模型,形成一个功能丰富,适配性好的训练框架。项目提供了多个高层次抽象的调用接口,包含多阶段训练,推理测试,benchmark评测,API Server等,使开发者开箱即用。同时提供了基于gradio的网页版工作台,方便初学者可以迅速上手操作,开发出自己的第一个模型。
环境准备
colab
使用google的colab可以快速且免费地使用LlamaFactory进行模型微调。
https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
本地机器
- 本地运行需要必要的cuda和pytorch,如果你在容器里运行,可以直接使用nvidia的pytorch镜像:
nvcr.io/nvidia/pytorch:24.05-py3
- 安装LlamaFactory
export PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
export USE_MODELSCOPE_HUB=1
export MODELSCOPE_CACHE=/data/modelscope
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics,modelscope]"
llamafactory-cli webui
数据集
LLamaFactory内置了一些常用的数据集,可以直接在现有的数据集上修改,也可以自己扩充新的数据集,详细参考:https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README_zh.md
这里直接用身份验证数据集,修改模型的身份去验证微调效果:
import json
cd /content/LLaMA-Factory/
NAME = "测试AI"
AUTHOR = "测试科技"
with open("data/identity.json", "r", encoding="utf-8") as f:
dataset = json.load(f)
for sample in dataset:
sample["output"] = sample["output"].replace("{{"+ "name" + "}}", NAME).replace("{{"+ "author" + "}}", AUTHOR)
with open("data/identity.json", "w", encoding="utf-8") as f:
json.dump(dataset, f, indent=2, ensure_ascii=False)
目前LLamaFactory支持两种数据集:Alpaca 格式、Sharegpt 格式。相较于 alpaca 格式的数据集,sharegpt 格式支持更多的角色种类,例如 human、gpt、observation、function 等等。
Alpaca格式
指令微调 (Supervised Fine-tuning, SFT) 格式
在指令监督微调时,instruction
列对应的内容会与 input
列对应的内容拼接后作为人类指令,即人类指令为 instruction\ninput
。而 output
列对应的内容为模型回答。
如果指定,system
列对应的内容将被作为系统提示词。
history
列是由多个字符串二元组构成的列表,分别代表历史消息中每轮对话的指令和回答。注意在指令监督微调时,历史消息中的回答内容也会被用于模型学习。
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
["第一轮指令(选填)", "第一轮回答(选填)"],
["第二轮指令(选填)", "第二轮回答(选填)"]
]
}
]
预训练 (Pre-training, PT) 格式
在预训练时,只有 text
列中的内容会用于模型学习。
[
{"text": "document"},
{"text": "document"}
]
偏好数据集
偏好数据集用于奖励模型训练、DPO 训练、ORPO 训练和 SimPO 训练。
它需要在 chosen
列中提供更优的回答,并在 rejected
列中提供更差的回答。
[
{
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"chosen": "优质回答(必填)",
"rejected": "劣质回答(必填)"
}
]
Sharegpt格式
指令微调 (Supervised Fine-tuning, SFT) 格式
相比 alpaca 格式的数据集,sharegpt 格式支持更多的角色种类,例如 human、gpt、observation、function 等等。它们构成一个对象列表呈现在 conversations
列中。
注意其中 human 和 observation 必须出现在奇数位置,gpt 和 function 必须出现在偶数位置。
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "function_call",
"value": "工具参数"
},
{
"from": "observation",
"value": "工具结果"
},
{
"from": "gpt",
"value": "模型回答"
}
],
"system": "系统提示词(选填)",
"tools": "工具描述(选填)"
}
]
偏好数据集
[
{
"conversations": [
{
"from": "human",
"value": "人类指令"
},
{
"from": "gpt",
"value": "模型回答"
},
{
"from": "human",
"value": "人类指令"
}
],
"chosen": {
"from": "gpt",
"value": "优质回答"
},
"rejected": {
"from": "gpt",
"value": "劣质回答"
}
}
]
微调
开始微调
可以通过gradio页面直接微调
# 启动gradio,使用7860端口访问
llamafactory-cli webui
选择微调模型和数据集
开始微调
微调完成后的文件存在项目目录的saves文件夹下。
判断微调效果
可以通过看损失曲线是否收敛来判断本次训练的效果,如果收敛不好,则需要考虑调小学习率、减少梯度累积、增加训练轮数等方式。以下是一些损失较大的图例(需要修改参数重新微调)。
验证微调效果
微调参数解释
梯度累积 (gradient_accumulation_steps)
模型会在每次处理一个小批次后,不立即更新权重,而是先将每次的小批次梯度累积起来,直到累积到设定的次数为止(例如,4或8),然后再一次性执行权重更新。这样做的主要原因是为了减少显存占用。如果你的显存较为充裕,想要快速更新权重,可以选择较小的值,比如 4
,否则可以选择较大的值,比如 8
。
学习率 (learning_rate)
更大的学习率意味着模型每次更新的幅度更大,收敛速度会较快。然而,过大的学习率可能会导致模型在损失函数的表面上跳动过大,甚至错过全局最优解,导致模型发散或收敛到局部最优解。
5e-5:适合在模型刚开始训练时使用,或者用于一些简单任务和数据集较小的场景。较大的学习率可以帮助模型快速进入一个较好的区域,减少训练时间。
5e-6:适合在微调阶段使用,特别是当你在较大的预训练模型上进行微调时,小的学习率更为合适。因为大模型通常已经学习到了大量的知识,微调阶段只需要进行小幅度的调整,过大的学习率反而可能破坏模型原有的良好性能。
训练轮次 (training_epochs)
在深度学习中,单次迭代(即单个epoch)通常不足以让模型学到足够的信息。模型通过多轮次的训练,不断调整权重,逐渐优化自身,降低误差并提高在验证集上的性能。每次新的epoch开始,模型会使用更新的权重重新开始遍历数据,从而逐步学习数据中的模式和特征。
如果训练轮次设置得太少,模型可能无法充分学习数据中的模式,表现为欠拟合(underfitting),即训练不足,模型在训练集和验证集上的表现都较差。
果训练轮次设置得太多,模型可能过度拟合(overfitting)训练集,即模型对训练数据的表现非常好,但在验证集或测试集上的表现可能会变差,因为模型“记住”了训练数据中的噪声,而不是学习到泛化的模式。
合并LoRA 权重
使用页面的Export功能进行LoRA的权重合并
合并完成后,可以用以下代码进行测试
import transformers
import torch
model_id='/data/finetuned/llama3-8b-instruct-merged'
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device="cuda",
)
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "你是谁?"},
]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
print("欢迎使用聊天模型!输入 'exit' 退出对话。\n")
while True:
input_text = input("你: ")
if input_text.lower() == "exit":
print("对话结束。")
break
outputs = pipeline(
prompt,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
print(outputs[0]["generated_text"][len(prompt):])
问题解决
KeyError: ‘WEBP’
使用图形化界面微调模型时,页面会展示训练损失图,可能会报这个错误,这个错误表明在使用 matplotlib
保存图片为 WEBP
格式时,出现了 KeyError: 'WEBP'
,即 PIL
不支持 WEBP
格式。这通常是由于 Pillow
库没有正确安装 WEBP
格式的支持引起的。
apt-get update && apt-get install libwebp-dev
pip uninstall pillow
pip install --no-cache-dir pillow
使用python验证是否解决,期待返回True
from PIL import features
print(features.check('webp'))