Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

news2025/1/10 16:41:50

目录

  • 0 专栏介绍
  • 1 Q-Learning算法原理
  • 2 强化学习基本框架
  • 3 机器人走迷宫算法
    • 3.1 迷宫环境
    • 3.2 状态、动作和奖励
    • 3.3 Q-Learning算法实现
    • 3.4 完成训练
  • 4 算法分析
    • 4.1 Q-Table
    • 4.2 奖励曲线

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 Q-Learning算法原理

在Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)介绍到时序差分强化学习是动态规划与蒙特卡洛的折中

Q π ( s t , a t ) = n 次增量 Q π ( s t , a t ) + α ( R t − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) ⏟ 采样 \begin{aligned}Q^{\pi}\left( s_t,a_t \right) &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( R_t-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+\gamma R_{t+1}-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}{ \underset{\text{采样}}{\underbrace{Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }}}\end{aligned} Qπ(st,at)n次增量 Qπ(st,at)+α(RtQπ(st,at))n次增量 Qπ(st,at)+α(rt+1+γRt+1Qπ(st,at))n次增量 采样 Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)Qπ(st,at))

其中 r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) r_{t+1}+\gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) -Q^{\pi}\left( s_t,a_t \right) rt+1+γQπ(st+1,at+1)Qπ(st,at)称为时序差分误差。基于离轨策略的时序差分强化学习的代表性算法是Q-learning算法,其算法流程如下所示。具体的策略改进算法推导请见之前的文章,本文重点在于应用Q-learning算法解决实际问题

在这里插入图片描述

我们先来看看最终实现的效果

训练前
在这里插入图片描述

训练后

在这里插入图片描述

接下来详细讲解如何一步步实现这个智能体

2 强化学习基本框架

强化学习(Reinforcement Learning, RL)在潜在的不确定复杂环境中,训练一个最优决策 π \pi π指导一系列行动实现目标最优化的机器学习方法。在初始情况下,没有训练数据告诉强化学习智能体并不知道在环境中应该针对何种状态采取什么行动,而是通过不断试错得到最终结果,再反馈修正之前采取的策略,因此强化学习某种意义上可以视为具有“延迟标记信息”的监督学习问题。

在这里插入图片描述

强化学习的基本过程是:智能体对环境采取某种行动 a a a,观察到环境状态发生转移 s 0 → s s_0\rightarrow s s0s,反馈给智能体转移后的状态 s s s和对这种转移的奖赏 r r r。综上所述,一个强化学习任务可以用四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=S,A,P,R表征

  • 状态空间 S S S:每个状态 s ∈ S s \in S sS是智能体对感知环境的描述;
  • 动作空间 A A A:每个动作 a ∈ A a \in A aA是智能体能够采取的行动;
  • 状态转移概率 P P P:某个动作 a ∈ A a \in A aA作用于处在某个状态 s ∈ S s \in S sS的环境中,使环境按某种概率分布 P P P转换到另一个状态;
  • 奖赏函数 R R R:表示智能体对状态 s ∈ S s \in S sS下采取动作 a ∈ A a \in A aA导致状态转移的期望度,通常 r > 0 r>0 r>0为期望行动, r < 0 r<0 r<0为非期望行动。

所以,程序上也需要依次实现四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=S,A,P,R

3 机器人走迷宫算法

3.1 迷宫环境

我们创建的迷宫包含障碍物、起点和终点

class Maze(tk.Tk, object):
    '''
    * @breif: 迷宫环境类
    * @param[in]: None
    '''    
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze game')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self.buildMaze()

    '''
    * @breif: 创建迷宫
    '''
    def buildMaze(self):
        self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT)
        # 网格地图
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建原点坐标
        origin = np.array([20, 20])

        # 创建障碍
        barrier_list = [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0),
                        (0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6),
                        (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 1), (6, 2),
                        (6, 3), (6, 4), (6, 5), (1, 2), (2, 2), (4, 1), (5, 4),
                        (1, 4), (3, 3)]
        self.barriers = [self.creatObject(origin, *index) for index in barrier_list]

        # 创建终点
        self.terminus = self.creatObject(origin, 5, 5, 'blue')

3.2 状态、动作和奖励

机器人的状态可以设置为当前的位置坐标

s = self.canvas.coords(self.agent)

机器人的动作可以设为上、下左、右

if action == 0:   # up
      if s[1] > UNIT:
          base_action[1] -= UNIT
  elif action == 1:   # down
      if s[1] < (MAZE_H - 1) * UNIT:
          base_action[1] += UNIT
  elif action == 2:   # right
      if s[0] < (MAZE_W - 1) * UNIT:
          base_action[0] += UNIT
  elif action == 3:   # left
      if s[0] > UNIT:
          base_action[0] -= UNIT

机器人的奖励设置为以下几种:

  • 碰到障碍物:-10分,并进入终止状态
  • 成功到达终点: +50分,并进入终止状态
  • 未到达终点:-1分,能量耗散惩罚,防止机器人原地振荡
if s_ in [self.canvas.coords(barrier) for barrier in self.barriers]:
   reward = -10
   done = True
   s_ = 'terminal'
elif s_ == self.canvas.coords(self.terminus):
   reward = 50
   done = True
   s_ = 'terminal'
else:
   reward = -1
   done = False

3.3 Q-Learning算法实现

根据算法流程,实现下面的Q-Learning训练函数

def train(self, env, episodes=1000, reward_curve=[], file=None):
	with tqdm(range(episodes)) as bar:
	    for _ in bar:
	        # 初始化环境和该幕累计奖赏
	        state = env.reset()
	        acc_reward = 0
	        while True:
	            # 刷新环境
	            env.render()
	            # 采样一个动作并进行状态转移
	            action = self.policySample(str(state))
	            next_state, reward, done = env.step(action)
	            acc_reward += reward
	            # 智能体学习策略
	            self.learn(str(state), action, reward, str(next_state))
	            state = next_state
	            if done:
	                reward_curve.append(acc_reward)
	                break
	# 保存策略
	if not file:
	    self.q_table.to_csv(file)
	env.destroy()

3.4 完成训练

训练过程如下所示,完成后保存权重文件

if __name__ == "__main__":
    env = Maze()
    agent = Agent(actions=list(range(env.n_actions)))
    reward_curve = []

    # 训练智能体
    env.after(100, agent.train, env, 50, reward_curve, './weight/csv')

    # 主循环
    env.mainloop()

4 算法分析

4.1 Q-Table

在Q-Learning算法中,我们需要维护一个Q-Table,用来记录各种状态和动作的价值。Q-Table是一个二维表格,其中每一行表示一个状态,每一列表示一个动作。Q-Table中的值表示某个状态下执行某个动作所获得的回报(或者预期回报)。Q-Table的更新是Q-Learning算法的核心。在每次执行动作后,我们会根据当前状态、执行的动作、获得的奖励和下一个状态,来更新Q-Table中对应的值,更新方式是

Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) Q^{\pi}\left( s_t,a_t \right) ={ {Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }} Qπ(st,at)=Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)Qπ(st,at))

对应代码

self.q_table.loc[state, action] += self.lr * (q_target - q_predict)

在这里插入图片描述

保存的权重文件正是Q-Table,我们可以直观地看一下,其中0-3指的是上下左右四个动作,每行行首则是状态值,其余数是Q-Value

,0,1,2,3
"[45.0, 45.0, 75.0, 75.0]",-3.764746051087998,-4.129632180625153,2.070923999854885,-4.129632180625153
terminal,0.0,0.0,0.0,0.0
"[85.0, 45.0, 115.0, 75.0]",-3.7017636879676745,-3.2427095093971663,6.341493354722148,-2.4376270354451357
"[125.0, 45.0, 155.0, 75.0]",-2.822694674017249,12.009385340227768,-3.10550914130922,-1.7370066390489591
"[125.0, 85.0, 155.0, 115.0]",-1.018256983413196,-2.3765728565289628,19.23732307528551,-2.602996266117196
"[165.0, 85.0, 195.0, 115.0]",-2.063857163563445,27.370237164958994,-0.7307141976318489,0.14330394709222574
"[205.0, 85.0, 235.0, 115.0]",-0.4546075907459214,-0.45498153729692925,-0.490099501,0.3662096391980347
"[165.0, 125.0, 195.0, 155.0]",0.9791630128216775,35.427315495348594,-0.28782126600827374,-1.7383137616441329
"[205.0, 45.0, 235.0, 75.0]",-0.3940399,-0.38288265597631166,-0.3940399,-0.3940399
"[205.0, 125.0, 235.0, 155.0]",-0.31765122402993484,-0.3940399,-0.3940399,1.5298899806741253
...

4.2 奖励曲线

训练过程的奖励曲线如下所示

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1321274.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AntDB-T提升查询性能的关键之查询优化解析

查询优化器是提升查询效率非常重要的手段&#xff0c;本文将主要介绍AntDB-T数据库查询优化的相关设计。AntDB-T数据库是一款企业级通用分布式关系型数据库&#xff0c;而查询是AntDB-T数据库管理系统中最关键、最吸引人的功能之一。每个生产数据库系统每天都需要处理大量的查询…

python【matplotlib】鼠标拖动滚动缩放坐标范围和拖动图例共存

背景 根据前面的博文&#xff1a; python【matplotlib】画图鼠标缩放拖动动态改变坐标轴范围 和Python【Matplotlib】图例可拖动改变位置 两个博文&#xff0c;博主考虑了一下&#xff0c;如何将两者的功能结合起来&#xff0c;让二者共存。 只需根据Python【Matplotlib】鼠标…

【音视频 | AAC】AAC音频编码详解

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

【GD32307E-START】06 ST7735 SPI-LCD显示模块移植

软硬件平台 GD32F307E-START Board开发板GCC Makefile1.8寸TFTLCD 分辨率128*160 驱动IC ST7735S 接口定义 序号引脚标号说明1GND接地2VCC5V/3.3V电源输入3SCKSPI总线时钟信号4SDASPI总线写数据信号5RESET液晶屏复位信号&#xff0c;低电平复位6DC液晶屏寄存器/数据选择信…

LinuxCNC系统安装

首先我们需要准备一个U盘来安装系统&#xff0c;然后进入Debian官网。操作系统处&#xff0c;点击“下载Debian”。 如果需要下载其他比较全版本&#xff0c;可以点击“其他下载链接”&#xff0c;选择DVD的安装&#xff0c;因为是国外的网站&#xff0c;最好不要选择网络安装。…

【每日OJ—有效的括号(栈)】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 1、有效的括号题目&#xff1a; 1.1方法讲解&#xff1a; 1.2代码实现&#xff1a; 总结 前言 世上有两种耀眼的光芒&#xff0c;一种是正在升起的太阳&#…

机器学习:增强式学习Reinforcement learning

收集有标签数据比较困难的时候同时也不知道什么答案是比较好的时候可以考虑使用强化学习通过互动&#xff0c;机器可以自己知道什么结果是好的&#xff0c;什么结果是坏的 Outline 什么是RL Action就是一个functionEnvironment就是告诉这个Action是好的还是坏的 例子 Space i…

2023年度IT168技术卓越奖名单:亚信安慧AntDB数据库

信创卓越贡献奖&#xff1a;湖南亚信安慧科技有限公司 一句话点评&#xff1a;亚信安慧的核心交易数据库AntDB具有应用时间久&#xff08;15年&#xff09;、运行节点多&#xff08;2000&#xff09;、数据规模大&#xff08;PB级&#xff09;、产品稳定可靠&#xff08;500项目…

L1-050:倒数第N个字符串

题目描述 给定一个完全由小写英文字母组成的字符串等差递增序列&#xff0c;该序列中的每个字符串的长度固定为 L&#xff0c;从 L 个 a 开始&#xff0c;以 1 为步长递增。例如当 L 为 3 时&#xff0c;序列为 { aaa, aab, aac, ..., aaz, aba, abb, ..., abz, ..., zzz }。这…

Spring Cloud + Vue前后端分离-第6章 通用代码生成器开发

Spring Cloud Vue前后端分离-第6章 通用代码生成器开发 6-1 代码生成器原理介绍 1.增加generator模块&#xff0c;用于代码生成 2.集成freemarker 通用代码生成器开发 FreeMarker 是一款模版引擎&#xff0c;通过模板生成文件&#xff0c;包括html页面&#xff0c;excel …

【经典LeetCode算法题目专栏分类】【第5期】贪心算法:分发饼干、跳跃游戏、模拟行走机器人

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能AI、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 分发饼干 class Solutio…

万兆网络之线路测速

网络测速有很多种方式&#xff0c;建议使用开源的iperf搭建测试 官方&#xff1a;iperf3&#xff08;技术网站一般不被和谐&#xff0c;有部分可能被污染&#xff09; Windows下载后解压即可运行 小技巧&#xff1a;如果你用的笔记本只有一个C盘&#xff0c;最好将免安装的软…

Zotero插件安装、问题、bug大全(随时更新)

Zotero插件安装、问题、bug大全&#xff08;随时更新&#xff09; 1. 插件安装2. 茉莉花&#xff08;Jasminum&#xff09;插件使用tips及可能遇到的问题2.1 更新2.2 未找到PDFtk Server的可执行文件 问题解决方法 3. Zotero Sci-hub插件相关问题3.1 Zotero Sci-hub插件有时抓取…

iOS问题记录 - iOS 17通过NSUserDefaults设置UserAgent无效

文章目录 前言开发环境问题描述问题分析解决方案最后 前言 最近维护一个老项目时遇到的问题。说起这老项目我就有点头疼&#xff0c;一个快十年前的项目&#xff0c;这么说你可能不觉得有什么&#xff0c;但是你想想Swift也才发布不到十年&#xff08;2014年6月发布&#xff0…

DS排序--快速排序

Description 给出一个数据序列&#xff0c;使用快速排序算法进行从小到大的排序 排序方式&#xff1a;以区间第一个数字为枢轴记录 输出方式&#xff1a;每一步区间排序&#xff0c;都输出整个数组 –程序要求– 若使用C只能include一个头文件iostream&#xff1b;若使用C…

深度学习笔记_7经典网络模型LSTM解决FashionMNIST分类问题

1、 调用模型库&#xff0c;定义参数&#xff0c;做数据预处理 import numpy as np import torch from torchvision.datasets import FashionMNIST import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn.functional as F im…

2000年AMC8数学竞赛中英文真题典型考题、考点分析和答案解析

今天是2023年12月19日&#xff0c;距离2024年的AMC8正式考试倒计时一个月。 从战争中学习战争最有效。前几天&#xff0c;六分成长分析了2023年、2022年、2020、2019、2018、2017的AMC8真题的典型考题、考点和详细答案解析。 今天我们不再从2016年分析&#xff0c;来看看更早…

pytorch文本分类(三)模型框架(DNNtextCNN)

pytorch文本分类&#xff08;三&#xff09;模型框架&#xff08;DNN&textCNN&#xff09; 原任务链接 目录 pytorch文本分类&#xff08;三&#xff09;模型框架&#xff08;DNN&textCNN&#xff09;1. 背景知识深度学习 2. DNN2.1 从感知器到神经网络2.2 DNN的基本…

避坑指南:uni-forms表单在uni-app中的实践经验

​&#x1f308;个人主页&#xff1a;前端青山 &#x1f525;系列专栏&#xff1a;uni-app篇 &#x1f516;人终将被年少不可得之物困其一生 依旧青山,本期给大家带来JavaScript篇专栏内容:uni-app中forms表单的避坑指南篇 该篇章已被前端圈子收录,点此处进入即可查看更多优质内…

Pytorch nn.Linear()的基本用法与原理详解及全连接层简介

主要引用参考&#xff1a; https://blog.csdn.net/zhaohongfei_358/article/details/122797190 https://blog.csdn.net/weixin_43135178/article/details/118735850 nn.Linear的基本定义 nn.Linear定义一个神经网络的线性层&#xff0c;方法签名如下&#xff1a; torch.nn.Li…