深度强化学习中的动作屏蔽(Action Masking)

news2025/1/12 20:55:15

RLlib中的example有一个代码是action_masking,很感兴趣,所以学习了一下
主要功能是:

  • “动作屏蔽”允许代理根据当前观察选择动作。这在许多实际场景中非常有用,在这些场景中,不同的时间步长可以执行不同的操作。
  • 解释动作屏蔽的博客文章:https://boring-guy.sh/posts/masking-rl/ RLlib
  • 支持动作屏蔽,即通过稍微调整环境和模型来禁止这些动作,如本示例所示。 在这里,ActionMaskEnv 包装了一个底层环境(这里是 RandomEnv),根据环境的观察仅将所有操作的子集定义为有效。如果选择了无效的操作,环境会引发错误 - 这绝不能发生!环境构造 Dict 观察,其中 obs[“observations”] 保存原始观察,obs[“action_mask”] 保存有效操作。
  • 为了避免选择无效的操作,使用了ActionMaskModel。该模型采用原始观察结果,计算相应操作的逻辑,然后将所有无效操作的逻辑设置为零,从而禁用它们。
  • 这仅适用于离散操作。

博客原文

简介

当我开始深度强化学习时,我工作的环境中的每个时间步都无法执行特定操作。

让我们具体说明不可能或不可用动作的概念:假设您想开发一个代理来玩马里奥赛车。接下来,假设代理有空库存(没有香蕉🍌或任何东西)。代理无法执行“使用库存中的对象”操作。将代理限制为有意义的操作选择将使其能够以更智能的方式进行探索并输出更好的策略。

现在您了解了不可能或不可用操作的概念,自然的问题是:“我如何管理不可能的操作?” 🤔 我实施的第一个解决方案是,如果智能体采取不可能的操作,则分配负奖励。它的表现比不限制动作的选择,但我对这种方法不满意,因为它不能阻止代理选择不可能的动作。

然后我决定使用动作屏蔽(Action Masking)。这种方法实现起来简单且优雅,因为它限制代理只采取“有意义”的动作。

在我的深度强化学习实践中,我了解到有很多方法可以使用Masking。Masking可用于神经网络中的任何级别并用于不同的任务。不幸的是,除了 Costa Huang 的这篇精彩文章 [7] 之外,很少有强化学习的Masking实现可用。

这篇博文的范围是解释Masking的概念并通过图形和代码进行说明。事实上,这些Masking可以对我们在阅读这篇博文时看到的许多约束进行建模。请注意,整个过程是完全可微的。简而言之,Masking是为了简化您的生活。

要求

  • 马尔可夫决策过程 (MDP) 的概念
  • 策略梯度和 Q 学习算法的概念
  • PyTorch 的一些知识或 numpy 的基础知识
  • 自注意力的概念。

如果您想了解这个概念是什么,我邀请您阅读这篇解释 Transformer 的精彩文章 [6]

动作方面 Action level

概念:

深度强化学习中Masking的主要功能是过滤掉不可能或不可用的动作。例如,在《星际争霸 II》和《Dota 2》中,每个时间步的动作总数分别为 1 0 26 10^{26} 1026 1 , 837 , 080 1,837,080 1,837,080 。然而,每个时间步的可能操作空间仅占可用操作空间的一小部分。因此,使用Masking有两个优点:

  • 第一个是避免给环境带来无效的行为。
  • 第二个是它是一种简单的方法,可以通过减少行动来管理广阔的空间。

在这里插入图片描述
图1说明了动作屏蔽的原理。其背后的想法很简单,它包括替换不可能的操作相关的 logits为 − ∞ -∞

那么,为什么应用这个掩码可以防止选择不可能的动作呢?
1. 基于价值的算法(Q-Learning):
在基于价值的方法中,我们选择动作价值函数的最高估计值 Q ( s , a ) Q(s,a) Q(s,a)
a = arg max ⁡ a ∈ A Q ( s , . ) a=\argmax\limits_{a\in A}Q(s,.) a=aAargmaxQ(s,.)
通过应用掩码,与不可能动作相关的 Q 值将等于 − ∞ -∞ ,因此它们永远不会是最高值,因此永远不会被选择。

2. 基于策略的算法(策略梯度):
在基于策略的方法中,我们根据模型输出的概率分布对动作进行采样:
a =   π θ ( . ∣ s ) a=~\pi_\theta(.|s) a= πθ(.∣s)
因此,有必要将与不可能动作相关的概率设置为0。在我们使用Masking的时候不可能的动作是 − ∞ -∞ 。我们使用 softmax 函数从 logits 转移到概率域:
Softmax ( z ⃗ ) i = e z i ∑ j = 1 K e z j for  i = 1 , … , K  and  z = ( z 1 , … , z K ) ∈ R K . \text{Softmax}(\vec{z})_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}} \quad \text{for} \ i = 1, \ldots, K \ \text{and} \ z = (z_1, \ldots, z_K) \in \mathbb{R}^K. Softmax(z )i=j=1Kezjezifor i=1,,K and z=(z1,,zK)RK.
考虑到我们已将与不可能动作相关的 logits 值设置为 − ∞ -∞ ,对这些动作进行采样的概率等于0。

实现

现在让我们练习并实现离散动作空间和基于策略的算法的动作屏蔽。我使用 Costa Huang 的论文和动作屏蔽代码 [7] 作为起点。想法很简单,我们继承 PyTorch 的 Categorical 类并添加一个可选的 mask 参数。
当我们应用Masking时,我们会替换不可能动作的 logits。
然而,由于我们使用 float32,因此我们需要以 32 位表示的最小值。在 PyTorch 中,我们通过运行 torch.finfo(torch.float.dtype).min 来获取它,即 -3.40e+38。
最后,对于一些基于策略的方法,例如近端策略优化(PPO)[12],有必要计算模型输出的概率分布熵。在我们的例子中,我们将仅计算可用操作的熵。

from typing import Optional

import torch
from torch.distributions.categorical import Categorical
from torch import einsum
from einops import  reduce


class CategoricalMasked(Categorical):
    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = torch.tensor(
                torch.finfo(logits.dtype).min, dtype=logits.dtype
            )
            logits = torch.where(self.mask, logits, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)

以下代码块的目的是向您展示如何使用操作掩码。首先,我们创建虚拟逻辑和具有相同形状的虚拟蒙版。

logits_or_qvalues = torch.randn((2, 3), requires_grad=True) # batch size, nb action
print(logits_or_qvalues) 
# tensor([[-1.8222,  1.0769, -0.6567],
#         [-0.6729,  0.1665, -1.7856]])

mask = torch.zeros((2, 3), dtype=torch.bool) # batch size, nb action
mask[0][2] = True
mask[1][0] = True
mask[1][1] = True
print(mask) # False -> mask action 
# tensor([[False, False,  True],
#         [ True,  True, False]])

然后我们比较有和没有遮蔽的动作。

head = CategoricalMasked(logits=logits_or_qvalues)
print(head.probs) # Impossible action are not masked
# tensor([[0.0447, 0.8119, 0.1434], There remain 3 actions available
#         [0.2745, 0.6353, 0.0902]]) There remain 3 actions available

head_masked = CategoricalMasked(logits=logits_or_qvalues, mask=mask)
print(head_masked.probs) # Impossible action are  masked
# tensor([[0.0000, 0.0000, 1.0000], There remain 1 actions available
#         [0.3017, 0.6983, 0.0000]]) There remain 2 actions available

print(head.entropy())
# tensor([0.5867, 0.8601])

print(head_masked.entropy())
# tensor([-0.0000, 0.6123])

我们可以观察到,当我们应用掩码时,与不可能的动作相关的概率等于0 。因此,我们的智能体永远不会选择不可能的动作。
最后,当我们在熵计算中不包括不可能的动作时,我们就得到了一致的值。这种校正后的熵计算使代理能够仅在有效动作上最大化其探索。
这么酷的把戏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1244398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决:Gitee + PicGo配置图床失败

解决:Gitee PicGo配置图床失败 PicGo安装插件的时候选择:gitee-uploader,不要选择gitee! 在Gitee新建的图床仓库中设置一个images文件夹,用来保存上传的图片,但是要注意在PicGo中的path中要写上路径/img…

视频直播美颜SDK全面解析:美颜SDK技术对比

美颜SDK的出现,为直播主和用户提供了更丰富的美颜体验。 一、美颜SDK的基本原理 美颜SDK多种技术协同工作,使得直播画面更加细腻、自然,给用户带来更好的视觉感受。不同的SDK可能采用不同的算法和处理流程,从而产生各具特色的美…

解决:ImportError: cannot import name ‘Sequence‘ from ‘collections‘

解决:ImportError: cannot import name ‘Sequence‘ from ‘collections‘ 背景 在使用之前的代码时,报错: File “G:\research\code\MicroDE_py\plot_bcic_iv_4_ecog_trial.py”, line 262, in from skorch.helper import predefined_spl…

适用于文件传输需求高,文件传输数据量大的aspera替代方案

与大多数点对点文件传输工具一样,Aspera提供了FTP和其他基于TCP的文件传输方法的可靠替代方案。 本文将主要介绍Aspera。但也要谈谈Aspera以及何时何地考虑将镭速作为Aspera替代方案。 Aspera的独特之处在于跨平台、WAN 优化文件传输的早期创新者。在这方面&#x…

11-22 SSM整合1

请求参数 (这里的形参数据都是SpringMvc注入的) controller里的方法不是我们来调用的 是由SpringMvc的前端控制器所调用的(前端控制器调用了处理器 由处理器和适配器去调用我们controller里的方法),controller里的方法叫handler->处理器 SpringMVC的Controller方…

Pikachu(二)

CSRF (跨站请求伪造)概述 Cross-site request forgery 简称为“CSRF”,在CSRF的攻击场景中攻击者会伪造一个请求(这个请求一般是一个链接),然后欺骗目标用户进行点击,用户一旦点击了这个请求,整个攻击就完…

JS数组常用的20种方法详解(每一个方法都有例子,超全面,超好理解的教程,干货满满)

目录 1.会改变原数组的方法(7种) 1.push() 2.pop() 3.unshift() 4.shift() 5.reverse() 6.sort() 7.splice() 2.不改变原数组的方法(13种,返回的新数组是从原数组浅拷贝来的) 1.concat() 2.join() 3.slice…

算法通关村第十二关-白银挑战字符串经典题目

大家好我是苏麟 , 今天带来字符串相关的题目 . 大纲 反转问题字符串反转K个一组反转仅仅反转字母反转字符串中的单词 反转问题 字符串反转 描述 : 编写一个函数,其作用是将输入的字符串反转过来。输入字符串以字符数组 s的形式给出。 题目 : LeetCode 344. 反转…

【MATLAB】全网入门快、免费获取、持续更新的科研绘图教程系列1

1 【MATLAB】科研绘图第一期点线图 %% Made by Lwcah %% 公众号:Lwcah %% 知乎、B站、小红书、抖音同名账号:Lwcah,感谢关注~ %% 更多MATLABSCI绘图教程敬请观看~%% 清除变量 clc; clear all; close all;%% 一幅图的时候figureWidth 8.5;figureHeight …

AI:87-基于深度学习的街景图像地理位置识别

🚀 本文选自专栏:人工智能领域200例教程专栏 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带有在本地跑过的代码,详细讲解供大家学习,希望可以帮到大家。欢迎订阅支持,正在不断更新中,…

VR云游:让旅游产业插上数字化翅膀,打造地方名片

自多地入冬降温以来,泡温泉成了许多人周末度假的选择,在气温持续走低的趋势下,温泉游也迎来了旺季;但是依旧有些地区温度依旧温暖,例如南京的梧桐美景也吸引了不少游客前去打卡,大家穿着汉服与金黄的树叶合…

【Hello Go】Go语言并发编程

并发编程 概述基本概念go语言的并发优势 goroutinegoroutine是什么创建goroutine如果主goroutine退出runtime包GoschedGoexitGOMAXPROCS channel无缓冲的channel有缓冲的channelrange和close单向channel 定时器TimerTicker Select超时 概述 基本概念 并行和并发概念 并行 &…

佳易王个体诊所病历登记系统查询软件教程

佳易王个体诊所病历登记系统查询软件教程 在开处方时可以随时查看该病人的历史病历。 软件功能: 1、配方模板:可以自由添加配方分类,预先设置药品配方,可以一键导入电子处方。 2、正常开药:可以灵活选择药品&#x…

什么是持续集成的自动化测试?

持续集成的自动化测试 如今互联网软件的开发、测试和发布,已经形成了一套非常标准的流程,最重要的组成部分就是持续集成(Continuous integration,简称CI,目前主要的持续集成系统是Jenkins)。 那么什么是持…

【图文详解】SiamFC++与图注意力的强强联合:单目标追踪系统

1.研究背景与意义 随着计算机视觉技术的不断发展,单目标追踪(Single Object Tracking, SOT)作为计算机视觉领域的一个重要研究方向,已经在许多实际应用中得到了广泛的应用。单目标追踪系统可以通过分析视频序列中的目标运动&…

【Typroa使用】Typroa+PicGo-Core(command line)+gitee免费图片上传配置

TyproaPicGo-Core(command line)gitee免费图片上传配置 本文是在win10系统下配置typroapicGo-Core(command line)gitee图片上传的教程。需要的环境和工具有: gitee账号,新建仓库及token令牌;已经安装了的typroa,需要0.9.98版本以上…

Python 字典(dict)基础学习

一、字典的基础定义(key:value)键值对 my_dict {"王力宏": 99, "周杰伦": 88, "林俊杰": 77} my_dict2 {} my_dict3 dict() print(my_dict) print(my_dict2) print(my_dict3) 字典基础定义 字典名 {key1:value1,key2:value2,key3:value3}…

shell 脚本的函数和数组

函数 —— 封装的一个公式:sin、cos、tan —— 函数为脚本的别名 —— 函数就是一个功能模块,在函数中写执行的命令即可;使用函数可以避免代码重复,增加可读性,简化脚本,使用函数可以将大的工程分割为若…

【C++初阶】STL详解(六)Stack与Queue的介绍与使用

本专栏内容为:C学习专栏,分为初阶和进阶两部分。 通过本专栏的深入学习,你可以了解并掌握C。 💓博主csdn个人主页:小小unicorn ⏩专栏分类:C 🚚代码仓库:小小unicorn的代码仓库&…

UE5 中的computer shader使用

转载:UE5 中的computer shader使用 - 知乎 (zhihu.com) 目标 通过蓝图输入参数,经过Compture Shader做矩阵运算 流程 1. 新建插件 2. 插件设置 3. 声明和GPU内存对齐的参数结构 4. 声明Compture Shader结构 5. 参数绑定 6. 着色器实现 7. 分配 work gr…