强化学习实战3:Sarsa 与 Q-Learning 算法求解迷宫问题

news2024/9/25 9:40:40

前置知识

首先实验环境依然是我们之前说的迷宫环境,然后是一些基本术语,应该都是比较熟悉的:

在这里插入图片描述
在这里插入图片描述

强化学习的算法大概有两类,一类是策略迭代(讲究的是策略 Π ),还有一类是价值迭代,也就是本节要说的内容。

在价值迭代算法的类型中有两个非常重要的算法,即 Sarsa 和 Q-learning。

本节介绍 Sarsa。

Sarsa 名称的由来:取的是一个小 trajectory 的首字母缩写:state action reward state action。

同时这也很形象的展示了 Sarsa 算法的原理,就是根据不断的重复 state action reward state action 这一系列行为从而进行迭代。

Sarsa 维护了一个 Q-table,而 Q-table 一开始是我们进行初始化的,然后通过价值迭代的方式我们跟环境不断的交互然后不断的更新和学习这个 Q-table,也就是说 Q-table 实际上是一个待学习的参数。因此显然对于 Q-table 而言,其行索引表示状态,列索引则表示行为。最后要注意的一点是 Q-table 的值不是概率分布,就是单纯的值(应该就是 reward 值),也就是最大的 value 显然就是我们最应该采取的 action。

基于 Q-table 我们可以评估 action value 和 state value 。

在使用 Sarsa 算法去更新 Q-table 的时候,我们基于的是贝尔曼方程。

而 Q-learning 和 Sarsa 仅仅只是公式上有一个区别,在下一节中将会进行介绍。

Sarsa 算法各部分实现

Q-table 的实现

# 刻画环境:边界 border 和 障碍 barrier
theta_0 = np.array([
    [np.nan, 1, 1, np.nan],  # 表示S0时的策略,即agent不能往上、不能往左走,但可以往右和下走
    [np.nan, 1, np.nan, 1],
    [np.nan, np.nan, 1, 1],
    [1, np.nan, np.nan, np.nan],
    [np.nan, 1, 1, np.nan],
    [1, np.nan, np.nan, 1],
    [np.nan, 1, np.nan, np.nan],
    [1, 1, np.nan, 1],
    # S8 已经是终点了,因此不再需要上下左右到处走了
])

# ---------------------------Q-table---------------------------------
n_states, n_actions = theta_0.shape

# Q-table,状态是离散的(S0 到 S7),动作也是离散的(上下左右)
# 下面这是元素级别的乘法,也就是对位元素相乘
Q = np.random.rand(n_states, n_actions) * theta_0
print(Q)

输出结果如下:

在这里插入图片描述

这个部分就是单纯的随机初始化一下我们的 Q-table,以便于后面进行迭代更新。

ε-greedy 的实现

# -------------------------ε-greedy--------------------------------
# 对于 ε-greedy,其是一个 探索 和 利用 策略
# 将 theta_0 转换为 策略 Π,而 Π 其实就是概率值嘛
def cvt_theta_0_to_pi(theta):
    m, n = theta.shape
    pi = np.zeros((m, n))
    for r in range(m):
        pi[r, :] = theta[r, :] / np.nansum(theta[r, :])
    return np.nan_to_num(pi)


pi_0 = cvt_theta_0_to_pi(theta_0)


# epsilon-ε
# ε-greedy 这样一个策略是用来选取 action 的
# s 表示当前状态,Q 表示 Q-table, eps 是一个超参数, pi_0 是策略
def get_action(s, Q, eps, pi_0):
    # 动作空间是 0 1 2 3
    action_space = list(range(4))
    # eps, explore
    if np.random.rand() < eps:
        action = np.random.choice(action_space, p=pi_0[s, :])
    else:
        # 1-eps, exploit
        action = np.nanargmax(Q[s, :])
    return action

ε-greedy 策略的作用是用来帮助选取 action 的,对于 get_action 函数其具体的解释如下:

在这里插入图片描述
在这里插入图片描述

Sarsa 算法的实现

接下来介绍 Sarsa 算法,也就是进行 QΠ(s,a) 的算法。

在这里插入图片描述

# --------------------------- Sarsa ---------------------------
# gamma 就是 折扣率 参数, eta 是我们需要给定的超参数
def sarsa(s, a, r, s_next, a_next, Q, eta, gamma):
    if s_next == 8:
        Q[s, a] = Q[s, a] + eta * (r - Q[s, a])
    else:
        Q[s, a] = Q[s, a] + eta * (r + gamma * Q[s_next, a_next] - Q[s, a])

解决迷宫问题

代码封装

封装上述代码如下:

# --------------------------- 环境创建 --------------------------
# MazeEnv 类维护着状态,以及 step 函数的返回
class MazeEnv(gym.Env):
    def __init__(self):
        self.state = 0

    def reset(self):
        self.state = 0
        return self.state

    def step(self, action):
        if action == 0:
            self.state -= 3
        elif action == 1:
            self.state += 1
        elif action == 2:
            self.state += 3
        elif action == 3:
            self.state -= 1
        done = False
        reward = 0
        if self.state == 8:
            done = True
            reward = 1
        return self.state, reward, done, {}


# Agent 类基于当前环境中的状态选择动作形成策略
class Agent:
    def __init__(self):
        # action space
        self.actions = list(range(4))
        # 刻画环境:边界 border 和 障碍 barrier
        self.theta_0 = np.array([
            [np.nan, 1, 1, np.nan],  # 表示S0时的策略,即agent不能往上、不能往左走,但可以往右和下走
            [np.nan, 1, np.nan, 1],
            [np.nan, np.nan, 1, 1],
            [1, np.nan, np.nan, np.nan],
            [np.nan, 1, 1, np.nan],
            [1, np.nan, np.nan, 1],
            [np.nan, 1, np.nan, np.nan],
            [1, 1, np.nan, 1],
            # S8 已经是终点了,因此不再需要上下左右到处走了
        ])
        # 策略 Π
        self.pi = self._cvt_theta_0_to_pi()
        # Q-table
        self.Q = np.random.rand(*self.theta_0.shape) * self.theta_0
        # 超参数
        self.eta = 0.1
        # 折扣率
        self.gamma = 0.9
        # ε-greedy 策略的超参数
        self.eps = 0.5

    # 将 theta_0 转换为 策略 Π,而 Π 其实就是概率值嘛
    def _cvt_theta_0_to_pi(self):
        m, n = self.theta_0.shape
        pi = np.zeros((m, n))
        for r in range(m):
            pi[r, :] = self.theta_0[r, :] / np.nansum(self.theta_0[r, :])
        return np.nan_to_num(pi)

    def get_action(self, s):
        # eps, explore 探索
        if np.random.rand() < self.eps:
            action = np.random.choice(self.actions, p=self.pi[s, :])
        else:
            # 1-eps, exploit 利用
            action = np.nanargmax(self.Q[s, :])
        return action

    def sarsa(self, s, a, r, s_next, a_next):
        if s_next == 8:
            self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])
        else:
            self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * self.Q[s_next, a_next] - self.Q[s, a])

训练

训练代码如下:

# --------------------------------- 训练 ---------------------------------------
maze = MazeEnv()
agent = Agent()
episode = 0
while True:
    # 下面这行代码会创建一个新的一维数组old_Q,其长度与agent.Q的行数(即状态的数量)相同。
    # old_Q中的每个元素都是对应状态下行(动作)中的最大Q值(忽略NaN)。
    """
    np.nanmax会返回数组中所有非NaN元素的最大值
    axis=1:这个参数指定了np.nanmax函数应该沿着哪个轴来计算最大值。
    在NumPy中,二维数组的轴(axis)是一个维度。
    axis=0表示沿着列(垂直方向)计算,而axis=1表示沿着行(水平方向)计算。
    因此,np.nanmax(agent.Q, axis=1)的意思是对于agent.Q中的每一行(即每个状态对应的所有动作),忽略NaN值,找到最大的Q值。
    """
    old_Q = np.nanmax(agent.Q, axis=1)
    # 每一次 episode 开始时都重置状态为初始状态
    s = maze.reset()
    # 通过 ε-greedy 策略选取一个 action
    a = agent.get_action(s)
    # 记录历史 state-action 对儿
    s_a_history = [[s, np.nan]]
    # 循环跑出每一 episode 的 trajectory
    while True:
        # 将列表末尾第一个 state-action 对儿中的 action 更新一下
        s_a_history[-1][1] = a
        s_next, reward, done, _ = maze.step(a)
        s_a_history.append([s_next, np.nan])
        if done:
            a_next = np.nan
        else:
            a_next = agent.get_action(s_next)
        agent.sarsa(s, a, reward, s_next, a_next)
        if done:
            break;
        else:
            a = a_next
            s = maze.state
    # 这行代码计算了智能体Q表中每个状态下新旧最大Q值之间的绝对差异的总和,并将这个总和赋值给变量update
    """
    np.abs(np.nanmax(agent.Q, axis=1) - old_Q):
    这部分代码首先计算np.nanmax(agent.Q, axis=1)(即每个状态下的新最大Q值)与old_Q(即每个状态下的旧最大Q值)之间的差,
    然后使用np.abs函数取这些差的绝对值。
    这样做的目的是消除正负差异的方向性,只关注差异的大小。
    结果是一个一维数组,其元素表示每个状态下新旧最大Q值之间的绝对差异。
    最后,np.sum函数被用来计算上一步得到的差异数组中所有元素的总和。
    这个总和可以被视为一种“更新量”或“变化量”,它量化了从old_Q到当前agent.Q中每个状态下的最大Q值所发生的总体变化。
    """
    update = np.sum(np.abs(np.nanmax(agent.Q, axis=1) - old_Q))
    episode += 1
    agent.eps /= 2
    print(episode, update, len(s_a_history))
    if episode > 100 or update < 1e-5:
        break

# 最终 Q 表的样子
print("-------------------------- Q-table ---------------------------------")
print(agent.Q)

运行效果如下:

在这里插入图片描述
中间省略…
在这里插入图片描述

从输出结果可以看出,前面七轮是在不断震荡的,从第七轮以后就稳定了,agent 只需要七步就可以走到终点。

另外从最后的 Q-table 也可以看出来,agent 在每一个状态都可以以较大概率选取最优的 action,因此只需要七步就可以走到终点。

可视化展现

都是之前介绍过的代码,直接贴出来了:

# 可视化展现
# 创建一个新的图形对象,并设置其大小为 5x5 英寸
fig = plt.figure(figsize=(5, 5))

# 获取当前图形对象的轴对象
ax = plt.gca()

# 设置坐标轴的范围
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)

# 绘制红色的方格边界,表示迷宫的结构
plt.plot([2, 3], [1, 1], color='red', linewidth=2)
plt.plot([0, 1], [1, 1], color='red', linewidth=2)
plt.plot([1, 1], [1, 2], color='red', linewidth=2)
plt.plot([1, 2], [2, 2], color='red', linewidth=2)

# 在指定位置添加文字标签,表示每个状态(S0-S8)、起点和终点
plt.text(0.5, 2.5, 'S0', size=14, ha='center')
plt.text(1.5, 2.5, 'S1', size=14, ha='center')
plt.text(2.5, 2.5, 'S2', size=14, ha='center')
plt.text(0.5, 1.5, 'S3', size=14, ha='center')
plt.text(1.5, 1.5, 'S4', size=14, ha='center')
plt.text(2.5, 1.5, 'S5', size=14, ha='center')
plt.text(0.5, 0.5, 'S6', size=14, ha='center')
plt.text(1.5, 0.5, 'S7', size=14, ha='center')
plt.text(2.5, 0.5, 'S8', size=14, ha='center')
plt.text(0.5, 2.3, 'Start', ha='center')
plt.text(2.5, 0.3, 'Goal', ha='center')

# 设置坐标轴的显示参数,使得坐标轴不显示
plt.tick_params(axis='both', which='both',
                bottom=False, top=False,
                right=False, left=False,
                labelbottom=False, labelleft=False)

# 在起点位置绘制一个绿色的圆形表示当前位置
line, = ax.plot([0.5], [2.5], marker='o', color='g', markersize=60)


def init():
    line.set_data([], [])
    return (line,)


def animate(i):
    state = s_a_history[i][0]
    x = (state % 3) + 0.5
    y = 2.5 - int(state / 3)
    line.set_data(x, y)


anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(s_a_history), interval=200, repeat=False)
anim.save('maze_0.mp4')
# 视频观测有时候不太友好,我们还可以使用 IPython 提供的 HTML 的交互式工具
# 由于 PyCharm 不支持显示 IPython 的交互式输出,因此我们这里将 IPython 的输出转换为 HTML 文件再打开
with open('animation.html', 'w') as f:
    f.write(anim.to_jshtml())

可视化的结果是动态的,这里就不展示了,就是 agent 可以很直接快速的找到最终状态。

Q-Learning 算法

实际上,只需要修改 Sarsa 算法代码中的一行,Sarsa 就变成了 Q-Learning 了。

因此这里一起把 Q-Learning 实现了。

重点关注二者区别,主要是算法思想上的不同,二者公式分别如下:

在这里插入图片描述

从公式上可以知道,Sarsa 是策略依赖型的(on-policy),而 Q-Learning 是策略关闭型的(off-policy)。

核心代码如下:

def q_learning(self, s, a, r, s_next):
	if s_next == 8:
		self.Q[s, a] = self.Q[s, a] + self.eta * (r - self.Q[s, a])
	else:
		self.Q[s, a] = self.Q[s, a] + self.eta * (r + self.gamma * np.nanmax(self.Q[s_next, :]) - self.Q[s, a])

然后训练代码也只需要改一行:

# agent.sarsa(s, a, reward, s_next, a_next)
agent.q_learning(s, a, reward, s_next)

其余同上面的 Sarsa 算法一样,效果也是差不多的,这里不再赘述。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1920573.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电脑 DNS 缓存是什么?如何清除?

DNS&#xff08;Domain Name System&#xff0c;域名系统&#xff09;是互联网的重要组成部分&#xff0c;负责将人类易记的域名转换为机器可读的 IP 地址&#xff0c;从而实现网络通信。DNS 缓存是 DNS 系统中的一个关键机制&#xff0c;通过临时存储已解析的域名信息&#xf…

lnmp+DISCUZ+WORDPRESS

lnmpDISCUZWORDPRESS lnmpDISCUZ&#xff08;论坛的一个服务&#xff09; l&#xff1a;linux操作系统 n&#xff1a;nginx前端页面的web服务 php&#xff1a;动态请求转发的中间件 mysql&#xff1a;数据库 保存用户和密码以及论坛的相关内容 mysql8.0.30安装&#xff1a…

微信综合购物商城小程序ui模板源码

微信电商小程序前端页面&#xff0c;综合购物商城ui界面模板。主要功能包含&#xff1a;电商主页、商品分类、购物车、购物车结算、我的个人中心管理、礼券、签到、新人专享、专栏、商品详情页、我的订单、我的余额、我的积分、我的收藏、我的地址、我的礼券等。这是一款非常齐…

单相整流-TI视频课笔记

目录 1、单相半波整流 1.1、单相半波----电容滤波---超轻负载 1.2、单相半波----电容滤波---轻负载 1.3、单相半波----电容滤波---重负载 2、全波整流 2.1、全波整流的仿真 2.2、半波与全波滤波的对比 3、全桥整流电路 3.1、全波和全桥整流对比 3.2、半波全波和全桥…

【Linux杂货铺】2.进程优先级

1.进程优先级基本概念 进程优先级是操作系统中用于确定进程调度顺序的一个指标。每个进程都会被分配一个优先级&#xff0c;优先级较高的进程会在调度时优先被执行。进程优先级的设定通常根据进程的重要性、紧急程度、资源需求等因素来确定。操作系统会根据进程的优先级来决定进…

出现 failed to remove xxxx: Invalid argument 解决方法

目录 前言1. 问题所示2. 原理分析3. 解决方法 前言 这好像是一个Git的一个Bug&#xff0c;对应有个下下策的解决方式 1. 问题所示 Git提交的时候出现如下问题 Git warning:failed to remove debug.log:invalid argumentgit clean -f -1 --F&#xff1a;\xxx failed to rem…

准备工作+1、请求和响应+2、模型和管理站点

Django快速入门——创建一个基本的投票应用程序 准备工作1、创建虚拟环境2、安装django 1、请求和响应&#xff08;1&#xff09;创建项目&#xff08;2&#xff09;用于开发的简易服务器&#xff08;3&#xff09;创建投票应用&#xff08;4&#xff09;编写第一个视图1、编写…

Python(四)---序列

文章目录 前言1.列表1.1.列表简介1.2.列表的创建1.2.1.基本方式[]1.2.2.list()方法1.2.3.range()创建整数列表1.2.4.推导式生成列表 1.3. 列表各种函数的使用1.3.1.增加元素1.3.2.删除元素1.3.3.元素的访问和计数1.3.4.切片1.3.5.列表的排序 1.4.二维列表 2.元组2.1.元组的简介…

mybatis基础语法

Mybatis快速入门 1.需求 使用MyBatis查询所有的用户, 封装到List集合 2.分析 创建maven工程&#xff08;jar&#xff09;&#xff0c;添加坐标创建pojo创建UserDao接口创建UserDao映射文件创建Mybatis核心配置文件SqlMapConfig.xml编写java代码测试 3.实现 准备工作&…

《梦醒蝶飞:释放Excel函数与公式的力量》11.3 ISTEXT函数

第11章&#xff1a;信息函数 第三节 11.3 ISTEXT函数 11.3.1 简介 ISTEXT函数是Excel中的一个信息函数&#xff0c;用于检查指定单元格中的内容是否为文本。如果单元格内容是文本&#xff0c;则返回TRUE&#xff1b;否则返回FALSE。ISTEXT函数在数据验证、条件格式化和逻辑判…

【排序算法】插入排序(希尔排序)

目录 一.直接插入排序 1.基本思想 2.实现 3.特性 1.效率 2.时间复杂度&#xff1a;O(N^2) 3.空间复杂度&#xff1a;O(1) 4.稳定性&#xff1a;稳定 二.希尔排序 1.基本思想 2.实现 3.特性 1.效率 2.时间复杂度&#xff1a;O(N^1.3) ​编辑 3.空间复杂度&#xff…

AI在软件开发中的角色:辅助创新还是自动化取代?

文章目录 每日一句正能量前言&#xff1a;人工智能与软件开发的未来交汇点AI工具现状AI对开发者的影响工作方式的改变需要掌握的新技能保持竞争力的策略结论 AI开发的未来AI在软件开发领域的未来发展方向AI是否可能完全取代开发者如何在AI时代规划开发者的职业发展结论 后记&am…

【通过pnpm创建vite项目】

vue3最新项目技术构建后台管理系统 一、技术要求二、安装pnpm2.1 构建vite三、项目配置3.1 eslint 配置3.2 prettier配置3.3 stylelint配置3.4 配置husky3.5 配置commitlint3.6 pnpm 强制安装四、Element-plus 引入4.1 完整引入4.2 国际化配置4.3 配置别名4.4 Env环境配置4.5 s…

教育与社会的发展

生产力与教育的关系 政治经济制度与教育的关系 文化和人口与教育的关系

《梦醒蝶飞:释放Excel函数与公式的力量》11.4 ISERROR函数

第11章&#xff1a;信息函数 第四节 11.4 ISERROR函数 11.4.1 简介 ISERROR函数是Excel中的一个信息函数&#xff0c;用于检查指定单元格或表达式是否产生错误。如果单元格或表达式产生任何类型的错误&#xff08;如N/A、VALUE!、REF!等&#xff09;&#xff0c;则返回TRUE&…

子任务:IT运维的精细化管理之道

在当今的企业运营中&#xff0c;信息技术已成为支撑业务发展的核心力量。根据Gartner的报告&#xff0c;IT服务管理&#xff08;ITSM&#xff09;的有效实施可以显著提升企业的运营效率&#xff0c;降低成本高达15%&#xff0c;同时提高服务交付速度和质量。随着业务的复杂性和…

Python中对asyncio的实际使用

前言&#xff1a;一般涉及异步编程我都无脑用celery&#xff0c;但是最近在做一个项目&#xff0c;项目不大&#xff0c;也不涉及定时任务&#xff0c;所以就用了asyncio。 asyncio是python自带的模块&#xff0c;比celery轻量&#xff0c;使用起来也简单。以前学习过&#xf…

java中Error与Exception的区别

java中Error与Exception的区别 1、错误&#xff08;Error&#xff09;1.1 示例 2、 异常&#xff08;Exception&#xff09;2.1 示例 3、 区别总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 当我们谈论编程中的错误&#xff08;Error&…

【LeetCode】917:翻转字符串

方法&#xff1a;双指针 class Solution { public:bool isletter(char ch){if(ch>a&&ch<z)return true;if(ch>A&&ch<Z)return true;return false;}string reverseOnlyLetters(string s) {int lens.size();int left0,rightlen-1;string s1;while(le…

60、基于浅层神经网络的数据拟合(matlab)

1、基于浅层神经网络的数据拟合的简介、原理以及matlab实现 1&#xff09;内容说明 基于浅层神经网络的数据拟合是一种常见的机器学习方法&#xff0c;用于通过输入数据来拟合一个非线性函数。这种方法通常包括一个输入层、一个或多个隐藏层和一个输出层。神经网络通过学习权…