强化学习时序差分算法之多步Sarsa算法——以悬崖漫步环境为例

news2024/11/15 8:32:57

0.简介

        蒙特卡洛方法利用当前状态之后每一步奖励而不使用任何价值估计,时序差分算法则只利用当前状态的奖励以及对下一状态的价值估计。

        蒙特卡洛算法是无偏的,但是它的每一步的状态转移具有不确定性,同时每一步状态采取的动作所得到的不一样的奖励最终会累计起来,从而极大影响最终的状态估计,因而其方差较大。

        时序差分算法只采用了一步状态转移以及使用了一步奖励,因而具有非常小的方差;但是它由于用到了下一状态的价值估计而不是其真实价值,故而有偏。

        多步时序差分算法则结合了二者的优势,其使用n步奖励以及之后的状态的价值估计,其公式为:

        多步Sarsa算法伪代码如下所示:

1.导入相关库

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

2.悬崖漫步环境实现环节

class cliffwalkingenv():
    def __init__(self,colnum,rownum,stepreward,cliffreward,initx,inity):
        self.colnum=colnum
        self.rownum=rownum
        self.stepreward=stepreward
        self.cliffrreward=cliffreward
        self.initx=initx
        self.inity=inity
    def step(self,action):
        change=[[0,-1],[0,1],[-1,0],[1,0]]
        self.x=min(self.colnum-1,max(0,self.x+change[action][0]))
        self.y=min(self.rownum-1,max(0,self.y+change[action][1]))
        next_state=self.y*self.colnum+self.x
        reward=self.stepreward
        done=False
        if self.y==self.rownum-1 and self.x>0:
            done=True
            if self.x!=self.colnum-1:
                reward=self.cliffrreward
        return next_state,reward,done
    def reset(self):
        self.x=self.initx
        self.y=self.inity
        return self.y*self.colnum+self.x

3.在Sarsa算法基础上进行修改,引入多步时序差分计算,实现多步(n步)Sarsa算法

class nstep_sarsa():
    """ n步Sarsa算法 "" """
    def __init__(self,n,colnum,rownum,alpha,gamma,epsilon,actionnum=4):
        self.n=n
        self.colnum=colnum
        self.rownum=rownum
        self.alpha=alpha
        self.gamma=gamma
        self.epsilon=epsilon
        self.actionnum=actionnum
        self.qtable=np.zeros([self.colnum*self.rownum,self.actionnum])
        self.statelist=[]#保存之前的状态
        self.actionlist=[]#保存之前的动作
        self.rewardlist=[]#保存之前的奖励
    def  takeaction(self,state):
        if np.random.random()<self.epsilon:
            action=np.random.randint(self.actionnum)
        else:
            action=np.argmax(self.qtable[state])
        return action
    def bestaction(self,state):#打印策略
        qmax=np.max(self.qtable[state])
        a=[0 for _ in range(self.actionnum)]
        for k in range(self.actionnum):
            if self.qtable[state][k]==qmax:
                a[k]=1
        return a
    def update(self,s0,a0,r,s1,a1,done):
        self.statelist.append(s0)
        self.actionlist.append(a0)
        self.rewardlist.append(r)
        if len(self.statelist)==self.n:#若保存的数据可以进行n步更新
            G=self.qtable[s1][a1]#得到Q(s(n+t),a(n+t))
            for i in reversed(range(self.n)):
                G=self.gamma*G+self.rewardlist[i]#不断向前计算每一步的回报
                if done and i>0:#如果到达终止状态,最后几步虽然长度不够n步,也将其进行更新
                    s=self.statelist[i]
                    a=self.actionlist[i]
                    self.qtable[s][a]+=self.alpha*(G-self.qtable[s][a])
            s=self.statelist.pop(0)
            a=self.actionlist.pop(0)
            self.rewardlist.pop(0)
            self.qtable[s][a]+=self.alpha*(G-self.qtable[s][a])#n步Sarsa的主要更新步骤
        if done:
            self.statelist=[]
            self.actionlist=[]
            self.rewardlist=[]

4.最终通过算法寻找的最优策略的显示

def printagent(agent,env,actionmeaning,disaster=[],end=[]):
    for i in range(env.rownum):
        for j in range(env.colnum):
            if (i*env.colnum+j) in disaster:
                print('****',end=' ')
            elif (i*env.colnum+j) in end:
                print('EEEE',end=' ')
            else:
                a=agent.bestaction(i*env.colnum+j)
                pistr=''
                for k in range(len(actionmeaning)):
                    pistr+=actionmeaning[k] if a[k]>0 else 'o'
                print('%s'%pistr,end=' ')
        print()

5.相关参数设置

ncol=12#悬崖漫步环境中的网格环境列数
nrow=4#悬崖漫步环境中的网格环境行数
step_reward=-1#每步的即时奖励
cliff_reward=-100#悬崖的即时奖励
init_x=0#智能体在环境中初始位置的横坐标
init_y=nrow-1#智能体在环境中初始位置的纵坐标
n_step=5#5步Sarsa算法
alpha=0.1#价值估计更新的步长
epsilon=0.1#epsilon-贪婪算法的探索因子
gamma=0.9#折扣衰减因子
num_episodes=500#智能体在环境中运行的序列总数
tqdm_num=10#进度条的数量
printreturnnum=10#打印回报的数量
actionmeaning=['↑','↓','←','→']#上下左右表示符

6.程序主体部分实现

np.random.seed(0)
returnlist=[]
env=cliffwalkingenv(colnum=ncol,rownum=nrow,stepreward=step_reward,cliffreward=cliff_reward,initx=init_x,inity=init_y)
agent=nstep_sarsa(n=n_step,colnum=ncol,rownum=nrow,alpha=alpha,gamma=gamma,epsilon=epsilon,actionnum=4)
for i in range(tqdm_num):
    with tqdm(total=int(num_episodes/tqdm_num),desc='Iteration %d'% i) as pbar:#tqdm进度条功能
        for episode in range(int(num_episodes/tqdm_num)):#每个进度条的序列数
            episodereturn=0
            state=env.reset()
            action=agent.takeaction(state)
            done=False
            while not done:
                nextstate,reward,done=env.step(action)
                nextaction=agent.takeaction(nextstate)
                episodereturn+=reward#这里回报计算不进行折扣因子衰减
                agent.update(state,action,reward,nextstate,nextaction,done)
                state=nextstate
                action=nextaction
            returnlist.append(episodereturn)
            if (episode+1)%printreturnnum==0:#每printreturnnum条序列打印一下这printreturnnum条序列的平均回报
                pbar.set_postfix({'episode':'%d'%(num_episodes/tqdm_num*i+episode+1),'return':'%.3f'%(np.mean(returnlist[-printreturnnum:]))})
            pbar.update(1)
episodelist=list(range(len(returnlist)))
plt.plot(episodelist,returnlist)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('{}-step Sarsa on{}'.format(n_step,'Cliff Walking'))
plt.show()
print('{}步Sarsa算法最终收敛得到的策略为:'.format(n_step))
printagent(agent=agent,env=env,actionmeaning=actionmeaning,disaster=list(range(37,47)),end=[47])

7.实现效果与数据

Iteration 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 626.45it/s, episode=50, return=-26.500] 
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3128.16it/s, episode=100, return=-35.200] 
Iteration 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2781.15it/s, episode=150, return=-20.100] 
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2182.42it/s, episode=200, return=-27.200] 
Iteration 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2634.94it/s, episode=250, return=-19.300] 
Iteration 5: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2638.69it/s, episode=300, return=-27.400] 
Iteration 6: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2638.86it/s, episode=350, return=-28.000] 
Iteration 7: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2281.37it/s, episode=400, return=-36.500] 
Iteration 8: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2785.10it/s, episode=450, return=-27.000] 
Iteration 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2642.51it/s, episode=500, return=-19.100]

5步Sarsa算法最终收敛得到的策略为:
ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ ooo→ o↓oo 
↑ooo ↑ooo ↑ooo oo←o ↑ooo ↑ooo ↑ooo ↑ooo ooo→ ooo→ ↑ooo o↓oo
ooo→ ↑ooo ↑ooo ↑ooo ↑ooo ↑ooo ↑ooo ooo→ ooo→ ↑ooo ooo→ o↓oo
↑ooo **** **** **** **** **** **** **** **** **** **** EEEE 

8.总结

        通过实验我们可以发现5步Sarsa算法的收敛性比单步Sarsa算法更快,此时多步Sarsa算法得到的策略会在最远离悬崖的一边行走,以保证最大的安全性。关于单步Sarsa算法在悬崖漫步中的实现效果见我另一篇博客:强化学习时序差分算法之Sarsa算法——以悬崖漫步环境为例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1966796.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学习Java的日子 Day61 Listener监听器

Day61 Listener监听器 JavaWeb 三大组件(Servlet、Filter、Listener) 概念 监听器用于监听web应用中某些对象信息的创建、销毁、增加&#xff0c;修改&#xff0c;删除等动作的发生&#xff0c;然后作出相应的响应处理。当范围对象的状态发生变化的时候&#xff0c;服务器自动调…

vue echarts 横向柱状图,交错正负轴标签

横向柱状图&#xff1a; 同一个页面展示多个相同横向柱状图&#xff1b; 代码如下&#xff1a; <template><div style"display: flex;justify-content: space-around;"><div v-for"(chart,index) in barChartList" :key"index"…

基于LoRA和AdaLoRA微调Qwen1.5-0.5B-Chat

本文只开放基于LoRA和AdaLoRA微调代码,具体技术可以自行学习。 Qwen1.5-0.5B-Chat权重路径:https://huggingface.co/Qwen/Qwen1.5-0.5B 数据集路径:https://github.com/DB-lost/self-llm/blob/master/dataset/huanhuan.json 1. 知识点 LoRA, AdaLoRA技术 具体技术可以去看…

Transformer 架构告诉我们什么?

欢迎来到雲闪世界。ChatGPT 等大型语言模型 (LLM) 的出色表现震惊了世界。这一突破源于 Transformer 架构的发明&#xff0c;该架构出奇地简单且可扩展。它仍然由深度学习神经网络构建。主要新增功能是所谓的“注意力”机制&#xff0c;该机制将每个单词标记置于语境中。此外&a…

睿考网:CPA考试各科难度分析

CPA考试分为专业阶段和综合阶段两个部分&#xff0c;其中专业阶段包含六个科目&#xff0c;六科难度分别为大家介绍一下。 《会计》科目是CPA专业阶段中基础且难度较低的科目&#xff0c;同时也是核心的科目。对于零基础的考生来说&#xff0c;可能会感到困难&#xff0c;需要…

【C语言】堆的实现

堆的基本概念 堆在逻辑上是完全二叉树&#xff0c;那什么又是完全二叉树呢&#xff1f; 完全二叉树简单来说就是前n-1层每个节点都有两个儿子&#xff0c;最后一层叶子紧挨着排列。 堆在物理结构上适合用数组存储。 让我们先来学习树->二叉树的基本知识&#xff08;可看文…

verilog中的$radom函数

我需要产生一个背压。背压每次经过x个时钟周期之后翻转&#xff0c;x是0到1_6000间的一个随机数。 如下图的代码&#xff0c;($random % 10)产生的是-9到9的数&#xff0c;包括0&#xff0c; ($random % 10) 1 那么值就在 -8到10之间。 always (posedge clk) beginDATA_READ…

K8S可视化管理平台KubeSphere

什么是 KubeSphere &#xff1f; KubeSphere 是一款开源项目&#xff0c;在目前主流容器调度平台 Kubernetes 之上构建的企业级分布式多租户容器管理平台&#xff0c;提供简单易用的操作界面以及向导式操作方式&#xff0c;在降低用户使用容器调度平台学习成本的同时&#xff…

DBeaver连接mysql时,报错Public Key Retrieval is not allowed

解决 在新建连接的时候&#xff0c;驱动属性里设置 allowPublicKeyRetrieval 的值为 true。

SQL进阶技巧:Hive URL解析函数详解及实际应用

目 录 0 实际业务需求 1 URL的基本组成 2 PROTOCOL 协议 3 Hive中的URL解析函数 3.1 数据准备 3.2 创建数据库 3.3 需求 3.3.1 parse_url 讲解 3.3.2 测试 3.3.3 实现需求 3.3.4 注意问题 3.5 parse_url_tuple 3.5.1 需求 3.5.2 实现需求 3.5.3 注意问题 4 小…

HTML + CSS 学习指南:从入门到精通

一、HTML CSS 简介 HTML 和 CSS 在网页开发中扮演着至关重要的角色。HTML 如同网页的骨架&#xff0c;为网页提供了基本的结构和内容。它使用各种标签来定义页面的元素&#xff0c;如标题、段落、图片、链接等&#xff0c;确保信息得以有条理地组织和呈现。 CSS 则恰似网页的…

点可云ERP进销存V8版本—购货退货单操作使用讲解

本章我们讲解购货退货单的使用场景及操作使用说明。 购货退货单是指供应商收回或退还给采购方的货物的单据。它记录了购货方向供应商退还货物的详细信息&#xff0c;一般会在货物质量问题、退货政策、错误订购等情况下发生购货退货。 购货退货单可以通过两个方式产生&#xff0…

学习记录——day24 多进程编程

创建三个进程 可以让父进程创建一个子进程&#xff0c;再由父进程或者子进程创建一个子进程 #include <myhead.h> int main(int argc, char const *argv[]) {pid_t pid fork();if (pid >0){//父进程pid_t pid1 fork();if (pid1 >0){printf("father\n"…

linux Ubuntu 安装mysql-8.0.39 二进制版本

我看到网上很多都写的乱七八糟, 我自己总结了一个 首先, 去Mysql官网上下载一个mysql-8.0.39二进制版本的安装包 这个你自己去下载我这里就写一个安装过程和遇到的坑 第一步 解压mysql压缩包和创建my.cnf文件 说明: 二进制安装指定版本MySQL的时候&#xff0c;需要手动写配置…

十月稻田玉米品类全国销量领先背后:“卖点”到“买点”的用户思维

近日&#xff0c;十月稻田在梯媒全新上线的新潮玉米广告&#xff0c;吸引了很多消费者的注意。 画面里&#xff0c;十月稻田的黄糯玉米棒金黄且饱满&#xff0c;旁白是广告语&#xff1a;“新玉米上市&#xff0c;香香香&#xff01;”。这支广告也挑起了许多观众的食欲&#…

【QGroundControl二次开发】七.QGC自定义MAVLink消息MavLink通信协议 C++应用

1. 接收解析源码分析 通过接收串口或UDP发来的的字节流buffer&#xff0c;长度lengthbuffer.size()&#xff0c;通过下列脚本解析&#xff0c;每解析出一个mavlink数据包就执行onMavLinkMessage函数 for(int i 0 ; i < length ; i){msgReceived mavlink_parse_char(MAVL…

【运维自动化】网络统一监控运维管理解决方案(PPT建设方案)

运维自动化是提升IT运维效率、降低人力成本、增强系统稳定性和可靠性的关键举措。随着业务规模的增长&#xff0c;传统的手动运维方式已难以满足快速响应和高效管理的需求。自动化运维通过脚本、工具和系统平台&#xff0c;实现日常任务自动化执行、故障预警与快速恢复、资源优…

数据结构笔记纸质总结

1.基本概念 2.复杂度 3.线性表 4.栈 5.队列 6.串 7.数组 8.矩阵 9.广义表 10.树

15.3 Zookeeper官方使用_实现分布式锁

1. 简介 2. 代码演示 2.1 客户端连接类 package com.ruoyi.common.zookeeper;import com.ruoyi.common.exception.UtilException; import

命途多舛的Concepts:从提出到剔除再到延期最后到纳入,Concepts为什么在C++中大起大落?

在C的漫长发展史中&#xff0c;Concepts&#xff08;概念&#xff09;的故事显得尤为引人注目。它的历程不仅是C社区技术演进的缩影&#xff0c;也是对软件工程实践的一次深刻反思。本文将详细剖析C的Concepts&#xff1a;它是什么&#xff0c;它的设计初衷与使用场景&#xff…