【torchrl】强化学习训练流程

news2024/11/17 8:26:53

在这里插入图片描述

1 采集数据阶段

上面这个循环是用来采集数据,并且加入到replay buffer中。最终获取的数据是
- s: 当前状态,或者observation
- a: 当前动作,后面重要性采样需要用到
- pa: 选择当前动作的概率,后面重要性采样用到
- r: 当前的奖励值
- s’: 下一个状态

有些算法可能会直接在这里估计state value,即这步的下面一步合并到这步中,那么就会增加state value, 和next state value
这里也可以采集多步,再加入到replay buffer中。

2 训练阶段

一般是采集一步数据,加入replay buffer,然后采样若干数据训练多次。也有是采集多步数据,然后训练多次。

2.1 critic损失计算

  1. 计算target value需要通过advatage+state value计算。
  2. advatage具体的计算方式是:
    adv = r + γ v ′ − v \text{adv} = r+\gamma v'-v adv=r+γvv
    这里的v和v’是当前value和下一步value,adv间接表示了动作的价值。
    当多步的adv加权求和之后,可以获得vae,它表示了更加优秀的动作价值估计,在计算actor loss时使用。

adv+v就是当前状态价值了,或者说 v t = r + γ v ′ v_t=r+\gamma v' vt=r+γv,这个值就是critic的目标。

2.2 actor损失的计算

  1. 首先通过当前的actor网络,估计出了动作action的概率pa
  2. 损失就是-vae(pa-pa’),这里pa’是之前选择动作action的概率。
  3. 对于PPO来说,可以对损失进行进一步裁剪。

以上的训练阶段的两个损失,均是通过PPOLoss计算的。具体可以看:https://github.com/pytorch/rl/blob/main/torchrl/objectives/ppo.py

问题

1 在收集之后马上计算state value,和从replay buffer取出来后再计算state value,两者有什么区别?

下面是gpt回答:

  1. 将状态值加入到Replay Buffer中

优点:

  • 稳定性:由于状态值是与状态和动作对应的估计值,这样存储在replay buffer中的值是基于当时的critic网络计算的,避免了因critic网络更新而引起的估计不稳定。
    减少计算开销:在每次使用replay buffer中的样本进行训练时,直接读取存储的状态值,可以减少重新计算的开销。

缺点:

  • 过时的估计:随着训练的进行,critic网络不断更新,存储在replay buffer中的状态值可能变得过时,不再准确反映当前的网络状态。
  • 存储空间:需要额外的存储空间来保存这些状态值,增加了内存的需求。
  1. 通过当前的Critic网络重新估计

优点:

  • 最新的估计:每次使用replay buffer中的样本时,使用当前的critic网络重新估计状态值,保证了状态值反映的是最新的网络状态。
  • 避免过时信息:由于每次都重新计算,可以避免使用过时的信息进行更新,提高了训练的有效性。

缺点:

  • 计算开销增加:每次从replay buffer中取样本时都需要通过当前的critic网络重新估计状态值,增加了计算开销。
  • 可能的估计不稳定:由于critic网络在训练过程中不断更新,状态值的估计可能会有较大的波动,导致训练的不稳定性。

两者应该都可以,在torchrl中也会在vae计算时候检查是否已经估计了state value,没有的话会自动帮你调用critic估计一下。

2 在网络前加上RNN是否破坏了马尔可夫性

如果只能获取有限状态,应该是不影响。如果是所有状态,则影响。

3 replay buffer应该存储什么

如果采集到数据,马上就计算state value,那么其实不需要保存state,也就是critic(no grad)这一步可以放在step之前,然后在replay buffer中不再存储state,而是state value。这两种方式都可以,看自己选择了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1706870.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MVCC机制

个人理解篇,不一定对,应付面试的时候看的 MVCC(Multi-Version Concurrency Control)全称多版本并发控制,主要用在隔离模式下的提交读、可重复读模式下,依赖于readview和undolog链 一、readview 1、结构 字段 备注 m_ids 活跃…

合作伙伴推广不积极?跟奖金到账时间有关!

在推广返现活动中,对于合作伙伴推广者来说,奖金是否及时到账是他们最关心的问题之一。如果品牌主一直不审批奖励数据,推广者则无法及时收到奖金,这很容易影响他们的推广积极性和忠诚度。怎样能够提高奖励审核的效率呢?…

失落的方舟台服账号怎么注册 失落的方舟台服注册收不到验证码

《失落的方舟》(Lost Ark)是由韩国Smilegate公司研发的一款大型多人在线角色扮演游戏(MMORPG)。该游戏以其精美的画面、丰富的剧情、动作类游戏的战斗手感以及广阔的开放世界而著称,自发布以来便吸引了全球众多游戏玩家…

边缘计算网关的主要功能有哪些?天拓四方

随着物联网(IoT)的快速发展和普及,边缘计算网关已经成为了数据处理和传输的重要枢纽。作为一种集成数据采集、协议转换、数据处理、数据聚合和远程控制等多种功能的设备,边缘计算网关在降低网络延迟、提高数据处理效率以及减轻云数…

k8s中的集群调度

文章目录 k8s中的集群调度Pod 创建流程 通过指定节点来创建pod所在的node节点通过标签来指定pod创建在哪个节点上pod 的亲和性Pod的亲和性和反亲和性亲和性(Affinity)反亲和性(Anti-Affinity) 污点与容忍污点(Taint&am…

探索未来设计新境界,PSAI插件 艺术创作神器来袭!

想象一下,如果有一个工具,能够让你的设计工作变得既简单又高效,那会是怎样的体验?现在,梦想成真了! 这是一款革命性的PSAI设计插件,专为创意人士打造。它将彻底改变你的设计流程,让你…

【python脚本】修改目标检测的xml标签(VOC)类别名

需求: 在集成多个数据集一同训练时,可能会存在不同数据集针对同一种目标有不同的类名,可以通过python脚本修改数据内的类名映射,实现统一数据集标签名的目的。 代码: # -*- coding: utf-8 -*- # Time : 2023/9/11 1…

js setTimeout、setInterval、promise、async await执行顺序梳理

基础知识 async: 关键字用于标记一个函数为异步函数,该函数中有一个或多个promise对象,需要等待执行完成后才会继续执行。 await:关键字,用于等待一个promise对象执行完,并返回其中的值,只能在async函数内部使用。可…

RT_Thread内核源码分析(一)——CM3内核和上下文切换

目录 一、程序存储分析 1.1 CM3内核寻址空间映射 1.2 程序静态存储和动态执行 二、CM3内核相关知识 2.1 操作模式和特权极别 2.2 环境相关寄存器 2.2.1 通用寄存器组, 2.2.2 状态寄存器组 2.2.3 模式切换环境自动保存 2.2.4 函数调用形参位置 2.3 …

OC IOS 文件解压缩预览

热很。。热很。。。。夏天的城市只有热浪没有情怀。。。 来吧,come on。。。 引用第三方库: pod SSZipArchive 开发实现: 一、控制器实现 头文件控制器定义: // // ZipRarViewController.h // // Created by carbonzhao on 2…

solidworks 3D草图案例2-方块异形切

单位mm 单位mm 长方体 底面是48mm*48mm,高为60mm 3D草图 点击线,根据三视图,绘制角度线, 由于三点确定一个面,因此确定三点就可以了 基准面 点击参考几何体-基准面,依次点击3个点 曲面切除 完成后点击插…

02--大数据Hadoop集群实战

前言: 前面整理了hadoop概念内容,写了一些概念和本地部署和伪分布式两种,比较偏向概念或实验,今天来整理一下在项目中实际使用的一些知识点。 1、基础概念 1.1、完全分布式 Hadoop是一个开源的分布式存储和计算框架&#xff0…

Serverless应用引擎SAE评测|一分钟部署在线游戏

Serverless应用引擎SAE评测|一分钟部署在线游戏 什么是Serverless应用引擎SAE一分钟部署在线游戏SAE控制台 资源释放其他操作 在进行Serverless应用引擎SAE评测之前,首先需要了解一下什么是SAE。 什么是Serverless应用引擎SAE Serverless应用引擎SAE(Se…

超频是什么意思?超频的好处和坏处

你是否曾经听说过超频?在电脑爱好者的圈子里,这个词似乎非常熟悉,但对很多普通用户来说,它可能还是一个神秘而陌生的存在。 电脑超频是什么意思 电脑超频(Overclocking),顾名思义,是…

C++面向对象程序设计 - 标准输出流

在C中,标准输出流通常指的是与标准输出设备(通常是终端或控制台)相关联的流对象。这个流对象在C标准库中被定义为std::cout、std::err、std::clog,它们是std::ostream类的一个实例。 一、cout,cerr和clog流 ostream类…

VLDB ’25 最后 6 天截稿,58 个顶会信息纵览;ISPRS 城市分割数据集上线

「顶会」板块上线 hyper.ai 官网啦!该板块为大家提供最新最全的 CCF A 类计算机顶会信息,包含会议简介、截稿倒计时、投稿链接等。 你是不是已经注册了顶会,但对截稿时间较为模糊,老是在临近 ddl 时才匆忙提交;又或者…

监控云安全的9个方法和措施

如今,很多企业致力于提高云计算安全指标的可见性,这是由于云计算的安全性与本地部署的安全性根本不同,并且随着企业将应用程序、服务和数据移动到新环境,需要不同的实践。检测云的云检测就显得极其重要。 如今,很多企业…

模拟量4~20mA电流传感器接线方式

一、模拟量4~20mA电流传感器接线方式 无源双线制是常见的电流型传感器接线方式,它具有简单、经济的特点。其接线方式如下: 传感器的“”接到数据采集器的电源“”上, 传感器的“-”端子连接到数据采集器的“AI”端子上, 数据采集器…

翻译《The Old New Thing》- What did MakeProcInstance do?

What did MakeProcInstance do? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20080207-00/?p23533 Raymond Chen 2008年02月07日 MakeProcInstance 做了什么? MakeProcInstance 宏实际上什么也不做。 #define MakeProcInst…

HackTheBox-Machines--Beep

Beep测试过程 1 信息收集 nmap端口扫描 gryphonwsdl ~ % nmap -sC -sV 10.129.137.179 Starting Nmap 7.94 ( https://nmap.org ) at 2024-05-28 14:39 CST Nmap scan report for 10.129.229.183 Host is up (0.28s latency). Not shown: 988 closed tcp ports (conn-refused…