IEEE TAI 2024
paper
1 Introduction
一篇offline to online 的文章,有效解决迁移过程出现的performance drop。所提出的O2AC算法首先在离线阶段添加一项BC惩罚项,用于限制策略靠近专家策略;而在在线微调阶段,通过动态调整BC的权重,缓解performance drop。
2 Method
2.1 offline
离线阶段,采用BC结合确定性策略优化方法。最大化下列损失函数:
J
o
f
f
i
n
e
(
θ
)
=
E
(
s
,
a
)
∼
B
[
ζ
Q
ϕ
(
s
,
π
θ
(
s
)
)
−
∥
π
θ
(
s
)
−
a
∥
2
]
J_{\mathrm{offine}}(\boldsymbol{\theta})=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\zeta Q_{\boldsymbol{\phi}}(\boldsymbol{s},\pi_{\boldsymbol{\theta}}(\boldsymbol{s}))-\left\|\pi_{\boldsymbol{\theta}}(\boldsymbol{s})-\boldsymbol{a}\right\|^2\right]
Joffine(θ)=E(s,a)∼B[ζQϕ(s,πθ(s))−∥πθ(s)−a∥2]
其中,
ζ
\zeta
ζ用于平衡BC以及一般policy iteration,其数值如下:
ζ
=
α
1
m
∑
(
s
i
,
a
i
)
∈
B
‾
∣
Q
(
s
i
,
a
i
)
∣
\zeta=\frac{\alpha}{\frac1m\sum_{(\boldsymbol{s}_i,\boldsymbol{a}_i)\in\overline{\mathcal{B}}}|Q(\boldsymbol{s}_i,\boldsymbol{a}_i)|}
ζ=m1∑(si,ai)∈B∣Q(si,ai)∣α
其中
B
‾
\overline{\mathcal{B}}
B表示从Buffer中采样地mini-batch, size为m
2.2 online
在线微调阶段,对确定性策略优化的损失函数表示如下
J
o
n
l
i
n
e
(
θ
)
=
E
(
s
,
a
)
∼
B
[
ζ
Q
ϕ
(
s
,
π
θ
(
s
)
)
−
λ
∥
π
θ
(
s
)
−
a
∥
2
]
J_{\mathrm{online}}(\boldsymbol{\theta})=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\zeta Q_{\boldsymbol{\phi}}(\boldsymbol{s},\pi_{\boldsymbol{\theta}}(\boldsymbol{s}))-\lambda\left\|\pi_{\boldsymbol{\theta}}(\boldsymbol{s})-\boldsymbol{a}\right\|^2\right]
Jonline(θ)=E(s,a)∼B[ζQϕ(s,πθ(s))−λ∥πθ(s)−a∥2]
相较于offline,损失函数增加对BC权重因子
λ
\lambda
λ。该数值是动态减少的,实验设置为每5k steps, 减少10%。对Q价值的更新则是类似于TD3,使用两个target网络以及延时更新。
L
(
ϕ
)
=
E
(
s
,
a
)
∼
B
[
(
y
ˉ
−
Q
ϕ
(
s
,
a
)
)
2
]
where
y
ˉ
=
r
+
min
i
=
1
,
2
Q
ϕ
i
ˉ
(
s
,
′
a
′
∼
π
θ
ˉ
)
.
\begin{aligned}L(\phi)&=\mathbb{E}_{(\boldsymbol{s},\boldsymbol{a})\sim\mathcal{B}}\left[\left(\bar{y}-Q_{\boldsymbol{\phi}}(\boldsymbol{s},\boldsymbol{a})\right)^2\right]\\\\\text{where }\bar{y}&=r+\min_{i=1,2}Q_{\bar{\boldsymbol{\phi}_i}}(\boldsymbol{s},'\boldsymbol{a}'\sim\pi_{\bar{\boldsymbol{\theta}}}).\end{aligned}
L(ϕ)where yˉ=E(s,a)∼B[(yˉ−Qϕ(s,a))2]=r+i=1,2minQϕiˉ(s,′a′∼πθˉ).
伪代码如下:
Summary
有个疑问,online阶段对策略进行更新时,采样的数据(s,a)是来自replaybuffer B \mathcal{B} B。 B \mathcal{B} B包含在线阶段真实交互数据以及离线数据。如果(s,a)是OOD或者质量差数据,那么此时BC项应该尽可能地不要发挥作用。简单的调整 λ \lambda λ恐怕效果不够。可以探索添在BC项再加一个指示函数自适应地判断,“异常数据”直接截断为0.