Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

news2025/1/11 5:44:57

Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning.
This code refers to Mofan’s reinforcement learning course.

文章目录

  • Reinforcement Learning with Code 【Code 2. Tabular Sarsa】
    • 2.1 Problem and result
    • 2.2 Environment
    • 2.3 Tabular Sarsa Algorithm
    • 2.4 Run this main
    • 2.5 Check the Q table
    • Reference

2.1 Problem and result

Please consider the problem that a little mouse (denoted by red block) wants to avoid trap (denoted by black block) to get the cheese (denoted by yellow circle). As the figure shows.

Image

This chapter aims to realize tabular Sarsa algorithm sovle this problem.

2.2 Environment

We use the tkinter package of python to build our environment to interact with agent.

import numpy as np
import time
import sys
import tkinter as tk
# if sys.version_info.major == 2: # 检查python版本是否是python2
#     import Tkinter as tk
# else:
#     import tkinter as tk


UNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        # Action Space
        self.action_space = ['up', 'down', 'right', 'left'] # action space 
        self.n_actions = len(self.action_space)

        # 绘制GUI
        self.title('Maze env')
        self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))   # 指定窗口大小 "width x height"
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                           height=MAZE_H * UNIT,
                           width=MAZE_W * UNIT)     # 创建背景画布

        # create grids
        for c in range(UNIT, MAZE_W * UNIT, UNIT): # 绘制列分隔线
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(UNIT, MAZE_H * UNIT, UNIT): # 绘制行分隔线
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin 第一个方格的中心,
        origin = np.array([UNIT/2, UNIT/2]) 

        # hell1
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - (UNIT/2 - 5), hell1_center[1] - (UNIT/2 - 5),
            hell1_center[0] + (UNIT/2 - 5), hell1_center[1] + (UNIT/2 - 5),
            fill='black')
        # hell2
        hell2_center = origin + np.array([UNIT, UNIT * 2])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - (UNIT/2 - 5), hell2_center[1] - (UNIT/2 - 5),
            hell2_center[0] + (UNIT/2 - 5), hell2_center[1] + (UNIT/2 - 5),
            fill='black')

        # create oval 绘制终点圆形
        oval_center = origin + np.array([UNIT*2, UNIT*2])
        self.oval = self.canvas.create_oval(
            oval_center[0] - (UNIT/2 - 5), oval_center[1] - (UNIT/2 - 5),
            oval_center[0] + (UNIT/2 - 5), oval_center[1] + (UNIT/2 - 5),
            fill='yellow')

        # create red rect 绘制agent红色方块,初始在方格左上角
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')

        # pack all 显示所有canvas
        self.canvas.pack()


    def get_state(self, rect):
            # convert the coordinate observation to state tuple
            # use the uniformed center as the state such as 
            # |(1,1)|(2,1)|(3,1)|...
            # |(1,2)|(2,2)|(3,2)|...
            # |(1,3)|(2,3)|(3,3)|...
            # |....
            x0,y0,x1,y1 = self.canvas.coords(rect)
            x_center = (x0+x1)/2
            y_center = (y0+y1)/2
            state = ((x_center-(UNIT/2))/UNIT + 1, (y_center-(UNIT/2))/UNIT + 1)
            return state


    def reset(self):
        self.update()
        self.after(500) # delay 500ms
        self.canvas.delete(self.rect)   # delete origin rectangle
        origin = np.array([UNIT/2, UNIT/2])
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')
        # return observation 
        return self.get_state(self.rect)   

    

    def step(self, action):
        # agent和环境进行一次交互
        s = self.get_state(self.rect)   # 获得智能体的坐标
        base_action = np.array([0, 0])
        reach_boundary = False
        if action == self.action_space[0]:   # up
            if s[1] > 1:
                base_action[1] -= UNIT
            else: # 触碰到边界reward=-1并停留在原地
                reach_boundary = True

        elif action == self.action_space[1]:   # down
            if s[1] < MAZE_H:
                base_action[1] += UNIT
            else:
                reach_boundary = True   

        elif action == self.action_space[2]:   # right
            if s[0] < MAZE_W:
                base_action[0] += UNIT
            else:
                reach_boundary = True

        elif action == self.action_space[3]:   # left
            if s[0] > 1:
                base_action[0] -= UNIT
            else:
                reach_boundary = True

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        s_ = self.get_state(self.rect)  # next state

        # reward function
        if s_ == self.get_state(self.oval):     # reach the terminal
            reward = 1
            done = True
            s_ = 'success'
        elif s_ == self.get_state(self.hell1): # reach the block
            reward = -1
            s_ = 'block_1'
            done = False
        elif s_ == self.get_state(self.hell2):
            reward = -1
            s_ = 'block_2'
            done = False
        else:
            reward = 0
            done = False
            if reach_boundary:
                reward = -1

        return s_, reward, done

    def render(self):
        time.sleep(0.15)
        self.update()




if __name__ == '__main__':
    def test():
        for t in range(10):
            s = env.reset()
            print(s)
            while True:
                env.render()
                a = 'right'
                s, r, done = env.step(a)
                print(s)
                if done:
                    break
    env = Maze()
    env.after(100, test)      # 在延迟100ms后调用函数test
    env.mainloop()



This part is important that the reward function design is include, which is as follows

reward = { 1 , if reach the cheese − 1 , if reach the trap or reach the boundary 0 , others \text{reward} = \left \{ \begin{aligned} & 1, \quad \text{if reach the cheese} \\ & -1, \quad \text{if reach the trap or reach the boundary} \\ & 0, \quad \text{others} \end{aligned} \right. reward= 1,if reach the cheese1,if reach the trap or reach the boundary0,others

We need to explan some function of the class Maze.

  • First, the function _build_maze creates the inital maze location.
    In this example we use the left up coordination of each grid as the state of each block.
  • Second, the function get_state converts the coordination of each grid to numerical representation such as ( 1 , 1 ) , ( 1 , 2 ) , ⋯ (1,1),(1,2),\cdots (1,1),(1,2),.
  • Third, the function reset renew the state which means placing the mouse in the original grid.
  • Then, the function step we let the agent interact with envrionment for one step, ang get the reward after the action.
  • Then, the function render controls updating the window.

2.3 Tabular Sarsa Algorithm

import numpy as np
import pandas as pd


class RL():
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = actions  # action list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy # epsilon greedy update policy
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table, use the coordinate as the observation

            # self.q_table = self.q_table.append(       # DataFrame.append is invalid
            #     pd.Series(
            #         [0]*len(self.actions),
            #         index=self.q_table.columns,
            #         name=state,
            #     )
            # )

            self.q_table = pd.concat(
                [
                self.q_table,
                pd.DataFrame(
                        data=np.zeros((1,len(self.actions))),
                        columns = self.q_table.columns,
                        index = [state]
                    )
                ]
            )

    def choose_action(self, observation):
        """
            Use the epsilon-greedy method to update policy
        """
        self.check_state_exist(observation)
        # action selection
            # epsilon greedy algorithm
        if np.random.uniform() < self.epsilon:
            
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            # state_action == np.max(state_action) generate bool mask
            # choose best action
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_):
        pass



class SarsaTable(RL):
    """
        Implement Sarsa algorithm which is on-policy
    """
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable,self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'success' :
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

We store the Q-table as a DataFrame of pandas. The explanation of the functions are as follows.

  • First, the function check_state_exist check the existence of one state, if not we append it to the Q-table. This is because once the state-action pair is visited, then we update it into the Q-table.
  • Second, the function choose_action is following the ϵ \epsilon ϵ-greedy algorithm

π ( a ∣ s ) = { 1 − ϵ ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) , for the geedy action ϵ ∣ A ( s ) ∣ , for the other  ∣ A ( s ) ∣ − 1  actions \pi(a|s) = \left \{ \begin{aligned} 1 - \frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A(s)}|-1), & \quad \text{for the geedy action} \\ \frac{\epsilon}{|\mathcal{A}(s)|}, & \quad \text{for the other } |\mathcal{A}(s)|-1 \text{ actions} \end{aligned} \right. π(as)= 1A(s)ϵ(A(s)1),A(s)ϵ,for the geedy actionfor the other A(s)1 actions

  • Third, the function learn is update the q value as Q-learning algorithm purposed, which relays on the sample ( s t , a t , r t + 1 , s t + 1 , a t + 1 ) \textcolor{red}{(s_t,a_t,r_{t+1},s_{t+1},a_{t+1})} (st,at,rt+1,st+1,at+1). The sample denotes current state, current action, immediate reward, next state and next action respectively.

Sarsa : { q t + 1 ( s t , a t ) = q t ( s t , a t ) − α t ( s t , a t ) [ q t ( s t , a t ) − ( r t + 1 + γ   q t ( s t + 1 , a t + 1 ) ) ] q t + 1 ( s , a ) = q t ( s , a ) , for all  ( s , a ) ≠ ( s t , a t ) \text{Sarsa} : \left \{ \begin{aligned} \textcolor{red}{q_{t+1}(s_t,a_t)} & \textcolor{red}{= q_t(s_t,a_t) - \alpha_t(s_t,a_t) \Big[q_t(s_t,a_t) - (r_{t+1}+ \gamma \ q_t(s_{t+1},a_{t+1})) \Big]} \\ \textcolor{red}{q_{t+1}(s,a)} & \textcolor{red}{= q_t(s,a)}, \quad \text{for all } (s,a) \ne (s_t,a_t) \end{aligned} \right. Sarsa: qt+1(st,at)qt+1(s,a)=qt(st,at)αt(st,at)[qt(st,at)(rt+1+γ qt(st+1,at+1))]=qt(s,a),for all (s,a)=(st,at)

2.4 Run this main

Run this main script that we can run the all codes.

from maze_env_custom import Maze
from RL_brain import SarsaTable

MAX_EPISODE = 30


def update():
    for episode in range(MAX_EPISODE):
        # initial observation, observation is the rect's coordiante
        # observation is [x0,y0, x1,y1]
        observation = env.reset()   

        # RL choose action based on observation ['up', 'down', 'right', 'left']
        action = RL.choose_action(str(observation))

        while True:
            # fresh env
            env.render()

            # RL take action and get next observation and reward
            observation_, reward, done = env.step(action)
            

            action_ = RL.choose_action(str(observation_))


            # RL learn from this transition
            RL.learn(str(observation), action, reward, str(observation_), action_)

            # swap observation
            observation = observation_
            action = action_

            # break while loop when end of this episode
            if done:
                break

        # show q_table
        print(RL.q_table)
        print('\n')

    # end of game
    print('game over')
    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = SarsaTable(env.action_space)

    env.after(100, update)
    env.mainloop()

2.5 Check the Q table

After a long run we can check the q-table to judge wheter the learning is reasonable. The q-table is as follows:

                      up      down     right          left
(1.0, 1.0) -6.837352e-02 -0.000135 -0.000266 -2.970185e-02
(2.0, 1.0) -4.901299e-02 -0.000334 -0.000484 -6.039572e-04
(2.0, 2.0) -3.988164e-04 -0.049010 -0.038785 -2.737623e-04
block_1     0.000000e+00  0.049010  0.000000  0.000000e+00
(4.0, 2.0) -2.646359e-04  0.001314 -0.019900 -1.000000e-02
(4.0, 1.0) -4.900994e-02  0.000014 -0.010000 -3.128178e-06
(3.0, 1.0) -2.970450e-02 -0.029433 -0.000516 -2.078845e-04
(1.0, 2.0) -4.933690e-04 -0.000374 -0.000951 -3.940947e-02
block_2    -1.979099e-07  0.000000  0.010000 -1.531800e-07
(1.0, 3.0) -3.525635e-04 -0.000056 -0.010000 -3.940439e-02
(1.0, 4.0) -7.194310e-07 -0.010000  0.000591 -1.990000e-02
(2.0, 4.0) -1.000000e-02 -0.019900  0.012381  0.000000e+00
(3.0, 4.0)  1.654862e-01  0.000000  0.000000  0.000000e+00
(4.0, 4.0)  0.000000e+00  0.000000 -0.010000  0.000000e+00
(4.0, 3.0)  0.000000e+00  0.000000  0.000000  5.851985e-02
success     0.000000e+00  0.000000  0.000000  0.000000e+00

For example, when at the original place if the mouse wants to move up or move left it will reach the boundary and get reward − 1 -1 1. Hence the state value in q-table is minus.


Reference

赵世钰老师的课程
莫烦ReinforcementLearning course

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/815639.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C# 6.0】云LIS平台源码

基于云计算的区域LIS平台为医疗机构改善患者服务质量提供了强有力的支持&#xff0c;“以患者为核心”这一理念得到了充分实现&#xff0c;可以解决各医院LIS建设水平参差不齐的现状&#xff0c;并完善各医院内LIS系统的功能&#xff0c;实现数据标准统一、功能完善、性能可靠&…

VR全景旅游,智慧文旅发展新趋势!

引言&#xff1a; VR全景旅游正在带领我们踏上一场全新的旅行体验。这种沉浸式的旅行方式&#xff0c;让我们可以足不出户&#xff0c;却又身临其境地感受世界各地的美景。 一&#xff0e;VR全景旅游是什么&#xff1f; VR全景旅游是一种借助虚拟现实技术&#xff0c;让用户…

go env 配置(环境变量)说明

前提&#xff1a;已经安装好 golang 可正确的运行下面这段命令&#xff0c;来查看 go 的配置&#xff1a; go env 输出示例&#xff1a; 以上是我本地(windows)环境下输出的配置信息(环境变量) 我们这次就针对每个配置信息进行一个说明&#xff0c;具体到每个字段是什么意思…

浅谈Vue3 computed计算属性

什么是computed 官方给出的解释&#xff1a;接受一个 getter 函数&#xff0c;返回一个只读的响应式 ref 对象。该 ref 通过 .value 暴露 getter 函数的返回值。它也可以接受一个带有 get 和 set 函数的对象来创建一个可写的 ref 对象 // 只读 function computed<T>(ge…

远程控制软件安全吗?一文看懂ToDesk、RayLink、TeamViewer、Splashtop相关安全机制

目录 一、前言 二、远程控制中的安全威胁 三、国内外远控软件安全机制 【ToDesk】 【RayLink】 【Teamviewer】 【Splashtop】 四、安全远控预防 一、前言 近期&#xff0c;远程控制话题再一次引起关注。 据相关新闻报道&#xff0c;不少不法分子利用远程控制软件实施网络诈骗&…

直播预告 | 开源运维工具使用现状以及可持续产品的思考

运维平台自上世纪90年代开始进入中国市场&#xff0c;曾形成以传统四大外企&#xff1a;IBM、BMC、CA、HP为代表的头部厂商&#xff0c;还有一众从网管起家的国内厂商。2010年前后&#xff0c;出现了以Zabbix、Nagios、Cacti为代表的开源工具&#xff0c;后来又陆续出现了Prome…

如何计算文本的困惑度perplexity(ppl)

前言 本文关注在Pytorch中如何计算困惑度&#xff08;ppl&#xff09;为什么能用模型 loss 代表 ppl 如何计算 当给定一个分词后的序列 X ( x 0 , x 1 , … , x t ) X (x_0, x_1, \dots,x_t) X(x0​,x1​,…,xt​), ppl 计算公式为&#xff1a; 其中 p θ ( x i ∣ x &l…

Ansible之playbook剧本编写

一、playbook的相关知识 1.playbook简介 playbook是 一个不同于使用Ansible命令行执行方式的模式&#xff0c;其功能更强大灵活。简单来说&#xff0c;playbook是一个非常简单的配置管理和多主机部署系统&#xff0c;不同于任何已经存在的模式&#xff0c;可作为一个适合部署复…

3.5千伏硅化碳(SiC)深埋式超结二极管

目录 相关知识研究了什么文章创新点研究方法文章的结论 相关知识 在科学和工程技术领域&#xff0c;SEM通常是扫描电子显微镜&#xff08;Scanning Electron Microscope&#xff09;的缩写。因此&#xff0c;在 “外延SEM横截面图” 中&#xff0c;SEM指的是扫描电子显微镜&am…

分享一些精选的开源框架与代码!

今天主要是收集并精选了一些自己所了解和学习过的优秀的嵌入式开源框架代码和项目&#xff0c;不太了解的就不推荐给大家了&#xff0c;因为开源的东西实在是太多了&#xff0c;鱼龙混杂&#xff0c;所以取其精华去其糟粕是迫在眉睫的大事~ 当然也不要总是沉浸在开源的东西之中…

手把手教你Pytest+Allure2.X定制报告详细教程,给自己的项目量身打造一套测试报告-02(非常详细)

简介 俗话说“人靠衣服马靠鞍”一个项目做的在好&#xff0c;没有一分的漂亮的测试报告有时候也是很难在客户那边验收的&#xff0c;今天就带你们解决这一难题。 前边一篇文章是分享如何搭建pytestAllure的环境&#xff0c;从而生成一份精美的、让人耳目一新的测试报告&#…

【Linux 网络】 HTTPS协议原理 对称加密 非对称加密 数字证书

HTTPS协议 HTTPS协议和HTTP协议的区别什么是“加密” 和“解密”加密和解密的小故事 为什么要进行加密&#xff1f;臭名昭著的“运营商劫持”事件 常见加密方式对称加密非对称加密 数据摘要数字签名 HTTPS工作过程探究方案 1 &#xff1a; 只使用对称加密方案2 &#xff1a; 只…

微信小程序交易体验分常见问题指引

小程序交易体验分是为保障小程序用户的交易体验&#xff0c;促进开发者向用户提供更好的服务&#xff0c;帮助开发者更好的评估自身服务水平的机制。平台将对开发者在其小程序的违规行为进行判定&#xff0c;根据违规行为的严重程度对该小程序扣减不同分值的交易体验分&#xf…

Excel快捷键F1-F9详解:掌握实用快捷操作,提升工作效率

Excel是广泛应用于办公场景的优质电子表格软件&#xff0c;然而&#xff0c;许多人只是使用鼠标点击菜单和工具栏来完成操作&#xff0c;而忽略了快捷键的威力。在本文中&#xff0c;我们将详解Excel中的F1-F9快捷键&#xff0c;帮助您掌握实用的快捷操作&#xff0c;提升工作效…

多组学背景下的基因调控网络推断

染色质、转录因子和基因之间的相互作用产生了复杂的调节回路&#xff0c;可以表示为基因调节网络&#xff08;GRNs&#xff09;。GRNs的研究有助于了解疾病中细胞身份是如何建立、维持和破坏的。GRN可以从实验数据——历史上的大量组学数据——或文献中推断出来。单细胞多组学技…

图形化分析工具

1.图形化之短距分析 2.图形化之温度、CPU频率分析 3.常用指令集合 4.常见fastboot指令集合 5.模块之自动化分析–此项功能需要后续不断完善

再下一城丨美格智能座舱模组获头部新势力正式定点

近日&#xff0c;美格智能与国内领先的Tier1厂商密切协作&#xff0c;基于美格车载智能模组打造的智能座舱解决方案&#xff0c;成功获得国内某头部造车新势力的座舱域控制器项目定点&#xff0c;为其打造下一代智能座舱解决方案&#xff0c;创造更加沉浸和智能的座舱体验。 据…

C# Blazor 学习笔记(1):Blazor基础语法,组件化和生命周期

文章目录 前言基础语法路由Page 页面元素条件生成if / elseforforeach 绑定参数绑定(双向)事件绑定字典绑定 attributes 组件化如何使用Parameter 参数注入使用回调函数组件声明回调组件注入回调组件触发回调 直接控制 ref 生命周期App起始阶段&#xff1a;生命周期钩子阶段&am…

【产品文档】产品测试报告模板

今天和大家免费分享产品测试报告的文档模板。产品测试文档是在软件开发或产品开发过程中所编制的文件&#xff0c;用于记录测试计划、测试用例、测试结果和其他与产品测试相关的信息。它是测试团队或测试人员的重要工具&#xff0c;用于指导和管理测试过程&#xff0c;并与开发…

学网络安全,千万别棋差这一招

下午好&#xff0c;我的网工朋友 总感觉今年云计算的热度比网安高上不少&#xff0c;但网络安全在我心里依然是比较有意思的技术方向。 想学的人很多&#xff0c;想入门的人也很多。但找对方向和方法的人&#xff0c;很少很少。 网络安全到底怎么学&#xff1f; 其实入门这…