JoyRL Actor-Critic算法

news2024/12/27 11:29:10

策略梯度算法的缺点

这里策略梯度算法特指蒙特卡洛策略梯度算法,即 REINFORCE 算法。 相比于 DQN 之类的基于价值的算法,策略梯度算法有以下优点。

  • 适配连续动作空间。在将策略函数设计的时候我们已经展开过,这里不再赘述。
  • 适配随机策略。由于策略梯度算法是基于策略函数的,因此可以适配随机策略,而基于价值的算法则需要一个确定的策略。此外其计算出来的策略梯度是无偏的,而基于价值的算法则是有偏的。

 但同样的,策略梯度算法也有其缺点。

  • 采样效率低。由于使用的是蒙特卡洛估计,与基于价值算法的时序差分估计相比其采样速度必然是要慢很多的,这个问题在前面相关章节中也提到过。
  • 高方差。虽然跟基于价值的算法一样都会导致高方差,但是策略梯度算法通常是在估计梯度时蒙特卡洛采样引起的高方差,这样的方差甚至比基于价值的算法还要高。
  • 收敛性差。容易陷入局部最优,策略梯度方法并不保证全局最优解,因为它们可能会陷入局部最优点。策略空间可能非常复杂,存在多个局部最优点,因此算法可能会在局部最优点附近停滞。
  • 难以处理高维离散动作空间:对于离散动作空间,采样的效率可能会受到限制,因为对每个动作的采样都需要计算一次策略。当动作空间非常大时,这可能会导致计算成本的急剧增加。

结合了策略梯度和值函数的 Actor-Critic 算法则能同时兼顾两者的优点,并且甚至能缓解两种方法都很难解决的高方差问题。

Q:为什么各自都有高方差的问题,结合了之后反而缓解了这个问题呢?

A:策略梯度算法是因为直接对策略参数化,相当于既要利用策略去与环境交互采样,又要利用采样去估计策略梯度,而基于价值的算法也是需要与环境交互采样来估计值函数的,因此也会有高方差的问题。

 而结合之后呢,Actor 部分还是负责估计策略梯度和采样,但 Critic 即原来的值函数部分就不需要采样而只负责估计值函数了,并且由于它估计的值函数指的是策略函数的值,相当于带来了一个更稳定的估计,来指导 Actor 的更新,反而能够缓解策略梯度估计带来的方差。

Q Actor-Critic算法

如图 10.1 所示,我们通常将 Actor 和 Critic 分别用两个模块来表示,即图中的策略函数( Policy )和价值函数( Value Function )。Actor与环境交互采样,然后将采样的轨迹输入 Critic 网络,Critic 网络估计出当前状态-动作对的价值,然后再将这个价值作为 Actor 网络的梯度更新的依据,这也是所有 Actor-Critic 算法的基本通用架构

A2C与A3C算法

A2C

A3C

广义优势估计

由于优势函数通本质上来说还是使用蒙特卡洛估计,因此尽管减去了基线,有时候还是会产生高方差,从而导致训练过程不稳定

实战:A2C算法

定义模型

Critic 的输入是状态,输出则是一个维度的价值,而 Actor 输入的也会状态,但输出的是概率分布

class Critic(nn.Module):
    def __init__(self,state_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits_p = F.softmax(self.fc3(x), dim=1)
        return logits_p

这里由于是离散的动作空间,根据在策略梯度章节中设计的策略函数,我们使用了 softmax 函数来输出概率分布。另外,实践上来看,由于 Actor 和 Critic 的输入是一样的,因此我们可以将两个网络合并成一个网络,以便于加速训练。这有点类似于 Duelling DQN 算法中的做法

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.action_layer = nn.Linear(256, action_dim)
        self.value_layer = nn.Linear(256, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits_p = F.softmax(self.action_layer(x), dim=1)
        value = self.value_layer(x)
        return logits_p, value

动作采样

与 DQN 算法不同等确定性策略不同,A2C 的动作输出不再是 Q 值最大对应的动作,而是从概率分布中采样动作,这意味着即使是很小的概率,也有可能被采样到,这样就能保证探索性

# Categorical分布函数,能直接从概率分布中采样动作
from torch.distributions import Categorical
class Agent:
    def __init__(self):
        self.model = ActorCritic(state_dim, action_dim)
    def sample_action(self,state):
        '''动作采样函数
        '''
        state = torch.tensor(state, device=self.device, dtype=torch.float32)
        logits_p, value = self.model(state)
        dist = Categorical(logits_p) 
        action = dist.sample() 
        return action

策略更新

我们首先需要计算出优势函数,一般先计算出回报,然后减去网络输出的值即可

class Agent:
    # 定义一个Agent类

    def _compute_returns(self, rewards, dones):
        # 计算回报
        returns = []  # 初始化一个回报列表
        discounted_sum = 0  # 初始化折扣累计和
        # 从后向前遍历奖励和是否结束的序列
        for reward, done in zip(reversed(rewards), reversed(dones)):
            # 如果游戏结束,则折扣累计和重置为0
            if done:
                discounted_sum = 0
            # 否则,将奖励加上折现因子gamma乘以之前的折扣累计和
            discounted_sum = reward + (self.gamma * discounted_sum)
            # 将计算出的折扣累计和添加到回报列表的开头
            returns.insert(0, discounted_sum)
        # 将回报列表转换为PyTorch张量,并移到Agent指定的设备上
        returns = torch.tensor(returns, device=self.device, dtype=torch.float32).unsqueeze(dim=1)
        # 对回报进行归一化处理
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)  # 添加一个很小的数以避免除以零
        return returns

    def compute_advantage(self):
        '''计算优势函数
        '''
        # 从记忆库中随机抽取一批经验
        logits_p, states, rewards, dones = self.memory.sample()
        # 计算回报
        returns = self._compute_returns(rewards, dones)
        # 将状态转换为PyTorch张量,并移到Agent指定的设备上
        states = torch.tensor(states, device=self.device, dtype=torch.float32)
        # 前向传播模型以获得动作的概率和对数概率
        logits_p, values = self.model(states)
        # 计算优势,即回报与批评价值的差
        advantages = returns - values
        return advantages

这里我们使用了一个技巧,即将回报归一化,这样可以让优势函数的值域在 [−1,1] 之间,这样可以让优势函数更稳定,从而减少方差。计算优势之后就可以分别计算 Actor 和 Critic 的损失函数了

class Agent:
    def compute_loss(self):
        '''计算损失函数
        '''
        logits_p, states, rewards, dones = self.memory.sample()
        returns = self._compute_returns(rewards, dones)
        states = torch.tensor(states, device=self.device, dtype=torch.float32)
        logits_p, values = self.model(states)
        advantages = returns - values
        dist = Categorical(logits_p)
        log_probs = dist.log_prob(actions)
        # 注意这里策略损失反向传播时不需要优化优势函数,因此需要将其 detach 掉
        actor_loss = -(log_probs * advantages.detach()).mean() 
        critic_loss = advantages.pow(2).mean()
        return actor_loss, critic_loss

练习题

1.相比于 REINFORCE 算法, A2C 主要的改进点在哪里,为什么能提高速度?

(1)结合了策略梯度和值函数的 Actor-Critic 算法则能同时兼顾两者的优点,并且甚至能缓解两种方法都很难解决的高方差问题

(2)A2C计算了一个优势函数来衡量实际回报与批评价值之间的差异

(3)A2C在计算回报时使用了均值标准化,这有助于加快学习的收敛速度

2.A2C 算法是 on-policy 的吗?为什么?

是的。A2C算法通过Actor-Critic实现on-policy学习。Actor负责生成行动的概率分布,而Critic负责评估状态的价值。在A2C的更新过程中,智能体根据Actor生成的策略选择行动,并使用这些行动的结果来更新Actor和Critic。因此,A2C在执行和学习时使用的是同一策略

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1404746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你必须了解的羊奶知识,一文悉数为你揭晓

你必须了解的羊奶知识,一文悉数为你揭晓 羊奶,作为一种营养丰富的乳制品,近年来备受关注。许多人选择羊奶作为替代牛奶的选择,因为它被认为更易消化,并且具有许多健康益处。在本文中,小编羊大师将为大家介…

Mac 上网易云音乐 ncm 格式文件如何转换为 mp3 音频文件?嗨格式转换器

hello朋友们大家好,最近想着怎么把网易云的歌保存到U盘,然后放到车上去听,然后辛辛苦苦搞了一宿,第二天拿到车上发现播放不了,根本就不认识 ncm 格式,我百度了一下,发现 ncm 是网易云的专用加密…

推荐一一款小众黑科技工具,低调使用建议收藏

wireshark是个啥就不多说了,非常流行的网络封包分析软件。 可以截取各种网络封包,显示网络封包的详细信息。 软件功能十分强大,操作也不复杂。 很多小友都在后台问能不能出一期完整的抓包分析贴,今天给你们安排上了哈。 01 W…

Kafka(二)【文件存储机制 生产者】

目录 一、Kafka 文件存储机制 二、Kafka 生产者 1、生产者消息发送流程 1.1、发送原理 2、异步发送 API 2.1、普通异步发送 案例演示 2.2、带回调函数的异步发送 2.3、同步发送 API 3、生产者分区 3.1、分区的好处 3.2、生产者发送消息的分区策略 (1&am…

【Java数据结构 -- 队列:队列有关面试oj算法题】

队列、循环队列、用队列模拟栈、用栈模拟队列 1.队列1.1 什么是队列1.2 创建队列1.3 队列是否为空和获取队头元素 empty()peek()1.4 入队offer()1.5 出队(头删)poll() 2. 循环队列2.1 创建循环队列2.2 判断是否为空isEmpty()和满isFull()2.3 入队enQueue…

深入理解Linux中的动态库与静态库

🎬慕斯主页:修仙—别有洞天 ♈️今日夜电波:I Wish My Mind Would Shut Up—Ivoris 0:21━━━━━━️💟──────── 2:04 🔄 ◀️ …

基于SpringBoot的手机商城

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式 🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 &…

苏州渭塘镇应用无人机“智慧执法”

苏州渭塘镇应用无人机“智慧执法” 在今年以来,渭塘镇综合行政执法局采用了“空中地面”的立体监督模式,以实现对“互联网执法”工作的深入推进。在这一模式下,无人机巡查作为技术手段得到广泛应用,而安全生产监管信息系统和综合…

MySQL-函数-数值函数

常见的数值函数 案例

【一文秒懂】Ftrace系统调试工具使用终极指南

我的圈子: 高级工程师聚集地 我是董哥,高级嵌入式软件开发工程师,从事嵌入式Linux驱动开发和系统开发,曾就职于世界500强公司! 创作理念:专注分享高质量嵌入式文章,让大家读有所得! …

Armv8-M的TrustZone技术之内存属性单元

如果处理器包含Armv8-M安全扩展,则内存区域的安全状态由内部安全属性单元(SAU,Secure Attribution Unit)或外部实现定义的属性单元(IDAU,Implementation Defined Attribution Unit)的组合控制。…

【WinForm.NET开发】ToolStrip 控件体系结构

本文内容 ToolStripToolStripItem附件类 ToolStrip 和 ToolStripItem 类提供了一种灵活的可扩展系统,用于显示工具栏、状态和菜单项。 这些类都包含在 System.Windows.Forms 命名空间中 ,它们的名称通常都带有“ToolStrip”前缀(如 ToolStr…

yolov8 opencv dnn部署自己的模型

源码地址 本人使用的opencv c github代码,代码作者非本人 使用github源码结合自己导出的onnx模型推理自己的视频 推理条件 windows 10 Visual Studio 2019 Nvidia GeForce GTX 1070 opencv4.7.0 (opencv4.5.5在别的地方看到不支持yolov8的推理,所以只使用opencv…

HDMI、VGA、DVI、DB接口的区别

HDMI、VGA、DVI和DB(也称为DisplayPort)是不同类型的视频接口标准,它们用于连接计算机、显示器、电视和其他视频设备。 HDMI(High-Definition Multimedia Interface,高清晰度多媒体接口):HDMI支…

C语言——静态通讯录的实现

今天我们来实现一下一个静态的通讯录: 我就先展示一下几个功能: 实现一个通讯录; 通讯录可以用来存储100个人的信息,每个人的信息包括:姓名、性别、年龄、电话、住址 提供方法: 添加联系人信息删除指定…

rancher和k8s接口地址,Kubernetes监控体系,cAdvisor和kube-state-metrics 与 metrics-server

为了能够提前发现kubernetes集群的问题以及方便快捷的查询容器的各类参数,比如,某个pod的内存使用异常高企 等等这样的异常状态(虽然kubernetes有自动重启或者驱逐等等保护措施,但万一没有配置或者失效了呢)&#xff0…

容器技术2-镜像与容器储存

目录 一、镜像制作 1、ddocker build 2、docker commit 二、镜像存储 1、公共仓库 2、私有仓库 三、镜像使用 四、容器存储 1、镜像元数据 2、存储驱动 3、数据卷 一、镜像制作 1、ddocker build 基于 Dockerfile 自动构建镜像 其机制为:每一行都会基于…

Go 的 Http 请求系统指南

文章目录 快速体验请求方法URL参数响应信息BodyStatusCodeHeaderEncoding 图片下载定制请求头复杂的POST请求表单提交提交文件 CookieClient 上设置 Cookie请求上设置 Cookie 重定向和请求历史超时设置总超时连接超时读取超时 请求代理错误处理总结 前几天在 “知乎想法” 谈到…

linux安装docker(入门一)

环境:centos 7(linux) 网站 官网: https://docs.docker.com/ Docker Hub 网站: https://hub.docker.com/ 容器官方概述 一句话概括容器:容器就是将软件打包成标准化单元,以用于开发、交付和部署。 容器镜像是轻量的、可执行的独立软件包 &…

Python小细节之代码极致简化到一行(5)(列表推导式)(技法慎用)

列表、推导式 引言简化前简化后讲解简化前简化后 应用结尾 引言 简单快速 大行其道 现在我又带着简化代码来了 我思考了下 简化的代码是技巧的体现 但是简短的代码里面 蕴藏着的是Python的精华 所以 我会更加详细的解析代码的内容 致力于让每个零基础的人都看懂 简化前 m…