一、简介
训练模型以理解并预测人类偏好是一项复杂的任务。传统方法如SFT(监督微调)通常需要较高的成本,因为这些算法需要对数据进行特定标签的标注。偏好优化(Preference Optimization)作为一种替代方案,可以简化这一过程并提供更准确的结果。通过对候选回答的对比和排序,而不是赋予固定的标签,偏好优化能够更高效地捕捉人类偏好的细微差别。
虽然偏好优化已经在大语言模型中广泛使用,但现在它也可以应用于视觉语言模型(VLM)。得益于TRL(Transformer Reinforcement Learning)的开发,现在我们可以使用TRL对VLM进行直接偏好优化(Direct Preference Optimization)。本文将介绍使用TRL和DPO对视觉语言模型进行训练的全过程。
二、偏好数据集
进行偏好优化,首先需要有一个能体现用户偏好的数据集。在双项选择的设定下,相应的数据一般包含一个提示词(Prompt)和两个候选回答,其中一个被标记为选中(chosen),另一个被标记为淘汰(rejected)。模型需要学习选择正确的回答,而不是被淘汰的回答。下图展示了一个例子:
❔ 问题: 有多少个家庭?
- ❌ 被淘汰的回答: 图片没有提供关于家庭的信息。
- ✅ 选中的回答: 图片显示了一个工会组织的表格,包含18000个家庭。
尽管选中的回答也不是完全正确(应该是18000000个家庭),但比被淘汰的回答更好。
本文将使用openbmb/RLAIF-V-Dataset作为示例数据集,该数据集包含超过83000条标注数据。可以通过以下代码查看数据集:
from datasets import load_dataset
dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train[:1%]")
sample = dataset[1]
sample["image"].show()
sample["question"]
'how many families?'
sample["rejected"]
'The image does not provide any information about families.'
sample["chosen"]
'The image shows a Union Organization table setup with 18,000 families.'
我们将要训练的 VLM 模型需要文本和图像同时作为输入,所以这里的第一步还是要对数据集格式进行改造。一条数据应该被结构化成能模拟人机对话的形式。用户提供一个提示语,其中包含一张图片和一个问题,然后模型需要能够给出一个回答。我们用以下代码实现格式转换:
from datasets import features
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)
def format(example):
# Prepare the input for the chat template
prompt = [
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": example["question"]}],
},
]
chosen = [
{
"role": "assistant",
"content": [{"type": "text", "text": example["chosen"]}],
},
]
rejected = [
{
"role": "assistant",
"content": [{"type": "text", "text": example["rejected"]}],
},
]
# Apply the chat template
prompt = processor.apply_chat_template(prompt, tokenize=False)
chosen = processor.apply_chat_template(chosen, tokenize=False)
rejected = processor.apply_chat_template(rejected, tokenize=False)
# Resize the image to ensure it fits within the maximum allowable
# size of the processor to prevent OOM errors.
max_size = processor.image_processor.size["longest_edge"]
example["image"].thumbnail((max_size, max_size))
return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}
# Apply the formatting function to the dataset,
# remove columns to end up with only "images", "prompt", "chosen", "rejected" columns
dataset = dataset.map(format, remove_columns=dataset.column_names)
# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True)) # to avoid bytes
dataset = dataset.cast(f)
完成了格式转换,我们来看看第一条数据:
>>> dataset[1]
{'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=L size=980x812 at 0x154505570>],
'prompt': 'User:<image>how many families?<end_of_utterance>\n',
'rejected': 'Assistant: The image does not provide any information about families.<end_of_utterance>\n',
'chosen': 'Assistant: The image shows a Union Organization table setup with 18,000 families.<end_of_utterance>\n'}
三、训练
3.1 训练需要多大的 GPU 显存?
以微调1B的模型为例子,假设模型的的每个参数用32bit存储,32bit=4byte。
每个参数通常以浮点数形式存储。FP32(32位浮点数)每个参数占用4字节的存储空间,而BF16(16位浮点数)每个参数占用2字节的存储空间。
需要用到GPU的部分:模型权重(需要加载进去)、梯度(更新参数)、优化器(状态量,SGD和Adam占用的显存空间不一样)、激活值等等
GB
- 模型权重1B = 1b x 4 byte = 4GB;
- 梯度的显存需求与模型权重相同 4GB;
- 以Adam优化器(LLM用的多)为例,Adam需要维护模型的参数、每个参数的动量和平方梯度信息,因此占用的显存大约是模型权重的3倍 [一阶动量估计(类似动量)、二阶动量估计(平方梯度)];
注意,优化器都是用FP32进行存储的,因为大量的小值累加(sum、mean)操作,如果用FP16进行会损失精度,太小的值用FP16会表示为0。
- 激活值(中间结果),反向传播和前向传播会用到,这边只是简单起见,bs=1,和模型参数一样是4GB,实际上这个计算推导很复杂,后面有机会再写~,同时Transformer中激活值和序列长度以平方次数增长;
- 输入数据:跟Batch size、样本I大小有关系,就是B x I x 4 字节,这边暂时忽略;
参数来源 | 计算公式 | 显存需求 |
---|---|---|
要训练的模型 | 32 GB | |
参考模型(这个任务额外要的,防止模型发生偏移,和要训练的模型一样大) | 32 GB | |
梯度 | 32 GB | |
优化器状态量 | 72 GB | |
合计 | 168 GB |
可以使用量化、LoRA 等技术来大幅度地减少显存需求,让训练可以进行。
3.2 使用 bfloat16 和 LoRA 后的显存需求
参数来源 | 计算公式 | 显存需求 |
---|---|---|
要训练的模型 | 16 GB | |
参考模型 | 16 GB | |
梯度 | 0.1 GB | |
优化器状态量 | 0.3 GB | |
合计 | 32.4 GB |
四、微调Llava 1.5和PaliGemma等模型
TRL的DPO实现已支持Idefics2、Llava 1.5和PaliGemma,同时TRL也在努力支持更多的模型。最简单的调用方法是使用TRL提供的示例脚本。例如,如果你想微调PaliGemma,可以使用以下命令:
accelerate launch examples/scripts/dpo_visual.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path google/paligemma-3b-pt-224 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_paligemma_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules=all-linear
五、可视化结果
下表展示了一些可视化的结果:
Image | Question | Idefics2 | Idefics2+DPO |
---|---|---|---|
Are there two ships in this image? | Yes | No | |
Is the ground uneven in this image? | No | Yes | |
Is there one shovel in this image? | Yes | No |
六、参考链接
[1] https://huggingface.co/docs/peft/en/index
[2] https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/lora-qlora?hl=zh-cn
[3] https://huggingface.co/blog/zh/dpo_vlm