trl中的PPO代码解析(炒冷饭版)

news2024/11/15 18:01:51

不说其他的解释,上来就看代码。建议先对PPO的整体流程有了解。

trl的版本为0.4.0,注:【新版的trl中代码更复杂,如果只是想读懂PPO具体怎么用trl实现的,0.4.0版本即可】
在这里插入图片描述

step1: rollout

ppo_trainer.generate()函数使用policy model生成rollout
在这里插入图片描述

step2:evaluate

使用reward model对step1产生的rollout进行evaluate,获得一个标量的score,这个score并不是rewards,step4计算得到的才是最终的rewards
在这里插入图片描述

step3: logprobs

从old policy model和ref model中获得rollout的logits, values等值,用于后续计算rewards。
在这里插入图片描述

对应的代码部分为:
在这里插入图片描述

step4: rewards

注意:这里产生的变量中,score变成了rewards。
在这里插入图片描述
PPO中,为了防止policy model过度偏离ref model,会在计算rewards过程中额外增加一项KL散度,
r e w a r d s = s c o r e − λ K L ( π θ ( a ∣ s ) ∣ ∣ π θ r e f ( a ∣ s ) ) rewards = score - \lambda KL(\pi_{\theta}(a|s)||\pi_{\theta_{ref}}(a|s)) rewards=scoreλKL(πθ(as)∣∣πθref(as))
对应的代码部分为:
在这里插入图片描述

step5: train_minibatch

注意,这里的logprobs, vpreds, 与old_logprobs, old_values均是policy LM产生的,但是参数不一样。
在这里,产生logprobs, vpreds的policy LM的参数是会按照mini_batch_size进行不断更新的,所以每个mini_batch_size对于的new policy LM的参数是不一样的。而产生old_logprobs, old_values的old policy LM的参数对于每个mini_batch_size是不变的。

可以按照一般的训练神经网络的过程理解:产生old_logprobs, old_values的old policy LM的参数是按照epoch更新的,而产生logprobs, vpreds的new policy LM是按照step更新的。
在这里插入图片描述

对应的代码部分为:
在这里插入图片描述

step6: advantages

根据old_values, rewards,计算优势,在进一步计算出returns
在这里插入图片描述
对应的代码部分为(代码中的values为old_values):
在这里插入图片描述

step7: critic_loss

critic loss通常是通过均方误差(MSE)来计算。对于每一个状态,我们都有一个由critic网络预测的预期回报 v p r e d s vpreds vpreds,以及一个真实的回报 r e t u r n s returns returns,critic_loss是二者的平方差。
对应的代码部分为:
在这里插入图片描述

step8: actor loss

actor loss是基于策略梯度的损失函数,用于优化policy。在ppo中,通常使用一种称为重要性采样(importance sampling)的技术来计算策略梯度。
m a x i m i z e θ    E π θ ′ [ m i n ( r t ( θ ) A π θ o l d ( s , a ) ,   c l i p ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A π θ o l d ( s , a )   ) ] maximize_{\theta} \ \ E_{\pi_{\theta^{'}}}[min( r_{t}(\theta) A^{\pi_{\theta_{old}}}(s,a),\ clip(r_{t}(\theta), 1-\epsilon, 1+\epsilon)A^{\pi_{\theta_{old}}}(s,a)\ )] maximizeθ  Eπθ[min(rt(θ)Aπθold(s,a), clip(rt(θ),1ϵ,1+ϵ)Aπθold(s,a) )]
其中, r t ( θ ) = π θ ( a ∣ s ) π θ o l d ( a ∣ s ) r_{t}(\theta) = {\pi_{\theta}(a|s) \over \pi_{\theta_{old}}(a|s)} rt(θ)=πθold(as)πθ(as),这一项是新旧策略的比率, A π θ o l d ( s , a ) A^{\pi_{\theta_{old}}}(s,a) Aπθold(s,a)是优势函数,clip是裁剪函数,将其裁剪到 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon,1+\epsilon] [1ϵ,1+ϵ]之间。这个损失函数的目标是,最大化新策略的期望回报,同时限制新旧策略之间的差异。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2160794.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从入门到精通:QT 100个关键技术关键词

Qt基础概念 Qt Framework - 一个跨平台的C图形用户界面应用程序开发框架。它不仅提供了丰富的GUI组件,还包括网络、数据库访问、多媒体支持等功能。 Qt Creator - Qt官方提供的集成开发环境(IDE),集成了代码编辑器、项目管理工具、…

2024年AI技术爆发的元年,用对工具,让你副业比主业赚得多!

大家好,我是强哥 文字的力量不容小觑,或许你没有多好的文笔,或许你已经很久没有拿笔写字了,但是没关系,我们有工具! AI时代的到来,不会用工具,那你可就OUT了 如果你觉得文字不能赚…

Convert excel format exception.You can try specifying the ‘excelType‘

在使用easyexcel读取文件流获取集合的时候报了这个错 在点进代码抛出异常的地方,发现这么一段逻辑 是通过文件流的前8个字节来判断文件的类型,实际上这种判断规则是无法保证准确的。然后自然的想到是不是引入的jar包版本太旧了,所以有这个b…

Axure大屏可视化模板:跨领域数据分析平台原型案例

随着信息技术的飞速发展,数据可视化已成为各行各业提升管理效率、优化决策过程的重要手段。Axure作为一款强大的原型设计工具,其大屏可视化模板在农业、园区、城市、企业数据可视化、医疗等多个领域得到了广泛应用。本文将通过几个具体案例,展…

安全测试|如何使用burpsuite+xray实现联动测试

目的:安全测试过程中手动分析测试与xray自动化扫描测试结合,这样可以从多层保障安全测试的分析,针对平台业务接口量大的安全测试是十分有用的,可以实现双向测试同时开始。 1.xray 安装和使用 1.1 下载地址:xray commu…

git push错误:Out of memory, malloc failed (tried toallocate 947912704 bytes)

目录 一、错误截图 二、解决办法 一、错误截图 因项目文件过大,http.postBuffer设置的内存不够,所以报错。 二、解决办法 打开cmd窗口,执行如下命令即可 git config --global http.postBuffer 1024000000 如图所示 执行完成以后&#…

WinCC中归档数据片段的时间和尺寸设置

1.归档数据片段介绍工控人加入PLC工业自动化精英社群 1.1 概述 WinCC V6.2 开始的后台数据库采用了MS SQL Server 2005 ,所以归档方式与V5 有所不同,它的运行数据存放在数据片段(segment)当中,工程师可以…

Protobuf:基本概念与使用流程

Protobuf:基本概念与使用流程 基本概念Linux 安装使用流程.proto文件编译使用 运行机制 基本概念 在进行网络编程时,经常需要进行数据传输,只有双方主机都保证数据格式的一致性,才能保证数据被正常解析。这个过程称为序列化与反序…

召回04 离散特征的处理

推荐系统会将一个id映射成一个向量 Qne-Hot编码 Embedding(嵌入): 把每个类别映射成一个低维的稠密向量

Drive.js 的一些 Api 使用记录

文章目录 2024 年 drive.js 的基础使用想在下一步的时候处理些逻辑呢?(同步)Element 的各种选择器 2024 年 drive.js 的基础使用 安装就跳过了 npm install driver.js ,一行代码就可以搞定 官网的 Basic Usage 基础使用的截图如下: 想在下…

C++番外篇——对于继承中子类与父类对象同时定义其析构顺序的探究

思考这样一串代码的运行结果&#xff1a; #include <iostream> using namespace std; class Person { public:~Person() { cout << "~Person()" << endl; } }; class Student:public Person { public:~Student() { cout << "~Student(…

线程池工作原理?

线程池的工作原理&#xff1a; 当任务过来时&#xff0c;如果线程池中的线程数小于核心线程数&#xff0c;就创建线程。&#xff08;默认情况下&#xff0c;线程池不会预先创建线程&#xff0c;但可以配置&#xff09;当核心线程数满了以后&#xff0c;提交过来的任务会放到阻塞…

Axure9破解

1.下载安装包 通过百度网盘分享的文件&#xff1a;Axure RP 9.zip 链接&#xff1a;https://pan.baidu.com/s/1Lcu-gg4qF8tTkOlt7bC2ww?pwdwmqq 提取码&#xff1a;wmqq 2.设置登录以及破解码 位置&#xff1a;帮助-管理授权-添加key Licensee&#xff1a;123456 Key&#…

Ping32:一站式终端安全解决方案,企业安心之选

在数字化时代&#xff0c;企业的终端安全面临着前所未有的挑战。随着网络威胁的日益复杂化和多样化&#xff0c;如何确保终端设备的安全稳定运行&#xff0c;保护企业敏感数据不被泄露&#xff0c;成为了每个企业必须面对的重要课题。正是在这样的背景下&#xff0c;Ping32作为…

第十四届蓝桥杯嵌入式国赛

一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目&#xff0c;包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题&#xff0c;可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …

探索LLM中的CoT链式推理:ECHO方法深度解读

近年来&#xff0c;随着大型语言模型&#xff08;LLMs&#xff09;的快速发展&#xff0c;如何有效利用这些模型进行复杂任务的推理成为了研究热点。其中&#xff0c;链式思考&#xff08;Chain-of-Thought, CoT&#xff09;推理方法作为一种有效的策略&#xff0c;能够显著提升…

Redhat 6,7,8系(复刻系列) 一键部署Oracle12c zip

Oracle12c前言 Oracle 12c是甲骨文公司推出的一款关系数据库管理系统,它引入了多项创新特性,如多租户架构、大数据处理和云部署,适用于企业级应用。以下是Oracle 12c的详细介绍: Oracle 12c的主要特点 高性能:通过多线程处理、自动优化等技术,提高了数据库的查询和处理…

云栖大会 | 天润融通发布微藤智能体平台,中国客户联络正式进入“智能体时代”

9月19日&#xff0c;以“云启智跃&#xff0c;产业蝶变”为主题的2024云栖大会在杭州正式开幕。大会持续三天&#xff0c;聚焦AI时代的技术升级与实践应用&#xff0c;设有三大主论坛、400多个分论坛&#xff0c;并开放4万平方米的智能科技展区&#xff0c;展示全球百余款AI应用…

CHARLS数据库系列教程(3)---绘制(加权和不加权)基线表一

CHARLS 是一项具备中国大陆 45 岁及以上人群代表性的追踪调查&#xff0c;旨在建设一个高质量的公共微观数据库&#xff0c;采集的信息涵盖社会经济状况和健康状况等多维度的信息&#xff0c;以满足老龄科学研究的需要。 为利用国际上最佳的数据采集方式&#xff0c;并确保研究…

2024年工业制造企业CRM研究报告:需求清单、市场格局、案例分析

我国是世界上产业体系最完备的国家&#xff0c;拥有全球规模最大、门类最齐全的生产制造体系&#xff0c;在500种主要工业产品中&#xff0c;有四成以上产品产量位居全球第一。2023年制造业增加值达33万亿元&#xff0c;占世界的比重稳定在30%左右&#xff0c;我国制造业增加值…