强化学习Q-learning实践

news2025/1/31 20:51:29

1. 引言

前篇文章介绍了强化学习系统红的基本概念和重要组成部分,并解释了Q-learning算法相关的理论知识。本文的目标是在Python3中实现该算法,并将其应用于实际的实验中。
闲话少说,我们直接开始吧!

2. Taxi-v3 Env

为了使本文具有实际具体的意义,特意选择了一个简单而基本的环境,可以让大家充分欣赏Q-learning算法的优雅。我们选择的环境是OpenAI GymTaxi-v3,该环境简单明了,是强化学习RL领域的优秀入门样例。实际上Taxi-v3由一个grid map组成,如下图示:
在这里插入图片描述

其中,该环境下的agent是一名出租车司机,他必须接客户(红色小人)并将其送到目的地(图中的小房子)。

3. States

一版来说,States的作用如下 (1) 确定action(2)计算执行action的奖励reward(3)计算到下一状态的转换所需的信息。

观察上图,我们的网格grid map的大小为5x5,所以出租车所有可能的选择有25个。除此之外,等待接车的乘客可以在四个可能的接车点(标记为Y、R、G、B)处等待当然也可以在出租车里,所以乘客所有可能的选择有(4+1)个;最后,乘客的目的地在(Y、R、G、B)四个中的一个,所以乘客的目的地共有4个选择,图示如下:
在这里插入图片描述

综上,我们用以下向量表示States

State = [x_pos_taxi, y_pos_taxi, pos_passenger, dest_passenger]

进而,我们agentStates一共有5X5X5X4=500个,可以被编码为0到499之间的整数。其实,实际可用的状态的数量略小于500,例如,乘客将永远不会有相同的乘车点和目的地。由于建模的复杂性,我们通常关注完整的状态空间。

4. 举个栗子

上述文字讲完后,有些同学还是有很多不理解的东东,那我们来找个中间过程来看看,如下:
在这里插入图片描述
上图中,STATE:(2,1,0,1)表示,当前出租车在grid map中的第二行第一列,同时乘客的状态选择为0表示位于乘客位于红色格子里等待乘车,同时乘客的目的地状态选择为1表示乘客的目的地为绿色格子。
进而,下图中的STATE:(3,4,4,0)表示,当前出租车在grid map中的第三行第四列,同时乘客的状态选择为4表示此时乘客位于出租车里,同时乘客的目的地状态选择为0 ,表示乘客的目的地为红色格子。

看到这里的童鞋,请仔细理解上述两个例子。

5. Action space

至于该环境Envagent的动作空间Action space,我们可以想象,代理agent可以使用以下离散动作来与环境交互:向前、向后、向右、向左、接乘客和送乘客。这使得总共有6个可能的动作,这些动作依次以0到5的数字编码,以便于编程。动作和数字之间的对应关系如图1所示。
在这里插入图片描述

6. Rewards

至于agent执行的每一步action所获得的奖励reward,做如下约定:

  • 移动:-1, 表示每一步都会受到一点惩罚,以鼓励从出发地到目的地走最短的路。
  • 错误运送:-10, 表示当乘客被送到到错误的位置时,乘客自然会不高兴,所以惩罚大一些是合适的。
  • 成功送达:20,表示出租车司机成功完成了任务,鼓励相应的行为,因此产生了正向的reward

7. Initialization

在数学上定义了这个问题之后,我们接着将着手用代码实现。首先,我们安装必要的库,然后导入它们。显然,我们需要安装gym 环境。除此之外,我们只需要一些可视化的东西和常见的数据处理库。

"""install libraries"""
!pip install cmake 'gym[atari]' scipy pygame

"""Import libraries"""
import gym
import numpy as np
import matplotlib.pyplot as plt
import random
from IPython.display import clear_output
from time import sleep
from matplotlib import animation

接着,我们使用以下代码来创建和渲染Taxi-v3环境。

"""Initialize and validate the environment"""
env = gym.make("Taxi-v3", render_mode="rgb_array").env
state, _ = env.reset()

# Print dimensions of state and action space
print("State space: {}".format(env.observation_space))
print("Action space: {}".format(env.action_space))

# Sample random action
action = env.action_space.sample(env.action_mask(state))
next_state, reward, done, _, _ = env.step(action)

# Print output
print("State: {}".format(state))
print("Action: {}".format(action))
print("Action mask: {}".format(env.action_mask(state)))
print("Reward: {}".format(reward))

# Render and plot an environment frame
frame = env.render()
plt.imshow(frame)
plt.axis("off")
plt.show()

结果如下:
在这里插入图片描述

8. 测试随机agent

在上述环境Env按照预期开始工作后,此时我们可以随机让代理疯狂运行了。我们不妨让我们的agent在任何时刻都会采取随机行动,来看看会产生怎样的效果。

"""Simulation with random agent"""
epoch = 0
num_failed_dropoffs = 0
experience_buffer = []
cum_reward = 0

done = False

state, _ = env.reset()

while not done:
    # Sample random action
    "Action selection without action mask"
    action = env.action_space.sample()

    "Action selection with action mask"
    #action = env.action_space.sample(env.action_mask(state))

    state, reward, done, _, _ = env.step(action)
    cum_reward += reward

    # Store experience in dictionary
    experience_buffer.append({
        "frame": env.render(),
        "episode": 1,
        "epoch": epoch,
        "state": state,
        "action": action,
        "reward": cum_reward,
        }
    )

    if reward == -10:
        num_failed_dropoffs += 1

    epoch += 1

# Run animation and print console output
run_animation(experience_buffer)

print("# epochs: {}".format(epoch))
print("# failed drop-offs: {}".format(num_failed_dropoffs))

运行上述代码后,得到结果如下:
在这里插入图片描述
为什么要看完上述冗长的动画?好吧,这确实给人一种印象,一个未经训练的RL模型下的agent是如何表现的,以及需要多长时间才能获得有意义的reward

9. 训练agent

接着,我们来尝试训练我们的agent,我们知道Q值是在进行观测之后使用以下等式进行更新的。
在这里插入图片描述

请注意,对于500个状态和6个动作,我们必须填写一个大小为500*6=3000的Q表,每个状态-动作二元组需要多次观察才能学到有用的知识。相应的训练代码如下:

"""Training the agent"""
q_table = np.zeros([env.observation_space.n, env.action_space.n])

# Hyperparameters
alpha = 0.1  # Learning rate
gamma = 1.0  # Discount rate
epsilon = 0.1  # Exploration rate
num_episodes = 10000  # Number of episodes

# Output for plots
cum_rewards = np.zeros([num_episodes])
total_epochs = np.zeros([num_episodes])

for episode in range(1, num_episodes+1):
    # Reset environment
    state, info = env.reset()
    epoch = 0 
    num_failed_dropoffs = 0
    done = False
    cum_reward = 0

    while not done:
        
        if random.uniform(0, 1) < epsilon:
            "Basic exploration [~0.47m]"
            action = env.action_space.sample() # Sample random action (exploration)
            
            "Exploration with action mask [~1.52m]"
          # action = env.action_space.sample(env.action_mask(state)) "Exploration with action mask"
        else:      
            "Exploitation with action mask [~1m52s]"
           # action_mask = np.where(info["action_mask"]==1,0,1) # invert
           # masked_q_values = np.ma.array(q_table[state], mask=action_mask, dtype=np.float32)
           # action = np.ma.argmax(masked_q_values, axis=0)

            "Exploitation with random tie breaker [~1m19s]"
          #  action = np.random.choice(np.flatnonzero(q_table[state] == q_table[state].max()))
            
            "Basic exploitation [~47s]"
            action = np.argmax(q_table[state]) # Select best known action (exploitation)
 
        next_state, reward, done, _ , info = env.step(action) 

        cum_reward += reward
        
        old_q_value = q_table[state, action]
        next_max = np.max(q_table[next_state])
        
        new_q_value = (1 - alpha) * old_q_value + alpha * (reward + gamma * next_max)
        
        q_table[state, action] = new_q_value
        
        if reward == -10:
            num_failed_dropoffs += 1

        state = next_state
        epoch += 1
        
        total_epochs[episode-1] = epoch
        cum_rewards[episode-1] = cum_reward

    if episode % 100 == 0:
        clear_output(wait=True)
        print(f"Episode #: {episode}")

print("\n")
print("===Training completed.===\n")

# Plot reward convergence
plt.title("Cumulative reward per episode")
plt.xlabel("Episode")
plt.ylabel("Cumulative reward")
plt.plot(cum_rewards)
plt.show()

# Plot epoch convergence
plt.title("# epochs per episode")
plt.xlabel("Episode")
plt.ylabel("# epochs")
plt.plot(total_epochs)
plt.show()

在2000个episode 之后,我们似乎学到了一个相当好的模型,如下:
在这里插入图片描述
上图图像中,横坐标表示我们一共训练了10000Episode,纵坐标表示每一个Episode下,出租车司机将乘客送达目的地所需要的移动步数epochs

10. 验证训练效果

最后,让我们看看我们的模型学到了什么。根据我们所处的状态,我们在Q表中查找相应的Q值(即,每个状态对应于动作的六个值),并选择具有最高相关Q值的动作。代码如下:

"""Test policy performance after training"""

num_epochs = 0
total_failed_deliveries = 0
num_episodes = 1
experience_buffer = []
store_gif = True

for episode in range(1, num_episodes+1):
    # Initialize experience buffer

    my_env = env.reset()
    state = my_env[0]
    epoch = 1 
    num_failed_deliveries =0
    cum_reward = 0
    done = False

    while not done:
        action = np.argmax(q_table[state])
        state, reward, done, _, _ = env.step(action)
        cum_reward += reward

        if reward == -10:
            num_failed_deliveries += 1

        # Store rendered frame in animation dictionary
        experience_buffer.append({
            'frame': env.render(),
            'episode': episode,
            'epoch': epoch,
            'state': state,
            'action': action,
            'reward': cum_reward
            }
        )

        epoch += 1

    total_failed_deliveries += num_failed_deliveries
    num_epochs += epoch

    if store_gif:
        store_episode_as_gif(experience_buffer)

# Run animation and print output
run_animation(experience_buffer)

# Print final results
print("\n") 
print(f"Test results after {num_episodes} episodes:")
print(f"Mean # epochs per episode: {num_epochs / num_episodes}")
print(f"Mean # failed drop-offs per episode: {total_failed_deliveries / num_episodes}")

结果如下:
在这里插入图片描述

可以看到,在执行了足够多的迭代之后,我们可以发现出租车总是直接驶向乘客,走最短的路到达目的地,并成功地将乘客放下。

11. 总结

本文通过具体的应用,来对前篇Q-learning的理论知识用代码进行了详细的说明,主要通过业内知名的Taxi-v3环境进行了讲解,并给出了完整的代码示例。

您学废了嘛?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/631581.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文讲完Java常用设计模式(23种)

介绍 设计模式的起源可以追溯到20世纪80年代&#xff0c;当时面向对象编程开始流行。在这个时期&#xff0c;一些软件开发者开始注意到他们在不同的项目中遇到了相同的问题&#xff0c;并且他们开始寻找可重用的解决方案。这些解决方案被称为设计模式。最早提出设计模式的人是…

centos7的docker安装与简单介绍

docker的基本组成&#xff08;三要素&#xff09; 镜像容器仓库 理解&#xff1a;镜像可以理解成一个类&#xff0c;容器就是用这个类new出来的对象&#xff0c;仓库就是放镜像文件的。docker本身是容器运行载体或管理引擎 安装 安装gcc yum -y install gcc安装需要的软件…

Vcpkg介绍及使用

Vcpkg用于在Windows、Linux、Mac上管理C和C库&#xff0c;极大简化了第三方库的安装&#xff0c;它由微软开源&#xff0c;源码地址&#xff1a;https://github.com/Microsoft/vcpkg&#xff0c;最新发布版本为2023.04.15 Release&#xff0c;它的license为MIT。 在windows上安…

[解决方案]springboot怎么接受encode后的参数(参数通过=拼接)

springboot怎么接受encode后的参数(拼接& springboot怎么接受encode后的参数(拼接&)问题出现原因发送encode后的值在postman里面的情况这个时候该如何接受呢&#xff08;encode后的值接受&#xff09;controller层的代码用到的工具类CRequest springboot怎么接受encode…

软考A计划-系统架构师-官方考试指定教程-(14/15)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

数组删除元素使用remove最优的方法

Array.prototype.remove function(from, to) { var rest this.slice((to || from) 1 || this.length); this.length from < 0 ? this.length from : from; return this.push.apply(this, rest); };

Anaconda安装及入门教程(Windows、Ubuntu)

文章目录 安装Anaconda3UbuntuWindows 使用换源设置不自动启用conda环境显示环境创建环境激活环境查找某个包的版本安装某个版本的包 虚拟环境中安装包删除虚拟环境删除特定的包复制环境设置代理UbuntuWindows 使用 conda-pack 离线导入、导出环境安装conda-pack导出导入 安装A…

简单使用nacos、openFeign和Sentinel(建议看源码和截图一起看)

1、Nacos 1、下载nacos&#xff0c;可以从结尾获取压缩包和源码 2、下方图例是两个服务程序注册成功到注册中心&#xff0c;并且配置从配置中心拉取&#xff0c;成功访问数据库 3、nacos中配置项里的内容 spring:datasource:driver-class-name: com.mysql.cj.jdbc.Driv…

在linux上做移动开发必须知道这五个

导读随着越来越多的人依靠手机进行各种业务&#xff0c;移动应用开发的重要性也在不断增加。虽然他们与桌面应用程序有很多相似之处&#xff0c;但移动应用程序本身也具有一系列挑战和特殊性。因此&#xff0c;希望在当前市场找到有利就业的程序员将需要利用和发展当前需求的技…

第五章 结构化设计

结构化设计的概念 1. 设计的定义 一种软件开发活动&#xff0c;定义实现需求规约所需的软件结构。 结构化设计分为&#xff1a; (1)总体设计&#xff1a;确定系统的整体模块结构&#xff0c;即系统实现所需要的软件模块以及这些模块之间的调用关系。 (2)详细设计&#xff1a;…

从Referer到XMLHttpRequest:探究Web安全中的重要知识点

目录 Referer 概念 Referrer-policy&#xff08;可以一定程度上防御CSRF攻击&#xff09; 同源 iframe sandbox(沙箱): cookie的原理&#xff1a; 如何设置Referrer&#xff1f; 盗链 盗链的工作原理 三种情况下可以引用图片&#xff1a; XMLHTTPRequest AJAX&…

初出茅庐的小李博客之STM32F103实现CAN通信

CAN通信基础知识 参考上一篇博客 https://editor.csdn.net/md/?articleId131026450 原理图 转换芯片是 TJA1050 代码实现思路 发送思路&#xff1a;定时发送 按键测试发送 接收思路&#xff1a;中断接收 CAN代码实现 第一步 定义了两个全局变量TxMessage和RxMessage&am…

Unity如何实现Microphone判断录入音频的频率是低音还是高音

前言 Unity中使用Microphone可以通过麦克风录制AudioClip音频,我们可以通过它实现录音功能,然后可以通过录入的音频数据对音频进行分析,比如音量大小,频率高低,等等。 我们今天就来分析一下音频的高音低音。 如何判断高音低音 科普:一般人们习惯将音响划分一定的频段…

参数估计(点估计和区间估计)

参数估计&#xff08;点估计和区间估计&#xff09; 1.1 点估计 点估计的理解示意图 下图中样本均值就是对总体均值的点估计 1.1.1 矩估计 关于什么是矩&#xff1f;可以参考马同学。传送门&#xff1a;如何理解概率论中的“矩”&#xff1f; 根据大数定律&#xff0c;样本…

007: vue中修改el-select选中颜色不生效的办法

第007个 查看专栏目录: 按照VUE知识点 ------ 按照element UI知识点 echarts&#xff0c;openlayers&#xff0c;cesium&#xff0c;leaflet&#xff0c;mapbox&#xff0c;d3&#xff0c;canvas 免费交流社区 专栏目标 在vue和element UI联合技术栈的操控下&#xff0c;本专栏…

apm装机教程(一):无人车

文章目录 前言一、下载固件二、设置参数 前言 APM4.2.3 差速小车 pix2.4.8 MP地面站 一、下载固件 pix2.4.8使用的是fmuv3的固件&#xff0c;可以在官网上下载 https://firmware.ardupilot.org/Rover/stable-4.2.3/fmuv3/ 或者从我的网盘下载 链接&#xff1a;https://pan.b…

【算法题】面试题 01.01. 判定字符是否唯一

题目来源&#xff1a;《程序员面试金典&#xff08;第 6 版&#xff09;》 1、Problem: 面试题 01.01. 判定字符是否唯一 文章目录 面试题 01.01. 判定字符是否唯一一、题目描述二、解决方案&#xff08;一&#xff09;方案一1、解题思路2、解题方法3、复杂度4、代码实现 &…

数据分析手册-R语言

1 数据准备 1.1 数据录入 1.1.1 R中导入数据 &#xff08;1&#xff09;导入excel 1.2 修改工作路径 https://wenku.baidu.com/view/2785ff37b4360b4c2e3f5727a5e9856a561226c9.html 1.3 如何计算样本量 https://www.bilibili.com/video/BV19t411u7TW?spm_id_from333.…

亚马逊云科技Serverless构建的实时数仓解决方案,助力猎豹降低30%成本

也许你也听过这样一句话&#xff1a;“21世纪什么最贵&#xff1f;人才&#xff01;”当数字经济全面席卷而来&#xff0c;这个问题的答案不可置否地变为了“数据”。通过数据分析获取近乎实时的洞察&#xff0c;以驱动业务的全流程&#xff0c;是企业数字化转型的必经之路。借…

【网络编程】计算机网络基础知识总结 | 运输层 |TCP协议

文章目录 前言一、计算机网络层次结构二、网络层三、运输层3.1、TCP/IP协议介绍3.2、端口&#xff08;协议端口号&#xff09;3.3、套接字3.4、TCP实现原理3.4.1、TCP的特点3.4.2、停止等待协议3.4.3、滑动窗口协议3.4.4、拥塞控制3.4.5、TCP连接的三个阶段 3.5、UDP实现原理 前…