RLAIF(0)—— DPO(Direct Preference Optimization) 原理与代码解读

news2025/3/1 1:46:28

之前的系列文章:介绍了 RLHF 里用到 Reward Model、PPO 算法。
但是这种传统的 RLHF 算法存在以下问题:流程复杂,需要多个中间模型对超参数很敏感,导致模型训练的结果不稳定。
斯坦福大学提出了 DPO 算法,尝试解决上面的问题,DPO 算法的思想也被后面 RLAIF(AI反馈强化学习)的算法借鉴,这个系列会从 DPO 开始,介绍 SPIN、self-reward model 算法。
而 DPO 本身是一种不需要强化学习的算法,简化了整个 RLHF 流程,训练起来会更简单。

原理


传统的 RLHF 步骤一般是:训练一个 reward model 对 prompt 的 response 进行打分,训练完之后借助 PPO 算法,使得 SFT 的模型和人类偏好对齐,这个过程我们需要初始化四个基本结构一致的 transformer 模型。DPO 算法,提供了一种更为简单的 loss function,而这个就是 DPO 的核心思想:针对奖励函数的 loss 函数被转换成针对策略的 loss 函数,而针对策略的 loss 函数又暗含对奖励的表示,即人类偏好的回答会暗含一个更高的奖励。

L D P O ( π θ ; π r e f ) = − E ( x , y w , y l ) ∼ D [ log ⁡ σ ( β log ⁡ π θ ( y w ∣ x ) π r e f ( y w ∣ x ) − β log ⁡ π θ ( y l ∣ x ) π r e f ( y l ∣ x ) ) ] \mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right] LDPO(πθ;πref)=E(x,yw,yl)D[logσ(βlogπref(ywx)πθ(ywx)βlogπref(ylx)πθ(ylx))]

这个函数唯一的超参数是 β \beta β,决定了不同策略获得的奖励之间的 margin。
可以看出 DPO 虽然说是没有用强化学习,但是还是有强化学习的影子,相当于 beta 是一个固定的 reward,通过人类偏好数据,可以让这个奖励的期望最大化,本至上来说 dpo 算法也算是一种策略迭代。
对 loss function 求导,可得如下表达式:

∇ θ L D P O ( π θ ; π r e f ) = − β E ( x , y w , y l ) ∼ D [ σ ( r ^ θ ( x , y l ) − r ^ θ ( x , y w ) ) ⏟ higher weight when reward estimate is wrong  [ ∇ θ log ⁡ π ( y w ∣ x ) ⏟ increase likelihood of  y w − ∇ θ log ⁡ π ( y l ∣ x ) ⏟ decrease likelihood of  y l ] ] \begin{aligned} & \nabla_\theta \mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta ; \pi_{\mathrm{ref}}\right)= \\ & -\beta \mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}[\underbrace{\sigma\left(\hat{r}_\theta\left(x, y_l\right)-\hat{r}_\theta\left(x, y_w\right)\right)}_{\text {higher weight when reward estimate is wrong }}[\underbrace{\nabla_\theta \log \pi\left(y_w \mid x\right)}_{\text {increase likelihood of } y_w}-\underbrace{\nabla_\theta \log \pi\left(y_l \mid x\right)}_{\text {decrease likelihood of } y_l}]] \end{aligned} θLDPO(πθ;πref)=βE(x,yw,yl)D[higher weight when reward estimate is wrong  σ(r^θ(x,yl)r^θ(x,yw))[increase likelihood of yw θlogπ(ywx)decrease likelihood of yl θlogπ(ylx)]]

其中 r ^ θ ( x , y ) = β log ⁡ π θ ( y ∣ x ) π ref  ( y ∣ x ) \hat{r}_\theta(x, y)=\beta \log \frac{\pi_\theta(y \mid x)}{\pi_{\text {ref }}(y \mid x)} r^θ(x,y)=βlogπref (yx)πθ(yx)
不得不佩服作者构思的巧妙,通过求导,作者捕捉到了”暗含“的 reward —— σ ( r ^ θ ( x , y l ) − r ^ θ ( x , y w ) ) \sigma\left(\hat{r}_\theta\left(x, y_l\right)-\hat{r}_\theta\left(x, y_w\right)\right) σ(r^θ(x,yl)r^θ(x,yw)),作者在论文里说到,当我们在让 loss 降低的过程中,这个 reward 也会变小(可以用来做 rejected_rewards)。因而在下面的代码实现里,chosen_rewards 和 rejected_rewards 就来自这个想法。

代码实现

我们来看下 trl 是如何实现 dpo loss 的,可以看到 dpo 和 ppo 相比,实现确实更为简单,而且从开源社区的反应来看,dpo 的效果也很不错,现在 dpo 已经成为偏好对齐最主流的算法之一。

def dpo_loss(
    self,
    policy_chosen_logps: torch.FloatTensor,
    policy_rejected_logps: torch.FloatTensor,
    reference_chosen_logps: torch.FloatTensor,
    reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

    Returns:
        A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
        The losses tensor contains the DPO loss for each example in the batch.
        The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
    """
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    if self.reference_free:
        ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
    else:
        ref_logratios = reference_chosen_logps - reference_rejected_logps

    pi_logratios = pi_logratios.to(self.accelerator.device)
    ref_logratios = ref_logratios.to(self.accelerator.device)
    logits = pi_logratios - ref_logratios

    # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
    # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
    # calculates a conservative DPO loss.
    losses = (
        -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
        - F.logsigmoid(-self.beta * logits) * self.label_smoothing
    )

    chosen_rewards = (
        self.beta
        * (
            policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
        ).detach()
    )
    rejected_rewards = (
        self.beta
        * (
            policy_rejected_logps.to(self.accelerator.device)
            - reference_rejected_logps.to(self.accelerator.device)
        ).detach()
    )

    return losses, chosen_rewards, rejected_rewards

改进

在 Preference Tuning LLMs with Direct Preference Optimization Methods (huggingface.co) 一文中提到,dpo 也存在一些不足:
dpo 算法很容易在数据集上过拟合dpo 训练依赖成对的偏好数据集,这种数据集的构造和标注都很耗费时间。
一些研究者也根据这些问题提出了新的算法:
针对 1,deepmind 提出了 Identity Preference Optimisation (IPO),给 DPO 的 loss 加了一个正则项,避免训练快速过拟合
针对 2,ContextualAI 提出了 Kahneman-Tversky Optimisation (KTO),KTO 算法的数据集,不再是成对的偏好数据集,而是给每条数据集 “good” 或者 “bad” 的标签进行偏好对齐。
这些算法的相关内容会在介绍完 SPIN 和 self-reward model 之后进行补充,欢迎关注,感谢阅读。
最后说点题外话,通过 DPO,我们可以看出深度学习还是一门实验的学科,如果光看 DPO 算法本身,很难相信这样一个简单的算法,会这么有效。所以多动手写代码,多实验,也许有一天我们也可以发现很 work 的算法,共勉。

参考

  1. Direct Preference Optimization: Your Language Model is Secretly a Reward Model
  2. Aligning LLMs with Direct Preference Optimization (youtube.com)
  3. huggingface/trl: Train transformer language models with reinforcement learning. (github.com)
  4. Preference Tuning LLMs with Direct Preference Optimization Methods (huggingface.co)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1504343.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VMware下载与安装

准备一个Linux的系统,成本最低的方式就是在本地安装一台虚拟机,VMware是业界最好用的虚拟机软件之一 官网:https://www.vmware.com/ 下载页面:https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html …

Python 3 教程(3)

958 在 Windows 下可以不写第一行注释: #!/usr/bin/python3 第一行注释标的是指向 python 的路径,告诉操作系统执行这个脚本的时候,调用 /usr/bin 下的 python 解释器。 此外还有以下形式(推荐写法): #!/usr/bin/env p…

C# 高级特性(十一):多线程之async,await

之前使用Thread和Task启动多线程时都会遇到一个麻烦,就是如何反馈结果。在代码里就是如何设计回调函数。如果带界面还得考虑UI线程的问题。 而使用async,await可以达到两个效果。 1 不用设计回调函数,直接按单线程的格式写。 2 不用考虑UI…

【NR 定位】3GPP NR Positioning 5G定位标准解读(十)-增强的小区ID定位

前言 3GPP NR Positioning 5G定位标准:3GPP TS 38.305 V18 3GPP 标准网址:Directory Listing /ftp/ 【NR 定位】3GPP NR Positioning 5G定位标准解读(一)-CSDN博客 【NR 定位】3GPP NR Positioning 5G定位标准解读(…

面向对象高级编程下

面向对象高级编程下 面向对象高级编程下一. 转换函数二. non-explict-one-argument ctor三. explicit-one-argument ctor四. pointer-like classes1. 智能指针2. 迭代器 五. function-like classes六. namespace七. 模板1.类模板2.函数模板3.成员模板 八.模板特化和偏特化1. 模…

C语言——函数指针——函数指针变量详解

函数指针变量 函数指针变量的作用 函数指针变量是指向函数的指针,它可以用来存储函数的地址,并且可以通过该指针调用相应的函数。函数指针变量的作用主要有以下几个方面: 回调函数:函数指针变量可以作为参数传递给其他函数&…

CIA402协议笔记

文章目录 1、对象字典1.1 Mode of Operation( 606 0 h 6060_h 6060h​)1.2 Modes of opration display( 606 1 h ) 6061_h) 6061h​) 2、状态机2.1 控制字(ControlWord、6040h)2.2 状态字(StatusWord、6041h)2.3 shutd…

C语言书籍——A/陷阱之处

文章参考于文献:《C陷阱与缺陷》[美]Andrew Koening 🌈个人主页:慢了半拍 🔥 创作专栏:《史上最强算法分析》 | 《无味生》 |《史上最强C语言讲解》 | 《史上最强C练习解析》 🏆我的格言:一切只…

杨辉三角(C语言)

杨辉三角 一.什么是杨辉三角 一.什么是杨辉三角 每个数等于它上方两数之和。 每行数字左右对称,由1开始逐渐变大。 第n行的数字有n项。 前n行共[(1n)n]/2 个数。 … 当前行的数上一行的数上一行的前一列的数 void yanghuisanjian(int arr[][20], int n) {for (int i…

3.9Code

基于顺序存储结构的图书信息表的图书去重 #include<iostream> #include<stdlib.h> #include<string.h>typedef int status;#define OK 1using namespace std;typedef struct{char no[50];char name[50];float price; }Book;typedef struct{Book* elem;int …

Linux 学习(持续更新。。。)

wc命令 命令直接执行&#xff0c;输出包含四项&#xff0c;分别代表&#xff1a;行数、字数、字节数、文件。 例子:编译下列代码: #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <string.h> #include <fcntl.h> #inclu…

七、软考-系统架构设计师笔记-数据库设计基础知识

1、数据库基础概念 数据库基本概念 数据(Data)数据库(Database)数据库管理系统(DBMS)数据库系统(DBS) 1.数据(Data) 是数据库中存储的基本对象&#xff0c;是描述事物的符号记录。 数据的种类&#xff1a; 文本、图形、图像、音频、视频等。 2.数据库(Database, DB) 数据库…

JavaScript代码混淆与防格式化功能详解

在前端开发中&#xff0c;为了增加代码的安全性&#xff0c;防止恶意分析和逆向工程&#xff0c;有时候会采用一些防格式化的技术。这些技术主要通过混淆和难以阅读的方式来防止代码的易读性&#xff0c;提高代码的复杂度&#xff0c;增加攻击者分析的难度。 1. 代码压缩与混淆…

docker 使用官方镜像搭建 PHP 环境

一、所需环境&#xff1a; 1、PHP&#xff1a;7.4.33-fpm 的版本 2、Nginx&#xff1a;1.25.1 的版本 3、MySQL&#xff1a; 5.7 的版本 4、Redis&#xff1a;7.0 的版本 1.1、拉取官方的镜像 docker pull php:7.4.33-fpm docker pull nginx:1.25.1 docker pull mysql:5.7 do…

Character类

Character类 功能&#xff1a;实现了对 基本型数据的类包装。 构造方法&#xff1a; (char c) 常用方法&#xff1a; • public static boolean (char ch)//ch是数字字符返回true。 • public static boolean (char ch)//ch是字母返回 true。 • public static char (char c…

商家转账到零钱申请失败怎么办

商家转账到零钱是什么&#xff1f; 商家转账到零钱是微信商户号里的一个功能&#xff0c;以前叫做企业付款到零钱。从 2022 年 5 月 18 日开始&#xff0c;原企业付款到零钱升级为商家转账到零钱&#xff0c;已开通商户的功能使用不受影响&#xff0c;新开通商户可前往产品中心…

简介:图灵机和图灵测试

一、图灵机&#xff08;Turing machine&#xff09; 图灵机&#xff08;Turing machine&#xff09;是由英国数学家Alan Turing于1936年提出的一种抽象计算模型&#xff0c;阿兰图灵在24岁时发表论文《On Computable Numbers, with an Application to the Entscheidungsproble…

指针总结及例题总结

1 定义 指针是用来存放地址的变量 不同类型的指针变量所占用的存储空间是相同的&#xff0c;sizeof(int)sizeof(char)sizeof(double)... *是解引用操作符&#xff0c;&是取地址操作符&#xff0c;两者有着抵消作用 int a20;int* p&a;*p*&a20; 2&#xff0c;…

windows更改账户名

win R输入netplwiz 点击用户名进去&#xff0c; 修改用户名之后重启即可。

echarts 模拟时间轴播放效果

使用echarts柱状图模拟时间轴播放控制。开始/暂停/快进/慢进/点选 代码可直接放echart官方示例执行 let start new Date(2024-01-01); let end new Date(2024-01-10); let diff end - start; let dotLen 200;let data [start, end]; option {color: [#3398DB],tooltip: …