【rl-agents代码学习】01——总体框架

news2025/1/23 4:03:42

文章目录

  • rl-agent Get start
    • Installation
    • Usage
    • Monitoring
  • 具体代码

学习一下rl-agents的项目结构以及代码实现思路。

source: https://github.com/eleurent/rl-agents

rl-agent Get start

Installation

pip install --user git+https://github.com/eleurent/rl-agents

Usage

rl-agents中的大部分例子可以通过cd到scripts文件夹 cd scripts,执行 python experiments.py命令实现。

Usage:
  experiments evaluate <environment> <agent> (--train|--test)
                                             [--episodes <count>]
                                             [--seed <str>]
                                             [--analyze]
  experiments benchmark <benchmark> (--train|--test)
                                    [--processes <count>]
                                    [--episodes <count>]
                                    [--seed <str>]
  experiments -h | --help

Options:
  -h --help            Show this screen.
  --analyze            Automatically analyze the experiment results.
  --episodes <count>   Number of episodes [default: 5].
  --processes <count>  Number of running processes [default: 4].
  --seed <str>         Seed the environments and agents.
  --train              Train the agent.
  --test               Test the agent.

evaluate命令允许在给定的环境中评估给定的agent。例如,

# Train a DQN agent on the CartPole-v0 environment
$ python3 experiments.py evaluate configs/CartPoleEnv/env.json configs/CartPoleEnv/DQNAgent.json --train --episodes=200

每个agent都按照标准接口与环境交互:

action = agent.act(state)
next_state, reward, done, info = env.step(action)
agent.record(state, action, reward, next_state, done, info)

环境的配置文件

{
    "id": "intersection-v0",
    "import_module": "highway_env",
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 15,
        "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"],
        "features_range": {
            "x": [-100, 100],
            "y": [-100, 100],
            "vx": [-20, 20],
            "vy": [-20, 20]
        },
        "absolute": true,
        "order": "shuffled"
    },
    "destination": "o1"
}

agent的配置文件,核心就是"__class__": "<class 'rl_agents.agents.deep_q_network.pytorch.DQNAgent'>",利用agent_factory进行agent的创建。

{
    "__class__": "<class 'rl_agents.agents.deep_q_network.pytorch.DQNAgent'>",
    "model": {
        "type": "MultiLayerPerceptron",
        "layers": [128, 128]
    },
    "gamma": 0.95,
    "n_steps": 1,
    "batch_size": 64,
    "memory_capacity": 15000,
    "target_update": 512,
    "exploration": {
        "method": "EpsilonGreedy",
        "tau": 15000,
        "temperature": 1.0,
        "final_temperature": 0.05
    }
}

如果部分key缺失的话,会使用默认的值agent.default_config()

最后,可以在基准(baseline)测试中安排一批实验。然后在几个进程上并行执行所有实验。

# Run a benchmark of several agents interacting with environments
$ python3 experiments.py benchmark cartpole_benchmark.json --test --processes=4

基准配置文件包含环境配置列表和agent配置列表。

{
    "environments": ["configs/CartPoleEnv/env.json"],
    "agents": [
        "configs/CartPoleEnv/DQNAgent.json",
        "configs/CartPoleEnv/LinearAgent.json",
        "configs/CartPoleEnv/MCTSAgent.json"
    ]
}

Monitoring

有几种工具可用于监控agent性能:

  • Run metadata:为了可重复性,将运行所用的环境和agent配置合并,并保存到metadata.*.json文件中。
  • Gym Monitor:每次运行的主要统计数据(episode rewards, lengths, seeds)都会记录到episode_batch.*.stats.json文件中。可以通过运行scripts/analyze.py来自动可视化这些数据。
  • Logging:agent可以通过标准的Python日志记录库发送消息。默认情况下,所有日志级别为INFO的消息都会保存到logging.*.lo文件中。要保存日志级别为DEBUG的消息,请添加选项scripts/experiments.py --verbose
  • Tensorboard:默认情况下,一个tensoboard writer会记录有关有用标量、图像和模型图的信息到运行目录。可以通过运行以下命令来进行可视化:tensorboard --logdir <path-to-runs-dir>

具体代码

rl-agents核心代码集中在rl-agents文件夹和scripts文件夹中,其中,rl-agents主要实现相关的算法,scripts为相应的配置文件。
在这里插入图片描述

experiments.py为入口程序,先从它看起,其相应的用法如下:
在这里插入图片描述

Usage:
  experiments evaluate <environment> <agent> (--train|--test) [options]
  experiments benchmark <benchmark> (--train|--test) [options]
  experiments -h | --help

Options:
  -h --help              Show this screen.
  --episodes <count>     Number of episodes [default: 5].
  --no-display           Disable environment, agent, and rewards rendering.
  --name-from-config     Name the output folder from the corresponding config files
  --processes <count>    Number of running processes [default: 4].
  --recover              Load model from the latest checkpoint.
  --recover-from <file>  Load model from a given checkpoint.
  --seed <str>           Seed the environments and agents.
  --train                Train the agent.
  --test                 Test the agent.
  --verbose              Set log level to debug instead of info.
  --repeat <times>       Repeat several times [default: 1].

首先从main函数开始,根据evaluate或者benchmark执行相应的任务。暂且先从evaluate入手。

def main():
    opts = docopt(__doc__)
    if opts['evaluate']:
        for _ in range(int(opts['--repeat'])):
            evaluate(opts['<environment>'], opts['<agent>'], opts)
    elif opts['benchmark']:
        benchmark(opts)

evaluate主要完成envagent的创建以及evaluation 对象的创建,再根据选择train或test执行不同的程序。

def evaluate(environment_config, agent_config, options):
    """
        Evaluate an agent interacting with an environment.

    :param environment_config: the path of the environment configuration file
    :param agent_config: the path of the agent configuration file
    :param options: the evaluation options
    """
    logger.configure(LOGGING_CONFIG)
    if options['--verbose']:
        logger.configure(VERBOSE_CONFIG)
    env = load_environment(environment_config)
    agent = load_agent(agent_config, env)
    run_directory = None
    if options['--name-from-config']:
        run_directory = "{}_{}_{}".format(Path(agent_config).with_suffix('').name,
                                  datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
                                  os.getpid())
    options['--seed'] = int(options['--seed']) if options['--seed'] is not None else None
    evaluation = Evaluation(env,
                            agent,
                            run_directory=run_directory,
                            num_episodes=int(options['--episodes']),
                            sim_seed=options['--seed'],
                            recover=options['--recover'] or options['--recover-from'],
                            display_env=not options['--no-display'],
                            display_agent=not options['--no-display'],
                            display_rewards=not options['--no-display'])
    if options['--train']:
        evaluation.train()
    elif options['--test']:
        evaluation.test()
    else:
        evaluation.close()
    return os.path.relpath(evaluation.run_directory)

Evaluation类中主要包含以下函数:
在这里插入图片描述

__init__的一些参数说明

参数描述
env要解决的环境,可能是包装了AbstractEnv的环境
agent解决环境的AbstractAgent agent
directory工作空间目录路径
run_directory运行目录路径
num_episodes运行的episode数
trainingagent是处于训练模式还是测试模式
sim_seed环境/agent随机性源的种子
recover从文件中恢复agent参数。如果为True,则使用默认的最新保存。如果为字符串,则将其用作路径。
display_env渲染环境,并有一个监视器录制其视频
display_agent如果支持,将agent图形添加到环境查看器中
display_rewards通过episodes显示agent的性能
close_env当评估结束时,是否应该关闭环境
step_callback_fn在每个环境步骤之后调用的回调函数。它接受以下参数:(episode, env, agent, transition, writer)。

首先看一下train,根据agent是否有batched属性,分为run_batched_episodesrun_episodes

    def train(self):
        self.training = True
        if getattr(self.agent, "batched", False):
            self.run_batched_episodes()
        else:
            self.run_episodes()
        self.close()

run_episodes就是一般强化学习的基本过程,注意其中的reset step 等函数都是经过封装的。实现自己的算法时需要注意。run_batched_episodes则主要实现一些并行计算的任务,这一部分等之后再详细介绍。

在这里插入图片描述

    def run_episodes(self):
        for self.episode in range(self.num_episodes):
            # Run episode
            terminal = False
            self.reset(seed=self.episode)
            rewards = []
            start_time = time.time()
            while not terminal:
                # Step until a terminal step is reached
                reward, terminal = self.step()
                rewards.append(reward)

                # Catch interruptions
                try:
                    if self.env.unwrapped.done:
                        break
                except AttributeError:
                    pass

            # End of episode
            duration = time.time() - start_time
            self.after_all_episodes(self.episode, rewards, duration)
            self.after_some_episodes(self.episode, rewards)

test为模型测试部分

    def test(self):
        """
        Test the agent.

        If applicable, the agent model should be loaded before using the recover option.
        """
        self.training = False
        if self.display_env:
            self.wrapped_env.episode_trigger = lambda e: True
        try:
            self.agent.eval()
        except AttributeError:
            pass
        self.run_episodes()
        self.close()

其中eval也需要进行重写。

    def eval(self):
        """
            Set to testing mode. Disable any unnecessary exploration.
        """
        pass

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1202166.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第一章:线性查找

系列文章目录 文章目录 系列文章目录前言一、线性查找二、实现查找算法三、循环不变量四、复杂度分析五、常见复杂度六、测试算法性能总结 前言 从线性查找入手算法。 一、线性查找 线性查找目的在线性数据结构中一个一个查找目标元素输入数组和目标元素输出目标元素所在的索…

LCA

定义 最近公共祖先简称 LCA&#xff08;Lowest Common Ancestor&#xff09;。两个节点的最近公共祖先&#xff0c;就是这两个点的公共祖先里面&#xff0c;离根最远的那个。 性质 如果 不为 的祖先并且 不为 的祖先&#xff0c;那么 分别处于 的两棵不同子树中&#…

Clickhouse学习笔记(13)—— Materialize MySQL引擎

该引擎用于监听 binlog 事件&#xff0c;类似于canal、Maxwell等组件 ClickHouse 20.8.2.3 版本新增加了 MaterializeMySQL 的 database 引擎&#xff0c;该 database 能映射到 MySQL中的某个database &#xff0c;并自动在ClickHouse中创建对应ReplacingMergeTree。 ClickHous…

【Python Opencv】图片与视频的操作

文章目录 前言一、opencv图片1.1 读取图像1.2 显示图像1.3 写入图像1.4 示例代码 二、Opencv视频2.1 从相机捕获视频获取摄像头一帧一帧读取显示图片VideoCapture 中的get和set函数示例代码 2.2 从文件播放视频示例代码 2.3 保存视频示例代码 总结 前言 在计算机视觉和图像处理…

As Const:一个被低估的 TypeScript 特性

目录 理解 as const TypeScript的期望与现实 解决方案&#xff1a;as const 与 object.freeze 的比较 一个配合 as const 的更清洁的 go to root 函数 使用 as const 提取对象值 基于Vue3.0的优秀低代码项目 你有没有感觉 TypeScript中可能有一些被低估但却非常有用的工…

解析JSON字符串:属性值为null的时候不被序列化

如果希望属性值为null及不序列化&#xff0c;只序列化不为null的值。 1、测试代码 配置代码&#xff1a; mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); 或者通过注解JsonInclude(JsonInclude.Include.NON_NULL) //常见问题2&#xff1a;属性为null&a…

操作系统 | 虚拟机及linux的安装

​ &#x1f308;个人主页&#xff1a;Sarapines Programmer&#x1f525; 系列专栏&#xff1a;《操作系统实验室》&#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 目录结构 1.操作系统实验之虚拟机及linux的安装 1.1 实验目的 1.2 实验内容 1.3 实验步骤 …

修改django开发环境runserver命令默认的端口

runserver默认8000端口 虽然python manage.py runserver 8080 可以指定端口&#xff0c;但不想每次runserver都添加8080这个参数 可以通过修改manage.py进行修改&#xff0c;只需要加三行&#xff1a; from django.core.management.commands.runserver import Command as Ru…

蓝桥杯 选择排序

选择排序的思想 选择排序的思想和冒泡排序类似&#xff0c;是每次找出最大的然后直接放到右边对应位置&#xff0c;然后将最 右边这个确定下来&#xff08;而不是一个一个地交换过去&#xff09;。 再来确定第二大的&#xff0c;再确定第三大的… 对于数组a[]&#xff0c;具体…

虹科方案 | 汽车电子电气架构设计仿真解决方案

来源&#xff1a;虹科汽车电子 虹科方案 | 汽车电子电气架构设计仿真解决方案 导读 本文将介绍面向服务&#xff08;SOA&#xff09;的汽车TSN网络架构&#xff0c;并探讨RTaW-Pegase仿真与设计软件在TSN网络设计中的应用。通过RTaW将设计问题分解&#xff0c;我们可以更好地理…

低代码、零代码开源与不开源:区别解析

在如今日益发展的数字时代&#xff0c;程序开发变得越来越重要。为了实现日益提高的业务需求&#xff0c;开发人员必须能够以更高效、更灵活的方式构建和交货软件解决方案。低代码和零代码开源是近几年流行的两种开发方法。本文将探讨它们与传统非开源程序开发的差别&#xff0…

javaSE的发展历史以及openjdk和oracleJdk

1 JavaSE 的发展历史 1.1 Java 语言的介绍 SUN 公司在 1991 年成立了一个称为绿色计划&#xff08;Green Project&#xff09;的项目&#xff0c;由 James Gosling&#xff08;高斯林&#xff09;博士领导&#xff0c;绿色计划的目的是开发一种能够在各种消费性电子产品&…

PostGIS学习教程一:PostGIS介绍

一、什么是空间数据库 PostGIS是一个空间数据库&#xff0c;Oracle Spatial和SQL Server(2008和之后版本&#xff09;也是空间数据库。 但是这意味着什么&#xff1f;是什么使普通数据库变成空间数据库&#xff1f; 简短的答案是… 空间数据库像存储和操作数据库中其他任何…

前端面试之事件循环

什么是事件循环 首先&#xff0c; JavaScript是一门单线程的语言&#xff0c;意味着同一时间内只能做一件事&#xff0c;这并不意味着单线程就是阻塞&#xff0c;而是实现单线程非阻塞的方法就是事件循环 在JavaScript中&#xff0c;所欲任务都可以分为&#xff1a; 同步任务…

大洋钻探系列之二IODP 342航次是干什么的?(上)

本文简单介绍一下大洋钻探IODP 342航次&#xff0c;从中&#xff0c;我们一窥大洋钻探航次的风采。 IODP342的航次报告在网络上可以下载&#xff0c;英文名字叫《Integrated Ocean Drilling ProgramExpedition 342 Preliminary Report》&#xff0c;航次研究的主要内容是纽芬兰…

ts学习02-数据类型

新建index.html <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </h…

同为科技(TOWE)主副控智能自动断电桌面PDU插排

在这个快节奏的现代社会&#xff0c;我们越来越需要智能化的产品来帮助我们提高生活质量和工作效率&#xff0c;同时&#xff0c;为各种家用电器及电子设备充电成为不少消费者新的痛点。桌面插排如何高效、安全地管理这些设备&#xff0c;成为了一个亟待解决的问题。同为科技&a…

享受JoySSL证书买赠活动,提升您的网站安全和用户信任!

互联网时代&#xff0c;网站安全性和用户信任度变得尤为重要。作为您网站的保护盾&#xff0c;SSL证书是确保数据传输安全和建立可信连接的关键组成部分。在这个背景下&#xff0c;我们非常激动地宣布JoySSL平台推出了令人兴奋的SSL证书买赠活动&#xff1a;买二送一&#xff0…

如何选择共享wifi项目服务商,需要注意哪些?

在移动互联网时代&#xff0c;无线网络已经成为人们生活中不可或缺的一部分。随着5G时代的到来&#xff0c;共享WiFi项目成为了市场上备受关注的焦点。在众多共享WiFi公司中&#xff0c;如何选择共享wifi项目服务商合作&#xff0c;今天我们就来盘点下哪些公司可靠&#xff01;…

CC1310F128RSMR Sub-1GHz超低功耗无线微控制器芯片

CC1310F128RSMR QFN-32 Sub-1GHz超低功耗无线微控制器 CC1310F128RSMR是一款低成本、 超低功耗、Sub-1 GHz射频器件&#xff0c;它是Simplel ink微控制器(MCU)平台的一部分。该平台由Wi- Fi组成、蓝牙低功耗&#xff0c;Sub-1 GHz&#xff0c;以太网&#xff0c;Zigbee线程和主…