Reinforcement Learning with Code 【Code 1. Tabular Q-learning】

news2025/1/18 19:02:38

Reinforcement Learning with Code 【Code 1. Tabular Q-learning】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning.
This code refers to Mofan’s reinforcement learning course.

文章目录

  • Reinforcement Learning with Code 【Code 1. Tabular Q-learning】
    • 1.1 Problem and result
    • 1.2 Environment
    • 1.3 Tabular Q-learning Algorithm
    • 1.3 Run this main
    • Reference

1.1 Problem and result

Please consider the problem that a little mouse (denoted by red block) wants to avoid trap (denoted by black block) to get the cheese (denoted by yellow circle). As the figure shows.

Image

This chapter aims to realize tabular Q-learning algorithm sovle this problem.

1.2 Environment

We use the tkinter package of python to build our environment to interact with agent.

import numpy as np
import time
import sys
import tkinter as tk
# if sys.version_info.major == 2: # 检查python版本是否是python2
#     import Tkinter as tk
# else:
#     import tkinter as tk


UNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        # Action Space
        self.action_space = ['up', 'down', 'right', 'left'] # action space 
        self.n_actions = len(self.action_space)

        # 绘制GUI
        self.title('Maze env')
        self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))   # 指定窗口大小 "width x height"
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                           height=MAZE_H * UNIT,
                           width=MAZE_W * UNIT)     # 创建背景画布

        # create grids
        for c in range(UNIT, MAZE_W * UNIT, UNIT): # 绘制列分隔线
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(UNIT, MAZE_H * UNIT, UNIT): # 绘制行分隔线
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin 第一个方格的中心,
        origin = np.array([UNIT/2, UNIT/2]) 

        # hell
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - (UNIT/2 - 5), hell1_center[1] - (UNIT/2 - 5),
            hell1_center[0] + (UNIT/2 - 5), hell1_center[1] + (UNIT/2 - 5),
            fill='black')
        # hell
        hell2_center = origin + np.array([UNIT, UNIT * 2])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - (UNIT/2 - 5), hell2_center[1] - (UNIT/2 - 5),
            hell2_center[0] + (UNIT/2 - 5), hell2_center[1] + (UNIT/2 - 5),
            fill='black')

        # create oval 绘制终点圆形
        oval_center = origin + np.array([UNIT*2, UNIT*2])
        self.oval = self.canvas.create_oval(
            oval_center[0] - (UNIT/2 - 5), oval_center[1] - (UNIT/2 - 5),
            oval_center[0] + (UNIT/2 - 5), oval_center[1] + (UNIT/2 - 5),
            fill='yellow')

        # create red rect 绘制agent红色方块,初始在方格左上角
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')

        # pack all 显示所有canvas
        self.canvas.pack()


    def get_state(self, rect):
            # convert the coordinate observation to state tuple
            # use the uniformed center as the state such as 
            # |(1,1)|(2,1)|(3,1)|...
            # |(1,2)|(2,2)|(3,2)|...
            # |(1,3)|(2,3)|(3,3)|...
            # |....
            x0,y0,x1,y1 = self.canvas.coords(rect)
            x_center = (x0+x1)/2
            y_center = (y0+y1)/2
            state = ((x_center-(UNIT/2))/UNIT + 1, (y_center-(UNIT/2))/UNIT + 1)
            return state


    def reset(self):
        self.update()
        self.after(500) # delay 500ms
        self.canvas.delete(self.rect)   # delete origin rectangle
        origin = np.array([UNIT/2, UNIT/2])
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')
        # return observation 
        return self.get_state(self.rect)   

    

    def step(self, action):
        # agent和环境进行一次交互
        s = self.get_state(self.rect)   # 获得智能体的坐标
        base_action = np.array([0, 0])
        if action == self.action_space[0]:   # up
            if s[1] > 1:
                base_action[1] -= UNIT
        elif action == self.action_space[1]:   # down
            if s[1] < MAZE_H:
                base_action[1] += UNIT
        elif action == self.action_space[2]:   # right
            if s[0] < MAZE_W:
                base_action[0] += UNIT
        elif action == self.action_space[3]:   # left
            if s[0] > 1:
                base_action[0] -= UNIT

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        s_ = self.get_state(self.rect)  # next state

        # reward function
        if s_ == self.get_state(self.oval):     # reach the terminal
            reward = 1
            done = True
            s_ = 'terminal'
        elif s_ in [self.get_state(self.hell1), self.get_state(self.hell2)]: # reach the block
            reward = -1
            done = True
            s_ = 'terminal'
        else:
            reward = 0
            done = False

        return s_, reward, done

    def render(self): # 显示函数
        self.after(300) # delay 300ms
        self.update()   # update the window




if __name__ == '__main__':
    def test():
        for t in range(10):
            s = env.reset()
            print(s)
            while True:
                env.render()
                a = 'right'
                s, r, done = env.step(a)
                print(s)
                if done:
                    break
    env = Maze()
    env.after(100, test)      # 在延迟100ms后调用函数test
    env.mainloop()


This part is important that the reward function design is include, which is as follows

reward = { 1 , if reach the cheese − 1 , if reach the trap 0 , others \text{reward} = \left \{ \begin{aligned} & 1, \quad \text{if reach the cheese} \\ & -1, \quad \text{if reach the trap} \\ & 0, \quad \text{others} \end{aligned} \right. reward= 1,if reach the cheese1,if reach the trap0,others

We need to explan some function of the class Maze.

  • First, the function _build_maze creates the inital maze location.
    In this example we use the left up coordination of each grid as the state of each block.
  • Second, the function get_state converts the coordination of each grid to numerical representation such as ( 1 , 1 ) , ( 1 , 2 ) , ⋯ (1,1),(1,2),\cdots (1,1),(1,2),.
  • Third, the function reset renew the state which means placing the mouse in the original grid.
  • Then, the function step we let the agent interact with envrionment for one step, ang get the reward after the action.
  • Then, the function render controls updating the window.

1.3 Tabular Q-learning Algorithm

import numpy as np
import pandas as pd


class QLearningTable():
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = actions  # action list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy # epsilon greedy update policy
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table, use the coordinate as the observation

            # self.q_table = self.q_table.append(       # DataFrame.append is invalid
            #     pd.Series(
            #         [0]*len(self.actions),
            #         index=self.q_table.columns,
            #         name=state,
            #     )
            # )

            self.q_table = pd.concat(
                [
                self.q_table,
                pd.DataFrame(
                        data=np.zeros((1,len(self.actions))),
                        columns = self.q_table.columns,
                        index = [state]
                    )
                ]
            )

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # action selection
            # epsilon greedy algorithm
        if np.random.uniform() < self.epsilon:
            
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            # state_action == np.max(state_action) generate bool mask
            # choose best action
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

We store the Q-table as a DataFrame of pandas. The explanation of the functions are as follows.

  • First, the function check_state_exist check the existence of one state, if not we append it to the Q-table. This is because once the state-action pair is visited, then we update it into the Q-table.
  • Second, the function choose_action is following the ϵ \epsilon ϵ-greedy algorithm

π ( a ∣ s ) = { 1 − ϵ ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) , for the geedy action ϵ ∣ A ( s ) ∣ , for the other  ∣ A ( s ) ∣ − 1  actions \pi(a|s) = \left \{ \begin{aligned} 1 - \frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A(s)}|-1), & \quad \text{for the geedy action} \\ \frac{\epsilon}{|\mathcal{A}(s)|}, & \quad \text{for the other } |\mathcal{A}(s)|-1 \text{ actions} \end{aligned} \right. π(as)= 1A(s)ϵ(A(s)1),A(s)ϵ,for the geedy actionfor the other A(s)1 actions

  • Third, the function learn is update the q value as Q-learning algorithm purposed.

Q-learning : { q t + 1 ( s t , a t ) = q t ( s t , a t ) − α t ( s t , a t ) [ q t ( s t , a t ) − ( r t + 1 + γ max ⁡ a ∈ A ( s t + 1 ) q t ( s t + 1 , a ) ) ] q t + 1 ( s , a ) = q t ( s , a ) , for all  ( s , a ) ≠ ( s t , a t ) \text{Q-learning} : \left \{ \begin{aligned} \textcolor{red}{q_{t+1}(s_t,a_t)} & \textcolor{red}{= q_t(s_t,a_t) - \alpha_t(s_t,a_t) \Big[q_t(s_t,a_t) - (r_{t+1}+ \gamma \max_{a\in\mathcal{A}(s_{t+1})} q_t(s_{t+1},a)) \Big]} \\ \textcolor{red}{q_{t+1}(s,a)} & \textcolor{red}{= q_t(s,a)}, \quad \text{for all } (s,a) \ne (s_t,a_t) \end{aligned} \right. Q-learning: qt+1(st,at)qt+1(s,a)=qt(st,at)αt(st,at)[qt(st,at)(rt+1+γaA(st+1)maxqt(st+1,a))]=qt(s,a),for all (s,a)=(st,at)

1.3 Run this main

Run this main script that we can run the all codes.

from maze_env_custom import Maze
from RL_brain import QLearningTable

MAX_EPISODE = 10


def update():
    for episode in range(MAX_EPISODE):
        # initial observation, observation is the rect's coordiante
        # observation is [x0,y0, x1,y1]
        observation = env.reset()   

        while True:
            # fresh env
            env.render()

            # RL choose action based on observation ['up', 'down', 'right', 'left']
            action = RL.choose_action(str(observation))

            # RL take action and get next observation and reward
            observation_, reward, done = env.step(action)

            # RL learn from this transition
            RL.learn(str(observation), action, reward, str(observation_))

            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                break

        # show q_table
        print(RL.q_table)
        print('\n')

    # end of game
    print('game over')
    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = QLearningTable(env.action_space)

    env.after(100, update)
    env.mainloop()


Reference

赵世钰老师的课程
莫烦ReinforcementLearning course

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/809256.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WIZnet W5100S-EVB-Pico 静态IP配置教程(二)

W5100S是一个硬连线TCP/IP封装以太网控制器W5100S支持间接并行总线和高速SPI接口2种方式与主机进行通信。其内部还集成了以太网数据链路层&#xff08;MAC&#xff09;和10Base -T/100Base -T 以太网物理层&#xff08;PHY&#xff09;&#xff0c;支持自动协商&#xff08;10/…

记录vue的一些踩坑日记

记录vue的一些踩坑日记 安装Jq npm install jquery --save vue列表跳转到详情页&#xff0c;再返回列表的时候不刷新页面并且保持原位置不变&#xff1b; 解决&#xff1a;使用keepAlive 在需要被缓存的页面的路由中添加&#xff1a;keepAlive: true, {path: /viewExamine,nam…

Docker安装 Mysql 8.x 版本

文章目录 Docker安装 Mysql 8.0.22Mysql 创建账号并授权Mysql 数据迁移同版本数据迁移跨版本数据迁移 Mysql 5.x 版本与 Mysql 8.x版本是两个大版本&#xff0c;这里演示安装Mysql 8.x版本 Docker安装 Mysql 8.0.22 # 下载mysql $ docker pull mysql 默认安装最新…

SolidWorks(1)

打开solidworks,选择零件选择草图、绘制草图选择上视基准面 最后完成草图 选择拉伸切除 最终成品 鼠标按住中键&#xff0c;进行旋转

优思学院:六西格玛的10大概念和特点

六西格玛是一种管理方法&#xff0c;致力于提高组织的运营效率和质量水平。它起源于20世纪80年代的美国&#xff0c;随后在全球范围内得到广泛应用。今天我们将探讨六西格玛的十大概念和特点&#xff0c;帮助您了解如何将这一管理方法应用于您的业务中。 1. 什么是六西格玛&am…

【前端知识】React 基础巩固(三十七)——自定义connect高阶组件

React 基础巩固(三十七)——自定义connect高阶组件 一、手撸一个自定义connect高阶组件 import { PureComponent } from "react"; import store from "../store";/*** connect的参数&#xff1a;* 参数一&#xff1a; 函数* 参数二&#xff1a; 函数* 返…

TiProxy 尝鲜

说明 最近发现 tidb 有个 GitHub - pingcap/TiProxy 仓库&#xff0c;抱着好奇的心态想试试这个组件的使用效果。于是按照文档的介绍在本地环境使用tiup做了一些实验&#xff0c;现在将实验过程和实验结果分享给大家。 TiProxy介绍 官方README介绍的已经很清楚了&#xff0c;…

QT 调用USB免驱摄像头

文章目录 前言一、界面布局二、QImageEncoderSettings类三、图像的显示总结 前言 本篇文章来讲解一下如何使用QT调用摄像头&#xff0c;这里我使用的是USB免驱动摄像头&#xff0c;使用不需要按照驱动QT就可以调用到摄像头。 一、界面布局 这里使用QT设计师进行界面的布局&a…

数据结构:树的存储结构

学习树之前&#xff0c;我们已经了解了二叉树的顺序存储和链式存储&#xff0c;哪么我们如何来存储普通型的树结构的数据&#xff1f;如下图1&#xff1a; 如图1所示&#xff0c;这是一颗普通的树&#xff0c;我们要如何来存储呢&#xff1f;通常&#xff0c;存储这种树结构的数…

vmware磁盘组使用率100%处理

今天在外办事时&#xff0c;有客户发过来一个截图&#xff0c;问vmware 磁盘组空间使用率100%咋办&#xff1f;如下图&#xff1a; 直接回复&#xff1a; 1、首先删除iso文件等 2、若不存在ISO文件等&#xff0c;找个最不重要的虚拟机直接删除&#xff0c;删除后稍等就会释放…

数据分析-关于指标和指标体系

一、电商指标体系 二、指标体系的作用 三、统计学中基本的分析手段

记一次解决FTPS上传的文件为空的问题

最近公司的vsftpd文件服务由之前的FTP传输改成了FTPS的&#xff0c;虽然代码做了相应的调整&#xff0c;但是始终有个问题&#xff0c;就是在服务器上文件创建成功了&#xff0c;名称也是正确的&#xff0c;可是一看大小确是0&#xff0c;于是查看日志 用项目上传的始终是失败&…

mybatis-config.xml-配置文件详解

文章目录 mybatis-config.xml-配置文件详解说明文档地址:配置文件属性解析properties 属性应用实例 settings 全局参数定义应用实例 typeAliases 别名处理器举例说明 typeHandlers 类型处理器environments 环境environment 属性应用实例 mappers配置 mybatis-config.xml-配置文…

vue除了子组件抛出的额外参数,父组件如何传递额外参数

以下为一个简单的demo,只为记录一下 很多时候如果我们多个地方使用同一函数时,往往就需要进行判断了,但是组件库返回的函数携带的参数没办法让我们做多余的判断 这时就需要传递多余的参数了 方法一 使用箭头函数 <template><div><el-switchv-model"value&…

2023年中秋节国庆节是几月几号开始休假国庆中秋双节放假几天?

2023年中秋节国庆节是几月几号开始休假国庆中秋双节放假几天&#xff1f; 2023年中秋节与国庆节双节相遇一起放假长大8天; 2023年中秋节与国庆节放假时间是从2023年9月29日开始至10月6日共计8天&#xff0c;10月7日、10月8日补班&#xff1b; 2023年中秋节怎么购买月饼省钱又…

Jenkins配置自动化构建的几个问题

在创建构建任务时&#xff0c;填写git远程仓库地址时&#xff0c;出现以下报错 解决此报错先排查一下linux机器上的git版本 git --version 如果git 版本过低&#xff0c;可能会导致拉取失败&#xff0c;此时需要下载更高的git版本。 参考 Git安装 第二个解决办法报错信息中…

Ansible的脚本 --- playbook 剧本

文章目录 一、playbook剧本的组成创建剧本运行playbook二、定义、引用变量三、指定远程主机sudo切换用户四、when条件判断五、迭代Templates 模块tags 模块 一、playbook剧本的组成 playbooks 本身由以下各部分组成 &#xff08;1&#xff09;Tasks&#xff1a;任务&#xff0…

SpringCloudAlibaba之整合openFeign

一.openFeign的入门使用&#xff08;4步&#xff09; 1.引入openFeign的依赖包&#xff0c;记得父项目中要加上SpringCloud依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-openfeign</artifactI…

VSCode新手快速下载、安装、使用

目录 下载 安装 1、许可协议 2、安装位置 3、开始菜单文件夹 4、附加任务 5、确认安装 6、完成 使用 1、汉化&#xff08;设置中文界面&#xff09; 2、设置 下载 进入VSCode官方页面&#xff0c;选择自己系统对应的下载链接VSCode默认提供的User Installer版本。但…

【解决】el-tree报Cannot read property ‘getCheckedKeys‘ of undefined

如果你报错 Cannot read property getCheckedKeys of undefined 或者 Cannot read property getCheckedNodes of undefined 只要在你的在<el-tree>上加个这个&#xff0c;就可以了 ref"tree"