使用python运行网格世界环境下 TD算法

news2025/3/4 8:19:24

一、概述

本代码实现了在网格世界环境中使用 TD (0)(Temporal Difference (0))算法进行策略评估,并对评估结果进行可视化展示。通过模拟智能体在网格世界中的移动,不断更新状态值函数,最终得到每个状态的价值估计。

二、依赖库

  • numpy:用于进行数值计算,如数组操作、随机数生成等。
  • matplotlib.pyplot:用于绘制图形,将状态值函数的评估结果进行可视化。

三、代码结构与详细说明

1. 导入库

收起

python

import numpy as np
import matplotlib.pyplot as plt

导入 numpy 库并将其别名为 np,导入 matplotlib.pyplot 库并将其别名为 plt,以便后续使用。

2. 定义网格世界环境类 GridWorld

收起

python

# 定义网格世界环境
class GridWorld:
    def __init__(self, size=10):
        self.size = size
        self.terminal = (3, 3)  # 终止状态

  • 功能:初始化网格世界环境。
  • 参数
    • size:网格世界的大小,默认为 10x10 的网格,即网格的边长为 10。
    • self.terminal:终止状态的坐标,这里设置为 (3, 3),当智能体到达该状态时,一个回合结束。

收起

python

    def is_terminal(self, state):
        return state == self.terminal

  • 功能:判断给定的状态是否为终止状态。
  • 参数
    • state:要判断的状态,以坐标元组 (x, y) 的形式表示。
  • 返回值:如果 state 等于 self.terminal,返回 True;否则返回 False

收起

python

    def step(self, state, action):
        x, y = state
        if self.is_terminal(state):
            return state, 0  # 终止状态不再变化

  • 功能:模拟智能体在当前状态下采取指定动作后的状态转移和奖励获取。
  • 参数
    • state:当前状态,以坐标元组 (x, y) 的形式表示。
    • action:智能体采取的动作,用整数表示,0 表示向上,1 表示向下,2 表示向左,3 表示向右。
  • 处理逻辑:如果当前状态是终止状态,则直接返回当前状态和奖励 0,因为终止状态不会再发生变化。

收起

python

        # 定义动作:0=上, 1=下, 2=左, 3=右
        dx, dy = [(-1,0), (1,0), (0,-1), (0,1)][action]
        new_x = max(0, min(self.size - 1, x + dx))
        new_y = max(0, min(self.size - 1, y + dy))
        new_state = (new_x, new_y)

  • 处理逻辑
    • 根据动作编号 action 从列表 [(-1,0), (1,0), (0,-1), (0,1)] 中选取对应的偏移量 (dx, dy)
    • 计算新的 x 和 y 坐标,使用 max 和 min 函数确保新坐标在网格世界的边界内(范围是 0 到 self.size - 1)。
    • 将新的坐标组合成新的状态 new_state

收起

python

        reward = -1  # 每步固定奖励
        return new_state, reward

  • 处理逻辑:每走一步给予固定奖励 -1,然后返回新的状态和奖励。

3. TD (0) 策略评估函数 td0_policy_evaluation

收起

python

# TD(0) 策略评估
def td0_policy_evaluation(env, episodes=5000, alpha=0.1, gamma=1.0):
    V = np.zeros((env.size, env.size))  # 初始化状态值函数

  • 功能:使用 TD (0) 算法对策略进行评估,得到每个状态的价值估计。
  • 参数
    • env:网格世界环境对象,用于获取环境信息和执行状态转移。
    • episodes:训练的回合数,默认为 5000。
    • alpha:学习率,控制每次更新状态值函数时的步长,默认为 0.1。
    • gamma:折扣因子,用于权衡当前奖励和未来奖励的重要性,默认为 1.0。
  • 处理逻辑:初始化一个大小为 (env.size, env.size) 的二维数组 V,用于存储每个状态的价值估计,初始值都为 0。

收起

python

    for _ in range(episodes):
        state = (0, 0)  # 初始状态
        while True:
            if env.is_terminal(state):
                break

  • 处理逻辑
    • 进行 episodes 个回合的训练,每个回合从初始状态 (0, 0) 开始。
    • 使用 while 循环不断执行动作,直到智能体到达终止状态。

收起

python

            # 随机策略:均匀选择动作(上下左右各25%概率)
            action = np.random.randint(0, 4)
            next_state, reward = env.step(state, action)

  • 处理逻辑
    • 使用 np.random.randint(0, 4) 随机选择一个动作,每个动作的选择概率为 25%。
    • 调用环境的 step 方法,执行选择的动作,得到下一个状态 next_state 和奖励 reward

收起

python

            # TD(0) 更新公式
            td_target = reward + gamma * V[next_state]
            td_error = td_target - V[state]
            V[state] += alpha * td_error

  • 处理逻辑
    • 根据 TD (0) 算法的更新公式,计算目标值 td_target,即当前奖励加上折扣后的下一个状态的价值估计。
    • 计算 TD 误差 td_error,即目标值与当前状态的价值估计之差。
    • 使用学习率 alpha 乘以 TD 误差,更新当前状态的价值估计。

收起

python

            state = next_state  # 转移到下一状态

  • 处理逻辑:将当前状态更新为下一个状态,继续下一次循环。

收起

python

    return V

  • 返回值:返回经过 episodes 个回合训练后得到的状态值函数 V

4. 运行算法

收起

python

# 运行算法
env = GridWorld()
V = td0_policy_evaluation(env, episodes=1000)

  • 处理逻辑
    • 创建一个 GridWorld 类的实例 env,初始化网格世界环境。
    • 调用 td0_policy_evaluation 函数,对环境进行 1000 个回合的策略评估,得到状态值函数 V

5. 可视化结果函数 plot_value_function

收起

python

# 可视化结果
def plot_value_function(V):
    fig, ax = plt.subplots()
    im = ax.imshow(V, cmap='coolwarm')

  • 功能:将状态值函数 V 进行可视化展示。
  • 处理逻辑
    • 创建一个图形对象 fig 和一个坐标轴对象 ax
    • 使用 ax.imshow 函数将状态值函数 V 以图像的形式显示出来,使用 coolwarm 颜色映射。

收起

python

    for i in range(V.shape[0]):
        for j in range(V.shape[1]):
            text = ax.text(j, i, f"{V[i, j]:.1f}",
                           ha="center", va="center", color="black")

  • 处理逻辑:遍历状态值函数 V 的每个元素,在对应的图像位置上添加文本标签,显示该状态的价值估计,保留一位小数。

收起

python

    ax.set_title("TD(0) Estimated State Value Function")
    plt.axis('off')
    plt.colorbar(im)
    plt.show()

  • 处理逻辑
    • 设置图形的标题为 “TD (0) Estimated State Value Function”。
    • 关闭坐标轴显示。
    • 添加颜色条,用于显示颜色与数值的对应关系。
    • 显示绘制好的图形。

6. 调用可视化函数

收起

python

plot_value_function(V)

调用 plot_value_function 函数,将经过 TD (0) 算法评估得到的状态值函数 V 进行可视化展示。

四、注意事项

  • 可以根据需要调整 GridWorld 类的 size 参数和 terminal 属性,改变网格世界的大小和终止状态的位置。
  • 可以调整 td0_policy_evaluation 函数的 episodesalpha 和 gamma 参数,优化策略评估的效果。
  • 代码中的随机策略是简单的均匀随机选择动作,可根据实际需求修改为更复杂的策略。

完整代码

import numpy as np
import matplotlib.pyplot as plt

# 定义网格世界环境
class GridWorld:
    def __init__(self, size=10):
        self.size = size
        self.terminal = (3, 3)  # 终止状态
        
    def is_terminal(self, state):
        return state == self.terminal
    
    def step(self, state, action):
        x, y = state
        if self.is_terminal(state):
            return state, 0  # 终止状态不再变化
        
        # 定义动作:0=上, 1=下, 2=左, 3=右
        dx, dy = [(-1,0), (1,0), (0,-1), (0,1)][action]
        new_x = max(0, min(self.size - 1, x + dx))
        new_y = max(0, min(self.size - 1, y + dy))
        new_state = (new_x, new_y)
        
        reward = -1  # 每步固定奖励
        return new_state, reward

# TD(0) 策略评估
def td0_policy_evaluation(env, episodes=5000, alpha=0.1, gamma=1.0):
    V = np.zeros((env.size, env.size))  # 初始化状态值函数
    
    for _ in range(episodes):
        state = (0, 0)  # 初始状态
        while True:
            if env.is_terminal(state):
                break
            
            # 随机策略:均匀选择动作(上下左右各25%概率)
            action = np.random.randint(0, 4)
            next_state, reward = env.step(state, action)
            
            # TD(0) 更新公式
            td_target = reward + gamma * V[next_state]
            td_error = td_target - V[state]
            V[state] += alpha * td_error
            
            state = next_state  # 转移到下一状态
    
    return V

# 运行算法
env = GridWorld()
V = td0_policy_evaluation(env, episodes=1000)

# 可视化结果
def plot_value_function(V):
    fig, ax = plt.subplots()
    im = ax.imshow(V, cmap='coolwarm')
    for i in range(V.shape[0]):
        for j in range(V.shape[1]):
            text = ax.text(j, i, f"{V[i, j]:.1f}",
                           ha="center", va="center", color="black")
    ax.set_title("TD(0) Estimated State Value Function")
    plt.axis('off')
    plt.colorbar(im)
    plt.show()

plot_value_function(V)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2309381.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在Linux上使用APT安装Sniffnet的详细步骤

一、引言 Sniffnet 是一款开源的网络流量监控工具,适用于多种Linux发行版。如果你的Linux系统使用APT(Advanced Package Tool)作为包管理器,以下是如何通过APT安装Sniffnet的详细步骤。 二、系统要求 在开始安装之前&#xff0…

zookeeper-docker版

Zookeeper-docker版 1 zookeeper概述 1.1 什么是zookeeper Zookeeper是一个分布式的、高性能的、开源的分布式系统的协调(Coordination)服务,它是一个为分布式应用提供一致性服务的软件。 1.2 zookeeper应用场景 zookeeper是一个经典的分…

StableDiffusion本地部署 3 整合包猜想

本地部署和整合包制作猜测 文章目录 本地部署和整合包制作猜测官方部署第一种第二种 StabilityMatrix下载整合包制作流程猜测 写了这么多python打包和本地部署的文章,目的是向做一个小整合包出来,不要求有图形界面,只是希望一键就能运行。 但…

数据结构(初阶)(七)----树和二叉树(前中后序遍历)

实现链式结构的二叉树 实现链式结构的二叉树遍历前序遍历中序遍历后序遍历 节点个数叶子节点个数⼆叉树第k层结点个数⼆叉树的深度/⾼度查找值为X的节点二叉树的销毁 层序遍历判断二叉树是否为完全二叉树 ⽤链表来表⽰⼀棵⼆叉树,即⽤链来指⽰元素的逻辑关系。 通常…

科技赋能筑未来 中建海龙MiC建筑技术打造保障房建设新标杆

近日,深圳梅林路6号保障房项目顺利封顶,标志着国内装配式建筑领域又一里程碑式突破。中建海龙科技有限公司(以下简称“中建海龙”)以模块化集成建筑(MiC)技术为核心,通过科技创新与工业化建造深…

json介绍、python数据和json数据的相互转换

目录 一 json介绍 json是什么? 用处 Json 和 XML 对比 各语言对Json的支持情况 Json规范详解 二 python数据和json数据的相互转换 dumps() : 转换成json loads(): 转换成python数据 总结 一 json介绍 json是什么? 实质上是一条字符串 是一种…

计算机毕设JAVA——某高校宿舍管理系统(基于SpringBoot+Vue前后端分离的项目)

文章目录 概要项目演示图片系统架构技术运行环境系统功能简介 概要 网络上许多计算机毕设项目开发前端界面设计复杂、不美观,而且功能结构十分单一,存在很多雷同的项目:不同的项目基本上就是套用固定模板,换个颜色、改个文字&…

Spring Boot 测试:单元、集成与契约测试全解析

一、Spring Boot 分层测试策略 Spring Boot 应用采用经典的分层架构,不同层级的功能模块对应不同的测试策略,以确保代码质量和系统稳定性。 Spring Boot 分层架构: Spring Boot分层架构 A[客户端] -->|HTTP 请求| B[Controller 层] …

Oracle 数据库基础入门(四):分组与联表查询的深度探索(上)

在 Oracle 数据库的学习进程中,分组查询与联表查询是进阶阶段的重要知识点,它们如同数据库操作的魔法棒,能够从复杂的数据中挖掘出有价值的信息。对于 Java 全栈开发者而言,掌握这些技能不仅有助于高效地处理数据库数据&#xff0…

机器学习的起点:线性回归Linear Regression

机器学习的起点:线性回归Linear Regression 作为机器学习的起点,线性回归是理解算法逻辑的绝佳入口。我们从定义、评估方法、应用场景到局限性,用生活化的案例和数学直觉为你构建知识框架。 回归算法 一、线性回归的定义与核心原理 定义&a…

17、什么是智能指针,C++有哪几种智能指针【高频】

智能指针其实不是指针,而是一个(模板)类,用来存储指向某块资源的指针,并自动释放这块资源,从而解决内存泄漏问题。主要有以下四种: auto_ptr 它的思想就是当当一个指针对象赋值给另一个指针对…

PyCharm接入本地部署DeepSeek 实现AI编程!【支持windows与linux】

今天尝试在pycharm上接入了本地部署的deepseek,实现了AI编程,体验还是很棒的。下面详细叙述整个安装过程。 本次搭建的框架组合是 DeepSeek-r1:1.5b/7b Pycharm专业版或者社区版 Proxy AI(CodeGPT) 首先了解不同版本的deepsee…

PyCharm怎么集成DeepSeek

PyCharm怎么集成DeepSeek 在PyCharm中集成DeepSeek等大语言模型(LLM)可以借助一些插件或通过代码调用API的方式实现,以下为你详细介绍两种方法: 方法一:使用JetBrains AI插件(若支持DeepSeek) JetBrains推出了AI插件来集成大语言模型,不过截至2024年7月,官方插件主要…

【定昌Linux系统】部署了java程序,设置开启启动

将代码上传到相应的目录,并且配置了一个.sh的启动脚本文件 文件内容: #!/bin/bash# 指定JAR文件的路径(如果JAR文件在当前目录,可以直接使用文件名) JAR_FILE"/usr/local/java/xs_luruan_client/lib/xs_luruan_…

Java零基础入门笔记:(7)异常

前言 本笔记是学习狂神的java教程,建议配合视频,学习体验更佳。 【狂神说Java】Java零基础学习视频通俗易懂_哔哩哔哩_bilibili 第1-2章:Java零基础入门笔记:(1-2)入门(简介、基础知识)-CSDN博客 第3章…

【字符串】最长公共前缀 最长回文子串

文章目录 14. 最长公共前缀解题思路:模拟5. 最长回文子串解题思路一:动态规划解题思路二:中心扩散法 14. 最长公共前缀 14. 最长公共前缀 ​ 编写一个函数来查找字符串数组中的最长公共前缀。 ​ 如果不存在公共前缀,返回空字符…

react 中,使用antd layout布局中的sider 做sider的展开和收起功能

一 话不多说,先展示效果: 展开时: 收起时: 二、实现代码如下 react 文件 import React, {useState} from react; import {Layout} from antd; import styles from "./index.module.less"; // 这个是样式文件&#…

easyExcel使用案例有代码

easyExcel 入门,完成web的excel文件创建和导出 easyExcel官网 EasyExcel 的主要特点如下: 1、高性能:EasyExcel 采用了异步导入导出的方式,并且底层使用 NIO 技术实现,使得其在导入导出大数据量时的性能非常高效。 2、易于使…

苹果廉价机型 iPhone 16e 影像系统深度解析

【人像拍摄差异】 尽管iPhone 16e支持后期焦点调整功能,但用户无法像iPhone 16系列那样通过点击屏幕实时切换拍摄主体。前置摄像头同样缺失人像深度控制功能,不过TrueTone原彩闪光灯系统在前后摄均有保留。 很多人都高估了 iPhone 的安全性,查…

视觉图像坐标转换

1. 透镜成像 相机的镜头系统将三维场景中的光线聚焦到一个平面(即传感器)。这个过程可以用小孔成像模型来近似描述,尽管实际相机使用复杂的透镜系统来减少畸变和提高成像质量。 小孔成像模型: 假设有一个理想的小孔,…