使用DPO微调Llama2

news2024/11/20 19:22:49

简介

0a8c71720b4dabdfcaae8d442bd45bcc.jpeg

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)
  2. 用偏好标签标注数据
  3. 基于偏好数据训练奖励模型
  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示
  • chosen : 即针对给定提示的较优回答
  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(
    model, # 经 SFT 的基础模型
    model_ref, # 一般为经 SFT 的基础模型的一个拷贝
    beta=0.1, # DPO 的温度超参
    train_dataset=dataset, # 上文准备好的数据集
    tokenizer=tokenizer, # 分词器
    args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.1 到 0.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name, # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
    use_auth_token=True,
)
base_model.config.use_cache = False

# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # location of saved SFT model
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(
    script_args.model_name_or_path, # same model as the main one
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=script_args.beta,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/927389.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

认准这几条Web设计规范,做好Web不在话下!

在当今数字化的世界中,Web设计的重要性愈发凸显。无论是企业网站、电子商务平台还是个人博客,用户对网站的外观和体验要求越来越高。为了确保用户能够轻松访问和使用网站,遵循Web设计规范是至关重要的。本文将探讨一些关键的Web设计规范&…

Failed to start bean ‘documentationPluginsBootstrapper‘

问题描述 在集成redisson-spring-boot-starter时,项目启动时报如下错误 之前在集成swagger3.0的时候,遇到过同样的问题,原因是Springfox使用的路径匹配是基于AntPathMatcher的,而Spring Boot 2.7.X使用的是PathPatternMat…

使用oracleVM搭建虚拟机

选择新建,点击 取名字,选择你的安装路径,选择你爹镜像光盘,再勾选下面的,表示跳过一些步骤 其他的都可以默认,下一步即可 创建好了,点击设置,改变光驱,硬盘的顺序 等待它…

MES管理系统如何实现数据采集和过程控制

随着工业4.0的到来,MES管理系统解决方案已成为企业实现生产过程数字化和智能化的关键工具。MES生产管理系统不仅提供生产计划、调度、质量管理和设备维护等功能,还在数据采集和过程控制方面发挥着重要作用。本文将探讨MES生产管理系统如何实现数据采集和…

智能化追踪与实时管理:RFID技术在流水线上的革命性应用

随着科技的不断发展,物联网技术已经深入到了我们生活的方方面面,其中,射频识别(Radio Frequency Identification,简称RFID)技术被广泛应用于各行各业。在流水线生产中,RFID技术的应用也越来越广…

跨模态检索:基于OpenAI的Clip预训练模型构建以文搜图系统

目录 1 项目背景 2 关键技术 2.1 Clip模型 2.2 Milvus向量数据库 3 系统代码实现 3.1 运行环境构建 3.2 数据集下载 3.3 预训练模型下载 3.4 代码实现 3.4.1 创建向量表和索引 3.4.2 构建向量编码模型 3.4.3 数据向量化与加载 3.4.4 构建检索web 4 总结 1 项目背景…

如何数据库备份,如何将数据库备份到其他服务器

在当今的数字世界里,数据库已经成为单位和个人存储、管理和检索海量数据的关键工具。然而,随着数据量的增加,内容丢失的风险也随之增加。这就是为什么定期备份数据库变得尤为重要。本文将详细介绍如何有效备份数据库,以保护您的数…

2023高教社杯数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 建模资料 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米,宽为12米&…

五度易链最新“产业大数据服务解决方案”亮相,打造数据引擎,构建智慧产业

快来五度易链官网 点击网址【http://www.wdsk.net/】 看看我们都发布了哪些新功能!!! 自2015年布局产业大数据服务行业以来,“五度易链”作为全国产业大数据服务行业先锋企业,以“让数据引领决策,以智慧驾驭未来”为愿景,肩负“打…

说点大实话丨知名技术博主 Kirito 测评云原生网关

作者:徐靖峰 关注了阿里云云原生公众号,经常能看到 MSE-Higress 相关的推文,恰逢这次阿里云产品举办了一个 MSE-Higress 云原生网关的测评活动,借此机会体验了一把云原生网关的功能。 购买流程体验 购买网关时,页面明…

python入门篇04-循环(while与for),变量,函数基础

python目录 1. 前言1.1 上文传送 2. python基础使用2.1 while循环2.1.1 while循环的使用> 案例: 猜数字游戏(多经典...) 2.1.2 while双层循环> 案例: 输出9*9乘法表> 运行结果 2.2 for循环2.2.1 **for循环使用**> 案例: (字符串)查出有多少字符 2.2.2 方法range()的…

Leetcode每日一题:1448. 统计二叉树中好节点的数目

原题 给你一棵根为 root 的二叉树,请你返回二叉树中好节点的数目。 「好节点」X 定义为:从根到该节点 X 所经过的节点中,没有任何节点的值大于 X 的值。 示例 1: 输入:root [3,1,4,3,null,1,5] 输出:4 解…

网络映射会遇到哪些困难

网络映射通过将复杂的网络划分为更小、可管理的块,帮助 IT 管理员获得对其网络的更大控制和可见性,它有助于可视化不同的网络组件(如服务器、交换机端口和路由器)如何互连以执行其功能,通过表示网络设备的通信方式&…

腾讯云服务器价格表大全_轻量服务器_CVM云服务器报价明细

腾讯云服务器租用费用表:轻量应用服务器2核2G4M带宽112元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、云服务器CVM S5实例2核2G配置280.8元一年、GPU服务器GN10Xp实例145元7天,腾讯云服务器网长期更新腾讯云轻量…

无涯教程-进程 - 子进程监控

正如我们所看到的,每当我们使用fork从程序创建子进程时,都会发生以下情况- 当前进程成为父进程新进程成为子进程 如果父进程比子进程提前完成任务然后退出,会发生什么?现在谁将成为子进程的父进程?子进程的父进程是init进程,它…

业财融合背景下,全面预算管理的发展之路

随着社会经济的高速发展,单一的组织机构职能极大限制了企业发展的创新动能。业务壁垒的不断滋生造成了信息传达严重的不对等,沟通协作成本加大,业务效率降低,专业化的分工形式逐渐成为了制约企业发展的桎梏。 2016年&…

基于Python科研论文绘制学习 - task3

Seaborn seaborn 在matplotlib 的基础上进行了更高级的封装,能用更少的代码绘制配图。 1、图类型 关系型图 数据分布型图 分类数据型图 回归模型分析图 2、多子图网格型图 FacetGrid() import pandas as pd import numpy as np…

全球纳米烧结银市场年复合增长率为6.5%!

烧结银简单来讲是指经过低温烧结技术将纳米银粉&#xff08;平均粒径<0.1μm(100nm)&#xff09;印刷在承印物上&#xff0c;使之成为具有传导电流和排除积累静电荷能力的银浆&#xff0c;其由导电填料——银粉、粘合剂、溶剂及改善性能的微量添加剂组成&#xff0c;使用低熔…

云企业网CEN与转发路由器TR

云企业网CEN 云企业网CEN&#xff08;Cloud Enterprise Network&#xff09;是运行在阿里云私有全球网络上的一张高可用网络。云企业网通过转发路由器TR&#xff08;Transit Router&#xff09;帮助您在跨地域专有网络之间&#xff0c;专有网络与本地数据中心间搭建私网通信通…

苹果手机数据恢复的详细教程,果粉必看!

“照片不小心误删”、“清理内存把聊天记录删除了”、“手机重要文件丢失”……大家是否也会遇到以上的糟糕情况呢&#xff1f;“手机数据丢失”这六个字的杀伤力有多大&#xff0c;大家可想而知。 那么&#xff0c;手机删除的数据能够恢复吗&#xff1f;苹果手机数据恢复的方…