InstructGPT的四阶段:预训练、有监督微调、奖励建模、强化学习涉及到的公式解读

news2024/10/23 13:32:43

1. 预训练

在这里插入图片描述

1. 语言建模目标函数(公式1):

L 1 ( U ) = ∑ i log ⁡ P ( u i ∣ u i − k , … , u i − 1 ; Θ ) L_1(\mathcal{U}) = \sum_{i} \log P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) L1(U)=ilogP(uiuik,,ui1;Θ)

  • 解释
    • U = { u 1 , u 2 , … , u n } \mathcal{U} = \{u_1, u_2, \dots, u_n\} U={u1,u2,,un}是输入的未标注语料(token序列)。
    • k k k是上下文窗口的大小,即预测当前词 u i u_i ui时,使用前 k k k个词( u i − k , … , u i − 1 u_{i-k}, \dots, u_{i-1} uik,,ui1)作为上下文。
    • P ( u i ∣ u i − k , … , u i − 1 ; Θ ) P(u_i \mid u_{i-k}, \dots, u_{i-1}; \Theta) P(uiuik,,ui1;Θ) 是模型根据前 k k k个词预测 u i u_i ui的条件概率,其中参数 Θ \Theta Θ是通过训练得到的神经网络参数。
    • 通过最大化对数似然,模型被训练以最小化预测和真实词之间的差距,这个过程通常通过随机梯度下降(SGD)进行。

2. Transformer解码器结构(公式2):

公式2描述了模型的架构,采用了多层的Transformer解码器。Transformer通过自注意力机制来捕捉上下文依赖关系,并对输入序列进行编码。

初始嵌入层:

h 0 = U W e + W p h_0 = U W_e + W_p h0=UWe+Wp

  • 解释
    • U = ( u − k , … , u − 1 ) U = (u_{-k}, \dots, u_{-1}) U=(uk,,u1) 是上下文窗口中输入序列的词向量。
    • W e W_e We是词嵌入矩阵,用于将输入的token转换为词向量。
    • W p W_p Wp是位置嵌入矩阵,提供每个token的位置编码,用于捕捉词序信息。
Transformer块:

h l = transformer_block ( h l − 1 ) ∀ i ∈ [ 1 , n ] h_l = \text{transformer\_block}(h_{l-1}) \quad \forall i \in [1, n] hl=transformer_block(hl1)i[1,n]

  • 解释
    • h l h_l hl表示第 l l l层的Transformer输出。
    • 每一层 h l h_l hl是通过前一层 h l − 1 h_{l-1} hl1经过Transformer块(自注意力和前馈网络)的处理得到的。
    • 共有 n n n层,每一层都通过类似的操作进行。
输出层:

P ( u ) = softmax ( h n W e T ) P(u) = \text{softmax}(h_n W_e^T) P(u)=softmax(hnWeT)

  • 解释
    • h n h_n hn是最后一层的输出,经过词嵌入矩阵的转置 W e T W_e^T WeT变换后,再通过Softmax函数计算每个词的概率分布。
    • 这个概率分布用于预测输出的目标词 u u u,Softmax确保输出的各个词的概率和为1。

总结:

  • 该模型采用无监督的方式进行预训练,利用大规模未标注语料数据,通过最大化词序列的条件概率来训练语言模型。
  • 预训练的模型架构基于Transformer解码器,通过多层自注意力机制和位置编码来有效捕捉上下文信息,并使用Softmax输出目标词的概率分布。

2. 有监督微调

在这里插入图片描述

1. 微调任务的目标函数(公式3):

P ( y ∣ x 1 , … , x m ) = softmax ( h l m W y ) P(y \mid x^1, \dots, x^m) = \text{softmax}(h_l^m W_y) P(yx1,,xm)=softmax(hlmWy)

  • 解释
    • x 1 , … , x m x^1, \dots, x^m x1,,xm 是输入的token序列。
    • h l m h_l^m hlm表示输入序列经过预训练模型(如Transformer)的最后一层输出的激活值(即特征表示)。
    • W y W_y Wy是用于预测目标标签 y y y的线性层的参数矩阵。
    • Softmax 函数将线性层的输出转换为每个类的概率分布,用于分类任务中的标签预测。

2. 最大化目标函数(公式4):

L 2 ( C ) = ∑ ( x , y ) log ⁡ P ( y ∣ x 1 , … , x m ) L_2(C) = \sum_{(x, y)} \log P(y \mid x^1, \dots, x^m) L2(C)=(x,y)logP(yx1,,xm)

  • 解释
    • 这是监督学习的目标函数,模型通过最大化预测标签 y y y 的对数概率来微调模型参数。
    • C C C是标注数据集,包含输入序列 x x x和相应的标签 y y y
    • 目标是最大化所有样本的对数似然,确保模型在监督任务中的准确性。

3. 辅助目标函数(公式5):

L 3 ( C ) = L 2 ( C ) + λ ∗ L 1 ( C ) L_3(C) = L_2(C) + \lambda \ast L_1(C) L3(C)=L2(C)+λL1(C)

  • 解释
    • 为了提高监督学习的效果,模型还结合了语言建模的辅助目标,即无监督的语言建模损失( L 1 L_1 L1 )和监督任务的损失( L 2 L_2 L2)相结合。
    • λ \lambda λ 是用于平衡两个目标的权重参数。
    • 这样做的好处是:可以通过语言模型的任务帮助监督任务更好地泛化,同时加快收敛速度。这种辅助目标在之前的研究中已证明可以有效提高性能。

总结:

  • 监督微调阶段,模型利用预训练好的参数,结合带标签的数据来优化预测性能。
  • 通过最大化预测标签 y y y的对数概率,模型适应特定任务。
  • 引入语言建模作为辅助任务,有助于提升模型的泛化能力和训练效率。

3. 奖励建模

在这里插入图片描述

这个段落介绍了 奖励模型(Reward Modeling, RM) 在 InstructGPT 模型中的训练方式,具体描述了模型如何从 监督微调模型(SFT) 继续优化以输出奖励分数(reward score),并通过 比较(comparison)来训练这个奖励模型。

核心内容解释:

  1. 奖励模型的基础

    • 奖励模型的训练是基于监督微调模型(SFT)进一步改进的。为了输出奖励,SFT 模型的最后一层(unembedding layer)被移除,留下的是可以根据给定的提示(prompt)和回应(response)输出一个标量的奖励值。
    • 为了节省计算资源,他们使用了一个 6B 参数的奖励模型(RM),因为较大规模的 175B 参数奖励模型虽然理论上可能更准确,但实际训练时表现不稳定,且不适合作为 RL 的值函数。
  2. Stiennon et al. (2020) 的方法

    • 数据集:奖励模型通过比较训练,数据集包含两个模型生成的输出(response),并根据这些生成的结果做出比较。
    • 损失函数:训练中使用了 交叉熵损失(cross-entropy loss),这些比较的标签是人类标注员根据两者的优劣给出的。交叉熵损失衡量了两个结果中哪一个应该被人类标注员优先选择,实际优化的是两者奖励分数的对数几率(log odds)。
  3. 加速比较收集过程

    • 在标注过程中,给标注员展示了 K = 4 到 9 个不同的生成回应,标注员需要对这些生成的结果进行排名。这会生成 ( K 2 ) \binom{K}{2} (2K)对比较,即每个标注任务中有 K 个回应时,将产生 K 中选取 2 个的组合数的比较对数。比如 K = 9 K = 9 K=9 时,会生成 ( 9 2 ) = 36 \binom{9}{2} = 36 (29)=36 对比较。
    • 为了避免过度拟合,研究者决定不将所有比较对一起训练(因为不同的比较对之间存在强相关性),而是仅从每个提示中抽取一对比较结果作为一个训练样本。这种方式更加 计算高效,因为对于每个完成的任务只需要一次前向传播(forward pass),而不是处理所有 ( K 2 ) \binom{K}{2} (2K) 的比较对。
  4. 奖励模型的损失函数

    奖励模型的损失函数定义如下:
    loss ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) ∼ D [ log ⁡ ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] \text{loss}(\theta) = -\frac{1}{\binom{K}{2}} \mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \left( \sigma \left( r_\theta(x, y_w) - r_\theta(x, y_l) \right) \right) \right] loss(θ)=(2K)1E(x,yw,yl)D[log(σ(rθ(x,yw)rθ(x,yl)))]

    • σ ( ⋅ ) \sigma(\cdot) σ()sigmoid 函数,用于将差值映射到 [0, 1] 区间,用来表示某一结果被标注员认为更优的概率。
    • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型对于给定提示 x x x和生成的回应 y y y所输出的奖励分数。
    • y w y_w yw y l y_l yl分别是优胜和劣胜的生成结果。
    • D D D是人类标注员比较的结果数据集。
    • 1 ( K 2 ) \frac{1}{\binom{K}{2}} (2K)1 是对所有比较的标准化处理。

解释

  • 损失函数的目标是最大化优胜结果 y w y_w yw 比劣胜结果 y l y_l yl更受偏好的概率(通过两者奖励分数差异的 sigmoid 值来实现)。这个函数通过最小化损失来优化奖励模型,使得奖励模型能够更准确地给出与人类标注员偏好一致的分数。
  1. 防止过拟合的解释(脚注 5)
    • 如果每一个比较对都被视为一个单独的数据点,那么每个生成的回应可能在训练中会得到 K − 1 K-1 K1 次更新,从而导致模型过拟合。而研究人员发现,模型过度训练甚至只需一个 epoch 就会过拟合。为了解决这个问题,他们只对每个提示下的一对回应进行一次前向传播训练,从而避免过拟合。

总结:

  • 奖励模型的训练基于人类反馈,通过比较两个模型生成的回应来进行优化。该训练过程使用了 交叉熵损失函数,优化目标是让奖励模型尽可能地预测出哪个回应更符合人类标注员的偏好。
  • 通过只选取部分比较对进行训练(而不是所有组合对),减少了计算开销,并有效避免了模型过拟合。

4. 强化学习(PPO)

在这里插入图片描述

1. 强化学习(Reinforcement Learning, RL)

在这部分,模型使用了强化学习(RL)进行微调,采用了 PPO(Proximal Policy Optimization) 算法来优化策略。PPO 是一种策略梯度算法,常用于强化学习任务中,通过限制策略更新的步长来提高训练的稳定性。

2. 环境设置

强化学习的环境被设置为一个 多臂老虎机问题(bandit environment),该环境会随机给定提示(prompt),模型需要生成相应的回应。生成的回应会通过奖励模型(reward model)来打分并结束当前回合。

为了防止模型过度优化奖励模型,训练过程中在每个 token 的输出时,加入了 KL 惩罚项,这项惩罚的来源是监督微调模型(SFT, Supervised Fine-Tuned Model)。

3. PPO-ptx 模型

为了提高模型的泛化能力,研究者还尝试将 预训练梯度(pretraining gradients) 与 PPO 的梯度混合,构建了所谓的 PPO-ptx 模型。这种方法可以解决在某些公共 NLP 数据集上性能回退的问题。

他们使用了以下的 目标函数(Objective Function)

objective ( ϕ ) = E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] + γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \text{objective}(\phi) = \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] + \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] objective(ϕ)=E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]+γExDpretrain[log(πϕRL(x))]

4. 公式解读

第一部分:PPO 的核心目标

E ( x , y ) ∼ D π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) ] \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) \right] E(x,y)DπϕRL[rθ(x,y)βlog(πSFT(yx)πϕRL(yx))]

  • E ( x , y ) ∼ D π ϕ R L \mathbb{E}_{(x,y) \sim D_{\pi^{RL}_{\phi}}} E(x,y)DπϕRL表示在 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL 生成的分布 D π ϕ R L D_{\pi^{RL}_{\phi}} DπϕRL上的期望。
  • r θ ( x , y ) r_\theta(x, y) rθ(x,y) 是奖励模型 r θ r_\theta rθ对生成的响应 y y y的奖励。
  • β log ⁡ ( π ϕ R L ( y ∣ x ) π S F T ( y ∣ x ) ) \beta \log \left( \frac{\pi^{RL}_{\phi}(y \mid x)}{\pi^{SFT}(y \mid x)} \right) βlog(πSFT(yx)πϕRL(yx))是 KL 散度惩罚项,惩罚 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL偏离 SFT 模型 π S F T \pi^{SFT} πSFT的程度,其中 β \beta β控制 KL 惩罚的权重。

解释
这部分的目标是最大化模型的奖励,同时通过 KL 惩罚防止策略 π ϕ R L \pi^{RL}_{\phi} πϕRL与监督微调模型 π S F T \pi^{SFT} πSFT偏离过远。惩罚项确保模型在强化学习时不走偏,保持与原本训练目标的相似性。

第二部分:预训练损失

γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi^{RL}_{\phi}(x)) \right] γExDpretrain[log(πϕRL(x))]

  • E x ∼ D pretrain \mathbb{E}_{x \sim D_{\text{pretrain}}} ExDpretrain 表示在预训练数据分布 D pretrain D_{\text{pretrain}} Dpretrain上的期望。
  • log ⁡ ( π ϕ R L ( x ) ) \log(\pi^{RL}_{\phi}(x)) log(πϕRL(x)) 是 RL 策略 π ϕ R L \pi^{RL}_{\phi} πϕRL生成的结果的对数概率。
  • γ \gamma γ是预训练损失项的权重,控制预训练数据与强化学习的结合程度。

解释
这一部分引入了预训练的损失,使得模型能够保持在大规模预训练数据上的表现,防止模型在强化学习过程中完全依赖于奖励模型而失去通用能力。通过设置 γ \gamma γ,我们可以平衡预训练损失和强化学习损失的影响。

5. 策略和符号说明

  • π ϕ R L \pi^{RL}_{\phi} πϕRL 是在强化学习中学习到的策略,参数为 ϕ \phi ϕ
  • π S F T \pi^{SFT} πSFT 是通过监督学习微调得到的策略,它代表了模型在强化学习之前的性能。
  • β \beta β是 KL 惩罚的权重系数,控制 RL 策略和 SFT 策略的偏离程度。
  • γ \gamma γ 是预训练损失的权重系数,控制预训练梯度在 PPO 优化中的作用。

总结:

在这篇文章中,InstructGPT 使用了强化学习中的 PPO(Proximal Policy Optimization) 进行策略优化,同时通过引入 KL 散度惩罚项 来确保 RL 策略与 SFT 策略不过度偏离。此外,预训练损失 通过一个额外的项加入到了目标函数中,以解决在某些 NLP 任务上的性能回退问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2213141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智慧灌区信息化管理系统解决方案

一、方案背景 我国南方地区一些县级一般拥有5000多个大小水利设施, 尤其是灌区水利设施众多,这些灌区水利设施修建年代久,信息化程度低,但在保障农民生产、农田灌溉、抵抗自然灾害方面发挥着一定的作用,并能够最大限度…

go开发环境设置-安装与交叉编译(二)

1. 引言 Go语言,又称Golang,是Google开发的一门编程语言,以其高效、简洁和并发编程的优势受到广泛欢迎。作为一门静态类型、编译型语言,Go在构建网络服务器、微服务和命令行工具方面表现突出。 在开发过程中,开发者常…

科技云报到:大模型时代下,向量数据库的野望

科技云报到原创。 自ChatGPT爆火,国内头部平台型公司一拥而上,先后发布AGI或垂类LLM,但鲜有大模型基础设施在数据层面的进化,比如向量数据库。 在此之前,向量数据库经历了几年的沉寂期,现在似乎终于乘着Ch…

Yolov11与Yolov8在西红柿识别数据集上对比

Ultralytics 最新发布的 YOLOv11 相较于其上一代产品 YOLOv8,虽然没有发生革命性的变化,但仍有一些显著的改进(值得注意的是,YOLOv9 和 YOLOv10 并非由 Ultralytics 开发)。其中,最引人注目的变化包括&…

4.redis通用命令

文章目录 1.使用官网文档2.redis通用命令2.1set2.2get2.3.redis全局命令2.3.1 keys 2.4 exists2.5 del(delete)2.6 expire - (失效时间)2.7 ttl - 过期时间2.7.1 redis中key的过期策略2.7.2redis定时器的实现原理 2.8 type2.9 object 3.生产环境4.常用的数据结构4.1认识数据类型…

代码复现(四):DBINet

文章目录 datasets/AB2019BASDataset.pydatasets/ext_transforms.pynetwork/modules.pynetwork/DBINet.pynetwork/DBINet_Backbone.pyAB2019_train.py 代码链接:DBINet datasets/AB2019BASDataset.py 加载Australia Bushfire 2019 Burned Area Segmentation Datase…

【论文精读】RELIEF: Reinforcement Learning Empowered Graph Feature Prompt Tuning

Navigating the Digital World as Humans Do: UNIVERSAL VISUAL GROUNDING FOR GUI AGENTS 前言AbstractMotivationSolutionRELIEFIncorporating Feature Prompts as MDPAction SpaceState TransitionReward Function Policy Network ArchitectureDiscrete ActorContinuous Act…

【杂记】之语法学习第一课输入输出与数据类型与选择结构

首先学会新建源文件 1.打开DEV C 2.文件—>新建—>源代码 3.编写程序 4.编译并运行(F11) 第一个程序《Hello,World!》 题目描述 编写一个能够输出 Hello,World! 的程序。 提示: 使用英文标点符号;Hello,World! 逗号后…

8-基于双TMS320C6678 + XC7K420T的6U CPCI Express高速数据处理平台

1、板卡概述 板卡由我公司自主研发,基于6UCPCI架构,处理板包含双片TI DSP TMS320C6678芯片;一片Xilinx公司FPGA XC7K420T-1FFG1156 芯片;六个千兆网口(FPGA两个,DSP四个);DSP与FPGA之…

基于springboot+vue实现的酒店在线预订系统

基于springbootvue实现的酒店在线预订系统 (源码L文ppt)4-082 4.2 系统结构设计 构图是系统的体系结构,体系结构是体系结构体系的一部分,体系结构体系是体系结…

LabVIEW空间相机测控系统

空间相机是遥感技术中的核心设备,其在太空中的性能对任务的成功至关重要。为了确保空间相机能够在极端环境下稳定工作,地面模拟测试成为必不可少的环节。LabVIEW开发的空间相机测控系统,通过对温度、应力和应变等参数进行高精度测量&#xff…

LeetCode 3319. 第 K 大的完美二叉子树的大小

LeetCode 3319. 第 K 大的完美二叉子树的大小 给你一棵 二叉树 的根节点 root 和一个整数k。 返回第 k 大的 完美二叉子树的大小,如果不存在则返回 -1。 完美二叉树 是指所有叶子节点都在同一层级的树,且每个父节点恰有两个子节点。 子树 是指树中的某一…

计算机网络:数据链路层 —— 以太网(Ethernet)

文章目录 局域网局域网的主要特征 以太网以太网的发展100BASE-T 以太网物理层标准 吉比特以太网载波延伸物理层标准 10吉比特以太网汇聚层交换机物理层标准 40/100吉比特以太网传输媒体 局域网 局域网(Local Area Network, LAN)是一种计算机网络&#x…

本地装了个pytorch cuda

安装命令选择 pip install torch1.13.1cu116 torchvision0.14.1cu116 torchaudio0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 torch版本查看 python import torch print(torch.__version__) 查看pytorch能否使用cuda import torch# 检查CUDA是否可用…

如何用AWG实现脉冲激光输出

脉冲激光二极管提供强功率短脉冲的能力,使其成为目标指定和测距等军事应用的理想选择。事实上,开发这些二极管的许多历史动机都有军事渊源。然而,当今的技术进步和成本降低的大背景之下,在测试测量和医学领域新的应用得以开辟。 …

113.WEB渗透测试-信息收集-ARL(4)

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于: 易锦网校会员专享课 上一个内容:112.WEB渗透测试-信息收集-ARL(3)-CSDN博客 等待搭建成功 创建成…

Vue深入了解

Vue深入了解 MVVMv-model (双向数据绑定原理)异步更新keep-alive原理$nextTick原理computed 和 watch 的区别css-scoped虚拟DOMVuex && PiniaVue-router原理proxy 与 Object.defineProperty组件通信方式 MVVM <!DOCTYPE html> <html lang"en">&…

声波驱鸟 全向强声广播的应用

HT-360A多层叠形360向广播是恒星科通自主研发的一款应急广播专用设备&#xff0c;该设备内部采用1-4组换能器垂直阵列设置&#xff0c;水平采用指数函数碟形堆叠技术&#xff0c;在垂直方向上多层碟扬声器可实现360度环形垂直阵列&#xff0c;实现多层声场叠加。 系统可采用4G…

Linux驱动中的并发与竞争处理

Linux是一个多任务操作系统&#xff0c;肯定会存在多个任务共同操作同一段内存或者设备的情况&#xff0c;多个任务甚至中断都能访问的资源叫做共享资源&#xff0c;就和共享单车一样。在驱动开发中要注意对共享资源的保护&#xff0c;也就是要处理对共享资源的并发访问。比如共…

智慧校园打架斗殴检测预警系统 异常奔跑检测系统 Python 和 OpenCV 实现简单

在当今数字化时代&#xff0c;智慧校园建设已成为教育领域的重要发展方向。校园安全作为学校管理的重中之重&#xff0c;如何借助先进的技术手段实现高效、精准的安全监控&#xff0c;成为了教育工作者和技术专家共同关注的焦点。其中&#xff0c;智慧校园打架斗殴检测预警系统…