InstructGPT论文解读

news2024/9/20 19:59:54

介绍

在这里插入图片描述

上图可以看出InstructGPT(PPO-ptx)及变体(variant trained without pretraining mix)(PPO)显著优于GPT,1.3B的InstructGPT优于175B的GPT,SFT作为RLHF第一阶段的模型效果比GPT好。当然结果是由人来评价的。

RLHF包含三个步骤:

  1. SFT: 对话数据微调基座语言模型,
  2. RM: 评分数据微调RM模型
  3. PPO: SFT模型生成答案,RM模型评分,PPO算法对SFT模型强化学习,进一步改进生成效果

下图就是3个阶段
在这里插入图片描述

蓝色箭头表示该数据用于训练我们的一个模型。在步骤2中,框A-D是由标签者排序的模型样本

InstructGPT训练了三种尺寸(1.3B, 6B, 175B)

InstructGPT生成的答案相比GPT3更加真实,同时在摘要和问答上也不会过度编造内容。

在公共数据集上微调的模型比InstructGPT表现差,公共数据集特定任务的原因。

InstructGPT泛化能力较好,即使在训练数据中分布较小的任务上也能获得好的测试效果。

InstructGPT还是存在编造事实,简单问题给出冗长模棱两可的回答,无法检测错误的前提。

本文的其余部分结构如下:我们首先在第2节详细介绍相关工作,然后在第3节深入介绍我们的方法和实验细节,包括我们的高级方法(3.1),任务和数据集细节(3.3和3.2),人类数据收集(3.4),我们如何训练我们的模型(3.5),以及我们的评估程序(3.6)。然后,我们在第4节中展示了我们的结果,分为三个部分:API提示分布的结果(4.1),公共NLP数据集的结果(4.2)和定性结果(4.3)。最后,我们在第5节中对我们的工作进行了扩展讨论,包括对齐研究的含义(5.1),我们对齐的内容(5.2),限制(5.3),开放问题(5.4)以及本工作的更广泛影响(5.5)。

相关工作

关于RLHF和语言对齐的工作从2016到2022就一直在NLP任务上使用,InstructGPT是在更广泛的NLP任务上使用RLHF。

使用instruction微调LM,有助于提高泛化能力。

实验细节

InstructGPT使用用户提出的prompt进行第一阶段的训练,第二阶段和第三阶段可以循环,即使用RM训练PPO,然后PPO后产生更好的RM数据,再训练。

第一个InstructGPT训练数据来自标注者的标注数据作为种子数据进行训练。主要三类数据:

  1. Plain: 我们简单地要求标注者提出一个任意的任务,同时确保任务具有足够的多样性
  2. Few-shot: 我们要求标注者提出一条指令,以及该指令的多个查询/响应对。
  3. User-based: 我们在OpenAI API的等待列表应用程序中列出了许多用例。我们要求标注者提出与这些用例相对应的提示

SFT数据集包含大约13k个训练提示(来自API和标注器编写),RM数据集有33k个训练提示(来自API和标注器编写),PPO数据集有31k个训练提示(仅来自API)。[第一次训练]

提交给InstructGPT模型的更多提示见附录A.2.1,提交给GPT-3模型的提示见附录A.2.2。我们在附录A中提供了关于我们数据集的更多细节。

数据收集

更多任务更广泛的数据,也有一些敏感话题。需要对标注员进行培训。

SFT:监督学习对GPT-3微调,训练了16个epoch,cosine学习率,残差的dropout是0.2。使用RM模型在验证集上选择分数最高的SFT模型,训练一轮,SFT就会过拟合,但还是要训练更多轮。

RM:RM模型的结构是去除最后嵌入层的SFT,6B的RM比175B训练过程更稳定。评分数据是在相同输入下,多个模型给出输出,然后对所有输出进行评分。相同prompt的输出两两做成一对作为一个训练数据,这多对数据需要放在一个batch中,这样一个prompt就一次前向传递。如果分散到多个batch,相当于每对看作single,一个epoch就会过拟合。

RL:

评价:1)在用户提交给openai的prompt上测试,2)公开的NLP数据集

实验结果

instructgpt相比GPT3更能遵守prompt,不容易瞎编内容,更适合做助手。

公共的NLP数据集不能完全测试InstructGPT的能力,因为公共NLP数据集中的分类问答等任务只占到InstructGPT训练数据分布的一小部分(18%),更多的prompt是关于生成和头脑风暴的(57%)

在这里插入图片描述

在传统的NLP任务上,在zero-shot和few-shot时,InstructGPT并一定比GPT3更好。有时候给InstructGPT提问,模型没给出明确答案,但给了几个可能的答案,实际正确答案就是可能答案中,原因可能是模型有点谦虚。

讨论

增加模型对齐的成本相对于预训练是适度的:训练SFT和PPO模型的成本远小于GPT3的成本。(练我们的175B SFT模型需要4.9 petaflops/s-days,训练我们的175B PPO-ptx模型需要60 petaflops/s-days,而GPT-3需要3640 petaflops/s-days)

附加

数据集大小
在这里插入图片描述

注意RM的数据集需要进行拼凑,实际大小比上述大一个数量级。

模型细节

所有模型架构都使用GPT-3架构(Brown et al, 2020)。对于奖励模型和价值函数,将原模型的非嵌入层替换为投影层,输出标量值。所有模型都使用fp16的权重和激活,并使用fp32的权重主副本。所有模型都使用与Brown等人(2020)相同的字节对编码(BPE)。我们所有的语言模型和RL策略的上下文长度都是2k个令牌。我们过滤掉超过1k个tokens的prompts,并将最大响应长度限制为1k个tokens。所有模型都使用Adam优化器进行训练。

SFT

我们训练了16个epochs的SFT模型,残差的dropout是0.2。我们使用的余弦LR调度降低到原始学习率的10%,没有学习率预热。1.3 B和6B模型,我们使用的LR为9.65e-6,批量大小为32。对于175B,使用的LR为5.035 -6,批量大小为8。为了选择学习率,对1.3 b和6B的7个LRs和175B的5个LRs进行了几何搜索。我们还使用几何搜索调整了epoch的数量,最后的模型是根据RM分数选择的,我们发现与验证损失相比,RM分数更能预测人类偏好结果。

RM

我们训练了一个单独的6B奖励模型,我们将其用于所有大小的PPO模型。较大的175B RM有可能实现更低的验证损失,但是(1)它们的训练更不稳定,这使得它们不太适合用作PPO值函数的初始化;(2)使用175B RM和值函数大大增加了PPO的计算需求。在初步实验中,我们发现6B RMs在广泛的学习率范围内是稳定的,并导致同样强大的PPO模型。

2最终的奖励模型是根据一个6B GPT-3模型初始化的,该模型在各种公共NLP数据集(ARC、BoolQ、CoQA、DROP、MultiNLI、OpenBookQA、QuAC、RACE和Winogrande)上进行了微调。这主要是由于历史原因;当从GPT-3或SFT模型初始化RM时,我们发现类似的结果。我们在完整的奖励模型训练集(见表6)上以lr = 9e-6的学习率,cosine learning rate schedule(在训练结束时下降到其初始值的10%)和批大小为64的单个epoch进行训练。训练似乎对学习率或schedule不太敏感;学习率变化高达50%也会产生类似的效果。训练对epoch的数量非常敏感,多个epoch会迅速将模型与训练数据过拟合,验证损失变大。这里的批大小表示每批的不同数量的提示(The batch size here represents the distinct number of prompts per batch)。每个提示有K = 4到K = 9个labeled completions(组合就是从K中选择2个作为一组 ( K 2 ) \begin{pmatrix} K\\ 2 \end{pmatrix} (K2))。分数相同的组会被丢弃。因此,单个批次可以包含多达64 × ( K 2 ) \begin{pmatrix} K\\ 2 \end{pmatrix} (K2)≤2304次比较。

RLHF初始化模型的详细信息

我们从预训练的GPT-3模型初始化RLHF模型,并对演示数据集进行2次监督微调。我们还在微调期间混合了10%的预训练数据,因为我们发现这对PPO训练很有帮助(详见附录E.11)。使用余弦学习率调度,学习率最终衰减到峰值学习率的10%。1.3 b和6B型号我们使用32个batch,175B型号使用8个batch。我们比较每个模型的几个不同的峰值学习率,并选择一个在演示和预训练验证数据集上损失低的模型。对1.3B和6B模型的5个LR值进行了对数线性扫描,而175B模型选了3个。1.3 b, 6B和175B型号的最终LR分别为5e-6, 1.04e-5和2.45e-6

RLHF训练的细节

然后,我们用预训练混合从上述监督微调模型初始化RL策略。这些模型也用于计算KL奖励,与(Stiennon等人,2020)相同,β = 0.02。我们训练了256k episodes 的所有RL模型。在使用PII和基于公共前缀的重复数据删除过滤掉提示之后,这些集包括大约31k个唯一提示。每次迭代的批大小为512,minibatch大小为64。换句话说,每个批次被随机分成8个小批次,并且只训练一个内部epoch (Schulman et al, 2017)。在前10次迭代中应用恒定的学习率和预热,从峰值学习率的十分之一开始。采用加权指数移动平均,衰减率为0.992。在估计广义优势时不应用折扣(Schulman et al, 2016)。PPO的clip ratio设置为0.2,采样温度为1。

如前所述,对于所有PPO模型,我们使用6B RM和6B值函数,后者由前者初始化。通过对所有模型大小的策略使用相同的6B奖励模型和价值函数,可以比较策略模型大小对策略性能的影响。对于1.3 b和6B策略使用固定学习率9e-6,对于175B策略使用固定学习率5e-6。

我们最初的RLHF实验显示了在公共NLP数据集(如SQuADv2和DROP)上的回归,并且我们通过在PPO训练期间混合预训练梯度来缓解回归。我们使用的预训练样本数量是RL训练集数量的8倍。预训练数据是从用于训练GPT-3模型的数据集中随机抽取的。对于每个minibatch,我们连续计算PPO梯度和预训练梯度,并将它们累加到梯度缓冲区中。我们将预训练梯度乘以系数γ = 27.8(见公式2),以控制PPO和预训练分布的梯度的相对强度。

FLAN和T0模型

我们通过在FLAN和T0数据集上微调175B GPT-3模型来获得FLAN和T0基线。由于T0比FLAN(1.2M数据点)包含更多的数据(96M数据点),我们对T0进行了次采样,使每个模型的训练数据量具有可比性。请注意,原始模型在数据点可以重复的数据点上进行训练,但是在我们的数据点中,我们遍历每个数据点而不重复(为了更好地匹配我们训练SFT基线的方式)。我们应用了余弦学习率计划,并尝试每个数据集的初始学习率为4e-6和6e-6。在训练结束时,学习率下降到峰值的10%,我们在两个实验中都使用64个批处理大小。

为了选择最佳的FLAN检查点,我们使用6B奖励模型对提示的验证集的完成情况进行评分。如图13所示,在最初的400k个训练示例之后,奖励达到饱和。这表明,即使训练时间更长,也不太可能提高人类的评估表现。我们为我们的人类评估选择了RM得分最高的检查点,即学习率为4e-6并且训练了896k个示例的检查点。

参考

https://arxiv.org/abs/2203.02155

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/675450.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多EIP下的UDP通信异常分析

背景 SRE和程序在测试DDos多EIP防御方案的过程中,发现多EIP模式下, 监听的UDP端口连接会出现客户端访问异常。 表现为客户端发送一次数据后服务端这边主动断开了,或是客户端和服务端同时断开。 该问题会导致业务在多EIP方案下无法达到预期效果&#xff0…

【Visual Studio】Qt 获取系统时间,并实时更新时间,使用 C++ 语言,配合 Qt 开发串口通信界面

知识不是单独的,一定是成体系的。更多我的个人总结和相关经验可查阅这个专栏:Visual Studio。 这个需求来源于这个工程:【Visual Studio】Qt 的实时绘图曲线功能,使用 C 语言,配合 Qt 开发串口通信界面。 文章目录 Qt…

chatgpt赋能python:Python求和1-100的方法

Python求和1-100的方法 Python语言简介 Python是一个广泛使用的高级编程语言,其设计哲学强调代码可读性和语法简洁性。Python语言作为一门多范式的编程语言,支持对象、函数式和结构化编程等多种形式。Python应用领域广泛,如机器学习、桌面应…

chatgpt赋能python:Python求1!:介绍

Python求1!:介绍 在Python编程中,阶乘是一个常见的数学运算。阶乘指定的数的所有小于或等于其本身的正整数之积,例如,1!等于1,2!等于2乘以1,3!等于3乘以2乘以1,以此类推。 在这篇文…

【工程项目管理】工程项目管理实践报告

前言: 1.大学课程的大作业,觉得存起来也没什么用就干脆发出来了。。。 2.很可能有不严谨之处,各位看官如若发现欢迎指出~ 创作者文章管理系统 1 实践环节作业1:选题及任务分解WBS (1)选题 a.项目名称&a…

【1 beego学习 -MAC框架与ORM数据库】

0 beego的启动流程 1 入口 package mainimport ( //全局使用的路由和models_ "studyDemo/models"_ "studyDemo/routers"beego "github.com/beego/beego/v2/server/web" )func main() {beego.Run() }2 根据请求路由加载对应的控制器 package r…

【计算机组成原理】微程序控制器

目录 一、微程序控制器概述 二、微程序控制器设计方法 三、微指令执行过程 四、控制字段的编码方式 五、下址字段的设计方法 六、微程序入口地址的产生方法 一、微程序控制器概述 微程序:微指令构成的有序集合,一条指令对应一段微程序 微指令&…

【小沐学Android】Material Design设计规范之颜色篇

文章目录 1、简介1.1 Android1.2 Material Design 2、Material Design 12.1 材料设计2.2 颜色 3、Material Design 23.1 材料系统3.2 颜色 4、Material Design 34.1 颜色样式4.2 配色方案4.3 Material Theme Builder 结语 1、简介 1.1 Android 谷歌在2007年发布了第一个测试版…

chatgpt赋能python:Python清除代码:让你的项目更加优美

Python清除代码:让你的项目更加优美 随着时间推移和项目规模扩大,代码中可能会出现许多冗余、无用或重复的代码。这不仅会让代码难以维护,还会降低代码的性能和可读性。而Python作为一种高级编程语言,提供了许多工具和技术来清除…

牛客练习赛108 E.琉焰(非树边性质/线段树分治+可撤销并查集 or LCT)

题目 思路来源 官方题解 题解 针对每个连通块,单独考虑: 一方面, 任取连通块的某棵生成树, 对于任意非树边(u,v),把树边u到v上的所有边都选中,即被覆盖1次, 任取某个非树边集合S&#xff…

LangChain for LLM Application Development 基于LangChain开发大语言应用模型(下)

以下内容均整理来自deeplearning.ai的同名课程 Location 课程访问地址 DLAI - Learning Platform Beta (deeplearning.ai) LangChain for LLM Application Development 基于LangChain开发大语言应用模型(上) 一、LangChain: Q&A over Documents基于文…

bert4rec简介

1、bert4rec提出动机 用户行为动态变化,序列行为建模取得了不错的效果 单向结构限制了行为序列中隐藏信息的挖掘 序列神经网络顺序依赖,无法并行计算 为此,提出了 基于双向self-attention和Cloze task的用户行为序列建模方法。据我们所知…

解决Jenkins报错

解决Jenkins报错 1 linux空间不够问题1.1 报错现象1.2 定位问题1.3 解决措施 2 bash问题2.1 问题现象2.2 问题定位2.3 解决措施 3 虚拟环境问题3.1 问题现象3.2 问题定位3.3 解决措施 4 jenkins构建完成但一直转圈问题4.1 问题现象4.2 问题定位4.3 解决措施 5 jenkins自动化部署…

C高级6.24

一、整理grep、find、cut、tar、apt-get、dpkg、ln、ln-s指令 1.grep ----->查找字符串 grep 字符串 文件名 -w:按单词查找 -R:实现递归查找,主要用于路径是目录的情况 -i:不区分大小写 -n:显示行号 grep -w "^ubuntu" /etc/passwd ---->查找以ub…

【深度学习】RepVGG解析和学习体会

文章目录 前言0. Vgg1.RepVGG Block 详解 前言 论文名称:RepVGG: Making VGG-style ConvNets Great Again 论文下载地址:https://arxiv.org/abs/2101.03697 官方源码(Pytorch实现):https://github.com/DingXiaoH/RepV…

今天是世界Wi-Fi日!

很多人都不知道,今天其实是世界Wi-Fi日: 这个特殊的纪念日,是由无线宽带联盟(Wireless Broadband Alliance)确定的,并得到了互联城市咨询委员会 (CCAB)等组织的大力支持。 无线宽带联…

数据处理神器tidyverse!教你如何秒速搞定数据处理!

一、前言 在R语言中,tidyverse是一个庞大的数据分析生态系统,它由一系列数据可视化和数据处理软件包组成,能够极大地提高数据分析的效率和准确性。 在使用 Tidyverse 的过程中,我们会经常用到以下几个工具: ggplot2&am…

chatgpt赋能python:Python浮点数:介绍、精度和应用

Python浮点数:介绍、精度和应用 Python是一种高级编程语言,许多程序员使用Python编写计算机程序。与其他编程语言不同,Python是一种动态类型的语言,并且它处理浮点数时更加灵活。在本文中,我们将介绍Python浮点数的概…

python自动化办公——读取PPT写入word表格

Python自动化办公——读取PPT内容写入word表格 文章目录 Python自动化办公——读取PPT内容写入word表格一、需求分析二、导入依赖三、代码四、结果及总结 一、需求分析 📖由于我们知识图谱课程需要将课堂小组汇报的PPT总结成word文档,而我觉得一页一页复…

win10安装nginx的配置和使用方法(图文)

window10系统安装nginx服务,提供网页方面的服务。下面为详细图文安装配置教程。 1)下载nginx软件 官方下载地址:http://nginx.org/en/download.html 2)解压缩软件 unzip nginx-1.20.1.zip 或者 使用解压缩软件,下…