Training language models to follow instructions with human feedback 论文阅读

news2024/11/24 15:30:29

论文原文:https://arxiv.org/pdf/2203.02155

论文简介

语言模型越大并不意味着它能更好的理解用户的意图,因此在这篇论文中,展示了根据人的反馈对模型进行微调,使得语言模型能够在各种人物上更好的理解用户的意图。在评估中,1.3B参数的InstructGPT模型的输出比175B GPT-3的输出更受欢迎,尽管参数少了100倍。此外,InstructGPT模型虽然在公共的数据上的效果有所降低,但是真实性和减少有害方面生成的能力提升。论文表明,尽管InstructGPT仍然会犯一些简单的错误,但根据人类反馈进行微调是能够理解人类意图的一个有效的方式和方向。
**相当于是,OpenAI提出了”align“的概念,希望模型的输出与人类的意图”对齐“,其用的方法是RLHF(Reinforcement Learning from Human Feedback)基于人类反馈的强化学习。**

方法和实验细节

在这里插入图片描述

Collect demonstration data, and train a supervised policy. (收集范例数据,并以有监督方式训练)

我们的打标签者提供了输入提示分布(prompt distribution)上所需行为的范例(有关此分布的详细信息,请参阅第 3.2 节)。 然后,我们使用有监督学习在该数据集上微调预训练的 GPT-3 模型。这部分就是根据prompts,也就是写的各种问题,进行标注,将prompts和标注的对话作为人工标注的数据集,对预训练的GPT-3进行有监督微调

Collect comparison data, and train a reward model. (收集比较数据,训练奖励模型)

我们收集了模型输出之间比较的数据集,其中打标记者根据输入标明了他们更喜欢的输出。 然后我们训练奖励模型来预测人类偏好的输出。用上一步得到的SFT模型生成各种问题的答案,再对这些答案进行比较(排序式)标注,如D>C>A=B,基于这个标注数据集,在去掉最后的嵌入层的SFT模型基础上进行有监督学习训练一个RM(reward model),这样使用模型来模仿标注者进行打分

Optimize a policy against the reward model using PPO. (使用PPO针对奖励模型优化策略)

我们使用RM奖励模型的输出作为标量奖励。 我们使用 PPO 算法微调监督策略以优化此奖励。
步骤2和步骤3可以不断迭代; 收集当前最佳策略的更多比较数据,用于训练新的 RM,然后训练新的策略。 在实践中,我们的大部分比较数据来自监管的学习,也有一些来自我们的PPO学习。用上一步的RM模型进行打分,然后分数就可以用强化学习来对SFT模型进行优化

数据集

打标签者提供了输入提示分布(prompt distribution)上所需行为的范例,根据论文所说,为了训练第一个InstructGPT模型,打标签者需要自己编写提示,分为三种:

  • Plain:只是要求标记者提出一个任意的任务,同时确保任务具有足够的多样性。
  • Few-shot:要求标注者提出一条指令,以及针对该指令的多个查询/相应对。
  • User-based:在OpenAI API的候补名单申请中陈述了许多用例,要求标注者提出与这些用例相对应的提示。
    根据这些提示,生成了三个用于微调过程的不同数据集:(1)SFT数据集,带有用于训练SFT模型的打标签者范例数据,(2)RM数据集,带有用于训练的模型已被打标签者分了等级的数据,(3)PPO数据集,没有任何人工标签,用于RLHF微调的输入。SFT数据集包含大约13k个训练提示数据(来自API和标记者编写),RM数据集有33k个训练提示数据(来自API和打标记者编写),PPO数据集有31k个训练提示数据(仅来自API)。
    在这里插入图片描述
    上表显示了API提示(特别是RM数据集)的用例类别的分布,大多数用例都是生成的,而不是分类或QA。在表二中展示了一些说明性提示(由研究人员编写,以模仿提交给InstructGPT模型的提示类型)。

任务

训练任务来自两个来源:(1)由标注者编写的提示数据集和(2)提交给API上的早期InstructGPT模型的提示数据集。这些提示非常多样化,包括生成、问答、对话、摘要、提取和其他自然语言任务。数据集超过96%是英语。
对于每个自然语言提示,任务通常是通过自然语言指令直接指定的(例如”写一个关于聪明青蛙的故事“),但也可以通过少数例子间接指定(例如给出两个青蛙故事的例子,并提示模型生成一个新的)或隐含的连续(例如提供一个关于青蛙的故事的开始)。在每种情况下,我们都要求标注者尽最大努力推断出写提示的用户的意图,并要求他们跳过任务非常不清楚的输入(相当于当任务非常不清楚的时候,可以跳过回答,避免答非所问)。此外,在我们提供给他们的指示和他们的最佳判断的指导下,标注者还需考虑到隐含的意图,如回应的真实性,以及潜在的有害输出,如有偏见或有毒的语言。

模型

我们从GPT-3预训练语言模型开始。这些模型是在广泛分布的互联网数据上进行训练的,可以适应广泛的下游任务,但行为特征不佳。从这些模型开始,我们用三种不同的技术训练模型:

  • 有监督微调(SFT——Supervised fine-tuning),我们使用监督学习对标记器演示中的GPT-3进行微调。我们训练了16个epoch,使用余弦学习率衰减,0.2的残差dropout。我们根据验证集上的RM分数进行最终的SFT模型选择。我们发现SFT模型在1个epoch后对验证损失上过拟合,然而我们发现尽管存在过拟合,但更多epochs的训练有助于RM分数和人类偏好评级。(尽管这个SFT模型训练更多的epoch会产生过拟合,但是这是为了得到后续的RM模型的初始化模型,对RM模型有帮助,并不是直接使用这个SFT模型,所以过拟合没关系
  • 奖励建模(RM——Reward model),从移除了最后的非嵌入层的SFT模型开始(GPT模型最后的softmax层是用于得到每个词的概率,去掉softmax层以后,增加一个线性层来投影,将所有词的输出投影到一个值上面,即输出一个标量的分数),我们训练了一个模型来接收提示和相应,并输出标量奖励。在本文中,我们只使用6B RM,这样可以节省大量计算,而且我们发现175B RM训练可能不稳定,因此不太适合用作RL(Reinforcement learning)中的值函数。RM在同一输入的两个模型输出之间进行比较的数据集上训练。他们使用交叉熵损失,将比较作为标签——奖励的差异代表了人类标记者更喜欢一种反应的对数几率。
    为了加速分等级数据的收集,我们向标签提供者提供 K = 4 K=4 K=4 K = 9 K=9 K=9之间的任何排名相应。这会为显示给标签者的每个提示生成 ( K 2 ) = C K 2 \binom{K}{2}=C_K^2 (2K)=CK2比较。由于分等级数据在每个标记任务中都非常相关,我们发现,如果我们简单地将分等级数据混洗到一个数据集中,在数据集上的一次遍历会导致奖励模型过拟合。相反,我们将每个提示的所有 ( K 2 ) \binom{K}{2} (2K)比较数据作为单个批处理元素进行训练。这在计算上要高效得多,因为它只需要每次完成一次RM的前向传递(而不是超过 ( K 2 ) \binom{K}{2} (2K)次前向传递),而且因为它不在过拟合,大大提高了验证准确性和日志损失。
    具体来说,奖励模型的损失函数为(这里使用的是排序中最常见的pairwise ranking loss,成对排名损失):
    在这里插入图片描述
    这里的 r θ ( x , y ) r_{\theta}(x,y) rθ(x,y)是表示prompt x x x和相应 y y y在参数为 θ \theta θ的奖励模型下的奖励值, y w y_w yw是在prompt x x x下生成的一对响应 y w y_w yw y l y_l yl中更受欢迎的那一个, D D D是比较的数据集。每一个排名对 y i , y j y_i,y_j yi,yj的损失是 − l o g ( σ ( y i − y j ) ) -log(\sigma(y_i-y_j)) log(σ(yiyj)),换成奖励函数就是 − l o g ( σ ( r θ ( x , y w ) ) − r θ ( x , y l ) ) -log(\sigma(r_{\theta}(x,y_w))-r_{\theta}(x,y_l)) log(σ(rθ(x,yw))rθ(x,yl)),然后共 C K 2 C_K^2 CK2个排序对,所以期望除以它。
    目标是最小化这个loss,也就是最大化这两个奖励的差值, l o g ( σ ) log(\sigma) log(σ)最开始的时候是把生成的每个输出对都作为单独的数据混洗到数据集中,这样的话就需要超过 ( K 2 ) \binom{K}{2} (2K)次前向传递,而且输出对之间有重复,这样容易过拟合,所以将所有的输出对都统一作为单个批处理元素进行训练,这样的话就只需要 K K K次前向传递,因为奖励模型只需要算出9个奖励。之所以取 K = 9 K=9 K=9,是因为考虑到人工标注的时候,很大一部分是花在读懂这个prompt,所以在 K = 4 K=4 K=4 K = 9 K=9 K=9之间,只多了不到一倍的时间,但是标注的数据由6变成了36,多了6倍
    最后,由于RM损失对于奖励的变化是不变的,我们使用偏差对奖励模型进行归一化,以便在进行RL之前,标记器演示的平均得分为0。
  • 强化学习(RL——Reinforcement learning),我们使用PPO在我们的环境中微调了SFT模型。该模型是一个bandit环境,它呈现随机的客户提示并期望对提示的响应。给定提示和相应,它会产生由奖励模型确定的奖励并结束情节。此外,我们在每个token上上添加了SFT模型的每个token的KL惩罚,以减轻奖励模型的过度优化。从RM初始化值函数。我们称这些模型为PPO。
    我们还尝试将预训练梯度混合到PPO梯度中,以修复公共NLP数据集上的性能回归。我们称这些模型为”PPO-ptx“。我们在RL训练中最大化以下组合目标函数:
    在这里插入图片描述
    其中 π Θ R L \pi_{\Theta}^{RL} πΘRL是学习到的RL策略, π S F T \pi^{SFT} πSFT是有监督训练的模型, D p r e t r a i n D_{pretrain} Dpretrain是预训练分布。KL奖励系数 β \beta β,预训练损失系数 γ \gamma γ分别控制KL惩罚和预训练梯度的强度。对于”PPO“模型, γ \gamma γ被设置为0,除非另有说明,本文中的InstructGPT指的是PPO-ptx模型。对于上面说的31k个prompts数据集 D D D,都使用当前的RL模型,也就是RL策略 π θ R L \pi_{\theta}^{RL} πθRL,输出 y y y,然后用RM模型得到分数 r θ ( x , y ) ,目标函数是希望这个分数最大化 r_{\theta}(x,y),目标函数是希望这个分数最大化 rθ(x,y),目标函数是希望这个分数最大化然后根据这个目标函数,更新RL模型,然后再用RM模型计算得分,反复迭代。
    目标函数中还有两项,在此分别解释一下, β l o g ( π Θ R L ( y ∣ x ) / π S F T ( y ∣ x ) ) \beta log(\pi_{\Theta}^{RL}(y|x)/\pi^{SFT}(y|x)) βlog(πΘRL(yx)/πSFT(yx))是正则项,这是PPO的主要思想,随着模型的更新,RL产生的输出 y y y和原始的 S F T SFT SFT模型输出的 y y y会逐渐不一样,即数据分布( y ∣ x y|x yx)的差异会越来越大, R L RL RL的输出可能会不准,所以论文在loss里加入了一个KL散度 KL ( P ∥ Q ) = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) = ∫ P ( x ) log ⁡ ( P ( x ) Q ( x ) )   d x \text{KL}(P \parallel Q) = \sum_{x} P(x) \log \left(\frac{P(x)}{Q(x)}\right)= \int P(x) \log \left(\frac{P(x)}{Q(x)}\right)\, dx KL(PQ)=xP(x)log(Q(x)P(x))=P(x)log(Q(x)P(x))dx,用于描述一个概率分布相对于另一个概率分布的非对称性差异,相当于用这个散度来正则,希望RLSFT的输出分布不要偏太远,因为是最大化目标函数,所以要最小化KL散度需要在前面加一个负号。
    γ E x   D p r e t r a i n [ l o g ( π Θ R L ( x ) ) ] \gamma E_x ~ D_{pretrain}[log(\pi_{\Theta}^{RL}(x))] γEx Dpretrain[log(πΘRL(x))],由于前两项目标函数只和人类排序部分有关,所以训练出来会导致模型仅仅对排序的结果较好,而在最终任务通用NLP任务上性能会下降,所以论文在loss中加入了GPT-3预训练模型的目标函数, D p r e t r a i n D_{pretrain} Dpretrain表示从训练GPT-3的预训练数据中采样 x x x,然后输入RL模型得到输出概率 π Θ R L ( x ) \pi_{\Theta}^{RL}(x) πΘRL(x),这样相当于是GPT-3本身的损失函数。

    总的来说,如果 γ = 0 \gamma=0 γ=0就是一个PPO函数,否则就是一个PPO加上一个GPT-3的目标函数的结合成为RL模型的目标函数,也就是PPO-ptx
    在这里插入图片描述

讨论

论文提出,本文使用的”对齐技术“——RLHF,是用于对齐人类系统的一个重要方法。与预训练相比,增加模型对齐的成本是适中的(仅仅标注几万条prompt数据),与训练GPT-3的花费相比(海量的各种数据),只占一小部分。上述结果也表明,RLHF在使语言模型更加helpful(真实性和无害性是被隐式优化了)方面非常有效,甚至比模型增加100倍更有效。所以,在自然语言领域,研究alignment可能比训练更大规模的模型更具性价比。
align也有争议,就是到底要align人类到什么地步,是用户让做什么就做什么,还是要理解用户更深层的、内在的一些东西。此外最后的RL模型也不是必要的,如果在第一步多标数据,在GPT-3微调,步骤会变得简单,可能更加实用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1830871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux系统之Ward服务器监控工具

Linux系统之Ward服务器监控工具 文章目录 Linux系统之Ward服务器监控工具介绍资源列表基础环境一、安装Java环境二、下载ward的jar包2.2、下载软件包 三、安装ward工具3.1、启动ward服务3.2、查看你后台启动任务3.3、监听ward服务端口 四、访问ward服务4.1、进入ward初始界面4.…

【可控图像生成系列论文(一)】MimicBrush 港大、阿里、蚂蚁集团合作论文解读

背景:考虑到用户的不同需求,图像编辑是一项实用而富有挑战性的任务,其中最困难的部分之一是准确描述编辑后的图像应该是什么样子。 创新点:在本文作者提出了一种新的编辑形式,称为模仿编辑,以帮助用户更方…

深入理解 Java 中的 volatile 关键字

暮色四合,晚风轻拂,湖面上泛起点点波光,宛如撒下了一片星河。 文章目录 前言一、CPU 三级缓存二、JMM三、并发编程正确性的基础四、volatile 关键字五、volatile 可见性六、volatile 有序性6.1 指令重排序6.2 volatile 禁止指令重排6.3 vola…

虚拟机使用桥接模式网络配置

1、获取本机的网络详细信息 windowr 输入cmd 使用ipconfig -all 一样即可 在自己的虚拟机中设置网络 虚拟机中的ip ---------192.168.36.*,不要跟自己的本机ip冲突 网关-----------192.168.36.254 一样即可 dns -----------一样即可,我多写了几个&am…

数字孪生智慧机场:引领航空未来

图扑数字孪生技术赋能智慧机场,实现运营管理和乘客服务的全面优化。实时数据监控与智能决策助力高效安全的航空体验,推动行业创新与发展。

网络安全:SQL注入防范

文章目录 网络安全:SQL注入防范引言防范措施概览使用参数化查询示例代码 输入验证和过滤示例代码 使用ORM框架示例代码 其他防范措施结语 网络安全:SQL注入防范 引言 在上一篇文章中,我们介绍了SQL注入攻击的基础知识。本文将重点讨论如何防…

【UML用户指南】-16-对高级结构建模-构件

目录 1、概念 2、构件与接口 3、可替换性 4、组织构件 5、端口 6、内部结构 6.1、部件 6.2、连接件 7、常用建模技术 7.1、对结构类建模 7.2、对API建模 构件是系统中逻辑的并且可替换的部分,它遵循并提供对一组接口的实现。好的构件用定义良好的接口来定…

来自工业界的知识库 RAG 服务(四),FinGLM 竞赛冠军项目详解

背景介绍 在 前一篇文章 中介绍过智谱组织的一个金融大模型 RAG 比赛 FinGLM 以及 ChatGLM反卷总局 团队的项目,这篇文章继续介绍下获得冠军的馒头科技的技术方案。 建议不了解比赛背景信息的可以先查看 来自工业界的知识库 RAG 服务(三),FinGLM 竞赛获…

[YOLOv10涨点改进:注意力魔改 | 轻量级的 Mixed Local Channel Attention (MLCA),加强通道信息和空间信息提取能力]

本文属于原创独家改进:一种轻量级的Mixed Local Channel Attention (MLCA)模块,该模块考虑通道信息和空间信息,并结合局部信息和全局信息以提高网络的表达效果 1.YOLOv10介绍 论文:[https://arxiv.org/pdf/2405.14458] 代码: https://gitcode.com/THU-MIG/yolov10?utm_s…

基于R-Tree的地理空间数据分析加速

几年前,我正在做一个业余项目。我想创建一个 Web 应用程序,推荐当地的特色景点,例如咖啡馆、书店或隐藏的酒吧。我的想法是在地图上显示用户触手可及的所有兴趣点。我的数据集中有数十万个点,我必须巧妙地过滤用户给定范围内的数据…

DVWA - Brute Force

DVWA - Brute Force 等级:low ​ 直接上bp弱口令爆破,设置变量,攻击类型最后一个,payload为用户名、密码简单列表 ​ 直接run,长度排序下,不一样的就是正确的用户名和密码 ​ 另解: 看一下…

3DMAX网格插入插件使用方法讲解

3DMAX网格插入插件使用方法 3DMAX网格插入插件,在选择的面上安门窗、打螺丝、挖洞、插入眼耳口鼻及其它网格模型等可以分分钟搞定!它通过将面选择替换为库中的资源来加快建模过程。非常适合硬网格和有机建模! 【适用版本】 3dMax2013及更高版…

快速欧氏聚类与普通欧氏聚类比较

1、前言 文献《FEC: Fast Euclidean Clustering for Point Cloud Segmentation》介绍了一种快速欧氏聚类方法,大概原理可以参考如下图,具体原理可以参考参考文献。 2、时间效率比较:快速欧氏聚类VS普通欧氏聚类 网上搜集的快速欧式聚类,与自己手写的普通欧式聚类进行对比,…

网络知识:这些特殊的IP地址,具体的用途你都知道吗

目录 一、0.0.0.0 二、255.255.255.255 限制广播地址 三、127.0.0.1 本机地址 四、224.0.0.1 组播地址 五、169.254.x.x 六、10.x.x.x、172.16。x。x~172.31。x。x、192.168。x。x 私有地址 对于计算机网络来说,IP地址是非常重要的概念&#xff0c…

Objective-C 学习笔记 | 协议(property)

Objective-C 学习笔记 | 协议(property) Objective-C 学习笔记 | 协议(property) Objective-C 学习笔记 | 协议(property) iOS 应用经常会用 UITableView 实例来显示数据,但是它本身不包含数据…

采集罗克韦尔AB、西门子等PLC数据发布成HTTP接口

智能网关IGT-DSER集成了多种PLC的原厂协议,方便实现各种PLC的原厂协议转换为HTTP协议的服务端,通过网关的参数配置软件绑定JSON文件的字段与PLC寄存器地址,即可接收来自客户端的GET、PUT和POST命令,解析和打包JSON文件(JSON文件格…

去哪儿网PMO张璐受邀为第十三届中国PMO大会演讲嘉宾

全国PMO专业人士年度盛会 去哪儿网PMO张璐女士受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾,演讲议题为“数字化助力组织目标落地”。大会将于6月29-30日在北京举办,敬请关注! 议题简要 本次议题将分享去哪儿流程标准化&工具化…

我用chatgpt写了一款程序

众所周知,Chatgpt能够帮助人们写代码,前几天苏音试着完全用Chatgpt写一款Python程序 有一句话我很赞同,未来能代替人的不是AI,是会使用AI的人。 最终,写下来效果还不错,完全提升了我的办公效率。 开发前…

告密者斯诺登:永远不要信任 OpenAI 或其 ChatGPT 等产品|TodayAI

为什么 OpenAI 变得越来越难以信任 OpenAI,一家以开发先进人工智能技术而闻名的公司,正面临越来越多的信任危机。近期,一些令人不安的迹象使人们对这家公司的透明度和安全性产生了质疑。 首先,在 OpenAI 的旧金山办公室外&#…

顺安蜘蛛池四川官网下载

baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? baidu搜索:如何联系八爪鱼SEO? 虽然影视泛目录很火,但超度站群版本自出现以来-直流量稳定,可惜这两年起站全靠域名。但话说回来,咱不能否认,只要用的域名好,做啥泛目录都有好…