Online Decision Transformer

news2025/1/26 15:38:26

摘要

  • 最近的工作表明,离线强化学习 (RL) 可以表述为序列建模问题 (Chen et al., 2021; Janner et al., 2021),并通过类似于大规模语言建模的方法来解决。 然而,RL 的任何实际实例化还涉及在线组件,其中在被动离线数据集上预训练的策略通过与环境的特定任务交互进行微调。 我们提出了在线决策Transformer(ODT),这是一种基于序列建模的 RL 算法,将离线预训练与在线微调融合在一个统一的框架中。 我们的框架使用序列级熵正则化器与自回归建模目标相结合,以实现样本有效的探索和微调。 根据经验,我们表明 ODT 在 D4RL 基准测试的绝对性能上与最先进的技术具有竞争力,但在微调过程中显示出更显着的收益。

引言

  • 序列建模的生成式预训练已成为许多领域和模式中机器学习的统一范式,特别是在语言和视觉方面(Radford 等人,2018;Chen 等人,2020;Brown 等人,2020;Lu 等人, 2022)。 最近,这种预训练范式已扩展到离线强化学习 (RL)(Chen 等人,2021 年;Janner 等人,2021 年),其中训练代理以自回归最大化离线数据集中轨迹的可能性。 在训练期间,这种范式本质上将离线 RL 转换为监督学习问题(Schmidhuber,2019;Srivastava 等人,2019;Emmons 等人,2021)。 然而,这些工作呈现出不完整的画面,因为通过离线 RL 学习的策略受到训练数据集质量的限制,需要通过在线交互对感兴趣的任务进行微调。 这种监督学习范式是否可以扩展到在线环境仍然是一个悬而未决的问题。
  • 与语言和感知不同,RL 的在线微调与预训练阶段根本不同,因为它涉及通过探索获取数据。 对探索的需求使得离线 RL 的传统监督学习目标(例如,均方误差)在在线环境中不足。 此外,据观察,对于标准在线算法,访问离线数据通常会对在线性能产生零甚至负面影响(Nair 等,2020)。 因此,离线预训练的整体流程以及对 RL 策略的在线微调需要仔细考虑训练目标和协议。
  • 我们介绍了在线决策变压器(ODT),这是一种RL的学习框架,它将离线预训练与在线微调相结合,以实现样本高效的策略优化。我们的框架建立在先前为离线RL引入的决策变换器(DT)(Chen et al,2021)架构的基础上,特别适用于在线交互成本高昂的场景,这需要离线预训练和样本有效的微调。我们确定了与DTs不兼容的几个关键缺点,并对其进行了纠正,从而为我们的整体渠道带来了卓越的性能。
  • 首先,我们从确定性政策转向随机政策,以确定在线阶段的探索目标。我们通过类似于最大RL框架的策略熵来量化探索(Levine,2018。然而,与传统框架不同的是,ODT的策略熵在轨迹的总体水平上受到限制(与单个时间步长相反),并且其双重形式规范了监督学习目标(与直接收益最大化相反)。接下来,我们开发了一种与ODT的体系结构和训练协议一致的新型重放缓冲区(Mnih等,2015)。缓冲区存储轨迹,并通过ODT的在线推出进行填充。由于ODT参数化返回条件策略,我们进一步研究在在线推出期间指定所需返回的策略。但是,此值可能与推出期间观察到的真实返回不匹配。为了应对这一挑战,我们将后验experience replay(Andrychowicz等,2017)的概念扩展到我们的设置,并在增强它们之前用正确的return tokens重新标记推出的轨迹。
  • 根据经验,我们通过将其性能与D4RL基准上的最新算法进行比较来验证我们的整体框架(Fu等,2020)。我们发现,由于我们的微调策略,我们的相对改进优于其他基线(Nair等,2020;Kostrikov等,2021a),同时在考虑基础模型的预训练结果时表现出竞争性的绝对表现。最后,我们通过严格的消融和额外的实验设计来补充我们的主要结果,以证明和验证我们方法的关键组成部分。

相关工作

  • 我们的工作包括两个广泛的研究途径,我们在这里详细介绍。
  • RL的Transformer。最近令人兴奋的进展是将离线RL问题制定为上下文条件序列建模问题(Chen等,2021;Janner等,2021)。这些工作建立在强化学习作为监督学习范式(Schmidhuber,2019;Srivastava等,2019;Emmons等,2021)的基础上,其侧重于以任务规范(例如,target goal或return)为条件的动作序列的预测建模,而不是显式学习Q函数或策略梯度。Chen等人(2021)将Transformer训练为无模型上下文条件策略,Janner等人(2021)将Transformer训练为策略和模型,并表明波束搜索可用于改进纯无模型性能。然而,这些工作仅探索离线RL设置,其类似于Transformer传统上在自然语言处理应用中训练的固定数据集。我们的工作重点是将这些结果扩展到在线微调环境,显示出与最先进的RL方法的竞争力。
  • 离线RL方法主要将conservative保守组件添加到现有的非策略RL方法中以防止分布外推,但需要对超参数进行许多调整和重新调整才能起作用(Kumar等,2020a;Kostrikov等,2021b))。与我们的工作类似,Fujimoto和Gu(2021)展示了将行为克隆术语添加到离线RL方法的好处,并且该术语的简单添加允许将非策略RL算法以最小的变化移植到离线设置。
  • 离线RL与在线微调。虽然ODT源于与传统RL方法不同的视角,但现有的许多工作都集中在对给定离线数据集进行预训练的相同范例上,并在在线环境中进行微调。Nair等人(2020)表明,将离线或非策略性RL方法应用于离线预培训和在线微调制度往往无助于甚至阻碍绩效。这种政策外方法的不良表现可归因于政策外引导错误积累(Munos,20032005;Farahmand等,2010;Kumar等,2019)。在离线RL方法中,在线微调制度中的不良表现可以通过过度保守来解释,这在离线制度中是必要的,以防止价值高估超出分配状态。Nair等人(2020)首次提出了一种适用于离线和在线培训制度的算法。最近的工作(Kostrikov et al。,2021a)也提出了一种基于期望的离线RL隐式Q学习算法,该算法也显示出强大的在线微调性能,因为该策略是通过避免分发外行为的行为克隆步骤提取的动作。
  • Lee等人(2021)通过平衡重放方案和一系列功能来解决离线在线设置问题,以在离线培训期间保持保守主义。Lu等人(2021)改进了AWAC(Nair et al。,2020),它在在线微调阶段表现出崩溃,在在线阶段纳入了积极的抽样和探索。我们还发现积极的抽样和探索是良好的在线微调的关键,但是我们将展示ODT中这些特征是如何自然发生的,从而导致一种简单的端到端方法,可以自动适应离线和在线设置。

预赛

  • 我们假设我们的环境可以建模为马尔可夫决策过程 (MDP),可以描述为 M = < S , A , p , P , R , γ > M=<S, A, p, P, R, γ> M=<S,A,p,P,R,γ>,其中 S S S 是状态空间, A A A 是动作空间, P ( s t + 1 ∣ s t , a t ) P(s_{t+1}|s_t,a_t) P(st+1stat) 是转换的概率分布, R ( s t , a t ) R(s_t,a_t) R(stat) 是奖励函数, γ γ γ 是折扣因子(Bellman,1957)。 代理从从固定分布 p ( s 1 ) p(s_1) p(s1) 采样的初始状态 s 1 s_1 s1 开始,然后在每个时间步 t t t 它从状态 s t ∈ S s_t \in S stS a t ∈ A a_t \in A atA 采取行动并移动到下一个状态 s t + 1   P ( ⋅ ∣ s t , a t ) s_{t+1}~P(\cdot |s_t, a_t) st+1 P(st,at)。 在每个动作之后,代理都会收到一个确定性的奖励 r t = R ( s t , a t ) r_t=R(s_t,a_t) rt=R(st,at)。 请注意,我们的算法也直接适用于部分可观察马尔可夫决策过程 (POMDP),但我们使用 MDP 框架以便于阐述。

3.1 设置和符号

  • 我们对决策转换器 (DT) 的在线微调感兴趣(Chen 等人,2020 年),其中代理将可以访问非平稳训练数据分布 T T T。最初,在预训练期间, T T T 对应于线数据分布 T o f f l i n e T_{offline} Toffline,并通过离线数据集 T o f f l i n e T_{offline} Toffline 访问。 在微调期间,它通过重播缓冲区 T r e p l a y T_{replay} Treplay 访问。 令 τ τ τ 表示轨迹并令 ∣ τ ∣ |τ | τ 表示它的长度。 轨迹 τ τ τ 在时间步长 t t t 的返回 (RTG),gt “ř|τ|t1“trt1 ,是该时间步长的未来奖励总和。 让“ pa1, . . . , a|τ|q, s “ ps1, . . . , s|τ|q 和 g “ pg1, . . . , g|τ|q 分别表示 τ 的动作序列、状态和 RTG。

Online Decision Transformer

  • 由于训练数据的限制,在纯离线数据集上训练的 RL 策略通常不是最优的,因为离线轨迹可能没有高回报并且仅覆盖状态空间的有限部分。 提高性能的一种自然策略是通过在线交互微调预训练的 RL 代理。 然而,标准决策转换器的学习公式不足以进行在线学习,正如我们将在实验消融中展示的那样,当天真地用于在线数据采集时会崩溃。 在本节中,我们介绍了对决策转换器的关键修改,以实现高效采样的在线微调。
  • 作为第一步,我们提出了一个广义的概率学习目标。 我们将扩展此公式以解释在线决策转换器 (ODT) 中的探索。 在概率设置中,我们的目标是学习最大化数据集可能性的随机策略。 例如,对于连续动作空间,我们可以使用具有对角线的多元高斯分布的标准选择(Haarnoja 等人,2018a;Fujimoto & Gu,2021;Kumar 等人,2020b;Emmons 等人,2021) 协方差矩阵,用于模拟以状态和 RTG 为条件的动作分布。 让 θ θ θ 表示策略参数。 正式地,我们的政策是
    在这里插入图片描述
  • 其中协方差矩阵 Σ θ Σ_θ Σθ 假定为对角矩阵。 给定随机策略,我们最大化训练数据集中轨迹的对数似然,或等效地最小化负对数似然 (NLL) 损失。
    在这里插入图片描述
  • 我们在这里考虑的策略包含 DT 考虑的确定性策略。 优化目标 (1) 等同于优化 (3),假设协方差矩阵 Σθ 是对角矩阵并且方差在所有维度上都相同,这是我们假设涵盖的特例。

4.1 最大熵序列建模

  • 在线 RL 算法的关键属性是能够平衡探索-开发权衡。 即使使用随机策略,如等式 (3) 中的传统 DT 公式也没有考虑探索。 为了解决这个缺点,我们首先通过定义为的策略熵来量化探索:
    在这里插入图片描述
  • 其中 H [ π θ ( a k ) ] H[π_θ(a_k)] H[πθ(ak)] 表示分布 π θ ( a k ) π_θ(a_k) πθ(ak) 的香农熵。 策略熵取决于数据分布 T T T,它在离线预训练阶段是静态的,但在微调期间是动态的,因为它取决于探索期间获得的在线数据。
  • 类似于许多现有的 max-ent RL 算法 (Levine, 2018),例如 Soft Actor Critic (SAC, Haarnoja et al. (2018a;b)),我们明确地对策略熵施加一个下限以鼓励探索。 也就是说,我们有兴趣解决以下约束问题:
    在这里插入图片描述
  • 其中 β 是一个前缀超参数。 继 Haarnoja 等人 (2018b) 之后,在实践中,我们解决了 (5) 的对偶问题,以避免显式处理不等式约束。 即,我们考虑拉格朗日 Lpθ, λq “ Jpθq `λpβ ´HTθra|s, gsq 并通过交替优化 θ 和 λ 来解决问题 maxλě0minθLpθ, λq。 用固定的 λ 优化 θ 等同于
  • 最后,我们就实际优化细节方面与 SAC 的相似性发表了几点评论。 首先,我们没有完全解决子问题(6)和(7)。 对于它们两者,我们每次只进行一次梯度更新,也就是一步交替梯度下降。 其次,HTθ ra|s, gs 的计算涉及积分。 我们使用单样本蒙特卡洛估计来近似每个积分,并且样本被重新参数化以进行低方差梯度计算。 正如 Haarnoja 等人 (2018b) 也指出的那样,我们经常观察到问题 (5) 中的约束很紧,因此实际熵 HTθra|s, gs 与 β 匹配。

Training Pipeline

  • 我们使用变压器架构实例化上述公式。 在线决策转换器 (ODT) 建立在 DT 架构之上,并包含由于随机策略而产生的变化。 我们通过输出端的两个独立的全连接层来预测策略均值和对数方差。 算法 1 总结了 ODT 中的整体微调管道,其中详细的内部训练步骤在算法 2 中进行了描述。我们在下面概述了这些算法的主要组成部分,并在附录 B 中讨论了其他设计选择和超参数。
  • 轨迹级回放缓冲器。我们使用重放缓冲区来记录我们过去的经历并定期更新。对于大多数现有的 RL 算法,重放缓冲区由转换组成。在 rollout 中的每一步在线交互之后,策略或 Q 函数都会通过梯度步骤进行更新,并执行策略以收集新的转换以添加到重放缓冲区中。然而,对于 ODT,我们的回放缓冲区包含轨迹而不是转换。离线预训练后,我们通过离线数据集中回报率最高的轨迹初始化回放缓冲区。每次我们与环境交互时,我们都会使用当前策略完全推出一个情节,然后以先进先出的方式使用收集到的轨迹刷新重播缓冲区。之后,我们再次更新策略并推出,如算法 1 所示。与 Haarnoja 等人 (2018a) 类似,我们还观察到使用平均动作评估策略通常会带来更高的回报,但使用采样更有好处在线探索的行动,因为它会产生更多样化的数据。
  • 事后回报重新贴标签。 Hindsight experience replay (HER) 是一种在奖励稀疏的环境中提高目标条件代理的样本效率的方法(Andrychowicz 等人,2017 年;Rauber 等人,2017 年;Ghosh 等人,2019 年) . 这里的关键思想是将智能体的轨迹重新标记为已实现的目标,而不是预期目标。 对于 ODT,我们正在学习以初始 RTG 为条件的策略。 然而,在政策推出期间获得的回报和诱导的 RTG 可能与预期的 RTG 不同。 受 HER 的启发,我们用实现的回报为展开的轨迹 τ τ τ 重新标记 RTG 代币,这样在最后一个时间步 g ∣ τ ∣ g_{|\tau|} gτ 的 RTG 代币恰好是代理 r ∣ τ ∣ r_{|τ|} rτ 获得的奖励,参见算法 2 的第 6 行。 这种返回重新标记策略适用于奖励稀疏和密集的环境。
  • RTG 调节。 ODT 需要一个超参数,初始 RTG g o n l i n e g_{online} gonline,用于收集额外的在线数据(参见算法 1 的第 4 行)。 此前,Chen 等人 (2021) 表明,离线 DT 的实际评估回报与经验上的初始 RTG 具有很强的相关性,并且通常可以推断出超过离线数据集中观察到的最大回报的 RTG 值。 对于 ODT,我们发现最好将此超参数设置为专家回报的一个小的、固定的比例(在我们的实验中设置为 2)。 我们还试验了更大的值以及随时间变化的课程(例如,离线和在线数据集中最佳评估回报的分位数),但我们发现这些相对于固定的、缩放的 RTG 而言是次优的。
  • 抽样策略。 与 DT 类似,算法 2 使用两步采样过程来确保重放缓冲区 Treplay 中长度为 K 的子轨迹被均匀采样。 我们首先以与其长度成正比的概率采样单个轨迹,然后统一采样长度为 K 的子轨迹。对于具有非负密集奖励的环境,我们的采样策略类似于重要性采样。 在这些情况下,轨迹的长度与其返回高度相关,正如我们在附录 F 中进一步强调的那样。

动态训练

  • 我们评论了一些关于 ODT 训练动态及其影响的经验观察。 我们首先展示一个示例运行,其中 ODT 的在线返回饱和,表明训练已经收敛。 我们将自己限制在算法 1 收敛的情况下,讨论 ODT 的训练动态。 这种收敛假设使我们能够分析学习目标在训练过程中的含义,以及初始 RTG 令牌对 ODT 策略的行为变化。 我们强调算法 1 的收敛保证是一个悬而未决的问题,超出了本文的范围,我们的主张将主要通过实验来指导。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/46453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Express:Express 中间件

中间件的概念 1. 什么是中间件 中间件&#xff08;Middleware &#xff09;&#xff0c;特指业务流程的中间处理环节。 2. 现实生活中的例子 在处理污水的时候&#xff0c;一般都要经过三个处理环节&#xff0c;从而保证处理过后的废水&#xff0c;达到排放标准。 处理污水…

aws cloudformation 堆栈集的创建和使用

资料 使用 AWS CloudFormation StackSets 跨多个 AWS 账户和区域配置资源AWS cloudformation示例模板堆栈集堆栈实例状态原因 很多组织使用大量的 AWS 账户&#xff0c;通常用 AWS Organizations 将这些账户组织为分层结构&#xff0c;分组为不同的组织部门 (OU)。并且希望确…

Teams app 的 SSO 机制

我们来继续我们的 Teams sample 之旅&#xff0c;上一个讲了 Tab app&#xff0c;那我们这里再深入一步&#xff0c;看一下如何使用 sso 机制。 sso 是一个很有用机制&#xff0c;它可以让我们的 teams app 能获取当前用户的身份。sso 很多时候比较难彻底理解&#xff0c;在开…

刷爆力扣之公平的糖果交换

刷爆力扣之公平的糖果交换 HELLO&#xff0c;各位看官大大好&#xff0c;我是阿呆 &#x1f648;&#x1f648;&#x1f648; 今天阿呆继续记录下力扣刷题过程&#xff0c;收录在专栏算法中 &#x1f61c;&#x1f61c;&#x1f61c; 该专栏按照不同类别标签进行刷题&#xff…

【数据链路层】循环冗余码CRC、后退N帧协议GBN、选择重传协议SR、CSMA/CA

文章目录循环冗余码CRC多帧滑动窗口连续ARQ协议后退N帧协议GBN选择重传协议SRCSMA/CA---针对无线局域网处理隐蔽站问题RTS&#xff0c;CTS循环冗余码CRC /*** 计算CRC16校验码** param bytes* return* [1,3,4,1,205,1,18,235,173]*/public static String CRC16(byte[] bytes) {…

终于见识到了微服务的天花板!SpringCloud全线手册,太强了

后台都是在问微服务架构的面试题怎么答&#xff0c;想聊聊微服务架构了。微服务架构一跃成为 IT 领域炙手可热的话题也就这两年的事&#xff0c;大量一线互联网公司因为庞大的业务体量和业务需求&#xff0c;纷纷投入了微服务架构的建设中&#xff0c;像阿里巴巴、百度、美团等…

Kamiya丨Kamiya艾美捷大鼠微量白蛋白酶联免疫吸附试验说明书

Kamiya艾美捷大鼠微量白蛋白酶联免疫吸附试验预期用途&#xff1a; 大鼠微量白蛋白酶联免疫吸附试验&#xff08;ELISA&#xff09;是一种高灵敏度的双位点酶联免疫吸附试验&#xff08;ELISA&#xff09;大鼠生物样品中微量白蛋白的测定。仅供研究使用。 引言 白蛋白&#x…

Java项目:ssm学生学籍管理系统

作者主页&#xff1a;源码空间站2022 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 SSM项目-学生学籍管理系统。该项目分管理员、老师、学生三种用户角色。每种角色分别对应不同的菜单&#xff1b; 以下分别介绍各个角色对应的功…

[附源码]计算机毕业设计springboot基于Java的日用品在线电商平台

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

二进制数据的贝叶斯非参数聚类算法(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 部分运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 利用图像结构信息是字典学习的难点,针对传统非参数贝叶斯算法对图像结构信息利用不充分,以及算法运行效率低下的问题,该文…

GoLand2022.2.5版本Hello调动Greetings包

安装Goland2022.2.5 版本 1.官网下载goland-2022.2.5.dmg版本&#xff08;Mac)版本。如果是windows版本也可以直接下载&#xff09; 2.配置gopath&#xff0c;基本都是配置.我这里配置为/usr/local/go 作为全目录&#xff0c;如果是windows&#xff0c;直接在环境中配置path路…

Mysql基础知识篇(二)

1.UNION 与 UNION ALL 的区别&#xff1f; 如果使用 UNION&#xff0c;会在表链接后筛选掉重复的记录行如果使用 UNION ALL&#xff0c;不会合并重复的记录行从效率上说&#xff0c;UNION ALL 要比 UNION 快很多&#xff0c;如果合并没有刻意要删除重复行&#xff0c;那么就使…

自动化测试框架

自动化测试框架1.自动化测试框架核心功能1.数据驱动2.页面驱动3.关键字驱动2.关键字驱动实现-文档形式3.关键字驱动实现-表格形式1.自动化测试框架核心功能 这三种驱动测试可以结合使用来完成系统的自动化测试。可以将测试数据 1.数据驱动 将测试代码和测试数据分离&#xff…

科技云报道:云计算走向工业互联网“深水区”

科技云报道原创。 在新科技革命中&#xff0c;将网格化、信息化与智能化深度融合的工业互联网&#xff0c;正在将人、机、物全面互联&#xff0c;实现全要素、全产业链、全价值链的连接&#xff0c;推动传统产业加快转型升级、助力新兴产业加速发展壮大。 工业如何在快速变革…

培训机构借助创客匠人发力线上业务

疫情反反复复&#xff0c;传统线下教学受到严重影响,转型线上、借力线上发展业务成为行业主流趋势。但是,没有线上经验,人手不足的线下教培机构是否可以转型线上做教学服务,实现招生引流呢? 答案是——可以!用对工具,选对模式,其实很简单! 有很多没有专门线上运营团队,甚至是…

《计算机体系结构量化研究方法》1.7 可信任度

主要内容 计算机是在不同的抽象层上设计和构造的。我们可以逐级深入计算机的不同层面&#xff0c;将每个组件放大为一个完整的子系统进行查看&#xff0c;直到深入到独立的晶体管为止。尽管有些故障会波及整个系统&#xff0c;比如掉电&#xff0c;但许多故障可以被限制在模块…

leetcode-每日一题-1758-生成交替二进制字符串的最少操作数(简单,数学思想)

这道题标记为简单题是正常的&#xff0c;因为当你想到0或者1开头的时候就已经结束了看看我的分析 那么知道这个信息之后就很简单了&#xff0c;加上我们的位运算符号^作为标记即可&#xff0c;大家看看代码实现 1758. 生成交替二进制字符串的最少操作数 难度简单88收藏分享切换…

R语言和Tableau通过情感分析,我们可以从特朗普的推文得到什么?

社交媒体分析的许多用途中的一些是情绪分析&#xff0c;我们评估特定问题的帖子是积极还是消极。我们把社交媒体分析、机器学习、预测建模等集成到文本数据挖掘中。最近我们被客户要求撰写关于推文的研究报告&#xff0c;包括一些图形和统计输出。 在这篇文章中&#xff0c;我…

使用React.ts创建一个密码生成器的简单示例

目录密码生成器DemoFeature知识点React TypeScript —— Function Components为元素(::before/::after)绑定点击事件React如何正确定义对象数组在React中设置复选框check属性三目运算符实现React动态绑定class和style参考资料密码生成器Demo 使用密码生成器工具创建随机密码。P…

Java基于springboot +vue网上超市购物网站 多商家

随着我国信息化的发展&#xff0c;大家更多的是希望通过网络获取到更多的直接所需的信息&#xff0c;而商品一直以来就是人类永恒的追求之一&#xff0c;如何能够享有到更多的商品是很多人一直以来关系的问题。 本系统通过在线网购的方式让用户可以在需要购买商品但是有没有时间…