【动手学强化学习】part8-PPO(Proximal Policy Optimization)近端策略优化算法

news2025/3/16 5:50:23

阐述、总结【动手学强化学习】章节内容的学习情况,复现并理解代码。

文章目录

  • 一、算法背景
    • 1.1 算法目标
    • 1.2 存在问题
    • 1.3 解决方法
  • 二、PPO-截断算法
    • 2.1 必要说明
    • 2.2 伪代码
      • 算法流程简述
    • 2.3 算法代码
    • 2.4 运行结果
    • 2.5 算法流程说明
  • 三、疑问
  • 四、总结


一、算法背景

1.1 算法目标

给定“黑盒”环境,求解最优policy。

1.2 存在问题

之前介绍的基于策略的方法包括策略梯度算法和 Actor-Critic 算法。回顾一下基于策略的方法:参数化智能体的策略,并设计衡量策略好坏的目标函数,通过梯度上升的方法来最大化这个目标函数,使得策略最优。
这些方法虽然简单、直观,但在实际应用过程中会遇到训练不稳定的情况。这种算法有一个明显的缺点:当策略网络是深度模型时,沿着策略梯度更新参数,很有可能由于步长太长,策略突然显著变差,进而影响训练效果

1.3 解决方法

在更新时找到一块信任区域(trust region),在这个区域上更新策略时能够得到某种策略性能的安全性保证,这就是信任区域策略优化(trust region policy optimization,TRPO)算法的主要思想。(2015年提出
TRPO 算法在很多场景上的应用都很成功,但是我们也发现它的计算过程非常复杂,每一步更新的运算量非常大。因此,在目标函数中进行限制(引入一种更简单的剪切clip),以保证新的参数和旧的参数的差距不会太大,称之为PPO-截断。(2017年提出

  • PPO算法的核心思想:通过限制新旧策略之间的差异来稳定训练过程,使用裁剪的目标函数来替代传统的策略梯度目标,确保策略更新不会过大,从而提升训练的稳定性和效率

二、PPO-截断算法

  • 🌟算法类型
    环境依赖:❌model-based ✅model-free
    价值估计:✅non-incremental ❌incremental
    价值表征:❌tabular representation ✅function representation
    学习方式:✅on-policy ❌off-policy
    策略表征:❌value-based ✅policy-based

· non-incremental:(采用了repaly_buffer,采样一定数量的episodes才开始训练)
· function representation:(value_net(critic)网络用于学习价值函数)
· on-policy:(采样episode和优化的policy都是policy_net(actor))
· policy-based:(PPO算法是基于Actor-Critic 算法,本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习)

2.1 必要说明

  • PPO算法的优化目标是什么?

PPO算法是在TRPO算法基础上去优化的,因此二者算法优化目标一致:
max ⁡ θ E s ∼ ν π θ k E a ∼ π θ k ( ⋅ ∣ s ) [ π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) ] \max_\theta\quad\mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}}\mathbb{E}_{a\sim\pi_{\theta_k}(\cdot|s)}\left[\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}A^{\pi_{\theta_k}}(s,a)\right] θmaxEsνπθkEaπθk(s)[πθk(as)πθ(as)Aπθk(s,a)]
s . t . E s ∼ ν π θ k [ D K L ( π θ k ( ⋅ ∣ s ) , π θ ( ⋅ ∣ s ) ) ] ≤ δ \mathrm{s.t.}\quad\mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}}[D_{KL}(\pi_{\theta_k}(\cdot|s),\pi_\theta(\cdot|s))]\leq\delta s.t.Esνπθk[DKL(πθk(s),πθ(s))]δ

将优化目标拆解理解一下

m a x θ max_\theta maxθ :最大化关于策略参数 θ 的目标函数。

E s ∼ ν π θ k \mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}} Esνπθk :对状态 s 的期望,状态 s 从旧策略 π θ k \pi_{\theta_{k}} πθk 生成的状态分布 ν π θ k \nu^{\pi\theta_{k}} νπθk 中采样。

E a ∼ π θ k ( ⋅ ∣ s ) \mathbb{E}_{a\sim\pi_{\theta_{k}}(\cdot|s)} Eaπθk(s) :对动作 a 的期望,动作 a 从旧策略 π θ k \pi_{\theta_{k}} πθk 在状态 s 下的动作分布中采样。

π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as) :新策略 π θ \pi_{\theta} πθ 和旧策略 π θ k \pi_{\theta_{k}} πθk 在状态 s 下选择动作 a 的概率比值。

A π θ k ( s , a ) A^{\pi_{\theta_k}}(s,a) Aπθk(s,a) :在旧策略 π θ k \pi_{\theta_{k}} πθk 下,状态 s 和动作 a 的优势函数。

将约束拆解理解一下

E s ∼ ν π θ k \mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}} Esνπθk :对状态 s 的期望,状态 s 从旧策略 π θ k \pi_{\theta_{k}} πθk 生成的状态分布 ν π θ k \nu^{\pi\theta_{k}} νπθk 中采样。

D K L ( π θ k ( ⋅ ∣ s ) , π θ ( ⋅ ∣ s ) ) D_{KL}(\pi_{\theta_k}(\cdot|s),\pi_\theta(\cdot|s)) DKL(πθk(s),πθ(s)) :Kullback-Leibler(KL)散度,衡量旧策略 π θ k \pi_{\theta_{k}} πθk和新策略 π θ \pi_{\theta} πθ 在状态 s 下的动作分布之间的差异。

≤ δ \leq\delta δ :KL散度的期望值不超过一个预设的阈值 δ。通过限制策略更新的幅度,避免策略变化过大导致训练不稳定。

  • PPO-Clip算法做了什么改进?

PPO 的另一种形式 PPO-截断(PPO-Clip)更加直接,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大,即:

arg ⁡ max ⁡ θ E s ∼ ν π θ k E a ∼ π θ k ( ⋅ ∣ s ) [ min ⁡ ( π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) , c l i p ( π θ ( a ∣ s ) π θ k ( a ∣ s ) , 1 − ϵ , 1 + ϵ ) A π θ k ( s , a ) ) ] \arg\max_{\theta}\mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}}\mathbb{E}_{a\sim\pi_{\theta_k}(\cdot|s)}\left[\min\left(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}A^{\pi_{\theta_k}}(s,a),\mathrm{clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)},1-\epsilon,1+\epsilon\right)A^{\pi_{\theta_k}}(s,a)\right)\right] argθmaxEsνπθkEaπθk(s)[min(πθk(as)πθ(as)Aπθk(s,a),clip(πθk(as)πθ(as),1ϵ,1+ϵ)Aπθk(s,a))]

逐个部分拆解理解一下

E s ∼ ν π θ k \mathbb{E}_{s\sim\nu^{\pi_{\theta_k}}} Esνπθk :对状态 s 的期望,状态 s 从旧策略 π θ k \pi_{\theta_{k}} πθk 生成的状态分布 ν π θ k \nu^{\pi\theta_{k}} νπθk 中采样。

E a ∼ π θ k ( ⋅ ∣ s ) \mathbb{E}_{a\sim\pi_{\theta_{k}}(\cdot|s)} Eaπθk(s) :对动作 a 的期望,动作 a 从旧策略 π θ k \pi_{\theta_{k}} πθk 在状态 s 下的动作分布中采样。

π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as) :新策略 π θ \pi_{\theta} πθ 和旧策略 π θ k \pi_{\theta_{k}} πθk 在状态 s 下选择动作 a 的概率比值。

A π θ k ( s , a ) A^{\pi_{\theta_k}}(s,a) Aπθk(s,a) :在旧策略 π θ k \pi_{\theta_{k}} πθk 下,状态 s 和动作 a 的优势函数。

c l i p ( π θ ( a ∣ s ) π θ k ( a ∣ s ) , 1 − ϵ , 1 + ϵ ) {clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)},1-\epsilon,1+\epsilon\right) clip(πθk(as)πθ(as),1ϵ,1+ϵ) :将概率比值 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as) 限制在区间 [1−ϵ,1+ϵ] 内,其中 ϵ 是一个超参数,表示截断的范围。

⑥ min(⋅,⋅) : 取两个值中的较小值,确保策略更新不会过大。

2.2 伪代码

初始化:
创建环境 env
初始化策略网络 π_θ (PolicyNet)
初始化价值网络 V_φ (ValueNet)
设置超参数: 策略学习率 α, 价值学习率 β, 折扣因子 γ, PPO截断范围 ε, 训练轮数 K 等
初始化策略优化器和价值优化器
采集轨迹:
对于每个回合:
初始化状态 s
当回合未结束时:
根据策略 π_θ 选择动作 a
在环境中执行动作 a, 获取下一个状态 s’, 奖励 r 和是否结束 done
将转移 (s, a, r, s’, done) 存入缓冲区
更新当前状态 s 为 s’
结束循环
结束循环
计算优势和目标:
对于缓冲区中的每个转移:
计算TD目标: TD_target = r + γ * V_φ(s’) * (1 - done)
计算TD误差: δ = TD_target - V_φ(s)
使用广义优势估计 (GAE) 计算优势函数 A(s, a)
结束循环
PPO更新:
对于 k in 1 到 K:
对于缓冲区中的每个批次:
计算概率比值: r = π_θ_new(a|s) / π_θ_old(a|s)
计算代理目标1: L^CLIP = min(r * A, clip(r, 1-ε, 1+ε) * A)
策略损失: L^actor = -mean(L^CLIP)
价值损失: L^critic = mean((V_φ(s) - TD_target)^2)
使用Adam优化器更新策略网络以最大化 L^CLIP
使用Adam优化器更新价值网络以最小化 L^critic
结束循环
结束循环
重复:
重复上述过程,直到收敛

算法流程简述

2.3 算法代码

import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用来训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
                                               td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1,
                                                            actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps,
                                1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            critic_loss = torch.mean(
                F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 500
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
            epochs, eps, gamma, device)

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('PPO on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('PPO on {}'.format(env_name))
plt.show()

2.4 运行结果

Iteration 0: 100%|██████████| 50/50 [00:05<00:00, 9.51it/s, episode=50, return=182.300]
Iteration 1: 100%|██████████| 50/50 [00:06<00:00, 8.10it/s, episode=100, return=173.100]
Iteration 2: 100%|██████████| 50/50 [00:09<00:00, 5.30it/s, episode=150, return=200.000]
Iteration 3: 100%|██████████| 50/50 [00:07<00:00, 6.28it/s, episode=200, return=200.000]
Iteration 4: 100%|██████████| 50/50 [00:11<00:00, 4.36it/s, episode=250, return=200.000]
Iteration 5: 100%|██████████| 50/50 [00:09<00:00, 5.02it/s, episode=300, return=200.000]
Iteration 6: 100%|██████████| 50/50 [00:08<00:00, 6.23it/s, episode=350, return=200.000]
Iteration 7: 100%|██████████| 50/50 [00:07<00:00, 6.76it/s, episode=400, return=200.000]
Iteration 8: 100%|██████████| 50/50 [00:07<00:00, 6.70it/s, episode=450, return=200.000]
Iteration 9: 100%|██████████| 50/50 [00:07<00:00, 6.28it/s, episode=500, return=200.000]

在这里插入图片描述
在这里插入图片描述

2.5 算法流程说明

  • 初始化参数
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 500
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")
  • 环境设置
env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
            epochs, eps, gamma, device)
  • 网络定义
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

策略网络:policy_net(actor)设定为4-128-2的全连接层网络,输入为state=4维,输出为action概率=2维
y = s o f t m a x ( f c 2 ( r e l u ( f c 1 ( x ) ) ) ) , x = s t a t e s y=softmax\left(fc_2\left(relu\left(fc_1\left(x\right)\right)\right)\right),x=states y=softmax(fc2(relu(fc1(x)))),x=states价值网络:value_net(critic)设定为4-128-1的全连接层网络,输入维state=4维,输出为state value=1维
y = f c 2 ( r e l u ( f c 1 ( x ) ) ) , x = s t a t e s y=fc_2(relu(fc_1(x))),x=states y=fc2(relu(fc1(x))),x=states

  • 采样episode
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)
...
def train_on_policy_agent(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    next_state, reward, done, _ = env.step(action)
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list
...
    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

①获取初始state:state = env.reset()
②根据policy_net的输出采取action:agent.take_action(state)
③与环境交互,得到(s,a,r,s’,done):env.step(action)
④将样本添加至episode:transition_dict
⑤统计episode即时奖励累加值:episode_return += reward

  • 网络更新
	agent.update(transition_dict)
	...
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
                                               td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1,
                                                            actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps,
                                1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            critic_loss = torch.mean(
                F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

这一段是算法的核心部分,逐行进行解释

① 数据转换

        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
  • 功能:将经验数据(状态、动作、奖励、下一状态、是否结束)转换为PyTorch张量,并移动到指定设备(CPU或GPU)。
  • 数学背景:这些数据用于计算TD目标、TD误差和优势函数。

计算优势函数 A π θ k ( s , a ) A^{\pi_{\theta_k}}(s,a) Aπθk(s,a)

  • 计算td_target:
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)

β t = r t + 1 + γ ⋅ v ( s t + 1 ) ⋅ ( 1 − d o n e ) \beta_t=r_{t+1}+\gamma\cdot v(s_{t+1})\cdot(1-done) βt=rt+1+γv(st+1)(1done)

  • 计算td_error:
  td_delta = td_target - self.critic(states)

δ t = β t − v ( s t ) \delta_t=\beta_t-v(s_t) δt=βtv(st)

  • 计算优势函数:
advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
...
def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)

A t = ∑ l = 0 n − 1 ( γ λ ) l δ t + l A_t=\sum_{l=0}^{n-1}(\gamma\lambda)^l\delta_{t+l} At=l=0n1(γλ)lδt+l
γ 是折扣因子。
λ 是GAE的缩放因子,控制了优势函数的时间范围。
δ t + l \delta_{t+l} δt+l是第t+l步的TD误差。
GAE结合了蒙特卡洛方法和TD方法的优点,提供了更稳定的优势估计。

计算截断后的目标函数 c l i p ( π θ ( a ∣ s ) π θ k ( a ∣ s ) , 1 − ϵ , 1 + ϵ ) {clip}\left(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)},1-\epsilon,1+\epsilon\right) clip(πθk(as)πθ(as),1ϵ,1+ϵ)

  • 计算新策略对应旧策略的比值 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as)
    old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
    log_probs = torch.log(self.actor(states).gather(1, actions))
    ratio = torch.exp(log_probs - old_log_probs)

r t = π n e w ( a t ∣ s t ) π o l d ( a t ∣ s t ) = exp ⁡ ( log ⁡ π n e w ( a t ∣ s t ) − log ⁡ π o l d ( a t ∣ s t ) ) r_t=\frac{\pi_{new}(a_t|s_t)}{\pi_{old}(a_t|s_t)}=\exp(\log\pi_{new}(a_t|s_t)-\log\pi_{old}(a_t|s_t)) rt=πold(atst)πnew(atst)=exp(logπnew(atst)logπold(atst))

  • 计算未截断的目标函数
  surr1 = ratio * advantage

L 1 C l i p = r t ⋅ A t = π n e w ( a t ∣ s t ) π o l d ( a t ∣ s t ) ⋅ A t L_1^{Clip}=r_t\cdot A_t=\frac{\pi_{new}(a_t|s_t)}{\pi_{old}(a_t|s_t)}\cdot A_t L1Clip=rtAt=πold(atst)πnew(atst)At

  • 计算截断后的目标函数
  surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage 

L 2 C l i p = c l i p ( r t , 1 − ε , 1 + ε ) ⋅ A t = c l i p ( π n e w ( a t ∣ s t ) π o l d ( a t ∣ s t ) , 1 − ε , 1 + ε ) ⋅ A t L_2^{Clip}=clip(r_t,1-\varepsilon,1+\varepsilon)\cdot A_t=clip(\frac{\pi_{new}(a_t|s_t)}{\pi_{old}(a_t|s_t)},1-\varepsilon,1+\varepsilon)\cdot A_t L2Clip=clip(rt,1ε,1+ε)At=clip(πold(atst)πnew(atst),1ε,1+ε)At

④ 更新策略+价值网络

  • 设置策略网络损失函数
 actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数

L C L I P = − E [ m i n ( L 1 C L I P , L 2 C L I P ) ] = − E [ m i n ( π n e w ( a t ∣ s t ) π o l d ( a t ∣ s t ) ⋅ A t , c l i p ( π n e w ( a t ∣ s t ) π o l d ( a t ∣ s t ) , 1 − ε , 1 + ε ) ⋅ A t ) ] L^{CLIP}=-\mathbb{E}[min(L_1^{CLIP},L_2^{CLIP})]=-\mathbb{E}\left[min\left(\frac{\pi_{new}(a_t|s_t)}{\pi_{old}(a_t|s_t)}\cdot A_t,clip(\frac{\pi_{new}(a_t|s_t)}{\pi_{old}(a_t|s_t)},1-\varepsilon,1+\varepsilon)\cdot A_t\right)\right] LCLIP=E[min(L1CLIP,L2CLIP)]=E[min(πold(atst)πnew(atst)At,clip(πold(atst)πnew(atst),1ε,1+ε)At)]

负号表示我们要最小化损失,即最大化目标函数。
这里损失函数的设置与PPO-clip的优化目标一致!

  • 设置价值网络损失函数
 critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

L V F = E [ M S E ( ( V ( s t ) − T D T a r g e t ) 2 ) ] L^{VF}=\mathbb{E}[MSE((V(s_t)-TD_Target)^2)] LVF=E[MSE((V(st)TDTarget)2)]

  • 清零梯度
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
  • 反向传播
        actor_loss.backward()
        critic_loss.backward()
  • 更新网络参数
        self.actor_optimizer.step()
        self.critic_optimizer.step()

三、疑问

暂无

四、总结

  • PPO算法是基于TRPO算法的改进,TRPO算法的数学原理比较复杂,以后有机会深入学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2315824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

游戏引擎学习第159天

回顾与今天的计划 我们在完成一款游戏的制作。这个游戏没有使用任何引擎或新库&#xff0c;而是从零开始编写的完整游戏代码库&#xff0c;您可以自行编译它&#xff0c;并且它是一个完整的游戏。更特别的是&#xff0c;这个游戏甚至没有使用显卡&#xff0c;所有的渲染工作都…

内网攻防——红日靶场(一)

在学习内网的过程中有着诸多不了解的内容。希望能借下面的靶场来步入内网的大门。 一、准备阶段 首先准备好我们的虚拟机 之前有学过关于&#xff1a;工作组、域、DC的概念。 了解一下此时的网络拓扑图 1.设置网络VMnet1和Vmnet8 将VMnet1作为内网&#xff0c;VMnet8作为外…

协议-LoRa-Lorawan

是什么? LoRa是低功耗广域网通信技术中的一种,是Semtech公司专有的一种基于扩频技术的超远距离无线传输技术。LoRaWAN是为LoRa远距离通信网络设计的一套通讯协议和系统架构。它是一种媒体访问控制(MAC)层协议。LoRa = PHY Layer LoRaWAN = MAC Layer功耗最低,传输最远 ![ …

redis主从搭建

1. 哨兵 1.1 ⼈⼯恢复主节点故障 Redis 的主从复制模式下&#xff0c;⼀旦主节点由于故障不能提供服务&#xff0c;需要⼈⼯进⾏主从切换&#xff0c;同时⼤量 的客⼾端需要被通知切换到新的主节点上&#xff0c;对于上了⼀定规模的应⽤来说&#xff0c;这种⽅案是⽆法接受的&…

Linux中Gdb调试工具常用指令大全

1.gdb的安装 如果你是root用户直接用指令 &#xff1a;yum install gdb &#xff1b;如果你是普通用户用指令&#xff1a;sudo yum install gdb&#xff1b; 2.gdb调试前可以对你的makefile文件进行编写&#xff1a; 下面展示为11.c文件编写的makefile文件&#xff1a; code…

操作系统-八股

进程基础&#xff1a; 进程定义&#xff1a;运行中的程序&#xff0c;有独立的内存空间和地址&#xff0c;是系统进行资源调度和分配的基本单位。 并发&#xff0c;并行 并发就是单核上面轮询&#xff0c;并行就是同时执行&#xff08;多核&#xff09;&#xff1b; 进程上下…

ICLR2025 | SLMRec: 重新思考大语言模型在推荐系统中的价值

note 问题背景&#xff1a;序列推荐&#xff08;SR&#xff09;任务旨在预测用户可能的下一个交互项目。近年来&#xff0c;大型语言模型&#xff08;LLMs&#xff09;在SR系统中表现出色&#xff0c;但它们巨大的规模使得在实际平台中应用变得低效和不切实际。 研究动机&…

71.HarmonyOS NEXT PicturePreviewImage组件深度剖析:从架构设计到核心代码实现

温馨提示&#xff1a;本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦&#xff01; HarmonyOS NEXT PicturePreviewImage组件深度剖析&#xff1a;从架构设计到核心代码实现 (一) 文章目录 HarmonyOS NEXT PicturePreviewImage组件深…

简单实现京东登录页面

Entry Component struct Index {State message: string ;build() { Column(){//顶部区域Row(){Image($r(app.media.jd_cancel)).width(20).height(20)Text(帮助)}.width(100%).justifyContent(FlexAlign.SpaceBetween)//logo图标Image($r(app.media.jd_logo)).width(250).heig…

9.贪心算法

简单贪心 1.P10452 货仓选址 - 洛谷 #include<iostream> #include<algorithm> using namespace std;typedef long long LL; const int N 1e510; LL a[N]; LL n;int main() {cin>>n;for(int i 1;i < n;i)cin>>a[i];sort(a1,a1n);//排序 LL sum 0…

大模型训练全流程深度解析

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。https://www.captainbed.cn/north 文章目录 1. 大模型训练概览1.1 训练流程总览1.2 关键技术指标 2. 数据准备2.1 数据收集与清洗2.2 数据…

每日一题---单词搜索(深搜)

单词搜索 给出一个二维字符数组和一个单词&#xff0c;判断单词是否在数组中出现&#xff0c; 单词由相邻单元格的字母连接而成&#xff0c;相邻单元指的是上下左右相邻。同一单元格的字母不能多次使用。 数据范围&#xff1a; 0 < 行长度 < 100 0 < 列长度 <…

插入排序c++

插入排序的时间复杂度为O&#xff08;N^2&#xff09;&#xff0c;和冒泡排序的时间复杂度相同&#xff0c;但是在某些情况下插入排序会更优。 插入排序的原理是&#xff1a;第1次在0~0范围内排序&#xff0c;第2次在0~1范围内排序&#xff0c;第3次在0~2范围内排序……相当于…

Swagger 从 .NET 9 中删除:有哪些替代方案

微软已经放弃了对 .NET 9 中 Swagger UI 包 Swashbuckle 的支持。他们声称该项目“不再由社区所有者积极维护”并且“问题尚未得到解决”。 这意味着当您使用 .NET 9 模板创建 Web API 时&#xff0c;您将不再拥有 UI 来测试您的 API 端点。 我们将调查是否可以在 .NET 9 中使用…

嵌入式八股ARM篇

前言 ARM篇主要介绍一下寄存器和中断机制,至于汇编这一块…还请大家感兴趣自行学习 1.寄存器 R0 - R3 R4 - R11 寄存器 R0 - R3一般用作函数传参 R4 - R11用来保存程序运算的中间结果或函数的局部变量 在函数调用过程中 注意在发生异常的时候 cortex-M0架构会自动将R0-R3压入…

使用DeepSeek和墨刀AI,写PRD文档、画原型图的思路、过程及方法

使用DeepSeek和墨刀AI&#xff0c;写PRD文档、画原型图的思路、过程及方法 现在PRD文档要如何写更高效、更清晰、更完整&#xff1f; 还是按以前的思路写PRD&#xff0c;就还是以前的样子。 现在AI这么强大&#xff0c;产品经理如何使用DeepSeek写PRD文档&#xff0c;产品经…

【VUE2】第五期——VueCli创建项目、Vuex多组件共享数据、json-server——模拟服务端api

黑马程序员视频地址&#xff1a;091-vuex的基本认知_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1HV4y1a7n4?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p91 目录 1 VueCli 自定义创建项目 2 Eslint代码规范 2.1 规…

rpmlib(SetVersions) is needed by can-uilts-v2019.00.0-alt1.aarch64

在我在Linux中安装离线CAN工具时&#xff0c;出现了一个问题&#xff0c; rootwanghuo:~# rpm -ivh can-uilts-v2019.00.0-alt1.aarch64.rpm error: Failed dependencies:rpmlib(SetVersions) is needed by can-uilts-v2019.00.0-alt1.aarch64 意思是尝试安装 can-uilts-v20…

CNN 稠密任务经典结构

FCN UNet FPN FCNUNETFPNpadding无&#xff08;逐渐变小&#xff09; 有&#xff08;左右对称&#xff09;上采样 双线性双线性 最近邻跳跃链接 相加 Cropcat 1x1卷积相加 三个网络差不多&#xff0c;UNet名字最直观&#xff0c;后续流传…

算法刷题整理合集(二)

本篇博客旨在记录自已的算法刷题练习成长&#xff0c;里面注有详细的代码注释以及和个人的思路想法&#xff0c;希望可以给同道之人些许帮助。本人也是算法小白&#xff0c;水平有限&#xff0c;如果文章中有什么错误或遗漏之处&#xff0c;望各位可以在评论区指正出来&#xf…