RL笔记:基于策略迭代求CliffWaking-v0最优解(python实现)

news2024/12/26 21:59:51

目录

1. 概要

2. 实现

3. 运行结果


1. 概要

        CliffWalking-v0是gym库中的一个例子[1],是从Sutton-RLbook-2020的Example6.6改编而来。不过本文不是关于gym中的CliffWalking-v0如何玩的,而是关于基于策略迭代求该问题最优解的实现例。

        CliffWalking-v0的游戏环境是一个4*12的网格(如上图【1】所示)。游戏规则如下:

        Agent从左下角出发,在每个网格中,可以采取{UP,DOWN,RIGHT,LEFT}中任意一个动作。但是,如果采取动作后会越出边界的话,就退回原地。到达右下角的网格的话,一局游戏结束。

        最下面一排网格中除了左下角(出发网格)和右下角(目标网格)以外,是所谓的悬崖网格,如果采取行动后掉入悬崖网格,会得到-100点的奖励(或者说惩罚),并且会被直接扔回出发点。其它情况下,每次行动有-1点的奖励(或者说惩罚)。Agent必需最小化到达目标网格的开销(最大化奖励,或者说最小化惩罚)。

        这个游戏非常简单,不用计算,直觉就可以知道,最优策略是:在出发点向上走一格;然后在第3行一路右行;到达最右侧后向下移动一格后即到达目标网格。总的奖励是-13点。

        以下给出基于策略迭代算法来求解这个问题的最优策略,看看能不能得出以上直觉上的最优策略。

 

2. 实现

        CliffWalking-v0游戏的环境设定类似于GridWorld,所以这里采用了类似于GridWorld的状态表示方法。环境类对象创建时,用一个二维数组表示网格环境中各cell的类型,“1”表示Terminate cell;“-1”表示Cliff cells;“0”表示其它cells。如下所示:

    grid = [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ]
    grid[3][11] = 1 # Terminate cell
    for k in range(1,11):
        grid[3][11] = -1  # Cliff cells

        环境的转移状态函数P(s’,r|s,a)用Environment::transit_func()实现,如下所示:

    def transit_func(self, state, action):
        """
        Prob(s',r|s,a) stored in one dict[(s',reward)].
        """
        transition_probs = {}
        if not self.can_action_at(state):
            # Already on the terminal cell.
            return transition_probs

        opposite_direction = Action(action.value * -1)

        for a in self.actions:
            prob = 0
            if a == action:
                prob = self.move_prob
            elif a != opposite_direction:
                prob = (1 - self.move_prob) / 2

            next_state = self._move(state, a)
            if next_state.row == (self.row_length - 1) and 0 < next_state.column < (self.column_length - 1):
                reward = -100
                next_state = State(self.row_length - 1, 0) # Return to start grid when falls into cliff grid.
            else:
                reward = -1
            
            if (next_state,reward) not in transition_probs:
                transition_probs[(next_state,reward)] = prob
            else:
                transition_probs[(next_state,reward)] += prob

        return transition_probs

    def can_action_at(self, state):
        '''
        Assuming:
            grid[i][j] = 1: Terminate grid
            grid[i][j] =-1: Cliff grids
            grid[i][j] = 0: Other grids
        '''
        if self.grid[state.row][state.column] == 0:
            return True
        else:
            return False

    def _move(self, state, action):
        """
        Predict the next state upon the combination of {state, action}
        {state, action} --> next_state
        Called in transit_func()
        """
        if not self.can_action_at(state):
            raise Exception("Can't move from here!")

        next_state = state.clone()

        # Execute an action (move).
        if action == Action.UP:
            next_state.row -= 1
        elif action == Action.DOWN:
            next_state.row += 1
        elif action == Action.LEFT:
            next_state.column -= 1
        elif action == Action.RIGHT:
            next_state.column += 1

        # Check whether a state is out of the grid.
        if not (0 <= next_state.row < self.row_length):
            next_state = state
        if not (0 <= next_state.column < self.column_length):
            next_state = state

        # Entering into cliff grids is related to the correspong penalty and 
        # reset to start grid, hence will be handled upper layer.

        return next_state

        Planner类实现一个规划基类,进一步PolicyIterationPlanner类作为Planner子类实现了基于策略迭代的规划器,其中核心就是PolicyIterationPlanner:: policy_evaluation() 和 PolicyIterationPlanner::plan()。策略迭代算法在上一篇(RL笔记:动态规划(2): 策略迭代)中已经介绍,此处不再赘述。

        PolicyIterationPlanner:: policy_evaluation()实现的是策略评估,如下所示:

    def policy_evaluation(self, gamma, threshold):
        V = {}
        for s in self.env.states:
            # Initialize each state's expected reward.
            V[s] = 0

        while True:
            delta = 0
            for s in V:
                expected_rewards = []
                for a in self.policy[s]:
                    action_prob = self.policy[s][a]
                    r = 0
                    for prob, next_state, reward in self.transitions_at(s, a):
                        r += action_prob * prob * \
                             (reward + gamma * V[next_state])
                    expected_rewards.append(r)
                value = sum(expected_rewards)
                delta = max(delta, abs(value - V[s]))
                V[s] = value
            if delta < threshold:
                break

        return V

        PolicyIterationPlanner::plan()则实现了完整的策略迭代算法(策略评估部分调用了policy_evaluation())代码如下所示:

   def plan(self, gamma=0.9, threshold=0.0001):
        """
        Implement the policy iteration algorithm
        gamma    : discount factor
        threshold: delta for policy evaluation convergency judge.
        """
        self.initialize()
        states  = self.env.states
        actions = self.env.actions

        def take_max_action(action_value_dict):
            return max(action_value_dict, key=action_value_dict.get)

        while True:
            update_stable = True
            # Estimate expected rewards under current policy.
            V = self.policy_evaluation(gamma, threshold)
            self.log.append(self.dict_to_grid(V))

            for s in states:
                # Get an action following to the current policy.
                policy_action = take_max_action(self.policy[s])

                # Compare with other actions.
                action_rewards = {}
                for a in actions:
                    r = 0
                    for prob, next_state, reward in self.transitions_at(s, a):
                        r += prob * (reward + gamma * V[next_state])
                    action_rewards[a] = r
                best_action = take_max_action(action_rewards)
                if policy_action != best_action:
                    update_stable = False

                # Update policy (set best_action prob=1, otherwise=0 (greedy))
                for a in self.policy[s]:
                    prob = 1 if a == best_action else 0
                    self.policy[s][a] = prob

            # Turn dictionary to grid
            self.V_grid = self.dict_to_grid(V)
            self.iters = self.iters + 1
            print('PolicyIteration: iters = {0}'.format(self.iters))
            self.print_value_grid()
            print('******************************')

            if update_stable:
                # If policy isn't updated, stop iteration
                break

3. 运行结果

        运行结果如下(右下角可以忽视,因为到达右下角后游戏结束了,不会再有进一步的行动了):

        由此可见,以上实现的确得出了跟直感相同的最优策略。

完整代码参见:reinforcement-learning/CliffWalking-v0.py

本强化学习之学习笔记系列总目录参见:强化学习笔记总目录

[1] Cliff Walking - Gym Documentation (gymlibrary.dev)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/389872.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Promise-异步回调

1.理解Promise promise是ES6提出的异步编程的新的解决方案&#xff0c;通过链式调用解决ajax回调地狱 从语法上看&#xff0c;promise是一个构造函数&#xff0c;自己身上有all、reject、resolve方法&#xff0c;原型上有then、catch方法 从功能上看&#xff0c;Promise对象用…

BloomFilter原理学习

文章目录BloomFilter简单介绍BloomFilter中的数学知识fpp(误判率/假阳性)的计算k的最小值公式总结编程语言实现golang的实现[已知n, p求m和k](https://github.com/bits-and-blooms/bloom/blob/master/bloom.go#L133)参考BloomFilter简单介绍 BloomFilter我们可能经常听到也在使…

瑞吉外卖——day2

目录 一、新增员工 二、查询分页数据 三、启用、禁用员工账户、编辑员工信息 一、新增员工 点击左上角新增员工 页面如下&#xff1a; 我们随便填数据 &#xff0c;点击保存&#xff0c;请求的地址如下 返回前端可以看到请求方式为Post 在employeeController中编写对应的代…

Elasticsearch:图片相似度搜索的 5 个技术组成部分

作者&#xff1a;Radovan Ondas&#xff0c;Bernhard Suhm 在本系列博文的第一部分中&#xff0c;我们介绍了图像相似度搜索&#xff0c;并回顾了一种可以降低复杂性并便于实施的高级架构。 此博客解释了实现图像相似性搜索应用程序所需的每个组件的基本概念和技术注意事项。 学…

Python采集本地二手房,一键知晓上万房源信息

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 所以今天教大家用Python来采集本地房源数据&#xff0c;帮助大家筛选好房。 话不多说&#xff0c;让我们开始愉快的旅程吧~ 更多精彩内容、资源皆可点击文章下方名片获取此处跳转 本文涉及知识点 采集基本流程 requests 发送…

【Java】Spring Boot整合WebSocket

【Java】Spring Boot整合WebSocket WebSocket简介 WebSocket是一种协议&#xff0c;用于实现客户端和服务器之间的双向通信。它可以在单个TCP连接上提供全双工通信&#xff0c;避免了HTTP协议中的请求-响应模式&#xff0c;从而实现更高效的数据交换。WebSocket协议最初由HTM…

【计算几何】贝塞尔曲线 B样条曲线简介及其离散化 + Python C++ 代码实现

文章目录一、贝塞尔曲线二、B样条曲线三、Python 代码实现B样条曲线离散化四、C 代码实现B样条曲线离散化4.1 主要代码4.2 其余类4.3 离散效果展示&#xff08;在CAD中展示&#xff09;本文只做简介&#xff0c;关于贝塞尔曲线和B样条曲线的详细介绍&#xff0c;请参考&#xf…

unity UGUI系统梳理 - 基本布局

偷懒了&#xff0c;部分节选unity API API 1、矩形工具 为了便于布局&#xff0c;每个 UI 元素都表示为矩形。可使用工具栏中的__矩形工具 (Rect Tool)__ 在 Scene 视图中操纵此矩形。矩形工具既可用于 Unity 的 2D 功能&#xff0c;也可用于 UI&#xff0c;实际上甚至还可用…

C/C++开发,无可避免的多线程(篇三).协程及其支持库

一、c20的协程概念 在c20标准后&#xff0c;在一些函数中看到co_await、co_yield、co_return这些关键词&#xff0c;这是c20为协程实现设计的运算符。 协程是能暂停执行以在之后恢复的函数。原来我们调用一个功能函数时&#xff0c;只要调用了以后&#xff0c;就要完整执行完该…

【Kettle-佛系总结】

Kettle-佛系总结Kettle-佛系总结1.kettle介绍2.kettle安装3.kettle目录介绍4.kettle核心概念1.转换2.步骤3.跳&#xff08;Hop&#xff09;4.元数据5.数据类型6.并行7.作业5.kettle转换1.输入控件1.csv文件输入2.文本文件输入3.Excel输入4.XML输入5.JSON输入6.表输入2.输出控件…

百度Apollo规划算法——轨迹拼接

百度Apollo规划算法——轨迹拼接引言轨迹拼接1、什么是轨迹拼接&#xff1f;2、为什么要进行轨迹拼接&#xff1f;3、结合Apollo代码为例理解轨迹拼接的细节。参考引言 在apollo的规划算法中&#xff0c;在每一帧规划开始时会调用一个轨迹拼接函数&#xff0c;返回一段拼接轨迹…

Kubernetes之服务发布

学了服务发现后&#xff0c;svc的IP只能被集群内部主机及pod才可以访问&#xff0c;要想集群外的主机也可以访问svc&#xff0c;就需要利用到服务发布。 NodePort Nodeport服务是外部访问服务的最基本方式。当我们创建一个服务的时候&#xff0c;把服务的端口映射到kubernete…

【大数据AI人工智能】常见的归一化函数有哪些?分别用数学公式详细介绍

常见的归一化函数有哪些?分别用数学公式详细介绍一下。 常见的归一化函数 常见的归一化函数包括: Min-Max 归一化Z-Score 归一化Log 归一化Sigmoid 归一化下面分别介绍这些归一化函数以及它们的数学公式。 1. Min-Max 归一化 Min-Max 归一化是将原始数据线性映射到 [0,1]…

dp模型——状态机模型C++详解

状态机定义状态机顾名思义跟状态有关系&#xff0c;但到底有什么关系呢。在实际解决的时候&#xff0c;通常把状态想成节点&#xff0c;状态的转换想成有向边的有向图&#xff0c;我们来举个例子。相信大家都玩过类似枪战的游戏&#xff08;没玩过的也听说过吧&#xff09;&…

4.创建和加入通道相关(network.sh脚本createChannel函数分析)[fabric2.2]

fabric的test-network例子有一个orderer组织、两个peer组织、每个组织一个节点&#xff0c;只有系统通道&#xff08;system-channel&#xff09;&#xff0c;没有其他应用通道。我们可以使用./network.sh createChannel命令来创建一个名为mychannel的应用通道。 一、主要概念 …

【Java开发】JUC进阶 04:线程池详解

1 线程池介绍由于频繁创建销毁线程要调用native方法比较消耗资源&#xff0c;为了保证内核的充分利用&#xff0c;所以引入了线程池的概念。&#x1f4cc; 线程池优点降低资源消耗提高响应速度方便管理&#x1f4cc; 创建线程池使用Executors创建使用ThreadPoolExecutor创建&am…

Git图解-为啥是Git?怎么装?

目录 零、学习目标 一、版本控制 1.1 团队开发问题 1.2 版本控制思想 1.2.1 版本工具 二、Git简介 2.1 简介 2.2 Git环境的搭建 三、转视频版 零、学习目标 掌握git的工作流程 熟悉git安装使用 掌握git的基本使用 掌握分支管理 掌握IDEA操作git 掌握使用git远程仓…

【教程】记录Typecho Joe主题升级与Joe魔改版

目录 升级Joe 其他魔改版 Joe主题挺好看的&#xff0c;很早之前我就装了。后来官方升级了主题&#xff0c;但没有给升级教程。这里记录一下我的升级过程&#xff0c;供大家参考。 Joe Github&#xff1a;GitHub - HaoOuBa/Joe: A Theme of Typecho 升级站点&#xff1a;小锋学…

WSL2使用Nvidia-Docker实现CUDA版本自由切换

众所周知&#xff0c;深度学习的环境往往非常麻烦&#xff0c;经常不同的项目所依赖的 torch、tensorflow 包对 CUDA 的版本也有不同的要求&#xff0c;Linux 下进行 CUDA 的管理比较麻烦&#xff0c;是一个比较头疼的问题。 随着 WSL2 对物理机显卡的支持&#xff0c;Nvidia-…

用二极管和电容过滤电源波动,实现简单的稳压 - 小水泵升压改装方案

简而言之&#xff0c;就是类似采样保持电路&#xff0c;当电源电压因为电机启动而骤降时&#xff0c;用二极管避免电容电压跟着降低&#xff0c;从而让电容上连接的低功耗芯片有一个比较稳定的供电电压。没什么特别的用处&#xff0c;省个LDO 吧&#xff0c;电压跌幅太大的时候…