PPO(Proximal Policy Optimization Algorithms)论文解读及实现

news2025/1/12 17:21:15

论文标题:Proximal Policy Optimization Algorithms
核心思路:使用off policy 代替on policy,用一个策略网络来产生数据,用一个策略网络来更新参数,分别为policy_old和policy

0 摘要

  • Whereas standard policy gradient methods perform one gradient update per data sample, we propose a novel objective function that enables multiple epochs of minibatch updates
    由于标准的策略梯度算法每轮数据(每轮数据定义:从游戏开始到结束的一轮完整数据)只能更新一个梯度,提出了一种新的方法,使得每个数据集上可以更新多次。
  • The new methods, which we call proximal policy optimization (PPO), have some of the benefits of trust region policy optimization (TRPO),
  • 新提出的方法命名为 PPO,一些启发受益于TRPO。

1 Introduction

We propose a novel objective with clipped probability ratios, which forms a pessimistic estimate (i.e., lower bound) of the performance of the policy. To optimize policies, we alternate between sampling data from the policy and performing several epochs of optimization on the sampled data.
我们提出了一个新的目标(使概率clip在一个区间内),形成一个悲观的估计 (即,数据下界)。为了优化策略,我们在从策略中采样数据和对采样数据执行若干次优化之间交替进行。

2 Background: Policy Optimization

2.1 Policy Gradient Methods

g ^ = E ^ t [ ∇ θ l o g π θ ( a t ∣ s t ) A ^ t ] \hat g=\hat E_t[ \nabla _{\theta}log\pi_{\theta}(a_t|s_t) \hat A_t] g^=E^t[θlogπθ(atst)A^t]

where π theta is a stochastic policy and Aˆt is an estimator of the advantage function at timestep t.

L P G ( θ ) = E ^ t [ l o g π θ ( a t ∣ s t ) A ^ t ] L^{PG}(\theta)=\hat E_t[ log\pi_{\theta}(a_t|s_t) \hat A_t] LPG(θ)=E^t[logπθ(atst)A^t]

3 Clipped Surrogate Objective

在这里插入图片描述
Without a constraint, maximization of LCPI would lead to an excessively large policy update; hence, we now consider how to modify the objective, to penalize changes to the policy that move rt(θ) away from 1.
如果对(6)式没有约束,这个值将会导致过大的策略更新,我们现在考虑修改这个目标,惩罚远离1的策略变化
The main objective we propose is the following:
我们提出的目标函数如下:
在这里插入图片描述
通过公式(7),我们clip后的值取的是clip之前值的下界。
其中奖励At可能是整数,也可能是负数。

plots a single term (i.e., a single t) in LCLIP ; note that the probability ratio r is clipped at 1 − epsilon 到1+epsilon ,depending on whether the advantage is positive or negative.

在这里插入图片描述

4 Adaptive KL Penalty Coefficient

In our experiments, we found that the KL penalty performed worse than the clipped surrogate objective, however, we’ve included it here because it’s an important baseline.
在我们的实验中,发现实验KL散度惩罚项,没有CLIP效果好,这里仅作为一个baseline。
在这里插入图片描述

5 Algorithm

If using a neural network architecture that shares parameters between the policy and value function, we must use a loss function that combines the policy surrogate and a value function error term. This objective can further be augmented by adding an entropy bonus to ensure sufficient exploration.
Combining these terms, we obtain the following objective, which is (approximately) maximized each iteration:
如果使用一个神经网络架构在策略这价值函数共享参数,我们必须使用一个结合策略代理和价值函数误差项。这个目标可以被参数为添加熵奖励以确保足够的探索,
结合这些项,我们可以得到以下目标函数,最大化每一步迭代。
在这里插入图片描述
each of N (parallel) actors collect T timesteps of data. Then we construct the surrogate loss on these NT timesteps of data, and optimize it with minibatch SGD (or usually for better performance, Adam [KB14]), for K epochs.
共用大N个Actor,每个Actor收集T步的数据,然后构建surrogate 损失在这N*T的数据集上,我们使用小批量SGD来优化这个损失函数。-=
在这里插入图片描述

6 算法实现(pytorch)

以LunarLander-v2(月球着陆)游戏为例:
环境基本属性:
状态维度:state_dim=8
动作空间:action_dim=4

6.1 memory

用于存储游戏过程中的游戏数据,包含以下参数:
游戏每步选择的动作:actions=[]
游戏每步的状态:states=[]
每步动作出现的概率:logprobs=[]
每步动作的奖励:rewards=[]
是否游戏结束了:is_terminals=[]
通过以下代码获取actions,states,logprobs,
注意:memory 是policy_old 网络生成的
再通过环境执行action后,可以得到reward和is_terminal。
(完整代码见最后部分)

        state = torch.from_numpy(state).float().to(device)  	# 输入当前状态
        action_probs = self.action_layer(state)   # 经过action 网络层,输出动作概率
        dist = Categorical(action_probs)#按照给定的概率分布来进行采样
        action = dist.sample() # 安装aciton 概率,采样出一个action
        memory.states.append(state) # 添加当前状态
        memory.actions.append(action) # 添加当前动作
        memory.logprobs.append(dist.log_prob(action)) # 添加当前 概率

6.2 policy 网络结构

action 网络结构如下,输入为状态维度8,输出为action维度4。即输入环境状态,输出选择的aciton概率。
input_dim (batch_size,8)
output_dim (batch_size,4)
网络结构如下:

 # actor
self.action_layer = nn.Sequential(
                nn.Linear(state_dim, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, action_dim),
                nn.Softmax(dim=-1)
                )

critic网络输入为环境状态,维度为8,输出维度为1。即一个值,用于对当前状态的评价得分。
input_dim (batch_size,8)
output_dim (batch_size,1)
网络结构如下:

# critic
self.value_layer = nn.Sequential(
              nn.Linear(state_dim, n_latent_var),
              nn.Tanh(),
              nn.Linear(n_latent_var, n_latent_var),
              nn.Tanh(),
              nn.Linear(n_latent_var, 1)
              )
      

6.3 模型更新

上面memory存储游戏一定步数后,数据既可以开始用于网络训练。

  • 当前步reward 计算
    后续步reward对当前步reward是逐步衰减的。
    在这里插入图片描述
        rewards = []
        discounted_reward = 0
        # 使用reversed将奖励翻转,从最后一步往前计算累计衰减奖励。
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

使用新policy网络,输入old_states, old_actions,可以得到,新的action概率分布,和critical 值。

 logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
  • 论文核心公式实现
    在这里插入图片描述
# Finding the ratio (pi_theta / pi_theta__old):
 ratios = torch.exp(logprobs - old_logprobs.detach())

公式中A_t: advantages = rewards - state_values.detach()

# Finding Surrogate Loss:
 advantages = rewards - state_values.detach()

在这里插入图片描述

surr1 = ratios * advantages

在这里插入图片描述


surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
 -torch.min(surr1, surr2)

损失函数公式:
在这里插入图片描述

loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

新的policy网络更新K次后,将新的模型参数复制给老网络。然后使用老网络产生新的训练数据,再使用新的训练数据更新新的网络,如此往复循环。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/756372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python自动化办公:pptx篇

文章目录 简介能做什么PPT要素介绍官方demo高阶引申参考文献 202201笔记迁移 简介 python-pptx包是用来自动化处理ppt的。 使用的第一步是安装 pip install python-pptx相比python-docx,python-pptx的使用更为麻烦一些,原因有很多,比如说&…

波奇学Linux:make和Makefile

make和Makefile自动化构建并能决定源文件调用顺序,同时不必再写gcc命令 第一行依赖关系,第二行是tab键开头,是依赖方法 依赖关系:目标文件:依赖文件。 依赖方法:目标文件和依赖文件间的关系。 如果只有一条…

es下载历史的tar文件

第一步进入官网找到历史版本 第二步复制历史版本名称组合成下面的链接 直接get访问下载。如下链接所示只需要修改7.3.0这个版本号 https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.3.0-linux-x86_64.tar.gz

ChatGLM使用记录

ChatGLM ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存&#xff0…

opencv实战--角度测量和二维码条形码识别

文章目录 前言一、鼠标点击的角度测量二、二维码条形码识别 前言 一、鼠标点击的角度测量 首先导入一个带有角度的照片 然后下面的代码注册了一个鼠标按下的回调函数, 还有一个点的数列,鼠标事件为按下的时候就记录点,并画出点,…

uniapp微信小程序上传体积压缩包过大分包操作和上传时遇到代码质量未通过问题

1:首先我们得从项目最初阶段就得考虑项目是否要进行分包操作,如果得分包,我们应该创建一个与pages同级的文件夹,命名可以随意 2:第二部我们将需要分包的文件和页面放到分包文件夹里面subpage,这里我们得注意&#xff…

Python基础语法第三章之顺序循环条件

目录 一、顺序语句 二、条件语句 2.1什么是条件语句 2.2语法格式 2.2.1 if 2.2.2if - else 2.2.3if - elif - else 2.3缩进和代码块 2.4闰年的判断练习 2.5空语句 pass 三、循环语句 3.1while 循环 3.1.1代码示例练习 3.2 for 循环 ​3.3 continue 3.4 break 一…

给LLM装上知识:从LLM+LangChain的本地知识库问答到LLM与知识图谱的结合

前言 过去半年,随着ChatGPT的火爆,直接带火了整个LLM这个方向,然LLM毕竟更多是基于过去的经验数据预训练而来,没法获取最新的知识,以及各企业私有的知识 为了获取最新的知识,ChatGPT plus版集成了bing搜…

linux常用工具介绍

文章目录 前言目录文件查看ls1、查看详细信息(文件大小用K、M等显示)2、按照文件创建时间排序(常在查看日志时使用) sort1、排序数字 df 、du1、查看目录的大小2、查看目录 从大到小排序 显示前n个3、查看磁盘使用情况 tailf一些目…

银河麒麟高级服务器操作系统V10安装mysql数据库

一、安装前 1.检查是否已经安装mysql rpm -qa | grep mysql2.将查询出的包卸载掉 rpm -e --nodeps 文件名3.将/usr/lib64/libLLVM-7.so删除 rm -rf /usr/lib64/libLLVM-7.so4.检查删除结果 rpm -qa | grep mysql5.搜索残余文件 whereis mysql6.删除残余文件 rm -rf /usr/b…

怎么用二维码做企业介绍?企业宣传二维码2种制作方法

怎么做一个企业推广二维码呢?现在制作二维码来做宣传推广是常用的一种方式,一般需要包含企业介绍、工作环境、产品简介、宣传视频、公司地址等等方面内容,那么企业介绍二维码该如何制作?下面给大家分享一下使用二维码编辑器&#…

EventBus详解

目录 1 EventBus 简介简介角色关系图四种线程模型 2.EventBus使用步骤添加依赖注册解注册创建消息类发送消息接收消息粘性事件发送消息 使用postStick()接受消息 3 EventBus做通信优点4 源码getDefault()register()findSubscriberMethods方法findUsingReflection方法findUsingR…

前端部署项目,经常会出现下载完 node 或者 npm 运行时候发现,提示找不到

1. 首先要在下载时候选择要下载的路径,不能下载完后,再拖拽到其他文件夹,那样就会因为下载路径和当前路径不一致,导致找不到相关变量。 2. 所以一开始就要在下载时候确定要存放的路径,然后如果运行报错,就…

【Java基础教程】(十三)面向对象篇 · 第七讲:继承性详解——继承概念及其限制,方法覆写和属性覆盖,关键字super的魔力~

Java基础教程之面向对象 第七讲 本节学习目标1️⃣ 继承性1.1 继承的限制 2️⃣ 覆写2.1 方法的覆写2.2 属性的覆盖2.3 关键字 this与 super的区别 3️⃣ 继承案例3.1 开发数组的父类3.2 开发排序类3.3 开发反转类 🌾 总结 本节学习目标 掌握继承性的主要作用、实…

git指令记录

参考博客(侵权删):关于Git这一篇就够了_17岁boy想当攻城狮的博客-CSDN博客 Git工作区介绍_git 工作区_xyzso1z的博客-CSDN博客 git commit 命令详解_gitcommit_辰风沐阳的博客-CSDN博客 本博客只作为自己的学习记录,无商业用途&…

计算机存储设备

缓存为啥比内存快 内存使用 DRAM 来存储数据的、也就是动态随机存储器。内部使用 MOS 和一个电容来存储。 需要不停地给它刷新、保持它的状态、要是不刷新、数据就丢掉了、所以叫动态 、DRAM 缓存使用 SRAM 来存储数据、使用多个晶体管(比如6个)就是为了存储1比特 内存编码…

【python】python全国数据人均消费数据分析(代码+报告+数据)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化 👉荣__誉👈:阿里云博客专家博主、51CTO技术博主 &#x…

bio、nio、aio、io多路复用

BIO-同步阻塞IO NIO-同步非阻塞IO 不断的重复发起IO系统调用,这种不断的轮询,将会不断地询问内核,这将占用大量的 CPU 时间,系统资源利用率较低 IO多路复用模型-异步阻塞IO IO多路复用模型,就是通过一种新的系统调用&a…

前端开发者都应知道的 网站

1、ransform.tools 地址:transform.tools/ transform.tools 是一个网站,它可以让你转换几乎所有的东西,比如将HTML转换为JSX,JavaScript转换为JSON,CSS转换为JS对象等等。当我需要转换任何东西时,它真的帮…

Java反射机制概述

Java反射的概述 Reflection(反射)是被视为动态语言的关键,反射机制允许程序在执行期借助于Reflection API取得任何类的内部信息,并能直接操作任意对象的内部属性及方法。 加载完类之后,在堆内存的方法区中就产生了一…