PPO和文本生成

news2024/12/23 4:51:41

策略梯度

策略梯度(Policy Gradient)方法梯度的计算如下:
E ( a t , s t ) ∈ π θ [ A ^ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \nabla_ \theta \log \pi_\theta(a_t | s_t)] E(at,st)πθ[A^tθlogπθ(atst)] A ^ t \hat A_t A^t是优势函数(advantage function) A t A_t At的估计。
A t = Q ( s t , a t ) − V ( s t ) A_t=Q(s_t, a_t)-V(s_t) At=Q(st,at)V(st)优势函数计算的是,在该状态下采取这个行动的奖励与在该状态下的平均奖励的差值。
上面的导数可以通过对下面的目标求导获得:
L P G ( θ ) = E ( a t , s t ) ∈ π θ [ A ^ t log ⁡ π θ ( a t ∣ s t ) ] L^{PG}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\hat A_t \log \pi_\theta(a_t | s_t)] LPG(θ)=E(at,st)πθ[A^tlogπθ(atst)]

PPO(Proximal Policy Optimization)

PPO有两个形式,其中一种形式PPO_CLIP的优化目标函数是:
L C L I P ( θ ) = E ( a t , s t ) ∈ π θ [ min ⁡ ( r t ( θ ) A ^ t , c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t ) ] (1) L^{CLIP}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[\min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t)] \tag{1} LCLIP(θ)=E(at,st)πθ[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)](1)其中 r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta)=\frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{old}}(a_t | s_t)} rt(θ)=πθold(atst)πθ(atst)
PPO算法中的advantage用下面的公式估计:
A ^ t = δ t + ( γ λ ) δ t + 1 + ⋯ + ( γ λ ) T − t + 1 δ T − 1 \hat A^t = \delta^t + (\gamma \lambda)\delta_{t+1} + \cdots+ (\gamma \lambda)^{T-t+1}\delta_{T-1} A^t=δt+(γλ)δt+1++(γλ)Tt+1δT1其中 δ t = r t + γ V ( s t + 1 ) − V ( s t ) \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) δt=rt+γV(st+1)V(st)
通常情况下,我们用一个网络学习策略和价值函数,这样策略和价值函数能共享参数,那么就需要结合策略代理和价值函数误差项的损失函数。再加上熵奖励(entropy bonus)来以确保足够的探索,优化目标变为:
L C L I P + V F + S ( θ ) = E ( a t , s t ) ∈ π θ [ L t C L I P ( θ ) − c 1 L t V F ( θ ) + c 2 S [ π θ ] ( s t ) ] L^{CLIP+VF+S}(\theta)=\mathbb E_{(a_t,s_t) \in \pi_\theta}[L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S[\pi_\theta](s_t)] LCLIP+VF+S(θ)=E(at,st)πθ[LtCLIP(θ)c1LtVF(θ)+c2S[πθ](st)]其中 L t V F ( θ ) = ( V θ ( s t ) − V t t a r g ) 2 L_t^{VF}(\theta)=(V_\theta(s_t)-V_t^{targ})^2 LtVF(θ)=(Vθ(st)Vttarg)2是价值函数的误差项,S是entropy bonus。

文本生成

在文本生成的情况下,给一个prompt,生成完整的response,是一个episode。动作空间是vocabulary。每生成一个词是一个时间步。

公式(1)需要advantage的估计,为了计算advantage,我们需要定义奖励(reward) r r r和估计状态价值函数 V ( s ) V(s) V(s)

用于强化学习的reward计算如下:
R ( x , y ) = r ( x , y ) − β log ⁡ π ( y ∣ x ) ρ ( y ∣ x ) R(x,y) = r(x,y) - \beta\log\frac{\pi(y|x)}{\rho(y|x)} R(x,y)=r(x,y)βlogρ(yx)π(yx)x是问题,y是回答, r ( x , y ) r(x,y) r(x,y)是reward model的输出,也就是下面代码中的score。注意这里reward model的输出称之为score,送入强化学习部分的才称为reward。 π ( y ∣ x ) \pi(y|x) π(yx)是要学习的生成模型, ρ ( y ∣ x ) \rho(y|x) ρ(yx)是参数固定的原始生成模型。
在trl库中reward的计算如下:

   def compute_rewards(
       self,
       scores: torch.FloatTensor,
       logprobs: torch.FloatTensor,
       ref_logprobs: torch.FloatTensor,
       masks: torch.LongTensor,
   ):
       """
       Compute per token rewards from scores and KL-penalty.

       Args:
           scores (`torch.FloatTensor`):
               Scores from the reward model, shape (`batch_size`)
           logprobs (`torch.FloatTensor`):
               Log probabilities of the model, shape (`batch_size`, `response_length`)
           ref_logprobs (`torch.FloatTensor`):
               Log probabilities of the reference model, shape (`batch_size`, `response_length`)
       """
       rewards, non_score_rewards = [], []
       for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
           # compute KL penalty (from difference in logprobs)
           kl = self._kl_penalty(logprob, ref_logprob)
           non_score_reward = -self.kl_ctl.value * kl
           non_score_rewards.append(non_score_reward)
           reward = non_score_reward.clone()
           last_non_masked_index = mask.nonzero()[-1]

           # reward is preference model score + KL penalty
           reward[last_non_masked_index] += score
           rewards.append(reward)
       return torch.stack(rewards), torch.stack(non_score_rewards)

可以看到上面的实现中,只将reward model的score添加到最后一个token的reward上,其他token的reward来自当前模型和 原始生成模型之间KL散度。这么做是为了减轻奖励模型的过度优化问题。

在trl库中用一个网络AutoModelForCausalLMWithValueHead学习策略 π θ ( s ) \pi_\theta(s) πθ(s)和状态价值函数 V ( s ) V(s) V(s)。AutoModelForCausalLMWithValueHead在普通AutoModelForCausalLM模型上了一个线性层nn.Linear(hidden_size, 1),用于估计状态价值函数 V ( s ) V(s) V(s)
普通AutoModelForCausalLM模型估计token概率即可作为策略 π θ ( s ) \pi_\theta(s) πθ(s)

在trl库中advantage的计算如下:

    def compute_advantages(
        self: torch.FloatTensor,
        values: torch.FloatTensor, # AutoModelForCausalLMWithValueHead输出的状态价值估计V
        rewards: torch.FloatTensor, # compute_rewards函数计算得到的rewards
        mask: torch.FloatTensor,
    ):
        lastgaelam = 0
        advantages_reversed = []
        gen_len = rewards.shape[-1]

        values = values * mask
        rewards = rewards * mask

        for t in reversed(range(gen_len)):
            nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
            delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)

        returns = advantages + values
        advantages = masked_whiten(advantages, mask)
        advantages = advantages.detach()
        return values, advantages, returns

完整的PPO算法如下:
在这里插入图片描述

Reference

Proximal Policy Optimization Algorithms
Fine-Tuning Language Models from Human Preferences
Training language models to follow instructions with human feedback
https://github.com/huggingface/trl

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/861861.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

了解IL汇编跳转语句

il代码&#xff0c; .assembly extern mscorlib {}.assembly Test{.ver 1:0:1:0}.module test.exe.method static void main() cil managed{.maxstack 5.entrypointldstr "Enter First Number"call void [mscorlib]System.Console::WriteLine (string)call string …

低代码平台 数据库字段值不重复

在开发过程中&#xff0c;要求表里某字段值唯一 一、场景 在单据&#xff0c;要求某字段值不重复 查看数据模型&#xff1a; 查看单据&#xff1a; 二、问题 区域编码&#xff0c;区域名称不重复 三、解决方案 1&#xff09;数据库加索引 2&#xff09;书写保存后存储过…

Python(七十八)字符串的常用操作——字符串大小写转换操作

❤️ 专栏简介&#xff1a;本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中&#xff0c;我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 &#xff1a;本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

推荐两本书《JavaRoadmap》、《JustCC》

《JavaRoadmap》 前言 本书的受众 如果你是一名有开发经验的程序员&#xff0c;对 Java 语言语法也有所了解&#xff0c;但是却一直觉得自己没有入门&#xff0c;那么希望这本书能帮你打通 Java 语言的任督二脉。 本书的定位 它不是一本大而全的书&#xff0c;而是一本打通、…

JDBC连接数据库及改造工具类

引入mysql驱动依赖,一般会建个lib包,如果是java web项目 一般将以来包创建在web->WEB-INF下, 这里我就随便了 建议 try {} catch (SQLException throwables) {throwables.printStackTrace(); }finally {} 的写法,这里就简写了 写个工具类 public class DBUtil {static{try…

C语言案例 完数求解-09

题目&#xff1a;编写一个程序找出1000以内的所有完数。 步骤一&#xff1a;定义程序目标 编写一个C程序&#xff0c;输出1000以内的所有完数 步骤二&#xff1a;程序设计 1.完数原理&#xff1a;一个数如果恰好等于它的因子之和&#xff0c;这个数就称为“完数”。例如6 1 …

__attribute__ ((constructor))和__attribute__ ((destructor))用法

目录 1. 前言 2. __attribute__介绍 3. 测试代码 4. 总结 1. 前言 最近看代码&#xff0c;有个函数根本就没被任何函数调用&#xff0c;但从程序运行结果来看&#xff0c;该函数是被调用了的&#xff0c;找很久都没找到哪里调用了&#xff0c;最后发现该函数…

Java顺序表解析与应用

一、顺序表概念 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构&#xff0c;一般情况下采用数组存储。在数组上完成数据的增删查改。 二、主要功能接口实现 Java顺序表底层就是一个动态数组。其主要功能接口如下&#xff1a; // 1.打印顺序表&#xff0…

stack(栈)和queue(队列)

目录 1.stack的介绍和使用(栈) 1.1 stack的介绍 1.2 stack的使用 1.3stack的模拟实现 2.queue的介绍和使用(队列) 2.1queue的介绍 2.3queue的模拟实现 3.priority_queue的介绍和使用 3.1priority_queue的介绍 3.2 priority_queue的使用 3.3priority_queue的模拟实现 …

高忆管理:碳酸锂期现货价格大幅回落 行业期盼找回“价格之锚”

6月末以来&#xff0c;国内碳酸锂价格的反弹态势戛然而止&#xff0c;再度陷入接连跌落格式。现货方面&#xff0c;据上海钢联数据显现&#xff0c;电池级碳酸锂价格6月26日至今已接连22次下调&#xff0c;从31.50万元/吨下调至最新的25.60万元/吨&#xff1b;期货方面&#xf…

【问题解决】Git命令行常见error及其解决方法

以下是我一段时间没有使用xshell&#xff0c;然后用git命令行遇到的一些系列错误和他们的解决方法 遇到了这个报错&#xff1a; fatal: Not a git repository (or any of the parent directories): .git 我查阅一些博客和资料&#xff0c;可以解决的方式&#xff1a; git in…

MyBatis操作数据库常见用法总结2

文章目录 1.动态SQL使用什么是动态sql为什么用动态sql标签拼接标签拼接标签拼接标签拼接标签拼接 补充1&#xff1a;resultType和resultMap补充2&#xff1a;后端开发中单元测试工具使用&#xff08;Junit框架&#xff09; 1.动态SQL使用 以insert标签为例 什么是动态sql 是…

vue3+vite使用vite-plugin-svg-icons

使用vite-plugin-svg-icons插件显示本地svg图标 在开发项目的时候&#xff0c;经常会用到svg矢量图标&#xff0c;而且我们使用svg以后&#xff0c;页面上加载的不再是图片资源&#xff0c;这对页面性能来说是个很大的提升&#xff0c;而且我们svg文件比img要小很多&#xff0c…

使用xrdp协议远程桌面控制树莓派,无需公网IP!

远程桌面控制树莓派&#xff0c;我们可以用xrdp协议来实现&#xff0c;它内部使用的是windows远程桌面的协议。我们只需要在树莓派上安装xrdp&#xff0c;就可以在同个局域网下远程桌面控制树莓派。 而如果需要在公网下远程桌面控制树莓派&#xff0c;可以通过cpolar内网穿透&…

Linux固件子系统的实现机制简介

一、Linux固件子系统概述 固件是硬件设备自身执行的一段程序。固件一般存放在设备flash内。而出于成本和便利性的考虑&#xff0c;通常是先将硬件设备的运行程序打包为一个特定格式的固件文件&#xff0c;存储到终端系统内&#xff0c;通过终端系统给硬件设备进行升级。Linux内…

self-attention(自注意力机制)

先举个有趣的例子理解 Q 、 K 、 V Q、K、V Q、K、V&#xff1a; 将我们要查询的内容&#xff0c;和商品列表进行相似度匹配&#xff0c;先拿出相似度更高的商品列表。 再根据以往的评价&#xff0c;计算出总分&#xff0c;按照分数进行排序。 self-attention d k \sqrt{d_k}…

ubuntu 安装 python

ubuntu 安装 python 初环境与设备查询是否安装安装python 本篇文章将介绍ubuntu 安装 python 初 希望能写一些简单的教程和案例分享给需要的人 环境与设备 系统&#xff1a;ubuntu 查询是否安装 因为系统也许会自带一个python&#xff0c;所以验证一下&#xff0c;如果自…

Linux mysql5.7开启 binlog

查看 mysql是否开启 binlog。 查看命令&#xff1a; show variables like %log_bin%; log_bin OFF 是关闭的状态。 编辑my.cnf配置文件 vim /etc/my.cnf 默认的配置文件内容&#xff1a; 增加下面内容 server_id 1 binlog_format ROW log-bin mysql_log_bin 重启mysq…

uniapp 微信小程序 订阅消息

第一步&#xff0c;需要先去小程序官方挑选一下订阅模板拿到模板id 订阅按钮在头部导航上&#xff0c;所以 <u-navbar :bgColor"bgColor"><view class"u-nav-slot" slot"left" click"goSubscribe"><image :src"g…

【Java】try|catch|throws 具体详解+应用

目录 tryCatch 基本介绍 使用细节 throws异常处理 基本介绍 ​ 使用细节 自定义异常 基本概念 步骤 throw和throws的区别 tryCatch 基本介绍 使用细节 throws异常处理 基本介绍 使用细节 自定义异常 基本概念 步骤 throw和throws的区别