Sarsa算法讲解及实现

news2025/2/27 21:11:04

Sarsa算法讲解及实现

1. Q表格

我们使用表格来存储每一个状态 state, 和在这个 state 每个行为 action 所拥有的 Q 值。

Q即为Q(s,a)就是在某一时刻的 s 状态下(s∈S),采取动作a (a∈A)动作能够获得收益的期望,环境会根据agent的动作反馈相应的回报reward r,所以算法的主要思想就是将State与Action构建成一张Q-table来存储Q值,然后根据Q值来选取能够获得最大的收益的动作。

例子:

Q-Tablea1a2
s1q(s1,a1)q(s1,a2)
s2q(s2,a1)q(s2,a2)
s3q(s3,a1)q(s3,a2)

2. Sarsa算法讲解

在强化学习中,Sarsa是一种对Q表格进行更新的算法,由于在强化学习环境最开始的时候,也可以认为是游戏刚开始的时候,Q表格是随机初始化的,所以需要在智能体不断与环境进行交互的时候不断地更新Q表格。

Sarsa表示的是State-Action-Reward-State-Action,是一个学习马尔可夫决策过程策略的算法,通常应用于机器学习和强化学习学习领域中。

State-Action-Reward-State-Action:这个名称清楚地反应了其学习更新函数依赖的5个值:分别是当前状态S1,当前状态选中的动作A1,获得的奖励Reward,S1状态下执行A1后取得的状态S2及S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA。

Sarsa算法的更新公式:

Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ Q ( s t + 1 , a t + 1 ) − Q ( s t , a t ) ] Q(s_{\boldsymbol{t}},a_{t})\leftarrow Q(s_{\boldsymbol{t}},a_{t})+\alpha [r_{\boldsymbol{t}}+\gamma Q(s_{t+1},a_{\boldsymbol{t}+1})-Q(s_{\boldsymbol{t}},a_{t})] Q(st,at)Q(st,at)+α[rt+γQ(st+1,at+1)Q(st,at)]

Sarsa算法伪代码:
在这里插入图片描述

算法中各个参数的意义:

    1. alpha是学习率, 来决定这次的误差有多少是要被学习的, alpha是一个小于1 的数.
    1. gamma 是对未来 reward 的衰减值. 我们可以这样想象.
    1. Q表示的是Q表格.
    1. Epsilon greedy 是用在决策上的一种策略, 比如 epsilon = 0.9 时, 就说明有90% 的情况我会按照 Q 表的最优值选择行为, 10% 的时间使用随机选行为. 【这也是结合了强化学习中探索和利用的概念】

3. 代码


# agent.py

import numpy as np


class SarsaAgent(object):
    def __init__(self,
                 obs_n,
                 act_n,
                 learning_rate=0.01,
                 gamma=0.9,
                 e_greedy=0.1):
        self.obs_n = obs_n  # 状态维度
        self.act_n = act_n  # 动作维度
        self.learning_rate = learning_rate  # 学习率
        self.gamma = gamma  # 奖励衰减率
        self.e_greedy = e_greedy  # 按一定概率随机选动作
        self.Q = np.zeros((obs_n, act_n))  # Q表格 todo:嵌套一层有什么作用?

    def sample(self, obs):
        if np.random.sample() < (1 - self.e_greedy):  # 强化概念 #根据table的Q值选动作
            return self.predict(obs)
        else:
            # 随机选择一个
            return np.random.choice(self.act_n)

    def predict(self, obs):
        # 进行预测,直接选择Q值最高的那个动作

        # 拉出该状态的那一行动作
        Q_list = self.Q[obs]

        maxQ = np.max(Q_list)
        action_list = np.where(Q_list == maxQ)[0]

        # 防止有多个最大值,所以随机选择一个
        return np.random.choice(action_list)

    def learn(self, obs, act, reward, obs_next, act_next, done):
        """
            obs: 交互前的obs, s_t
            action: 本次交互选择的action, a_t
            reward: 本次动作获得的奖励r
            next_obs: 本次交互后的obs, s_t+1
            next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
            done: episode是否结束
        """
        # s a r s a

        if done:
            target = reward
        else:
            target = reward + self.gamma * self.Q[obs_next][act_next]

        self.Q[obs][act] += self.learning_rate * (target - self.Q[obs][act])

    def save(self):
        npy_file = './q_table.npy'
        np.save(npy_file, self.Q)
        print(npy_file + ' saved.')

    def restore(self, npy_file='./q_table.npy'):
        self.Q = np.load(npy_file)
        print(npy_file + ' loaded.')

# train.py

import gym
from agent import SarsaAgent
import time

"""
    相当于是跑一个回合呗
"""


def run_episode(env, agent, render=False):
    total_steps = 0
    total_reward = 0

    obs = env.reset()
    action = agent.sample(obs)

    while True:
        # 采取动作,获取环境的反馈值
        next_obs, reward, done, _ = env.step(action)
        next_action = agent.sample(next_obs)

        # 训练sarsa算法
        agent.learn(obs, action, reward, next_obs, next_action, done)

        action = next_action
        obs = next_obs

        total_reward += reward
        total_steps += 1

        if render:
            env.render()  # 渲染新的一帧图形
        if done:
            break

    return total_reward, total_steps


"""
    用于在训练好Q表格后对其进行测试
"""


def test_episode(env, agent):
    total_reward = 0
    obs = env.reset()
    while True:
        action = agent.predict(obs)
        next_obs, reward, done, _ = env.step(action)
        total_reward += reward

        obs = next_obs
        time.sleep(0.5)
        env.render()
        if done:
            print('test reward = %.1f' % total_reward)
            break


def main():
    env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left

    agent = SarsaAgent(obs_n=env.observation_space.n,
                       act_n=env.action_space.n,
                       learning_rate=0.1,
                       gamma=0.9,
                       e_greedy=0.1)

    is_render = False

    for episode in range(500):
        ep_reward, ep_steps = run_episode(env, agent, is_render)
        # 每隔20个episode渲染一下看看效果
        if episode % 20 == 0:
            print('it is in ' + str(episode) + 'round')
            is_render = True
        else:
            is_render = False
        # 训练结束,查看算法效果
    test_episode(env, agent)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/347268.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java嵌入式持久化消息队列SMQ,改造自FQueue

一、说明之前项目中一直使用ConcurrentLinkedQueue做为缓冲队列&#xff08;主要是单个项目内&#xff0c;单条改批量的场景&#xff0c;多个项目间使用的是rocketmq&#xff09;&#xff0c;虽然用着方便但是是纯内存的&#xff0c;如果项目发生异常崩溃内存队列中的数据就会全…

JavaSE学习day6 进制转换和idea的调试

1.进制 1.1 常见的进制分类(掌握) 学过计算机组成原理的同学可以跳过这里。 二进制 十进制 八进制 十六进制 1.2 二进制 计算机数据在底层存储和运算的时候&#xff0c;都是以二进制的形式操作的&#xff0c;了解不同的进制&#xff0c;便于我们对数据的运算过程理解的更…

个人博客推出了更多功能

背景 Web2.0的典型代表博客&#xff0c;吸引着粉丝们打造属于自己的个人博客&#xff0c;分享自己的学习经验&#xff0c;记录自己的日常生活。随着大厂的入局&#xff0c;我们可以很容易的申请自己的个人博客&#xff0c;但是弊端就是往往会被他们控制&#xff0c;甚至封号。…

汕头市农村生活污水治理“十四五”规划行动方案

汕头&#xff0c;简称“汕”&#xff0c;广东省辖地级市&#xff0c;北接潮州&#xff0c;西邻揭阳&#xff0c;南濒南海&#xff0c;东与台湾隔海相望&#xff0c;境内韩江、榕江、练江三江入海&#xff0c;是中国大陆唯一拥有内海湾的城市。今天就来为大家介绍&#xff0c;汕…

Windows系统实现命令行(CMD)关闭指定的IIS网站

一、需求说明我们部署在Windows服务器上的IIS网站&#xff0c;需要在特定的时间停止一会后在进行重新启动该网站。二、思路分析由于需要特定的时间停止后重启网站&#xff0c;则手动操作肯定是不行的&#xff0c;需要实现自动化操作&#xff1a;①特定时间操作可以使用Windows系…

聚观早报|王慧文要做「中国版 OpenAI」;Temu斥资近亿元赞助超级碗

点击蓝字 / 关注我们今日要闻&#xff1a;王慧文要做「中国版 OpenAI」&#xff1b;Temu斥资近亿元赞助超级碗&#xff1b;新东方在线股价收跌2.8%&#xff1b;ChatGPT带动的AIGC创业热潮要来了&#xff1b;传谷歌拆分其AR部门王慧文要做「中国版 OpenAI」 2 月 13 日&#xff…

CSS中的常见单位(px,%,em,rem,vw,vh,vmax,vmin,calc)

像素(px)&百分比(%) 像素(Pixel) 长度单位&#xff0c;相对于显示器屏幕分辨率而言&#xff0c;通常在不定义显示缩放比例的情况下&#xff0c;1px对应显示器屏幕上的一个像素点。早年的pc端展示的页面基本都用这个单位。 百分比(%) 相对长度单位&#xff0c;指占用的父…

电源模块 DC-DC直流升压正负高压输出12v24v转±110V±150V±220V±250V±300V±600V

特点效率高达80%以上1*2英寸标准封装电源正负双输出稳压输出工作温度: -40℃~85℃阻燃封装&#xff0c;满足UL94-V0 要求温度特性好可直接焊在PCB 上应用HRA 1~40W系列模块电源是一种DC-DC升压变换器。该模块电源的输入电压分为&#xff1a;4.5~9V、9~18V、及18~36VDC标准&…

对比Hashtable、HashMap、TreeMap有什么不同?

第9讲 | 对比Hashtable、HashMap、TreeMap有什么不同&#xff1f; Map 是广义 Java 集合框架中的另外一部分&#xff0c;HashMap 作为框架中使用频率最高的类型之一&#xff0c;它本身以及相关类型自然也是面试考察的热点。 今天我要问你的问题是&#xff0c;对比 Hashtable、…

HTTP协议——详细讲解

目录 一、HTTP协议 1.http 2.url url的组成&#xff1a; url的保留字符&#xff1a; 3.http协议格式​编辑 ①http request ②http response 4.对request做出响应 5.GET与POST方法 ①GET ②POST 7.HTTP常见Header ①Content-Type:: 数据类型(text/html等)在上文…

JavaSE系列 打基础版

JavaSE 笔记记录P1 Java概述1.1 java编译1.2 认识JDK、JRE1.3 下载jdk和配置环境变量1.4 开发注意事项和开发细节1.5 学习java之我的需求1.6 转义字符1.7 注释1.8 代码规范1.9 dos命令 了解P2 变量数据类型变量基本使用数据类型转换P3运算符P4 控制结构P5 数组、排序和查找P6面…

突破压缩极限的AI语音编解码器

I. Speech Codecs语音编码的目的是在保持语音质量的前提下尽可能地减少传输所用的带宽&#xff0c;主要是利用人的发声过程中存在的冗余度和人的听觉特性达到压缩的目的。经过了多年的发展&#xff0c;目前语音编解码器大致可以分为以下几类&#xff1a;波形编码&#xff0c;将…

c++ 指针、引用和常量

指针、引用和常量的关系_夜悊的博客-CSDN博客 1. ① 指针是对象&#xff0c;引用不是对象&#xff08;在此可以理解为变量&#xff0c;一个变量是一个对象&#xff09; 指针不必须初始化引用只是为一个已经存在的对象所起的另一个名字&#xff08;别名&#xff09;&#xff…

亚马逊云科技汽车行业解决方案

当今&#xff0c;随着万物智联、云计算等领域的高速发展&#xff0c;创新智能网联汽车和车路协同技术正在成为车企加速发展的关键途径&#xff0c;推动着汽车产品从出行代步工具向着“超级智能移动终端”快速转变。 挑战无处不在&#xff0c;如何抢先预判&#xff1f; 随着近…

安装 GPU 版本的 tensorflow 完整版本

前言&#xff1a; 之前安装的 CPU 版本的 tensorflow 一直出问题&#xff0c;索性就直接安装 GPU 版本的 tensorflow 了&#xff08;有了GPU 就不能浪费&#xff09;。 安装过程&#xff1a; 1&#xff09;看自己有无 GPU&#xff0c;找到对应 GPU 的版本&#xff1a;任务管理…

C生万物 | 常量指针和指针常量的感性理解

文章目录&#x1f4da;引言✒常量指针&#x1f50d;介绍与分析&#x1f4f0;小结与记忆口诀✒指针常量&#x1f50d;介绍与分析&#x1f4f0;小结与记忆口诀&#x1f449;一份凉皮所引发的故事&#x1f448;总结与提炼&#x1f4da;引言 本文我们来说说大家很困惑的两个东西&am…

【蒸滴C】C语言指针入门很难?看这一篇就够了

目录 一、前言 二、指针是什么 小结&#xff1a; 三、指针变量是什么 小结&#xff1a; 四、指针在32位机器和64位机器中的差别 32位机器&#xff1a; 64位机器: 小结&#xff1a; 五、指针和指针类型 &#xff08;1&#xff09;指针的意义 &#xff08;2&#xff…

springboot自动配置原理以及spring.factories文件的作用详解

一、springboot 自动配置原理先说说我们自己的应用程序中Bean加入容器的办法&#xff1a;bean加入容器我们在应用程序的入口设置了 SpringBootApplication标签&#xff0c;默认情况下他会扫描所有次级目录。如果增加了 scanBasePackages属性&#xff0c;就会扫描所有被指定的路…

SAP ABAP根据事务码查找增强最直接的方法

下面是为任意事务代码查找用户出口的步骤&#xff1a; 方法一&#xff1a; 第 1 步&#xff1a;使用 事务代码&#xff1a;SE93。输入您要搜索用户出口的 事务代码。 在我们的场景中&#xff0c;我们将使用 CO11N。 第 2 步&#xff1a;点击显示&#xff1a; 第 3 步&#xf…

2023年浙江安全员精选真题题库及答案

百分百题库提供建筑安全员考试试题、安全员证考试真题、安全员证考试题库等,提供在线做题刷题&#xff0c;在线模拟考试&#xff0c;助你考试轻松过关。 268.注册执业人员未执行法律法规和工程质量强制性标准,造成重大安全事故的,(). A.停止执业 B.5年不予注册 C.10年不予注…