【Actor-Critic】演员评论家模型

news2024/10/5 12:44:29

本博客代码部分参考了《动手学强化学习》

基于值函数的方法(DQN)和基于策略的方法(REINFORCE),其中基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数。那么,一个很自然的问题是,有没有什么方法既学习价值函数,又学习策略函数呢?答案就是 Actor-Critic。Actor-Critic 是囊括一系列算法的整体架构,目前很多高效的前沿算法都属于 Actor-Critic 算法。需要明确的是,Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。


文章目录

  • Actor-Critic 算法
    • Actor 模型
    • Critic 模型
    • 状态动作预测
    • 更新梯度
  • 一些概念
    • 模型的比较
    • TD算法,贝尔曼方程
    • policy-base 与 value-base
    • on-policy 与 off-policy
    • on-line 与 off-line
    • model-base 与 model-free
    • Q-learning 与 Sarsa
    • 区别


Actor-Critic 算法

class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)  # 价值网络
        # 策略网络优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)  # 价值网络优化器
        self.gamma = gamma
        self.device = device

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)

        # 时序差分目标
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                       dones)
        td_delta = td_target - self.critic(states)  # 时序差分误差
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        # 均方误差损失函数
        critic_loss = torch.mean(
            F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()  # 计算策略网络的梯度
        critic_loss.backward()  # 计算价值网络的梯度
        self.actor_optimizer.step()  # 更新策略网络的参数
        self.critic_optimizer.step()  # 更新价值网络的参数

Actor 模型

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

Actor模型最后输出的是根据State返回的动作概率,从F.softmax(self.fc2(x), dim=1)可以看出

Critic 模型

class Valuenet(nn.Module):
	def __init__(self, state_dim, hidden_dim, action_dim):
		super(Valuenet, self).__init__()
		self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
		self.fc2 = torch.nn.Linear(hidden_dim, 1)
	
	def forward(self, x):
		x = F.relu(self.fc1(x))
		return self.fc2(x)

Critic模型最后输出的是对当前状态预测的最大值,与之前的DQN有些许不同的是,Critic模型输出的维度为1维度,也就是说Critic只要关注当前得分预测,不用参与实际的动作选择。

状态动作预测

def take_action(self, state):
    state = torch.tensor([state], dtype=torch.float).to(self.device)
    probs = self.actor(state)
    action_dist = torch.distributions.Categorical(probs)
    action = action_dist.sample()
    return action.item()

可以看到与环境交互的时候,只有Actor参与了预测,Critic则是在计算误差的时候,才有调用。

更新梯度

Reinforce 模型的损失计算推导可以参考该篇博客
Acotr模型的损失计算推导可以参考该博客

REINFORCE 通过蒙特卡洛采样的方法(也就是从最后一步反向累加奖励)对策略梯度的估计是无偏的,但是方差非常大。通过引入基线函数(baseline function)来减小方差,也就是Actor-Critic 中的self.critic(states)

REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,这同时也要求任务具有有限的步数,而 Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

 # 时序差分目标
 td_target = rewards + self.gamma * self.critic(next_states) * (1 -
                                                                dones)
 td_delta = td_target - self.critic(states)  # 时序差分误差
 log_probs = torch.log(self.actor(states).gather(1, actions))
 actor_loss = torch.mean(-log_probs * td_delta.detach())
 # 均方误差损失函数
 critic_loss = torch.mean(
     F.mse_loss(self.critic(states), td_target.detach()))
 self.actor_optimizer.zero_grad()
 self.critic_optimizer.zero_grad()
 actor_loss.backward()  # 计算策略网络的梯度
 critic_loss.backward()  # 计算价值网络的梯度
 self.actor_optimizer.step()  # 更新策略网络的参数
 self.critic_optimizer.step()  # 更新价值网络的参数

Acto 算法通过计算
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)作为 TD目标
td_delta = td_target - self.critic(states) 计算时序差分误差
actor_loss = torch.mean(-log_probs * td_delta.detach()) 计算Actor 误差,其中的- 是因为根据公式推导,我们需要让目标函数趋向于期望价值,所以需要梯度上升来寻找最大值,在公式推导中没有-,但是优化器是用梯度下降进行参数更新,所以使用-做方向的调整。

Critic算法通过计算
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach())) 作为Critic的误差

一些概念

模型的比较

  • Q-learning、DQN 及 DQN 改进算法都是基于价值(value-based), 他们通过选择最大价值动作来与环境交互

  • Actor-Critic 结合了两者的特点,Reinforce 作为 Actor 部分, DQN 作为 Critic 部分,从而结合了两者的有点

  • Reinforce、Actor-Critic 通过 SARSA 样本数据进行训练,所以他们是 on-policy基于策略,也就是学习一个策略,从策略中进行动作概率分布抽样

  • REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新, Actor-Critic 算法则可以在每一步之后都进行更新,并且不对任务的步数做限制。

  • TRPO 是在 Actor-Critic 的基础上加入了更新幅度限,也就是制信任区域(trust region),从而避免模型效果的震荡。

  • PPO 是 TRPO 的改良版, 基于 TRPO 的思想,但是 PPO算法实现更加简单,没有TRPO 的计算那么复杂和远算量那么大。

  • PPO 有两种形式,一是 PPO-惩罚,二是 PPO-截断,

  • PPO-截断总是比 PPO-惩罚表现得更好, 大量实验表明。

  • REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO, 这类算法有一个共同的特点:它们都是在线策略算法,这意味着它们的样本效率(sample efficiency)比较低。

  • TRPO(trust region policy optimization,TRPO)。当策略网络是深度模型时,沿着策略梯度更新参数,很有可能由于步长太长,策略突然显著变差,进而影响训练效果。针对这个问题,考虑在更新时找到一块信任区域(trust region),在这个区域上更新策略时能够得到某种策略性能的安全性保证,这就是信任区域策略优化(trust region policy optimization,TRPO)算法的主要思想。TRPO 算法在 2015 年被提出,它在理论上能够保证策略学习的性能单调性,并在实际应用中取得了比策略梯度算法更好的效果。

TD算法,贝尔曼方程

policy-base 与 value-base

  1. Policy-Based RL(基于概率)

通过感官分析所处的环境,直接输出下一步要采取的各种动作的概率,然后根据概率采取行动,所以每种动作都有可能被选中,只是可能性不同。如,Policy Gradients等。

  1. Value-Based RL(基于价值)

输出所有动作的价值,根据最高价值来选择动作。如,Q learning、Sarsa等。(对于不连续的动作,这两种方法都可行,但如果是连续的动作基于价值的方法是不能用的,我们只能用一个概率分布在连续动作中选择特定的动作)。

  1. Actor-Critic

结合这两种方法建立一种Actor-Critic的方法,基于概率会给出做出的动作,基于价值会对做出的动作的价值二者的综合。

on-policy 与 off-policy

  1. 无论是在线策略(on-policy)算法还是离线策略(off-policy)算法,都有一个共同点:智能体在训练过程中可以不断和环境交互,得到新的反馈数据。
  2. 二者的区别主要在于在线策略算法会直接使用这些反馈数据,而离线策略算法会先将数据存入经验回放池中,需要时再采样

On-Policy Learning:
在On-Policy学习中,代理学习并改进当前正在执行的策略。它会根据当前策略收集的数据进行学习,因此策略的改进可能会受到当前策略的限制。On-Policy方法通常用于需要连续决策和探索的情境。
Off-Policy Learning:
在Off-Policy学习中,代理学习一个策略,但同时也可以使用来自不同策略的经验数据。这使得代理可以更灵活地学习,并且可以更有效地重用以前的经验。Off-Policy方法通常更具有样本效率,因为它们可以更好地利用之前的经验。

on-line 与 off-line

  1. online RL(在线强化学习) 学习过程中,智能体需要和真实环境进行交互(边玩边学习)。并且在线强化学习可分为on-policy RL和off-policy RL。on-policy(在线策略学习)采用的是当前策略搜集的数据训练模型,每条数据仅使用一次,如,Actor-Critic、Sarsa等。off-policy(离线策略学习)训练采用的数据不需要是当前策略搜集的,如Q learning。

  2. offline RL(离线强化学习) 学习过程中,不与真实环境进行交互,只从过往经验(dataset)中直接学习,而dataset是采用别的策略收集的数据,并且采集数据的策略并不是近似最优策略。。

Online Learning:在线学习是指代理在与环境互动的同时学习。它不断地采集经验,并根据当前的经验进行学习和决策。Online学习适用于需要实时决策的情境,但它也可能导致学习过程中的探索成本。
Offline Learning:离线学习是指代理在与环境互动之前收集一些经验数据,然后在离线状态下进行学习。这样可以避免在线学习的探索成本,但需要足够多的先前数据来训练模型。Offline学习在某些情况下更稳定,但可能无法应对快速变化的环境。

model-base 与 model-free

Model-Based Learning: 在Model-Based学习中,代理(学习者)试图建立一个关于环境的模型,该模型可以预测状态转移和奖励。代理使用这个模型来规划和执行动作,以最大化预期奖励。Model-Based方法通常需要较多的计算资源来构建和维护环境模型。
Model-Free Learning: 在Model-Free学习中,代理不试图建立环境模型,而是直接学习策略或价值函数,以根据观察到的经验来进行动作选择。Model-Free方法通常更适用于复杂或不确定的环境,因为它们不需要对环境进行精确的建模。基于值函数的方法 DQN、基于策略的方法 REINFORCE 以及两者结合的方法 Actor-Critic。

Q-learning 与 Sarsa

Q learning 与 Sarsa 由于TD算法实际内容上存在的差异,导致了其对样本数据利用的不同。Q learning 能够使用Exprience replay 而 Sarsa 不能使用Exprience replay
参考

Q-learning的目标是求解“真正”的 Q ∗ ( s , a ) Q^{*}(s,a) Q(s,a) ,而Sarsa的目标则是求解 Q π ( s , a ) Q_\pi(s,a) Qπ(s,a)
在Q-learning中,我们一般会采用experience replay技术,即准备一个数据库并不断把Agent新产生的 ( s , a , r , s ′ ) (s,a,r,s^{\prime}) (s,a,r,s) 数据集存入数据库中。我们每次会从数据库中随机抽取一个batch的数据集用以训练,这意味着每次训练时我们用到的数据集可能是Agent在很久以前产生的。但是,无论我们用到的数据是Agent在训练中的哪一个阶段产生的,数据都是服从环境分布的,所以它们当然都可以被用以训练。

在Sarsa中,情况则与Q-learning很不一样。对于 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 的训练数据集,我们不但要求 ( r , s ′ ) (r,s^{\prime}) (r,s) 应该服从环境分布,也要求 a ′ a^{\prime} a 必须服从 π \pi π 关于 s ′ s^{\prime} s 的条件分布。在训练中,Q表的内容会不断被改变,所以Agent产生数据的策略 也会不断被改变。这意味在Agent过去产生的 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 中, ( s ′ , a ′ ) (s^{\prime},a^{\prime}) (s,a) 可能不服从现在策略 π \pi π 对应的条件分布,因此Agent在过去产生的数据就不能用以现在的训练。

由于上述的原因,我们不能在Sarsa中采用experience replay。在训练中,设当前Agent产生数据的策略为 π \pi π
。我们可以一次性用Agent产生大量服从环境及 π \pi π 分布的数据,并用这些数据来进行训练。而训练过后,Q表的内容发生了变化,这意味着Agent产生数据的策略变成了与 π \pi π 不同的 π ′ \pi^{\prime} π 。这时,刚才那些服从环境与 π \pi π 分布的 ( s , a , r , s ′ , a ′ ) (s,a,r,s',a') (s,a,r,s,a) 数据就变得不再有价值,我们只能将其丢弃。接下来,我们就要让Agent用当前产生数据的策略 继续产生大量的数据,并进行下一步的训练。
我们将Q-learning中那种experience replay的训练方式称作off-policy的,而将Sarsa中这种“边学边玩”的训练方式称作on-policy的。不难看出,二者的核心差别就在于off-policy方法中我们只是要求数据服从于环境分布,而on-policy方法中我们却要求数据要服从环境与当前策略的分布。如果只要求数据服从环境分布,由于过去产生的数据都是服从环境分布的,所以我们当前可以将其储存下来多次利用;如果要求数据服从环境与当前策略分布,由于训练中我们的策略可能会一直发生改变,用过去策略产生的数据很可能无法用于当前的训练,所以我们的数据就不能被多次利用。

区别

  1. off-line 在现实生活中的许多场景下,让尚未学习好的智能体和环境交互可能会导致危险发生,或是造成巨大损失。例如,在训练自动驾驶的规控智能体时,如果让智能体从零开始和真实环境进行交互,那么在训练的最初阶段,它操控的汽车无疑会横冲直撞,造成各种事故。再例如,在推荐系统中,用户的反馈往往比较滞后,统计智能体策略的回报需要很长时间。而如果策略存在问题,早期的用户体验不佳,就会导致用户流失等后果。因此,离线强化学习(offline reinforcement learning)的目标是,在智能体不和环境交互的情况下,仅从已经收集好的确定的数据集中,通过强化学习算法得到比较好的策略。
    在这里插入图片描述

  2. On-Policy和Off-Policy 是关于策略评估和改进的方式。二者的区别主要在于在线策略算法会直接使用这些反馈数据,而离线策略算法会先将数据存入经验回放池中,需要时再采样。

  3. Model-Based和Model-Free 是关于如何对环境建模和学习的方法。Model-Based建立模型,Model-Free直接学习策略或价值函数。


https://blog.csdn.net/niulinbiao/article/details/134081800
https://blog.csdn.net/qq_43585760/article/details/133557729
https://wjrsbu.smartapps.cn/pages/article/index?id=602217717
https://zhuanlan.zhihu.com/p/166412379
https://hrl.boyuai.com/chapter/2/actor-critic%E7%AE%97%E6%B3%95

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1510973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

八数码题解

179. 八数码 - AcWing题库 首先要明确八数码问题的小结论,当原始序列中逆序对数列为奇数时一定无解,反之一定有解。 解法一:BFSA* 首先思考用纯BFS解决这个问题。 大致的框架就是: 队列q,状态数组dist,…

(ConvE)Convolutional 2D Knowledge Graph Embeddings

论文地址:https://arxiv.org/pdf/1707.01476.pdf 一、研究领域 知识图谱受限于知识构建方式的不足,常常伴随着不完备的特点,因此需要知识推理和补齐技术,来根据已有的事实来合理推断出新的事实以补充知识图谱,使其更完备。链路预测任务是知识推理和补齐技术的主要手段,用…

Python 导入Excel三维坐标数据 生成三维曲面地形图(面) 4-1、线条平滑曲面(原始图形)

环境和包: 环境 python:python-3.12.0-amd64包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 scipy 1.12.0 代码: import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.interpolate import griddata fro…

数据分析实战-Python实现博客评论数据的情感分析

数据分析实战-Python实现博客评论数据的情感分析 学习建议SnowNLP基础什么是SnowNLP?SnowNLP情感分析 SnowNLP使用SnowNLP安装情感分析中文分词关键词提取拼音、词性标准 SnowNLP实战-博客评论数据的情感分析数据准备数据获取数据分析 总结 学习建议 现在很多网站、…

SpringBoot整合阿里云文件上传OSS以及获取oss临时访问url

SpringBoot整合阿里云文件上传OSS 1. 引入相关依赖<!--阿里云 OSS依赖--><dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.10.2</version></dependency><dependen…

【MySQL】超详细_数据库的约束_MySQL的详细查询

复习前面MySQL的基础操作&#xff0c;目的是让我们有印象&#xff01;&#xff01;在这篇文章中&#xff0c;我主要写的是数据库的约束和查询操作的详细、深入讲解&#xff01; 基础操作 &#xff08;复习->【MySQL】超详细-基础操作&#xff09; 插入 insert -> inser…

https代理相对socks5代理有什么优势?

随着互联网的快速发展&#xff0c;代理服务已成为许多人在访问敏感或地理位置受限的网站时所依赖的工具。其中&#xff0c;HTTPS代理和SOCKS5代理是两种最常用的代理服务类型。本文将探讨HTTPS代理相对SOCKS5代理的优势。 1、安全性 HTTPS代理使用SSL/TLS协议对客户端和代理服…

C++ 矩形类

思维导图&#xff1a; #include <iostream> using namespace std; class Rect { private:int width;int height; public:void init(int w,int h){widthw;heighth;}void set_w(int w){widthw;}void set_h(int h){heighth;}void show(){cout << "perimeter &qu…

基于51单片机的LED点阵显示屏设计

目录 摘要 II Abstract III 第一章 绪论 1 1.1 课题背景 1 1.2 选题意义 1 1.3 论文主要内容 1 第二章 方法论证对比 3 2.1 单片机编程语言 3 2.2 控制系统设计 3 2.3 显示方式 3 第三章 系统硬件设计 4 3.1 总体硬件设计 4 3.2 系统各硬件电路介绍 5 3.2.1 电源电路设计介绍 …

蓝牙系列七:开源蓝牙协议栈BTStack数据处理(Wireshark抓包分析)

继续蓝牙系列的研究。 在上篇博客&#xff0c;通过阅读BTStack的源码&#xff0c;大体了解了其框架&#xff0c;对于任何一个BTStack的应用程序都有一个main函数&#xff0c;这个main函数是统一的。这个main函数做了某些初始化之后&#xff0c;最终会调用到应用程序提供的btst…

prometheus 原理(架构,promql表达式,描点原理)

大家好&#xff0c;我是蓝胖子&#xff0c;提到监控指标&#xff0c;不得不说prometheus&#xff0c;今天这篇文章我会对prometheus 的架构设计&#xff0c;promql表达式原理和监控图表的绘图原理进行详细的解释。来让大家对prometheus的理解更加深刻。 架构设计 先来看看&am…

Docker容器化技术(使用Dockerfile制作镜像)

Docker中的镜像分层 Docker 支持通过扩展现有镜像&#xff0c;创建新的镜像。实际上&#xff0c;Docker Hub 中 99% 的镜像都是通过在 base 镜像中安装和配置需要的软件构建出来的。 1、Docker 镜像为什么分层 镜像分层最大的一个好处就是共享资源。 比如说有多个镜像都从相…

python 通过代理服务器 连接 huggingface下载模型,并运行 pipeline

想在Python 代码中运行时下载模型&#xff0c;启动代理服务器客户端后 1. 检查能否科学上网 $ curl -x socks5h://127.0.0.1:1080 https://www.example.com <!doctype html> <html> <head><title>Example Domain</title><meta charset"…

Python: 如何绘制核密度散点图和箱线图?

01 数据样式 这是数据样式&#xff1a; 要求&#xff08;我就懒得再复述一遍了&#xff0c;直接贴图&#xff09;&#xff1a; Note&#xff1a;数据中存在无效值NA&#xff08;包括后续的DEM&#xff09;&#xff0c;需要注意 02 提取DEM 这里我就使用gdal去提取一下DEM列…

./ 相对路径与node程序的启动目录有关

node:internal/fs/sync:78 return binding.openSync( ^ Error: ENOENT: no such file or directory, open D:\前端的学习之路\项目\codeHub\keys\private_key.pem at Object.open (node:internal/fs/sync:78:18) at Object.openSync (node:fs:565:…

Java后台面试相关知识点解析

文章目录 JavaJava中四种引用类型及使用场景集合HashMap源码及扩容策略HashMap死循环问题ConcurrentHashMap与HashtableConCurrentHashMap 1.8 相比 1.7判断单链表是否有环&#xff0c;并且找出环的入口 IO线程池线程池的几种创建方式判断线程是否可以回收线程池的7大核心参数线…

菜鸟学会Linux的方法

系统安装是初学者的门槛&#xff0c;系统安装完毕后&#xff0c; 很多初学者不知道该如何学习&#xff0c;不知道如何快速进阶&#xff0c; 下面作者总结了菜鸟学好Linux技能的大绝招&#xff1a; 初学者完成Linux系统分区及安装之后&#xff0c;需熟练掌握Linux系统管理必备命…

蓝桥省赛倒计时 35 天-bfs 和 dfs

#include <iostream> using namespace std; int t; int m,n; char mp[55][55];//不能写成 int 数组 bool vis[55][55]; int dx[ ]{1,0,-1,0},dy[ ]{0,1,0,-1}; int res;void dfs_1(int x,int y){vis[x][y] true;//陆地向四个方向拓展for(int i0;i<4;i){int nx xdx[i…

蓝桥杯练习系统(算法训练)ALGO-973 唯一的傻子

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 问题描述 腿铮找2255有点事&#xff0c;但2255太丑了&#xff0c;所以腿铮不知道他的长相。正愁不知道到如何找他的时候&#xff0c;…

基于React低代码平台开发:直击最新高效应用构建

&#x1f3e1;浩泽学编程&#xff1a;个人主页 &#x1f525; 推荐专栏&#xff1a;《深入浅出SpringBoot》《java对AI的调用开发》 《RabbitMQ》《Spring》《SpringMVC》《项目实战》 &#x1f6f8;学无止境&#xff0c;不骄不躁&#xff0c;知行合一 文章目录…