不说其他的解释,上来就看代码。建议先对PPO的整体流程有了解。
trl的版本为0.4.0,注:【新版的trl中代码更复杂,如果只是想读懂PPO具体怎么用trl实现的,0.4.0版本即可】
step1: rollout
ppo_trainer.generate()函数使用policy model生成rollout
step2:evaluate
使用reward model对step1产生的rollout进行evaluate,获得一个标量的score,这个score并不是rewards,step4计算得到的才是最终的rewards
step3: logprobs
从old policy model和ref model中获得rollout的logits, values等值,用于后续计算rewards。
对应的代码部分为:
step4: rewards
注意:这里产生的变量中,score变成了rewards。
PPO中,为了防止policy model过度偏离ref model,会在计算rewards过程中额外增加一项KL散度,
r
e
w
a
r
d
s
=
s
c
o
r
e
−
λ
K
L
(
π
θ
(
a
∣
s
)
∣
∣
π
θ
r
e
f
(
a
∣
s
)
)
rewards = score - \lambda KL(\pi_{\theta}(a|s)||\pi_{\theta_{ref}}(a|s))
rewards=score−λKL(πθ(a∣s)∣∣πθref(a∣s))
对应的代码部分为:
step5: train_minibatch
注意,这里的logprobs, vpreds, 与old_logprobs, old_values均是policy LM产生的,但是参数不一样。
在这里,产生logprobs, vpreds的policy LM的参数是会按照mini_batch_size进行不断更新的,所以每个mini_batch_size对于的new policy LM的参数是不一样的。而产生old_logprobs, old_values的old policy LM的参数对于每个mini_batch_size是不变的。
可以按照一般的训练神经网络的过程理解:产生old_logprobs, old_values的old policy LM的参数是按照epoch更新的,而产生logprobs, vpreds的new policy LM是按照step更新的。
对应的代码部分为:
step6: advantages
根据old_values, rewards,计算优势,在进一步计算出returns
对应的代码部分为(代码中的values为old_values):
step7: critic_loss
critic loss通常是通过均方误差(MSE)来计算。对于每一个状态,我们都有一个由critic网络预测的预期回报
v
p
r
e
d
s
vpreds
vpreds,以及一个真实的回报
r
e
t
u
r
n
s
returns
returns,critic_loss是二者的平方差。
对应的代码部分为:
step8: actor loss
actor loss是基于策略梯度的损失函数,用于优化policy。在ppo中,通常使用一种称为重要性采样(importance sampling)的技术来计算策略梯度。
m
a
x
i
m
i
z
e
θ
E
π
θ
′
[
m
i
n
(
r
t
(
θ
)
A
π
θ
o
l
d
(
s
,
a
)
,
c
l
i
p
(
r
t
(
θ
)
,
1
−
ϵ
,
1
+
ϵ
)
A
π
θ
o
l
d
(
s
,
a
)
)
]
maximize_{\theta} \ \ E_{\pi_{\theta^{'}}}[min( r_{t}(\theta) A^{\pi_{\theta_{old}}}(s,a),\ clip(r_{t}(\theta), 1-\epsilon, 1+\epsilon)A^{\pi_{\theta_{old}}}(s,a)\ )]
maximizeθ Eπθ′[min(rt(θ)Aπθold(s,a), clip(rt(θ),1−ϵ,1+ϵ)Aπθold(s,a) )]
其中,
r
t
(
θ
)
=
π
θ
(
a
∣
s
)
π
θ
o
l
d
(
a
∣
s
)
r_{t}(\theta) = {\pi_{\theta}(a|s) \over \pi_{\theta_{old}}(a|s)}
rt(θ)=πθold(a∣s)πθ(a∣s),这一项是新旧策略的比率,
A
π
θ
o
l
d
(
s
,
a
)
A^{\pi_{\theta_{old}}}(s,a)
Aπθold(s,a)是优势函数,clip是裁剪函数,将其裁剪到
[
1
−
ϵ
,
1
+
ϵ
]
[1-\epsilon,1+\epsilon]
[1−ϵ,1+ϵ]之间。这个损失函数的目标是,最大化新策略的期望回报,同时限制新旧策略之间的差异。