强化学习------Qlearning算法

news2024/11/27 18:48:41

简介

Q learning 算法是一种value-based的强化学习算法,Q是quality的缩写,Q函数 Q(state,action)表示在状态state下执行动作actionquality, 也就是能获得的Q value是多少。算法的目标是最大化Q值,通过在状态state下所有可能的动作中选择最好的动作来达到最大化期望reward

Q learning算法使用Q table来记录不同状态下不同动作的预估Q值。在探索环境之前,Q table会被随机初始化,当agent在环境中探索的时候,它会用贝尔曼方程(ballman equation)来迭代更新Q(s,a), 随着迭代次数的增多,agent会对环境越来越了解,Q 函数也能被拟合得越来越好,直到收敛或者到达设定的迭代结束次数。
伪代码如下:
在这里插入图片描述
整个算法就是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. Qlearning 是一个 off-policy 的算法, 因为里面的 max action 让 Q table 的更新可以不基于正在经历的经验(可以是现在学习着很久以前的经验,甚至是学习他人的经验). 不过这一次的例子, 我们没有运用到 off-policy, 而是把 Qlearning 用在了 on-policy 上, 也就是现学现卖, 将现在经历的直接当场学习并运用. On-policy 和 off-policy 的差别我们会在之后的 [Deep Q network (off-policy)] 学习中见识到. 而之后的教程也会讲到一个 on-policy (Sarsa) 的形式, 我们之后再对比.

算法实战

我们使用openAI的gym中的CliffWalking-v0作为环境

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import numpy as np
import gym
import time
import gridworld

#Sarsa算法
class QLearning():

    def __init__(self,num_states,num_actions,e_greed=0.1,lr=0.9,gamma=0.8):
        #建立Q表格
        self.Q = np.zeros((num_states,num_actions))
        self.e_greed = e_greed   #探索概率
        self.num_states = num_states
        self.num_actions = num_actions
        self.lr = lr   #学习率
        self.gamma = gamma #折扣因子


    def predict(self,state):
        """
        通过当前状态预测下一个动作
        :param state:
        :return:
        """
        #获取当前状态的所有动作的切片
        Q_list = self.Q[state,:]
        #随机选取其中最大值中的某一个(防止存在多个最大值时,总是选最前面的问题)
        action = np.random.choice(np.flatnonzero(Q_list == Q_list.max()))
        return  action

    def action(self,state):
        """
        选取动作
        :param state:
        :return:
        """
        #探索,随机选择一个动作
        if np.random.uniform(0,1) < self.e_greed:
            action = np.random.choice(self.num_actions)
        else:   #直接选取最大Q值的动作
            action = self.predict(state)
        return action

    def learn(self,state,action,reward,next_state,done):
        cur_Q = self.Q[state,action]
        # 当游戏结束时,不存在next_action和next_state
        target_Q = reward + (1-float(done))*self.gamma*self.Q[next_state,:].max()

        self.Q[state,action] += self.lr*(target_Q - cur_Q)

#训练
def train_episode(env,agent,is_render):
    total_reward = 0
    #初始化环境
    state,_ = env.reset()
    while True:
        action = agent.action(state)
        #执行动作返回结果
        next_state,reward,done,_,_ = env.step(action)
        #更新参数
        agent.learn(state,action,reward,next_state,done)
        #循环执行
        state = next_state
        total_reward += reward
        if is_render:
            env.render()
        if done:
            break

    return  total_reward
#测试
def test_episode(env,agent,is_render=False):
    total_reward = 0
    # 初始化环境
    state,_ = env.reset()
    while True:
        action = agent.predict(state)
        next_state, reward, done, _,_ = env.step(action)

        state = next_state
        total_reward += reward
        env.render()
        time.sleep(0.5)
        if done:
            break

    return total_reward
#训练
def train(env,episodes=500,lr=0.1,gamma=0.9,e_greed=0.1):
    agent = QLearning(
        num_states = env.observation_space.n,
        num_actions = env.action_space.n,
        lr = lr,
        gamma = gamma,
        e_greed = e_greed
    )
    is_render = False
    #先训练episodes次
    for e in range(episodes):
        ep_reward = train_episode(env,agent,is_render)
        print('Episode %s : reward= %.1f'%(e,ep_reward))
        #每执行50轮就显示一次
        if e%50 == 0:
            is_render = True
        else:
            is_render = False
    #训练结束后,我i们测试模型
    test_reward = test_episode(env,agent)
    print('test_reward= %.1f' % (test_reward))


if __name__ == '__main__':
    env = gym.make("CliffWalking-v0")
    env = gridworld.CliffWalkingWapper(env)
    train(env)

运行效果

在这里插入图片描述

另附工具类

用于可视化游戏界面

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-

import gym
import turtle
import numpy as np

# turtle tutorial : https://docs.python.org/3.3/library/turtle.html

def GridWorld(gridmap=None, is_slippery=False):
    if gridmap is None:
        gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
    env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
    env = FrozenLakeWapper(env)
    return env


class FrozenLakeWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.max_y = env.desc.shape[0]
        self.max_x = env.desc.shape[1]
        self.t = None
        self.unit = 50

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for _ in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for i in range(self.desc.shape[0]):
                for j in range(self.desc.shape[1]):
                    x = j
                    y = self.max_y - 1 - i
                    if self.desc[i][j] == b'S':  # Start
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'F':  # Frozen ice
                        self.draw_box(x, y, 'white')
                    elif self.desc[i][j] == b'G':  # Goal
                        self.draw_box(x, y, 'yellow')
                    elif self.desc[i][j] == b'H':  # Hole
                        self.draw_box(x, y, 'black')
                    else:
                        self.draw_box(x, y, 'white')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


class CliffWalkingWapper(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.t = None
        self.unit = 50
        self.max_x = 12
        self.max_y = 4

    def draw_x_line(self, y, x0, x1, color='gray'):
        assert x1 > x0
        self.t.color(color)
        self.t.setheading(0)
        self.t.up()
        self.t.goto(x0, y)
        self.t.down()
        self.t.forward(x1 - x0)

    def draw_y_line(self, x, y0, y1, color='gray'):
        assert y1 > y0
        self.t.color(color)
        self.t.setheading(90)
        self.t.up()
        self.t.goto(x, y0)
        self.t.down()
        self.t.forward(y1 - y0)

    def draw_box(self, x, y, fillcolor='', line_color='gray'):
        self.t.up()
        self.t.goto(x * self.unit, y * self.unit)
        self.t.color(line_color)
        self.t.fillcolor(fillcolor)
        self.t.setheading(90)
        self.t.down()
        self.t.begin_fill()
        for i in range(4):
            self.t.forward(self.unit)
            self.t.right(90)
        self.t.end_fill()

    def move_player(self, x, y):
        self.t.up()
        self.t.setheading(90)
        self.t.fillcolor('red')
        self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)

    def render(self):
        if self.t == None:
            self.t = turtle.Turtle()
            self.wn = turtle.Screen()
            self.wn.setup(self.unit * self.max_x + 100,
                          self.unit * self.max_y + 100)
            self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
                                        self.unit * self.max_y)
            self.t.shape('circle')
            self.t.width(2)
            self.t.speed(0)
            self.t.color('gray')
            for _ in range(2):
                self.t.forward(self.max_x * self.unit)
                self.t.left(90)
                self.t.forward(self.max_y * self.unit)
                self.t.left(90)
            for i in range(1, self.max_y):
                self.draw_x_line(
                    y=i * self.unit, x0=0, x1=self.max_x * self.unit)
            for i in range(1, self.max_x):
                self.draw_y_line(
                    x=i * self.unit, y0=0, y1=self.max_y * self.unit)

            for i in range(1, self.max_x - 1):
                self.draw_box(i, 0, 'black')
            self.draw_box(self.max_x - 1, 0, 'yellow')
            self.t.shape('turtle')

        x_pos = self.s % self.max_x
        y_pos = self.max_y - 1 - int(self.s / self.max_x)
        self.move_player(x_pos, y_pos)


if __name__ == '__main__':
    # 环境1:FrozenLake, 可以配置冰面是否是滑的
    # 0 left, 1 down, 2 right, 3 up
    env = gym.make("FrozenLake-v0", is_slippery=False)
    env = FrozenLakeWapper(env)

    # 环境2:CliffWalking, 悬崖环境
    # env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    # env = CliffWalkingWapper(env)

    # 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal
    # gridmap = [
    #         'SFFF',
    #         'FHFF',
    #         'FFFF',
    #         'HFGF' ]
    # env = GridWorld(gridmap)

    env.reset()
    for step in range(10):
        action = np.random.randint(0, 4)
        obs, reward, done, info = env.step(action)
        print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
                step, action, obs, reward, done, info))
        env.render()  # 渲染一帧图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1070233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day58:ARMday5,GPIO流水灯实验

汇编指令&#xff1a; .text .global _start _start: 1.设置GPIOE GPIOF寄存器的时钟使能 RCC_MP_AHB4ENSETR[5:4]->1 0x50000a28 LDR R0,0x50000a28 LDR R1,[R0] ORR R1,R1,#(0x3<<4) STR R1,[R0]2.设置PE10、PF10、PE8管脚为输出模式&#xff0c;GPIOE_MODER[21…

【gcc】RtpTransportControllerSend学习笔记 1

本文是woder大神 的文章的学习笔记。主要是大神文章: webrtc源码分析(8)-拥塞控制(上)-码率预估 的学习笔记。大神的webrtc源码分析(8)-拥塞控制(上)-码率预估 详尽而具体,堪称神作。因为直接看大神的文章,自己啥也没记住,所以同时跟着看代码。跟着大神走一遍,不求甚解,…

SpringCloud学习笔记-注册微服务到Eureka注册中心

目录 1.在该Module的pom文件中引入eureka依赖2.在该module的src/main/resources/application.yml配置文件3.启动对应的微服务4.查看微服务是否启动成功 假如我有一个微服务名字叫user-service,我需要把它注册到Eureka注册中心,则具体步骤如下: 1.在该Module的pom文件中引入eure…

MQ - 38 Serverless : 基于Serverless架构实现流式数据处理

文章目录 导图Pre概述典型的数据流场景什么是 ServerlessServerless 的定义Serverless Function如何基于 Serverless 实现数据处理数据处理流程底层架构和技术原理两种方案的优劣势对比业务案例和场景分析日志清洗场景事件流处理其他case总结导图

geecg-uniapp 源码下载运行 修改端口号 修改tabBar 修改展示数据

APP体验&#xff1a; http://jeecg.com/appIndex技术官网&#xff1a; http://www.jeecg.com安装文档&#xff1a; 快速开始 JeecgBoot 开发文档 看云视频教程&#xff1a; 零基础入门视频官方支持&#xff1a; http://jeecg.com/doc/help 一&#xff0c;下载安装 源码下载…

【力扣面试题】URL化

&#x1f451;专栏内容&#xff1a;力扣刷题⛪个人主页&#xff1a;子夜的星的主页&#x1f495;座右铭&#xff1a;前路未远&#xff0c;步履不停 目录 一、题目描述二、题目分析1、使用String内部方法2、使用StringBuilder 一、题目描述 题目链接&#xff1a;URL化 编写一种…

【软考】5.2 传输介质/通信方式/IP地址/子网划分

《传输介质》 双绞线&#xff1a;网线&#xff1b;传输距离在100m以内 无屏蔽双绞线&#xff1a;UTP&#xff1b;可靠性相对较低屏蔽双绞线&#xff1a;STP&#xff1b;屏蔽怕干扰&#xff1b;可靠性相对较高&#xff1b;一般用于对传输可靠性要求很高的场合 网线&#xff1a…

【Java 进阶篇】HTML块级元素详解

HTML&#xff08;Hypertext Markup Language&#xff09;是用于创建网页的标记语言。在HTML中&#xff0c;元素被分为块级元素和内联元素两种主要类型。块级元素通常用于构建网页的结构&#xff0c;而内联元素则嵌套在块级元素内&#xff0c;用于添加文本和其他内容。本文将重点…

卷积层与池化层输出的尺寸的计算公式详解

用文字简单表述如下 卷积后尺寸计算公式&#xff1a; (图像尺寸-卷积核尺寸 2*填充值)/步长1 池化后尺寸计算公式&#xff1a; (图像尺寸-池化窗尺寸 2*填充值)/步长1 一、卷积中的相关函数的参数定义如下&#xff1a; in_channels(int) – 输入信号的通道 out_channels(int)…

ubuntu2204配置仓库为阿里源

官网上支持到2004&#xff0c;2204需要手动更改一下 deb https://mirrors.aliyun.com/ubuntu/ jammy main restricted universe multiverse deb-src https://mirrors.aliyun.com/ubuntu/ jammy main restricted universe multiversedeb https://mirrors.aliyun.com/ubuntu/ jam…

【数组】二分查找(减不减一,看初始化!)

一、力扣习题链接 704. 二分查找 - 力扣&#xff08;LeetCode&#xff09; 二、思路 这道题目的前提是数组为有序数组&#xff0c;同时题目还强调数组中无重复元素&#xff0c;因为一旦有重复元素&#xff0c;使用二分查找法返回的元素下标可能不是唯一的&#xff0c;这些都是…

四、二叉树-下(Binary tree)

文章目录 一、算法核心二、经典例题1.[226. 翻转二叉树](https://leetcode.cn/problems/invert-binary-tree/description/)&#xff08;1&#xff09;思想&#xff08;2&#xff09;代码&#xff08;3&#xff09;复杂度分析 2.[101. 对称二叉树](https://leetcode.cn/problems…

【JavaSE】Synchronized实现原理

我们通常来使用synchronized来保证原子性&#xff0c;保证线程的安全。 但其实synchronized的底层是由一对monitorenter/monitorexit指令实现&#xff0c;每一个对象都有一个监视器&#xff08;monitor&#xff09;&#xff0c;而synchronized是通过对象内部叫监听器&#xff…

11.3 读图举例

一、低频功率放大电路 图11.3.1所示为实用低频功率放大电路&#xff0c;最大输出功率为 7 W 7\,\textrm W 7W。其中 A \textrm A A 的型号为 LF356N&#xff0c; T 1 T_1 T1​ 和 T 3 T_3 T3​ 的型号为 2SC1815&#xff0c; T 4 T_4 T4​ 的型号为 2SD525&#xff0c; T 2…

一款超好用的开源内存剖析器,今天教你怎么用!

Memray是一个由彭博社开发的、开源内存剖析器&#xff1b;开源一个多月&#xff0c;已经收获了超8.4k的star&#xff0c;是名副其实的明星项目。今天我们就给大家来推荐这款python内存分析神器。 Memray可以跟踪python代码、本机扩展模块和python解释器本身中内存分配&#xff…

【C++】运算符重载 ⑫ ( 等于判断 == 运算符重载 | 不等于判断 != 运算符重载 | 完整代码示例 )

文章目录 一、数组类 等号 运算符重载1、等于判断 运算符重载2、不等于判断 ! 运算符重载 二、完整代码示例1、Array.h 数组头文件2、Array.cpp 数组实现类3、Test.cpp 测试类4、执行结果 一、数组类 等号 运算符重载 1、等于判断 运算符重载 使用 成员函数 实现 等于判断 …

盒子模型的基础

盒子模型 边框&#xff08;border&#xff09; border可以设置元素的边框&#xff0c;边框分成三部分&#xff0c;边框的&#xff08;粗细&#xff09;边框的样式&#xff0c;边框的颜色 <style>div {width: 100px;height: 100px;border-width: 200;border-style: 边框…

【运行时数据区和程序计数器】

文章目录 1. 运行时数据区2. 程序计数器(PC 寄存器) 1. 运行时数据区 当我们通过前面的&#xff1a;类的加载-> 验证 -> 准备 -> 解析 -> 初始化 这几个阶段完成后&#xff0c;就会用到执行引擎对我们的类进行使用&#xff0c;同时执行引擎将会使用到我们运行时数据…

你了解的SpringCloud核心组件有哪些?他们各有什么作用?

SpringCloud 1.什么是 Spring cloud Spring Cloud 为最常见的分布式系统模式提供了一种简单且易于接受的编程模型&#xff0c;帮助开发人员构建有弹性的、可靠的、协调的应用程序。Spring Cloud 构建于 Spring Boot 之上&#xff0c;使得开发者很容易入手并快速应用于生产中。…

px4仿真实现无人机自主飞行

一,确定消息类型 无人机通过即在电脑是现自主飞行:思路如下。 通过Mavros功能包,将ROS消息转换为Mavlink消息。实现对无人机的控制。 几种消息之间的关系如下: 对于ROS数据,就是我们机载电脑执行ROS系统的数据。 对于Mavros消息,就是Mavros功能包内部的消息。查询网站…