DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

news2024/11/27 16:29:12

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1650503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【热门话题】实用Chrome命令:提升前端开发效率的利器

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 实用Chrome命令:提升前端开发效率的利器引言目录1. 快速打开Chrome …

批量将GOID转成GO term名并添加BP,MF,CC分类信息

基因本体论(Gene Ontology,GO,https://www.geneontology.org)是一个广泛应用于生物信息学领域的知识库,它提供了一套标准化的词汇和分类体系,用于描述基因功能、细胞组分和生物过程。GO旨在统一科研人员对基…

【Delphi7】Access violation at address 0019F7C3. Write of address 0019F7C3.

这里写目录标题 问题基本情况问题描述1、启动Delphi 开发程序 时连续报如下错误2、打开“工程”菜单下的“选项”页面时时连续报如下错误 解决方案1、打开“高级系统设置”2、打开“性能选项”3、添加“数据执行保护”的程序4、选择“数据执行保护”的程序5、应用“数据执行保护…

【免费】虚拟同步发电机(VSG)惯量阻尼自适应控制仿真模型【simulink】

目录 主要内容 仿真模型要点 2.1 整体仿真模型 2.2 电压电流双闭环模块 2.3 SVPWM调制策略 2.4 无功电压模块 2.5 自适应控制策略及算法 部分结果 下载链接 主要内容 该模型为simulink仿真模型,主要实现的内容如下: 随着风力发电、光…

数据结构复习/学习9--二叉树

一、堆与完全二叉树 1.堆的逻辑与物理结构 2.父节点与子节点的下标 3.大小根堆 二、堆的实现(大根堆为例) 注意事项总结: 注意堆中插入与删除数据的位置和方法与维持大根堆有序时的数据上下调整 三、堆排序 1.排升序建大堆效率高 注意事项…

信锐交换机简介及应用说明(1)

交换机关键参数及分类 1.线速 线速是指交换机的端口上每秒钟传输的bit数,单位为bps(bit per second,即每秒传输多少bit,一个bit也就是一个二进制数0或者1)。以我们常见的例子来说明的话,比如100M的网卡就…

(三)JSP教程——JSP动作标签

JSP动作标签 用户可以使用JSP动作标签向当前输出流输出数据&#xff0c;进行页面定向&#xff0c;也可以通过动作标签使用、修改和创建对象。 <jsp:include>标签 <jsp:include>标签将同一个Web应用中静态或动态资源包含到当前页面中。资源可以是HTML、JSP页面和文…

论文复现丨多车场带货物权重车辆路径问题:改进邻域搜索算法

引言 本系列文章是路径优化问题学习过程中一个完整的学习路线。问题从简单的单车场容量约束CVRP问题到多车场容量约束MDCVRP问题&#xff0c;再到多车场容量时间窗口复杂约束MDCVRPTW问题&#xff0c;复杂度是逐渐提升的。 如果大家想学习某一个算法&#xff0c;建议从最简单…

Xshell打开XFTP提示需要下载

使用xshell无法启动xftp的问题&#xff0c;下载Xftp&#xff1a;百度网盘 请输入提取码 使用方法&#xff1a; 解压以后&#xff0c;右键运行“!)绿化处理.bat”即可。

保姆级教程:从 0 到 1 将项目发布到 Maven 中央仓库【2024年5月】

前言 大家好&#xff0c;我叫阿杆&#xff0c;不叫阿轩 最近写了一个参数校验组件&#xff0c;名字叫 spel-validator&#xff0c;是基于 javax.validation 的一个扩展&#xff0c;目的是简化参数校验。 我把项目开源到了GitHub https://github.com/stick-i/spel-validator …

【C语言】用数组和函数实现扫雷游戏

用数组和函数实现扫雷游戏 游戏界面&#xff1a; 代码如下&#xff1a; game.h #pragma once #include <stdio.h> #include <stdlib.h> #include <time.h> #define EASY_COUNT 10 #define ROW 9 #define COL 9 #define ROWS ROW2 #define COLS COL2 //初始…

[正则表达式]正则表达式语法与运用(Regular Expression, Regex)

0. 在线工具 RegExr: Learn, Build, & Test RegEx 1. 场景列举 vim Linux命令行 sublime 编辑器 java、python等语言中 ... ... 不同场景、不同版本语法可能不一样 2. 以下示例数据与基本语法 &2024 &As20242024# 2024sA#abdcefgha_bdcefghABASDSADAASDASD…

经常发文章的你是否想过定时发布是咋实现的?

前言 可乐他们团队最近在做一个文章社区平台,由于人手不够,前后端都是由前端同学来写。后端使用 nest 来实现。 某一天周五下午,可乐正在快乐摸鱼,想到周末即将来临,十分开心。然而,产品突然找到了他,说道:可乐,我们要做一个文章定时发布功能。 现在我先为你解释一…

已经做了小20年电商梦的腾讯,终于找到了破局的方向~

我是王路飞。 随着短视频的爆火、抖音电商成功开辟出短视频/直播带货的电商新赛道。 已经做了小20年电商梦的腾讯&#xff0c;终于找到了破局的方向~ 这个方向&#xff0c;就是被腾讯马老板亲口认证为&#xff1a;全村&#xff08;全公司&#xff09;希望所在的视频号。 内…

Python-100-Days: Day11 Files and Exception

1.读取csv文件 读取文本文件时&#xff0c;需要在使用open函数时指定好带路径的文件名&#xff08;可以使用相对路径或绝对路径&#xff09;并将文件模式设置为r&#xff08;如果不指定&#xff0c;默认值也是r&#xff09;&#xff0c;然后通过encoding参数指定编码&#xf…

基于FPGA的DDS波形发生器VHDL代码Quartus仿真

名称&#xff1a;基于FPGA的DDS波形发生器VHDL代码Quartus仿真&#xff08;文末获取&#xff09; 软件&#xff1a;Quartus 语言&#xff1a;VHDL 代码功能&#xff1a; DDS波形发生器VHDL 1、可以输出正弦波、方波、三角波 2、可以控制输出波形的频率 DDS波形发生器原理…

手撸Mybatis(五)——连接数据库进行insert,update和delete

本专栏的源码&#xff1a;https://gitee.com/dhi-chen-xiaoyang/yang-mybatis。 引言 在上一章中&#xff0c;我们成功实现了数据库的连接&#xff0c;以及单个字段的查询、resultType映射查询、resultMap映射查询。在本章&#xff0c;我们将讲解关于增加、修改和删除操作。 …

【MATLAB源码-第204期】基于matlab的语音降噪算法对比仿真,谱减法、维纳滤波法、自适应滤波法;参数可调。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 语音降噪技术的目的是改善语音信号的质量&#xff0c;通过减少或消除背景噪声&#xff0c;使得语音更清晰&#xff0c;便于听者理解或进一步的语音处理任务&#xff0c;如语音识别和语音通讯。在许多实际应用中&#xff0c;如…

MySQL45讲(一)(40)

回顾binlog_formatstatement STATEMENT 记录SQL语句。日志文件小&#xff0c;节约IO&#xff0c;但是对一些系统函数不能准确复制或不能复制&#xff0c;如now()、uuid()等 在RR隔离级别下&#xff0c;binlog_formatstatement 如果执行insert select from 这条语句是对于一张…

动态规划算法:简单多状态问题

例题一 解法&#xff08;动态规划&#xff09;&#xff1a; 算法思路&#xff1a; 1. 状态表⽰&#xff1a; 对于简单的线性 dp &#xff0c;我们可以⽤「经验 题⽬要求」来定义状态表⽰&#xff1a; i. 以某个位置为结尾&#xff0c;巴拉巴拉&#xff1b; ii. 以某个位置为起…