DPO: Direct Preference Optimization 介绍

news2024/9/22 1:25:31

DPO 是 RLHF 的屌丝版本,RLHF 需要加载 4 个模型(2个推理,2个训练),DPO 只需要加载 2 个模型(1个推理,一个训练)。

RLHF:

DPO:

 

DPO 原理

DPO 的本质是监督对比学习:通过对每条prompt提供两条不同的answer,并给出这两个answer的偏好偏序,让模型输出更接近good answer,同时更远离 bad answer。

这个过程中并不强制要求上述两者同时满足,只要接近good answer的程度大于bad answer就是有效的训练,比如与good answer远离了,但是与bad answer远离的更多也是有效的。

DPO loss

 

σ :sigmoid函数

β :超参数,一般在0.1 - 0.5之间

y_w :某条偏好数据中好的response,w就是win的意思

y_l :某条偏好数据中差的response,l就是loss的意思,所以偏好数据也叫comparision data

\pi_\theta(y_w|x) :给定输入x, 当前policy model生成好的response的累积概率(每个tokne的概率求和,具体看代码)

\pi_{ref}(y_l|x) :给定输入x, 原始模型(reference model)生成坏的response的累积概率

开始训练时,reference model和policy model都是同一个模型,只不过在训练过程中reference model不会更新权重。

简化形式:忽略 logsigmoid 并取对数

由于最初loss前面是有个负号的,所以优化目标是让本简化公式最大,即希望左半部分和右半部分的margin越大越好,左半部分的含义是good response相较于没训练之前的累积概率差值,右半部分代表bad response相较于没训练之前的累计概率差值,如果这个差值,即margin变大了。

 DPO 数据集

可以由prompt 模板: Human: prompt. Assistant: chosen/rejected 构成如下数据:Anthropic/hh-rlhf dataset

 DPO trainer 期望数据集具有非常特定的格式。 给定两个句子时,模型将被训练为直接优化偏好:那一个句子最相关。

Huagging Face DPO Trainer

与 PPO 期望 AutoModelForCausalLMWithValueHead 作为值函数相比,DPO 训练器期望 AutoModelForCausalLM 模型。 

 dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

Loss 选择:

  • RSO 作者建议在 SLiC 论文中的归一化似然上使用 hinge损失。 DPOTrainer 可以通过 loss_type="hinge" 参数切换到此损失,这种情况下的 beta 是margin的倒数。
  • IPO 作者对 DPO 算法提供了更深入的理论理解,并识别了过度拟合的问题,并提出了一种替代损失,可以通过训练器的 loss_type="ipo" 参数来使用。
  • cDPO 是对 DPO 损失的调整,其中我们假设偏好标签有一定的噪声,可以通过 label_smoothing 参数(0 到 0.5 之间)传递到 DPOTrainer,然后使用保守的 DPO 损失。 使用 loss_type="cdpo" 参数给训练器来使用它。
  • KTO 损失的导出是为了直接最大化 LLM 代的效用,而不是偏好的对数似然。 因此,数据集不一定是偏好,而是期望的完成与不期望的完成。 对于 DPOTrainer 所需的配对偏好数据,请使用训练器的 loss_type="kto_pair" 参数来利用此损失,而对于所需和不需要的数据的更一般情况,请使用尚未实现的 KTOTrainer。

简单实例

#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy

torch.manual_seed(0)
if __name__ == "__main__":
    # 超参数
    beta = 0.1
    # 加载模型
    policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=1000, num_hidden_layers=1, hidden_size=128))
    reference_model = deepcopy(policy_model)

    # data
    prompt_ids = [1, 2, 3, 4, 5, 6]
    good_response_ids = [7, 8, 9, 10]
    # 对loss稍加修改可以应对一个good和多个bad的情况
    bad_response_ids_list = [[1, 2, 3, 0], [4, 5, 6, 0]]

    # 转换成模型输入 [3, 10]
    input_ids = torch.LongTensor(
        [prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
    )
    # labels 提前做个shift [3, 9]
    labels = torch.LongTensor(
        [
            [-100] * len(prompt_ids) + good_response_ids,
            *[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
        ]
    )[:, 1:]
    loss_mask = (labels != -100)
    labels[labels == -100] = 0
    # 计算 policy model的log prob
    # policy_model(input_ids)["logits"] [3, 10, 1000] 句末的推理结果无效直接忽略
    logits = policy_model(input_ids)["logits"][:, :-1, :]
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    all_logps = (per_token_logps * loss_mask).sum(-1)
    # 暂时写死第一个是good response的概率, 三个例子中第一个是 good answer, 后两个是 bad answer
    policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]

    # 计算 reference model的log prob
    with torch.no_grad():
        logits = reference_model(input_ids)["logits"][:, :-1, :]
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        all_logps = (per_token_logps * loss_mask).sum(-1)
        # 暂时写死第一个是good response的概率
        reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]

    # 计算loss,会自动进行广播
    logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
    loss = -F.logsigmoid(beta * logits).mean()
    print(loss)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2135239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

canfd 卡 canfd-422ac在汽车电子测试中的大作用

随着汽车电子的高速发展,车内信息的急剧增多,传统的CAN总线的数据传输能力已经很难满足车辆ECU的数据传输需求了,此时CANFD就应运而生了。![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/c3822ac2b2ed4694a58132e0b4743e99.png)CAN…

无需部署,云电脑带你秒变AI绘画大师

在人工智能的浪潮中,AI 绘画逐渐火爆,通过关键字描述就能让AI把你想要的画面完美展示出来。 当前最火的绘画软件当属 Midjonary, Midjourney 界面比较直观,适合新手以及非专业的用户,风格比较多样,可以轻松…

Python “集合” 100道实战题目练习,巩固知识、检查技术

本文主要是作为Python中列表的一些题目,方便学习完Python的集合之后进行一些知识检验,感兴趣的小伙伴可以试一试,含选择题、判断题、实战题、填空题,答案在第五章。 在做题之前可以先学习或者温习一下Python的列表,推荐…

Python selenium 破解腾讯滑块行为验证码

直接上代码: from selenium import webdriver from selenium.webdriver.common.action_chains import ActionChains import time,re,requests from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from seleniu…

leetcode两个题:只出现一次的数字、杨辉三角等的介绍

文章目录 前言一、只出现一次的数字二、杨辉三角总结 前言 leetcode两个题&#xff1a;只出现一次的数字、杨辉三角等的介绍 一、只出现一次的数字 0跟任何数异或结果都是任何数相同的数异或结果为0 class Solution { public:int singleNumber(vector<int>& nums) …

Navigation之使用Safe Args传递数据(二)

系列文章目录 Navigation的简单使用(一&#xff09; 一、Safe Args传递数据 1.引入库 1.将 Safe Args 添加到您的项目&#xff0c;请在顶层 build.gradle 文件中包含以下 classpath&#xff1a; buildscript {repositories {google()}dependencies {def nav_version "…

2024年9月114日(使用kubectl run 创建pod|配置文件创建pod)

一、pod 1、更改镜像站 [rootk8s-master ~]# vim /etc/docker/daemon.json {"registry-mirrors": ["https://do.nark.eu.org","https://dc.j8.work","https://docker.m.daocloud.io","https://dockerproxy.com","http…

如何将本地项目上传到GitHub(SSH连接)

在个人GitHub中新建项目(远程仓库)&#xff0c;添加一个README文件&#xff0c;方便后面验证 记住这个默认分支&#xff0c;我这里是main&#xff0c;你的可能是master或其他 先复制下SSH地址 在项目文件夹中右键打开Git命令行 初始化本地仓库&#xff0c;同时指定默认分支为ma…

MyBatis 面试题11-27

11、Mybatis 是如何将 sql 执行结果封装为目标对象并返回的? 都有哪些映射形式&#xff1f; Mybatis 在执行 SQL 查询后&#xff0c;会将结果集封装为目标对象并返回。这主要依赖于 Mybatis 的映射机制&#xff0c;它提供了两种主要的映射形式&#xff1a; 第一种&#xff1…

【触想智能】工控一体机在船舶航运上应用的优势和应用场景分析

随着船舶航运业的发展&#xff0c;工控一体机在船舶航运领域上的应用越来越广泛。工控一体机的功能和性能可以加强船舶航运领域的自动化和智能化水平。 下面&#xff0c;触想智能小编针对工控一体机在船舶航运领域上应用的优势和应用场景进行简单分析&#xff0c;给大家借鉴参考…

LVGL控件之表格(lv_table)

目录 一、概述二、表格1、设置单元格的值2、行和列的设置3、宽度和高度的设置4、合并单元格5、滚动6、事件7、API 函数 一、概述 Table&#xff08;表格&#xff09;是由包含文本的行、列和单元格构建的。 表格对象非常轻量级&#xff0c;因为仅存储文本。没有为各个单元格创…

Altium Designer常用操作备忘笔记

Altium Designer常用操作备忘笔记 Chapter1 Altium Designer常用操作备忘笔记Chapter2 Altium Designer 22.1.2使用总结&#xff08;常更&#xff09;一、原理图1.1 绘制元器件原理图1.2 绘制元器件封装1.3 修改原理图网格1.4 修改原理图库后更新当前原理图1.5 旋转和翻转1.6 悬…

Leetcode Hot 100刷题记录 -Day15(螺旋矩阵)

螺旋矩阵 问题描述&#xff1a; 给你一个 m 行 n 列的矩阵 matrix &#xff0c;请按照 顺时针螺旋顺序 &#xff0c;返回矩阵中的所有元素。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,2,3],[4,5,6],[7,8,9]]输出&#xff1a;[1,2,3,6,9,8,7,4,5] 示例 2&#xff1a; 输…

104. 二叉树的最大深度【 力扣(LeetCode) 】

零、LeetCode 原题 104. 二叉树的最大深度 一、题目描述 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 二、测试用例 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出…

[git] MacBook 安装git

文章目录 1.Mac Git 安装2. 开发者工具安装 CommandLineTools安装完成&#xff0c;错误解决 3. git 账户配置账户设置生成秘钥git 或者 gitee 仓库添加公钥查看全局账户命令多账户设置config文件测试连接clone到本地 1.Mac Git 安装 Mac一般自带Git工具&#xff0c;也就是说已…

监听html元素是否被删除,删除之后重新生成被删除的元素

/*** 监听水印是否清除和修改*/ export function watermarkClear() {// 添加水印的盒子let box: any document.querySelector(.dplayer-video-wrap)// 水印let watermark: any document.querySelector(.dplayer-logo)// 观察器的配置&#xff08;需要观察什么变动&#xff09…

css scrollbar-width: none 隐藏默认滚动条

.table-box{ flex: 1; overflow-y: scroll; scrollbar-width: none;} scrollbar-width: none; 隐藏默认滚动条

计算机二级自学笔记(程序题1部分)

(1)b fun函数内a数组作为参数&#xff0c;下方提供了b函数作为中间数组所以在这进行初始化 (2)2 第二个循环内将a数组分为1-27与27-55两部分&#xff0c;1-27存放在b数组的奇数位&#xff0c;27-55存放在b数组的偶数位 (3)a[k] 将b数组传递回数组a nm mn 已经有的参数…

使用 PyTorch 构建 MNIST 手写数字识别模型

引言 MNIST 数据集是一个经典的机器学习数据集&#xff0c;包含了 70,000 张手写数字图像&#xff0c;其中 60,000 张用于训练&#xff0c;10,000 张用于测试。每张图像都是 28x28 像素的灰度图像&#xff0c;并且已经被居中处理以减少预处理步骤。本文将介绍如何使用 PyTorch…

剪辑必备,6个可白嫖的视频素材网站。

找视频素材就上这6个网站&#xff0c;免费下载&#xff0c;赶紧收藏好&#xff01; 1、菜鸟图库 视频素材下载_mp4视频大全 - 菜鸟图库 菜鸟图库网素材非常丰富&#xff0c;网站主要以设计类素材为主&#xff0c;高清视频素材也很多&#xff0c;像风景、植物、动物、人物、科技…