强化学习13——Actor-Critic算法

news2025/1/11 7:39:48

Actor-Critic算法结合了策略梯度和值函数的优点,我们将其分为两部分,Actor(策略网络)和Critic(价值网络)

  • Actor与环境交互,在Critic价值函数的指导下使用策略梯度学习好的策略
  • Critic通过Actor与环境交互收集的数据学习,得到一个价值函数,来判断当前状态哪些动作是好,哪些动作是坏,进而帮Actor进行策略更新。

image.png

A2C算法

AC算法的目的是为了消除策略梯度算法的高仿查问题,可以引用优势函数(advantage function) A π ( s t , a t ) A^{\pi}(s_t,a_t) Aπ(st,at) ,来表示当前当前状态-动作对相对于平均水平的优势:
A π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) A^{\pi}(s_t,a_t)=Q^{\pi}(s_t,a_t)-V^{\pi}(s_t) Aπ(st,at)=Qπ(st,at)Vπ(st)
通过与平均水平相减,可以降低方差。但需要注意的是,相减的是 V π ( s t ) V^{\pi}(s_t) Vπ(st) ,即在状态 s t s_t st 下的价值,即状态 s t s_t st 的回报的均值,而不是所有状态 s s s 的回报的均值。

可以将目标函数改为:
∇ θ J ( θ ) ∝ E π θ [ A π ( s t , a t ) ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta)\propto\mathbb{E}_{\pi_\theta}\left[A^\pi(s_t,a_t)\nabla_\theta\log\pi_\theta(a_t\mid s_t)\right] θJ(θ)Eπθ[Aπ(st,at)θlogπθ(atst)]
这就是A2C算法(Advantage Actor-Critic)算法。脱胎于A3C算法,即增加了多个进程,每一个进程都拥有一个独立的网络和环境以供训练。

image.png

广义优势估计

时序差分能有效解决高方差问题但是是有偏估计,而蒙特卡洛是无偏估计但是会带来高方差问题,因此通常会结合这两个方法形成一种新的估计方式,即 T D ( λ ) TD(\lambda) TD(λ) 估计,通过结合多步,形成新的估计方式,成为广义优势估计(generalized advantage estimation GAE)。

A GAE ( γ , λ ) ( s t , a t ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l = ∑ l = 0 ∞ ( γ λ ) l ( r t + l + γ V π ( s t + l + 1 ) − V π ( s t + l ) ) \begin{aligned} A^{\text{GAE}(\gamma,\lambda)}(s_t,a_t)& =\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l} \\ &=\sum_{l=0}^\infty(\gamma\lambda)^l\left(r_{t+l}+\gamma V^\pi(s_{t+l+1})-V^\pi(s_{t+l})\right) \end{aligned} AGAE(γ,λ)(st,at)=l=0(γλ)lδt+l=l=0(γλ)l(rt+l+γVπ(st+l+1)Vπ(st+l))
其中, δ t + l \delta_{t+l} δt+l 为时步 t + l t+l t+l 的TD误差,为:
δ t + l = r t + l + γ V π ( s t + l + 1 ) − V π ( s t + l ) \delta_{t+l}=r_{t+l}+\gamma V^{\pi}(s_{t+l+1})-V^{\pi}(s_{t+l}) δt+l=rt+l+γVπ(st+l+1)Vπ(st+l)
λ = 0 \lambda=0 λ=0 时,退化为单步TD误差:
A G A E ( γ , 0 ) ( s t , a t ) = δ t = r t + γ V π ( s t + 1 ) − V π ( s t ) A^{\mathrm{GAE}(\gamma,0)}(s_t,a_t)=\delta_t=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) AGAE(γ,0)(st,at)=δt=rt+γVπ(st+1)Vπ(st)
λ = 1 \lambda=1 λ=1 时,则为蒙特卡洛估计:
A G A E ( γ , 1 ) ( s t , a t ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l = ∑ l = 0 ∞ ( γ ) l δ t + l A^{\mathrm{GAE}(\gamma,1)}(s_t,a_t)=\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l}=\sum_{l=0}^\infty(\gamma)^l\delta_{t+l} AGAE(γ,1)(st,at)=l=0(γλ)lδt+l=l=0(γ)lδt+l

代码实操

image.png

import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils
# 定义策略网络
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

# 定义价值网络,输出一个价值,为一维张量
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

现在定义A2C算法的主题,包括采取动作和更新网络参数的两个函数。

class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)  # 价值网络
        # 策略网络优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)  # 价值网络优化器
        self.gamma = gamma
        self.device = device
        
    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    
    def update(self,transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        
        # 时序差分目标
        td_target=rewards+self.gamma*self.critic(next_states)*(1-dones)
        # 进行时序擦划分
        td_delta=td_target-self.critic(states)
        log_probs=torch.log(self.actor(states).gather(1,actions))
        actor_loss=torch.mean(-log_probs*td_delta.detach())
        # 均方误差损失函数
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()  # 计算策略网络的梯度
        critic_loss.backward()  # 计算价值网络的梯度
        self.actor_optimizer.step()  # 更新策略网络的参数
        self.critic_optimizer.step()  # 更新价值网络的参数
    
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                    gamma, device)

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()
  state = torch.tensor([state], dtype=torch.float).to(self.device)
Iteration 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.55it/s, episode=100, return=20.400]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.48it/s, episode=200, return=51.200]
Iteration 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:14<00:00,  6.91it/s, episode=300, return=151.500]
Iteration 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00,  3.88it/s, episode=400, return=256.700]
Iteration 4:  53%|███████████████████████████████████████████████████████████████████████████████▌                                                                      | 53/100 [00:17<00:10,  4.51it/s, episode=450, return=235.500]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1405766.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【BBuf的CUDA笔记】十三,OpenAI Triton 入门笔记一

0x0. 前言 2023年很多mlsys工作都是基于Triton来完成或者提供了Triton实现版本&#xff0c;比如现在令人熟知的FlashAttention&#xff0c;大模型推理框架lightllm&#xff0c;diffusion第三方加速库stable-fast等灯&#xff0c;以及很多mlsys的paper也开始使用Triton来实现比…

瑞_力扣LeetCode_104. 二叉树的最大深度

文章目录 题目 104. 二叉树的最大深度题解后序遍历 递归实现后序遍历 迭代实现层序遍历 题目 111. 二叉树的最小深度题解后序遍历层序遍历 题目 226. 翻转二叉树题解 &#x1f64a; 前言&#xff1a;本文章为瑞_系列专栏之《刷题》的力扣LeetCode系列&#xff0c;主要以力扣Lee…

Pandas.Series.sum() 求和(累和) 详解 含代码 含测试数据集 随Pandas版本持续更新

关于Pandas版本&#xff1a; 本文基于 pandas2.2.0 编写。 关于本文内容更新&#xff1a; 随着pandas的stable版本更迭&#xff0c;本文持续更新&#xff0c;不断完善补充。 传送门&#xff1a; Pandas API参考目录 传送门&#xff1a; Pandas 版本更新及新特性 传送门&…

sqlmap使用教程(3)-探测注入漏洞

1、探测GET参数 以下为探测DVWA靶场low级别的sql注入&#xff0c;以下提交方式为GET&#xff0c;问号&#xff08;?&#xff09;将分隔URL和传输的数据&#xff0c;而参数之间以&相连。--auth-credadmin:password --auth-typebasic &#xff08;DVWA靶场需要登录&#xf…

【GitHub项目推荐--微软开源的课程(Web开发课程/机器学习课程/物联网课程/数据科学课程)】【转载】

微软在 GitHub 开源了四大课程&#xff0c;面向计算机专业或者入门编程的同学。分别是 Web 开发课程、机器学习课程、物联网课程和数据分析课程。 四大课程在 GitHub 上共斩获 90K 的Star&#xff0c;每一课程包含 20 多小节&#xff0c;完成课程大约需要 12 周。每小节除了视…

《统计学习方法:李航》笔记 从原理到实现(基于python)-- 第1章 统计学习方法概论

文章目录 第1章 统计学习方法概论1.1 统计学习1&#xff0e;统计学习的特点2&#xff0e;统计学习的对象3&#xff0e;统计学习的目的4&#xff0e;统计学习的方法1.2.1 基本概念1.2.2 问题的形式化 1.3 统计学习三要素1.3.1 模型1.3.2 策略1.3.3 算法 1.4 模型评估与模型选择1…

【BIAI】Lecture 6 - Somatosensory systems

Lecture 6- Somatosensory systems 专业术语 somatosensory system 体感系统 Thermoreceptors 温度感受器 Photoreceptors 光感受器 Chemoreceptoprs 化学感受器 hairy skin 毛发皮肤 glabrous skin 光滑皮肤 sensory receptors 感觉受体 dermal 真皮的 epidermal 表皮的 axon…

Vue+Element(el-switch的使用)+springboot

目录 1、编写模板 2、发送请求 3、后端返数据 1.Controller类 2.interface接口&#xff08;Service层接口&#xff09; 3.Service&#xff08;接口实现&#xff09; 4.interface接口&#xff08;Mapper层接口&#xff09; 5.xml 6.效果 4、el-switch属性 1、编写模板 …

【工作】专业沟通,有效对齐信息(及时回应,做好汇报)

【工作】专业沟通&#xff0c;及时对齐信息 文章目录 一、读书笔记二、工作case三、前人case 一、读书笔记 1、书籍推荐&#xff1a; 两本值得学习的沟通方法书籍&#xff1a; 理论&#xff1a;《说话就是生产力》实践&#xff1a;《沟通的方法》 五本补充学习沟通方法的书…

Linux 强大的网络命令:nc命令操作方法

Netcat&#xff08;或简称nc&#xff09;是一个强大的网络工具&#xff0c;它在Linux系统中广泛使用&#xff0c;可用于创建各种网络连接。它被描述为"网络的瑞士军刀"&#xff0c;因为它的功能非常灵活&#xff0c;可以在网络中执行多种任务。 在大多数Linux发行版中…

Python入门(一)

anaconda安装 官网&#xff1a;https://www.anaconda.com下载 jupyter lab 简介&#xff1a; 包含了Jupyter Notebook所有功能。 JupyterLab作为一种基于web的集成开发环境&#xff0c;你可以使用它编写notebook&#xff0c;操作终端&#xff0c;编辑markdown文本&#xf…

Android:JNI实战,理论详解、Java与Jni数据调用

一.概述 上一篇博文讲解了如何搭建一个可以加载和链接第三方库、编译C/C文件的Jni Demo App。 这篇博文在这个Jni Demo App的基础上&#xff0c;从实战出发详细讲解 Jni 开发语法。 接下来&#xff0c;先用一小节将Jni开发比较重要的理论知识点过一下&#xff0c;然后进行代…

== 和 equals:对象相等性比较的细微差别

和 equals&#xff1a;对象相等性比较的细微差别 既要脚踏实地于现实生活&#xff0c;又要不时跳出现实到理想的高台上张望一眼。在精神世界里建立起一套丰满的体系&#xff0c;引领我们不迷失不懈怠。待我们一觉醒来&#xff0c;跌落在现实中的时候&#xff0c;可以毫无怨言地…

Minio 判断对象是否存在

引 Minio数据模型 中描述了 MinIO 中什么是桶&#xff0c;什么是对象&#xff0c;也给出了操作桶和操作对象的API。 在 MinIO 中&#xff0c; 对象 中间前缀 对象名称 。如何判定对象是否存在呢&#xff1f; 分析 在 MinIO 中并没有提供判断对象是否存在的操作&#xff…

VS Code Json格式化插件-JSON formatter

&#x1f9aa;整个文件格式化 按快捷键Shift Alt F &#x1f96a;仅格式化选择内容 需要选择完整的json段落即&#xff1a;{} 或 [] 括起来的部分&#xff0c;再按快捷键Ctrl K F

激光雷达行业梳理1-概述、市场、技术路线

激光雷达作为现代精确测距和感知技术的关键组成部分&#xff0c;在近几年里取得了令人瞩目的发展。作为自动驾驶感知层面的重要一环&#xff0c;相较摄像头、毫米波雷达等其他传感器具有“ 精准、快速、高效作业”的巨大优势&#xff0c;已成为自动驾驶的主传感器之一&#xff…

芋道--如何自定义业务表单,配置对应的工作流程(详细步骤)

需求描述: 芋道的动态表单就不再介绍了&#xff0c;相对来讲比较简单,跟着官网文档就可以实现&#xff0c;本文将详细的介绍如何新建独立的业务表记录申请的信息&#xff0c;并设计对应的工作流。 这里表中的每一条记录&#xff0c;都将通过流程实例编号(process_instance_id )…

mysql-进阶篇

文章目录 存储引擎MySQL体系结构相关操作 存储引擎特点InnoDBInnoDB 逻辑存储结构 MyISAMMemory三个存储引擎之间的区别存储引擎的选择 索引1. 索引结构B-TreeB-Tree (多路平衡查找树)B-Tree演变过程 BTree与 B-Tree 的区别BTree演变过程 Hash 2.索引分类3.索引语法演示 4.SQL性…

946. 验证栈序列(力扣)

946. 验证栈序列 Problem: 946. 验证栈序列 文章目录 思路解题方法复杂度Code 思路 对栈的使用 解题方法 1.我们可以通过把pushed重新一个一个入我们自己创建的栈如果某次入栈碰到与poped第一个元素相同的那我们就对poped出栈处理(即跳过第一个元素);如此循环,直到我们的栈到最…

【C++记忆站】类和对象(二)

类和对象(二) 如果一个类中什么成员都没有&#xff0c;简称为空类。 空类中真的什么都没有吗&#xff1f;并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。 默认成员函数&#xff1a;用户没有显式实现&#xff0c;编译器会生成的成员…