强化学习时序差分算法之Sarsa算法——以悬崖漫步环境为例

news2024/11/27 3:55:03

1.导入必要的库环境,代码如下所示。

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

2.本悬崖漫步环境中无需提供奖励函数以及状态转移函数,而需提供一个与智能体进行交互的step()函数,该函数输入为智能体当前状态下的动作,输出为当前状态下的奖励以及智能体的下一状态,代码如下所示。

class CliffWalkingEnv:
    def __init__(self,ncol,nrow,step_reward,cliff_reward):
        self.ncol=ncol
        self.nrow=nrow
        self. x=0#记录当前智能体位置的横坐标
        self.y=self.nrow-1#记录当前智能体位置的纵坐标
        self.step_reward=step_reward
        self.cliff_reward=cliff_reward
    def step(self,action):#外部调用此函数改变当前位置
        change=[[0,-1],[0,1],[-1,0],[1,0]]#定义四个动作,change[0]:上;change[1]:下;change[2]:左;change[3]:右;坐标系原点(0,0)定义在左上角
        self.x=np.clip(self.x+change[action][0],0,self.ncol-1)#也可采用self.x=min(self.ncol-1,max(0,self.x+change[action][0]))
        self.y=np.clip(self.y+change[action][1],0,self.nrow-1)#也可采用self.y=min(self.nrow-1,max(0,self.y+change[action][1]))
        next_state=self.ncol*self.y+self.x#计算下一个状态
        reward=self.step_reward
        done=False
        if self.y==self.nrow-1 and self.x>0:#如果当前位置在悬崖或者终点
            done=True
            if self.x!=self.ncol-1:#如果在悬崖
                reward=self.cliff_reward
        return next_state,reward,done
    def reset(self):#环境重置
        self.x=0
        self.y=self.nrow-1
        return self.y*self.ncol+self.x

3.实现Sarsa算法,维护一个Q_table()表格,本表格主要存储当前策略下所有状态动作对的价值,即所有状态下各个动作的动作价值函数,Sarsa算法与环境进行交互,采用\varepsilon-贪婪策略进行采样,使用时序差分算法进行Sarsa算法更新。本程序默认终止状态时所有动作的价值为0,即终止状态的动作价值函数均为0,其在初始化为0后不会在与环境交互过程中进行更新。具体代码如下所示。

class Sarsa:
    """ Sarsa算法 """
    def __init__(self,ncol,nrow,epsilon,alpha,gamma,n_action=4):
        self.ncol=ncol
        self.nrow=nrow
        self.epsilon=epsilon
        self.alpha=alpha
        self.gamma=gamma
        self.n_action=n_action
        self.Q_table=np.zeros([self.ncol*self.nrow,self.n_action])#初始化Q(s,a)表格
    def take_action(self,state):
        if np.random.random()<self.epsilon:#如果小于epsilon,则随机选择动作
            action=np.random.randint(self.n_action)
        else:
            action=np.argmax(self.Q_table[state])
        return action
    def best_action(self,state):
        Q_max=np.max(self.Q_table[state])
        a=[0]*self.n_action#or a=[0 for _ in range(self.n_action)]
        for i in range(self.n_action):#若两动作价值一样为最大,则会记录下来
            if self.Q_table[state,i]==Q_max:
                a[i]=1
        return a
    def update(self,s0,a0,r,s1,a1):#Sarsa算法核心部分,采用时序差分算法估计动作价值函数Q,Q(s0,a0)~Q(s0,a0)+alpha(r0+gamma*Q(s1,a1)-Q(s0,a0))
        TD_error=r+self.gamma*self.Q_table[s1,a1]-self.Q_table[s0,a0]
        self.Q_table[s0,a0]+=self.alpha*TD_error

4.本案例悬崖漫步环境中相关参数设置如下所示,可以自行修改。

ncol=12#悬崖漫步环境中的网格环境列数
nrow=4#悬崖漫步环境中的网格环境行数
step_reward=-1#每步的即时奖励
cliff_reward=-100#悬崖的即时奖励
epsilon=0.1#epsilon-贪婪算法的探索因子
alpha=0.1 #价值估计更新的步长
gamma=0.9 #回报计算的折扣衰减因子
n_action=4 #动作个数

5.主程序实现部分如下所示。

env=CliffWalkingEnv(ncol,nrow,step_reward,cliff_reward)
agent=Sarsa(ncol,nrow,epsilon,alpha,gamma,n_action)
num_episode=500#智能体在环境中运行的序列的数量
episode_Gt_list=[]#记录每个序列的回报
pbar_num=10#进度条的数量
for i in range(pbar_num):#显示每个进度条
    with tqdm(total=num_episode/pbar_num,desc='Episode %d'%i) as pbar:
        for episode in range(int(num_episode/pbar_num)):#每个进度条的序列数量
            number=1#记录每个序列的步数
            episode_Gt=0
            state=env.reset()
            action=agent.take_action(state)
            done=False
            while not done:
                next_state,reward,done=env.step(action)
                next_action=agent.take_action(next_state)
                episode_Gt+=reward #回报的计算不进行折扣因子衰减,考虑远期最优。
                agent.update(state,action,reward,next_state,next_action)
                state=next_state
                action=next_action
                number+=1
            episode_Gt_list.append(episode_Gt)
            if (episode+1)%10==0:#每10条序列打印一下这十条序列的平均回报
                pbar.set_postfix({'episode':'%d' %((num_episode/pbar_num)*i+episode+1),'return':'%.3f'%np.mean(episode_Gt_list[-10:],axis=None)})
            pbar.update(1)
    print('\n')
episodes_list=list(range(len(episode_Gt_list)))
plt.plot(episodes_list,episode_Gt_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Sarsa on {}'.format('Cliff Walking'))#or "plt.title('Sarsa on %s'%'Cliff Walking') "
plt.show()

6.结果如图所示。

Episode 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 276.35it/s, episode=50, return=-119.100]


Episode 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 345.86it/s, episode=100, return=-73.600]


Episode 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 421.77it/s, episode=150, return=-33.200]


Episode 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 586.43it/s, episode=200, return=-38.700]


Episode 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 668.64it/s, episode=250, return=-27.600]


Episode 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 787.75it/s, episode=300, return=-19.800]


Episode 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 891.04it/s, episode=350, return=-19.200] 


Episode 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 875.73it/s, episode=400, return=-21.100] 


Episode 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 844.25it/s, episode=450, return=-29.200] 


Episode 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 904.60it/s, episode=500, return=-17.800] 
 

 我们可以发现随着训练的进行,Sarsa算法获得的回报越来越高,在进行500条序列的学习后,可以获得-20左右的回报,此时非常接近最优策略了。

7.如下所示的程序是实现查看Sarsa算法得到的策略在各状态下会使智能体采取何种动作的功能。

def print_agent(agent,env,action_meaning,disaster=[],end=[]):
    for i in range(env.nrow):
        for j in range(env.ncol):
            if (i*env.ncol+j) in disaster:
                print('****',end=' ')
            elif (i*env.ncol+j) in end:
                print('EEEE',end=' ')
            else:
                a=agent.best_action(i*env.ncol+j)
                pi_str=''
                for k in range(len(action_meaning)):
                    pi_str+=action_meaning[k] if a[k]>0 else 'o'
                print(pi_str,end=' ')
        print()
action_meaning=['^','v','<','>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent,env,action_meaning,disaster=[range(37,47)],end=[47])

8.结果如下所示:

ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo ^ooo ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo ooo> ovoo
^ooo oo<o oo<o ^ooo ^ooo oo<o ^ooo ooo> oo<o ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE

可以发现Sarsa算法会采取比较远离悬崖的策略来抵达目标。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1962074.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python time模块格式化时间的N种技巧

文末赠免费精品编程资料~~ 是不是经常对着电脑屏幕上的日期时间发呆&#xff0c;心想&#xff1a;“要是能随心所欲地格式化这些数字就好了。”今天&#xff0c;我们就一起探索Python中的时间宝藏——time模块&#xff0c;让你轻松玩转时间显示&#xff0c;从新手进阶为时间格…

AI算力的新时代:智算中心的挑战与创新

随着AI的发展&#xff0c;作为AI三要素算法、数据、算力中的基础设施——算力首先迎来了高速的发展。智算中心作为AI时代承载算力的关键基础设施&#xff0c;在政策、市场的双重驱动下进入了高速建设周期&#xff0c;其在推动数字经济发展和技术进步方面发挥着重要作用&#xf…

【Gin】深度解析:在Gin框架中优化应用程序流程的责任链设计模式(下)

【Gin】深度解析&#xff1a;在Gin框架中优化应用程序流程的责任链设计模式(下) 大家好 我是寸铁&#x1f44a; 【Gin】深度解析&#xff1a;在Gin框架中优化应用程序流程的责任链设计模式(下)✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 本次文章分为上下两部分&#xf…

数学建模--微分方程

目录 常见的微分方程模型 微分方程建模的基本步骤 代码示例 常微分方程 ​编辑 ​编辑 偏微分方程 ​编辑 应用实例 结论 如何在数学建模中准确识别和选择合适的微分方程模型&#xff1f; 微分方程模型在解决实际问题中的应用案例有哪些&#xff1f; 常微分方程&a…

SpringBoot整合FFmpeg进行视频分片上传

SpringBoot整合FFmpeg进行视频分片上传------>Windows 分片上传的核心思路&#xff1a; 将文件按一定的分割规则&#xff08;静态或动态设定&#xff0c;如手动设置20M为一个分片&#xff09;&#xff0c;用slice分割成多个数据块。为每个文件生成一个唯一标识Key&#xf…

ONNX模型的量化

我们都希望从代码中榨取更多的性能&#xff0c;对吧&#xff1f; 在现代&#xff0c;充斥着需要大量计算资源的复杂机器学习算法&#xff0c;因此&#xff0c;榨取每一点性能至关重要。 传统上&#xff0c;机器学习算法是在具有支持大量并行计算能力的 GPU 上进行训练的。但是…

WordPress建站:如何使用ChemiCloud搭建外贸独立站

以前自行搭建一个网站&#xff0c;不懂一点技术那是很难完成的&#xff0c;现如今WordPress的出现极大地降低了搭建网站的技术门槛&#xff0c;不需要懂任何代码&#xff0c;只需按步骤操作就行。WordPress 是一个非常流行的开源内容管理系统&#xff08;CMS&#xff09;&#…

职业教育计算机网络综合实验实训室建设应用案例

近年来&#xff0c;职业教育在培养技能型人才方面发挥着越来越重要的作用。然而&#xff0c;传统的计算机网络技术教学模式往往重理论、轻实践&#xff0c;导致学生缺乏实际操作能力和职业竞争力。为了改变这一现状&#xff0c;唯众结合职业教育特点&#xff0c;提出了“教、学…

Kubeflow v1.7.0 创建新用户

文章目录 为新用户创建配置文件配置用户密码重启auth生效 为新用户创建配置文件 apiVersion: kubeflow.org/v1beta1 kind: Profile metadata:name: kubeflow-cyw-example-com # replace with the name of profile you want, this will be users namespace name spec:owner:k…

STC单片机UART映射printf

文章目录 使用STC-ISP生成UART初始化函数 增加如下函数&#xff0c;注意使用printf函数需要添加 #include <stdio.h> 头文件 #include <stdio.h>void Uart1_Init(void) //9600bps12.000MHz {SCON 0x50; //8位数据,可变波特率AUXR | 0x01; //串口1选择定时器2为…

【Spring】——Spring概述、IOC、IOC创建对象的方式、Spring配置、依赖注入(DI)以及自动装配知识

&#x1f3bc;个人主页&#xff1a;【Y小夜】 &#x1f60e;作者简介&#xff1a;一位双非学校的大二学生&#xff0c;编程爱好者&#xff0c; 专注于基础和实战分享&#xff0c;欢迎私信咨询&#xff01; &#x1f386;入门专栏&#xff1a;&#x1f387;【MySQL&#xff0…

LeetCode 101.对称二叉树 C写法

LeetCode 101.对称二叉树 C写法 思路&#xff1a; 将该树一分为二&#xff0c;左子树的左边与右子树的右边比&#xff0c;左子树的右边与右子树的左边比&#xff0c;不相等或者一边为空则不是对称。 代码&#x1f50e;&#xff1a; bool _isSymmetric(struct TreeNode* Leftroo…

程序员开发指南

在这个快节奏的时代&#xff0c;作为一名程序员&#xff0c;大家都希望能更快地开发出高质量的应用&#xff0c;而不是花费大量时间在基础设施和后台服务的搭建上。今天&#xff0c;我要向大家介绍一款专为懒人开发者准备的一站式开发应用的神器——MemFire Cloud。 一站式开发…

使用代理访问内网:实验二

目录 环境搭建 内网搭建&#xff08;win2019&#xff09; 跳板机搭建&#xff08;win10&#xff09; 实验步骤 1. win10上线kali 2. 借助msf做代理 3. 在攻击机上做个代理&#xff0c;访问目标网站 4. 使用SocksCap64工具&#xff0c;进行sock4a隧道的连接 5. 启用soc…

TypeScript 的主要特点和重要作用

还是大剑师兰特&#xff1a;曾是美国某知名大学计算机专业研究生&#xff0c;现为航空航海领域高级前端工程师&#xff1b;CSDN知名博主&#xff0c;GIS领域优质创作者&#xff0c;深耕openlayers、leaflet、mapbox、cesium&#xff0c;canvas&#xff0c;webgl&#xff0c;ech…

最短路(dijkstra迪杰斯特拉)

最短路径问题在图论中是一个经典的问题&#xff0c;目的是找到从一个起始顶点到其他所有顶点的最短路径。Dijkstra算法是解决非负权图最短路径问题的常用算法。下面是一个使用Dijkstra算法解决最短路径问题的Java程序例子。 动画描述(从0节点开始更新) 问题描述 假设有一个图…

【机器学习西瓜书学习笔记——模型评估与选择】

机器学习西瓜书学习笔记【第二章】 第二章 模型评估与选择2.1训练误差和测试误差错误率误差 欠拟合和过拟合2.2评估方法留出法交叉验证法自助法 2.3性能度量查准率、查全率与F1查准率查全率F1 P-R曲线ROC与AUCROCAUC 代价敏感错误率与代价曲线代价曲线 2.4比较检验假设检验&…

VSCode+Vue3无法找到模块“../components/xxxxx.vue”的声明文件的错误

莫名奇妙的错误 今天用Vue3写个demo&#xff0c;在components下面新建了一个DeviceList.Vue的文件&#xff0c;在HomeView引用它后居然报错&#xff0c;提示&#xff1a;无法找到模块“…/components/DeviceList.vue”的声明文件&#xff0c;真是离了个大谱&#xff0c;文件明…

【Redis】 拓展:Redis - BigKey方案探讨

BigKey: 用户越多&#xff0c;redis数据越多&#xff0c;bigkey会使得缓存数据更大&#xff0c;网络带宽会被占用&#xff0c;执行效率就低下&#xff0c;高并发的时候吞吐量QPS也会下降。 产生原因&#xff1a; 看如下list&#xff1a; 一个key的内容太大&#xff0c;比如1M&…

VR舒适度术语表与检查表:为MR和空间计算应用创业者准备

随着混合现实&#xff08;MR&#xff09;和空间计算应用的发展&#xff0c;确保用户在虚拟环境中的舒适度变得尤为重要。本文将介绍一套专门针对VR舒适度的术语表&#xff0c;并提供两个知名VR游戏作为示例&#xff0c;来展示如何应用这些术语。这些术语和示例可以帮助开发者更…