大模型RLHF算法更新换代,DeepMind提出自训练离线强化学习框架ReST

news2024/11/19 13:19:51

文章链接: https://arxiv.org/abs/2308.08998

大模型(LLMs)爆火的背后,离不开多种不同基础算法技术的支撑,例如基础语言架构Transformer、自回归语言建模、提示学习和指示学习等等。这些技术造就了像GPT-3、PaLM等基座生成模型,在这些基座模型之上,研究人员通过引入人类反馈的强化学习算法(RLHF)开发出了例如ChatGPT这些与人类偏好保持一致的可聊天模型,才将LLMs真正带领到公众视野中。RLHF由于自身在线更新的限制带来了较大的训练计算代价,且容易遭到”外部攻击“

为了解决上述问题,来自Google DeepMind的研究团队提出了一种全新的强化自训练算法(Reinforced Self-Training,ReST),ReST相比RLHF,可以以更高的效率使LLMs的输出与人类偏好保持一致。ReST的设计灵感来源于他们将语言模型的对齐问题视为一个不断增长的批量强化学习问题,因此本文作者首先从一个初始LLMs策略出发,并根据该策略生成一个离线数据集,然后使用离线RL算法使用这些样本反过来更新LLMs策略。作者重点在基础NLP任务中的机器翻译任务上对ReST算法的性能进行了评估,实验结果表明,ReST相比RLHF可以更明显的提高模型的翻译质量。

01. 引言

如何将LLMs的输出与人类偏好或价值观进行高效的对齐,是目前提升LLMs性能的关键问题,如果没有进行适当的对齐处理,LLMs可能会产生风险高或完全错误的内容,这对于下游应用程序具有毁灭性的影响。目前常用的RLHF方法通常使用人类反馈的标注数据来学习一个奖励模型,然后将其用于强化学习目标来对LLM进行微调对齐。但是RLHF通常依赖于在线RL方法,例如PPO[1]和A2C[2],这就需要在模型训练过程中多次使用奖励模型来从更新后的策略中采样新样本,这会带来高昂的计算代价。为了解决这一问题,本文提出了一个自训练强化学习算法ReST,ReST将人类标注员从反馈训练循环中丢弃,自行生成并使用离线数据进行反馈训练。作者巧妙地设计了一个内外循环机制,如下图所示。

其中外循环称为Grow循环,模型会根据当前的策略来采样生成一个对齐数据集,内循环称为Improve循环,模型会对外循环生成的数据集进行过滤(使用人类偏好评分函数对样本进行排序过滤),并将过滤后的数据继续用于微调优化策略,内外循环相互影响,以降低采样数据带来的训练成本。ReST不再依赖在线的RL损失,因而成为了一种通用的强化学习框架,允许在执行Improve循环时使用不同的离线RL损失,使整体框架更具灵活性。

02. 本文方法

2.1 ReST的整体流程

2.2 Grow外循环

2.2 Improve内循环

03. 实验效果

本文的实验主要在机器翻译基准上进行,作者选取了IWSLT 2014、WMT 2020和Web Domain三个数据集,其中前两者为常见的机器翻译数据集,后者为内部测试数据集,这些数据集都包含一组语言文本和对应人类标注员给出的真实参考翻译。作者选取了几种不同的离线强化学习算法作为baseline对比方法,包括OAC、BVM、PO、GOLD和BC。

3.1 对Improve循环进行分析

作者首先分析了ReST的两个循环步骤对最终性能的影响,例如增加Improve循环的次数是否会增加奖励模型的分数,如下图所示,灰色柱状为监督学习baseline的分数,通过调整损失函数类型、Improve steps(I)和Grow steps(G)来构成不同的ReST变体,其分数为紫色柱状所示

可以看到,随着Improve steps数量的不断增加,ReST在所有三个数据集上的平均奖励分数都得到了提高

3.2 对Grow循环进行分析

Grow步骤可以不断增加离线训练的样本数量,因此作者对比了执行单次Grow步骤和执行两次Grow步骤后的模型性能,如下图所示,执行两次Grow步骤的ReST变体在IWSLT 2014和Web Domain数据集上都有明显的提升

3.3 对损失函数进行分析

在下图中作者展示了本文方法与监督训练模型,以及使用不同损失函数的ReST变体的平均奖励分数对比,可以观察到,即使只使用单次Grow步骤,ReST的不同变体(紫色)也显着优于监督学习模型(灰色)得到的奖励分数

此外,我们也可以观察到,BC损失在单次Grow步骤的情况下,明显优于使用其他损失函数的效果

3.4 ReST与在线RL算法进行对比

作者选取PPO算法作为对比在线RL算法,PPO广泛用于各式RLHF流程中。在实验中,PPO算法可以通过单次Grow步骤访问与ReST算法相当数量的训练数据,对比结果如下表所示。

可以看到,在线PPO算法的平均奖励分数基本与ReST算法持平,但是这只是在单次Grow步骤的情况下,当ReST使用多步Grow和Improve后(并且参与训练的数据量相同),性能会得到显著的提升

04. 总结

本文提出了一种名为ReST的自训练离线强化学习算法,其中包含了一种新型的内外循环机制(分为Grow外循环和Improve内循环)来高效的调度RL过程中的策略生成和更新。同时其具有良好的拓展性,可以灵活的应用在多种不同的RL损失中,本文作者在机器翻译基准上的实验表明,使用常用的BC损失可以使ReST在多种不同的环境中得到更高的奖励分数。ReST的提出也向社区宣布,在对LLMs执行与人类偏好对齐时,可以尝试除PPO等在线RL算法之外的更多RL优化手段。

参考

[1] J. Schulman, F. Wolski, P. Dhariwal, A. Radford, and O. Klimov. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.

[2] V. Mnih, A. P. Badia, M. Mirza, A. Graves, T. Harley, T. P. Lillicrap, D. Silver, and K. Kavukcuoglu. Asynchronous methods for deep reinforcement learning. In International Conference on Learning Representations, 2016.

作者:seven_


  关于TechBeat人工智能社区

TechBeat(www.techbeat.net)隶属于将门创投,是一个荟聚全球华人AI精英的成长社区。

我们希望为AI人才打造更专业的服务和体验,加速并陪伴其学习成长。

期待这里可以成为你学习AI前沿知识的高地,分享自己最新工作的沃土,在AI进阶之路上的升级打怪的根据地!

更多详细介绍>>TechBeat,一个荟聚全球华人AI精英的学习成长社区

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1059458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

问答雕虫1

问题: 现在有如下表 假设按时间顺序,记录中连续出现0.2 0.3 0.5 0.7四条记录记为一次有效数据组,统计一段时间范围内,有效数据组出现的次数,最终计算有效数据组在整个时间范围内的记录的占比。用mysql语句或者函数如何…

uniapp uni.showToast 一闪而过的问题

问题:在页面跳转uni.navigateBack()等操作的前或后,执行uni.showToast,即使代码中设置2000ms的显示时间,也会一闪而过。 解决:用setTimeout延后navigateBack的执行。

css自学框架之选项卡

这一节我们学习切换选项卡,两种切换方式,一种是单击切换选项,一种是鼠标滑动切换,通过参数来控制,切换方法。 一、参数 属性默认值描述tabBar.myth-tab-header span鼠标触发区域tabCon.myth-tab-content主体区域cla…

C语言动态内存管理

🐵本篇文章将会对动态内存管理相关知识进行讲解 1. 为什么要存在动态内存管理❓ 目前我们掌握了两种开辟内存的方式,分别为: int a 10;//存放一个值 int arr[] { 1,2,3,4,5,6,7,8,9,10 };//存放一组数 这两种内存开辟方式都是静态的&#…

2023年山东省安全员C证证考试题库及山东省安全员C证试题解析

题库来源:安全生产模拟考试一点通公众号小程序 2023年山东省安全员C证证考试题库及山东省安全员C证试题解析是安全生产模拟考试一点通结合(安监局)特种作业人员操作证考试大纲和(质检局)特种设备作业人员上岗证考试大…

人脸识别:FaceSDK 8.1 Crack

FaceSDK 使 Microsoft Visual C、C#、Objective C、Swift、Java、VB、Delphi 和 Python 开发人员能够为 Web、Windows、Linux、macOS、iOS 和 Android 构建具有人脸识别和基于人脸的 32 位和 64 位应用程序生物特征识别功能 FaceSDK 用于数百个应用程序,用于通过网络…

极大似然估计概念的理解——统计学习方法

目录 1.最大似然估计的概念的理解1 2.最大似然估计的概念的理解2 3.最大似然估计的概念的理解3 4.例子 1.最大似然估计的概念的理解1 最大似然估计是一种概率论在统计学上的概念,是参数估计的一种方法。给定观测数据来评估模型参数。也就是模型已知,参…

Flutter项目安装到Android手机一直显示在assembledebug

问题 Flutter项目安装到Android手机一直显示在assembledebug 原因 网络不好,gradle依赖下载不下来 解决方案 修改如下的文件 gradle-wrapper.properties 使用腾讯提供的gradle镜像下载 distributionUrlhttps://mirrors.cloud.tencent.com/gradle/gradle-7.5…

SpringBoot-Shiro安全权限框架

Apache Shiro是一个强大而灵活的开源安全框架,它干净利落地处理身份认证,授权,企业会话管理和加密。 官网: http://shiro.apache.org/ 源码: https://github.com/apache/shiro Subject:代表当前用户或…

【问题证明】矩阵方程化为特征值方程求得的特征值为什么是全部特征值?不会丢解吗?

问题 这个问题困扰了我好久,一直感觉如果有其他的特征值没法证伪,不过一直存在思想的层面,没有实际解决,今天突然想到动笔来解决,遂得解,证明如下。 证明 总结 这个证明看似证明过后很直观,但…

10.4 小任务

目录 QT实现TCP服务器客户端搭建的代码&#xff0c;现象 TCP服务器 .h文件 .cpp文件 现象 TCP客户端 .h文件 .cpp文件 现象 QT实现TCP服务器客户端搭建的代码&#xff0c;现象 TCP服务器 .h文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #includ…

智能合约漏洞,BEVO 代币损失 4.5 万美元攻击事件分析

智能合约漏洞&#xff0c;BEVO 代币损失 4.5 万美元攻击事件分析 一、事件背景 北京时间 2023 年 1 月 31 日&#xff0c;在 twitter 上看到这样一条消息&#xff1a; BEVO 代币被攻击&#xff0c;总共损失 45000 美元&#xff0c;导致 BEVO 代币的价格下跌了 99%。 有趣的是…

编程新手?跟着这个教程,用Python画出小猪佩奇

小猪佩奇是许多小朋友们的心头好&#xff0c;它的形象可爱、颜色鲜艳。你知道吗&#xff0c;只需要Python中的一个简单模块&#xff0c;我们就可以自己绘制出这个可爱的形象&#xff01;本文将教你如何使用Python的turtle模块&#xff0c;一步步画出小猪佩奇。 1. 准备工作&a…

当我们做后仿时我们究竟在仿些什么(四)

就像人类容易接受自然数&#xff0c;但对于负数缺乏某种直觉上的认识一样&#xff1b;后仿过程中经常出现的 Negative Delay 和 Negative Timing Check 也非常容易使人困惑。 Warning-[SDFCOM_NICD] Negative INTERCONNECT Delay encountered今天这篇首先简要分析这些 Negativ…

创建线程池

如何创建线程池及处理相应任务 目录 如何创建线程池及处理相应任务线程池定义解决的问题(需求)工作原理实现线程池创建示意图重要构造器创建线程池(ExecutorService)线程池任务处理常用API处理Runnable任务处理Callable任务 使用工具类(Executors)创建线程池常用API应用案例 拓…

桌面自动化工具总结

引言:产品经理提出桌面程序需要自动化的测试,避免繁琐的人肉点击。说干就干。 现有自动化工具是五花八门,我找了两个框架。 这两个框架都是基于微软的UIA 框架,链接地址 https://learn.microsoft.com/en-us/windows/win32/winauto/uiauto-providerportal?source=recommen…

以太网的MAC层

以太网的MAC层 一、硬件地址 ​ 局域网中&#xff0c;硬件地址又称物理地址或MAC地址&#xff08;因为用在MAC帧&#xff09;&#xff0c;它是局域网上每一台计算机中固化在适配器的ROM中的地址。 ​ 关于地址问题&#xff0c;有这样的定义&#xff1a;“名字指出我们所要寻…

【Spring】Bean作用域和生命周期

Bean作用域和生命周期 一. Bean 的作用域1. Bean 的 6 种作⽤域&#xff1a;①. singleton②. prototype③. request④. session⑤. application⑥. websocket单例作用域(singleton) VS 全局作⽤域(application) 2. 设置作用域 二. Spring 执行流程和 Bean 的生命周期1. Spring…

MySQL优化、锁、总结常见问题

慢 SQL 如何定位呢&#xff1f; 慢 SQL 的监控主要通过两个途径&#xff1a; 慢查询日志&#xff1a;开启 MySQL 的慢查询日志&#xff0c;再通过一些工具比如 mysqldumpslow 去分析对应的慢查询日志&#xff0c;当然现在一般的云厂商都提供了可视化的平台。服务监控&#xf…

如何实现torch.arange的tensor版本

文章目录 背景实现方案不可行的情况 背景 import torch我们都知道&#xff0c;torch.arange只支持数字&#xff0c;不支持tensor&#xff0c;如下&#xff1a; torch.arange(0,5,1)tensor([0, 1, 2, 3, 4]) 但是如果使用tensor&#xff0c;就会报错&#xff1a; torch.arang…