【强化学习】gymnasium自定义环境并封装学习笔记

news2024/11/27 13:30:14

【强化学习】gymnasium自定义环境并封装学习笔记

  • gym与gymnasium简介
    • gym
    • gymnasium
  • gymnasium的基本使用方法
  • 使用gymnasium封装自定义环境
    • 官方示例及代码
    • 编写环境文件
      • __init__()方法
      • reset()方法
      • step()方法
      • render()方法
      • close()方法
  • 注册环境
    • 创建包 Package(最后一步)
    • 创建自定义环境示例
  • 参考文献

gym与gymnasium简介

gym

  • gym(OpenAI Gym)和gymnasium是两个不同的Python库,它们都旨在为强化学习研究提供环境和工具

  • gym出现的原因:不同于监督学习那样需要的是数据集,强化学习需要的是运行任务所需的环境,研究人员需要拥有标准化的环境和模块化的强化学习代码,方便复用以及方便研究人员能够在相同的环境和条件下测试算法

  • gym通过提供一个统一的接口,

  • gym(OpenAI Gym)是由OpenAI团队开发的,是最早和最广泛使用的强化学习环境库之一

  • 用于开发和比较强化学习算法的工具包和测试平台,提供了一个统一的接口来控制和交互各种环境

  • 截止2023年,Gym 已经不再更新或维护,最新版本为v0.26.2
    在这里插入图片描述

  • Gym的最新版本为v0.26.2,并且从这个版本开始,Gym的维护工作由Farama Foundation接手,并推出了Gymnasium

  • gym官网
    在这里插入图片描述

gymnasium

  • 所有Gym的开发工作已经转移到Gymnasium
  • gymnasium是一个较新的库,它试图解决gym中的一些限制和问题,并提供更现代化的接口
  • gymnasium设计时考虑了与gym的兼容性。它提供了一个兼容层,使得大多数gym环境可以直接在gymnasium中使用,无需或只需很少的修改
  • gymnasium官网
    在这里插入图片描述

gymnasium的基本使用方法

暂时先略过,日后补上,先介绍gymnasium封装自定义环境


使用gymnasium封装自定义环境

  • gymnasium官方介绍封装自定义环境的文档,本文主要基于此文档
  • 官方提供了示例代码,链接在此
  • 安装gymnasium命令:
pip install gymnasium

官方示例及代码

  • 官方使用的示例代码结构,如下所示
    在这里插入图片描述
  • wrappers是指包装器,用于修改或增强现有环境的行为,而不需要直接修改环境的源代码
  • 使用 wrappers 的一个关键优势是它们提供了一种灵活的方式来修改和扩展环境的功能,而不需要改变环境本身的实现。这使得研究人员可以专注于算法的开发,同时利用 wrappers 来适应不同的实验条件和研究目标。
  • env文件夹下的文件是环境名字

在这里插入图片描述

  • 在命令行中可使用tree 命令查看目录及文件结构,windows下需要使用/F参数来显示文件
tree /F C:\path\to\directory

在这里插入图片描述

编写环境文件

  • 所有自定义环境必须继承抽象类gymnasium.Env
  • 同时需要定义metadata,在 Gym 环境中,metadata 字典包含了环境的元数据,这些数据提供了关于环境行为和特性的额外信息

“render_modes”: 这个键的值是一个列表,指明了环境支持的渲染模式。在这个例子中,环境支持两种渲染模式:
“human”: 这种模式通常是指在屏幕上以图形界面的形式渲染环境,适合人类观察者观看。
“rgb_array”: 这种模式下,环境的渲染结果会以 RGB 数组的形式返回,这可以用于机器学习算法的输入,或者进行进一步的处理和分析。
“render_fps”: 这个键表示环境渲染的帧率,即每秒钟可以渲染的帧数。在这个例子中,4 表示环境将以每秒 4 帧的速率进行渲染。这通常用于控制渲染速度,使动画的播放更加平滑或符合特定的显示需

在这里插入图片描述

  • 在环境文件中需要实现__init__(),reset().setp(),render(),close()等方法,确保环境能够按照强化学习的标准工作流程运行
  • 定义action_space,智能体可以执行的动作类型和范围;定义observation_space,智能体可以观察到的状态的类型和范围
  • from gymnasium import spaces
  • 连续的空间使用spaces.Box 定义,low 和 high 参数指定了取值范围。
  • 离散的空间使用spaces.Discrete,参数指定可能的数量

init()方法

  • 初始化方法,用于设置环境的初始状态。这里可以定义环境参数、初始化状态空间和动作空间等
  • 定义 action_space 和 observation_space时,需要从 Gymnasium 的 spaces 模块导入spaces
  • spaces 模块提供了多种空间类型,用于表示强化学习环境中可能的动作和观察的类型和结构
    在这里插入图片描述
import numpy as np
import pygame

import gymnasium as gym
from gymnasium import spaces


class GridWorldEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, size=5):
        self.size = size  # The size of the square grid
        self.window_size = 512  # The size of the PyGame window

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_space = spaces.Dict(
            {
                "agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
                "target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
            }
        )

        # We have 4 actions, corresponding to "right", "up", "left", "down"
        self.action_space = spaces.Discrete(4)

        """
        The following dictionary maps abstract actions from `self.action_space` to
        the direction we will walk in if that action is taken.
        I.e. 0 corresponds to "right", 1 to "up" etc.
        """
        self._action_to_direction = {
            0: np.array([1, 0]),
            1: np.array([0, 1]),
            2: np.array([-1, 0]),
            3: np.array([0, -1]),
        }

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None

reset()方法

  • 用于重置环境状态,在每个训练周期(episode)开始时,reset() 方法被调用以重置环境到一个初始状态
  • 每次训练周期结束并且接收到结束信号(done 标志)时,会调用 reset 方法来重置环境状态
  • 用户可以通过 reset 方法传递一个 seed 参数,用于初始化环境使用的任何随机数生成器,确保环境行为的确定性和可复现性
def reset(self, seed=None, options=None):
    # We need the following line to seed self.np_random
    super().reset(seed=seed)

    # Choose the agent's location uniformly at random
    self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)

    # We will sample the target's location randomly until it does not coincide with the agent's location
    self._target_location = self._agent_location
    while np.array_equal(self._target_location, self._agent_location):
        self._target_location = self.np_random.integers(
            0, self.size, size=2, dtype=int
        )

    observation = self._get_obs()
    info = self._get_info()

    if self.render_mode == "human":
        self._render_frame()

    return observation, info

step()方法

  • step()方法是环境与智能体交互的核心,包含了环境逻辑的核心部分
  • step()方法处理动作,更新环境状态,并返回五个值组成的元组(observation, reward, terminated, truncated, info):观察(observation)、奖励(reward)、是否终止(terminated)、是否截断(truncated)和附加信息(info)

五元组的含义(observation, reward, terminated, truncated, info)

观察(Observation):这是环境状态的表示,智能体根据这个观察来选择动作。观察可以是状态的一部分或全部,也可以是经过加工的信息,如图像、向量等。观察是智能体与环境交互的直接输入。
奖励(Reward):这是一个标量值,表示智能体执行动作后从环境中获得的即时反馈。奖励用于指导智能体学习哪些行为是好的,哪些是不好的。在许多任务中,智能体的目标是最大化其获得的总奖励。
是否终止(Terminated/Done):这是一个布尔值,表示当前周期(episode)是否结束。如果为 True,则表示智能体已经完成了任务,或者环境已经达到了一个终止状态,智能体需要重新开始新的周期。
是否截断(Truncated):这也是一个布尔值,与 done 相似,但表示周期结束的原因可能不是任务完成,而是其他原因,如超时、达到某个特定的中间状态或违反了某些规则。在某些实现中,truncated 可能与 done 相同或不被使用。
附加信息(Info):这是一个字典,包含除观察、奖励、终止和截断之外的额外信息。这些信息可以包括关于状态转换的元数据,如是否处于探索阶段、环境的内部计数器、额外的性能评估指标等。

def step(self, action):
    # Map the action (element of {0,1,2,3}) to the direction we walk in
    direction = self._action_to_direction[action]
    # We use `np.clip` to make sure we don't leave the grid
    self._agent_location = np.clip(
        self._agent_location + direction, 0, self.size - 1
    )
    # An episode is done iff the agent has reached the target
    terminated = np.array_equal(self._agent_location, self._target_location)
    reward = 1 if terminated else 0  # Binary sparse rewards
    observation = self._get_obs()
    info = self._get_info()

    if self.render_mode == "human":
        self._render_frame()

    return observation, reward, terminated, False, info
  • info可通过_get_info方法获取,该方法用于收集和返回除了观察和奖励之外的其他有用信息。这些信息可以包括关于环境状态的额外数据
  • _get_obs方法负责将环境的内部状态转换为智能体可以观察的形式,通常涉及到从环境状态中提取相关信息,并将其格式化为智能体能够理解和使用的数据结构
def _get_obs(self):
    return {"agent": self._agent_location, "target": self._target_location}
def _get_info(self):
    return {
        "distance": np.linalg.norm(
            self._agent_location - self._target_location, ord=1
        )
    }

render()方法

  • render 方法用于将环境的状态可视化
  • 使用 Gymnasium 创建自定义环境时,PyGame 是一种流行的库,用于渲染环境的视觉表示。PyGame 允许创建图形窗口,并将环境的状态绘制到屏幕上,这对于需要视觉反馈的强化学习任务非常有用

渲染模式:

  • “human”:以图形界面的形式渲染,适用于人类观察者。
  • “rgb_array”:返回一个 RGB 图像数组,可以用于机器学习模型或进一步处理。
  • 下面为示例代码中的render方法
def render(self):
    if self.render_mode == "rgb_array":
        return self._render_frame()

def _render_frame(self):
    if self.window is None and self.render_mode == "human":
        pygame.init()
        pygame.display.init()
        self.window = pygame.display.set_mode(
            (self.window_size, self.window_size)
        )
    if self.clock is None and self.render_mode == "human":
        self.clock = pygame.time.Clock()

    canvas = pygame.Surface((self.window_size, self.window_size))
    canvas.fill((255, 255, 255))
    pix_square_size = (
        self.window_size / self.size
    )  # The size of a single grid square in pixels

    # First we draw the target
    pygame.draw.rect(
        canvas,
        (255, 0, 0),
        pygame.Rect(
            pix_square_size * self._target_location,
            (pix_square_size, pix_square_size),
        ),
    )
    # Now we draw the agent
    pygame.draw.circle(
        canvas,
        (0, 0, 255),
        (self._agent_location + 0.5) * pix_square_size,
        pix_square_size / 3,
    )

    # Finally, add some gridlines
    for x in range(self.size + 1):
        pygame.draw.line(
            canvas,
            0,
            (0, pix_square_size * x),
            (self.window_size, pix_square_size * x),
            width=3,
        )
        pygame.draw.line(
            canvas,
            0,
            (pix_square_size * x, 0),
            (pix_square_size * x, self.window_size),
            width=3,
        )

    if self.render_mode == "human":
        # The following line copies our drawings from `canvas` to the visible window
        self.window.blit(canvas, canvas.get_rect())
        pygame.event.pump()
        pygame.display.update()

        # We need to ensure that human-rendering occurs at the predefined framerate.
        # The following line will automatically add a delay to keep the framerate stable.
        self.clock.tick(self.metadata["render_fps"])
    else:  # rgb_array
        return np.transpose(
            np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
        )

close()方法

  • close 方法用于在环境不再使用时进行清理操作,例如关闭图形界面窗口、释放资源或执行其他必要的清理任务
  • 是一个没有参数也没有返回值的方法
  • 如果环境使用 PyGame 或其他图形库创建了渲染窗口,close 方法应该关闭这些窗口。
def close(self):
    if self.window is not None:
        pygame.display.quit()
        pygame.quit()

注册环境

  • 编写完上述与环境相关的代码后,需要注册自定义环境
  • 注册自定义环境是为了使gymnasium检测到该环境
from gymnasium.envs.registration import register

register(
     id="gym_examples/GridWorld-v0",
     entry_point="gym_examples.envs:GridWorldEnv",
     max_episode_steps=300,
)
  • environment ID由三部分组成,①命名空间gym_examples(可选) ②强制名称GridWorld ③版本v0(可选)
  • entry_point参数在注册自定义环境时使用,它指定了如何导入这个环境类
  • 格式通常是module:classname
  • module 是包含环境类的 Python 模块的路径。
    classname 是环境中具体的类的名称
  • 其他可指定的参数如下所示:

在这里插入图片描述

  • 经过注册的自定义环境GridWorldEnv可由以下命令创建
env = gymnasium.make('gym_examples/GridWorld-v0')
  • gym-examples/gym_examples/envs/init.py 文件中需要包含以下的内容
from gym_examples.envs.grid_world import GridWorldEnv

创建包 Package(最后一步)

  • 将代码构建为python的包,方便地在不同项目中重用自定义的环境代码
  • gym-examples/setup.py中写入以下内容
from setuptools import setup

setup(
    name="gym_examples",
    version="0.0.1",
    install_requires=["gymnasium==0.26.0", "pygame==2.1.0"],
)

此处可以将"==“改为”>=",如果不打算做图形化可以删去pygame==2.1.0

安装自定义环境(在包含 setup.py 的目录中执行)

pip install -e .
  • 安装成功后会生成gym_examples.egg-info文件夹

创建自定义环境示例

  • 使用以下命令
import gym_examples
env = gymnasium.make('gym_examples/GridWorld-v0')
  • 传参的版本
import gym_examples
env = gymnasium.make('gym_examples/GridWorld-v0', size=10)

参考文献

  1. Gymnasium Documentation:Make your own custom environment
  2. 深度强化学习:gymnasium下创建自己的环境(保姆式教程)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1816299.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【QT5】<知识点> QT常用知识(更新中)

目录 一、更改文本颜色和格式 二、QT容器类 三、字符串与整数、浮点数之间的转换 四、QString常用功能 五、SpinBox的属性介绍 六、滑动、滚动、进度条和表盘LCD 七、时间、日期、定时器 一、更改文本颜色和格式 动态设置字体粗体:QFont对象的setBold方法动态…

Yapi代码执行 waf绕过实战记录

本文记录了2021年一次有趣的客户目标测试实战。这次经历颇为特别,因此我将其整理成笔记,并在此分享,希望对大家有所帮助。 事件起因 疫情在家办公,准备开始划水的一天,这时接到 boss 的电话说要做项目,老…

微调技术:人工智能领域的神奇钥匙

在人工智能的浪潮中,深度学习技术凭借其强大的数据处理和学习能力,已成为推动科技进步的重要引擎。然而,深度学习模型的训练往往需要大量的数据和计算资源,这在某些特定场景下成为了限制其发展的瓶颈。为了解决这个问题&#xff0…

SolidWorks 2016 SP5安装教程

软件介绍 Solidworks软件功能强大,组件繁多。 Solidworks有功能强大、易学易用和技术创新三大特点,这使得SolidWorks 成为领先的、主流的三维CAD解决方案。 SolidWorks 能够提供不同的设计方案、减少设计过程中的错误以及提高产品质量。SolidWorks 不仅…

JavaWeb6 Tomcat+postman请求、响应

Web服务器 对HTTP协议操作进行封装,简化web程序开发 部署web项目,对外提供网上信息浏览服务 Tomcat 轻量级web服务器,支持servlet,jsp等少量javaEE规范 也被称为web容器,servlet容器 Springboot有内置Tomcat nginx…

网络编程2----UDP简单客户端服务器的实现

首先我们要知道传输层提供的协议主要有两种,TCP协议和UDP协议,先来介绍一下它们的区别: 1、TCP是面向连接的,UDP是无连接的。 连接的本质是双方分别保存了对方的关键信息,而面向连接并不意味着数据一定能正常传输到对…

[CUDA 学习笔记] 稀疏矩阵向量乘法(SpMV) CUDA 实现与优化

稀疏矩阵向量乘法(SpMV) CUDA 实现与优化 本文主要围绕基于 CUDA 的 SpMV 实现进行介绍, 包括几种典型稀疏矩阵存储格式下 SpMV 的朴素实现, 以及 CSR 格式下的几种优化实现. 稀疏矩阵存储格式 稀疏矩阵即含有大量零元的矩阵. 对于稀疏矩阵, 像稠密矩阵一样使用二维数组来存…

物业管理的隐形杀手:纸质点检表,你还在用吗?

在日常的生活中,我们经常会看到小区物业保洁、客服人员在工作岗位忙忙碌碌,但忽略了默默为我们提供舒适环境的“隐形守护者”——物业设施设备。然而,一旦这些设备出现故障,我们的日常生活就会陷入混乱。那么,如何确保…

比特币不是解决货币伦理的「灵丹妙药」

原文标题:《Bitcoin is no ‘silver bullet’ for money’s ethical problems》 撰文:Stephen Katte 编译:Chris,Techub News 本文来源香港Web3媒体:Techub News 比特币和法定货币经常因货币伦理问题而受到批评&am…

AcWing 1639:拓扑顺序 ← 链式前向星

【题目来源】https://www.acwing.com/problem/content/1641/【题目描述】 这是 2018 年研究生入学考试中给出的一个问题: 以下哪个选项不是从给定的有向图中获得的拓扑序列? 现在,请你编写一个程序来测试每个选项。 【输入格式】 第一行包含两…

ffmpeg实现视频播放 ----------- Javacv

什么是Javacv和FFmpeg? Javacv是一个专门为Java开发人员提供的计算机视觉库,它基于FFmpeg和Opencv库,提供了许多用于处理图 像、视频和音频的功能。FFmpeg是一个开源的音视频处理工具集,它提供了用于编码、解码、转换和播放音视频…

MyBatis 参数上的处理的细节内容

1. MyBatis 参数上的处理的细节内容 文章目录 1. MyBatis 参数上的处理的细节内容2. MyBatis 参数上的处理3. 准备工作4. 单个(一个)参数4.1 单个(一个)简单类型作为参数4.2 单个(一个) Map集合 作为参数4.3 单个(一个) 实体类POJO作为参数 5. 多个参数5.1 Param注解(命名参数)…

免费学习通刷课(免费高分)Pro版

文章目录 概要整体架构流程小结 概要 关于上一版的免费高分的学习通刷课,有很多人觉得还得登录太复杂了,然后我又发现了个神脚本,操作简单,可以后台挂着,但是还是建议调整速度到2倍速,然后找到你该刷的课&…

论文图片颜色提取

论文绘图的时候有些颜色不知道怎么选取,参考其他论文,将其他论文中的颜色提取下来,用取色器识别出来,记录如下: 颜色代码:#BEAED4 190,174,212 颜色代码:#C4CBCB 196,203,203 颜色代码&am…

【JVM】之常见面试题

文章目录 1.JVM中的内存区域划分2.JVM的类加载机制2.1 加载2.2 验证2.3 准备2.4 解析2.5 初始化2.6 类加载的时机 3 类加载器4.双亲委派模型5.JVM中的垃圾回收策略5.1 找谁是垃圾5.1.1 引用计数法5.1.2 可达性分析法 5.2 释放垃圾5.2.1 标记清除算法5.2.2 复制算法5.2.3 标记整…

GoogleDeepMind联合发布医学领域大语言模型论文技术讲解

Towards Expert-Level Medical Question Answering with Large Language Mod 这是一篇由Google Research和DeepMind合作发表的论文,题为"Towards Expert-Level Medical Question Answering with Large Language Models"。 我先整体介绍下这篇论文的主要内容&#x…

[CAN] 创建解析CAN报文DBC文件教程

👉本教程需要先安装CANdb软件,[CAN] DBC数据库编辑器的下载与安装 🙋前言 DBC(全称为Database CAN),是用于描述单个CAN网络中各逻辑节点的信息。 DBC是汽车ECU(Electronic Control Unit,电子控制单元&…

批量文件重命名软件

因为日常用电脑的时候,经常都会遇到需要对当前目录下的文件,进行重命名。最好是按照自己的规则上来进行批量重命名。我试了几款软件,都感觉不是很好,不是要收费,就是各种乱七八糟的流氓广告。本想着干脆自己写算了,在绝望之际,找到了这款软件,亲测,确实还用,特别是满…

python 10个高频率的自动化脚本(干货,速度收藏)

1. 文件操作:自动备份文件 场景:每日自动备份重要文件到指定目录。 import shutilimport datetimedef backup_file(src, dst_folder): now datetime.datetime.now().strftime(%Y%m%d%H%M%S) dst_path f"{dst_folder}/backup_{now}_{src.s…

bugku---misc---ping

1、下载附件,解压后是一个流量包 2、用wireshark分析,发现都是清一色的icmp报文,只能看看内容。 3、点了几条流量,发现有个地方连起来是flag 4、最终将所有的拼起来,得到flag flag{dc76a1eee6e3822877ed627e0a04ab4a}…