InstructGPT 论文阅读笔记

news2025/1/11 12:00:30

目录

简介

数据集                                

详细实现

实验结果

参考资料


简介

InstructGPT 模型是在论文《Training language models to follow instructions with human feedback》被提出的,OpenAI在2022年1月发布了这篇文章。

论文摘要翻译:把语言模型做得更大并不意味着让它们更好的遵循用户的意图。例如,大型语言模型可能会生成不真实、有毒或对用户没有帮助的输出。换句话说,这些模型没有与他们的用户对齐。本文中我们展示了一种通过微调人类反馈来使语言模型与用户在一系列任务中对齐意图的方法。从一组标注员手写的prompts和从OpenAI API提交的prompts开始,我们收集了一个有标注员演示的预期模型行为的数据集,用该数据集使用有监督学习微调了GPT-3模型。我们接着收集了一个模型输出结果排名数据集,用它通过RLHF(reinforcement learning from human feedback)来进一步微调这个有监督模型,我们把生成的模型称为InstructGPT。在我们的prompt 数据集的人类评估中,1.3B参数的InstructGPT模型结果比175B参数的GPT-3结果更被接受,尽管少了100倍的参数。此外,InstructGPT模型在提高真实性和减少有毒输出生成的同时,在公共NLP数据集上具有最小的性能回归。尽管InstructGPT仍然会犯一些简单的错误,但我们的结果表明,根据人类的反馈进行微调是使语言模型与人类意图对齐的一个有前景的方向

在前言部分开头就提到大语言模型可以通过给一些任务示例、被prompted的方式来执行一系列NLP任务,但是这些模型经常有一些不符合意图的行为,如捏造事实、产生有偏见或有毒的文本,或者完全不遵循指令。出现这些行为的原因是因为大语言模型的目标函数--预测从互联网得到的文本的下一个token,与“有用且安全地遵循用户的指令”这一目标是不同。因此作者们说语言模型的目标函数是misaligned。

那怎么避免模型出现这些不符合意图的行为,让语言模型变得helpful、honest、harmless呢?

  • helpful (they should help the user solve their task)
  • honest (they shouldn't fabricate information or mislead the user)
  • harmless (they should not cause physical, psychological, or social harm to people or the environment)

InstructGPT的做法是使用微调的方式来对齐(align)语言模型,特别是使用了RLHF(reinforcement learning from human feedback), 如下图1通过三步来对齐语言模型。

原论文中的Figure 2

数据集                                

在详细的三个步骤说明前,先来看看prompt数据集是怎么得到的:

  • 数据集中的prompt主要来源是用户使用OpenAI API时的prompt输入,特别是通过instructGPT早期版本的playground 接口提交的。对于每个用户ID,最多只取200个prompt。 并通过检查prompt是否有很长的公共前缀来进行去重,在划分train/validation/test数据集时,会通过用户ID来划分,以便验证集和测试集不包含来自训练集中数据的用户的数据。为了避免模型学习到用户的敏感信息,将个人身份信息去除了(personally identifiable information (PII))。
  • 一个有40人的标注团队也写了一些prompt, 共有三种类型的prompt:

    • Plain: We simply ask the labelers to come up with an arbitrary task, while ensuring the tasks had sufficient diversity.

    • Few-shot: We ask the labelers to come up with an instruction, and multiple query/response pairs for that instruction.

    • User-based: We had a number of use-cases stated in waitlist applications to the OpenAI API. We asked labelers to come up with prompts corresponding to these use cases.

基于这这些prompt, 生成了三个数据集:

  • SFT数据集,有标注员提供的prompt的演示输出结果,共13k的训练prompts, 用来训练第一步的SFT模型
  • RM数据集,有标注员标注的模型输出排名,共33k的训练prompts, 用来训练第二步中的RM模型。
  • PPO数据,没有任何人类标注,用来作为第三步的RLHF微调的输入,共31k 的训练数据(数据只来自API,也就是没有使用标注员生成的prompt)

数据集详细大小数据如下图, 数据集中的96%都是英语。

原论文中table 6 

prompt 类型分布如下截图 :

原论文中table 1

更详细的关于数据集的说明在论文附录A, 如何使用筛选测试来挑选标注员以及提供给标注员的标注指南可参考附录B

详细实现

第一步:

使用监督学习来微调GPT-3模型:

  • using a cosine learning rate decay, and residual dropout of 0.2
  • 通过在验证集上的RM 分数来进行SFT的模型选择
  • 一共训练了16 epochs, 虽然在1个epoch之后再验证集上的损失开始出现过拟合,但是发现训练更多的epoch可以同时提升RM和人类偏好评分

第二步:

从第一部分的去掉最后的unembedding layer的SFT模型开始,对于输入的prompt和语言模型输出结果,训练一个奖励模型(RM)输出一个标量奖励。

只使用了6B的模型,不仅节约大量计算,更因为175B的模型的训练很不稳定,所以不适合作为RL的值函数。

对于每个prompt,给每个标注员K=4到K=9的responses 来排序,故会生成\binom{K}{2}个比较对。如果直接将比较对shuffle组成一个数据集,只对数据集进行一次训练就会过拟合。而如果将一个prompt的\binom{K}{2}比较对作为同一个batch的元素,这不仅节省了计算量(对于每一个completion只需要一次前向传播,而不是对K个completion要有\binom{K}{2}个前向传播),并且不会过拟合,可以极大的提高验证集准确率和log loss。

奖励模型的损失函数定义为下式,是一个pairwise的ranking loss。 

loss(\theta ) = - \frac{1}{\binom{K}{2}} E_{(x, y_w, y_l) \sim D}[log(\sigma (r_{\theta} (x, y_w) - r_{\theta}(x, y_l)))]

式中r_{\theta}(x, y) 是参数为\theta的奖励模型对于prompt x和completion y的标量输出,对于比较对 y_w 和y_ly_w是更被接受的completion,D是人工比较对的数据集。

由于RM损失对奖励的偏移是不变的,因此使用偏差对奖励模型进行规范化(normalize),以便在进行强化学习之前标注员演示达到平均得分0。

第三步:

使用PPO来对SFT模型在试验环境中进行微调, 试验环境是一个展示随机的用户prompt并期望给出一个response的bandit 环境。对于给定的prompt和response, 环境根据第二部分的奖励模型给出一个奖励并结束这一回合。

给每个token添加了一个从SFT 模型得到的per-token KL penalty 以防止奖励模型的过度优化, 值函数是从奖励模型(RM)初始化得到的,将这个模型称为“PPO”。

为了修复在公共NLP数据集上的performance regressions, 将预训练梯度加入到PPO梯度,这些模型被称为“PPO-ptx"。 

因此,在RL训练过程中是最大化如下组合目标函数:

objective(\phi ) = E_{(x, y)\sim D_{\pi_{\phi}^{RL}}} \left[ r_{\theta}(x, y) - \beta log(\pi_{\phi}^{RL} (y|x) / \pi^{SFT} (y|x)) \right] + \\ \qquad \ \ \ \gamma E_{x \sim D_{pretrain} } \left[ log(\pi_{\phi}^{RL} (x)) \right]

式中,\pi_{\phi}^{RL}是学习到的RL策略, \pi^{SFT}是第一步的监督训练模型,D_{pretrain} 是预训练数据集。KL奖励系数\beta控制KL惩罚的程度,预训练损失系数 \gamma控制预训练梯度的程度, 对于“PPO”模型,\gamma

为0, 除非特殊说明,论文中的InstructGPT 都是指PPO-ptx 模型。

Baseline:  

  • 比较PPO模型与SFT模型和GPT-3的效果, 也比较了在prompt中加入一些例子给GPT-3(称为GPT-3-prompted)。
  • 在FLAN和T0数据集比较InstructGPT 与微调的175B的GPT-3。 分别在大约100万个样本上对它们进行微调,并选择在验证集中获得最高奖励模型分数的checkpoint。

(模型的更多训练细节在附录C。)

实验结果

几个关键结论:

  • 与GPT-3的输出相比,标注员明显更喜欢InstructGPT输出(如下图)。

原论文中的Figure 1
  • 模型可以泛化到没有参与过训练标注的标注员的喜好
  • 公共 NLP 数据集不能反映如何使用语言模型
  • InstrucGPT 相比与GPT-3在真实性上有所提高
  • InstrucGPT 相比与GPT-3在有毒内容上有一点点提高,但是在偏见上没有提高
  • 通过修改RLHF的微调程序,可以减小在公共NLP数据集上的performance regressions。
  • InstructGPT模型显示了对RLHF微调分布之外的指令的有效的泛化
  • InstructGPT仍然会犯一些简单的错误

参考资料

1. Ouyang, Long, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll Wainwright, Pamela Mishkin, Chong Zhang, et al. n.d. “Training Language Models to Follow Instructions with Human Feedback.” 

2. https://openai.com/research/instruction-following 

3. InstructGPT 论文精读【论文精读·48】_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/475702.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AttributeError: ‘Document‘ object has no attribute ‘pageCount‘ PyMuPDF库

这可能是由于PyMuPDF库更新导致的,里面的一些函数名发生了变化 1. AttributeError: Document object has no attribute pageCount 将 pageCount改为 page_count 2. AttributeError: Matrix object has no attribute preRotate 将preRotate改为prerotate 3.Attribut…

关于FFMPEG中的filter滤镜的简单介绍

滤镜的作用主要是对原始的音视频数据进行处理以实现各种各样的效果。比如叠加水印,翻转缩放视频等。 下图表示的正常转码流程,滤镜在解码和编码中间,虚线表示可有可无。 使用命令查看ffmpeg支持的滤镜 ffmpeg -filters 查看某个滤镜的详细参…

k210点亮LED灯

开发板上自带的3个led灯接线如图。 点亮led灯主要使用两个模块,如下: fm.register(pin,function,forceFalse) 【pin】芯片外部 IO 【function】芯片功能 【force】True 则强制注册,清除之前的注册记录 例:fm.register(12, fm.f…

真题详解(有向图)-软件设计(六十二)

真题详解(极限编程)-软件设计(六十一)https://blog.csdn.net/ke1ying/article/details/130435971 CMM指软件成熟度模型,一般1级成熟度最低,5级成熟度最高,采用更高级的CMM模型可以提高软件质量。 初始&am…

RepVGG学习笔记

RepVGG 0 前言1 结构重参数化1.1 结构重参数化第一步(将 C o n v 2 D Conv2D Conv2D算子和 B N BN BN算子融合以及将只有 B N BN BN的分支转换成一个 C o n v 2 D Conv2D Conv2D算子)1.2 结构重参数化第二步(多分支的 3 3 3\times3 33卷积融…

安全运营 ldap监控域控信息

0x00 背景 公司有多个主域,子域,有的子域因为境外数据安全的问题无法把日志传输到境内。那么如何在没有日志的情况下监控子域或者互信域的组织单元(OU)信息呢。 由于访问互信域要在域控上进行,本文根据最小权限原则监控普通用户也可以访问的…

Packet Tracer - 配置和验证小型网络

Packet Tracer - 配置和验证小型网络 地址分配表 设备 接口 IP 地址 子网掩码 默认网关 RTA G0/0 10.10.10.1 255.255.255.0 不适用 G0/1 10.10.20.1 255.255.255.0 不适用 SW1 VLAN1 10.10.10.2 255.255.255.0 10.10.10.1 SW2 VLAN1 10.10.20.2 255.25…

【C++】set和map的使用

对于STL容器来说,有很多相似的功能,所以这里主要将与之前不同的功能说清楚 文章目录 1.对于set与set的简单理解2. setinsert迭代器遍历countmultisetinsertfindcount 3. mapinsert与迭代器的使用统计水果次数 operator []operator[]的实现理解对整体的拆…

Nginx:常见的面试题和答案

1. 什么是Nginx? 答:Nginx是一款高性能的Web服务器和反向代理服务器,用于HTTP、HTTPS、SMTP、POP3和IMAP协议,同时用于处理高并发的请求,提供快速、可靠的服务。 2. Nginx的优点是什么? Nginx的优点包括&#xff1a…

【BeautifulSoup上】——05全栈开发——如桃花来

目录索引 介绍:解析库: 安装:pip install BeautifulSoup4pip install lxml 标签选择器:1.string属性:.name属性:获取标签中的属性值: 实用——标准选择器:使用find_all()根据标签名查…

五、C++内存管理机制 —— primitives(侯捷)

侯捷 C八部曲笔记汇总 - - - 持续更新 ! ! ! 一、C 面向对象高级开发 1、C面向对象高级编程(上) 2、C面向对象高级编程(下) 二、STL 标准库和泛型编程 1、分配器、序列式容器 2、关联式容器 3、迭代器、 算法、仿函数 4、适配器、补充 三、C 设计模式 四、C 新标准 五、C 内存管…

2023-04-29 动态规划介绍

2023-04-29 动态规划介绍 动态规划是运筹学课程的一部分 多阶段决策问题 有一类活动的过程,可以分成若干个互相联系的阶段,在它的每一阶段都需要作出决策,从而使整个过程达到最好的活动效果 当然,每个阶段的决策的选取不是任意确…

dc-6靶机

1.使用nmap进行信息搜集,存活主机,端口 192.168.85.184是存活主机,发现开放22,80端口 2.访问192.168.85.184的80端口 发现被重定向了,修改hosts文件 vim /etc/hosts 添加一行 192.168.85.174 wordy3.对网站进行信息搜…

彻底解决 Lost connection to MySQL server at ‘reading initial communication packet’, system error: 0 解决方法

当我遇到这错误的时候,我去网上也找过对应解决方法,出现这个的原因有很多种情况 大多是解决Linux系统里的 我是windows系统里的MySQL服务出问题了,所有那些方法对我来说毫无意义. 好了,说一下我的解决办法,其实也很简单 只需要卸载mysql服务,注册表也要删干净,也要把环境变…

C的文件操作

🐖前言 🐕1.为什们我们要用文件 在我们之前写程序时,如果使用scanf函数用键盘输入数据,这些东西都放到内存当中,一旦退出程序,那么这些数据就会消失,比如就像我们写的通讯录,不管是…

Shiro相关知识

1、Shiro功能概述 Apache Shiro是一个功能强大且易于使用的 Java 安全框架,可执行身份验证、授权、加密和会话管理。 主要功能: Authentication:身份认证。登录时验证身份信息。 Authorization:授权操作。访问控制的过程&…

CSS布局基础(标签类型,盒子模型)

布局基础 元素显示类型,盒子模型 标签类型块元素常见块元素 行内元素常见行内元素 行内块元素常见行内块 模式转换显示类型显著区别 盒子模型盒子组成布局描述边框圆角 内边距外边距块元素居中盒子内行内(块)元素居中 外边距使用陷阱两盒子外…

【进阶C语言】动态版通讯录的实现(详细讲解+全部码源)

前言 📕作者简介:热爱跑步的恒川,致力于C/C、Java、Python等多编程语言,热爱跑步,喜爱音乐的一位博主。 📗本文收录于C语言进阶系列,本专栏主要内容为数据的存储、指针的进阶、字符串和内存函数…

Linux基础IO【重定向及缓冲区理解】

✨个人主页: 北 海 🎉所属专栏: Linux学习之旅 🎃操作环境: CentOS 7.6 阿里云远程服务器 文章目录 🌇前言🏙️正文1、文件描述符1.1、先描述,再组织1.2、files_struct1.3、分配规则…

Java数组的学习(基础)

目录 第一章:数组的概念介绍 1.数组的概念 2.数组的初始化/数组的创建/数组的定义 第二章:数组的使用 数组添加元素的方法/数组的赋值 数组的遍历 数组之选择排序的升序 数组之冒泡排序的升序 数组的最小值 数组的反转 数组中常见的异常 第三…