ChatGPT的强化学习部分介绍——PPO算法实战LunarLander-v2

news2025/1/12 17:51:51

PPO算法

近线策略优化算法(Proximal Policy Optimization Algorithms) 即属于AC框架下的算法,在采样策略梯度算法训练方法的同时,重复利用历史采样的数据进行网络参数更新,提升了策略梯度方法的学习效率。 PPO重要的突破就是在于对新旧新旧策略器参数进行了约束,希望新的策略网络和旧策略网络的越接近越好。 近线策略优化的意思就是:新策略网络要利用到旧策略网络采样的数据进行学习,不希望这两个策略相差特别大,否则就会学偏。PPO依然是openai在2017年提出来的,论文地址

PPO-clip的损失函数,其中损失包含三个部分:

  • clip部分:加权采样后的优势值越好(可理解层价值网络的评分),同时采样通过clip防止新旧策略网络相差过大
  • VF部分: 价值网络预测的价值和环境真是的回报值越接近越好
  • S 部分:策略网络输出策略的熵值,越大越好,这个explore的思想,希望策略网络输出的动作分布概率不要太集中,提高了每个动作都有机会在环境中发生的可能。

实战部分

导入必要的包

%matplotlib inline
import matplotlib.pyplot as plt

from IPython import display

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm.notebook import tqdm
准备环境

准备好openai开发的LunarLander-v2的游戏环境。

seed = 543
def fix(env, seed):
    env.action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
import gym
import random
env = gym.make('LunarLander-v2' ,render_mode='rgb_array')
fix(env, seed) # fix the environment Do not revise this !!!

下面是采用代码去输出 环境的 观测值,一个8维向量,动作是一个标量,4选一。

  • 该环境共有 8 个观测值,分别是: 水平坐标 x; 垂直坐标 y; 水平速度; 垂直速度; 角度; 角速度; 腿1触地; 腿2触地;
  • 可以采取四种离散的行动,分别是: 0 代表不采取任何行动 1.代表主引擎向左喷射 2 .代表主引擎向下喷射 3 .代表主引擎向右喷射
  • 环境中的 reward 大致是这样计算: 小艇坠毁得到 -100 分; 小艇在黄旗帜之间成功着地则得 100~140 分; 喷射主引擎(向下喷火)每次 -0.3 分; 小艇最终完全静止则再得 100 分

在这里插入图片描述

随机动作玩5把

env.reset()

img = plt.imshow(env.render())

done = False
rewords = []
for i in range(5):
    env.reset()[0]
    img = plt.imshow(env.render())
    total_reward = 0
    done = False
    while not done:
        action = env.action_space.sample()
        observation, reward, done, _ , _= env.step(action)
        total_reward += reward
        img.set_data(env.render())
        display.display(plt.gcf())
        display.clear_output(wait=True)
    rewords.append(total_reward)

在这里插入图片描述

只有一把分数是正的,只有一次平安落地,小艇最终完全静止
在这里插入图片描述

搭建PPO agent

class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

class ActorCriticDiscrete(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCriticDiscrete, self).__init__()

        # actor
        self.action_layer = nn.Sequential(
                nn.Linear(state_dim, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, action_dim),
                nn.Softmax(dim=-1)
                )

        # critic
        self.value_layer = nn.Sequential(
               nn.Linear(state_dim, 128),
               nn.ReLU(),
               nn.Linear(128, 64),
               nn.ReLU(),
               nn.Linear(64, 1)
               )

    def act(self, state, memory):
        state = torch.from_numpy(state).float()
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)
        action = dist.sample()

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(dist.log_prob(action))

        return action.item()

    def evaluate(self, state, action):
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()

        state_value = self.value_layer(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy

class PPOAgent:
    def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        self.timestep = 0
        self.memory = Memory()

        self.policy = ActorCriticDiscrete(state_dim, action_dim, n_latent_var)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
        self.policy_old = ActorCriticDiscrete(state_dim, action_dim, n_latent_var)
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def update(self):   
        # Monte Carlo estimate of state rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(self.memory.rewards), reversed(self.memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        # Normalizing the rewards:
        rewards = torch.tensor(rewards, dtype=torch.float32)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.stack(self.memory.states).detach()
        old_actions = torch.stack(self.memory.actions).detach()
        old_logprobs = torch.stack(self.memory.logprobs).detach()

        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values : 新策略 重用 旧样本进行训练 
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            # Finding the ratio (pi_theta / pi_theta__old): 
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:计算优势值
            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages ###  重要性采样的思想,确保新的策略函数和旧策略函数的分布差异不大
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages ### 采样clip的方式过滤掉一些新旧策略相差较大的样本
            loss = -torch.min(surr1, surr2)  + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy

            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

    def step(self, reward, done):
        self.timestep += 1 
        # Saving reward and is_terminal:
        self.memory.rewards.append(reward)
        self.memory.is_terminals.append(done)

        # update if its time
        if self.timestep % update_timestep == 0:
            self.update()
            self.memory.clear_memory()
            self.timstamp = 0

    def act(self, state):
        return self.policy_old.act(state, self.memory)

训练PPO agent


state_dim = 8 ### 游戏的状态是个8维向量
action_dim = 4 ### 游戏的输出有4个取值
n_latent_var = 256           # 神经元个数
update_timestep = 1200      # 每多少补跟新策略
lr = 0.002                  # learning rate
betas = (0.9, 0.999)
gamma = 0.99                # discount factor
K_epochs = 4                # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO  论文中表明0.2效果不错
random_seed = 1 

agent = PPOAgent(state_dim ,action_dim,n_latent_var,lr,betas,gamma,K_epochs,eps_clip)
# agent.network.train()  # Switch network into training mode 
EPISODE_PER_BATCH = 5  # update the  agent every 5 episode
NUM_BATCH = 200     # totally update the agent for 400 time


avg_total_rewards, avg_final_rewards = [], []

# prg_bar = tqdm(range(NUM_BATCH))
for i in range(NUM_BATCH):

    log_probs, rewards = [], []
    total_rewards, final_rewards = [], []
    values    = []
    masks     = []
    entropy = 0
    # collect trajectory
    for episode in range(EPISODE_PER_BATCH):
        ### 重开一把游戏
        state = env.reset()[0]
        total_reward, total_step = 0, 0
        seq_rewards = []
        for i in range(1000):  ###游戏未结束

            action = agent.act(state) ### 按照策略网络输出的概率随机采样一个动作
            next_state, reward, done, _, _ = env.step(action) ### 与环境state进行交互,输出reward 和 环境next_state
            state = next_state
            total_reward += reward
            total_step += 1     
            rewards.append(reward) ### 记录每一个动作的reward
            agent.step(reward, done)   
            if done:  ###游戏结束
                final_rewards.append(reward)
                total_rewards.append(total_reward)
                break

    print(f"rewards looks like ", np.shape(rewards))  
    if len(final_rewards)> 0 and len(total_rewards) > 0:
        avg_total_reward = sum(total_rewards) / len(total_rewards)
        avg_final_reward = sum(final_rewards) / len(final_rewards)
        avg_total_rewards.append(avg_total_reward)
        avg_final_rewards.append(avg_final_reward)

PPO agent在玩5把游戏


fix(env, seed)
agent.policy.eval() # set the network into evaluation mode
test_total_reward = []
for i in range(5):
    actions = []
    state = env.reset()[0]
    img = plt.imshow(env.render())
    total_reward = 0
    done = False
    while not done :
        action= agent.act(state)
        actions.append(action)
        state, reward, done, _, _ = env.step(action)
        total_reward += reward
        img.set_data(env.render())
        display.display(plt.gcf())
        display.clear_output(wait=True)
    test_total_reward.append(total_reward)

从下图可以看到,玩得比随机的时候要好很多,有三把都平安着地,小艇最终完全静止。
在这里插入图片描述

其中有3把得分超过200,证明ppo在300轮已经学到了如何玩这个游戏,比之前的随机agent要强了不少。
在这里插入图片描述
github 源码

参考

ChatGPT的强化学习部分介绍
基于人类反馈的强化学习(RLHF) 理论
Anaconda安装GYM &关于Box2D安装的相关问题
conda-forge / packages / box2d-py
jupyter Notebook 内核似乎挂掉了,它很快将自动重启

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505401.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leecode101 ——对称二叉树

对称二叉树:Leecode 101 leecode 101 对称二叉树 根据题目描述,首先想清楚,对称二叉树要比较的是哪两个节点。对于二叉树是否对称,要比较的是根节点的左子树与根节点的右子树是不是相互翻转的,其实也就是比较两个树,…

A+CLUB活动预告 | 2023年5月

五月立夏已至,万物并秀。 ACLUB五月活动预告新鲜出炉,本月我们将会有2场峰会活动(深圳)、1场沙龙活动(深圳)和1场管理人支持计划(上海),期待您的参与! 注&a…

06- 算法解读 Fast R-CNN (目标检测)

要点: Fast R-CNN 属于 Two-stage detector 回归损失参考:https://www.cnblogs.com/wangguchangqing/p/12021638.html 二 Fast R-CNN算法 Fast R-CNN 是作者 Ross Girshick 继 R-CNN 后的又一力作。同样使用 VGG16 作为网络的 backbone , …

您应该使用WhatsApp电子商务的3个理由

随着WhatsApp进军电子商务领域并扩大其为企业提供的服务种类,它正迅速成为客户参与的完美工具和对话式商务的绝佳平台。使用WhatsApp作为主要渠道向客户销售商品和服务的前景产生了WhatsApp电子商务一词。 WhatsApp已被加冕为世界领先的消息传递平台。对于最终用户来…

二十三、SQL 数据分析实战(10个简单的SQL题目)

文章目录 题目1: 比赛名单整理题目2: 热门游戏排行题目3: 社区生鲜App覆盖分析题目4: 社区团购行为分析题目5: 统计字符出现次数题目6: 找出各类别商品销量最高的商品题目7: 找出每个部门薪资第二高的员工题目8: 游戏玩家登录情况分析题目9: 用户首单消费金额题目10: 参与优惠活…

这可能是你读过最透彻的TCC方案讲解|分布式事务系列(三)

本文从两个场景说起,详细描述了TCC的详细过程,以及对比2PC有什么区别,适用什么样的场景。 点击上方“后端开发技术”,选择“设为星标” ,优质资源及时送达 在面试前复习 TCC 的时候你是不是这样做的:百度TC…

百度再掀智能手机风云:推出小度AI智能手机

我是卢松松,点点上面的头像,欢迎关注我哦! 太突然了!百度进军手机市场,据报百度也要进军手机市场了。5月底发布首款AI智能手机。目前这款新机处于发布前的最后准备阶段。 这款智能手机将采用百度的AI技术,预计会具备…

react如何渲染包含html标签元素的字符串

如何渲染包含html标签元素的字符串 最近有个搜索替换的需求,用户可以输入关键字信息来匹配出对应的数据,然后对其关键字进行标记显示,如下图所示: 实现上面的需求的思路就是前端去判断检索内容,将内容中对应的关键字…

又失眠了!

下班晚,洗漱之后,就这个点了,睡不着,爬上来和大家随意 BB 几句。 一个人成长的过程,也是自我认同感不断增强的过程,在这个过程中,一个稳定的精神内核不断夯实,你不会为谁的贬低而诚惶…

Linux性能监控与调优工具

Linux性能监控与调优工具 文章目录 Linux性能监控与调优工具1.使用top、 vmstat、 iostat、 sysctl等常用工具2.使用高级分析手段, 如OProfile、 gprof4.使用LTP进行压力测试5.使用Benchmark评估系统 除了保证程序的正确性以外, 在项目开发中往往还关心性…

体验讯飞星火认知大模型,据说中文能力超越ChatGPT

📋 个人简介 💖 作者简介:大家好,我是阿牛,全栈领域优质创作者。😜📝 个人主页:馆主阿牛🔥🎉 支持我:点赞👍收藏⭐️留言&#x1f4d…

Nature | 生成式人工智能如何构建更好的抗体

疫情高峰期,研究人员竞相开发一些首批有效的COVID-19治疗方法:从已经康复的人的血液中分离出来的抗体分子。 现在,科学家已经证明,生成式人工智能(AI)可以通过一些繁琐的过程提供捷径,提出增强抗…

代码随想录算法训练营day29 | 491.递增子序列,46.全排列,47.全排列 II

代码随想录算法训练营day29 | 491.递增子序列,46.全排列,47.全排列 II 491.递增子序列解法一:回溯(map进行数层去重)解法二:回溯(仅针对本题,不具有普适性) 46.全排列解法…

有一说一,这是我看到的全网最新最全的SpringBoot后端接口规范了

一、前言 一个后端接口大致分为四个部分组成:接口地址(url)、接口请求方式(get、post等)、请求数据(request)、响应数据(response)。虽然说后端接口的编写并没有统一规范…

一个.Net功能强大、易于使用、跨平台开源可视化图表

可视化图表运用是非常广泛的,比如BI系统、报表统计等。但是针对桌面应用的应用,很多报表都是收费的,今天给大家推荐一个免费.Net可视化开源的项目! 项目简介 基于C#开发的功能强大、易于使用、跨平台高质量的可视化图表库&#…

Shiro 入门概述

目录 是什么 为什么要用 Shiro Shiro 与 SpringSecurity 的对比 基本功能 原理 是什么 Apache Shiro 是一个功能强大且易于使用的 Java 安全(权限)框架。Shiro 可以完 成:认证、授权、加密、会话管理、与 Web 集成、缓存 等。借助 Shiro 您可以快速轻松 地保护…

Linux 中的文件锁定命令:flock、fcntl、lockfile、flockfile

在 Linux 系统中,文件锁定是一种对文件进行保护的方法,可以防止多个进程同时访问同一个文件,从而导致数据损坏或者冲突。文件锁定命令是一组用于在 Linux 系统中实现文件锁定操作的命令,它们可以用于对文件进行加锁或解锁&#xf…

直击中国国际金融展:实在智能携多项科技成果亮相,展现数字金融力量

4月25日-27日,中国国际金融展于北京首钢会展中心成功举办。作为我国规格最高、历史最久的金融科技展,本次展会以“荟萃金融科技成果,展现数字金融力量,谱写金融服务中国式现代化新篇章”为主题,吸引了众多国内金融机构…

企业邮箱选购,需关注哪些重要因素?

企业邮箱选择考虑哪些问题?应该从企业邮箱安全、企业邮箱的稳定性、企业邮箱专业、方便迁移到新的企业邮箱、企业邮箱邮件的到达率、功能强大的企业邮箱、企业邮箱手机客户端设置等方面考虑。 1.企业邮箱安全 企业邮箱应考虑病毒防治能力。Zoho Mail企业邮箱从物理安…

华硕笔记本系统更新后开机自动蓝屏怎么U盘重装系统?

华硕笔记本系统更新后开机自动蓝屏怎么U盘重装系统?有用户将自己的华硕笔记本进行系统升级之后,遇到了开机自动蓝屏的情况。遇到这个问题我们怎么去进行解决呢?接下来一起来看看怎么通过U盘重装系统的方法解决此问题吧。 准备工作&#xff1a…