DPO 是 RLHF 的屌丝版本,RLHF 需要加载 4 个模型(2个推理,2个训练),DPO 只需要加载 2 个模型(1个推理,一个训练)。
RLHF:
DPO:
DPO 原理
DPO 的本质是监督对比学习:通过对每条prompt提供两条不同的answer,并给出这两个answer的偏好偏序,让模型输出更接近good answer,同时更远离 bad answer。
这个过程中并不强制要求上述两者同时满足,只要接近good answer的程度大于bad answer就是有效的训练,比如与good answer远离了,但是与bad answer远离的更多也是有效的。
DPO loss
σ :sigmoid函数
β :超参数,一般在0.1 - 0.5之间
:某条偏好数据中好的response,w就是win的意思
:某条偏好数据中差的response,l就是loss的意思,所以偏好数据也叫comparision data
:给定输入x, 当前policy model生成好的response的累积概率(每个tokne的概率求和,具体看代码)
:给定输入x, 原始模型(reference model)生成坏的response的累积概率
开始训练时,reference model和policy model都是同一个模型,只不过在训练过程中reference model不会更新权重。
简化形式:忽略 logsigmoid 并取对数
由于最初loss前面是有个负号的,所以优化目标是让本简化公式最大,即希望左半部分和右半部分的margin越大越好,左半部分的含义是good response相较于没训练之前的累积概率差值,右半部分代表bad response相较于没训练之前的累计概率差值,如果这个差值,即margin变大了。
DPO 数据集
可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset
DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。
Huagging Face DPO Trainer
与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
Loss 选择:
- RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
- IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
- cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
- KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。
简单实例
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy
torch.manual_seed(0)
if __name__ == "__main__":
# 超参数
beta = 0.1
# 加载模型
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))
reference_model = deepcopy(policy_model)
# data
prompt_ids = [1, 2, 3, 4, 5, 6]
good_response_ids = [7, 8, 9, 10]
# 对loss稍加修改可以应对一个good和多个bad的情况
bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]
# 转换成模型输入 [3, 10]
input_ids = torch.LongTensor(
[prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
)
# labels 提前做个shift [3, 9]
labels = torch.LongTensor(
[
[-100] * len(prompt_ids) + good_response_ids,
*[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
]
)[:, 1:]
loss_mask = (labels != -100)
labels[labels == -100] = 0
# 计算 policy model的log prob
# policy_model(input_ids)["logits"] [3, 10, 1000] 句末的推理结果无效直接忽略
logits = policy_model(input_ids)["logits"][:, :-1, :]
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
all_logps = (per_token_logps * loss_mask).sum(-1)
# 暂时写死第一个是good response的概率, 三个例子中第一个是 good answer, 后两个是 bad answer
policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]
# 计算 reference model的log prob
with torch.no_grad():
logits = reference_model(input_ids)["logits"][:, :-1, :]
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
all_logps = (per_token_logps * loss_mask).sum(-1)
# 暂时写死第一个是good response的概率
reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]
# 计算loss,会自动进行广播
logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
loss = -F.logsigmoid(beta * logits).mean()
print(loss)