普通策略梯度算法原理及PyTorch实现【VPG】

news2024/11/16 17:37:09

有没有想过强化学习 (RL) 是如何工作的?

在本文中,我们将从头开始构建最简单的强化学习形式之一 —普通策略梯度(VPG)算法。 然后,我们将训练它完成著名的 CartPole 挑战 — 学习从左向右移动购物车以平衡杆子。 在此过程中,我们还将完成对 OpenAI 的 Spinning Up 学习资源的第一个挑战。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎

本文的代码可以在 这里 找到。

1、我们的方法

我们将通过创建一个简单的深度学习模型来解决这个问题,该模型接受观察并输出随机策略(即采取每个可能行动的概率)。

然后,我们需要做的就是通过在环境中采取行动并使用此策略来收集经验。

当我们有足够的批量经验(几个episode经验的集合)后,我们需要转向梯度下降来改进模型。 在较高层面上,我们希望增加策略的预期回报,这意味着调整权重和偏差以增加高预期回报行动的概率。 就 VPG 而言,这意味着使用策略梯度定理,该定理给出了该预期回报的梯度方程(如下所示)。

这就是全部内容了—所以让我们开始编码吧!

2、创建模型

我们将首先创建一个带有一个隐藏层的非常简单的模型。 第一个线性层从 CartPole 的观察空间获取输入特征,最后一层返回可能结果的值。

def create_model(number_observation_features: int, number_actions: int) -> nn.Module:
    """Create the MLP model

    Args:
        number_observation_features (int): Number of features in the (flat)
        observation tensor
        number_actions (int): Number of actions

    Returns:
        nn.Module: Simple MLP model
    """
    hidden_layer_features = 32

    return nn.Sequential(
        nn.Linear(in_features=number_observation_features,
                  out_features=hidden_layer_features),
        nn.ReLU(),
        nn.Linear(in_features=hidden_layer_features,
                  out_features=number_actions),
    )

3、获取策略

我们还需要为每个时间步获取一个模型策略(以便我们知道如何采取行动)。 为此,我们将创建一个 get_policy 函数,该函数使用模型输出策略下每个操作的概率。 然后,我们可以返回一个分类(多项式)分布,该分布可用于选择根据这些概率随机分布的特定动作。

def get_policy(model: nn.Module, observation: np.ndarray) -> Categorical:
    """Get the policy from the model, for a specific observation

    Args:
        model (nn.Module): MLP model
        observation (np.ndarray): Environment observation

    Returns:
        Categorical: Multinomial distribution parameterized by model logits
    """
    observation_tensor = torch.as_tensor(observation, dtype=torch.float32)
    logits = model(observation_tensor)

    # Categorical will also normalize the logits for us
    return Categorical(logits=logits)

4、从策略中采样动作

从这个分类分布中,对于每个时间步长,我们可以对其进行采样以返回一个动作。 我们还将获得该动作的对数概率,这在稍后计算梯度时会很有用。

def get_action(policy: Categorical) -> tuple[int, float]:
    """Sample an action from the policy

    Args:
        policy (Categorical): Policy

    Returns:
        tuple[int, float]: Tuple of the action and it's log probability
    """
    action = policy.sample()  # Unit tensor

    # Converts to an int, as this is what Gym environments require
    action_int = action.item()

    # Calculate the log probability of the action, which is required for
    # calculating the loss later
    log_probability_action = policy.log_prob(action)

    return action_int, log_probability_action

5、计算损失

梯度的完整推导如这里所示。 宽松地说,它是每个状态-动作对的对数概率之和乘以该对所属的整个轨迹的回报的梯度。 额外的外层和汇总若干个情节(即一批),因此我们有重要的数据。

要使用 PyTorch 计算此值,我们可以做的是计算下面的伪损失,然后使用 .backward() 获取上面的梯度(注意我们刚刚删除了梯度项):

这通常被称为损失,但它并不是真正的损失,因为它不依赖于模型的性能。 它只是对于获取策略梯度有用。

def calculate_loss(epoch_log_probability_actions: torch.Tensor, epoch_action_rewards: torch.Tensor) -> float:
    """Calculate the 'loss' required to get the policy gradient

    Formula for gradient at
    https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#deriving-the-simplest-policy-gradient

    Note that this isn't really loss - it's just the sum of the log probability
    of each action times the episode return. We calculate this so we can
    back-propagate to get the policy gradient.

    Args:
        epoch_log_probability_actions (torch.Tensor): Log probabilities of the
            actions taken
        epoch_action_rewards (torch.Tensor): Rewards for each of these actions

    Returns:
        float: Pseudo-loss
    """
    return -(epoch_log_probability_actions * epoch_action_rewards).mean()

6、单个epoch训练

将以上所有内容放在一起,我们现在准备好训练一个epoch了。 为此,我们只需循环播放情节(episode)即可创建批次。 在每个情节中,创建一系列可用于训练模型的动作和奖励(即经验)。

def train_one_epoch(env: gym.Env, model: nn.Module, optimizer: Optimizer, max_timesteps=5000, episode_timesteps=200) -> float:
    """Train the model for one epoch

    Args:
        env (gym.Env): Gym environment
        model (nn.Module): Model
        optimizer (Optimizer): Optimizer
        max_timesteps (int, optional): Max timesteps per epoch. Note if an
            episode is part-way through, it will still complete before finishing
            the epoch. Defaults to 5000.
        episode_timesteps (int, optional): Timesteps per episode. Defaults to 200.

    Returns:
        float: Average return from the epoch
    """
    epoch_total_timesteps = 0

    # Returns from each episode (to keep track of progress)
    epoch_returns: list[int] = []

    # Action log probabilities and rewards per step (for calculating loss)
    epoch_log_probability_actions = []
    epoch_action_rewards = []

    # Loop through episodes
    while True:

        # Stop if we've done over the total number of timesteps
        if epoch_total_timesteps > max_timesteps:
            break

        # Running total of this episode's rewards
        episode_reward: int = 0

        # Reset the environment and get a fresh observation
        observation = env.reset()

        # Loop through timesteps until the episode is done (or the max is hit)
        for timestep in range(episode_timesteps):
            epoch_total_timesteps += 1

            # Get the policy and act
            policy = get_policy(model, observation)
            action, log_probability_action = get_action(policy)
            observation, reward, done, _ = env.step(action)

            # Increment the episode rewards
            episode_reward += reward

            # Add epoch action log probabilities
            epoch_log_probability_actions.append(log_probability_action)

            # Finish the action loop if this episode is done
            if done == True:
                # Add one reward per timestep
                for _ in range(timestep + 1):
                    epoch_action_rewards.append(episode_reward)

                break

        # Increment the epoch returns
        epoch_returns.append(episode_reward)

    # Calculate the policy gradient, and use it to step the weights & biases
    epoch_loss = calculate_loss(torch.stack(
        epoch_log_probability_actions),
        torch.as_tensor(
        epoch_action_rewards, dtype=torch.float32)
    )

    epoch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return np.mean(epoch_returns)

7、运行算法

现在可以运行算法了。

def train(epochs=40) -> None:
    """Train a Vanilla Policy Gradient model on CartPole

    Args:
        epochs (int, optional): The number of epochs to run for. Defaults to 50.
    """

    # Create the Gym Environment
    env = gym.make('CartPole-v0')

    # Use random seeds (to make experiments deterministic)
    torch.manual_seed(0)
    env.seed(0)

    # Create the MLP model
    number_observation_features = env.observation_space.shape[0]
    number_actions = env.action_space.n
    model = create_model(number_observation_features, number_actions)

    # Create the optimizer
    optimizer = Adam(model.parameters(), 1e-2)

    # Loop for each epoch
    for epoch in range(epochs):
        average_return = train_one_epoch(env, model, optimizer)
        print('epoch: %3d \t return: %.3f' % (epoch, average_return))


if __name__ == '__main__':
    train()

大约 40 个 epoch 后,可以看到模型已经很好地学习了环境(得分 180+/ 200):

epoch:  26       return: 118.070
epoch:  27       return: 114.659
epoch:  28       return: 135.405
epoch:  29       return: 144.000
epoch:  30       return: 143.972
epoch:  31       return: 152.091
epoch:  32       return: 166.065
epoch:  33       return: 162.613
epoch:  34       return: 166.806
epoch:  35       return: 172.933
epoch:  36       return: 173.241
epoch:  37       return: 181.071
epoch:  38       return: 186.222
epoch:  39       return: 176.793

原文链接:普通策略梯度实现 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284557.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哈希与哈希表

哈希表的概念 哈希表又名散列表,官话一点讲就是: 散列表(Hash table,也叫哈希表),是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记…

MySQL的多表查询

多表关系 一对多(多对一)-> 多对多-> 一对一-> 概述 概述 多表查询分类 内连接 代码演示--> -- 内连接演示 -- 1.查询每一个员工的姓名,及关联的部门的名称(隐式内连接实现) select emp.name, dept.name from emp,dept where emp.dept_id dept.id; …

10、外观模式(Facade Pattern,不常用)

外观模式(Facade Pattern)也叫作门面模式,通过一个门面(Facade)向客户端提供一个访问系统的统一接口,客户端无须关心和知晓系统内部各子模块(系统)之间的复杂关系,其主要…

sql面试题之“互相关注的人”(方法三)

题目:某社交平台有关注这个功能,关注的同时也会被关注。现有需求需要找出平台上哪些用户之间互相关注。 文章目录 题目如下:一、数据准备二、建表并导入数据1.建表2.导入数据3.数据分析和实现思路小结: 题目如下: 某社…

[RK-Linux] 移植Linux-5.10到RK3399(三)| 检查eMMC与SD卡配置

这个专题主要记录把 RK Linux-5.10 移植到 ROC-RK3399-PC Pro 的过程。 文章目录 一、eMMC二、SD 卡三、两个接口的区别一、eMMC RK3399 的 eMMC 接口如图: datasheet 介绍: 实际上,连接 eMMC 存储器用的是 SDHCI 接口。SDHCI(Secure Digital Host Controller Interface)…

【数据结构】最短路径——Floyd算法

一.问题描述 给定带权有向图G(V,E),对任意顶点 V (ij),求顶点到顶点的最短路径。 转化为: 多源点最短路径求解问题 解决方案一: 每次以一个顶点为源点调用Dijksra算法。时间复杂…

香港虚拟信用卡如何办理,支持香港apple id

什么是虚拟信用卡? 虚拟信用卡,英文称之为Virtual Credit Card Numbers,就是指没有实体卡片,是基于银行卡上面的BIN码所生成的虚拟账号。通常用于进行网络交易,使用起来很方便,也很安全。 它与实体信用卡…

算法-01-递归

1-理解递归 斐波那契数列(Fibonacci sequence),又称黄金分割数列 ,以兔子繁殖为例子而引入,故又称“兔子数列”,其数值为:1、1、2、3、5、8、13、21、34……特点是 从第三个数开始,第…

HOST文件被挟持,无法上网,如何解决。

问题: 晚上开机,突然发现无法联网,提示网络异常 解决: 首先网络诊断,host文件被劫持,修复后,仍然不行。 然后测试手机热点,发现仍然无法联网 尝试用火绒修复,无果。 所有…

Linux Camera Driver(2):CIS设备注册(DTS)

一:MIPI接口 1、硬件接口 MIPI接口以rv1109和gc2053的硬件为例进行说明: 2、ISP驱动 注意配置事项: endpoint配置,必须指定data-lanes,否则无法识别为mipi类型 链接方式:sensor->csi_dphy->isp->ispp (1)sensor节点配置 根据原理图可知:mipicsi_clk0即引…

Linux系统安装Python3环境

1、默认情况下,Linux会自带安装Python,可以运行python --version命令查看,如图: 我们看到Linux中已经自带了Python2.7.5。再次运行python命令后就可以使用python命令窗口了(CtrlD退出python命令窗口)。 2…

STM32F407-14.3.11-01互补输出和死区插入

互补输出和死区插入 高级控制定时器(TIM1 和 TIM8)可以输出两路互补信号,并管理输出的关断与接通瞬间。 这段时间通常称为死区,用户必须根据与输出相连接的器件及其特性(电平转换器的固有延迟、开关器件产生的延迟...&…

MySQL之时间戳(DateTime和TimeStamp)

MySQL之时间戳(DateTime和TimeStamp) 文章目录: MySQL之时间戳(DateTime和TimeStamp)一、DateTime类型二、TimeStamp类型三、DateTime和TimeStamp的区别 当插入数据时,需要自动记录一个时间时候&#xff0c…

llama.cpp部署(windows)

一、下载源码和模型 下载源码和模型 # 下载源码 git clone https://github.com/ggerganov/llama.cpp.git# 下载llama-7b模型 git clone https://www.modelscope.cn/skyline2006/llama-7b.git查看cmake版本: D:\pyworkspace\llama_cpp\llama.cpp\build>cmake --…

git 本地改动无法删除

1. 问题 记录下git遇到奇怪的问题,本地有些改动不知道什么原因无法删除 git stash, git reset --hard HEAD 等都无法生效,最终通过强制拉取线上解决 如下图: 2. 解决 git fetch --all git reset --hard origin/master执行这两…

LeedCode刷题---双指针问题

顾得泉:个人主页 个人专栏:《Linux操作系统》 《C/C》 《LeedCode刷题》 键盘敲烂,年薪百万! 双指针简介 常见的双指针有两种形式,一种是对撞指针,一种是左右指针。 对撞指针:一般用于顺序结构中&…

马斯克极简5步工作法 —— 筑梦之路

马斯克的五步流程法则: 第一步:确定需求 第二步:极力删除零件或过程 第三步:简化和优化 第四步:加快周期时间 第五步:自动化特别注意:完成前三步之前,千万不要考虑加速和自动化&…

JVM类加载全过程

Java虚拟机类加载的全过程,即加载,验证,准备,解析,初始化 一、加载 加载 是 类加载过程中的一个阶段, 有以下三部分组成 1)通过一个类的全限定名来获取定义此类的二进制流 2)将这…

跨域问题的解决办法

1、产生跨域问题样式 前台写完,直接访问后台接口,会产生跨域的问题,需要配置文件去解决这个问题 从源http://localhost:8080访问http://localhost:8088/books的XMLHttpRequest已被CORS策略阻止:请求的资源上没有Access- control - allow - o…

C++学习之路(十八)C++ 用Qt5实现一个工具箱(点击按钮以新窗口打开功能面板)- 示例代码拆分讲解

上篇文章,我们用 Qt5 实现了在小工具箱中添加了《增加托盘图标并且增加显示和退出菜单》功能。今天我们把按钮打开功能的方式改一改,让点击按钮以新窗口打开功能面板。下面我们就来看看如何来规划开发这样的小功能并且添加到我们的工具箱中吧。 老规矩&…