强化学习------Sarsa算法

news2024/11/16 11:28:31

简介

SARSA(State-Action-Reward-State-Action)是一个学习马尔可夫决策过程策略的算法,通常应用于机器学习和强化学习学习领域中。它由RummeryNiranjan在技术论文“Modified Connectionist Q-Learning(MCQL)” 中介绍了这个算法,并且由Rich Sutton在注脚处提到了SARSA这个别名。
State-Action-Reward-State-Action这个名称清楚地反应了其学习更新函数依赖的5个值,分别是当前状态S1,当前状态选中的动作A1,获得的奖励RewardS1状态下执行A1后取得的状态S2S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA

算法的核心思想可以简化为:

Latex代码:
用伪代码可以表示为:
在这里插入图片描述

算法实战

我们使用openAI的gym中的CliffWalking-v0作为环境

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import numpy as np
import gym
import time
import gridworld

#Sarsa算法
class Sarsa():

    def __init__(self,num_states,num_actions,e_greed=0.1,lr=0.9,gamma=0.8):
        #建立Q表格
        self.Q = np.zeros((num_states,num_actions))
        self.e_greed = e_greed   #探索概率
        self.num_states = num_states
        self.num_actions = num_actions
        self.lr = lr   #学习率
        self.gamma = gamma #折扣因子


    def predict(self,state):
        """
        通过当前状态预测下一个动作
        :param state:
        :return:
        """
        #获取当前状态的所有动作的切片
        Q_list = self.Q[state,:]
        #随机选取其中最大值中的某一个(防止存在多个最大值时,总是选最前面的问题)
        action = np.random.choice(np.flatnonzero(Q_list == Q_list.max()))
        return  action

    def action(self,state):
        """
        选取动作
        :param state:
        :return:
        """
        #探索,随机选择一个动作
        if np.random.uniform(0,1) < self.e_greed:
            action = np.random.choice(self.num_actions)
        else:   #直接选取最大Q值的动作
            action = self.predict(state)
        return action

    def learn(self,state,action,reward,next_state,next_action,done):
        cur_Q = self.Q[state,action]
        # 当游戏结束时,不存在next_action和next_state
        if done:
            target_Q = reward
        else:
            target_Q = reward + self.gamma*self.Q[next_state,next_action]

        self.Q[state,action] += self.lr*(target_Q - cur_Q)

#训练
def train_episode(env,agent,is_render):
    total_reward = 0
    #初始化环境
    state,_ = env.reset()
    action = agent.action(state)
    while True:
        #执行动作返回结果
        next_state,reward,done,_,_ = env.step(action)
        #根据状态获取动作
        next_action = agent.action(next_state)
        #更新参数
        agent.learn(state,action,reward,next_state,next_action,done)
        #循环执行
        action = next_action
        state = next_state
        total_reward += reward
        if is_render:
            env.render()
        if done:
            break

    return  total_reward
#测试
def test_episode(env,agent,is_render=False):
    total_reward = 0
    # 初始化环境
    state,_ = env.reset()
    while True:
        action = agent.predict(state)
        next_state, reward, done, _,_ = env.step(action)

        state = next_state
        total_reward += reward
        env.render()
        time.sleep(0.5)
        if done:
            break

    return total_reward
#训练
def train(env,episodes=500,lr=0.1,gamma=0.9,e_greed=0.1):
    agent = Sarsa(
        num_states = env.observation_space.n,
        num_actions = env.action_space.n,
        lr = lr,
        gamma = gamma,
        e_greed = e_greed
    )
    is_render = False
    #先训练episodes次
    for e in range(episodes):
        ep_reward = train_episode(env,agent,is_render)
        print('Episode %s : reward= %.1f'%(e,ep_reward))
        #每执行50轮就显示一次
        if e%50 == 0:
            is_render = True
        else:
            is_render = False
    #训练结束后,我i们测试模型
    test_reward = test_episode(env,agent)
    print('test_reward= %.1f' % (test_reward))


if __name__ == '__main__':
    env = gym.make("CliffWalking-v0")
    env = gridworld.CliffWalkingWapper(env)
    train(env)

运行效果

在这里插入图片描述

另附工具类

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-

import gym
import turtle
import numpy as np

# turtle tutorial : https://docs.python.org/3.3/library/turtle.html

def GridWorld(gridmap=None, is_slippery=False):
    if gridmap is None:
        gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
    env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
    env = FrozenLakeWapper(env)
    return env


class FrozenLakeWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.max_y = env.desc.shape[0]
        self.max_x = env.desc.shape[1]
        self.t = None
        self.unit = 50

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for _ in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for i in range(self.desc.shape[0]):
                for j in range(self.desc.shape[1]):
                    x = j
                    y = self.max_y - 1 - i
                    if self.desc[i][j] == b'S':  # Start
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'F':  # Frozen ice
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'G':  # Goal
                        self.draw_box(x, y, 'yellow')
                    elif self.desc[i][j] == b'H':  # Hole
                        self.draw_box(x, y, 'black')
                    else:
                        self.draw_box(x, y, 'white')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


class CliffWalkingWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.t = None
        self.unit = 50
        self.max_x = 12
        self.max_y = 4

    def draw_x_line(self, y, x0, x1, color='gray'):
        assert x1 > x0
        self.t.color(color)
        self.t.setheading(0)
        self.t.up()
        self.t.goto(x0, y)
        self.t.down()
        self.t.forward(x1 - x0)

    def draw_y_line(self, x, y0, y1, color='gray'):
        assert y1 > y0
        self.t.color(color)
        self.t.setheading(90)
        self.t.up()
        self.t.goto(x, y0)
        self.t.down()
        self.t.forward(y1 - y0)

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for i in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for _ in range(2):
                self.t.forward(self.max_x * self.unit)
                self.t.left(90)
                self.t.forward(self.max_y * self.unit)
                self.t.left(90)
            for i in range(1, self.max_y):
                self.draw_x_line(
                    y=i * self.unit, x0=0, x1=self.max_x * self.unit)
            for i in range(1, self.max_x):
                self.draw_y_line(
                    x=i * self.unit, y0=0, y1=self.max_y * self.unit)

            for i in range(1, self.max_x - 1):
                self.draw_box(i, 0, 'black')
            self.draw_box(self.max_x - 1, 0, 'yellow')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


if __name__ == '__main__':
    # 环境1:FrozenLake, 可以配置冰面是否是滑的
    # 0 left, 1 down, 2 right, 3 up
    env = gym.make("FrozenLake-v0", is_slippery=False)
    env = FrozenLakeWapper(env)

    # 环境2:CliffWalking, 悬崖环境
    # env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    # env = CliffWalkingWapper(env)

    # 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal
    # gridmap = [
    #         'SFFF',
    #         'FHFF',
    #         'FFFF',
    #         'HFGF' ]
    # env = GridWorld(gridmap)

    env.reset()
    for step in range(10):
        action = np.random.randint(0, 4)
        obs, reward, done, info = env.step(action)
        print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
                step, action, obs, reward, done, info))
        env.render()  # 渲染一帧图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1067102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ueditor

下载文件在文件夹下npm install 安装依赖&#xff08;如果没有安装 grunt , 请先在全局安装 grunt&#xff09; 在文件下载目录下 npm install3. 在终端执行 grunt default grunt default

P1014 [NOIP1999 普及组] Cantor 表

#include <bits/stdc.h> using namespace std; int main() {int n,k1;cin>>n;while (n>k) {nn-k;k;}if(k%20) cout<<n<<"/"<<(k1-n);else cout<<k1-n<<"/"<<n;return 0; }

Angular学习笔记:路由

本文是自己的学习笔记&#xff0c;主要参考资料如下。 - B站《Angular全套实战教程》&#xff0c;达内官方账号制作&#xff0c;https://www.bilibili.com/video/BV1i741157Fj?https://www.bilibili.com/video/BV1R54y1J75g/?p32&vd_sourceab2511a81f5c634b6416d4cc1067…

到github上去学别人怎么写代码

线性回归是一种线性模型&#xff0c;例如&#xff0c;假设输入变量"(x) “与单一输出变量”(y) “之间存在线性关系的模型。更具体地说&#xff0c;输出变量”(y) “可以通过输入变量”(x) "的线性组合计算得出。单变量线性回归是一种线性回归&#xff0c;只有1个输入…

【2023研电赛】东北赛区一等奖作品:基于FPGA的小型水下无线光通信端机设计

本文为2023年第十八届中国研究生电子设计竞赛东北赛区一等奖作品分享&#xff0c;参加极术社区的【有奖活动】分享2023研电赛作品扩大影响力&#xff0c;更有丰富电子礼品等你来领&#xff01;&#xff0c;分享2023研电赛作品扩大影响力&#xff0c;更有丰富电子礼品等你来领&a…

环面上 FHE 的快速自举:LUT/Automata Blind Rotate

参考文献&#xff1a; [AP14] Alperin-Sheriff J, Peikert C. Faster bootstrapping with polynomial error[C]//Advances in Cryptology–CRYPTO 2014: 34th Annual Cryptology Conference, Santa Barbara, CA, USA, August 17-21, 2014, Proceedings, Part I 34. Springer B…

如何实现浏览器的前进和后退功能?

文章来源于极客时间前google工程师−王争专栏。 如何理解栈 后进者先出&#xff0c;先进者后出&#xff0c;这就是典型的“栈”结构。 从栈的操作特性来看&#xff0c;栈是一种“操作受限”的线性表&#xff0c;只允许在一端插入和删除数据。 当某个数据集合只涉及在一端插入…

【yolov系列:yolov7改进添加SIAM注意力机制】

yolo系列文章目录 文章目录 yolo系列文章目录一、SimAM注意力机制是什么&#xff1f;二、YOLOv7使用SimAM注意力机制1.在yolov7的models下面新建SimAM.py文件2.在common里面导入在这里插入图片描述 总结 一、SimAM注意力机制是什么&#xff1f; 论文题目&#xff1a;SimAM: A …

使用访问图像三种办法,指针,迭代器,动态地址计算

使用访问图像三种办法&#xff0c;指针&#xff0c;迭代器&#xff0c;动态地址计算 指针访问 方法一 #include <opencv2/opencv.hpp> #include <iostream> #include <opencv2/highgui/highgui.hpp> #include <opencv2/imgproc/imgproc.hpp>using n…

TS中interface接口的使用

接口用来定义一个类的结构&#xff0c;用来定义一个接口中应该包含哪些属性和方法 语法结构如下&#xff1a; interface 接口名 { // 属性和方法 } 一、使用接口进行类型声明 我们声明一个对象类型可以使用如下方法&#xff1a; // 定义一个对象类型 type Mytype {name: st…

show index 中部分字段的含义

show index from 表名 查看某张表的索引情况 另: SELECT * FROM information_schema.STATISTICS WHERE TABLE_NAME "t1" 与 show index from t1 作用相似,且会返回更多的字段信息 创建一张测试表t1: CREATE TABLE t1 ( id INT ( 11 ) NOT NULL AUTO_INCREMENT, nam…

【C++】STL详解(十二)—— 用哈希表封装出unordered_map和unordered_set

​ ​&#x1f4dd;个人主页&#xff1a;Sherry的成长之路 &#x1f3e0;学习社区&#xff1a;Sherry的成长之路&#xff08;个人社区&#xff09; &#x1f4d6;专栏链接&#xff1a;C学习 &#x1f3af;长路漫漫浩浩&#xff0c;万事皆有期待 上一篇博客&#xff1a;【C】STL…

Zookeeper 集群搭建

Zookeeper Zookeeper是一个开源的分布式的&#xff0c;为分布式框架提供协调服务的Apache项目 Zookeeper 工作机制 Zookeeper从设计模式角度来理解&#xff1a;是一个基于观察者模式设计的分布式服务管理框架 一旦这些数据的状态发生变化&#xff0c;Zookeeper就将负责通知…

【EI复现】基于同步发电机转动惯量和阻尼系数协同自适应控制策略(Simulink仿真实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

华为OD机试 - 计算最大乘积(2022Q4 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#…

分贝定义简介

一、什么是分贝 辅助单元Bel表示任何给定部件、电路或系统的输入和输出之间的对数比L,并且可以用电压、电流或功率来表示: 如果使用场量(电压或电流)代替功率量,则: 我们可以将增益或损耗因子相加为正或负dB值,而不是将其乘以比率。 分贝与功率转化的速读表如下所示:…

流程图设计制作都有哪些好用的工具

流程图是一种直观的图形表示方式&#xff0c;通常用于显示事物的过程、步骤和关系。在现代工作中&#xff0c;设计师经常需要绘制各种流程图来解释工作过程、产品设计等。本文将为您推荐7个流程图软件&#xff0c;以帮助您快速绘制高效的流程图&#xff0c;并提高工作效率。 即…

IDEA 2021.2.2设置自动热部署

1.导入包坐标 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-devtools</artifactId><scope>runtime</scope><optional>true</optional></dependency> 2.pom.xml添加piugins插…

Redis-双写一致性

双写一致性 双写一致性解决方案延迟双删&#xff08;有脏数据的风险&#xff09;分布式锁&#xff08;强一致性&#xff0c;性能比较低&#xff09;异步通知&#xff08;保证数据的最终一致性&#xff0c;高并发情况下会出现短暂的不一致情况&#xff09; 双写一致性 当修改了数…

光伏并网逆变器低电压穿越技术研究(Simulink仿真)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…