Reinforce算法原理及Tensorflow代码实现

news2024/12/23 17:59:32

       Q-learning和DQN算法都是强化学习中的Value-based的方法,它们都是先经过Q值来选择动作。强化学习中还有另一大类是策略梯度方法(Policy Gradient Methods)。Policy Gradient 是一类直接针对期望回报(Expected Return)通过梯度下降(Gradient Descent)进行策略优化的强化学习方法。这一类方法避免了其他传统强化学习方法所面临的一些困难,比如,没有一个准确的价值函数,或者由于连续的状态和动作空间,以及状态信息的不确定性而导致的难解性(Intractability)。其中最著名的就是Policy Gradient,Policy Gradient算法又可以根据更新方式分为两大类:

蒙特卡罗更新方法:Reinfoce算法(回合更新)

时序差分更新方法:Actor-Critic算法(单步更新)

回顾蒙特卡罗方法和时序差分方法

       蒙特卡罗方法可以理解为算法完成一个回合之后,再利用这个回合的数据去学习,做一次更新。因为我们已经获得了整个回合的数据,所以也能够获得每一个步骤的奖励,我们可以很方便地计算每个步骤的未来总奖励,��Gt​。��Gt​是未来总奖励,代表从这个步骤开始,我们能获得的奖励之和。�1G1​代表我们从第一步开始,往后能够获得的总奖励。�2G2​代表从第二步开始,往后能够获得的总奖励。

       相比蒙特卡罗方法一个回合更新一次,时序差分方法是每个步骤更新一次,即每走一步,更新一次, 时序差分方法的更新频率更高。时序差分方法使用Q函数来近似地表示未来总奖励��Gt​。

Reinfoce算法原理

       Reinfoce使用蒙特卡罗方法估计每个状态下采取动作所获得的奖励期望值,然后用这些估计值计算策略梯度并更新策略参数。因为Reinfoce算法是一种无模型算法,它不需要对环境建立模型,也不需要预测值函数等中间步骤,相比其他强化学习算法更加简单和直接。

       Reinfoce算法在策略的参数空间中直观地通过梯度上升的方法逐步提高策略的性能。

▽�(�)=��∼��[∑�′=0∞▽������(��′∣��′)��′∑�=�′∞��−�′��]▽J(θ)=Eτ∼πθ​​[t′=0∑∞​▽θ​logπθ​(At′​∣St′​)γt′t=t′∑∞​γt−t′Rt​]

       由于折扣因子给未来的奖励赋予了较低的权重,使用折扣因子还有助于减少估计梯度时的方差大的问题。实际使用中,��′γt′经常被去掉,从而避免了过分强调轨迹早期状态的问题。

       虽然Reinfoce简单直观,但它的一个缺点是对梯度的估计有较大的方差。对于一个长度为L的轨迹,奖励��Rt​的随机性可能对L呈指数级增长。为了减轻估计的方差太大这个问题,一个常用的方法是引进一个基准函数�(��)b(Si​)。这里对�(��)b(Si​)的要求是:它只能是一个关于状态��Si​的函数(或者更确切地说,它不能是关于��Ai​的函数)。有了基准函数�(��)b(St​)之后,强化学习目标函数的梯度 ▽�(�)▽J(θ)可以表示成:

▽�(�)=��∼��[∑�′=0∞▽������(��′∣��′)(∑�=�′∞��−�′��−�(��′))]▽J(θ)=Eτ∼πθ​​[t′=0∑∞​▽θ​logπθ​(At′​∣St′​)(t=t′∑∞​γt−t′Rt​−b(St′​))]

Reinfoce算法的代码实现

算法伪代码:

代码详解:

考虑将整个算法放入一个类中,并将各部分代码写入对应的函数。这样可以使得代码更为简洁易读。PolicyGradient 类的结构如下所示:

 

ruby

复制代码

class PolicyGradient: def __init__(self, state_dim, action_num, learning_rate=0.02, gamma=0.99): ...... def get_action(self, s, greedy=False): # 基于动作分布选择动作 ...... def store_transition(self, s, a, r): # 存储从环境中采样的交互数据 ...... def learn(self): # 使用存储的数据进行学习和更新 ...... def _discount_and_norm_rewards(self): # 计算折扣化回报并进行标准化处理 ...... def save(self): # 存储模型 ...... def load(self): # 载入模型 ......

初始化函数先后创建了一些变量、模型并选择 Adam 作为策略优化器。在代码中,我们可以看出这里的策略网络只有一层隐藏层。

 

ini

复制代码

def __init__(self, state_dim, action_num, learning_rate=0.02, gamma=0.99): self.gamma = gamma self.state_buffer, self.action_buffer, self.reward_buffer = [], [], [] input_layer = tl.layers.Input([None, state_dim], tf.float32) layer = tl.layers.Dense( n_units=30, act=tf.nn.tanh, W_init=tf.random_normal_initializer(mean=0, stddev=0.3), b_init=tf.constant_initializer(0.1))(input_layer) all_act = tl.layers.Dense( n_units=action_num, act=None, W_init=tf.random_normal_initializer(mean=0, stddev=0.3), b_init=tf.constant_initializer(0.1))(layer) self.model = tl.models.Model(inputs=input_layer, outputs=all_act) self.model.train() self.optimizer = tf.optimizers.Adam(learning_rate)

在初始化策略网络之后,我们可以通过 get_action() 函数计算某状态下各动作的概率。通过设置’greedy=True’,可以直接输出概率最高的动作

 

scss

复制代码

def get_action(self, s, greedy=False): _logits = self.model(np.array([s], np.float32)) _probs = tf.nn.softmax(_logits).numpy() if greedy: return np.argmax(_probs.ravel()) return tl.rein.choice_action_by_probs(_probs.ravel())

但此时,我们选择的动作可能并不好。只有通过不断学习之后,网络才能做出越来越好的判断。每次的学习过程由 learn() 函数完成。我们使用标准化后的折扣化奖励交叉熵损失来更新模型。在每次更新后,学过的转移数据将被丢弃。

 

python

复制代码

def learn(self): discounted_ep_rs_norm = self._discount_and_norm_rewards() with tf.GradientTape() as tape: _logits = self.model(np.vstack(self.ep_obs)) neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=np.array(self.ep_as)) loss = tf.reduce_mean(neg_log_prob * discounted_ep_rs_norm) grad = tape.gradient(loss, self.model.trainable_weights) self.optimizer.apply_gradients(zip(grad, self.model.trainable_weights)) self.ep_obs, self.ep_as, self.ep_rs = [], [], [] # 清空片段数据 return discounted_ep_rs_norm

learn() 函数需要使用智能体与环境交互得到的采样数据。因此我们需要使用 store_tran-sition() 来存储交互过程中的每个状态、动作和奖励。

 

ruby

复制代码

def store_transition(self, s, a, r): self.ep_obs.append(np.array([s], np.float32)) self.ep_as.append(a) self.ep_rs.append(r)

策略梯度算法使用蒙特卡罗方法。因此,我们需要计算折扣化回报,并对回报进行标准化,也有助于学习。

 

python

复制代码

def _discount_and_norm_rewards(self): discounted_ep_rs = np.zeros_like(self.ep_rs) running_add = 0 for t in reversed(range(0, len(self.ep_rs))): running_add = running_add * self.gamma + self.ep_rs[t] discounted_ep_rs[t] = running_add# 标准化片段奖励 discounted_ep_rs -= np.mean(discounted_ep_rs) discounted_ep_rs /= np.std(discounted_ep_rs) return discounted_ep_rs

先准备好环境和算法。在创建好环境之后,我们产生一个名为 agent的 PolicyGradient 类的实例。

 

ini

复制代码

env = gym.make(ENV_ID).unwrapped np.random.seed(RANDOM_SEED) tf.random.set_seed(RANDOM_SEED) env.seed(RANDOM_SEED) agent = PolicyGradient( action_num=env.action_space.n, state_dim=env.observation_space.shape[0], ) t0 = time.time()

在训练模式中,使用模型输出的动作来和环境进行交互,之后存储转移数据并在每个片段更新策略。为了简化代码,智能体将在每局结束时直接进行更新。

 

css

复制代码

if args.train: all_episode_reward = [] for episode in range(TRAIN_EPISODES): state = env.reset() episode_reward = 0 for step in range(MAX_STEPS): if RENDER: env.render() action = agent.get_action(state) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward) state = next_state episode_reward += reward if done: break agent.learn() print(’Training | Episode: {} / {} | Episode Reward: {:.0f} | Running Time:{:.4f}’.format( episode + 1, TRAIN_EPISODES, episode_reward, time.time() - t0))

在每局游戏结束后的部分增加一些代码,以便更好地显示训练过程。我们显示每个回合的总奖励和通过滑动平均计算的运行奖励。之后可以绘制运行奖励以便更好地观察训练趋势。最后,存储训练好的模型。

 

scss

复制代码

agent.save() plt.plot(all_episode_reward) if not os.path.exists(’image’): os.makedirs(’image’) plt.savefig(os.path.join(’image’, ’pg.png’))

如果我们使用测试模式,则过程更为简单,只需要载入预训练的模型,再用它和环境进行交互即可。

 

css

复制代码

if args.test: agent.load() for episode in range(TEST_EPISODES): state = env.reset() episode_reward = 0 for step in range(MAX_STEPS): env.render() state, reward, done, info = env.step(agent.get_action(state, True)) episode_reward += reward if done: break print(’Testing | Episode: {} / {} | Episode Reward: {:.0f} | Running Time:{:.4f}’.format( episode + 1, TEST_EPISODES, episode_reward, time.time() - t0))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/636216.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电气火灾监控系统如何有效的预防木材加工企业电气火灾隐患

摘要:本文分析了木材加工企业的特点、现状及常见电气火灾隐患,提出了消灭电气火灾隐患的措施。结尾介绍了木材加工企业常用电气设备的选用及电气火灾监控系统在其低压配电系统的应用方案及产品选型。 关键词:木材加工企业;电气火…

网络服务——DHCP服务

个人简介:云计算网络运维专业人员,了解运维知识,掌握TCP/IP协议,每天分享网络运维知识与技能。座右铭:海不辞水,故能成其大;山不辞石,故能成其高。 个人主页:小李会科技的…

我和老刘又被搞惨了

前两天在调试PHY的时候遇到了一堆问题,老刘都不耐其烦的搞定了,这次我们开始调试音频部分,音频部分很简单,无非就是录音,要是能把录音的音频拿到了,那就万事大吉了。老刘也是信心满满,老刘对我说…

操作系统性能提升之内核锁优化

性能为王,系统的性能提升是每一个工程师的追求。目前,性能优化主要集中在消除系统软件堆栈中的低效率上或绕过高开销的系统操作。例如,内核旁路通过在用户空间中移动多个操作来实现这个目标,还有就是为某些类别的应用程序重构底层…

提升效率,从这款智能挂灯开始

❤️作者主页:小虚竹 ❤️作者简介:大家好,我是小虚竹。2022年度博客之星评选TOP 10🏆,Java领域优质创作者🏆,CSDN博客专家🏆,华为云享专家🏆,掘金年度人气作…

华为项目经理就是CEO,华为对项目经理的要求是什么?

项目经理要向上发展,下面我们来看看华为对项目经理的要求。 原文出自:pmo前沿

太空大战-第14届蓝桥杯国赛Scratch真题中级组第6题

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第148讲。 太空大战,本题是2023年5月28日上午举行的第14届蓝桥杯国赛Scratch图形化编程中级组真题第6题&am…

Flume学习---3、自定义Interceptor、自定义Source、自定义Sink

1、自定义Interceptor 1、案例需求 使用 Flume 采集服务器本地日志,需要按照日志类型的不同,将不同种类的日志发往不同的分析系统。 2、需求分析 在实际的开发中,一台服务器产生的日志类型可能有很多种,不同类型的日志可能需要发…

【微信公众平台对接】有关【上传图文消息内的图片获取URL】调用示例

1、微信接口说明: 2、调用示例 /*** 上传图文消息内的图片获取URL** param image* return*/PostMapping("uploadImg")public String uploadImg(MultipartFile image) {return wechatOpenService.uploadImg(image);}/*** 上传图文消息内的图片获取URL* htt…

css魔法:伪元素content内容竟然可以用css函数!

🌻 前言 CSS 伪元素用于设置元素指定部分的样式。伪元素中 ::before 和 ::after 是最常用的,它们分别用于在dom元素前/后插入内容,本文内容就是关于 ::before 和 ::after 的 content 内容的一些冷门用法展开的。 一般我们在使用伪元素时&…

基于Java+jsp+servlet的养老院管理系统设计和实现《收藏版》

基于Javajspservlet的养老院管理系统设计和实现《收藏版》 博主介绍:5年java开发经验,专注Java开发、定制、远程、指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取源码联系方…

大麦生成链接 大麦生成订单截图 抢票成功截图

一键生成购票链接 一键生成订单截图 下载程序:https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3

微服务工程搭建过程中的注意点

1、父工程pom.xml文件 1:父工程的maven坐标; 2:packaging使用pom; 原因:在Spring Cloud微服务工程中,通常会采用多模块的方式进行开发,父工程的pom文件中的packaging标签设置为pom,是…

操作系统 | 知识梳理 | 复习(上)

目录 📚操作系统概述 🐇操作系统中的抽象概念 📚准备知识 🐇中断输入输出 🐇软件中断 🐇处理器特权级 🐇操作系统的结构 📚程序的结构 🐇运行时视图简介 &…

SQL语句中EXISTS的详细用法大全

SQL语句中EXISTS的详细用法大全 前言一、建表1.在MySQL数据库建表语句2.在ORACLE数据库建表语句 二、在SELECT语句中使用EXISTS1.在SQL中使用EXISTS2.在SQL中使用NOT EXISTS3.在SQL中使用多个NOT EXISTS4.在SQL中使用多个EXISTS5.在SQL中使用NOT EXISTS和EXISTS 三、在DELETE语…

jmeter非gui运行,jtl生成了,但是html报告没有生成

jmeter非gui运行,jtl生成了,但是html报告没有生成,查看log,内容如下: 22:45:00,913 ERROR o.a.j.JMeter: Error generating dashboard: org.apache.jmeter.report.dashboard.GenerationException: Error while proces…

谷歌的passkey是什么?

谷歌的passkey是什么? 谷歌正在研发一种名为“Passkey”的新技术,它将用于用户身份验证。Passkey不同于传统的密码,它采用了硬件加密密钥(如安全密钥或生物识别方式)以及双因素身份验证等技术,可以更好地保…

微信:把元宇宙装进小程序

作为月活13.09亿的国民级应用,微信的每次小升级都很容易形成现象级。2023开年,微信放大招,试图把元宇宙装进小程序。 微信小程序 XR-FRAME 不久前,微信官方在开放社区贴出了“XR-FRAME”开发指南,这是一套为小程序定制…

RocketMQ 快速入门教程,手把手教教你干代码

目录 RocketMQ定义为什么要用消息中间件?应用解耦流量削峰数据分发 RocketMQ各部分角色介绍NameServer主机(Broker)生产者(Producer)消费者(Consumer)消息(Message) 使用RocketMQ的核心概念主题(Topic)消息队列(Message Queue)分组(Group)标签(Tag)偏移量(Offset) 普…

企业级信息系统开发讲课笔记4.11 Spring Boot中Spring MVC的整合支持

文章目录 零、学习目标一、Spring MVC 自动配置(一)自动配置概述(二)Spring Boot整合Spring MVC 的自动化配置功能特性 二、Spring MVC 功能拓展实现(一)创建Spring Boot项目 - SpringMvcDemo2021&#xff…