强化学习的Sarsa与Q-Learning的Cliff-Walking对比实验

news2024/11/18 2:37:44

强化学习的Sarsa与Q-Learning的Cliff-Walking对比实验

  • Cliff-Walking问题的描述
  • Sarsa和Q-Learning算法对比
  • 代码分享
  • 需要改进的地方
  • 引用和写在最后

Cliff-Walking问题的描述

在这里插入图片描述

悬崖行走:从S走到G,其中灰色部分是悬崖不可到达,求可行方案
建模中,掉下悬崖的奖励是-100,G的奖励是10,原地不动的奖励-1,到达非终点位置的奖励是0(与图中的示意图不一致,不过大差不差),分别使用同轨策略的Sarsa与离轨策略的Q-learning算法,经过20000幕进化迭代得出safe path,optimal path,最后根据Q值来得出最终的策略,以此来对上图进行复现

Sarsa和Q-Learning算法对比

Sarsa算法
在这里插入图片描述
Q-Learning算法

在这里插入图片描述首先要介绍的是什么是ε-greedy,即ε-贪心算法,一般取定ε为一个较小的0-1之间的值(比如0.2)
在算法进行的时候,用计算机产生一个伪随机数,当随机数小于ε时采取任意等概率选择的原则,大于ε时则取最优的动作。

在介绍完两个算法和ε-贪心算法之后,一言概之就是,Sarsa对于当前状态s的a的选择是ε-贪心的,对于s’的a‘的选择也是ε-贪心的Q-Learning与sarsa一样,只是对于s’的a‘的选择是直接取最大的。

代码分享

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches  # 图形类

np.random.seed(2022)


class Agent():
    terminal_state = np.arange(36, 48)  # 终止状态

    def __init__(self, board_rows, board_cols, actions_num, epsilon=0.2, gamma=0.9, alpha=0.1):
        self.board_rows = board_rows
        self.board_cols = board_cols
        self.states_num = board_rows * board_cols
        self.actions_num = actions_num
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha = alpha
        self.board = self.create_board()
        self.rewards = self.create_rewards()
        self.qtable = self.create_qtable()

    def create_board(self):  # 创建面板
        board = np.zeros((self.board_rows, self.board_cols))
        board[3][11] = 1
        board[3][1:11] = -1
        return board

    def create_rewards(self):  # 创建奖励表
        rewards = np.zeros((self.board_rows, self.board_cols))
        rewards[3][11] = 10
        rewards[3][1:11] = -100
        return rewards

    def create_qtable(self):  # 创建Q值
        qtable = np.zeros((self.states_num, self.actions_num))
        return qtable

    def change_axis_to_state(self, axis):  # 将坐标转化为状态
        return axis[0] * self.board_cols + axis[1]

    def change_state_to_axis(self, state):  # 将状态转化为坐标
        return state // self.board_cols, state % self.board_cols

    def choose_action(self, state):  # 选择动作并返回下一个状态
        if np.random.uniform(0, 1) <= self.epsilon:
            action = np.random.choice(self.actions_num)
        else:
            p = self.qtable[state, :]
            action = np.random.choice(np.where(p == p.max())[0])

        r, c = self.change_state_to_axis(state)
        new_r = r
        new_c = c

        flag = 0

        #状态未改变
        if action == 0:  # 上
            new_r = max(r - 1, 0)
            if new_r == r:
                flag = 1
        elif action == 1:  # 下
            new_r = min(r + 1, self.board_rows - 1)
            if new_r == r:
                flag = 1
        elif action == 2:  # 左
            new_c = max(c - 1, 0)
            if new_c == c:
                flag = 1
        elif action == 3:  # 右
            new_c = min(c + 1, self.board_cols - 1)
            if new_c == c:
                flag = 1

        r = new_r
        c = new_c
        if flag:
            reward = -1 + self.rewards[r,c]
        else:
            reward = self.rewards[r, c]

        next_state = self.change_axis_to_state((r, c))
        return action, next_state, reward


    def learn(self, s, r, a, s_,sarsa_or_q):
        # s状态,a动作,r即时奖励,s_演化的下一个动作
        q_old = self.qtable[s, a]
        # row,col = self.change_state_to_axis(s_)
        done = False
        if s_ in self.terminal_state:
            q_new = r
            done = True
        else:
            if sarsa_or_q == 0:
                if np.random.uniform(0.1) <= self.epsilon:
                    s_a = np.random.choice(self.actions_num)
                    q_new = r + self.gamma * self.qtable[s_, s_a]
                else:
                    q_new = r + self.gamma * max(self.qtable[s_, :])
            else:
                q_new = r + self.gamma * max(self.qtable[s_, :])
                # print(q_new)
        self.qtable[s, a] += self.alpha * (q_new - q_old)
        return done


    def initilize(self):
        start_pos = (3, 0)  # 从左下角出发
        self.cur_state = self.change_axis_to_state(start_pos)  # 当前状态
        return self.cur_state


    def show(self,sarsa_or_q):
        fig_size = (12, 8)
        fig, ax0 = plt.subplots(1, 1, figsize=fig_size)
        a_shift = [(0, 0.3), (0, -.4),(-.3, 0),(0.4, 0)]
        ax0.axis('off')  # 把横坐标关闭
        # 画网格线
        for i in range(self.board_cols + 1):  # 按列画线
            if i == 0 or i == self.board_cols:
                ax0.plot([i, i], [0, self.board_rows], color='black')
            else:
                ax0.plot([i, i], [0, self.board_rows], alpha=0.7,
                     color='grey', linestyle='dashed')

        for i in range(self.board_rows + 1):  # 按行画线
            if i == 0 or i == self.board_rows:
                ax0.plot([0, self.board_cols], [i, i], color='black')
            else:
                ax0.plot([0, self.board_cols], [i, i], alpha=0.7,
                         color='grey', linestyle='dashed')

        for i in range(self.board_rows):
            for j in range(self.board_cols):

                y = (self.board_rows - 1 - i)
                x = j

                if self.board[i, j] == -1:
                    rect = patches.Rectangle(
                        (x, y), 1, 1, edgecolor='none', facecolor='black', alpha=0.6)
                    ax0.add_patch(rect)
                elif self.board[i, j] == 1:
                    rect = patches.Rectangle(
                        (x, y), 1, 1, edgecolor='none', facecolor='red', alpha=0.6)
                    ax0.add_patch(rect)
                    ax0.text(x + 0.4, y + 0.5, "r = +10")

                else:
                    # qtable
                    s = self.change_axis_to_state((i, j))
                    qs = agent.qtable[s, :]
                    for a in range(len(qs)):
                        dx, dy = a_shift[a]
                        c = 'k'
                        q = qs[a]
                        if q > 0:
                            c = 'r'
                        elif q < 0:
                            c = 'g'
                        ax0.text(x + dx + 0.3, y + dy + 0.5,
                                 "{:.1f}".format(qs[a]), c=c)

        if sarsa_or_q == 0:
            ax0.set_title("Sarsa")
        else:
            ax0.set_title("Q-learning")
        if sarsa_or_q == 0:
            plt.savefig("Sarsa")
        else:
            plt.savefig("Q-Learning")
        plt.show(block=False)
        plt.pause(5)
        plt.close()

加上下面这一段,就可以使程序跑起来啦!

agent = Agent(4, 12, 4)
maxgen = 20000
gen = 1
sarsa_or_q = 0
while gen < maxgen:
    current_state = agent.initilize()
    while True:
        action, next_state, reward = agent.choose_action(current_state)
        done = agent.learn(current_state, reward, action, next_state,sarsa_or_q)
        current_state = next_state
        if done:
            break

    gen += 1

agent.show(sarsa_or_q)
print(agent.qtable)

设置sarsa_or_q分别为0和1可以查看采用不同方法计算得的结果示意图
根据Q值就可以得到最后的收敛策略
在这里插入图片描述
在这里插入图片描述

需要改进的地方

代码迭代的收敛太慢,笔者写的代码迭代了20000才收敛,这与课程中的100幕左右就收敛的结果是不一致的,算法的效率上还需要改进。值得补充的是,100幕左右收敛在迭代最大代数中并没有做到,所以在模拟仿真的时候,索性就选择了20000次,说不定提前就收敛了。
可以改进的地方:对模型进行建立,因为之前代码是无模型的,设立模型对策略进行引导会得到更好的结果,当然也有可能使问题陷入局部探索之中,这是继续深入学习需要讨论的。
与科研科研结合的地方:在研究方向上,如果要结合的话,需要学习多个个体在环境下同时学习时的处理方法
在这里插入图片描述

引用和写在最后

Cliff-Walking仿真的是Reinforcement Learning Course by David Silver中第五讲课中的例子
课程的地址给在这里
记录一下强化学习课程的学习暂时完结,完结撒花,哒哒!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/136182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(Java)【深基9.例4】求第 k 小的数

【深基9.例4】求第 k 小的数 一、题目描述 输入 nnn&#xff08;1≤n<50000001 \le n < 50000001≤n<5000000 且 nnn 为奇数&#xff09;个数字 aia_iai​&#xff08;1≤ai<1091 \le a_i < {10}^91≤ai​<109&#xff09;&#xff0c;输出这些数字的第 kk…

元旦礼第三弹!玻色量子荣登2022年中国创新力量50榜单

​2022年12月&#xff0c;国内最大的创新者社区极客公园重磅发布了全新的「中国创新力量 50 榜单&#xff08;InnoForce 50&#xff09;」——在过去一年为泛计算机科学领域及其交叉领域带来创新和突破的中国公司/机构。玻色量子凭借在光量子计算领域突出的核心竞争力&#xff…

配电网前推后带法求电力系统潮流(PythonMatlab实现)

目录 1 概述 2 数学模型 3 节点分层前推回代潮流计算及步骤 3.1 计算方法 3.2 计算步骤 4 算例及数据 5 Matlab&Python代码实现 1 概述 配电网通常是单电源全网连接、开环运行&#xff0c;即呈树状。针对配电系统分析&#xff0c;其根本就是进行潮流计算。潮流计算的…

通过反射机制访问java对象的属性 给属性赋值 读取属性的值

package com.javase.reflect;import java.lang.reflect.Field;/*** 通过反射机制&#xff0c;访问java对象的属性&#xff0c;给属性赋值&#xff0c;读取属性的值&#xff08;重点&#xff1a;五颗星*****&#xff09;* 本例中使用反射机制编写代码&#xff0c;看起来比不使用…

Hi3861鸿蒙物联网项目实战:智能温度计

华清远见FS-Hi3861开发套件&#xff0c;支持HarmonyOS 3.0系统。开发板主控Hi3861芯片内置WiFi功能&#xff0c;开发板板载资源丰富&#xff0c;包括传感器、执行器、NFC、显示屏等&#xff0c;同时还配套丰富的拓展模块。开发板配套丰富的学习资料&#xff0c;包括全套开发教程…

art-template模板引擎

1、模板引擎的基本概念 1.1、渲染UI结构时遇到的问题 var rows [] $.each(res.data, function (i, item) { // 循环拼接字符串 rows.push(<li class"list-group-item"> item.content <span class"badge cmt-date">评论时间&#xff1a; item…

C++ 使用Socket实现主机间的UDP/TCP通信

前言 完整代码放到github上了&#xff1a;cppSocketDemo 服务器端的代码做了跨平台&#xff08;POSIX和WINDOWS&#xff09;&#xff0c;基于POSIX平台&#xff08;Linux、Mac OS X、PlayStation等&#xff09;使用sys/socket.h库&#xff0c;windows平台使用winsock2.h库。 客…

STM32配置LED模块化

文章目录前言一、LED的模块化二、GPIO初始化详细解析三、LED代码封装总结前言 本篇文章将带大家深入了解GPIO的配置&#xff0c;并带大家实现LED模块化编程。 一、LED的模块化 什么叫模块化编程&#xff1f;我的理解就是每一个模块都分别写成对应的.c和.h文件&#xff0c;有…

S32K144—从0到1一个MBD模型的诞生

一个MBD模型的诞生&#xff0c;分为以下几步&#xff1a; 1、连接好硬件S32K144 EVB 2、选择一个合适的工作空间&#xff0c;新建一个simulink模型&#xff0c;保存 3、在模型中拖入模块&#xff1a; MBD_S32K1xx_Config_Information Digital_Input_ISR Periodic_Interrupt…

C++ 设计模式

设计模式序创建型模式工厂方法模式抽象工厂模式单例模式建造者模式&#xff08;生成器模式&#xff09;原型模式结构型模式适配器模式装饰器代理模式外观模式桥接模式组合模式&#xff08;部分--整体模式&#xff09;享元模式行为型模式策略模式模板模式观察者模式迭代器模式责…

对抗js前端加密的万能方法

1、前言 现在越来越多的网站采用全报文加密&#xff0c;测试的时候需要逆向提取加密算法以及密钥&#xff0c;过程十分繁琐和复杂。本文提供一种更为简单快捷的方法来解决此问题。 原理大致如下&#xff1a;使用浏览器的Override Hook加密前的数据&#xff0c;配置代理地址发…

[Linux]Linux编译器-gcc/g++

&#x1f941;作者&#xff1a; 华丞臧. &#x1f4d5;​​​​专栏&#xff1a;【LINUX】 各位读者老爷如果觉得博主写的不错&#xff0c;请诸位多多支持(点赞收藏关注)。如果有错误的地方&#xff0c;欢迎在评论区指出。 推荐一款刷题网站 &#x1f449; LeetCode刷题网站 文…

SpringBoot+Redis(官方案例)

在线文档项目结构 1.源码克隆&#xff1a;git clone https://github.com/spring-guides/gs-messaging-redis.git 2.包含两个项目initial和complete&#xff0c;initial可以根据文档练习完善&#xff0c;complete是完整项目 3.功能描述&#xff1a;构建应用程序&#xff0c;使用…

【谷粒商城基础篇】商品服务:商品维护

谷粒商城笔记合集 分布式基础篇分布式高级篇高可用集群篇简介&环境搭建项目简介与分布式概念&#xff08;第一、二章&#xff09;基础环境搭建&#xff08;第三章&#xff09;整合SpringCloud整合SpringCloud、SpringCloud alibaba&#xff08;第四、五章&#xff09;前端知…

xxx.lua入门编程

lua入门级编程,openresty的前置技能lua入门级编程,openresty的前置技能 看上图 lua示例&#xff1a; 入门示例 print("hello world!") local arr {"java","mysql","oracle"}; local map {usernamezhangsan,password123}; local fu…

Debezium 同步 PostgreSQL 数据到 RocketMQ 中

1.RocketMQ Connect概览 RocketMQ Connect是RocketMQ数据集成重要组件&#xff0c;可将各种系统中的数据通过高效&#xff0c;可靠&#xff0c;流的方式&#xff0c;流入流出到RocketMQ&#xff0c;它是独立于RocketMQ的一个单独的分布式&#xff0c;可扩展&#xff0c;可容错系…

字节二面:Redis 的大 Key 对持久化有什么影响?

Redis 的持久化方式有两种&#xff1a;AOF 日志和 RDB 快照。 所以接下来&#xff0c;针对这两种持久化方式具体分析分析。 大 Key 对 AOF 日志的影响 先说说 AOF 日志三种写回磁盘的策略 Redis 提供了 3 种 AOF 日志写回硬盘的策略&#xff0c;分别是&#xff1a; Always&am…

Git(四) - Git 分支操作

​​​​​​​ 一、什么是分支 在版本控制过程中&#xff0c;同时推进多个任务&#xff0c;为每个任务&#xff0c;我们就可以创建每个任务的单独分支。使用分支意味着程序员可以把自己的工作从开发主线上分离开来&#xff0c;开发自己分支的时候&#xff0c;不会影响主线分支…

前端面试常考 | js原型与原型链

文章目录一. 什么是原型?二. 什么是原型链?一. 什么是原型? 在js中所有的引用类型都有一个__proto__(隐式原型)属性&#xff0c;属性值是一个普通的对象。 而在js中的引用类型包括&#xff1a;Object&#xff0c;Array&#xff0c;Date&#xff0c;Function 而所有函数都有…

基于K8s的DevOps平台实践(二)

文章目录1. 流水线入门&#x1f351; 流水线基础语法&#x1f351; 脚本示例&#x1f351; 脚本解释&#x1f351; Blue Ocean2. Jenkinsfile实践&#x1f351; 演示一&#x1f351; 演示二&#x1f351; 演示三&#x1f351; 演示四&#x1f351; 总结3. 多分支流水线实践&…