ChatGLM-RLHF(五)-PPO(Proximal Policy Optimization)原理实现代码逐行注释

news2024/11/27 11:42:14

 一,前言

从open AI 的论文可以看到,大语言模型的优化,分下面三个步骤,SFT,RM,PPO,我们跟随大神的步伐,来学习一下这三个步骤和代码实现,本章介绍PPO代码实现。


上章我们介绍了PPO算法的公式,其形式如下:

$J(\theta) = E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] +E_{\tau \sim p_{\theta_{old}}(\tau)}[\min(r_t(\theta)A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$

下面是这段公式与代码的对应关系:

一,优势函数(Advantage)计算公式里的At:

1. 优势函数用于衡量当前状态或行动相对于期望值的优劣程度。A(s,a) = Q(s,a) - V(s),

2. $A_t$ 表示时间步 t 的优势函数,计算时使用当前策略 $\pi_{\theta}(a_t|s_t)$ 和旧策略 $\pi_{\theta_{\text{old}}}(a_t|s_t)$ 分别计算出 $r_t(\theta)$,然后使用 $A_t$$r_t(\theta)$ 计算出 $E_{\tau \sim p_{\theta_{old}}(\tau)}[\min(r_t(\theta)A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)A_t)]$部分的损失。同时,$A_t$还用于计算E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)] 部分的损失。优势函数的计算是一个重要的部分,它帮助我们估计当前策略相对于旧策略的性能提升。

3. 代码里没有直接使用Q(s,a) - V(s),而是使用了GAE-Lambda 算法进行计算,GAE-Lambda 算法通过将多个时间步的奖励加权平均,计算出当前时间步的优势函数,GAE-Lambda 算法的优点在于,它不仅考虑了当前时间步的奖励,还考虑了未来时间步的奖励,并且通过 Lambda 参数进行加权平均,从而更好地估计当前状态或行动的价值。这使得 GAE-Lambda 算法在训练强化学习模型时具有更好的稳定性和收敛性。

GAE-Lambda 算法的计算过程如下:

  1. 对于每个时间步 t,计算 delta,即当前时间步的奖励加上折扣因子乘以下一个时间步的值减去当前时间步的值:

    delta = rewards[t] + gamma * values[t+1] - values[t]

  2. 对于每个时间步 t,计算 GAE-Lambda,即 delta 加上折扣因子乘以 Lambda 倍的上一个时间步的 GAE-Lambda:

    lastgaelam = delta + gamma * lam * lastgaelam

  3. 将计算得到的 GAE-Lambda 添加到 advantages_reversed 列表中。

  4. 将 advantages_reversed 列表转换为张量,并进行维度转置,得到最终的优势函数张量 advantages。

  5. 具体如下代码

# 计算优势函数
for t in reversed(range(gen_len)):
    nextvalues = values[:,t + 1] if t < gen_len - 1 else last_values  # 获取下一个时间步的值,如果当前时间步是最后一个时间步,则使用 last_values
    delta = rewards[:, t] + self.config.gamma * nextvalues - values[:,t]  # 计算 delta,即当前时间步的奖励加上折扣因子乘以下一个时间步的值减去当前时间步的值
    lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam  # 计算 GAE-Lambda,即 delta 加上折扣因子乘以 Lambda 倍的上一个时间步的 GAE-Lambda
    advantages_reversed.append(lastgaelam)  # 将计算得到的 GAE-Lambda 添加到 advantages_reversed 列表中
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)  # 将 advantages_reversed 列表转换为张量,并进行维度转置

二,值函数的损失(Value Function Loss)的计算

值函数的损失公式通常使用均方差(Mean Squared Error,MSE)来衡量值函数的预测误差。值函数的损失公式可以表示为:

L(θ) = 0.5 * E[(V(s) - R)^2]

其中,L(θ)表示值函数的损失,θ表示值函数的参数,V(s)表示值函数对状态s的预测值,R表示实际的回报值。

这个公式的含义是,首先,通过 clip_by_value 函数将当前状态的价值函数 values 限制在一个区间内,得到 vpredclipped。然后,分别计算使用原始价值函数和限制后的价值函数计算得到的损失,即 vf_losses1 和 vf_losses2。通过计算值函数对状态的预测值与实际回报值之间的差异的平方,来衡量值函数的预测误差。然后取这些差异的平方的期望值,再乘以0.5,得到最终的损失值。最终,将两者的较大值作为值函数的损失,通过 masked_mean 函数计算期望。

            # 值函数的损失
            vpredclipped = clip_by_value(
                values, values - self.config.cliprange_value, values + self.config.cliprange_value
            )
            vf_losses1 = (values - returns) ** 2
            vf_losses2 = (vpredclipped - returns) ** 2
            vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), masks)
            vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).double(), masks)

三,策略函数的损失(Policy Function Loss)的计算:

这部分对应公式E_{\tau \sim p_{\mu}(\tau)}[\sum_{t=0}^T \rho_t A_t \log \pi_{\theta}(a_t|s_t)]

在PPO算法中,我们采用两种不同的方式计算策略损失,即pg_losses和pg_losses2。这两种方式分别对应目标函数中的两个部分。

pg_losses表示使用原始比率计算得到的损失,即:

$ L^{PG}_1(\theta) = -\frac{1}{N} \sum_{i=1}^N \sum_{t=0}^{T_i} \rho_{i,t} A_{i,t} \log \pi_{\theta}(a_{i,t}|s_{i,t}) $

其中,N表示采样轨迹的数量,$\rho_{i,t}$ 表示第 i 条轨迹在时间步 t 的重要性采样比例,$A_{i,t}$表示第 i 条轨迹在时间步 t 的优势函数。

pg_losses2表示使用限制后的比率计算得到的损失,即:

$ L^{PG}_2(\theta) = -\frac{1}{N} \sum_{i=1}^N \sum_{t=0}^{T_i} \min(r_{i,t}(\theta)A_{i,t}, \text{clip}(r_{i,t}(\theta), 1-\epsilon, 1+\epsilon)A_{i,t}) \log \pi_{\theta}(a_{i,t}|s_{i,t}) $

其中,$r_{i,t}(\theta)$ 表示第i条轨迹在时间步t的比率,$\epsilon$表示剪切幅度。

最终,将两种方式计算得到的损失取较大值,即:

pg_loss = \max(pg_losses, pg_losses2)

            # 策略函数的损失
            logprobs = F.log_softmax(logits, dim=1)
            ratio = torch.exp(logprobs - old_logprobs)
            pg_losses = -advantages * ratio
            pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
            pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), masks)
            pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).double(), masks)

总损失计算

            # 总损失
            loss = pg_loss + self.config.vf_coef * vf_loss

四, 完整代码可以参考:

GitHub - Pillars-Creation/ChatGLM-RLHF-LoRA-RM-PPO: ChatGLM-6B添加了RLHF的实现,以及部分核心代码的逐行讲解 ,实例部分是做了个新闻短标题的生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/847513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java编程实践:实现Java接口的方法也建议加上@Override注解

说明 作为一个Java编程实践&#xff0c;实现接口的方法也强烈建议加上Override注解。这样做的好处&#xff1a; 阅读代码的时候&#xff0c;一眼就能看出来是新增的函数&#xff0c;还是实现接口的函数。加上Override注解&#xff0c;如果拼写错误&#xff0c;编译器马上就能…

电视盒子哪款好?内行整理超值网络电视盒子推荐

从事电视盒子这行已经五年了&#xff0c;很多朋友在挑选电视盒子时会咨询我的意见&#xff0c;我耗费半个月时间整理了超值网络电视盒子推荐&#xff0c;盘点目前最值得入手的五款电视盒子机型&#xff0c;想买电视盒子不知道电视盒子哪款好可以从下面五款中挑选&#xff1a; 榜…

VIM 编辑器: Bram Moolenaar

VIM 用了很长时间&#xff0c; 个人的 VIM 配置文件差不多10年没有更新了。以前写程序的时候&#xff0c; 编辑都用这个。 linux kernel&#xff0c; boost规模的代码都不在话下。现在虽然代码写的少了&#xff0c;依然是我打开文件的首选。 现在用手机了&#xff0c;配个蓝牙键…

idea中如何处理飘红提示

idea中如何处理飘红提示 在写sql时&#xff0c;总是会提示各种错误 查找资料&#xff0c;大部分都是说关提示&#xff0c;这里把错误提示选择为None即可 关掉以后&#xff0c;也确实不显示任何提示了&#xff0c;但总有一种掩耳盗铃的感觉 这个sms表明明存在&#xff0c;但是还…

android studio安卓真机调试

把usb 手机开启到usb调试模式,然后用usb线连接手机 安装adb 如果下载速度很慢,请使用vpn 终端需要先安装brew /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"使用brew安装adb brew install android-platfor…

面试遇到登录功能测试用例设计,你回答对了吗

给你一个登录功能&#xff0c;如何设计测试用例 哪怕是最常用最小的一个登录功能&#xff0c;其实涉及到的测试用例也是非常多的&#xff0c;这个题目通常会通过面试来考察求职者的综合能力&#xff0c;尤其是测试用例的设计思维&#xff0c;因为你即使你背了各种测试用例设计…

开发一款保护程序检测进程假死,精准打开保护的程序

网上很多保护程序都收费, 有免费的,可以将一般程序改成windows服务,我没用,应该很强大 功能点: 1,首先要有能加入保护程序的功能 2,不断的轮询检测程序是否已经运行 3,不断的轮询检测程序是否假死 4,一些其他检测 将保护的程序存入文件列表 保护程序运行时加载…

嵌入式开发学习(STC51-9-led点阵)

内容 点亮一个点&#xff1b; 显示数字&#xff1b; 显示图像&#xff1b; LED点阵简介 LED 点阵是由发光二极管排列组成的显示器件 通常应用较多的是8 * 8点阵&#xff0c;然后使用多个8 * 8点阵可组成不同分辨率的LED点阵显示屏&#xff0c;比如16 * 16点阵可以使用4个8 *…

恒盛策略:医药股反弹,掀涨停潮!

今天上午A股商场涨跌互现&#xff0c;上证指数一度显着跌落&#xff0c;但临近上午收盘时翻红。 作为行情风向标&#xff0c;券商板块盘中一度大幅跌落&#xff0c;但随后快速收窄跌幅&#xff0c;板块内分解较为显着&#xff0c;其中市值超越1000亿元的龙头券商之一的中金公司…

OPENCV C++(八)HOG的实现

hog适合做行人的识别和车辆识别 对一定区域的形状描述方法 可以表示较大的形状 把图像分成一个一个小的区域的直方图 用cell做单位做直方图 计算各个像素的梯度强度和方向 用3*3的像素组成一个cell 3*3的cell组成一个block来归一化 提高亮度不变性 常用SVM分类器一起使用…

HTML Emoji和Emoji 参考手册

HTML表情可以用来在网页中插入各种表情符号图标&#xff0c;丰富了网页表现形式和视觉效果。下面是一些常用HTML表情代码大全&#x1f4dc; Emoji 参考手册 HTML Emoji 扩展&#xff1a;&#x1f4cc; HTML 自定义实现emoji - (freesion.com)

native vlan tag设置错误,导致交换机无法访问

一同事找来&#xff0c;说他的一个测试交换机&#xff0c;下挂一些测试设备&#xff0c;能正常访问&#xff0c;但交换机的ip192.168.100.128却无法telnet访问&#xff0c;ping过去显示无法访问目的主机&#xff0c;让给看一下原因&#xff1f; 已知组网这个交换机接在交换机的…

用于实体对齐的联合学习实体和关系表示2019 AAAI 8.7+8.8

用于实体对齐的联合学习实体和关系表示 摘要介绍相关工作实体对齐图卷积网络 问题公式我们的方法整体架构初步实体对齐图卷积层对齐训练 近似关系表示联合实体和关系对齐 实验总结 摘要 实体对齐是在不同知识图之间集成异构知识的一种可行方法。该领域的最新发展通常采用基于嵌…

端口映射软件可以做什么?快解析如何设置端口映射?

说到端口映射&#xff0c;首先说说nat。简单地说&#xff0c;nat就是在局域网内部网络中使用内部地址&#xff0c;而当内部节点要与外部网络进行通讯时&#xff0c;就在网关处&#xff0c;将内部地址替换成公用地址&#xff0c;从而在外部公网&#xff08;internet&#xff09;…

网络系统观察之道

什么是“可观察性”&#xff1f; 当然&#xff0c;“可观察性”这个术语并不是我们发明的。我们最开始从用户那里听到这个概念&#xff0c;这些用户主要来自网站可靠性工程 (SRE) 社区。有些信息来源认为&#xff0c;这个术语起源于硅谷巨头&#xff08;如 Twitter&#xff09…

CTF流量题解http2.pcapng

使用wireshark工具打开流量文件。 根据网络协议进行分组排序&#xff0c;对流量文件里面的内容进行观察。 16进制转换&#xff0c;16进制转换文本字符串&#xff0c;在线16进制转换 | 在线工具 (sojson.com) Base64编码/解码器&#xff0c;在线解码Base64 (sojson.com) https:…

VS2008总在当前项目文件夹创建3个不必要的文件夹的解决方法

如下图所示&#xff1a; 这3个文件夹都是无必要的空文件夹&#xff08;1.Visual Studio 2008 2.Visual Studio 2008Projects 3.Visual Studio 2008Templates&#xff09;&#xff0c;每个项目都这样就有点烦躁的了。每次打开还要给你重建。 解决方法&#xff1a; 1.重置“项…

【Java可执行命令】(十八)可视化监控和管理工具 jconsole:获取 JVM的内存使用情况、线程活动、GC 行为等重要指标的可视化工具 ~

Java可执行命令之jconsole 1️⃣ 概念2️⃣ 优势和缺点3️⃣ 使用3.1 语法格式3.2 注意事项 4️⃣ 应用场景&#x1f33e; 总结 1️⃣ 概念 jconsole 是 Java Development Kit (JDK) 自带的一款图形化监控和管理工具。它旨在提供一个简单而强大的界面&#xff0c;用于监视和管…

成品短视频App源码,开启你的创意视频之旅!

短视频App如今已成为人们记录和分享生活的热门方式。你是否想过自己拥有一款属于自己的短视频App呢?有了短视频App源码&#xff0c;就能轻松实现这一愿望。本文将介绍短视频App源码的优势、开发流程和功能特点&#xff0c;助你快速创建个性化短视频App&#xff0c;开启你的创意…

并发——什么是线程,什么是进程

文章目录 1.1. 何为进程?1.2. 何为线程? 1.1. 何为进程? 进程是程序的一次执行过程&#xff0c;是系统运行程序的基本单位&#xff0c;因此进程是动态的。系统运行一个程序即是一个进程从创建&#xff0c;运行到消亡的过程。 在 Java 中&#xff0c;当我们启动 main 函数时…