ETH开源PPO算法学习

news2024/12/24 21:02:48

前言

项目地址:https://github.com/leggedrobotics/rsl_rl

项目简介:快速简单的强化学习算法实现,设计为完全在 GPU 上运行。这段代码是 NVIDIA Isaac GYM 提供的 rl-pytorch 的进化版。

下载源码,查看目录,整个项目模块化得非常好,每个部分各司其职。下面我们自底向上地进行讲解加粗的部分。

rsl_rl/
│ __init__.py

├─algorithms/
│ │ __init__.py
│ │ ppo.py # PPO算法的实现
│ │
├─env/
│ │ __init__.py
│ │ vec_env.py # 实现并行处理多个环境的向量化环境
│ │
├─modules/
│ │ __init__.py
│ │ actor_critic.py # 定义 Actor-Critic 网络结构
│ │ actor_critic_recurrent.py # 定义包含循环层的 Actor-Critic 网络
│ │ normalizer.py # 数据正规化工具,有助于训练过程的稳定性
│ │
├─runners/
│ │ __init__.py
│ │ on_policy_runner.py # 实现用于执行 on-policy 算法训练循环的运行器
│ │
├─storage/
│ │ __init__.py
│ │ rollout_storage.py # 存储和管理策略 rollout 数据的工具
│ │
└─utils/
│ __init__.py
│ neptune_utils.py # 用于与 Neptune.ai 集成的工具
│ utils.py # 通用实用工具函数
│ wandb_utils.py # 用于与 Weights & Biases 集成的工具

rollout 数据储存和管理(rollout_storage.py)

定义了一个名为 RolloutStorage 的类,用于存储和管理在强化学习训练过程中从环境中收集到的数据(称为rollouts)。

  • 定义Transition

用于存储单个时间步的所有相关数据,包括观察值、动作、奖励、完成标志(dones)、值函数估计、动作的对数概率、动作的均值和标准差,以及可能的隐藏状态(对于使用循环网络的情况)。

  • 特权观察值(Privileged Observations)

除了self.observations外还有self.privileged_observations的使用,在强化学习中是指那些在训练期间可用但在实际部署或测试时不可用的额外信息。这些信息通常提供了环境的内部状态或其他有助于学习的提示,但在现实世界应用中可能难以获得或完全不可用。在训练期间使用特权观察值的一种常见方法是通过教师-学生架构(我们常常也称作特权学习),其中一个拥有全部信息的教师模型(可以访问特权观察值)来指导一个学生模型(只能访问普通观察值)。学生模型的目标是模仿教师模型的决策,尽管它没有直接访问特权信息。

  • 奖励和优势的计算
    def compute_returns(self, last_values, gamma, lam):
        advantage = 0
        for step in reversed(range(self.num_transitions_per_env)):
            if step == self.num_transitions_per_env - 1:
                next_values = last_values
            else:
                next_values = self.values[step + 1]
            next_is_not_terminal = 1.0 - self.dones[step].float()
            delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
            advantage = delta + next_is_not_terminal * gamma * lam * advantage
            self.returns[step] = advantage + self.values[step]

        # Compute and normalize the advantages
        self.advantages = self.returns - self.values
        self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

这段代码实现的是在强化学习中计算回报(returns)和优势(advantages)的逻辑,具体是使用了一种称为广义优势估算(Generalized Advantage Estimation, GAE)的方法。GAE是一种权衡偏差和方差以及平滑回报信号的技术,由以下几个数学公式定义:

  1. TD残差(Temporal Difference Residual):
    δ t = R t + γ V ( S t + 1 ) ( 1 − d o n e t ) − V ( S t ) \delta_t = R_t + \gamma V(S_{t+1}) (1 - done_t) - V(S_t) δt=Rt+γV(St+1)(1donet)V(St)
    其中, δ t \delta_t δt是时刻 t t t的TD残差, R t R_t Rt是奖励, γ \gamma γ是折扣因子, V ( S t ) V(S_t) V(St)是状态 S t S_t St的价值函数估计, d o n e t done_t donet是表示当前状态是否为终止状态的指示函数(如果当前状态为终止状态,则 d o n e t = 1 done_t = 1 donet=1;否则, d o n e t = 0 done_t = 0 donet=0)。如果 d o n e t = 1 done_t = 1 donet=1,那么 γ V ( S t + 1 ) \gamma V(S_{t+1}) γV(St+1)项将为 0,因为终止状态之后没有未来回报。

  2. GAE优势估计:
    A t G A E ( γ , λ ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l A_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} AtGAE(γ,λ)=l=0(γλ)lδt+l
    在代码中,这个无限求和是通过迭代地计算来近似的,具体的迭代公式为:
    A t = δ t + ( γ λ ) A t + 1 ( 1 − d o n e t ) A_t = \delta_t + (\gamma \lambda) A_{t+1} (1 - done_t) At=δt+(γλ)At+1(1donet)
    其中, A t A_t At是时刻 t t t的优势估计, λ \lambda λ是用来平衡TD估计和蒙特卡罗估计之间权重的参数。

  3. 回报的计算:
    G t = A t + V ( S t ) G_t = A_t + V(S_t) Gt=At+V(St)
    其中, G t G_t Gt是时刻 t t t的回报估计。

代码中使用的变量名与数学符号的对应关系:

变量名数学符号含义
rewards[step] R t R_t Rt时刻 t t t的奖励
gamma γ \gamma γ折扣因子,用于计算未来奖励的现值
values[step] V ( S t ) V(S_t) V(St)状态 S t S_t St在当前策略下的价值函数估计
dones[step] d o n e t done_t donet指示当前状态 S t S_t St是否为终止状态的标志(1 表示终止,0 表示非终止)
delta δ t \delta_t δt时刻 t t t的 TD 残差
advantage A t A_t At时刻 t t t的优势估计,根据 GAE 方法计算
lam λ \lambda λ用于 GAE 计算中平衡 TD 估计和蒙特卡罗估计之间权重的参数
returns[step] G t G_t Gt时刻 t t t的回报估计
advantages A t n o r m A_t^{norm} Atnorm标准化后的优势估计
mu_A, sigma_A μ A \mu_A μA, σ A \sigma_A σA优势估计的平均值和标准差
epsilon ϵ \epsilon ϵ避免除零错误而加的小常数,通常取值为 1e-8

代码中的循环从最后一个转换开始向前迭代,使用以上的数学公式来计算每一步的优势和回报。最后,它还对优势进行了标准化处理,即从每个优势中减去所有优势的平均值,并除以标准差,以减少训练期间的方差并加速收敛。标准化公式如下:
A t n o r m = A t − μ A σ A + ϵ A_t^{norm} = \frac{A_t - \mu_A}{\sigma_A + \epsilon} Atnorm=σA+ϵAtμA
其中, μ A \mu_A μA是优势的平均值, σ A \sigma_A σA是优势的标准差, ϵ \epsilon ϵ​ 是为了防止除以零而加的一个小常数(在代码中为 1e-8)。

  • 轨迹的平均长度

类中并没有显式存储轨迹的长度,轨迹长度隐含在self.dones之中。代码中使用的方法是:将每个环境中最后一步置为‘1’,然后flatten(展开)、拼接所有环境中的dones得到flat_dones,差分数组中为‘1’位置的索引得到智能体在每个环境中的步数,即轨迹长度。这个统计量有助于了解训练过程中智能体的表现。

  • mini-batch迭代器

mini_batch_generator 函数通过在多个训练周期(num_epochs)内,从经验回放缓冲区中随机选择小批量数据(包括观察值 observations、动作 actions、奖励 rewards 等)来生成小批量数据集。该函数利用 torch.randperm 生成随机索引 indices 来随机化数据抽样,进而支持基于批处理的学习方法,如梯度下降。通过每次只处理必要的数据量,该生成器在优化模型参数的同时,也优化了内存使用,确保了训练过程的高效性和灵活性。

(未完待续)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1480163.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue cesium加载点与定位到指定位置

vue cesium定位到指定位置 window.viewer.camera.flyTo({destination: Cesium.Cartesian3.fromDegrees(point.longDeg, point.latDeg, 6500000), orientation: {heading: 6.2079384332084935, roll: 0.00031509431759868534, pitch: -1.535}, duration: 3})vue cesium加载点 …

Talk|卡内基梅隆大学熊浩宇:Open-world Mobile Manipulation-开放世界机器人学习系统

本期为TechBeat人工智能社区第575期线上Talk。 北京时间2月29日(周四)20:00,卡内基梅隆大学研究生—熊浩宇的Talk已准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “Open-world Mobile Manipulation-开放世界机器人学习系统”,将向…

sora会是AGI的拐点么?

©作者|谢国斌 来源|神州问学 OpenAI近期发布的Sora是一个文本到视频的生成模型。这项技术可以根据用户输入的描述性提示生成视频,延伸现有视频的时间,以及从静态图像生成视频。Sora可以创建长达一分钟的高质量视频,展示出对用户提示的精…

经典语义分割(一)全卷积神经网络FCN

经典语义分割(一)全卷积神经网络FCN 1 FCN网络介绍 FCN(Fully Convolutional Networks,全卷积网络) 用于图像语义分割,它是首个端对端的针对像素级预测的全卷积网络,自从该网络提出后,就成为语义分割的基…

【论文阅读笔记】Explicit Visual Prompting for Low-Level Structure Segmentations

1.介绍 Explicit Visual Prompting for Low-Level Structure Segmentations 低级结构分割的显式视觉提示 2023年发表在IEEE CVPR Paper Code 2.摘要 检测图像中低级结构(低层特征)一般包括分割操纵部分、识别失焦像素、分离阴影区域和检测隐藏对象。虽…

前端面试知识点合集

原型和原型链 任何函数都可以作为构造函数。当该函数通过 new 关键字调用的时候,就称之为构造函数。 var Parent function(){}//定义一个函数,那它只是一个普通的函数,不能称它为构造函数var instance new Parent(); //这时这个Parent就不…

论文阅读_代码生成模型_CodeGeeX

英文名称: CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Evaluations on HumanEval-X 中文名称: CodeGeeX:一种用于代码生成的预训练模型,并在HumanEval-X上进行多语言评估 链接: https://arxiv.org/abs/2303.17568 代码: http…

【六袆 - MySQL】MySQL 5.5及更高版本中,InnoDB是新表的默认存储引擎;

InnoDB 这是一个MySQL组件,结合了高性能和事务处理能力,以确保可靠性、健壮性和并发访问。它体现了ACID设计哲学。它作为一个存储引擎存在,处理使用ENGINEINNODB子句创建的或修改的表。请参阅第14章“InnoDB存储引擎”以获取有关架构细节和管…

AOP案例(黑马学习笔记)

需求 需求:将案例中增、删、改相关接口的操作日志记录到数据库表中 ● 就是当访问部门管理和员工管理当中的增、删、改相关功能接口时,需要详细的操作日志,并保存在数据表中,便于后期数据追踪。 操作日志信息包含: ●…

本地复制文本无法在Ubuntu终端中粘贴问题

在公司,安装Ubuntu环境后无法粘贴。 查询并自己实践后,解决方法如下: 1. sudo apt-get autoremove open-vm-tools 2. sudo apt-get install open-vm-tools-desktop 3.重启虚拟机 又可以愉快的复制粘贴了

解析平面设计师的任务:4个要点带你全面了解!

平面设计师是整个市场相对稀缺但需求非常大的职业,许多设计师都是主要公司竞争的对象。平面设计在我们的日常生活中非常常见,涉及广告设计、标志设计、名片设计等领域。因此,本文将从四个方面详细介绍平面设计。 什么是平面设计 说到平面设…

低功耗的CMOS实时时钟/日历电路,内置报警和定时器功能采用 DIP8、 SOP8、 TSSOP8三种封装形式,应用于移动电话,便携仪器上——D8563

D8563是低功耗的CMOS实时时钟/日历电路,它提供一个可编程时钟输出,一个中断输出和掉电检测器,所有的地址和数据通过IC总线接口串行传递。最大总线速度为400Kbitss每次读写数据后,内嵌的字地址寄存器会自动产生增量。 主要特点: …

Semantic human matting

1.introduction 数据集包括,时尚模特数据集,超过18.8w张模特图,从中选出35311张图片,DIM数据集,仅包含人类的图像,202个前景图像,背景来自coco数据集和互联网,背景图不含人类&#x…

输入一个字符,判断该数是否为素数

//输入一个字符,判断该数是否为素数,若是,输出该数,若否,输出大于该整数的第一个素数。例如,输入:14 输出:17(因为17是大于14的第一个素数) 代码&#xff1a…

从DDR到DDR2的变化

1、DDR2设计思路 前文分别讲解了SDRAM的工作原理和SDRAM到DDR的变化,DDR采用双沿传输数据,为了提高传输数据的速率,先后推出了DDR-200、DDR-266、DDR-333、DDR-400,后面的数字表示数据传输的速率,对应的时钟频率分别为…

不容错过!这7款视频格式转换器免费版真的好用【全】

随着数字媒体的不断发展,视频制作和分享已经成为人们生活中的常态。然而,不同的设备和平台对视频格式的要求却各不相同,这给视频编辑和分享带来了一定的困扰。 因此,免费的视频格式转换器变得至关重要。以下是7款视频格式转换器免…

什么是WiFi 7

福建厦门微思网络始于2002年,面向全国招生! 主要课程:华为、思科、红帽、Oracle、VMware、CISP安全系列、PMP....... 网络工程师实用课程华为HCIA课程介绍 网络工程师使用课程华为HCIP课程介绍 网络工程师使用课程华为HCIE课程介绍 WiFi …

Vue3使用JSX/TSX

文章目录 1. 什么是 JSX & TSX?JSX(JavaScript XML)TSX(TypeScript XML) 2.Vue3 中使用 TSX基本渲染 & 响应式 & 事件 3.JSX 和 template 哪个好呢?总结 1. 什么是 JSX & TSX? 提示:JSX…

Premiere模板,唯美大气光斑闪烁效果照片展示视频制作模板

Premiere模板,唯美大气光斑闪烁效果照片展示视频制作PR电子相册模板mogrt下载。 特点:Premiere Pro 2023或更高版本,高清分辨率:19201080,每秒25帧的帧速率,包括教程视频。来自PR模板网,下载地址&#xff1…

基于ssm学生公寓管理系统的设计与开发论文

学生公寓管理系统的设计与实现 摘要 如今,科学技术的力量越来越强大,通过结合较为成熟的计算机技术,促进了学校、医疗、商城等许多行业领域的发展。为了顺应时代的变化,各行业结合互联网、人工智能等技术,纷纷开展了…