ray.rllib-入门实践-11: 自定义模型/网络

news2025/1/27 3:33:41

在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:

1. 定义自己的模型。

2. 向ray注册自定义的模型

3. 在config中配置使用自定义的模型

环境配置:

        torch==2.5.1

        ray==2.10.0

        ray[rllib]==2.10.0

        ray[tune]==2.10.0

        ray[serve]==2.10.0

        numpy==1.23.0

        python==3.9.18

一、 定义自己的模型 

需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法, 其代码框架如下:

import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class My_Model(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。

    def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...
    
    def forward(self, input_dict, state, seq_lens): ...
    
    def value_function(self): ...

示例如下:

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

class My_Model(TorchModelV2, nn.Module):
    def __init__(self, obs_space:gym.spaces.Space, 
                 action_space:gym.spaces.Space, 
                 num_outputs:int, 
                 model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
                 name:str
                 ,*, custom_arg1, custom_arg2):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)
        nn.Module.__init__(self)
        ## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值
        print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")

        ## 定义网络层
        obs_dim = int(np.product(obs_space.shape))
        action_dim = int(np.product(action_space.shape))
        ## shareNet
        self.shared_fc = nn.Linear(obs_dim,128)
        ## actorNet
        self.actorNet = nn.Linear(128, action_dim)
        ## criticNet
        self.criticNet = nn.Linear(128,1)

        self._feature = None 

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float()
        self._feature = self.shared_fc.forward(obs)
        action_logits = self.actorNet.forward(self._feature)
        return action_logits, state 
    
    def value_function(self):
        value = self.criticNet.forward(self._feature).squeeze(1)
        return value 

        在rllib中,每个算法的所有网络都被汇集到同一个 ModelV2 类下,供算法调用。actor 网络和critic网络可以在外面定义,也可以在model内部直接定义。 model的forward用于返回actor网络的输出, value_function函数用于返回critic网络的输出。 网络结构和网络层共享可以自定义设置。输入输出接口,需要与上面保持严格一致。

二、 向ray注册自定义模型

        ray.rllib.model.ModelCatalog 类,用于向ray注册自定义的model, 还可以用于获取env的 preprocessors 和 action distributions。

import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 

ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)

三、 在算法中配置并使用自定义的模型

主要是在 config.training() 模块中的 model 子模块中传入两个配置信息:

        1)"custom_model":"my_torch_model" ,                      
         2)"custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  

两个关键字固定不变,填入自己注册的模型名和对应的模型参数即可。

可以有以下三种配置代码的编写方式:

配置方法1:

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print 

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置使用自定义的模型
config = config.training(model= {"custom_model":"my_torch_model" ,                      
                                 "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
## 主要在上面两行配置使用自己的模型
##    配置 model 的 "custom_model" 项,用于指定rllib算法所使用的模型
##    配置 model 的 "custom_model_config" 项,用于传入自定义的网络参数,供自定义的model使用。
##    这两个关键词不可更改。

algo = config.build()
## 4. 执行训练
result = algo.train()
print(pretty_print(result))

与以上配置内容一样,还可以用以下两种配置写法:

配置方法2:

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置自定义模型
model_config_dict = {}
model_config_dict["custom_model"] = "my_torch_model" 
model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
config = config.training(model= model_config_dict)  

algo = config.build()

 配置方法3(推荐):

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

## 配置自定义模型
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}

algo = config.build()

 代码汇总:

"""
在ray.rllib中定义和使用自己的模型, 分为以下三个步骤:
1. 定义自己的模型。 
    需要继承自 TFModel 或 TorchModelV2, 并重写需要自定义的方法
    import torch.nn as nn
    from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
    class CustomTorchModel(TorchModelV2, nn.Module): ## 重构以下函数, 函数接口不能变。 
        def __init__(self, obs_space, action_space, num_outputs, model_config, name, *, custom_arg1, custom_arg2): ...
        def forward(self, input_dict, state, seq_lens): ...
        def value_function(self): ...

2. 向ray注册自定义的模型
    from ray.rllib.models import ModelCatalog
    ModelCatalog.register_custom_model("wzg_torch_model", CustomTorchModel)

3. 在config中配置使用自定义的模型
    model_config_dict = {
        "custom_model":"wzg_torch_model",
        "custom_model_config":{
            "custom_arg1": 1,
            "custom_arg2": 2}
    }
    config = PPOConfig()
    # config = config.training(model = model_config_dict)
    config.model["custom_model"] = "wzg_torch_model"
    config.model["custom_model_config"] = {"custom_arg1": 1,
                                    "custom_arg2": 2}
"""

## 1. 定义自己的模型
import numpy as np 
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym 
from gymnasium import spaces  
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

class My_Model(TorchModelV2, nn.Module):
    def __init__(self, obs_space:gym.spaces.Space, 
                 action_space:gym.spaces.Space, 
                 num_outputs:int, 
                 model_config:ModelConfigDict,  ## PPOConfig.training(model = ModelConfigDict), 调用的是config.model中的参数
                 name:str
                 ,*, custom_arg1, custom_arg2):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,model_config,name)
        nn.Module.__init__(self)
        ## 测试 custom_arg1 , custom_arg2 传递进来的是什么数值
        print(f"===========================  custom_arg1 = {custom_arg1}, custom_arg2 = {custom_arg2}")

        ## 定义网络层
        obs_dim = int(np.product(obs_space.shape))
        action_dim = int(np.product(action_space.shape))
        ## shareNet
        self.shared_fc = nn.Linear(obs_dim,128)
        ## actorNet
        self.actorNet = nn.Linear(128, action_dim)
        ## criticNet
        self.criticNet = nn.Linear(128,1)

        self._feature = None 

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"].float()
        self._feature = self.shared_fc.forward(obs)
        action_logits = self.actorNet.forward(self._feature)
        return action_logits, state 
    
    def value_function(self):
        value = self.criticNet.forward(self._feature).squeeze(1)
        return value 

## 2. 向ray注册自定义模型
import ray 
from ray.rllib.models import ModelCatalog # ModelCatalog 类: 用于注册 models, 获取env的 preprocessors 和 action distributions。 

ModelCatalog.register_custom_model(model_name="my_torch_model", model_class = My_Model)
ray.init()

## 3. 在训练中使用自定义模型
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print 

config = PPOConfig()
config = config.environment("CartPole-v1")
config = config.rollouts(num_rollout_workers=2)
config = config.framework(framework="torch") 

# ## 配置自定义模型:方法 1
# config = config.training(model= {"custom_model":"my_torch_model" ,                      
#                                  "custom_model_config": {"custom_arg1": 1, "custom_arg2": 2,}})  
# ## 配置自定义模型:方法 2
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config = config.training(model= model_config_dict) 

## 配置自定义模型: 方法 3 (个人更喜欢, 因为嵌套层次少)
config.model["custom_model"] = "my_torch_model"
config.model["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}

## 错误方法:
# model_config_dict = {}
# model_config_dict["custom_model"] = "my_torch_model" 
# model_config_dict["custom_model_config"] = {"custom_arg1": 1, "custom_arg2": 2,}
# config.model = model_config_dict # 会清空 model 里面的其他默认配置,导致报错

algo = config.build()

## 4. 执行训练
result = algo.train()
print(pretty_print(result))


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2283751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PCIE模式配置

对于VU系列FPGA,当DMA/Bridge Subsystem for PCI Express IP配置为Bridge模式时,等同于K7系列中的AXI Memory Mapped To PCI Express IP。

【由浅入深认识Maven】第2部分 maven依赖管理与仓库机制

文章目录 第二篇:Maven依赖管理与仓库机制一、前言二、依赖管理基础1.依赖声明2. 依赖范围(Scope)3. 依赖冲突与排除 三、Maven的仓库机制1. 本地仓库2. 中央仓库3. 远程仓库 四、 版本管理策略1. 固定版本2. 版本范围 五、 总结 第二篇&…

备赛蓝桥杯之第十五届职业院校组省赛第一题:智能停车系统

提示:本篇文章仅仅是作者自己目前在备赛蓝桥杯中,自己学习与刷题的学习笔记,写的不好,欢迎大家批评与建议 由于个别题目代码量与题目量偏大,请大家自己去蓝桥杯官网【连接高校和企业 - 蓝桥云课】去寻找原题&#xff0…

力扣 Hot 100 题解 (js版)更新ing

🚩哈希表 ✅ 1. 两数之和 Code: 暴力法 复杂度分析: 时间复杂度: ∗ O ( N 2 ) ∗ *O(N^2)* ∗O(N2)∗,其中 N 是数组中的元素数量。最坏情况下数组中任意两个数都要被匹配一次。空间复杂度:O(1)。 /…

DeepSeek-R1:性能对标 OpenAI,开源助力 AI 生态发展

DeepSeek-R1:性能对标 OpenAI,开源助力 AI 生态发展 在人工智能领域,大模型的竞争一直备受关注。最近,DeepSeek 团队发布了 DeepSeek-R1 模型,并开源了模型权重,这一举动无疑为 AI 领域带来了新的活力。今…

CY T 4 BB 5 CEB Q 1 A EE GS MCAL配置 - MCU组件

1、ResourceM 配置 选择芯片信号: 2、MCU 配置 2.1 General配置 1) McuDevErrorDetect: - 启用或禁用MCU驱动程序模块的开发错误通知功能。 - 注意:采用DET错误检测机制作为安全机制(故障检测)时,不能禁用开发错误检测。2) McuGetRamStateApi - enable/disable th…

校园商铺管理系统设计与实现(代码+数据库+LW)

摘 要 信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲,遇到了互联网时代才发现能补上自…

【JavaWeb学习Day13】

Tlias智能学习系统 需求: 部门管理:查询、新增、修改、删除 员工管理:查询、新增、修改、删除和文件上传 报表统计 登录认证 日志管理 班级、学员管理(实战内容) 部门管理: 01准备工作 开发规范-…

springboot使用tomcat浅析

springboot使用tomcat浅析 关于外部tomcat maven pom配置 // 打包时jar包改为war包 <packaging>war</packaging>// 内嵌的tomcat的scope标签影响范围设置为provided&#xff0c;只在编译和测试时有效&#xff0c;打包时不带入 <dependency><groupId>…

如何使用CRM数据分析优化销售和客户关系?

嘿&#xff0c;大家好&#xff01;你有没有想过为什么有些公司在市场上如鱼得水&#xff0c;而另一些却在苦苦挣扎&#xff1f;答案可能就藏在他们的销售策略和客户关系管理&#xff08;CRM&#xff09;系统里。今天我们要聊的就是如何通过有效的 CRM 数据分析来提升你的销售额…

Qt 控件与布局管理

1. Qt 控件的父子继承关系 在 Qt 中&#xff0c;继承自 QWidget 的类&#xff0c;通常会在构造函数中接收一个 parent 参数。 这个参数用于指定当前空间的父控件&#xff0c;从而建立控件间的父子关系。 当一个控件被设置为另一控件的子控件时&#xff0c;它会自动成为该父控…

电力场效应晶体管(电力 MOSFET),全控型器件

电力场效应晶体管&#xff08;Power MOSFET&#xff09;属于全控型器件是一种电压触发的电力电子器件&#xff0c;一种载流子导电&#xff08;单极性器件&#xff09;一个器件是由一个个小的mosfet组成以下是相关介绍&#xff1a; 工作原理&#xff08;栅极电压控制漏极电流&a…

一文讲解Java中的重载、重写及里氏替换原则

提到重载和重写&#xff0c;Java小白应该都不陌生&#xff0c;接下来就通过这篇文章来一起回顾复习下吧&#xff01; 重载和重写有什么区别呢&#xff1f; 如果一个类有多个名字相同但参数不同的方法&#xff0c;我们通常称这些方法为方法重载Overload。如果方法的功能是一样…

Pandas基础02(DataFrame创建/索引/切片/属性/方法/层次化索引)

DataFrame数据结构 DataFrame 是一个二维表格的数据结构&#xff0c;类似于数据库中的表格或 Excel 工作表。它由多个 Series 组成&#xff0c;每个 Series 共享相同的索引。DataFrame 可以看作是具有列名和行索引的二维数组。设计初衷是将Series的使用场景从一维拓展到多维。…

Meta-CoT:通过元链式思考增强大型语言模型的推理能力

大型语言模型&#xff08;LLMs&#xff09;在处理复杂推理任务时面临挑战&#xff0c;这突显了其在模拟人类认知中的不足。尽管 LLMs 擅长生成连贯文本和解决简单问题&#xff0c;但在需要逻辑推理、迭代方法和结果验证的复杂任务&#xff08;如高级数学问题和抽象问题解决&…

【时时三省】(C语言基础)二进制输入输出

山不在高&#xff0c;有仙则名。水不在深&#xff0c;有龙则灵。 ----CSDN 时时三省 二进制输入 用fread可以读取fwrite输入的内容 字符串以文本的形式写进去的时候&#xff0c;和以二进制写进去的内容是一样的 整数和浮点型以二进制写进去是不一样的 二进制输出 fwrite 字…

【go语言】数组和切片

一、数组 1.1 什么是数组 数组是一组数&#xff1a;数组需要是相同类型的数据的集合&#xff1b;数组是需要定义大小的&#xff1b;数组一旦定义了大小是不可以改变的。 1.2 数组的声明 数组和其他变量定义没有什么区别&#xff0c;唯一的就是这个是一组数&#xff0c;需要给…

SQL-leetcode—1179. 重新格式化部门表

1179. 重新格式化部门表 表 Department&#xff1a; ---------------------- | Column Name | Type | ---------------------- | id | int | | revenue | int | | month | varchar | ---------------------- 在 SQL 中&#xff0c;(id, month) 是表的联合主键。 这个表格有关…

k8s简介,k8s环境搭建

目录 K8s简介环境搭建和准备工作修改主机名&#xff08;所有节点&#xff09;配置静态IP&#xff08;所有节点&#xff09;关闭防火墙和seLinux&#xff0c;清除iptables规则&#xff08;所有节点&#xff09;关闭交换分区&#xff08;所有节点&#xff09;修改/etc/hosts文件&…

基于微信小程序的网上订餐管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…