PPO算法逐行代码详解

news2024/10/7 10:25:55

前言:本文会从理论部分、代码部分、实践部分三方面进行PPO算法的介绍。其中理论部分会介绍PPO算法的推导流程,代码部分会给出PPO算法的各部分的代码以及简略介绍,实践部分则会通过debug代码调试的方式从头到尾的带大家看清楚应用PPO算法在cartpole环境上进行训练的整体流程,进而帮助大家将理论与代码实践相结合,更好的理解PPO算法。

文章目录

      • 1. 理论部分
      • 2. 代码部分
        • 2.1 神经网络的定义
        • 2.2 PPO算法的定义
        • 2.3 on policy算法的训练代码
      • 3. 实践部分

1. 理论部分

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

2. 代码部分

这里使用的是《动手学强化学习》中提供的代码,我将这本书中的代码整理到了github上,并且方便使用pycharm进行运行和调试。代码地址:https://github.com/zxs-000202/dsx-rl

代码核心部分整体上可以分为三部分,分别是关于神经网络的类的定义(PolicyNet,ValueNet),关于PPO算法的类的定义(PPO),on policy算法训练流程的定义(train_on_policy_agent)。

2.1 神经网络的定义

策略网络actor采用两层全连接层,第一层采用relu激活函数,第二层采用softmax函数。

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

价值网络critic采用两层全连接层,第一层和第二层均采用relu激活函数,第二层最后输出的维度是1,表示t时刻某个状态s的价值V

class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
2.2 PPO算法的定义

这部分是PPO算法最核心的部分。本部分包括神经网络的初始化、相关参数的定义,如何根据状态s选择动作,以及网络参数是如何更新的。整体的代码如下:

class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
                                               td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1,
                                                            actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps,
                                1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()
2.3 on policy算法的训练代码

我们知道on policy算法与off policy算法的区别就在与进行采样的网络和用来参数更新的训练网络是否是一个网络。PPO是一种on policy的算法,on policy算法要求训练的网络参数一更新就需要重新进行采样然后训练。但PPO有点特殊,它是利用重要性采样方法来实现数据的多次利用,提高了数据的利用效率。

从下面的代码中可以看到每当一个episode的数据采样完毕之后都会执行agent.update(transition_dict)。在update中会将当前采样得到的数据(也就是当前episode的每个t时刻的 [ 当前状态,动作,奖励,是否结束,下一个状态 ] 这个五元组)通过重要性采样的方法进行多次的神经网络参数的更新,这个具体过程我会在第三部分实践部分进行详细说明。

def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode + 1),
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list

3. 实践部分

该部分带领大家从代码运行流程上从头到尾过一遍。
本次PPO算法训练应用的gym环境是CartPole-v0,如下图。该gym环境的状态空间是四个连续值用来表示杆所处的状态,而动作空间是两个离散值用来表示给杆施加向左或者向右的力(也就是action为0或者1)。
在这里插入图片描述
代码中前面的部分是网络、算法、训练参数的定义,真正开始训练是在return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)。我们在这里打一个断点进行调试。
在这里插入图片描述
比较重要的训练参数是episode为500,epoch为10。表示一共采样500个episode的数据,每个episode数据采样完毕后再update阶段进行10次的梯度下降参数更新。
在这里插入图片描述
执行到这里我们跳到take_action里面看看agent是如何选取动作的。
在这里插入图片描述
可以看到首先将输入的state转换为tensor,然后因为actor网络的最后一层是softmax函数,所以通过actor网络输出两个执行两个动作可能性的大小,然后通过action_dist = torch.distributions.Categorical(probs) action = action_dist.sample()根据可能性大小进行采样最后得到这次选择动作1进行返回。

第一个episode中我们采样到的数据如下:
在这里插入图片描述
可以看到第一个episode执行了41个step之后done了。
接下来我们跳进update函数中看看具体如何用这个episode采样的数据进行神经网络参数的更新。
在这里插入图片描述
首先从transition_dict中将采样的41个step的数据取出来存到tensor中方便之后进行运算。
在这里插入图片描述
接下来计算td_targettd_deltaadvantageold_log_probs
其中td_target表示t时刻的得到的奖励值r加上根据t+1时刻critic估计的状态价值V。用公式表示的话就是:

td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)

在这里插入图片描述
td_delta表示的是td_targett时刻采用critic估计的状态价值之间的差值。用公式表示的话就是:

td_delta = td_target - self.critic(states)

在这里插入图片描述
advantage则表示的是某状态下采取某动作的优势值,也就是下面PPO算法公式中的A

advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)

在这里插入图片描述
old_log_probs则表示的是旧策略下某个状态下采取某个动作的概率值的对数值。对应下面PPO算法公式中的clip函数中分母。

old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()

在这里插入图片描述

到这里我们计算出这些值的shape如下图所示。接下来我们就要进入一个循环中,循环epoch=10次进行参数的神经网络参数的更新。
在这里插入图片描述

具体而言首先在每个循环中计算一次当前的actor参数下对应的stateaction的概率值的对数值,也就是下面分式中的分子。
这里用到的高中对数数学公式有:
在这里插入图片描述
在这里插入图片描述
通过上面两个公式可以计算得到
在这里插入图片描述

然后计算这个分式的值ratio
之后结合下面的公式和代码可以看出surr1surr2分别是下图公式中对应的值。
在这里插入图片描述

log_probs = torch.log(self.actor(states).gather(1, actions))
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantage
surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断

然后计算actorloss值和criticloss

actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

计算得到的数据的shape如下图:
在这里插入图片描述
最后对actorcritic进行梯度下降进行参数更新。

self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
ctor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()

至此第一次循环执行完毕,接下来还需要再进行九次这种循环,把当前episode的41个step的数据再进行九次actorcritic参数的更新。

然后update相当于执行完毕了。在这里插入图片描述
接下来一个episode的训练过程相当于结束了,我们一共有500个episode需要训练,再循环499次就完成了整个训练的流程。

还有一部分比较重要的就是通过GAE计算优势函数Advantage的代码。

def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

这部分就留给大家结合第一部分理论部分最后一张图来进行理解了。

参考资料

  1. 磨菇书easy-rl https://datawhalechina.github.io/easy-rl/#/chapter5/chapter5
  2. 《动手学强化学习》 https://hrl.boyuai.com/chapter/2/ppo%E7%AE%97%E6%B3%95

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1089583.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

三、静态路由实验

拓扑图: 两个路由器分了三个网段出来,首先对两台PC机进行配置 进入R1路由器对两边链路进行ip配置 对AR2进行相同的配置,然后我们查看R1的路由表,里面有一些直连的信息。 三个网段的设备现在可以互通,我们要实现跨网段…

[数据结构]——单链表超详细总结

带你走进链表的世界 目录:一、线性表的概念二、顺序表三、链表3.1 链表的概念3.2 链表的分类3.3 无头单向非循环链表的实现3.4 带头双向循环链表的实现 四、顺序表和链表的区别和联系 目录: 链表是个优秀的结构,没有容量概念,可以…

Python接口测试 requests.post方法中data与json参数区别

引言 requests.post主要参数是data与json,这两者使用是有区别的,下面我详情的介绍一下使用方法。 Requests参数 1. 先可以看一下requests的源码: 1 2 3 4 5 6 7 8 9 10 11 12 13 def post(url, dataNone, jsonNone, **kwargs): r&quo…

STM32CUBEMX_DMA串口空闲中断接收+接收发送缓冲区

STM32CUBEMX_DMA串口空闲中断接收接收发送缓冲区 前言: 我了解的串口接收指令的方式有:在这里插入图片描述 1、接收数据中断特定帧尾 2、接收数据中断空闲中断 3、DMA接收空闲中断 我最推荐第三种,尤其是数据量比较大且频繁的时候 串口配置 …

Vmware Linux虚拟机安装教程(Centos版)

文章目录 1.Vmware-workstation安装软件2.双击下载的安装包开始安装3.打开VMware-workstation,输入密钥4.Centos7.6安装软件5.新建虚拟机6.为虚拟机配置映像文件7.开启虚拟机,配置环境7.1 Install Centos 77.2 选择简体中文字体7.3 软件选择7.4 安装位置…

LeetCode【20】 有效的括号

题型:栈 题目: 代码: public boolean isValidReview(String s) {//(1)从s的i0位置一次开始压栈,遇到左括号压栈,不管是大中小三种,左括号,压//(2&#xf…

基于 EventBridge 轻松搭建消息集成应用

作者:昶风 前言 本篇文章主要介绍基于阿里云 EventBridge 的消息集成能力,结合目前消息产品的需求热点,从能力范围到场景实战,对 EventBridge 的消息集成解决方案进行了概要的介绍。 从消息现状谈起 消息队列作为应用解耦&…

Redis 基础—Redis Desktop Manager(Redis可视化工具)安装及使用教程

Redis Desktop Manager 是一个可视化的 Redis 数据库管理工具,可以方便地查看和操作 Redis 数据库。使用 Redis Desktop Manager 可以大大提高 Redis 数据库的管理效率。 RDM的安装和配置 首先,您需要下载和安装Redis Desktop Manager。 安装完成后&am…

高校教务系统密码加密逻辑及JS逆向——皖南医学院

高校教务系统密码加密逻辑及JS逆向 本文将介绍高校教务系统的密码加密逻辑以及使用JavaScript进行逆向分析的过程。通过本文,你将了解到密码加密的基本概念、常用加密算法以及如何通过逆向分析来破解密码。 本文仅供交流学习,勿用于非法用途。 一、密码加…

提升企业服务行业管理效率的关键策略与方法」

To B 服务,也即企业服务近年来大家关注的重点,企业服务是指为企业客户提供运营管理相关的工具与系统解决方案等服务,通过专业化的服务可以实现降本增效,提高企业的经营效率。 近年来,人口红利的消失促使企业服务需求快…

DC电源模块低温是否影响转换效率

BOSHIDA DC电源模块低温是否影响转换效率 DC电源模块是一种常用的电源转换装置,其主要作用是将输入的电源信号变换成需要的输出电源信号。在实际应用中,DC电源模块的性能会受到多种因素的影响,其中低温也是一个重要的影响因素。本文将从转换…

kettle应用-从数据库抽取数据到excel

本文介绍使用kettle从postgresql数据库中抽取数据到excel中。 首先,启动kettle 如果kettle部署在windows系统,双击运行spoon.bat或者在命令行运行spoon.bat 如果kettle部署在linux系统,需要执行如下命令启动 chmod x spoon.sh nohup ./sp…

c++设计模式之单例设计模式

💂 个人主页:[pp不会算法v](https://blog.csdn.net/weixin_73548574?spm1011.2415.3001.5343) 🤟 版权: 本文由【pp不会算法^v^】原创、在CSDN首发、需要转载请联系博主 💬 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连)和订阅专栏哦…

以命令行形式执行Postman脚本(使用Newman)

目录 一、背景 二、Newman的安装 三、脚本准备 四、Newman的执行 五、生成报告 一、背景 ​ Postman的操作离不开客户端。但是在一些情况下可能无法使用客户端去进行脚本执行。比如在服务端进行接口测试。由此我们引入了Newman。Newman基于Node.js开发,它使您可…

【MySQL】表的查询与连接

文章目录 预备工作一、表的基本查询1、简单基本查询2、分组聚合统计3、基本查询练习 二、表的复合查询1、多表查询2、子查询2.1 **单行子查询**2.2 **多行子查询**2.3 **多列子查询**2.4 在from子句中使用子查询 3、合并查询 三、表的连接1、自连接2、内连接3、外连接 预备工作…

【C++】关键字 命名空间 输入输出 缺省函数

一,C关键字 C 总计 63 个关键字,C语言 32 个关键字 直接上图: asmdoifreturntrycontinueautodoubleinlineshorttypedefforbooldynamic_castintsignedtypeidpublicbreakelselongsizeoftypenamethrowcaseenummutablestaticunionwchar_tcatche…

Win11更新后瘦身C盘的两个小技巧

每当windows更新完后,就很容易出现一个现象,那便是C盘存储变红了。 这个时候,就会有方法指出:把C盘中的系统更新文件清理掉吧,这样C盘就又能瘦回去了! 然而,当你兴冲冲地按照网上的教程点击C…

OPTEE之KASAN地址消毒动态代码分析

安全之安全(security)博客目录导读 目录 一、KASAN简介 二、OPTEE_OS中KASAN配置选项 三、OPTEE_OS中KASAN配置选项打开 一、KASAN简介 内核地址消毒器(KASAN)是Linux内核的快速内存损坏检测器,KASAN检测slab、page_alloc、vmalloc、stack和全局内存中的越界、u…

Yarn基础入门

文章目录 一、Yarn资源调度器1、架构2、Yarn工作机制3、HDFS、YARN、MR关系4、作业提交之HDFS&MapReduce 二、Yarn调度器和调度算法1、先进先出调度器(FIFO)2、容量调度器(Capacity Scheduler)3、公平调度器(Fair …

信钰证券:首板第二天买入技巧?

股票上市第一天,也就是所谓的“首板”,一般会引起商场的高度注重。那么关于投资者而言,如安在接下来的第二天进行买入是个十分要害的决议计划。本文将从多个角度剖析首板第二天买入技巧,供读者参阅。 首先,多数人或许…