DQN基本概念和算法流程(附Pytorch代码)

news2024/12/23 17:27:24

❀DQN算法原理

DQN,Deep Q Network本质上还是Q learning算法,它的算法精髓还是让 Q 估计 Q_{估计} Q估计尽可能接近 Q 现实 Q_{现实} Q现实,或者说是让当前状态下预测的Q值跟基于过去经验的Q值尽可能接近。在后面的介绍中 Q 现实 Q_{现实} Q现实也被称为TD Target

再来回顾下DQN算法和核心思想
在这里插入图片描述

相比于Q Table形式,DQN算法用神经网络学习Q值。
在这里插入图片描述

我们可以理解为神经网络是一种估计方法,神经网络本身不是DQN的精髓,神经网络可以设计成MLP也可以设计成CNN等等,DQN的巧妙之处在于两个网络、经验回放等trick

下面介绍下DQN算法的一些trick,是希望帮助小伙伴们梳理区分两个网络的作用,阐述清楚经验回放等概念的本质,以及使用它们训练网络的技巧

Trick 1:两个网络

DQN算法采用了2个神经网络,分别是evaluate network(Q值网络)和target network(目标网络),两个网络结构完全相同

  • evaluate network用用来计算策略选择的Q值和Q值迭代更新,梯度下降、反向传播的也是evaluate network
  • target network用来计算TD Target中下一状态的Q值,网络参数更新来自evaluate network网络参数复制

设计target network目的是为了保持目标值稳定,防止过拟合,从而提高训练过程稳定和收敛速度

这里会有容易混淆的地方,梯度更新的是evaluate network的参数,不更新target network,然后每隔一段时间将evaluate network的网络参数复制给target network网络参数,那么优化器optimizer设置的时候用的也是evaluate network的parameters

Trick 2:基本框架

算法分成两个部分,分别是策略选择和策略评估,这也是强化学习算法基本的两个模块,梳理算法逻辑的时候从策略选择和策略评估两个方面入手,更容易弄清楚。策略选择部分,epsilon-greedy策略选择动作,策略评估部分使用贪婪策略

Trick 3:经验回放Experience Replay

DQN算法设计了一个固定大小的记忆库memory,用来记录经验,经验是一条一条的observation或者说是transition,它表示成 [ s , a , r , s ′ ] [s, a, r, s'] [s,a,r,s],含义是当前状态→当前状态采取的动作→获得的奖励→转移到下一个状态

一开始记忆库memory中没有经验,也没有训练evaluate network,积累了一定数量的经验之后,再开始训练evaluate network。记忆库memory中的经验可以是自己历史的经验(epsilon-greedy得到的经验),也可以学习其他人的经验。训练evaluate network的时候,是从记忆库memory中随机选择(划重点哦,是随机选择!)batch size大小的经验,喂给evaluate network

设计记忆库memory并且随机选择经验喂给evaluate network的技巧打破了相邻训练样本之间相关性,试着想下,状态→动作→奖励→下一个状态的循环是具有关联的,用相邻的样本连续训练evaluate network会带来网络过拟合泛化能力差的问题,而经验回放技巧增强了训练样本之间的独立性

❀算法流程图

每个episode流程是下面这样
在这里插入图片描述

其中choose_action、store_transition、learn是相互独立的函数模块,它们内部的算法逻辑是下面这样
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

❀Pytorch版本代码

采用Pytorch实现了DQN算法,完成了走迷宫Maze游戏,哈哈哈,这个游戏来自莫烦Python教程,代码嘛是自己修改过哒,代码贴在github上啦

ningmengzhihe/DQN_base: DQN algorithm by Pytorch - a simple maze game https://github.com/ningmengzhihe/DQN_base

(1)环境构建代码maze_env.py


import numpy as np
import time
import sys
if sys.version_info.major == 2:
    import Tkinter as tk
else:
    import tkinter as tk

UNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.n_features = 2
        self.title('maze')
        self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                           height=MAZE_H * UNIT,
                           width=MAZE_W * UNIT)

        # create grids
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin
        origin = np.array([20, 20])

        # hell
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 15, hell1_center[1] - 15,
            hell1_center[0] + 15, hell1_center[1] + 15,
            fill='black')
        # hell
        # hell2_center = origin + np.array([UNIT, UNIT * 2])
        # self.hell2 = self.canvas.create_rectangle(
        #     hell2_center[0] - 15, hell2_center[1] - 15,
        #     hell2_center[0] + 15, hell2_center[1] + 15,
        #     fill='black')

        # create oval
        oval_center = origin + UNIT * 2
        self.oval = self.canvas.create_oval(
            oval_center[0] - 15, oval_center[1] - 15,
            oval_center[0] + 15, oval_center[1] + 15,
            fill='yellow')

        # create red rect
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')

        # pack all
        self.canvas.pack()

    def reset(self):
        self.update()
        time.sleep(0.1)
        self.canvas.delete(self.rect)
        origin = np.array([20, 20])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        # return observation
        return (np.array(self.canvas.coords(self.rect)[:2]) - np.array(self.canvas.coords(self.oval)[:2]))/(MAZE_H*UNIT)

    def step(self, action):
        s = self.canvas.coords(self.rect)
        base_action = np.array([0, 0])
        if action == 0:   # up
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:   # down
            if s[1] < (MAZE_H - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:   # right
            if s[0] < (MAZE_W - 1) * UNIT:
                base_action[0] += UNIT
        elif action == 3:   # left
            if s[0] > UNIT:
                base_action[0] -= UNIT

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        next_coords = self.canvas.coords(self.rect)  # next state

        # reward function
        if next_coords == self.canvas.coords(self.oval):
            reward = 1
            done = True
        elif next_coords in [self.canvas.coords(self.hell1)]:
            reward = -1
            done = True
        else:
            reward = 0
            done = False
        s_ = (np.array(next_coords[:2]) - np.array(self.canvas.coords(self.oval)[:2]))/(MAZE_H*UNIT)
        return s_, reward, done

    def render(self):
        # time.sleep(0.01)
        self.update()

(2)DQN算法代码,包括神经网络定义、Q值更新:RL_brain.py

"""
Deep Q Network off-policy
"""
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

np.random.seed(42)
torch.manual_seed(2)


class Network(nn.Module):
    """
    Network Structure
    """
    def __init__(self,
                 n_features,
                 n_actions,
                 n_neuron=10
                 ):
        super(Network, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=n_features, out_features=n_neuron, bias=True),
            nn.Linear(in_features=n_neuron, out_features=n_actions, bias=True),
            nn.ReLU()
        )

    def forward(self, s):
        """

        :param s: s
        :return: q
        """
        q = self.net(s)
        return q


class DeepQNetwork(nn.Module):
    """
    Q Learning Algorithm
    """
    def __init__(self,
                 n_actions,
                 n_features,
                 learning_rate=0.01,
                 reward_decay=0.9,
                 e_greedy=0.9,
                 replace_target_iter=300,
                 memory_size=500,
                 batch_size=32,
                 e_greedy_increment=None):
        super(DeepQNetwork, self).__init__()

        self.n_actions = n_actions
        self.n_features = n_features
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon_max = e_greedy
        self.replace_target_iter = replace_target_iter
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.epsilon_increment = e_greedy_increment
        self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max

        # total learning step
        self.learn_step_counter = 0

        # initialize zero memory [s, a, r, s_]
        # 这里用pd.DataFrame创建的表格作为memory
        # 表格的行数是memory的大小,也就是transition的个数
        # 表格的列数是transition的长度,一个transition包含[s, a, r, s_],其中a和r分别是一个数字,s和s_的长度分别是n_features
        self.memory = pd.DataFrame(np.zeros((self.memory_size, self.n_features*2+2)))

        # build two network: eval_net and target_net
        self.eval_net = Network(n_features=self.n_features, n_actions=self.n_actions)
        self.target_net = Network(n_features=self.n_features, n_actions=self.n_actions)
        self.loss_function = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)

        # 记录每一步的误差
        self.cost_his = []


    def store_transition(self, s, a, r, s_):
        if not hasattr(self, 'memory_counter'):
            # hasattr用于判断对象是否包含对应的属性。
            self.memory_counter = 0

        transition = np.hstack((s, [a,r], s_))

        # replace the old memory with new memory
        index = self.memory_counter % self.memory_size
        self.memory.iloc[index, :] = transition

        self.memory_counter += 1

    def choose_action(self, observation):
        observation = observation[np.newaxis, :]

        if np.random.uniform() < self.epsilon:
            # forward feed the observation and get q value for every actions
            s = torch.FloatTensor(observation)
            actions_value = self.eval_net(s)
            action = [np.argmax(actions_value.detach().numpy())][0]
        else:
            action = np.random.randint(0, self.n_actions)
        return action

    def _replace_target_params(self):
        # 复制网络参数
        self.target_net.load_state_dict(self.eval_net.state_dict())

    def learn(self):
        # check to replace target parameters
        if self.learn_step_counter % self.replace_target_iter == 0:
            self._replace_target_params()
            print('\ntarget params replaced\n')

        # sample batch memory from all memory
        batch_memory = self.memory.sample(self.batch_size) \
            if self.memory_counter > self.memory_size \
            else self.memory.iloc[:self.memory_counter].sample(self.batch_size, replace=True)

        # run the nextwork
        s = torch.FloatTensor(batch_memory.iloc[:, :self.n_features].values)
        s_ = torch.FloatTensor(batch_memory.iloc[:, -self.n_features:].values)
        q_eval = self.eval_net(s)
        q_next = self.target_net(s_)

        # change q_target w.r.t q_eval's action
        q_target = q_eval.clone()

        # 更新值
        batch_index = np.arange(self.batch_size, dtype=np.int32)
        eval_act_index = batch_memory.iloc[:, self.n_features].values.astype(int)
        reward = batch_memory.iloc[:, self.n_features + 1].values

        q_target[batch_index, eval_act_index] = torch.FloatTensor(reward) + self.gamma * q_next.max(dim=1).values

        # train eval network
        loss = self.loss_function(q_target, q_eval)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.cost_his.append(loss.detach().numpy())

        # increasing epsilon
        self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max
        self.learn_step_counter += 1

    def plot_cost(self):
        plt.figure()
        plt.plot(np.arange(len(self.cost_his)), self.cost_his)
        plt.show()

(3)每个episode代码:run_this.py

from maze_env import Maze
from RL_brain import DeepQNetwork

def run_maze():
    step = 0  # 为了记录走到第几步,记忆录中积累经验(也就是积累一些transition)之后再开始学习
    for episode in range(200):
        # initial observation
        observation = env.reset()

        while True:
            # refresh env
            env.render()

            # RL choose action based on observation
            action = RL.choose_action(observation)

            # RL take action and get next observation and reward
            observation_, reward, done = env.step(action)

            # !! restore transition
            RL.store_transition(observation, action, reward, observation_)

            # 超过200条transition之后每隔5步学习一次
            if (step > 200) and (step % 5 == 0):
                RL.learn()

            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                break
            step += 1

    # end of game
    print("game over")
    env.destroy()


if __name__ == "__main__":
    # maze game
    env = Maze()
    RL = DeepQNetwork(env.n_actions, env.n_features,
                      learning_rate=0.01,
                      reward_decay=0.9,
                      e_greedy=0.9,
                      replace_target_iter=200,
                      memory_size=2000)
    env.after(100, run_maze)
    env.mainloop()
    RL.plot_cost()

❀参考资料

https://zhuanlan.zhihu.com/p/614697168
这份参考资料清晰的解释了2个Q值网络,pytorch代码值得参考

https://www.bilibili.com/video/BV13W411Y75P?p=14&vd_source=1565223f5f03f44f5674538ab582448c
莫烦Python在B站上的DQN教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

提高工作效率必备,5款实用的Windows系统工具推荐

每次分享实用的软件,都会给人一种踏实和喜悦的感觉,这也是我热衷于搜集和推荐高效工具软件的原因。 音量控制——EarTrumpet EarTrumpet是一款音量控制工具&#xff0c;可以让你更方便地调节Windows系统中不同应用程序的音量。你可以使用EarTrumpet来替代系统自带的音量混合器…

表单设计器开源的定义和应用场景布局介绍

为了实现提质增效的办公自动化&#xff0c;表单设计器开源工具的应用变得广泛起来。在低代码开发市场昌盛发展的今天&#xff0c;不少企业期望通过快速、现成的快速配置表单工具实现高效率表单制作&#xff0c;那么&#xff0c;现在给大家介绍的这款开发易用性强、组件丰富、高…

设计模式 -- 门面模式

前言 月是一轮明镜,晶莹剔透,代表着一张白纸(啥也不懂) 央是一片海洋,海乃百川,代表着一块海绵(吸纳万物) 泽是一柄利剑,千锤百炼,代表着千百锤炼(输入输出) 月央泽,学习的一种过程,从白纸->吸收各种知识->不断输入输出变成自己的内容 希望大家一起坚持这个过程,也同…

stable-diffusion真的好用吗?

hi&#xff0c;各位大佬&#xff0c;今天尝试下diffusion大模型&#xff0c;也是CV领域的GPT&#xff0c;但需要prompt&#xff0c;我给了prompt结果并不咋滴&#xff0c;如下示例&#xff0c;并附代码及参考link 1、img2img 代码实现&#xff1a; import torch from PIL im…

PageHelper的使用

这个分页插件是在Mybatis的环境中使用的&#xff0c;所以项目需要导入Mybatis依赖 更加详细的用法看官方文档&#xff1a;PageHelper官网 在Mybatis中使用 前提条件 引入依赖 <dependency><groupId>com.github.pagehelper</groupId><artifactId>pa…

GANs和Generative Adversarial Nets和Vox2Vox: 3D-GAN for Brain Tumour Segmentation

参考&#xff1a; 各种生成模型&#xff1a;VAE、GAN、flow、DDPM、autoregressive models https://blog.csdn.net/zephyr_wang/article/details/126588478李沐GAN精度 x.1 生成模型家族 DGMs&#xff08;Deep Generatitve Models&#xff09;家族主要有&#xff1a;GAN&…

数据分析的目的和意义是什么?_光点科技

数据分析是一个越来越受到关注的领域&#xff0c;因为它可以帮助企业和组织利用数据来制定更明智的决策。数据分析的目的和意义是多方面的&#xff0c;例如&#xff1a; 1.了解客户需求 数据分析可以帮助企业更好地了解客户需求&#xff0c;从而制定更准确的市场营销策略。通过…

原生JS + HTML + CSS 实现快递物流信息 API 的数据链式展示

引言 全国快递物流查询 API 是一种提供实时、准确、可靠的快递物流信息查询服务的接口。它基于现有的物流信息系统&#xff0c;通过API接口的方式&#xff0c;向用户提供快递物流信息的查询、跟踪、统计等功能。使用全国快递物流查询 API&#xff0c;用户可以在自己的应用程序…

[2021 东华杯]bg3

Index介绍漏洞利用过程一.泄露Libc二.Tcache Bin Attack三.完整EXP介绍 [2021 东华杯]bg3 本题是C写的一道经典菜单堆题&#xff0c;拥有增删改查全部功能。 Bug DataBase - V3.0 - I think i am UnBeatAble 1. Upload A Bug 2. Change A Uploaded Bug 3. Get Uploaded Bug D…

企业大数据湖总体规划及大数据湖 一体化运营管理建设方案

背景&#xff1a;数据快速入湖&#xff0c;分析更加智能&#xff0c;应用更加多样&#xff0c;服务更加开放更多企业数据将进入数据湖&#xff0c;来自传统系统的数据和传感器等新型数据资源不断融合&#xff0c;数据孤岛将继续被打破。随着大数据分析能力的不断提高&#xff0…

借助Nacos配置中心实现一个动态线程池

目录 一、实现思路 二、实现说明概览 三、代码实现 DynamicThreadPool RejectedProxyInvocationHandler DynamicThreadPoolRegister DynamicThreadPoolRefresher 测试动态线程池 平常我们系统中定义的一些线程池如果要想修改的话&#xff0c;需要修改配置重启服务才能生…

『pyqt5 从0基础开始项目实战』05. 按钮点击事件之添加新数据 (保姆级图文)

目录导包和框架代码给按钮绑定一个点击事件获取输入框的数据多线程与界面更新&#xff08;新线程与UI更新的数据交互&#xff09;代码结构完整代码main文件Threads.py总结欢迎关注 『pyqt5 从0基础开始项目实战』 专栏&#xff0c;持续更新中 欢迎关注 『pyqt5 从0基础开始项目…

上海亚商投顾:沪指创年内新高 大金融、中字头集体走强

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 市场情绪 沪指今日低开高走&#xff0c;午后涨超1%&#xff0c;创出近10个月以来新高&#xff0c;创业板指走势较弱&#xf…

不走弯路,AI真的能提高生产效率

AI应用虽然取得了令人瞩目的成果&#xff0c;但是在实际应用中仍存在不少困境。市面上不乏有AI绘画、AI写作、AI聊天的相关产品&#xff0c;即使Chatgpt可以写代码、写论文&#xff0c;但由于技术的有限性&#xff0c;还需要不断地优化完善才能给出更精准的答复&#xff0c;也少…

契约锁与多家软件行业伙伴达成战略合作,携手助力组织数字化转型

近日&#xff0c;契约锁电子签章与天翼云、神州数码、同望科技、宏灿软件、甄零科技、正量科技等多家软件行业伙伴达成战略合作&#xff0c;充分发挥各自专业与资源优势&#xff0c;从产品、市场、销售、技术等多方面展开深度合作&#xff0c;共同为客户提供全程数字化解决方案…

zabbix创建自定义监控模板

目录 第一章先行配置zabbix 第二章配置自定义 2.1.案列&#xff1a;自定义监控客户端服务器登录的人数需求&#xff1a;限制登录人数不超过 3 个&#xff0c;超过 3 个就发出报警信息 2.2.在 Web 页面创建自定义监控项模板 2.3.zabbix 自动发现与自动注册 总结 自定义监控…

【论文精度(李沐老师)】Generative Adversarial Nets

Abstract 我们提出了一个新的framework&#xff0c;通过一个对抗的过程来估计生成模型&#xff0c;其中会同时训练两个模型&#xff1a;生成模型G来获取整个数据的分布&#xff0c;辨别模型D来分辨数据是来自于训练样本还是生成模型G。生成模型G的任务是尽量的让辨别模型D犯错…

DI依赖注入

DI依赖注入Setter注入setter注入引用类型setter注入简单类型&#xff08;基本数据类型和字符串&#xff09;构造器注入构造器注入引用类型自动装配集合注入首先我们明确一些观点1、注入的Bean的数据包括引用类型与简单类型&#xff08;基本数据类型和字符串&#xff09;2、通过…

HTML5 地理定位

HTML5 Geolocation&#xff08;地理定位&#xff09; HTML5 Geolocation&#xff08;地理定位&#xff09;用于定位用户的位置。 Geolocation 通过请求一个位置信息&#xff0c;用户同意后&#xff0c;浏览器会返回一个包含经度和维度的位置信息&#xff01; 定位用户的位置 …

【C语言数组部分】

数组部分综述引入&#xff1a;数组概念&#xff1a;一、一维数组1.1一维数组的创建&#xff1a;1.2一维数组的初始化&#xff1a;1.2.1初始化概念&#xff1a;1.2.2完全初始化&#xff1a;1.2.3不完全初始化&#xff1a;1.3字符数组的初始化&#xff1a;1.3.1用字符初始化&…