深度强化学习 Actor-Critic演员评论家 PPO

news2024/9/25 15:21:33

将策略(Policy Based)和价值(Value Based)相结合的方法:Actor-Critic算法,在强化学习领域最受欢迎的A3C算法,DDPG算法,PPO算法等都是AC框架。

 一、Actor-Critic算法简介

Actor-Critic从名字上看包括两部分,演员(Actor)和评价家(Critic)。其中Actor使用的是策略函数,负责生成动作(Action)并和环境交互。而Critic使用的是价值函数,负责评估Actor的表现,并指导Actor下一阶段的动作。

import gym
import itertools
import matplotlib
import numpy as np
import sys
import tensorflow as tf
import collections

if "../" not in sys.path:
    sys.path.append("../")
from Lib.envs.cliff_walking import CliffWalkingEnv
from Lib import plotting

matplotlib.style.use('ggplot')

env = CliffWalkingEnv()


class PolicyEstimator():
    """
    策略函数逼近
    """

    def __init__(self, learning_rate=0.01, scope="policy_estimator"):
        with tf.variable_scope(scope):
            self.state = tf.placeholder(tf.int32, [], "state")
            self.action = tf.placeholder(dtype=tf.int32, name="action")
            self.target = tf.placeholder(dtype=tf.float32, name="target")

            # This is just table lookup estimator
            state_one_hot = tf.one_hot(self.state, int(env.observation_space.n))
            self.output_layer = tf.contrib.layers.fully_connected(
                inputs=tf.expand_dims(state_one_hot, 0),
                num_outputs=env.action_space.n,
                activation_fn=None,
                weights_initializer=tf.zeros_initializer)

            self.action_probs = tf.squeeze(tf.nn.softmax(self.output_layer))
            self.picked_action_prob = tf.gather(self.action_probs, self.action)

            self.loss = -tf.log(self.picked_action_prob) * self.target

            self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            self.train_op = self.optimizer.minimize(
                self.loss, global_step=tf.contrib.framework.get_global_step())

    def predict(self, state, sess=None):
        sess = sess or tf.get_default_session()
        return sess.run(self.action_probs, {self.state: state})

    def update(self, state, target, action, sess=None):
        sess = sess or tf.get_default_session()
        feed_dict = {self.state: state, self.target: target, self.action: action}
        _, loss = sess.run([self.train_op, self.loss], feed_dict)
        return loss


class ValueEstimator():
    """
    值函数逼近器
    """

    def __init__(self, learning_rate=0.1, scope="value_estimator"):
        with tf.variable_scope(scope):
            self.state = tf.placeholder(tf.int32, [], "state")
            self.target = tf.placeholder(dtype=tf.float32, name="target")

            # This is just table lookup estimator
            state_one_hot = tf.one_hot(self.state, int(env.observation_space.n))
            self.output_layer = tf.contrib.layers.fully_connected(
                inputs=tf.expand_dims(state_one_hot, 0),
                num_outputs=1,
                activation_fn=None,
                weights_initializer=tf.zeros_initializer)

            self.value_estimate = tf.squeeze(self.output_layer)
            self.loss = tf.squared_difference(self.value_estimate, self.target)

            self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            self.train_op = self.optimizer.minimize(
                self.loss, global_step=tf.contrib.framework.get_global_step())

    def predict(self, state, sess=None):
        sess = sess or tf.get_default_session()
        return sess.run(self.value_estimate, {self.state: state})

    def update(self, state, target, sess=None):
        sess = sess or tf.get_default_session()
        feed_dict = {self.state: state, self.target: target}
        _, loss = sess.run([self.train_op, self.loss], feed_dict)
        return loss


def actor_critic(env, estimator_policy, estimator_value, num_episodes, discount_factor=1.0):
    """
    Actor Critic 算法.通过策略梯度优化策略函数逼近器

    参数:
        env: OpenAI环境.
        estimator_policy: 待优化的策略函数
        estimator_value: 值函数逼近器,用作评论家
        num_episodes: 回合数
        discount_factor: 折扣因子

    返回值:
        EpisodeStats对象,包含两个numpy数组,分别存储片段长度和片段奖励
    """

    # Keeps track of useful statistics
    stats = plotting.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes))

    Transition = collections.namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])

    for i_episode in range(num_episodes):
        state = env.reset()

        episode = []

        for t in itertools.count():

            action_probs = estimator_policy.predict(state)
            action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
            next_state, reward, done, _ = env.step(action)

            episode.append(Transition(
                state=state, action=action, reward=reward, next_state=next_state, done=done))

            stats.episode_rewards[i_episode] += reward
            stats.episode_lengths[i_episode] = t

            # 计算TD目标
            value_next = estimator_value.predict(next_state)
            td_target = reward + discount_factor * value_next
            td_error = td_target - estimator_value.predict(state)

            # 更新值函数逼近
            estimator_value.update(state, td_target)

            # 更新策略逼近
            # 使用TD误差作为优势估计
            estimator_policy.update(state, td_error, action)

            print("\rStep {} @ Episode {}/{} ({})".format(
                t, i_episode + 1, num_episodes, stats.episode_rewards[i_episode - 1]), end="")

            if done:
                break

            state = next_state

    return stats


tf.reset_default_graph()

global_step = tf.Variable(0, name="global_step", trainable=False)
policy_estimator = PolicyEstimator()
value_estimator = ValueEstimator()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    stats = actor_critic(env, policy_estimator, value_estimator, 300)

plotting.plot_episode_stats(stats, smoothing_window=10)

二、邻近策略优化(Proximal Policy Optimization,PPO)

邻近策略优化(Proximal Policy Optimization,PPO)算法解决的问题是离散动作空间和连续动作空间的强化学习问题,是on-policy的强化学习算法。

算法主要思想:策略pi接受状态s,

输出动作概率分布,在动作概率分布中采样动作,执行动作,得到回报,跳到下一个状态。在这样的步骤下,我们可以使用策略pi收集一批样本,然后使用梯度下降算法学习这些样本,但是当策略pi的参数更新后,这些样本不能继续被使用,还要重新使用策略pi与环境互动收集数据,真的非常耗时。因此采用重要性采样,使这些样本可以被重复使用

1. 模型结构

图片

PPO是基于Actor-Critic架构的,这个架构的优势是解决了连续动作空间的问题。

  • actor网络的输入为状态,输出为动作概率(对于离散动作空间而言)或者动作概率分布参数(对于连续动作空间而言)
  • critic网络的输入为状态,输出为状态的价值。

actor网络输出的动作使优势越大越好,critic网络输出的状态价值越准确越好。

2. 产生experience的过程

图片

已知一个状态s0,

  • 通过 actor网络 得到所有动作的概率(图中以三个动作:a,b,c为例),
  • 然后依概率采样得到动作a0,
  • 然后将a0输入到环境中得到s1和r1,

状态价值v(s0)通过critic网络输出得到,这样就得到一个experience: (s0,a0,r1,v(s0,logP(a0|s0)),然后将experience放入经验池中。

以上是离散动作的情况,如果是连续动作,就输出概率分布的参数(比如高斯分布的均值和方差),然后按照概率分布去采样得到动作a0。

经验池的意义是为了更方便的计算一条轨迹上状态的累积折扣回报v(st)以及优势A(st,at),而不是消除experience的相关性。

3. 网络更新

3.1 actor网络的更新流程

优势函数A的定义为:

图片

因为Actor网络需要输出的动作优势尽可能地大,所以它的训练需要用以下表达式作为Loss函数:

图片

其中

图片

反映了新旧策略差异的程度。

对于上式等价于如下形式:

图片

A大于0表示此时策略更好,要加大优化力度。目标函数取最大,那么就会尽量取大的r值,但如果更新力度过大,新旧策略差异就会太大,即

图片

,那么clip操作和min操作会进行限制,防止了过度优化。

PPO算法使用多步TD,因此它需要跑完一条轨迹后,才开始计算各个状态的累积回报和动作的优势。具体而言,状态价值是通过critic网络输出得到的,动作优势是通过先计算

图片

,然后用

图片

作为折扣因子去计算动作优势,公式如下:

图片

3.2 Critic网络的更新流程

Actor网络更新后,接着拿从经验池buffer中采出的数据进行Critic网络的更新(数据已经计算了状态价值,折扣回报Gt的计算是基于多步TD的方法,从那个状态开始,用每一步环境返回的奖励R与折扣因子相乘后累加,即:

图片

其中

图片

为网络的估计值,更新方式为:计算好的折扣回报与Critic网络预测当前状态价值做差,用MSEloss作为Loss函数,对神经网络进行训练。

算法流程如下:

图片

参考链接:强化学习PPO算法介绍PPO算法解决的问题是离散动作空间和连续动作空间的强化学习问题,是on-policy的强化学习算法。icon-default.png?t=N7T8https://mp.weixin.qq.com/s/pG9UzN1NjfBy4ZvgnNVRwQ

第十二章 深度强化学习-Actor-Critic演员评论家第十二章 Actor-Critic演员评论家我们在上一章中介绍了策略梯度(Policy Gradient)方icon-default.png?t=N7T8https://mp.weixin.qq.com/s?__biz=MzU1OTkwNzk4NQ==&mid=2247485611&idx=1&sn=5bf388ead8a1edc0051665d7b6f7825b&chksm=fc115d55cb66d4434d701ce86138e0345dffb909657abdad1e8f0da3eb700bccc1a4b86c52b4&scene=21#wechat_redirect

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1701022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Geoserver发布shp图层服务的样式控制及样式生成方式

在利用geoserver发布视频图层服务时,shp图层的样式可以在QGis文件中进行编辑;shp文件编辑后,需要导出样式文件,并在geoserver中进行注册,发布时对应shp图层文件时,需要选中对应样式,加载图层服务…

WorkPlus移动应用平台集成单点登录,实现统一门户解决方案

随着企业数字化转型的深入,移动办公已经成为企业提高工作效率和员工协作的重要途径。为了更好地管理企业移动应用,提升员工体验,简化登录流程,许多企业开始采用集成单点登录技术的企业移动应用平台,实现统一门户的目标…

实验室课程|基于SprinBoot+vue的实验室课程管理系统(源码+数据库+文档)

实验室课程管理系统 目录 基于SprinBootvue的实验室课程管理系统 一、前言 二、系统设计 三、系统功能设计 1管理员功能模块 2学生功能模块 3教师功能模块 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介…

PMP考试通关秘籍

考试大纲 考试大纲:考察维度3 个(人、过程、商业环境);更加贴近真实项目趋势;侧重点从做事到关注人;对于项目经理的软技能要求更高,匹配 PM 能力模型。 人员(42%)&…

55页PDF|人工智能通用大模型(ChatGPT)的进展、风险与应对(附下载)

👉获取方式: 😝有需要的小伙伴,可以保存图片到wx扫描二v码免费领取【保证100%免费】🆓

3D技术的应用领域

3D技术在现代科技和工业中有广泛的应用,其涵盖的领域非常广泛,从娱乐到医学,再到制造业和建筑,3D技术正在改变我们理解和互动的方式。以下是一些主要的应用领域。北京木奇移动技术有限公司,专业的软件外包开发公司&…

k8s devops实战教程+生产实践+可就业

k8s devops实战教程 简介教程涉及到内容教程获取学习教程后的收货助学群 简介 越来越多的企业应用云原生化,催生很多应用的部署方式也发生了很多变化。 从物理机部署应用过度到虚机部署应用再到应用容器化,从单应用再到服务拆分为微服务,靠人…

linux查看是否被入侵(一)

1、查看当前系统状态 [rootbastion-IDC ~]#top #一般挖矿等病毒点用CPU比较大 2、查看当前登录用户(w\who) 3、检查系统日志 检查系统错误登陆日志,统计IP重试次数 [rootbastion-IDC ~]# lastb 4、查看近期用户登录情况 [rootkvm01 ~]# last -n 5 #-n 5 表示…

element el-table表格表头某一列表头文字或者背景修改颜色

效果如下 整体代码 &#xff0c;具体方法在最下面&#xff01; <el-table v-loading"listLoading" :data"sendReceivList" element-loading-text"Loading" border fit ref"tableList" :header-cell-class-name"addClass&quo…

揭秘APP广告变现的高效秘诀:如何让你的APP更赚钱?

在数字化时代&#xff0c;APP已成为人们获取信息、娱乐休闲的重要平台。对于许多内容创作者来说&#xff0c;如何通过APP实现盈利&#xff0c;是一个亟待解决的问题。而APP广告变现项目&#xff0c;正是其中一种备受关注的盈利模式。那么&#xff0c;如何有效地利用APP广告变现…

安泰电子:功率放大器的选择方法有哪些

选择适合的功率放大器是实现电子系统中的关键步骤之一。以下是一些选择功率放大器的常用方法和考虑因素&#xff1a; 功率需求&#xff1a;首先确定你的系统需要多大的功率输出。功率输出需求通常由被驱动设备的功率要求决定。计算出所需功率后&#xff0c;选择一个具有适当功率…

绿色阅读:旧书回收,让知识循环

在快节奏的现代社会中&#xff0c;知识的获取和更新速度日新月异。然而&#xff0c;在这个信息爆炸的时代&#xff0c;我们是否曾想过&#xff0c;那些曾经陪伴我们度过无数日夜、给予我们智慧和启迪的旧书&#xff0c;在它们完成使命后&#xff0c;是否应该被遗忘在角落&#…

IdentiFace——多模态人脸识别系统,可捕捉从情绪到性别的所有信息及其潜力

1. 概述 面部识别系统的开发极大地推动了计算机视觉领域的发展。如今&#xff0c;人们正在积极开发多模态系统&#xff0c;将多种生物识别特征高效、有效地结合起来。 本文介绍了一种名为 IdentiFace 的多模态人脸识别系统。该系统利用基于 VGG-16 架构的模型&#xff0c;将人…

Go 语言安装部署(超详细版本)

在学习和使用 Go 语言时&#xff0c;正确的安装和配置是非常重要的一步。本文将介绍如何在不同操作系统上安装 Go 语言&#xff0c;并讨论一些常见的配置选项&#xff0c;帮助读者更好地了解和使用 Go 语言。无论是初学者还是有一定经验的开发者&#xff0c;都能从本文中获得有…

buuctf-相册

题目提示找到邮箱 下载是一个apk文件 他都不建议安装到手机了 我还是不找麻烦动调了吧 他说是mail,那行吧 找mail 找到就是这一段 base64 s3 notebook 这里可以看见加载了native库 所以要IDA 打开so文件 apk就是一个压缩包,直接解压就行 lib里面就有so文件 再根据熟知的…

深度剖析整型和浮点型数据在内存中的存储(C语言)

目录 整型在内存中的存储 为什么整型在内存中存储的是补码&#xff1f; 大小端字节序 为什么有大端小端&#xff1f; 浮点型家族 浮点数在内存中的存储 long long 整型在内存中的存储 整型在内存中有三种二进制表示形式&#xff1a;原码&#xff0c;反码&#xff0c;补码…

网站笔记:huggingface model memory calculator

Model Memory Utility - a Hugging Face Space by hf-accelerate 这个工具可以计算在 Hugging Face Hub上托管的大型模型训练和执行推理时所需的vRAM内存量。模型所需的最低推荐vRAM内存量表示为“最大层”的大小&#xff0c;模型的训练大约是其大小的4倍&#xff08;针对Adam…

Python-3.12.0文档解读-内置函数id()详细说明+记忆策略+常用场景+巧妙用法+综合技巧

一个认为一切根源都是“自己不够强”的INTJ 个人主页&#xff1a;用哲学编程-CSDN博客专栏&#xff1a;每日一题——举一反三Python编程学习Python内置函数 Python-3.12.0文档解读 目录 详细说明 概述 参数 返回值 特性 实现细节&#xff08;CPython&#xff09; 安全…

M-G370PDG惯性测量单元,可实时监测天线的姿态和位置变化

动中通天线系统通常包括天线、卫星信号跟踪器、调制解调器、电源管理单元和用户终端设备等部分。其中&#xff0c;天线是系统的关键部件&#xff0c;负责接收和发送卫星信号。随着移动载体的运动&#xff0c;天线需要实时调整方向&#xff0c;以保持与卫星的稳定连接。卫星信号…

uniapp页面vue3下拉触底发送获取新数据请求实现分页功能

页面下拉触底获取新数据实现分页功能实现方式有两种&#xff0c;根据自己的业务需求来定&#xff0c;不同的方案适用场景不一样&#xff0c;有的是一整个页面下拉获取新数据&#xff0c;有的是部分盒子内容滚动到底部时候实现获取新数据&#xff0c;下面讨论一下两种方式的区别…