强化学习Q-Learning算法和简单迷宫代码

news2024/12/25 1:38:07

使用到的符号:

agent 代理
reward 奖励
state(s) 状态
action(a) 行为
R reward 矩阵
Q 矩阵:表示从经验中学到的知识
episode:表示 初始→目标 一整个流程

贝尔曼方程(迭代公式):

 
Q ( s , a ) ← Q ( s , a ) + α [ R ( s , a ) + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) \leftarrow Q(s,a) + \alpha [R(s,a) + \gamma \mathop {\max }\limits_{a'} Q(s',a') - Q(s,a)] Q(s,a)Q(s,a)+α[R(s,a)+γamaxQ(s,a)Q(s,a)]

α = 1 \alpha = 1 α=1

Q ( s , a ) ← R ( s , a ) + γ max ⁡ a ′ Q ( s ′ , a ′ ) Q(s,a) \leftarrow R(s,a) + \gamma \mathop {\max }\limits_{a'} Q(s',a') Q(s,a)R(s,a)+γamaxQ(s,a)

其中, α \alpha α 是学习率, γ \gamma γ 是超参数, Q ( s ′ , a ′ ) Q(s',a') Q(s,a) 表示下一个状态和行为。

以一个简单迷宫为例:设 agent 在房子内任意一个房间(0-4),迷宫出口为 5,即走出迷宫的条件是到房子外。
在这里插入图片描述
建模成 (“状态”对应节点,“行为”对应边):
在这里插入图片描述
构建 R 矩阵(R 固定不变):
在这里插入图片描述
初始化 Q 为零矩阵, γ = 0.8 \gamma = 0.8 γ=0.8,以一个 episode 具体说明。

在这里插入图片描述

假设初始状态为 2。

① 当前状态 2 的下一步行为只能选 3,根据迭代公式,考虑下一个状态和行为,状态 3 可能的行为:1、2 或 4。

在这里插入图片描述
② 当前状态为 3,随机地,选取转至状态 4。下一个状态和行为:状态 4 可能的行为:0、3、5。

在这里插入图片描述
③ 当前状态为 4,随机地,选取转至状态 5。下一个状态和行为:状态 5 可能的行为:1、4、5。

在这里插入图片描述

状态 5 为目标状态,故一次 episode 完成。

一次 episode 后的 Q 矩阵

在这里插入图片描述

具体 Q-Learning 算法的计算步骤:

在这里插入图片描述

迷宫实例及代码:

红色方块是 agent ,黄色圆圈和黑色方块都是目标状态,其中,黄色圆圈的奖励为 1,黑色方块的奖励为 -1。该迷宫一共有 16 个状态,每个状态可能的行为:u(上),d(下),l(左),r(右)。

在这里插入图片描述

程序主循环
from Q_Learning.maze_env import Maze
from Q_Learning.RL_brain import DQN
import time


def run_maze():
    print("====Game Start====")
    step = 0
    max_episode = 500
    for episode in range(max_episode):
        state = env.reset()  # 重置智能体位置
        step_every_episode = 0
        epsilon = episode / max_episode  # 动态变化随机值
        while True:
            if episode < 10:
                time.sleep(0.001)
            if episode > 480:
                time.sleep(0.002)
            env.render()  # 显示新位置
            action = model.choose_action(state, epsilon)  # 根据状态选择行为
            # 环境根据行为给出下一个状态,奖励,是否结束。
            next_state, reward, terminal = env.step(action)
            model.store_transition(state, action, reward, next_state)  # 模型存储经历
            # 控制学习起始时间(先积累记忆再学习)和控制学习的频率(积累多少步经验学习一次)
            if step > 200 and step % 5 == 0:
                model.learn()
            # 进入下一步
            state = next_state
            if terminal:
                print("episode=", episode, end=",")
                print("step=", step_every_episode)
                break
            step += 1
            step_every_episode += 1
    # 游戏环境结束
    print("====Game Over====")
    env.destroy()


if __name__ == "__main__":
    env = Maze()  # 环境
    model = DQN(
        n_states=env.n_states,
        n_actions=env.n_actions
    )  # 算法模型
    run_maze()
    env.mainloop()
    model.plot_cost()  # 误差曲线
环境模块 maze_env.py
import tkinter as tk
import sys
import numpy as np

UNIT = 40  # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        print("<env init>")
        super(Maze, self).__init__()
        # 动作空间(定义智能体可选的行为),action=0-3
        self.action_space = ['u', 'd', 'l', 'r']
        # 使用变量
        self.n_actions = len(self.action_space)
        self.n_states = 2
        # 配置信息
        self.title('maze')
        self.geometry("160x160")
        # 初始化操作
        self.__build_maze()

    def render(self):
        # time.sleep(0.1)
        self.update()

    def reset(self):
        # 智能体回到初始位置
        # time.sleep(0.1)
        self.update()
        self.canvas.delete(self.rect)
        origin = np.array([20, 20])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        # return observation
        return (np.array(self.canvas.coords(self.rect)[:2]) - np.array(self.canvas.coords(self.oval)[:2])) / (MAZE_H * UNIT)

    def step(self, action):
        # 智能体向前移动一步:返回next_state,reward,terminal
        s = self.canvas.coords(self.rect)
        base_action = np.array([0, 0])
        if action == 0:  # up
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:  # down
            if s[1] < (MAZE_H - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:  # right
            if s[0] < (MAZE_W - 1) * UNIT:
                base_action[0] += UNIT
        elif action == 3:  # left
            if s[0] > UNIT:
                base_action[0] -= UNIT

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        next_coords = self.canvas.coords(self.rect)  # next state

        # reward function
        if next_coords == self.canvas.coords(self.oval):
            reward = 1
            print("victory")
            done = True
        elif next_coords in [self.canvas.coords(self.hell1)]:
            reward = -1
            print("defeat")
            done = True
        else:
            reward = 0
            done = False
        s_ = (np.array(next_coords[:2]) - np.array(self.canvas.coords(self.oval)[:2])) / (MAZE_H * UNIT)
        return s_, reward, done

    def __build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                                height=MAZE_H * UNIT,
                                width=MAZE_W * UNIT)

        # create grids
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)
        origin = np.array([20, 20])
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 15, hell1_center[1] - 15,
            hell1_center[0] + 15, hell1_center[1] + 15,
            fill='black')
        oval_center = origin + UNIT * 2
        self.oval = self.canvas.create_oval(
            oval_center[0] - 15, oval_center[1] - 15,
            oval_center[0] + 15, oval_center[1] + 15,
            fill='yellow')
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        self.canvas.pack()
DQN模型 RL_brain.py
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt


class Net(nn.Module):
    def __init__(self, n_states, n_actions):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_states, 10)
        self.fc2 = nn.Linear(10, n_actions)
        self.fc1.weight.data.normal_(0, 0.1)
        self.fc2.weight.data.normal_(0, 0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        out = self.fc2(x)
        return out


class DQN:
    def __init__(self, n_states, n_actions):
        print("<DQN init>")
        # DQN有两个net:target net和eval net,具有选动作,存经历,学习三个基本功能
        self.eval_net, self.target_net = Net(n_states, n_actions), Net(n_states, n_actions)
        self.loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.01)
        self.n_actions = n_actions
        self.n_states = n_states
        # 使用变量
        self.learn_step_counter = 0  # target网络学习计数
        self.memory_counter = 0  # 记忆计数
        self.memory = np.zeros((2000, 2 * 2 + 2))  # 2*2(state和next_state,每个x,y坐标确定)+2(action和reward),存储2000个记忆体
        self.cost = []  # 记录损失值

    def choose_action(self, x, epsilon):
        # print("<choose_action>")
        x = torch.unsqueeze(torch.FloatTensor(x), 0)  # (1,2)
        if np.random.uniform() < epsilon:
            action_value = self.eval_net.forward(x)
            action = torch.max(action_value, 1)[1].data.numpy()[0]
        else:
            action = np.random.randint(0, self.n_actions)
        # print("action=", action)
        return action

    def store_transition(self, state, action, reward, next_state):
        # print("<store_transition>")
        transition = np.hstack((state, [action, reward], next_state))
        index = self.memory_counter % 200  # 满了就覆盖旧的
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        # print("<learn>")
        # target net 更新频率,用于预测,不会及时更新参数
        if self.learn_step_counter % 100 == 0:
            self.target_net.load_state_dict((self.eval_net.state_dict()))
        self.learn_step_counter += 1

        # 使用记忆库中批量数据
        sample_index = np.random.choice(200, 16)  # 2000个中随机抽取32个作为batch_size
        memory = self.memory[sample_index, :]  # 抽取的记忆单元,并逐个提取
        state = torch.FloatTensor(memory[:, :2])
        action = torch.LongTensor(memory[:, 2:3])
        reward = torch.LongTensor(memory[:, 3:4])
        next_state = torch.FloatTensor(memory[:, 4:6])

        # 计算loss,q_eval:所采取动作的预测value,q_target:所采取动作的实际value
        q_eval = self.eval_net(state).gather(1, action) # eval_net->(64,4)->按照action索引提取出q_value
        q_next = self.target_net(next_state).detach()
        # torch.max->[values=[],indices=[]] max(1)[0]->values=[]
        q_target = reward + 0.9 * q_next.max(1)[0].unsqueeze(1) # label
        loss = self.loss(q_eval, q_target)
        self.cost.append(loss)
        # 反向传播更新
        self.optimizer.zero_grad()  # 梯度重置
        loss.backward()  # 反向求导
        self.optimizer.step()  # 更新模型参数

    def plot_cost(self):
        plt.plot(np.arange(len(self.cost)), self.cost)
        plt.xlabel("step")
        plt.ylabel("cost")
        plt.show()

代码存在的一点缺陷:存在个别 episode 需要经过大量状态才能找到目标状态(主要表现在程序在两个状态间来回跳动:7→8→7→8→7→8…,需要很长时间才能跳出这个局限)

致谢:https://blog.csdn.net/itplus/article/details/9361915
   https://www.cnblogs.com/nrocky/p/14496252.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1263282.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络割接,用VRRP替换HSRP,你得这么割才行!

组网需求 如图3-11所示&#xff0c;C6500作为核心层设备上行连接出口路由器NE40E-X3&#xff0c;下行连接接入层设备CE6800。C6500上配置HSRP实现冗余备份网关&#xff0c;同时在二层网络部署MSTP破除环路。 总体思路 HSRP为CISCO私有协议&#xff0c;CE系列交换机&#xff08…

RabbitMQ之延迟消息

文章目录 前言一、死信交换机二、延迟消息死信交换机实现延迟消息图解流程 DelayExchange插件实现延迟消息安装插件声明延迟交换机发送延迟消息 总结 前言 死信交换机、延迟消息 一、死信交换机 当一个队列中的消息满足下列情况之一时&#xff0c;可以成为死信&#xff08;dea…

微信小程序推送服务号消息(一)【Go+微信小程序+微信服务号+微信开放平台】

一、需求场景 业务需要给微信小程序用户在某些场景推送微信服务号消息&#xff0c;例如&#xff1a;订单即将超时&#xff0c;电子合同签约超时等&#xff1b; 二、开发准备 1、开通微信服务号 入口&#xff1a;微信公众平台 1.1 在服务号中获取推送消息所需的配置信息&#…

头插法巧解任意链表区间反转

题目链接&#xff1a;https://leetcode.cn/problems/reverse-linked-list-ii/description/?envTypestudy-plan-v2&envIdtop-interview-150 头节点插入法&#xff0c;设置一个虚拟节点&#xff0c;不断循环到要反转的前面一个位置&#xff0c;哪怕是1也能包含进去。接着反…

Python UUID 完全指南

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com UUID&#xff08;Universally Unique Identifier&#xff0c;通用唯一标识符&#xff09;是一种全局唯一标识符生成方式&#xff0c;用于创建独一无二的标识符。Python的 uuid 模块提供了多种方法用于生成各种类…

四丶openlayer之瓦片地图

瓦片地图源于一种大地图解决方案&#xff0c;针对一整块非常大的地图进行切片&#xff0c;分成很多相同大小的小块地图&#xff0c;在用户访问的时候&#xff0c;再一块一块小地图加载&#xff0c;拼接在一起&#xff0c;从而还原成一整块大的地图。这样做的优点在于&#xff0…

【算法】FFT-1(递归实现)(不包括IFFT)

FFT 多项式多项式乘法复数及运算导数泰勒公式及展开式欧拉公式单位根 FFTCode IFFT 多项式 我们从课本中可以知道&#xff0c;一个 n − 1 n-1 n−1 次的多项式可以写成 a 0 a 1 x a 2 x 2 a 3 x 3 ⋯ a n − 1 x n − 1 a_{0}a_{1}xa_{2}x^2a_{3}x^3\dotsa_{n-1}x^{n-…

从 0 搭建 Vite 3 + Vue 3 Js版 前端工程化项目

之前分享过一篇vue3+ts+vite构建工程化项目的文章,针对小的开发团队追求开发速度,不想使用ts想继续使用js,所以就记录一下从0搭建一个vite+vue3+js的前端项目,做记录分享。 技术栈 Vite 3 - 构建工具 Vue 3 Vue Router - 官方路由管理器 Pinia - Vue Store你也可以选择vue…

2023年【G2电站锅炉司炉】考试试卷及G2电站锅炉司炉模拟试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 G2电站锅炉司炉考试试卷是安全生产模拟考试一点通生成的&#xff0c;G2电站锅炉司炉证模拟考试题库是根据G2电站锅炉司炉最新版教材汇编出G2电站锅炉司炉仿真模拟考试。2023年【G2电站锅炉司炉】考试试卷及G2电站锅炉…

Unity 接入TapADN播放广告时闪退 LZ4JavaSafeCompressor

通过跟踪安卓日志&#xff0c;发现报如下错误 Didnt find class "com.tapadn.lz4.LZ4JavaSafeCompressor" 解决方案&#xff1a; 去掉Minify这边的勾选&#xff0c;再打包即可。

香港优才计划是什么意思?一文详解2023年最新政策!

香港优才计划是什么意思&#xff1f;一文详解2023年最新政策&#xff01; 目前香港优才计划申请火热&#xff0c;但是还是有很多新手不太了解这个项目&#xff0c;跟风申请绝对不是什么好事&#xff0c;先了解清楚再考虑也是对自己对家人的一种交代。这篇文章就再来科普下。 优…

AWVS 使用方法归纳

1.首先确认扫描的网站&#xff0c;以本地的dvwa为例 2.在awvs中添加目标 输入的地址可以是域名也可以是ip&#xff0c;只要本机可以在浏览器访问的域名或ip即可 添加地址及描述之后&#xff0c;点击保存&#xff0c;就会展现出目标设置选项 business criticality译为业务关键…

SpringCloud之服务网关Gateway组件使用——详解

目录 一、网关介绍 1.什么是服务网关 2. 为什么需要网关 3.网关组件在微服务中架构 二、服务网关组件 1. zuul 1.x 2.x(netflix 组件) 1.1 zuul版本说明 2. gateway (spring) 2.1 特性 2.2 开发网关动态路由 2.2.1.创建项目引入网关依赖 2.2.2 快捷方式配置路由 2.2…

C++学习——类和对象(上)

C学习——类和对象 一、面向对象和面向过程的初步认识二、什么是类 一、面向对象和面向过程的初步认识 我们之前学习了C语言&#xff0c;我们知道 ① C语言&#xff1a;C语言是一门面向过程的语言&#xff0c;关注的是过程&#xff0c;分析出求解问题的步骤&#xff0c;通过函…

Web安全之SQL注入:明明设置了强密码,为什么还会被别人登录?

一、背景 让我们先来看一个案例。某天&#xff0c;当你在查看应用的管理后台时&#xff0c;发现有很多异常的操作。接着&#xff0c;你很快反应过来了&#xff0c;这应该是黑客成功登录了管理员账户。于是&#xff0c;你立刻找到管理员&#xff0c;问他是不是设置了弱密码。管…

jenkins 代码执行 (CVE-2017-1000353)漏洞复现

jenkins 代码执行 (CVE-2017-1000353)漏洞复现 名称: jenkins 代码执行 &#xff08;CVE-2017-1000353&#xff09; 描述: ​Jenkins 可以通过其网页界面轻松设置和配置,其中包括即时错误检查和内置帮助。 插件 通过更新中心中的 1000 多个插件,Jenkins 集成了持续集成和持续…

分享一套MES源码,可以直接拿来搞钱的好项目

目前国内智能制造如火如荼&#xff0c;工厂信息化、数字化是大趋势。如果找到一个工厂&#xff0c;搞定一个老板&#xff0c;搞软件的朋友就能吃几年。 中国制造业发达&#xff0c;工厂林立&#xff0c;但是普遍效率不高&#xff0c;需要信息化提高效率。但是矛盾的地方在于&a…

删除链表的倒数第N个节点,剑指offerII(21),力扣

目录 题目地址&#xff1a; 题目&#xff1a; 相似类型题&#xff1a; 我们直接看本题题解吧&#xff1a; 解题方法&#xff1a; 难度分析&#xff1a; 解题分析&#xff1a; 解题思路&#xff08;双指针&#xff09;&#xff1a; 代码实现&#xff1a; 代码说明&#xff1a; 代…

算法训练 第九周

一、移动零 1.双指针 我们可以设定两个指针i和j&#xff0c;其中i用来遍历整个数组&#xff0c;j用来遍历存放不为零的数的位置&#xff0c;当nums[i]不等于零时&#xff0c;就让nums[i]和nums[j]进行交换&#xff0c;然后j,当我们遍历完整个数组之后就完成了操作&#xff0c;…

基于Python的网络爬虫设计与实现

基于Python的网络爬虫设计与实现 摘要&#xff1a;从互联网时代开始&#xff0c;网络搜索引擎就变得越发重要。大数据时代&#xff0c;一般的网络搜索引擎不能满足用户的具体需求&#xff0c;人们更加注重特定信息的搜索效率&#xff0c;网络爬虫技术应运而生。本设计先对指定…