强化学习之Actor-Critic算法(基于值函数和策略的结合)——以CartPole环境为例

news2024/11/26 8:38:14

0.简介

DQN算法作为基于值函数的方法代表,基于值函数的方法只学习一个价值函数。REINFORCE算法作为基于策略的方法代表,基于策略的方法只学习一个策略函数。Actor-Critic算法则结合了两种学习方法,其本质是基于策略的方法,因为其目标是优化一个带参的策略,只是会额外学习价值函数帮助策略函数更好地学习。

我们回顾一下在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,来指导策略的更新。而值函数的概念正是基于期望回报,我们能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。让我们先回顾一下策略梯度的形式,在策略梯度中,我们可以把梯度写成下面这个形式:

image.png

其中 ψ t 可以有很多种形式:

image.png

 在 REINFORCE 的最后部分,我们提到了 REINFORCE通过蒙特卡洛采样的方法对梯度的估计是无偏的,但是方差非常大,我们可以用第三种形式引入基线 (baseline) b ( s t ) 来减小方差。此外我们也可以采用 Actor-Critic 算法,估计 一个动作价值函数 Q 来代替蒙特卡洛采样得到的回报,这便是第 4 种形式。这个时候,我们也可以把状态价值函数 V  作为基线,从偍牧但是用神经网络进行估计的方法可以减小方差、提高鲁棒性。除此之外,REINFORCE 算法基于蒙特卡洛采样,只能在序列结束后进行更新,而 Actor-Critic 的方法则可以在每一步之后都进行更新。

我们将 Actor-Critic 分为两个部分: 分别是 Actor (策略网络) 和 Critic (价值网络):

  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于帮助 Actor 进行更新策略。
  • Actor 要做的则是与环境交互,并利用 Ctitic 价值函数来用策略梯度学习一个更好的策略。

image.png

 与 DQN 中一样,我们采取类似于目标网络的方法,上式中 r + γ V ω ( s t + 1 )作为时序差分目标,不会产生梯度来更新价值函数。所以价值函数的梯度为

image.png

然后使用梯度下降方法即可。接下来让我们总体看看 Actor-Critic 算法的流程吧!

  • 初始化策略网络参数 θ  ,价值网络参数 ω
  • 不断进行如下循环 (每个循环是一条序列) :

。 用当前策略 π θ 平样轨 迹 { s 1 , a 1 , r 1 , s 2 , a 2 , r 2 … }

。 为每一步数据计算: δ t = r t + γ V ω ( s t + 1 ) − V ω ( s )

。 更新价值参数 w = w + α ω ∑ t δ t ∇ ω V ω ( s )

。 更新策略参数 θ = θ + α θ ∑ t δ t ∇ θ log ⁡ π θ ( a ∣ s )

 好了!这就是 Actor-Critic 算法的流程啦,让我们来用代码实现它看看效果如何吧!

1.导库

import gym
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

2.策略网络PolicyNet定义

class PolicyNet(torch.nn.Module):#策略网络
    def __init__(self,statedim,hiddendim,actiondim):
        super(PolicyNet,self).__init__()
        self.fc1=torch.nn.Linear(statedim,hiddendim)
        self.fc2=torch.nn.Linear(hiddendim,actiondim)
    def forward(self,x):
        x=torch.nn.functional.relu(self.fc1(x))
        return torch.nn.functional.softmax(self.fc2(x),dim=1)

3.价值网络ValueNet定义

class ValueNet(torch.nn.Module):#价值网络
    def __init__(self,statedim,hiddendim):
        super(ValueNet,self).__init__()
        self.fc1=torch.nn.Linear(statedim,hiddendim)
        self.fc2=torch.nn.Linear(hiddendim,1)
    def forward(self,x):
        x=torch.nn.functional.relu(self.fc1(x))
        return self.fc2(x)

4.ActorCritic算法实现

class ActorCritic:#演员-评论家算法
    def __init__(self,statedim,hiddendim,actiondim,actor_learningrate,critic_learningrate,gamma,device):
        self.actor=PolicyNet(statedim,hiddendim,actiondim).to(device)#策略网络
        self.critic=ValueNet(statedim,hiddendim).to(device)#价值网络
        self.actor_optimizer=torch.optim.Adam(self.actor.parameters(),lr=actor_learningrate)#策略网络优化器
        self.critic_optimizer=torch.optim.Adam(self.critic.parameters(),lr=critic_learningrate)#价值网络优化器
        self.gamma=gamma
        self.device=device
    def takeaction(self,state):#根据策略网络采取动作
        state=torch.tensor([state],dtype=torch.float).to(self.device)
        probs=self.actor(state)
        actiondist=torch.distributions.Categorical(probs)
        action=actiondist.sample()
        return action.item()#返回选择的动作的索引的标量形式
    def update(self,transitiondist):#更新策略网络和价值网络
        states,actions,rewards,nextstates,dones=transitiondist["states"],transitiondist["actions"],transitiondist["rewards"],transitiondist["nextstates"],transitiondist["dones"]
        states=torch.tensor(states,dtype=torch.float).to(self.device)
        actions=torch.tensor(actions).view(-1,1).to(self.device)
        rewards=torch.tensor(rewards,dtype=torch.float).view(-1,1).to(self.device)
        nextstates=torch.tensor(nextstates,dtype=torch.float).to(self.device)
        dones=torch.tensor(dones,dtype=torch.float).view(-1,1).to(self.device)
        td_target=rewards+self.gamma*self.critic(nextstates)*(1-dones)#时序差分目标
        td_delta=td_target-self.critic(states)#时序差分误差
        log_probs=torch.log(self.actor(states).gather(1,actions))
        #.detach() 来创建一个与原始张量值相同但不可训练的副本。这个副本可以在不影响原始张量的情况下进行各种操作,并且不会在反向传播中被更新。
        actor_loss=torch.mean(-log_probs*td_delta.detach())#策略网络的损失函数;#.detach()的作用是将这个张量从计算图中分离出来,这样在计算损失时不会对其进行反向传播,通常是为了防止某些不希望被更新的部分被意外更新。
        critic_loss=torch.mean(torch.nn.functional.mse_loss(self.critic(states),td_target.detach()))#均方差损失函数
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()#计算策略网络的梯度
        critic_loss.backward()#计算价值网络的梯度
        self.actor_optimizer.step()#策略网络参数更新
        self.critic_optimizer.step()#价值网络参数更新

5.训练本算法的函数实现

def train_on_policy_agent(env,agent,episodesnum,pbarnum,printreturnnum,seedid):#训练演员-评论家算法
    returnlist=[]
    for k in range(pbarnum):
        with tqdm(total=int(episodesnum/pbarnum),desc='Iteration %d' % k) as pbar:
            for episode in range(int(episodesnum/pbarnum)):
                episodereturn=0
                transitiondist={"states":[],"actions":[],"nextstates":[],"rewards":[],"dones":[]}
                state=env.reset(seed=seedid)[0]
                done=False
                while not done:
                    action=agent.takeaction(state)
                    nextstate,reward,done,truncated,_=env.step(action)
                    done=done or truncated
                    transitiondist["states"].append(state)
                    transitiondist["actions"].append(action)
                    transitiondist["nextstates"].append(nextstate)
                    transitiondist["rewards"].append(reward)
                    transitiondist["dones"].append(done)
                    state=nextstate
                    episodereturn+=reward
                returnlist.append(episodereturn)
                agent.update(transitiondist)
                if (episode+1)%(printreturnnum)==0:
                    pbar.set_postfix({"episode":"%d"%(episodesnum/pbarnum*k+episode+1),"return":"%.3f"%np.mean(returnlist[-printreturnnum:])})
                pbar.update(1)
    return returnlist

6.移动平滑处理时间序列函数实现

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

7.参数配置

actor_learningrate=1e-3
critic_learningrate=1e-2
episodesnum=1000
hiddendim=128
gamma=0.98
pbarnum=10
printreturnnum=10
seedid=0
device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

8.车杆环境实验

env=gym.make("CartPole-v1")#env=gym.make("CartPole-v1",render_mode="human")
env.reset(seed=seedid)
torch.manual_seed(seedid)
statedim=env.observation_space.shape[0]
actiondim=env.action_space.n
agent=ActorCritic(statedim,hiddendim,actiondim,actor_learningrate,critic_learningrate,gamma,device)
returnlist=train_on_policy_agent(env,agent,episodesnum,pbarnum,printreturnnum,seedid)
episodelist=list(range(len(returnlist)))
plt.plot(episodelist,returnlist)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("Actor-Critic on {}-{}".format(env.spec.name,env.spec.id))
plt.show()
mvreturn=moving_average(returnlist,9)
plt.plot(episodelist,mvreturn)
plt.xlabel("Episodes")
plt.ylabel("Returns")
plt.title("Actor-Critic on {}-{}".format(env.spec.name,env.spec.id))
plt.show()

9.实验结果

Actor-Critic算法很快收敛到最优策略,训练过程非常稳定,抖动情况与REINFORCE算法相比有了明显改进,这说明价值函数的引入减少了方差。 

10.小结

Actor-Critic算法是基于值函数和基于策略的方法的叠加,价值模块Critic在策略模块Actor采样的数据中学习分辨什么是好的动作,什么是不好的动作,进而指导Actor进行策略更新,随着Actor训练不断进行,与环境交互产生的数据分布也发生改变,这需要Critic尽快适应新数据分布并给出好的判别。TRPO、PPO、DDPG、SAC等深度强化学习算法都是在Actor-Critic算法基础上进行发展改进的,其作为基础,深入理解大有裨益。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2035547.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

element时间段选择器或时间选择器 只设置默认起始时间或者结束时间,不显示问题

element时间段选择器或时间选择器 只设置默认起始时间或者结束时间&#xff0c;不显示问题 <div v-for"(item,index) in [a,b]":key"item"><el-date-pickerv-if"b"v-model"value1[item]"type"datetimerange"value-…

16s功能注释Bugbase的安装使用--本地版

文章目录 概述介绍下载安装程序下载并配置环境安装依赖R包并显示帮助运行示例数据Bug及解决方法-☆ 使用输入文件准备-☆下载Greengenes数据库在QIIME2中操作R语言操作 运行Bugbase 概述 Bugbase依赖于Greegenes1与R 但是R现已更新到4.4以上&#xff0c;安装R包时会不兼容且输…

【时时三省】(C语言基础)结构体初阶

山不在高&#xff0c;有仙则名。水不在深&#xff0c;有龙则灵。 ----CSDN 时时三省 结构体的声明 结构的基础知识: 结构是一些值的集合&#xff0c;这些值称为成员变量。结构的每个成员可以是不同类型的变量 数组: 是一组相同类型的元素的集合 结构体: 也是一些值得集合…

python-二进制?十进制?(赛氪OJ)

[题目描述] 给定两个十进制整数 : A&#xff0c;B。 你需要把它们的二进制形式以十进制的运算法则相加输出结果。 例如&#xff1a; A3 , B2的时候&#xff0c;A 的二进制表示是 : 11 , &#x1d435;B 的二进制表示是 10 &#xff0c;你需要输出答案为 : 21。 输入格式…

基于 Flutter 从零开发一款产品(一)—— 跨端开发技术介绍

前言 相信很多开发者在学习技术的过程中&#xff0c;常常会陷入一种误区当中&#xff0c;就是学了很多技术理论知识&#xff0c;但是仍做不出什么产品出来&#xff0c;往往学了很多干货&#xff0c;但是并无实际的用处。其实&#xff0c;不论是做什么&#xff0c;我们都需要从…

嵌入式linux系统镜像制作day1

点击上方"蓝字"关注我们 01、前言 嵌入式设备&#xff08;例如心电图检测仪&#xff0c;售票系统等&#xff09;。尽管&#xff0c;嵌入式设备像那些智能手机一样&#xff0c;绝大多数都使用同样的硬件和软件&#xff0c;包括系统芯片SoC、储存、连接和多媒体接口、…

Could not find artifact net.sf.json-lib:json-lib:jar

一开始我改了maven的setting&#xff0c;由官网变为阿里云仓库&#xff0c;最后还是不行 <dependency><groupId>net.sf.json-lib</groupId><artifactId>json-lib</artifactId><version>2.0</version><classifier>jdk15</cl…

freeRTOS任务通知(Task Notifications)

目录 前言 一、任务通知概述 1.优势及限制 2.通知状态和通知值 二、任务通知的使用 两类函数 1.xTaskNotifyGive/ulTaskNotifyTake 2.xTaskNotify/xTaskNotifyWait 三、传输计数值代码示例 四、传输任意值代码示例 前言 所谓"任务通知"&#xff0c;你可以反…

螺旋矩阵 | LeetCode-59 | LeetCode-54 | 分类讨论

&#x1f64b;大家好&#xff01;我是毛毛张! &#x1f308;个人首页&#xff1a; 神马都会亿点点的毛毛张 &#x1f383;分类不好&#xff0c;这道题就做不出来&#xff01;&#x1f388; &#x1f4cc;LeetCode链接&#xff1a;59. 螺旋矩阵 II &#x1f4cc;LeetCode链接…

李晨晨的嵌入式学习 DAY25

今天对昨天的fork函数进行了补充&#xff0c;并且学习了exec函数 一&#xff0c;fork函数补充 1.open在fork之前 子进程会继承父进程已打开的相关信息&#xff0c;父子进程会影响同一个offset值 2.open在fork之后 父子进程各自有各自打开的文件信息&#xff0c;不相互影响 …

Luatos-lua For MacOSX

0x00 缘起 看到Luatos-soc-pc项目能够编译到MacOS平台并且运行&#xff0c;所以尝试编译&#xff1b;可是Apple Clang编译器太过于严格&#xff0c;导致编译不通过。遂换到gcc-11编译通过&#xff0c;虽然其中依旧会报错&#xff08;宏定义LUA_USE_MACOSX不起作用&#xff0c;导…

Linux驱动入门实验班——LED驱动(附百问网视频链接)

目录 一、确定引脚编号 二、编写思路 2.1驱动层 2.2应用层 三、源码 四、实现 课程链接 一、确定引脚编号 首先&#xff0c;可以在开发板上执行如下命令查看已经在使用的GPIO状态&#xff1a; cat /sys/kernel/debug/gpio 可以看到每个gpio都有对应的编号&#xff0c;…

岗位信息采集全攻略:两种方法快速获取招聘信息

摘要 本文将揭秘两大实战策略&#xff0c;助你在激烈的人才市场中迅速捕捉前程无忧上的宝贵岗位信息&#xff0c;无论是手动搜索还是利用现代技术手段&#xff0c;都能事半功倍&#xff0c;抢占先机。 正文 一、手动搜索的艺术&#xff1a;精准定位&#xff0c;深度挖掘 1.…

【网络】传输层TCP协议的报头和传输机制

目录 引言 报头和有效载荷 确认应答机制 超时重传机制 排序和去重 连接管理机制 个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 引言 TCP是传输层协议&#xff0c;全称传输控制协议。TCP报头中有丰富的字段以及协议本身会制定完善的策略来保证网络传输的可靠性。 TCP…

ICM-20948芯片详解(12)

接前一篇文章&#xff1a;ICM-20948芯片详解&#xff08;11&#xff09; 六、寄存器详解 2. USER BANK 0寄存器详述 &#xff08;56&#xff09;FIFO_EN_1 参考代码&#xff1a; 无。 &#xff08;57&#xff09;FIFO_EN_2 ACCEL_FIFO_EN 1 —— 以采样率将ACCEL-XOUT_H、…

haproxy实例

什么是haproxy Haproxy是一款提供高可用性&#xff0c;负载均衡以及基于tcp和http的的应用交付控制器的开源软件。它由法国人威利塔罗使用c语言开发的。它广泛用于管理和路由网络流量&#xff0c;并确保应用程序的高可用性和高性能。 haproxy的功能 提供第4层&#xff08;TCP层…

vulnhub系列:Hackademic.RTB1

vulnhub系列&#xff1a;Hackademic.RTB1 靶机下载 一、信息收集 nmap 扫描存活&#xff0c;根据 mac 地址寻找 IP nmap 192.168.23.0/24nmap 扫描端口&#xff0c;开放端口&#xff1a;22、80 nmap 192.168.23.143 -p- -Pn -sV -O访问80端口&#xff0c;页面发现 target …

DirectX修复工具解决问题:一步步教你排除常见错误

在日常使用电脑的过程中&#xff0c;许多用户可能会遇到与DirectX相关的问题&#xff0c;特别是在运行大型游戏或图形密集型应用程序时。这种情况下&#xff0c;选择一款合适的DirectX修复工具免费版来解决问题至关重要&#xff01; 我们将分享六款好用的DirectX修复工具&…

字节Java后端二面也太难了吧...

粉丝投稿&#xff0c;字节二面直接连环问场景题&#xff0c;难以招架&#xff0c;已经准备好市场上常见的场景题了&#xff0c;希望能帮助你&#xff01; 由于平台篇幅原因&#xff0c;很多内容展示不了&#xff0c;需要这份《java面试宝典》的伙伴们转发文章关注后&#xff…

Linux_Shell变量及运算符-05

一、Shell基础 1.1 什么是shell Shell脚本语言是实现Linux/UNIX系统管理及自W动化运维所必备的重要工具&#xff0c; Linux/UNIX系统的底层及基础应用软件的核心大都涉及Shell脚本的内容。Shell是一种编程语言, 它像其它编程语言如: C, Java, Python等一样也有变量/函数/运算…