SARAS多步TD目标算法

news2025/1/21 0:56:55

SARAS多步TD目标算法

代码仓库:https://github.com/daiyizheng/DL/tree/master/09-rl
SARSA算法是on-policy 时序差分

在迭代的时候,我们基于 ϵ \epsilon ϵ-贪婪法在当前状态 S t S_t St 选择一个动作 A t A_t At ,然后会进入到下一个状态 S t + 1 S_{t+1} St+1 ,同时获得奖励 R t + 1 R_{t+1} Rt+1 ,在新的状态 S t + 1 S_{t+1} St+1 我们同样基于 ϵ \epsilon ϵ-贪婪法选择一个动作 A t + 1 A_{t+1} At+1 ,然后用它来更新我们的价值函数,更新公式如下:
Q ( S t , A t ) ← Q ( S t , A t ) + α [ R t + 1 + γ Q ( S t + 1 , A t + 1 ) − Q ( S t , A t ) ] Q\left(S_t, A_t\right) \leftarrow Q\left(S_t, A_t\right)+\alpha\left[R_{t+1}+\gamma Q\left(S_{t+1}, A_{t+1}\right)-Q\left(S_t, A_t\right)\right] Q(St,At)Q(St,At)+α[Rt+1+γQ(St+1,At+1)Q(St,At)]

  • 注意: 这里我们选择的动作 A t + 1 A_{t+1} At+1 ,就是下一步要执行的动作,这点是和Q-Learning算法的最大不同
  • 这里的 TD Target: δ t = R t + 1 + γ Q ( S t + 1 , A t + 1 ) \delta_t=R_{t+1}+\gamma Q\left(S_{t+1}, A_{t+1}\right) δt=Rt+1+γQ(St+1,At+1)
  • 在每一个非终止状态 S t S_t St
  • 进行一次更新,我们要获取 5 个数据, < S t , A t , R t + 1 , S t + 1 , A t + 1 > <S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}> <St,At,Rt+1,St+1,At+1>

那么n-step Sarsa如何计算

Q ( S t , A t ) ← Q ( S t , A t ) + α ( q t ( n ) − Q ( S t , A t ) ) Q\left(S_t, A_t\right) \leftarrow Q\left(S_t, A_t\right)+\alpha\left(q_t^{(n)}-Q\left(S_t, A_t\right)\right) Q(St,At)Q(St,At)+α(qt(n)Q(St,At))

其中 q ( n ) q_{(n)} q(n) 为:
q t ( n ) = R t + 1 + γ R t + 2 + ⋯ + γ n − 1 R t + n + γ n Q ( S t + n , A t + n ) q_t^{(n)}=R_{t+1}+\gamma R_{t+2}+\cdots+\gamma^{n-1} R_{t+n}+\gamma^n Q\left(S_{t+n}, A_{t+n}\right) qt(n)=Rt+1+γRt+2++γn1Rt+n+γnQ(St+n,At+n)

代码

  1. 构建环境
import gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        #is_slippery控制会不会滑
        env = gym.make('FrozenLake-v1',
                       render_mode='rgb_array',
                       is_slippery=False)

        super().__init__(env)
        self.env = env

    def reset(self):
        state, _ = self.env.reset()
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        over = terminated or truncated

        #走一步扣一份,逼迫机器人尽快结束游戏
        if not over:
            reward = -1

        #掉坑扣100分
        if over and reward == 0:
            reward = -100

        return state, reward, over

    #打印游戏图像
    def show(self):
        from matplotlib import pyplot as plt
        plt.figure(figsize=(3, 3))
        plt.imshow(self.env.render())
        plt.show()


env = MyWrapper()

env.reset()

env.show()
  1. 构建Q表
import numpy as np

#初始化Q表,定义了每个状态下每个动作的价值
Q = np.zeros((16, 4))

Q
  1. 构建数据
from IPython import display
import random


#玩一局游戏并记录数据
def play(show=False):
    state = []
    action = []
    reward = []
    next_state = []
    over = []

    s = env.reset()
    o = False
    while not o:
        a = Q[s].argmax()
        if random.random() < 0.1:
            a = env.action_space.sample()

        ns, r, o = env.step(a)

        state.append(s)
        action.append(a)
        reward.append(r)
        next_state.append(ns)
        over.append(o)

        s = ns

        if show:
            display.clear_output(wait=True)
            env.show()

    return state, action, reward, next_state, over, sum(reward)


play()[-1]
  1. 训练
#训练
def train():
    #训练N局
    for epoch in range(50000):

        #玩一局游戏,得到数据
        state, action, reward, next_state, over, _ = play()
        for i in range(len(state)):
            #计算value
            value = Q[state[i], action[i]]

            #计算target
            #累加未来N步的reward,越远的折扣越大
            #这里是在使用蒙特卡洛方法估计target
            reward_s = 0
            for j in range(i, min(len(state), i + 5)):
                reward_s += reward[j] * 0.9**(j - i)

            #计算最后一步的value,这是target的一部分,按距离给折扣
            target = Q[next_state[j]].max() * 0.9**(j - i + 1)

            #如果最后一步已经结束,则不需要考虑状态价值
            #最后累加reward就是target
            target = target + reward_s

            #更新Q表
            Q[state[i], action[i]] += (target - value) * 0.05

        if epoch % 5000 == 0:
            test_result = sum([play()[-1] for _ in range(20)]) / 20
            print(epoch, test_result)


train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1211691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从HTTP到Tomcat:揭秘Web应用的底层协议与高性能容器

WEB服务器 1. HTTP协议1.1 HTTP-概述1.1.1 介绍1.2.2 特点 2.2 HTTP-请求协议2.3 HTTP-响应协议2.3.1 格式介绍2.3.2 响应状态码 2.4 HTTP-协议解析 2. WEB服务器-Tomcat2.1 简介2.1.1 服务器概述2.1.2 Web服务器2.1.3 Tomcat 2.2 基本使用2.2.1 下载2.2.2 安装与卸载2.2.3 启动…

nestJs(二)node项目发送请求

整体演示 Get 请求参数 Get 请求的参数一般会放在 URL 上&#xff0c;这只需要Query 装饰器就行了。 Post 参数 Post 参数有些不同&#xff0c;会用到 DTO 的传输。因为数据通过 HTTP 传输是文本类型&#xff0c;因此需要将文本类型转化成代码可识别的变量。 新建 students…

ssm823基于ssm的心理预约咨询管理系统的设计与实现+vue

ssm823基于ssm的心理预约咨询管理系统的设计与实现vue 交流学习&#xff1a; 更多项目&#xff1a; 全网最全的Java成品项目列表 https://docs.qq.com/doc/DUXdsVlhIdVlsemdX 演示 项目功能演示&#xff1a; ————————————————

2013年11月26日 Go生态洞察:Go中的文本规范化

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

如何在jupyter 上安装Office365-REST-Python-Client

最近工作需要写python代码从sharepoint 上定期load 数据写入到SQL server 中&#xff0c; 首先需要安装 office365 的python库&#xff08;python库名&#xff1a; Office365-REST-Python-Client&#xff09;但是直接安装失败了。 !pip install Office365-REST-Python-Client…

HTML5学习系列之标题和正文、描述性信息

HTML5学习系列之标题和正文、描述性信息 标题和正文标题段落 描述性信息强调注解备选上下标术语代码预定义格式缩写词编辑提示引用引述换行显示修饰非文本注解 总结 标题和正文 标题 按语义轻重排列&#xff1a;h1\h2\h3\h4\h5\h6 <h1>诗词介绍</h1> <h2>…

μC/OS-II---消息队列管理1(os_q.c)

目录 消息队列的主要优点消息队列和消息邮箱消息队列相关操作消息队列创建消息队列删除在消息队列等待消息 消息队列的主要优点 消息队列的主要优点是解耦和异步通信。发送者和接收者之间不需要直接建立连接&#xff0c;它们只需要知道消息队列的名称或标识符即可。发送者将消…

Springboot项目中打印SQL语句日志

在项目中我想查看自己的SQL语句是什么&#xff0c;就是如下图的内容&#xff1a; 方法一&#xff1a;&#xff08;我常用的&#xff09; 可以在项目中的.yml配置文件中添加如下内容&#xff1a; logging:level:com.uyun.bankbranchalert.mapper: debug其中com.uyun.bankbran…

Postman+Newman+Jenkins实现接口测试持续集成

近期在复习Postman的基础知识&#xff0c;在小破站上跟着百里老师系统复习了一遍&#xff0c;也做了一些笔记&#xff0c;希望可以给大家一点点启发。 1.新建一个项目 2.设置自定义工作空间 3.执行windows的批处理命令 4.执行系统的Groovy脚本 5.生成的HTML的报告集成到Jenkin…

测试用例的书写方式以及测试模板大全

一个优秀的测试用例&#xff0c;应该包含以下信息&#xff1a; 1 &#xff09; 软件或项目的名称 2 &#xff09; 软件或项目的版本&#xff08;内部版本号&#xff09; 3 &#xff09; 功能模块名 4 &#xff09; 测试用例的简单描述&#xff0c;即该用例执行的目的或方法…

【nlp】2.3 LSTM模型

LSTM模型 1 LSTM介绍2 LSTM的内部结构图2.1 LSTM结构分析2.2 Bi-LSTM介绍2.3 使用Pytorch构建LSTM模型2.4 LSTM优缺点1 LSTM介绍 LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸…

立仪科技光谱共焦在半导体领域的应用

半导体技术在近年来以极快的速度发展&#xff0c;对质量和精密度的要求也不断提升。在这样的背景下&#xff0c;用于材料与设备研究的先进检测技术如光谱共焦成像将自然地找到一席之地。下面我们将详细探讨一下光谱共焦在半导体领域中的应用。 光谱共焦技术&#xff0c;通过在细…

互斥量保护资源

一、概念 在多数情况下&#xff0c;互斥型信号量和二值型信号量非常相似&#xff0c;但是从功能上二值型信号量用于同步&#xff0c; 而互斥型信号量用于资源保护。 互斥型信号量和二值型信号量还有一个最大的区别&#xff0c;互斥型信号量可以有效解决优先级反转现 象。 …

Spring Cloud Netflix微服务组件-Eureka

CAP理论 分区容忍是能容忍一个或一部分节点挂掉后&#xff0c;整体系统也能正常工作&#xff08;就是别的节点还是活着的&#xff09;&#xff0c;所以分布式系统中P是必须要有的。比如数据库主从架构&#xff0c;主从两个节点之间需要数据同步&#xff0c;主挂了&#xff0c;…

Mybatis中limit用法补充

limit a,b a是从第a1条数据开始&#xff0c;b是指读取几条数据 例如&#xff1a;select * from table limit 0,10 这句sql语句是说从表中获取第1条开始的10条记录 前端将page:页码    pageSize:每页多少条    这两个参数&#xff0c;传到后台。    通过这两个参数&am…

火焰图:链路追踪分析的可视化利器

什么是火焰图&#xff1f; 火焰图用于可视化分布式链路追踪&#xff0c;通过使用持续时间和不同颜色的水平条形来表示请求执行路径中的每个服务调用。分布式跟踪的火焰图包括错误、延迟数据等详情&#xff0c;帮助开发人员识别和解决应用程序中的瓶颈问题。 链路追踪与 Span …

易货:一种古老而有效的商业模式

在当今的商业世界中&#xff0c;我们常常听到关于电子商务、互联网和社交媒体等新技术的讨论。然而&#xff0c;尽管这些新技术为我们的日常生活带来了许多便利&#xff0c;但它们并没有完全取代传统的商业模式。其中&#xff0c;易货模式是一种古老而有效的商业模式&#xff0…

AI绘画工具汇总

目前市面上的AI绘画工具十分繁杂&#xff0c;以下工具可供参考&#xff1a; 1. Midjourney 添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; Midjourney&#xff1a;最主流的AI绘图工具之一&#xff0c;出图效果好&#xff0c;简单学习就可上手。需要在di…

Golang获取月份的第一天和最后一天

package mainimport ("fmt""strconv""strings""time" )func main() {month : "2023-11"result : GetMonthStartAndEnd(month)fmt.Println(result["start"] " - " result["end"]) }// 获取月…

Stable Diffusion WebUI使用AnimateDiff插件生成动画

AnimateDiff 可以针对各个模型生成的图片&#xff0c;一键生成对应的动图。 配置要求 GPU显存建议12G以上&#xff0c;在xformers或者sdp优化下显存要求至少6G以上。 要开启sdp优化&#xff0c;在启动参数加上--sdp-no-mem-attention 实际的显存使用量取决于图像大小&#…