O2O:Sample Efficient Offline-to-Online Reinforcement Learning

news2024/11/18 2:51:46

IEEE TKDE 2024
paper

Introduction

O2O存在策略探索受限以及分布偏移问题,进而导致在线微调阶段样本效率低。文章提出OEMA算法首先使用离线数据训练乐观的探索策略,然后提出基于元学习的优化方法,减少分布偏移并提高O2O的适应过程。

Method

在这里插入图片描述

optimistic exploration strategy

离线学习方法TD3+BC的行为策略 π e ( s ) \pi_e(s) πe(s)是由目标策略 π ϕ ( s ) \pi_\phi(s) πϕ(s)加上一个正态分布中采样的噪声。文章指出,目标策略被优化靠近离线数据集的保守策略,为了提高目标策略的探索能力,本文提出基于价值不确定性度量的方法:
π e = arg ⁡ max ⁡ π Q ^ U B ( s , π ( s ) ) , s . t . 1 2 ∥ π ϕ ( s ) − π ( s ) ∥ ≤ δ , \begin{aligned}\pi_{e}&=\arg\max_{\pi}\hat{Q}_{\mathrm{UB}}(s,\pi(s)),\\s.t.&\frac{1}{2}\|\pi_{\phi}(s)-\pi(s)\|\le\delta,\end{aligned} πes.t.=argπmaxQ^UB(s,π(s)),21πϕ(s)π(s)δ,
其中 Q ^ U B ( s , π ( s ) ) \hat{Q}_{\mathrm{UB}}(s,\pi(s)) Q^UB(s,π(s))为Q值的近似上界, 用来衡量认知不确定性。上述问题在保证策略约束的同时选择高不确信的动作。

不确信估计采用高斯分布。分布的均值为两个Q网络输出的均值,而方差表示如下:
σ Q ( s , a ) = ∑ i = 1 , 2 1 2 ( Q θ i ( s , a ) − μ Q ( s , a ) ) 2 = 1 2 ∣ Q θ 1 ( s , a ) − Q θ 2 ( s , a ) ∣ . \begin{gathered} \sigma_{Q}(s,a) =\sqrt{\sum_{i=1,2}\frac12(Q_{\theta_{i}}(s,a)-\mu_{Q}(s,a))^{2}} \\ =\frac12\Big|Q_{\theta_1}(s,a)-Q_{\theta_2}(s,a)\Big|. \end{gathered} σQ(s,a)=i=1,221(Qθi(s,a)μQ(s,a))2 =21 Qθ1(s,a)Qθ2(s,a) .
那么 Q ^ U B = μ Q ( s , a ) + β UB σ Q ( s , a ) \hat{Q}_{\mathrm{UB}} =\mu_Q(s,a)+\beta_\text{UB}\sigma_Q(s,a) Q^UB=μQ(s,a)+βUBσQ(s,a) β \beta β控制乐观程度,当取值-1时上式等价于:
Q ^ U B ( s , a ) ∣ β U B = − 1 = μ Q ( s , a ) − σ Q ( s , a ) = 1 2 ( Q θ 1 ( s , a ) + Q θ 2 ( s , a ) ) − 1 2 ∣ Q θ 1 ( s , a ) − Q θ 2 ( s , a ) ∣ = min ⁡ ( Q θ 1 ( s , a ) , Q θ 2 ( s , a ) ) , (9) \begin{aligned} &\hat{Q}_{\mathrm{UB}}(s,a)\Big|_{\beta_{\mathrm{UB}}=-1}=\mu_{Q}(s,a)-\sigma_{Q}(s,a) \\ &\begin{aligned}&=\frac{1}{2}(Q_{\theta_1}(s,a)+Q_{\theta_2}(s,a))-\frac{1}{2}|Q_{\theta_1}(s,a)-Q_{\theta_2}(s,a)|\end{aligned} \\ &=\min(Q_{\theta_{1}}(s,a),Q_{\theta_{2}}(s,a)),& \text{(9)} \end{aligned} Q^UB(s,a) βUB=1=μQ(s,a)σQ(s,a)=21(Qθ1(s,a)+Qθ2(s,a))21Qθ1(s,a)Qθ2(s,a)=min(Qθ1(s,a),Qθ2(s,a)),(9)
而当 β = 1 \beta=1 β=1时等价于 Q ^ U B ( s , a ) ∣ β U B = 1 = max ⁡ ( Q θ 1 ( s , a ) , Q θ 2 ( s , a ) ) , \left.\hat{Q}_\mathrm{UB}(s,a)\right|_{\beta_\mathrm{UB}=1}=\max(Q_{\theta_1}(s,a),Q_{\theta_2}(s,a)), Q^UB(s,a) βUB=1=max(Qθ1(s,a),Qθ2(s,a)),

原问题一种简单的解决方法是使用BC将其转化为无约束问题:
π e naive ( s ) = arg ⁡ max ⁡ π Q ^ UB ( s , π ( s ) ) − λ ∥ π ϕ ( s ) − π ( s ) ∥ \pi_e^\text{naive}{ ( s ) }=\arg\max_{\pi}\hat{Q}_\text{UB}{ ( s , \pi ( s ) ) }-\lambda\|\pi_\phi(s)-\pi(s)\| πenaive(s)=argπmaxQ^UB(s,π(s))λπϕ(s)π(s)
然而,由于目标策略通过策略改进不断更新,这种基于行为克隆的惩罚项无法缩小行为策略和目标策略之间的差距,违反了带约束的原问题。

为了解决该问题,提出在TD3的行为策略上增加一项扰动模型 ξ \xi ξ。行为策略改为
π e ( s ) = π ϕ ( s ) + ξ ω ( s , π ϕ ( s ) ) + ϵ \pi_e(s)=\pi_\phi(s)+\xi_\omega(s,\pi_\phi(s))+\epsilon πe(s)=πϕ(s)+ξω(s,πϕ(s))+ϵ
而对扰动模型的参数最小化下列损失函数
L ( ω ) = − E s ∼ B [ Q ^ U B ( s , π e ( s ) ) ] \mathcal{L}(\omega)=-\mathbb{E}_{s\sim\mathcal{B}}\left[\hat{Q}_{\mathrm{UB}}(s,\pi_e(s))\right] L(ω)=EsB[Q^UB(s,πe(s))]

Meta Adaptation for Distribution Shift Reduction

接着,为了解决在线微调存在的分布偏移问题,采用元学习的方法。具体的,保留两个buffer,Buffer B B B存储离线以及在线所有数据, B r B_r Br存储最新在线数据。

meta training

首先在B上训练策略:
L t r n ( ϕ ) = − E s ∼ B [ Q θ 1 ( s , π ϕ ( s ) ) ] \mathcal{L}_{trn}(\phi)=-\mathbb{E}_{s\sim\mathcal{B}}\left[Q_{\theta_1}\left(s,\pi_\phi(s)\right)\right] Ltrn(ϕ)=EsB[Qθ1(s,πϕ(s))]
然后基于SGD的一次梯度下降得到: ϕ ′ = ϕ − α ∇ ϕ L t r n ( ϕ ) \phi^{\prime}=\phi-\alpha\nabla_{\phi}\mathcal{L}_{trn}(\phi) ϕ=ϕαϕLtrn(ϕ)

meta test

然后利用最新在线数据集测试:
L t s t ( ϕ ′ ) = − E s ∼ B r [ Q θ 1 ( s , π ϕ ′ ( s ) ) ] \mathcal{L}_{tst}(\phi')=-\mathbb{E}_{s\sim\mathcal{B}_r}[Q_{\theta_1}(s,\pi_{\phi'}(s))] Ltst(ϕ)=EsBr[Qθ1(s,πϕ(s))]

meta optimization

最后将上述两个损失函数用下面的元优化目标共同优化
ϕ = arg ⁡ min ⁡ ϕ L t r n ( ϕ ) + β L t s t ( ϕ − α ∇ ϕ L t r n ( ϕ ) ) \phi=\arg\min_\phi\mathcal{L}_{trn}(\phi)+\beta\mathcal{L}_{tst}(\phi-\alpha\nabla_\phi\mathcal{L}_{trn}(\phi)) ϕ=argϕminLtrn(ϕ)+βLtst(ϕαϕLtrn(ϕ))

问题

  1. 原文中在meta optimization中对 ϕ \phi ϕ梯度更新是否修正为:
    ϕ ← ϕ − α ∂ ( L t r n ( ϕ ) + β L t s t ( ϕ − α ∇ ϕ L t r n ( ϕ ) ) ) ∂ ϕ \phi\leftarrow\phi-\alpha\frac{\partial\left(\mathcal{L}_{trn}\left(\phi\right) +\beta\mathcal{L}_{tst}\left(\phi-\alpha\nabla_{\phi}\mathcal{L}_{trn}\left(\phi\right)\right)\right)}{\partial\phi} ϕϕαϕ(Ltrn(ϕ)+βLtst(ϕαϕLtrn(ϕ)))
  2. 基于这个偏导出现的第二个问题。这是源码中元学习的训练过程
# Compute actor losse
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

            """ Meta Training"""
            self.actor_optimizer.zero_grad()
            actor_loss.backward(retain_graph=True)
            self.hotplug.update(3e-4)

            """Meta Testing"""
            self.beta = max(0.0, self.beta - self.anneal_step)
            meta_actor_loss = -self.critic.Q1(meta_state, self.actor(meta_state)).mean()
            weight = self.beta * actor_loss.detach() / meta_actor_loss.detach()
            meta_actor_loss_norm = weight * meta_actor_loss
            meta_actor_loss_norm.backward(create_graph=True)

            """Meta Optimization"""
            self.actor_optimizer.step()
            self.hotplug.restore()

其中,meta-testing中计算weight以及meta_actor_loss_norm不太明白。按照本人的理解,原文计算 L t s t L_{tst} Ltst ϕ \phi ϕ求偏导:
β ∂ L t s t ( ϕ − α ∇ ϕ L t r n ( ϕ ) ) ∂ ϕ = β ∂ L t s t ( ϕ − α ∇ ϕ L t r n ( ϕ ) ) ∂ ϕ ′ ∂ ϕ ′ ∂ ϕ \frac{\beta{\color{red}\partial}\mathcal{L}_{tst}\left(\phi-\alpha\nabla_{\phi}\mathcal{L}_{trn}\left(\phi\right)\right)}{\partial\phi}=\beta\frac{{\color{red}\partial}\mathcal{L}_{tst}\left(\phi-\alpha\nabla_{\phi}\mathcal{L}_{trn}\left(\phi\right)\right)}{\partial\phi'}\frac{\partial\phi'}{\partial\phi} ϕβLtst(ϕαϕLtrn(ϕ))=βϕLtst(ϕαϕLtrn(ϕ))ϕϕ
中间的偏导自然是由meta-test小节的损失函数所得到的meta_actor_loss。而 β ∂ ϕ ′ ∂ ϕ = β ∂ L t s t ∂ ϕ / ∂ L t s t ∂ ϕ ′ = β ∂ L t r n ∂ ϕ / ∂ L t r n ∂ ϕ ′ \beta\frac{\partial\phi'}{\partial\phi}=\beta\frac{\partial L_{tst}}{\partial\phi}/\frac{\partial L_{tst}}{\partial\phi'}=\beta\frac{\partial L_{trn}}{\partial\phi}/\frac{\partial L_{trn}}{\partial\phi'} βϕϕ=βϕLtst/ϕLtst=βϕLtrn/ϕLtrn就是那个weight。
但这样应该使用从相同的Buffer中获得状态数据,这并未在源码中体现。

可能有误,欢迎指正

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1501290.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java零基础 - 数组的定义和声明

哈喽,各位小伙伴们,你们好呀,我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一个人虽可以走的更快,但一群人可以走的更远。 我是一名后…

React-Redux中actions

一、同步actions 1.概念 说明:在reducers的同步修改方法中添加action对象参数,在调用actionCreater的时候传递参数,数会被传递到action对象payload属性上。 2.reducers对象 说明:声明函数同时接受参数 const counterStorecre…

DDoS和CC攻击的原理

目前最常见的网络攻击方式就是CC攻击和DDoS攻击这两种,很多互联网企业服务器遭到攻击后接入我们德迅云安全高防时会问到,什么是CC攻击,什么又是DDoS攻击,这两个有什么区别的,其实清楚它们的攻击原理,也就知…

mybatis中使用<choose><when><otherwise>标签实现根据条件查询不同sql

项目场景&#xff1a; 有时候业务层未进行条件处理那么在sql怎么操作呢,这里我是将c#版本的代码改成Java版本的时候出现的问题,因为c#没有业务层 更多操作是在sql中实现的 也就是业务层和编写sql地方一起写了,当我按照c#代码改Java到写sql时发现<if>标签不能实现我们业务…

3.8 动态规划 背包问题

一.01背包 46. 携带研究材料&#xff08;第六期模拟笔试&#xff09; (kamacoder.com) 代码随想录 (programmercarl.com) 携带研究材料: 时间限制&#xff1a;5.000S 空间限制&#xff1a;128MB 题目描述: 小明是一位科学家&#xff0c;他需要参加一场重要的国际科学大会…

OpenCascade源码剖析:Handle类

Handle其实就是智能指针的上古版本&#xff0c;了解一点C11的应该对shared_ptr非常熟悉&#xff0c;那么你就把Handle当做shared_ptr来理解就没有任何问题了。 不过OCCT的Handles是侵入式的实现&#xff0c;前面讲过Standard_Transient类提供了引用计数机制&#xff0c;这个就…

新质生产力助春播春管:佳格天地连续第5年上线大数据平台,服务春季生产

随着“惊蛰”节气过去,全国各地陆续掀起春播春管热潮。今年的政府工作报告中指出,2023年我国粮食产量1.39万亿斤,再创新高。2024年要坚持不懈抓好“三农”工作,扎实推进乡村全面振兴,粮食产量预期目标1.3万亿斤以上。 粮食产量预期目标的明确为一年农事生产指引了方向。同时,新…

地址分词 | EXCEL批量进行地址分词,标准化为十一级地址

一 需求 物流需要对用户输入地址进行检查&#xff0c;受用户录入习惯地址可能存在多种问题。 地址标准化是基于地址引擎和地址大数据模型&#xff0c;自动将地址信息标准化为省、市、区市县、街镇、小区、楼栋、单元、楼层、房屋、房间等元素&#xff0c;补充层级缺失数据、构建…

导出谷歌gemma模型为ONNX

参考代码如下&#xff08;从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来&#xff0c;后面会合入进去&#xff09; 模型权重链接参考&#xff1a; https://huggingface.co/google/gemma-2b-it 可以对modeling_gemma.py进行一些修改(transforme…

LLCC68与SX1278 LoRa模块的优势对比?

LLCC68和SX1278都是Semtech公司推出的LoRa调制解调器模块&#xff0c;属于LoRa模块家族。它们在无线通信领域都有着广泛的应用&#xff0c;但具体的优势会取决于具体的应用场景和需求。下面是对LLCC68和SX1278 LoRa模块的一些优势对比&#xff1a; LLCC68 LoRa模块的优势&#…

qt自定义时间选择控件窗口

效果如图&#xff1a; 布局如图&#xff1a; 参考代码&#xff1a; //DateTimeSelectWidget #ifndef DATETIMESELECTWIDGET_H #define DATETIMESELECTWIDGET_H#include <QWidget> #include <QDateTime>namespace Ui { class DateTimeSelectWidget; }class DateTim…

【手游联运平台搭建】游戏平台的作用

随着科技的不断发展&#xff0c;游戏行业也在不断壮大&#xff0c;而游戏平台作为连接玩家与游戏的桥梁&#xff0c;发挥着越来越重要的作用。游戏平台不仅为玩家提供了便捷的游戏体验&#xff0c;还为游戏开发者提供了广阔的市场和推广渠道。本文将从多个方面探讨游戏平台的作…

扩展CArray类,增加Contain函数

CArray不包含查找类的函数&#xff0c;使用不便。考虑扩展CArray类&#xff0c;增加Contain函数&#xff0c;通过回调函数暴露数组元素的比较方法&#xff0c;由外部定义。该方法相对重载数组元素的“”符号更加灵活&#xff0c;可以根据需要配置不同的回调函数进行比较 //类型…

继深圳后,重庆与鸿蒙展开原生应用开发合作

截至2023年底&#xff0c;开源鸿蒙开源社区已有250多家生态伙伴加入&#xff0c;开源鸿蒙项目捐赠人达35家&#xff0c;通过开源鸿蒙兼容性测评的伙伴达173个&#xff0c;累计落地230余款商用设备&#xff0c;涵盖金融、教育、智能家居、交通、数字政府、工业、医疗等各领域。 …

底层day3作业

思维导图 作业&#xff1a;1.总结任务的调度算法&#xff0c;把实现代码再写一下 算法&#xff1a;抢占式调度时间片轮转 1.抢占式调度&#xff1a;任务优先级高的可以打断任务优先级低的执行&#xff08;适用于不同优先级&#xff09; 2.时间片轮转&#xff1a;每一个任务拥…

react的diff源码

react 的 render 阶段&#xff0c;其中 begin 时会调用 reconcileChildren 函数&#xff0c; reconcileChildren 中做的事情就是 react 知名的 diff 过程 diff 算法介绍 react 的每次更新&#xff0c;都会将新的 ReactElement 内容与旧的 fiber 树作对比&#xff0c;比较出它们…

电脑小问题:Windows更新后黑屏

Windows 更新后黑屏解决方法 在 Windows 更新后&#xff0c;伴随了一个小问题&#xff0c;电脑启动后出现了桌面黑屏。原因可能是火绒把 explorer.exe 当病毒处理了。 下面讲解 Windows 更新后黑屏的解决方法&#xff0c;步骤如下&#xff1a; 1. 按 ctrl alt delete 组合键…

基于Python3的数据结构与算法 - 12 数据结构(列表和栈)

目录 一、引入 二、分类 三、列表 1. C语言中数组的存储方式 2. Python中列表的存储方式 四、栈 1. 栈的应用 -- 括号匹配问题 一、引入 定义&#xff1a;数据结构是指相互之间存在着一种或多种关系的数据元素的集合和该集合中数据元素之间的关系组成。简单来说&#x…

portainer管理远程docker和docker-swarm集群

使用前请先安装docker和docker-compose&#xff0c;同时完成docker-swarm集群初始化 一、portainer-ce部署 部署portainer-ce实时管理本机docker&#xff0c;使用docker-compose一键拉起 docker-compose.yml version: 3 services:portainer:container_name: portainer#imag…

Docker上部署LPG(loki+promtail+grafana)踩坑复盘

Docker上部署LPG&#xff08;lokipromtailgrafana&#xff09;踩坑复盘 声明网上配置部署踩坑 声明 参考掘金文章&#xff1a;https://juejin.cn/post/7008424451704356872 版本高的用docker compose命令&#xff0c;版本低的用docker-compose 按照文章描述&#xff0c;主要准备…