大型语言模型(LLM)——直接偏好优化完整指南

news2025/1/11 8:17:07

概述

将大型语言模型 (LLM) 与人类价值观和偏好相结合是一项挑战。传统方法,例如 [从人类反馈中强化学习]((RLHF)通过整合人类输入来完善模型输出,为这一领域的研究铺平了道路。然而,RLHF 可能非常复杂且资源密集,需要大量的计算能力和数据处理。 直接偏好优化 (DPO)作为一种新颖且更精简的方法出现,为这些传统方法提供了一种有效的替代方案。通过简化优化过程,DPO 不仅减少了计算负担,还增强了模型快速适应人类偏好的能力

偏好协调的必要性

在深入研究 DPO 之前,我们必须先了解为什么将 LLM 与人类偏好相结合如此重要。尽管 LLM 具有令人印象深刻的能力,但经过大量数据集训练的 LLM 有时会产生不一致、有偏见或与人类价值观不一致的输出。这种不一致可以表现为多种方式:

  • 生成不安全或有害的内容
  • 提供不准确或误导性的信息
  • 训练数据中存在偏差

为了解决这些问题,研究人员开发了利用人工反馈来微调 LLM 的技术。其中最突出的方法是 RLHF。

了解 RLHF:DPO 的前身

人类反馈强化学习 (RLHF) 一直是将 LLM 与人类偏好相结合的首选方法。让我们分解 RLHF 流程以了解其复杂性:

a) 监督微调 (SFT):该过程首先在高质量响应数据集上对预先训练的 LLM 进行微调。此步骤可帮助模型为目标任务生成更相关、更连贯的输出。

b) 奖励模型:训练单独的奖励模型来预测人类偏好。这涉及:

  • 根据给定的提示生成响应对
  • 让人类评价他们喜欢哪种反应
  • 训练模型来预测这些偏好

c) 强化学习:经过微调的 LLM 随后使用强化学习进一步优化。奖励模型提供反馈,引导 LLM 生成符合人类偏好的反应。

下面是一个简化的 Python 伪代码,用于说明 RLHF 过程:

虽然 RLHF 有效,但它有几个缺点:

  • 它需要训练和维护多个模型(SFT、奖励模型和 RL 优化模型)
  • RL 过程可能不稳定,并且对超参数敏感
  • 计算成本高昂,需要通过模型进行多次前向和后向传递

这些限制促使人们寻找更简单、更有效的替代方案,从而导致了 DPO 的发展。

直接偏好优化:核心概念

在这里插入图片描述

此图对比了两种将 LLM 输出与人类偏好相一致的不同方法:基于人类反馈的强化学习 (RLHF) 和直​​接偏好优化 (DPO)。RLHF 依靠奖励模型通过迭代反馈循环来指导语言模型的策略,而 DPO 则使用偏好数据直接优化模型输出以匹配人类偏好的响应。此比较突出了每种方法的优势和潜在应用,为未来 LLM 如何训练以更好地符合人类期望提供了见解。

DPO 背后的关键思想:

a) 隐性奖励模型:DPO 将语言模型本身视为隐式奖励函数,从而消除了对单独奖励模型的需求。

b) 基于策略的制定:DPO 不优化奖励函数,而是直接优化策略(语言模型)以最大化首选响应的概率。

c) 闭式解:DPO 利用数学洞察力,可以对最佳策略提供闭式解,从而避免进行迭代 RL 更新。

实施 DPO:实用代码演练

下图展示了使用 PyTorch 实现 DPO 损失函数的代码片段。该函数在改进语言模型如何根据人类偏好对输出进行优先排序方面起着至关重要的作用。以下是关键组件的细分:

  • 函数签名:本 dpo_loss 函数接受几个参数,包括策略日志概率(pi_logps),参考模型对数概率(ref_logps),以及代表优先和非优先完成情况的指数(yw_idxs, yl_idxs)此外, beta 参数控制KL惩罚的强度。
  • 对数概率提取:代码从策略和参考模型中提取首选和不首选完成的对数概率。
  • 对数比率计算:针对策略模型和参考模型,计算了优先完成和不优先完成的对数概率之间的差异。该比率对于确定优化的方向和幅度至关重要。
  • 损失和奖励计算:损失计算如下: logsigmoid 函数,而奖励则通过缩放策略和参考日志概率之间的差异来确定 beta.
    在这里插入图片描述
    使用 PyTorch 的 DPO 损失函数

DPO 的数学原理

DPO 是对偏好学习问题的一个巧妙的重新表述。下面是分步分解:

a)起点:KL 约束奖励最大化

原始 RLHF 目标可以表示为:
在这里插入图片描述
地点:

  • πθ 是我们正在优化的策略(语言模型)
  • r(x,y) 是奖励函数
  • πref 是参考策略(通常是初始 SFT 模型)
  • β 控制 KL 散度约束的强度

b) 最优策略形式: 可以证明,该目标的最优策略采取如下形式:

π_r(y|x) = 1/Z(x) * πref(y|x) * exp(1/β * r(x,y))

其中 Z(x) 是归一化常数。

c) 奖励策略二元性: DPO 的关键见解是用最优策略来表达奖励函数:

r(x,y) = β * log(π_r(y|x) / πref(y|x)) + β * log(Z(x))

d) 偏好模型假设偏好遵循 Bradley-Terry 模型,我们可以将偏好 y1 而非 y2 的概率表示为:

p*(y1 ≻ y2 | x) = σ(r*(x,y1) - r*(x,y2))

其中 σ 是逻辑函数。

e) DPO 目标 将我们的奖励策略二元性代入偏好模型,我们得出 DPO 目标:

L_DPO(πθ; πref) = -E_(x,y_w,y_l)~D [log σ(β * log(πθ(y_w|x) / πref(y_w|x)) - β * log(πθ(y_l|x) / πref(y_l|x)))]

可以使用标准梯度下降技术来优化该目标,而无需 RL 算法。

实施 DPO

现在我们了解了 DPO 背后的理论,让我们看看如何在实践中实现它。我们将使用 蟒蛇PyTorch 对于此示例:

import torch
import torch.nn.functional as F
class DPOTrainer:
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5):
        self.model = model
        self.ref_model = ref_model
        self.beta = beta
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
     
    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs):
        """
        pi_logps: policy logprobs, shape (B,)
        ref_logps: reference model logprobs, shape (B,)
        yw_idxs: preferred completion indices in [0, B-1], shape (T,)
        yl_idxs: dispreferred completion indices in [0, B-1], shape (T,)
        beta: temperature controlling strength of KL penalty
        Each pair of (yw_idxs[i], yl_idxs[i]) represents the indices of a single preference pair.
        """
        # Extract log probabilities for the preferred and dispreferred completions
        pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
        ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
        # Calculate log-ratios
        pi_logratios = pi_yw_logps - pi_yl_logps
        ref_logratios = ref_yw_logps - ref_yl_logps
        # Compute DPO loss
        losses = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
        rewards = self.beta * (pi_logps - ref_logps).detach()
        return losses.mean(), rewards
    def train_step(self, batch):
        x, yw_idxs, yl_idxs = batch
        self.optimizer.zero_grad()
        # Compute log probabilities for the model and the reference model
        pi_logps = self.model(x).log_softmax(-1)
        ref_logps = self.ref_model(x).log_softmax(-1)
        # Compute the loss
        loss, _ = self.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs)
        loss.backward()
        self.optimizer.step()
        return loss.item()
# Usage
model = YourLanguageModel()  # Initialize your model
ref_model = YourLanguageModel()  # Load pre-trained reference model
trainer = DPOTrainer(model, ref_model)
for batch in dataloader:
    loss = trainer.train_step(batch)
    print(f"Loss: {loss}")

挑战和未来方向

虽然 DPO 比传统 RLHF 方法具有显著优势,但仍然存在挑战和有待进一步研究的领域:

a)可扩展至更大的模型:

随着语言模型的规模不断扩大,如何有效地将 DPO 应用于具有数千亿个参数的模型仍然是一个悬而未决的挑战。研究人员正在探索以下技术:

  • 高效的微调方法(例如,LoRA、前缀调整)
  • 分布式训练优化
  • 梯度检查点和混合精度训练

使用 LoRA 与 DPO 的示例:

from peft import LoraConfig, get_peft_model
class DPOTrainerWithLoRA(DPOTrainer):
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5, lora_rank=8):
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        self.model = get_peft_model(model, lora_config)
        self.ref_model = ref_model
        self.beta = beta
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
# Usage
base_model = YourLargeLanguageModel()
dpo_trainer = DPOTrainerWithLoRA(base_model, ref_model)

开发能够有效适应偏好数据有限的新任务或领域的 DPO 技术是一个活跃的研究领域。正在探索的方法包括:

  • 用于快速适应的元学习框架
  • 基于提示的 DPO 微调
  • 将学习从一般偏好模型转移到特定领域
c)处理模糊或冲突的偏好:

现实世界的偏好数据通常包含歧义或冲突。提高 DPO 对此类数据的稳健性至关重要。潜在的解决方案包括:

  • 概率偏好建模
  • 主动学习解决歧义
  • 多代理偏好聚合

概率偏好建模的示例:

class ProbabilisticDPOTrainer(DPOTrainer):
    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob):
        # Compute log ratios
        pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
        ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
         
        log_ratio_diff = pi_yw_logps.sum(-1) - pi_yl_logps.sum(-1)
        loss = -(preference_prob * F.logsigmoid(self.beta * log_ratio_diff) +
                 (1 - preference_prob) * F.logsigmoid(-self.beta * log_ratio_diff))
        return loss.mean()
# Usage
trainer = ProbabilisticDPOTrainer(model, ref_model)
loss = trainer.compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, preference_prob=0.8)  # 80% c
d)将DPO与其他对准技术相结合:

将 DPO 与其他对齐方法相结合可以产生更强大、更强大的系统:

  • 明确约束满足的宪法人工智能原则
  • 用于复杂偏好引出的辩论和递归奖励建模
  • 用于推断底层奖励函数的逆向强化学习

DPO 与体质 AI 相结合的示例:

class ConstitutionalDPOTrainer(DPOTrainer):
    def __init__(self, model, ref_model, beta=0.1, lr=1e-5, constraints=None):
        super().__init__(model, ref_model, beta, lr)
        self.constraints = constraints or []
    def compute_loss(self, pi_logps, ref_logps, yw_idxs, yl_idxs):
        base_loss = super().compute_loss(pi_logps, ref_logps, yw_idxs, yl_idxs)
         
        constraint_loss = 0
        for constraint in self.constraints:
            constraint_loss += constraint(self.model, pi_logps, ref_logps, yw_idxs, yl_idxs)
         
        return base_loss + constraint_loss
# Usage
def safety_constraint(model, pi_logps, ref_logps, yw_idxs, yl_idxs):
    # Implement safety checking logic
    unsafe_score = compute_unsafe_score(model, pi_logps, ref_logps)
    return torch.relu(unsafe_score - 0.5)  # Penalize if unsafe score > 0.5
constraints = [safety_constraint]
trainer = ConstitutionalDPOTrainer(model, ref_model, constraints=constraints)

实际考虑和最佳实践

在为实际应用实施 DPO 时,请考虑以下提示:

a) 数据质量:偏好数据的质量至关重要。确保您的数据集:

  • 涵盖多种输入和期望行为
  • 具有一致且可靠的偏好注释
  • 平衡不同类型的偏好(例如事实性、安全性、风格)

b) 超参数调整:虽然 DPO 的超参数比 RLHF 少,但调整仍然很重要:

  • β(beta):控制偏好满足度与参考模型偏差之间的权衡。从以下值开始 0.1-0.5.
  • 学习率:使用比标准微调更低的学习率,通常在以下范围内: 1e-6 至 1e-5.
  • 批次大小:更大的批次大小(32-128) 通常对偏好学习很有效。

c) 迭代细化:DPO可以迭代应用:

  1. 使用 DPO 训练初始模型
  2. 使用经过训练的模型生成新的响应
  3. 收集有关这些回应的新偏好数据
  4. 使用扩展的数据集重新训练
    在这里插入图片描述

直接偏好优化性能

该图深入研究了 GPT-4 等 LLM 与人类判断在各种训练技术(包括直接偏好优化 (DPO)、监督微调 (SFT) 和近端策略优化 (PPO))中的性能对比。该表显示,GPT-4 的输出越来越符合人类偏好,尤其是在摘要任务中。GPT-4 与人类审阅者之间的一致性水平表明该模型能够生成与人类评估者产生共鸣的内容,几乎与人类生成的内容一样接近。

案例研究和应用

为了说明 DPO 的有效性,让我们看一些实际应用及其一些变体:

  • 迭代 DPO:此变体由 Snorkel (2023) 开发,将拒绝采样与 DPO 相结合,从而实现更精细的训练数据选择过程。通过对多轮偏好采样进行迭代,该模型能够更好地泛化并避免过度拟合嘈杂或有偏见的偏好。
  • 首次公开募股(迭代偏好优化):IPO 由 Azar 等人 (2023) 提出,增加了一个正则化项来防止过度拟合,这是基于偏好的优化中常见的问题。此扩展允许模型在遵循偏好和保留泛化能力之间保持平衡。
  • 韩国旅游观光局 (知识转移优化):Ethayarajh 等人 (2023) 的最新变体 KTO 完全摒弃了二元偏好。相反,它专注于将知识从参考模型转移到策略模型,以优化与人类价值观的更顺畅和更一致的一致性。
  • 用于跨领域学习的多模态 DPO 作者:Xu 等人(2024 年):一种将 DPO 应用于不同模态(文本、图像和音频)的方法,展示了其在将模型与人类偏好相结合方面跨不同数据类型的多功能性。这项研究强调了 DPO 在创建能够处理复杂、多模态任务的更全面的 AI 系统方面的潜力。

结论

直接偏好优化代表了语言模型与人类偏好相一致的重大进步。它的简单性、效率和有效性使其成为研究人员和从业人员的强大工具。

通过利用直接偏好优化的强大功能并牢记这些原则,您可以创建不仅具有令人印象深刻的功能而且与人类价值观和意图紧密结合的语言模型。

原文地址:https://www.unite.ai/direct-preference-optimization-a-complete-guide/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2051548.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CVE-2024-27198 和 CVE-2024-27199:JetBrains TeamCity 服务器的漏洞利用及其防护措施

引言 JetBrains TeamCity 作为一个广泛使用的持续集成和部署工具,其安全性备受关注。然而,最近披露的CVE-2024-27198和CVE-2024-27199两个漏洞揭示了该平台存在的重大安全隐患。这些漏洞允许攻击者通过绕过身份验证机制,创建未经授权的管理员…

Java代码基础算法练习-乘阶求和-2024.08.18

对应的源代码可以在我的 Gitee 仓库中找到&#xff0c;欢迎star~ [Gitee 仓库](https://gitee.com/yukongji/java-basic-algorithm) 任务描述&#xff1a; 求Sn1!2!3!4!5!…n!之值&#xff0c;其中n是一个数字(n<10)。 解决思路&#xff1a; 输入: 读取用户输入的 n 值。检查…

Java Sream中自定义Collector实现复杂数据收集方法

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

雷达气象学(10)——双偏振雷达及其变量

文章目录 10.1 双偏振雷达的优势10.2 双偏振变量——反映降水粒子特性的变量10.3 差分反射率 Z D R Z_{DR} ZDR​10.3.1 差分反射率的定义10.3.2 雨滴的差分反射率10.3.3 冰、雪和冰雹的差分反射率10.3.4 层云降水的差分反射率10.3.5 对流降水的差分反射率 10.4 相关系数 ρ …

Chrome快捷键提高效率

浏览效率提高快捷建 快速切换标签页 Ctrl 数字&#xff08;1或者2&#xff09;&#xff0c;标签页数字从左到右为顺序&#xff0c;1开始。快速切换标签页。 Ctrl1 到 Ctrl8 切换到标签栏中指定位置编号所对应的标签页 Ctrl9切换到最后一个标签页 CtrlTab 或 CtrlPgDown 切…

Elasticsearch中磁盘水位线的深度解析

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

Clobbering DOM attributes to bypass HTML filters

目录 寻找注入点 代码分析 payload构造 注入结果 寻找注入点 DOM破坏肯定是出现在js文件中&#xff0c;我们首先来看源码 /resources/labheader/js/labHeader.js这个源码没什么问题我们重点关注在下面两个源码上 /resources/js/loadCommentsWithHtmlJanitor.js这个源码中重…

从关键新闻和最新技术看AI行业发展(第二十九期2024.7.29-8.11) |【WeThinkIn老实人报】

写在前面 【WeThinkIn老实人报】旨在整理&挖掘AI行业的关键新闻和最新技术&#xff0c;同时Rocky会对这些关键信息进行解读&#xff0c;力求让读者们能从容跟随AI科技潮流。也欢迎大家提出宝贵的优化建议&#xff0c;一起交流学习&#x1f4aa; 欢迎大家关注Rocky的公众号&…

docker 安装mino服务,启动报错: Fatal glibc error: CPU does not support x86-64-v2

背景 docker 安装mino服务&#xff0c;启动报错&#xff1a; Fatal glibc error: CPU does not support x86-64-v2 原因 Docker 镜像中的 glibc 版本要求 CPU 支持 x86-64-v2 指令集&#xff0c;而你的硬件不支持。 解决办法 降低minio对应的镜像版本 经过验证&#xff1a;qu…

sanic + webSocket:股票实时行情推送服务实现

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storm…

Linux中的内核编程

Linux内核是操作系统的核心组件&#xff0c;负责管理系统的资源、提供硬件抽象和执行系统调用。内核编程是一项涉及操作系统核心的高级任务&#xff0c;它允许开发人员直接与系统内核进行交互&#xff0c;实现更高效、更特定的功能。本文将深入探讨Linux中的内核编程&#xff0…

2021年上半年网络工程师考试上午真题

2021年上半年网络工程师考试上午真题 网络工程师历年真题含答案与解析 第 1 题 以下关于RISC和CISC计算机的叙述中&#xff0c;正确的是&#xff08; &#xff09;。 (A) RISC不采用流水线技术&#xff0c;CISC采用流水线技术(B) RISC使用复杂的指令&#xff0c;CISC使用简…

事件驱动架构的事件版本管理

有一种办法&#xff1a;发送会议邀请给所有团队&#xff0c;经过101次会议后&#xff0c;发布维护横幅&#xff0c;所有人同时点击发布按钮。或... 可用适配器&#xff0c;但微调。没错&#xff01;就像软件开发中90%问题一样&#xff0c;有种模式帮助你找到聪明解决方案。 1…

C Primer Plus(中文版)第13章编程练习,仅供参考

第十三章编程练习 对于文件的操作是程序开发过程中必不可少的。首先&#xff0c;来看一下第一题&#xff0c;对13.1程序进行修改&#xff0c;输入文件名&#xff0c;而不是命令行参数。完整程序代码以及运行结果如下&#xff1a; #include<stdio.h> #include<stdlib…

【数据结构篇】~单链表(附源码)

【数据结构篇】~链表 链表前言链表的实现1.头文件2.源文件 链表前言 链表是一种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的指针链接次序实现的。 1、链式机构在逻辑上是连续的&#xff0c;在物理结构上不一定连续​ 2、结点一般是从…

Java二十三种设计模式-命令模式(18/23)

命令模式&#xff1a;将请求封装为对象的策略 概要 本文全面探讨了命令模式&#xff0c;从基础概念到实现细节&#xff0c;再到使用场景、优缺点分析&#xff0c;以及与其他设计模式的比较&#xff0c;并提供了最佳实践和替代方案&#xff0c;旨在帮助读者深入理解命令模式并…

【xr-frame】微信小程序xr-frame典型案例

微信小程序xr-frame典型案例 在之前的工作中&#xff0c;我大量使用XR-Frame框架进行AR开发&#xff0c;并积累了一些案例和业务代码。其中包括2D图像识别、手部动作识别、Gltf模型加载、动态模型加载、模型动画等内容。小程序部分使用TypeScript编写&#xff0c;而XR-Frame组…

利用puppeteer将html网页生成图片

1.什么是puppeteer&#xff1f; Puppeteer是一个Node库&#xff0c;它提供了一个高级API来通过DevTools协议控制Chromium或Chrome。 可以使用Puppeteer来自动化完成浏览器的操作&#xff0c;官方给出的一些使用场景如下&#xff1a; 生成页面PDF抓取 SPA&#xff08;单页应用…

3.Windows Login Unlocker-忘记电脑密码也可以解决

想要解锁Windows系统的开机密码&#xff0c;但官网的传统方法只适合Windows本地账户&#xff0c;对微软账户或PIN码()束手无策&#xff1f;别担心&#xff0c;小编之前推荐过的「Windows Login Unlocker」软件能为您排忧解难。这款出色的工具不仅能够轻松绕过各种Windows密码&a…

C语言-写一个用矩形法求定积分的通用函数,分别求积分区间为[0,1]sinx,cosx,e的x方的定积分

一、题目要求&#xff1a; 二、思路 ①数学方面:矩形法求定积分的公式 将积分图形划分成为指定数量的矩形&#xff0c;求取各个矩形的面积&#xff0c;然后最终进行累加得到结果 1.积分区间: [num1, num2] 2.分割数量:count 每个矩形的边长:dx(num2-num1)/count 3.被积分…