DI-engine强化学习入门(七)如何自定义神经网络模型

news2024/9/21 22:29:19

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定:

  1. 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据这些配置自动构建神经网络模型。这种方式简单易用,适用于常见的标准网络结构。
  2. 自定义模型实例并传入Policy。这种方式下,需要用户自己定义Tensorflow/Pytorch模型类,实现前向传播等接口,然后将实例传入Policy中。这种方式灵活度高,用户可以自由设计任意结构的神经网络。但是需要用户比较熟悉网络定义和 Tensorflow/Pytorch接口。

(注:在强化学习中,策略(Policy)是指智能体(Agent)决策的规则。策略是从状态(State)到动作(Action)的映射,它定义了在给定的状态下,智能体应该采取什么动作。策略可以是确定性的(Deterministic)也可以是随机性的(Stochastic)。)
以上两种方式都会在Policy中封装为neural_net属性,策略学习会通过这个网络完成状态的 embedding 以及动作的选择。这套机制和接口为用户提供了必要的灵活性,可以根据具体问题和需求配置各种神经网络模型。
这是红色的文字

Policy 默认使用的模型是什么
DI-engine 中已经实现的 policy,默认使用 default_model 方法中表明的神经网络模型,例如在 SACPolicy 中:

@POLICY_REGISTRY.register('sac')class SACPolicy(Policy):...    def default_model(self) -> Tuple[str, List[str]]:        if self._cfg.multi_agent:            return 'maqac_continuous', ['ding.model.template.maqac']        else:            return 'qac', ['ding.model.template.qac']...

在这段代码中,我们看到的是一个名为 DI-engine 的强化学习框架中的一个策略(Policy)类的一部分。具体来说,它定义了一个使用Soft Actor-Critic, SAC 算法的策略类。这个段落描述了如何在这个框架内设置和使用策略相关的神经网络模型。

让我们逐步解释这段代码:

  1. @POLICY_REGISTRY.register('sac') 是一个装饰器,它将 SACPolicy 类注册到一个名为 POLICY_REGISTRY 的注册器中,并且用 'sac' 作为这个策略的标识符。这样的注册机制允许框架能够根据名字轻松地查找和实例化策略。
  2. class SACPolicy(Policy): 表明 SACPolicy 是从更一般的 Policy 类派生的,它是一个具体的策略实现,使用了 SAC 算法。
  3. default_model 是 SACPolicy 类的一个方法,它定义了该策略默认使用的模型。这个方法返回两个元素的元组:
  • 'maqac_continuous' 或 'qac':这是在模型注册器中注册的模型名字。根据配置是否是多智能体(multi_agent),它返回不同的模型名。
  • ['ding.model.template.maqac'] 或 ['ding.model.template.qac']:这是包含模型类的文件路径的列表。这个路径告诉 DI-engine 在哪里可以找到定义模型的代码。

4.当使用配置文件时,DI-engine 的入口文件将使用 cfg.policy.model 中的参数(比如 obs_shape, action_shape)来实例化提供的模型类。这个过程是自动化的,意味着用户定义好配置,DI-engine 将负责根据这些配置创建并初始化模型。

5.模型类会根据传入的参数构造适当的神经网络。例如,如果传入的 obs_shape 参数表明观测是一个图像,则模型可能会使用卷积层来处理输入;如果观测是一个向量,则可能使用全连接层。

简而言之,这段代码展示了 DI-engine 如何灵活地处理不同类型的策略和模型,以及如何通过配置文件来方便地自定义和实例化这些策略和模型。这种设计允许研究者和开发者能够轻松试验不同的算法和模型架构,而无需直接修改代码。

如何自定义神经网络模型

在 DI-engine 强化学习框架中,每个策略(如 SACPolicy)通常有一个关联的默认模型(通过 default_model 方法指定),这个默认模型是为特定类型的任务设计的。例如,原始的 qac 模型可能是为处理具有一维观测空间的环境设计的,即观测是一个向量。

但是!如果任务是在一个模拟器(如 dmc2gym,一个DeepMind Control Suite到OpenAI Gym接口的适配器)上运行,并且任务是 cartpole-swingup,而且你希望使用观测为像素的输入(即观测是一个图像),那么默认的 qac 模型不足以处理这样的高维度和多通道的输入。在这种情况下,观测空间的形状是 (3, height, width),其中 3 表示图像的颜色通道数(RGB),height 和 width 分别表示图像的高度和宽度。

在 dmc2gym 文档中,from_pixel 参数设定为 True 意味着环境将提供像素级的观测,而 channels_first 设定为 True 表明图像的通道维度是第一维(这是PyTorch等深度学习库通常采用的格式)。

面对这样的情况,如果你想要使用 SAC 算法处理像素级的观测,你需要自定义一个能够处理这种高维观测的模型。所以我们创建一个新的模型类,该类在内部使用卷积神经网络(CNN)来处理输入的图像数据,并适当地修改网络架构以适应任务的特定要求。

自定义模型完成后,可以将这个模型应用到 SACPolicy 中,替换原本的 qac 模型。涉及到以下几个步骤:

  1. 实现一个新的模型类,它继承自某个基础模型类,并覆盖必要的方法以支持像素级输入。
  2. 在策略配置中指定你的自定义模型,以便 DI-engine 使用你提供的模型而不是默认模型。
  3. 确保你的自定义模型注册到 DI-engine 的模型注册器中,这样框架可以识别和使用它。

自定义 model 基本步骤
1. 明确环境 (env) 和策略 (policy)
首先,需要确定你的强化学习任务的具体环境和任务。例如,我们选择 dmc2gym 环境中的 cartpole-swingup 任务,并且决定观测将以像素数据的形式提供,我们的观测空间是一个图像,其形状为 (3, height, width)。下面我们使用 SAC 算法来进行学习。

在这里,from_pixel = True 表明环境将提供基于像素的观测,channels_first = True 表明图像数据的通道维度在前,这通常是深度学习库(如 PyTorch)的标准格式。

2. 查阅策略中的 default_model 是否适用
接下来,需要检查选择的策略是否具有适用于任务的默认模型。这可以通过查看策略的文档或直接查阅源代码来完成。以 DI-engine 中的 SAC 策略为例,可以查看 SACPolicy 类中的 default_model 方法来了解默认模型:

如果进一步看一下 ding.model.template.qac 中的 QAC 模型,咱们可能会发现它仅支持一维的观测空间,而不支持像 (3, height, width) 这样的图像形状。这意味着对于我们的 cartpole-swingup 任务,需要创建一个自定义模型来处理像素级的观测。

3. 自定义模型 (custom_model) 实现
自定义模型的实现需要遵循一些基本的原则,以确保与 DI-engine 框架的兼容性。

a. 实现功能
自定义模型需要实现默认模型中的所有公共方法。包括:

  • __init__: 构造函数,对模型的各个部分进行初始化。
  • forward: 定义模型如何从输入到输出的前向传递。
  • compute_actor: 计算策略网络的输出,即给定观测值时的动作。
  • compute_critic: 计算价值网络的输出,即动作的价值。

b. 保持返回值类型一致
自定义模型的方法需要保证与原始默认模型的返回值类型一致,以便于替换使用。

c. 利用已实现的 encoder 和 head
在 ding/model/common 下有多种 encoder 和 head 的实现,可用于构建不同部分的模型:

  • Encoder: 负责对输入数据进行编码,使其适合后续的处理。例如,ConvEncoder 用于处理图像观测输入,FCEncoder 用于处理一维观测输入。

点击DI-engine强化学习入门(七)如何自定义神经网络模型 - 古月居可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1666332.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【漏洞复现】泛微OA E-Cology XmlRpcServlet文件读取漏洞

漏洞描述: 泛微OA E-Cology是一款面向中大型组织的数字化办公产品,它基于全新的设计理念和管理思想,旨在为中大型组织创建一个全新的高效协同办公环境。泛微OA E-Cology XmlRpcServlet存在任意文件读取漏洞,允许未经授权的用户读…

三星硬盘格式化后怎么恢复数据

在数字化时代,硬盘作为数据存储的核心部件,承载着我们的重要文件、照片、视频等资料。然而,不慎的格式化操作可能使我们失去宝贵的数据。面对这样的困境,许多用户可能会感到无助和焦虑。本文旨在为三星硬盘用户提供格式化后的数据…

计算机网络实验2:路由器常用协议配置

实验目的和要求 掌握路由器基本配置原理理解路由器路由算法原理理解路由器路由配置方法实验项目内容 路由器的基本配置 路由器单臂路由配置 路由器静态路由配置 路由器RIP动态路由配置 路由器OSPF动态路由配置实验环境 1. 硬件:PC机; 2. 软…

金三银四面试题(二十六):责任链模式知多少?

什么是责任链模式 责任链模式(Chain of Responsibility Pattern)是一种行为型设计模式,旨在通过将请求的处理分布在一系列对象上,从而使得多个对象可以尝试处理同一个请求。这些对象被链接成一条链,每个对象都可以对请…

stm32——OLED篇

技术笔记! 一、OLED显示屏介绍(了解) 1. OLED显示屏简介 二、OLED驱动原理(熟悉) 1. 驱动OLED驱动芯片的步骤 2. SSD1306工作时序 三、OLED驱动芯片简介(掌握) 1. 常用SSD1306指令 2. …

专业130+总分400+哈尔滨工程大学810信号与系统考研哈工程水声电子信息通信工程,真题,大纲,参考书。

毕业设计刚搞完,总结一下去年考研的复习经历,希望对大家复习有帮助,考研专业课810信号与系统130总分400,如愿上岸哈工程水声。专业课:130 哈工程水声院810专业课信号与系统难度适中,目前数一难度很高&…

【C语言/Python】嵌入式常用数据滤波处理:卡尔曼滤波器的简易实现方式(Kalman Filter)

【C语言/Python】嵌入式常用数据滤波处理:卡尔曼滤波器的简易实现方式(Kalman Filter) 文章目录 卡尔曼滤波卡尔曼滤波公式卡尔曼滤波数据处理效果C语言的卡尔曼滤波实现附录:压缩字符串、大小端格式转换压缩字符串浮点数压缩Pack…

TCP三次握手四次挥手 UDP

TCP是面向链接的协议,而UDP是无连接的协议 TCP的三次握手 三次传输过程是纯粹的不涉及数据,三次握手的几个数据包中不包含数据内容。它的应用层,数据部分是空的,只是TCP实现会话建立,点到点的连接 TCP的四次挥手 第四…

JVM堆内存分析

jmap工具查看堆内存 jmap:全称JVM Memory Map 是一个可以输出所有内存中对象的工具,可以将JVM中的heap(堆),以二进制输出成文本,打印出Java进程对应的内存 找到pid jmap -heap 19792 Attaching to process ID 19792…

贪心算法-----柠檬水找零

今日题目:leetcode860 题目链接:点击跳转题目 分析: 顾客只会给三种面值:5、10、20,先分类讨论 当收到5美元时:不用找零,面值5张数1当收到10美元时:找零5美元,面值5张数…

bevformer详解(1):论文介绍

3D 视觉感知任务,包括基于多摄像头的3D检测和地图分割对于自动驾驶系统至关重要。本文提出了一种名为BEVFormer的新框架,它通过使用空间和时间的Transformer 学习统一的BEV表示来支持多个自动驾驶感知任务。简而言之,BEVFormer通过预定义的网格形式的Bev Query与空间和时间空…

icap对flash的在线升级

文章目录 一、icap原语介绍(针对 S6 系列的 ICap),之后可以拓展到A7、K7当中去二、程序1设计2.1信号结构框图2.2 icap_delay设计2.3 icap_ctrl设计(可以当模板使用,之后修改关键参数即可) 三、程序2设计四、…

如何同时或者按顺序间隔启动多个程序

首先,需要用到的这个工具: 度娘网盘 提取码:qwu2 蓝奏云 提取码:2r1z 1、打开工具,切换到定时器模块,快捷键:Ctrl3 2、新建一个定时器,我这里演示同时打开多个程序(比…

在shell程序里如何从文件中获取第n行

问题: 有没有一种“规范”的方式来做到这一点?我一直在使用 head -n | tail -1,它可以做到这一点,但我一直想知道是否有一个Bash工具,专门从文件中提取一行(或一段行)。 所谓“规范”,我指的是一个主要功…

HTML五彩缤纷的爱心

写在前面 小编准备了一个五彩缤纷的爱心,送给各位小美女们~ 在桌面创建一个.txt文本文件,把代码复制进去,将后缀.txt改为.html,然后就可以双击运行啦! HTML简介 HTML(超文本标记语言)是一种…

Stable Diffusion是什么?

目录 一、Stable Diffusion是什么? 二、Stable Diffusion的基本原理 三、Stable Diffusion有哪些运用领域? 一、Stable Diffusion是什么? Stable Diffusion是一个先进的人工智能图像生成模型,它能够根据文本描述创造出高质量的图…

VMware安装centos7教程

文章目录 1、centos7的ios镜像下载2、CentOS7安装3、Centos配置 其他教程: 1、VMware Workstation 16 Pro安装教程 2、VMwarePro16安装Ubuntu16.04图文教程 1、centos7的ios镜像下载 官网:https://vault.centos.org/ 阿里云:https://develo…

Idea入门:一分钟创建一个Java工程

一,新建一个Java工程 1,启动Idea后,选择 [New Project] 2,完善工程信息 填写工程名称,根据实际用途取有意义的英文名称选择Java语言,可以看到还支持Kotlin、Javascript等语言选择包管理和项目构建工具Mav…

新闻资讯微信小程序开发后端+php【附源码,文档说明】

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

2024最新版JavaScript逆向爬虫教程-------基础篇之无限debugger的原理与绕过

目录 一、无限debugger的原理与绕过1.1 案例介绍1.2 实现原理1.3 绕过debugger方法1.3.1 禁用所有断点1.3.2 禁用局部断点1.3.3 替换文件1.3.4 函数置空与hook 二、补充2.1 改写JavaScript文件2.2 浏览器开发者工具中出现的VM开头的JS文件是什么? 三、实战 一、无限…