LLM —— 强化学习(RLHF-PPO和DPO)学习笔记

news2024/9/24 13:25:17

强化学习整体流程

请添加图片描述
智能体执行动作与环境进行交互,根据奖励R的反馈结果不断进行更新。

价值函数

请添加图片描述
奖励将会考虑两个方面的奖励,一个当下的奖励,一个是未来的奖励(为了防止陷入局部最优解)。

LLM强化学习

请添加图片描述

强化学习模型分类

需要LLM在训练过程中做生成的方法是 On Policy,其余的为Off Policy。On Policy是包含了反馈机制,Off Policy不包含反馈机制。

1、On Policy

(1)RLHF模型

组成部分
请添加图片描述
具有四个部分,演员模型、参考模型、评论家模型和奖励模型。
其中,演员模型和评论家模型是需要训练改变参数的,而奖励模型和参考模型在训练中不改变参数。

训练过程
请添加图片描述
首先,由第一个SFT阶段后,会得到SFT模型。然后,使用SFT模型作为参考模型和演员模型。使用偏好数据训练一个Reward模型作为PPO阶段的奖励模型和评论家模型。在第三部训练过程中,Reference模型和Reward模型的参数将不会变。Actor模型和Critic模型将会随着训练变化参数。

模型部分

  1. Actor模型
    请添加图片描述
    Actor模型是我们最终训练的目标模型,生成的reponse中每个token将作为生成的动作,后续会被奖励模型和评论家模型进行评判。
    请添加图片描述
    prompt为 S t S_t St,response为 A t A_t At

  2. Reference模型
    请添加图片描述
    Reference模型主要是为了防止Actor模型训练后与SFT阶段产生的模型差异过大,RLFH阶段只是为了让Actor模型生成符合人类偏好的数据,但并不希望与SFT阶段的生成效果偏差过大。因此,需要Reference模型去来 “纠正” 它。

  3. Reward模型
    请添加图片描述
    Reward模型用于计算生成token的当下收益,输入prompt+response,生成的Response中米哥token会作为动作输入给RW模型的最后一层Value Head得到 R t R_t Rt 奖励值。
    请添加图片描述
    其中,x为prompt, y w y_w yw为偏好正例数据response, y l y_l yl为偏好负例数据reponse。

  4. Critic模型
    在这里插入图片描述
    Critic模型会评价Actor模型未来的收益,因为Actor模型会随着训练不断变化,Critic模型也需要随着变化去更新,来更好的评判Actor模型未来会产生的收益。
    请添加图片描述
    R t R_t Rt是当下收益, V t + 1 V_{t+1} Vt+1是未来收益,两个结合后为对未来收益更综合的评估。

RLHF-PPO过程
5. prompt输入给Actor模型,会得到动作 A t A_t At。同时,也将prompt输入给Reference模型,得到参考的 A t A_t At
6. 将得到的动作 A t A_t At 输入给 Critic模型和 Reward模型分别得到 V t V_t Vt R t R_t Rt
7. 将Actor模型的 A t A_t At和Peference模型的 A t A_t At进行比较,得到一个KL散度分数。
8. 结合 R t R_t Rt V t V_t Vt和KL散度分数,对Actor模型和Critic模型进行联合优化更新Loss。
9. 更新 Actor模型和 Critic模型。

RLHF现存问题

主要集中在训练时的问题。

  1. 算力消耗大
    请添加图片描述

  2. 容易崩溃
    请添加图片描述
    因为牵扯模型多且需要协同配合,容易出现训练崩溃。

(2)deepspeed chat

请添加图片描述
1、如何高效生成答案?

DeepSpeed混合引擎
挑战1:在大多数高度优化的系统中,训练和推理通常使用两个不同的后端。

原因是这两个目标通常在不同的情况下使用——训练用于模型更新,推理用于模型部署。
在RLHF微调中,生成与训练是存在串行的方式,演员模型需要在每一步为每个查询生成答
案。因此,标准的训练模式可能是RLHF微调的瓶颈,因为它没有针对推理进行优化。

挑战2:因为模型分布在不同的GPU上。在生成过程中,生成步骤需要在GPU之间收集参数后进行推理,通信成本将会非常高

为了克服这两个挑战,引入了DeepSpeed混合引擎(DeepSpeed-HE)。
1.这个引擎可以自动在DeepSpeed提供的训练引擎和推理引擎之间切换
2.DeepSpeed-HE可以自动将ZeRO训练模式更改为张量并行推理,消除了重复参数收集的
需要。

2、如何处理多个模型使用的大量内存消耗?

第一,得益于DeepSpeed ZeRO优化,我们可以将模型参数和优化器分布在整个用于训练的GPU系统上。这显著减少了这些模型所需的内存消耗。

第二,参考模型与PPO训练循环中的演员模型大小相同,这需要相当数量的内存。然而,这个参考模型只在我们需要“旧行为概率“时才被调用。因此,参考模型的计算成本低于演员模型。为了减少内存压力,我们提供了一个单模型卸载选项,只将参考模型卸载到CPU。我们观察到,在相同的训练批量大小下,卸载参考模型(到CPU)与否的吞吐量影响很小。然而,如果演员模型卸载到CPU,训练速度会显著减慢。

第三,优化器的优化状态消耗了大量的训练内存。为了缓解这个问题,采用LoRA训练的方式,它只更新训练期间参数的一小部分。结果,与标准训练相比,优化状态要小得多

(3)实战案例
1、偏好数据集的构建

请添加图片描述
基于XuanYuan-6B进行RLHF落地实战,该模型主要应用于金融领域。

Prompt构建
重点关注两个方面:一方面是数据的丰富性和多样性,一方面是数据的质量。

  • 数据的多样性保证
    请添加图片描述
    把通用性、安全性和金融性进行了更细粒度的拆分,得到了多个子项,并按照一定的量级和比例收集每一子项的数据。这样可以使收集的prompt覆盖到不同的方面,同事具备合理的量级和配比。

  • 数据质量保证
    专业人员对数据进行清晰:删除或修改有明显错误的prompt或格式有瑕疵的prompt。经过清洗后,获得4W+高质量的prompt数据。

Response生成
为保证RM训练数据和测试数据分布的一致性,避免出现OOD(Out of distribution)问题,生成步骤:

  1. 使用XuanYuan-6B-SFT来产生response。在强化学习阶段,RM模型的输入是Actor模型的输出,Actor的初始状态为XuanYuan-6B-SFT。
  2. 使用XuanYuan-6B-SFT的采样参数,提高其采样参数中temperature和top_p的值,然后再生成response,以保证response的多样性,以及其包含的偏好信息的多样性。

偏好标注
当前业界主要有两种流行的标注方式:rank标注和pair标注。

  • rank标注
    一个Prompt包含多个response(一般为4个),标注者要对多个response进行排序,之后根据排序信息,可以将response两两组合,构建形如所示的偏好数据。Instruct-GPT即采用这类标注方式。
  • pair标注
    一条prompt仅生成两个reponse,标注者直接比较两个reponse,标出哪条response更符合偏好。此外,一些标注方法也要求标出偏好的强度。Anthropic和LlaMA2即采用pair形式的偏好标注。

实践

请添加图片描述

  • 放弃rank标注方式: 原因是标注速度慢且不同的标注人员对标注结果评判的一致性较低。
  • 采用pair标注方式: 直接比较两个response进行标注,并且要求标注出偏好的强度,以收集更多的偏好信息,来提升RM的泛化性能。
  • 具体标注步骤: 标注页面可以选择8个档位进行标注,从左到右依次命名为A3、A2、A1、A0、B0、B1、B2、B3。其中,A3表示A优于B的程度,B3表示B优于A的程度,其他档位依次类推。
  • 制订了一套完善的标注标准: 覆盖了实际中可能出现的大多数场景,并在标注过程中不断发现和解决新出现的问题,不断扩充完善我们的标注标准。
  • 对交付的标注结果进行严格的质检: 如果数据不合格会重新进行标注,直至满足验收标准。
  • 删除了偏好强度最低的数据(即A0和B0): 偏好强度低意味着两个response较为接近,未包含明显的偏好信息。这类数据歧义较大,会让模型感觉比较"困惑",不利于模型进行偏好建模。

最终的数据量:约6W+条偏好数据,其中90%用于训练,剩余10%用于测试。

2、RM训练

架构
请添加图片描述
XuanYuan-6B-SFT作为RM的基本架构,去掉最后的LM_head layer(softmax层,输出词表中每个token的概率),并将其替换为value_head layer。Value_head layer为一个线性层,输入是XuanYuan-6B-SFT次顶层的特征,输出为一个一维的reward分数。

损失函数
请添加图片描述
在实践中计算损失有两种方式:token-level的对比损失和sentence-level的对比损失

  • token-level的对比损失
    参考DeepSpeed-Chat中做法,使用token-level的对比损失来进行RM训练。
    训练阶段:

    • 步骤1:对于 y c y_c yc y r y_r yr,先找到他们第一个不相同的token所在的位置,作为起始位置
    • 步骤2:找到两个response结束的位置,并取两者中的最长长度,作为结束位置。
    • 计算从起始位置,到结束位置,相同位置上 y c y_c yc y r y_r yr之间的对比损失,最后求对比损失的均值作为该条件偏好样本的损失。(逐字对比)
    • 预测阶段:取response最后一个token对应的reward作为该response的reward。
  • sentence-level的对比损失
    为保证训练/测试的一致性,训练时应该取 y c y_c yc 最后一个token的reward和 y c y_c yc 最后一个token的reward来计算对比损失。

  • 上述两个在实践中的实际情况
    实验对比了两种损失函数的表现,结果表明sentence-level损失训练RM可获得更高的测试精度,但是RM不仅用于给reward打分,还用与强化训练阶段critic model的初始化。我们发现使用sentence-level损失训练的RM初始化critic model后,强化训练会变得不稳定,难以收敛。因此,我们仍使用token-level损失来进行RM训练,虽然精度会小幅度下降,但是强化训练的稳定性会有较高提升。

模型选择
在RM训练阶段,训练多个epoch,并在每个epoch结束后存储当前RM,之后选择合适的RM进行后续强化训练。在选择RM时,我们主要看以下几点:

  1. 测试精度:因为测试精度客观反应了RM打分合理性;
  2. RM输出的reward值:如果reward值过小或过大,在后续强化训练时会产生数值问题,导致训练无法正常进行;
  3. 接受和拒绝response奖励值之间的差距:具体做法是计算测试集中的接受reward的均值和拒绝reward均值,观察两个均值之间是否存在一定的差距。如果存在一定的差距,则说明RM有较强的鲁棒性。(差距越大,对比的体现就越明显)

最终选择RM测试精度是63%,输出尺度在[-1, 1]区间内,差距为0.5。

3、RLHF训练

模型结构
actor model和reference model:XuanYuan-6B-SFT
critic model和reward model:XuanYuan-6B-RM进行初始化

训练中actor model和critic model需要更新,而reference model和reward model保持不变。

数据
强化训练的数据为prompt数据。

数据组成:偏好数据的prompt,增加了额外的新prompt,比例为1:1。

  • 偏好数据中的prompt用于强化训练会使训练过程更为"容易",很大程度上可以避免RM打分不准而导致的一系列问题,如reward hacking、训练不收敛等。
  • 仅采用偏好数据中的prompt是不够的,这样模型见到的数据过于局限,不利于提升模型的泛化性能,因此增加了额外的新prompt一起用于强化训练。新prompt的构建方式和偏好数据中prompt构建方式相同。≈

训练

  • 训练参考
    训练过程参考了Instruct-GPT】LlaMA2以及Anthropic的做法。在实现上,参考了DeepSpeed-Chat框架。

  • 强化训练的目标
    请添加图片描述

  • 超参选择

    • KL的权重 β = 0.05 \beta=0.05 β=0.05
      β \beta β 是一个超参数,用于平衡探索和保持现状之间的权衡。过高的 β \beta β 会使模型接近初始模型 π 0 \pi_0 π0(不怎么探索),强化训练效果不明显;过低的 β \beta β 会过度优化 reward值(探索过头),容易造成reward hacking。

    • actor model和critic model的学习率设置为5e-7
      过高的学习率会让RM值快速上升,容易造成reward hacking;而过低的学习率会极大降低训练速度。

    • loss精度
      在计算loss时,使用fp32的数据精度,避免loss的数值问题引起的训练不稳定现象。

    • 训练了约300 PPO step
      训练中重点关注critic loss和RM reward值的变化,critic loss整体上应呈现下降趋势,而RM reward整体上应呈现上升趋势。
      注:RM reward上升的过高也是一种异常现象,此时大概率出现了reward hacking。

  • 模型选择

    • 每训练20个PPO step,存储当前的actor model。
    • 训练完成后,根据RM reward变化情况,挑选几个不同阶段的代表性模型进行快速的人工评估。
    • 人工评估时对比对象是强化训练前的SFT模型,即XuanYuan-6B-SFT。
      评估完成后统计good(actor response > SFT model response),same(actor response = SFT model response),bad(actor response < SFT model response)数量。然后,选择最有优势的actor model进行更正式的人工评估。
  • 人工评估

    • 聘请了专业的评估人员进行模型评估,评估题目覆盖通用型、安全性、金融垂类等不同范畴。
    • 每道题目均由三个不同的评估人员进行评估,来避免不同评估人员的洗好偏差。
    • 评估题目对其他人员完全封闭,避免研发人员同构构造类似的评估题目进行训练来获得更好的评估结果。

模型评估效果:
请添加图片描述
模型在通用性的各细分领域的评估结果:
请添加图片描述
从结果来看,在大多数子领域,经过强化训练后,模型的能力都有了明显的提升。在日常对话、逻辑推理、内容创作和安全性等子领域,强化带来的效果提升都很明显。

然后,在一些其他子领域,比如信息摘要、翻译等,强化训练并未带来明显的进步。在后续工作中,需要补充更多的偏好数据,同事提升偏好标注质量,来进一步补齐这些弱项的能力。

模型在金融子领域的各细分领域的评估结果:
请添加图片描述
在金融知识理解、金融业务分析连个子领域,强化训练带来了明显的能力提升。而在其他子领域,强化训练并未取得逾期的效果。对这些子领域,需要补充更多的高质量偏好数据,提高RM对这类prompt和response打分的准确性,进而提升强化训练的效果。

2、Off Policy

(1)DPO模型

加载2个模型,其中一个推理,另外一个训练,直接在偏好数据上进行训练即可。开始训练时,reference model和policy model都是同一个模型,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2084365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

四十四、【人工智能】【机器学习】- Kernel Ridge Regression(KRR)

系列文章目录 第一章 【机器学习】初识机器学习 第二章 【机器学习】【监督学习】- 逻辑回归算法 (Logistic Regression) 第三章 【机器学习】【监督学习】- 支持向量机 (SVM) 第四章【机器学习】【监督学习】- K-近邻算法 (K-NN) 第五章【机器学习】【监督学习】- 决策树…

【已解决】”只读方式“下的PPT可以编辑吗?

以“只读方式”打开的PPT文件&#xff0c;在编辑时会受到一些限制&#xff0c;那怎样才能正常编辑呢&#xff1f;根据PPT不同模式的“只读方式”&#xff0c;解决方法也不同&#xff0c;一起来看看吧&#xff01; 情况一&#xff1a;PPT属性设置为“只读” 当PPT文件在文件属性…

Python中排序算法之选择排序

选择排序算法是对《Python中排序算法之冒泡排序》中提到的冒泡排序算法的改进。 1 选择排序原理 选择排序是在参加排序的所有元素中找到数值最小&#xff08;或最大&#xff09;的元素&#xff0c;如果它不是左侧第一个元素&#xff0c;就使它与左侧第一个元素中的数据相互交…

CKKWWKKW-Dip-K-NH2;LTX-315;巯基化修饰溶瘤肽;CAS:1345407-05-7

【CKKWWKKW-Dip-K-NH2 简介】 CKKWWKKW-Dip-K-NH2&#xff0c;也被称为LTX-315&#xff0c;是一种具有抗癌活性的溶瘤肽。分子量为1439.79&#xff0c;分子式为C78H106N18O9。氨基酸序列为Lys-Lys-Trp-Trp-Lys-Lys-Trp-Dip-Lys-NH2。LTX-315被发现对多种癌细胞具有抑制作用&…

Git之1.5版本重要特性及用法实例(五十三)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者. 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列…

Java 使用 POI 导出Excel,设置同一个单元格的内容显示不同的文字颜色

在使用Apache POI的库生成Excel的时候&#xff0c;如何在一个Cell中的文字中显示不同的颜色&#xff1f;下面是一个示例代码&#xff0c;演示如何在单元格中设置不同颜色的文本。 代码 // 创建工作簿和工作表 Workbook workbook new XSSFWorkbook(); Sheet sheet workbook.c…

鸿蒙OS试题(7)

46在组件中&#xff0c;经常需要使用字符串、图片等资源。HSP中的组件需要使用资源时&#xff0c;一般将其所用资源放在HSP包内&#xff0c;而非放在HSP的使用方处&#xff0c;以符合高内聚低合的原则。下面访问HSP资源错误的是 A.通过$r访问HSP中的资源。lmage($r(app.media.…

免费分享:2020年全球10m分辨率红树林(附下载方法)

Google Earth Engine (GEE) 是一个强大的云端地理信息处理平台&#xff0c;‌由Google与卡内基美隆大学和美国地质调查局共同开发。‌ 它提供了一个存取卫星图像和其他地球观测数据数据库的途径&#xff0c;‌并具备足够的运算能力来处理这些数据。‌ MSIC算法是指基于时间序列…

SCI FI SHOOTER CHARACTERS PACK VOL 1

这个包是科幻射击角色包第一卷的升级版。如果您已经拥有旧版本,您可以使用升级路径,从降价中受益,并享受升级后的版本*** 此包包含11个SCi FI角色,可随时填充您的项目: 外星步兵 外国雇佣兵 外星特种部队通灵者 外星赏金猎人 外星战争老兵 外星战士 人类太空海盗兵 海盗中…

【运维】解决Ubuntu 22.04 desktop版本打不开终端

问题 我是在Visual Box中创建的虚拟机&#xff0c;基于Ubuntu 22.04.4 desktop amd64版本。创建之后&#xff0c;在应用列表中打开terminal&#xff0c;并没有启动&#xff0c;过一会&#xff0c;程序自动退出 解决 这种一般都是语言和地区设置的不一致 比如&#xff1a;地区…

Linux上安装Conda以管理Python环境

在Windows下装了Linux发行版Debian&#xff0c;以后不用来回开启VMware啦&#xff01;并在Debian中安装了Conda,记录一下所需命令(其他版本如Ubuntu中安装是一样的命令)。 目录 1.WSL 2.安装Conda 3.Python环境配置 1.WSL Install WSL | Microsoft Learn 微软官网 ①以管理…

让视频播放更智能、更流畅!开源视频播放器项目GSYVideoPlayer

GSYVideoPlayer&#xff1a;简单、强大、灵活。一切尽在GSYVideoPlayer - 精选真开源&#xff0c;释放新价值。 概览 GSYVideoPlayer是一个为Android应用开发者提供的开源视频播放解决方案。它通过提供一套简洁直观的API&#xff0c;使得视频播放功能的集成变得简单快捷。开发…

六西格玛培训教你用多变量分析找问题根源——张驰咨询

在六西格玛培训的殿堂里&#xff0c;多变量分析不仅是学员们掌握的一项关键技能&#xff0c;更是他们通往卓越绩效之路上的一把重要钥匙。这门深奥而强大的工具&#xff0c;不仅拓宽了学员们的数据分析视野&#xff0c;还为他们提供了在复杂系统中寻找最优解、实现持续改进的能…

Oracle ADG切换检查及操作

一、配置检查 1、数据库名称及log_archive_config检查 使用命令 show parameter name; show parameter log_archive_config; 查看点 查看数据库db_unique_name、db_name、service_names 设置查看log_archive_config是否配置了正确的生产及容灾db_unique_name 确认点 生…

Spring Cloud Open Feign 超时配置及源码分析

前言&#xff1a; 在开发 Spring Cloud 微服务项目时候&#xff0c;Feign 调用是非常常见的&#xff0c;Feign 调用的底层还是 HTTP 的远程调用&#xff0c;会有超时问题&#xff0c;如果没有搞清楚超时问题&#xff0c;生产环境的调用肯那个会有种种问题出现&#xff0c;本篇…

pymysql cursor使用教程

Python之PyMySQL的使用&#xff1a; 在python3.x中&#xff0c;可以使用pymysql来MySQL数据库的连接&#xff0c;并实现数据库的各种操作&#xff0c;本次博客主要介绍了pymysql的安装和使用方法。 PyMySQL的安装 一、.windows上的安装方法&#xff1a; 在python3.6中&…

图像字幕Image Captioning——使用语法和语义正确的语言描述图像

1. 什么是图像字幕 Image Captioning&#xff08;图像字幕生成&#xff09; 是计算机视觉和自然语言处理&#xff08;NLP&#xff09;领域的一个交叉研究任务&#xff0c;其目标是自动生成能够描述给定图像内容的自然语言句子。这项任务要求系统不仅要理解图像中的视觉内容&…

NLP从零开始------文本中阶序列处理之语言模型(完整版)

语言模型( language model) 用于计算一个文字序列的概率&#xff0c; 评估该序列作为一段文本出现在通用或者特定场景中的可能性。每个人的语言能力蕴涵了一个语言模型&#xff0c;当我们说出或写下一段话的时候&#xff0c;已经在不自觉地应用语言模型来帮助我们决定这段话中的…

ceph-rgw zipper的设计理念(2)

本文简介 书接上文。本文以CreateBucket为例进行详细讲述设计理念以及接口变化趋势。 1、接收请求和协议处理请求 rgw_asio_frontend.cc 主要功能&#xff1a;回调函数注册和请求处理 void handle_connection(boost::asio::io_context& context,RGWProcessEnv& env…

如何使用IDEA搭建Mybatis框架环境(详细教程)

文章目录 ☕前言为什么学习框架技术Mybatis框架简介 &#x1f379;一、如何配置Mybatis框架环境1.1下载需要MyBatis的jar文件1.2部署jar文件1.3创建MyBatis核心配置文件configuration.xml1.4.创建持久类(POJO)和SQL映射文件1.5.创建测试类 &#x1f9cb;二、 MyBatis框架的优缺…