强化学习------DQN算法

news2024/11/27 13:51:00

简介

DQN,即深度Q网络(Deep Q-network),是指基于深度学习的Q-Learing算法。Q-Learing算法维护一个Q-table,使用表格存储每个状态s下采取动作a获得的奖励,即状态-价值函数Q(s,a),这种算法存在很大的局限性。在现实中很多情况下,强化学习任务所面临的状态空间是连续的,存在无穷多个状态,这种情况就不能再使用表格的方式存储价值函数。
为了解决这个问题,我们可以用一个函数Q(s,a;w)来近似动作-价值Q(s,a),称为价值函数近似Value Function Approximation,我们用神经网络来生成这个函数Q(s,a;w),称为Q网络(Deep Q-network),w是神经网络训练的参数。

Q-Learning参考:https://blog.csdn.net/niulinbiao/article/details/133659036

DQN相较于传统的强化学习算法(Q-learning)有三大重要的改进:

  • 引入深度学习中的神经网络,利用神经网络去拟合Q-learning中的Q表,解决了Q-learning中,当状态维数过高时产生的“维数灾难”问题;

  • 固定Q目标网络,利用延后更新的目标网络计算目标Q值,极大的提高了网络训练的稳定性和收敛性;

  • 引入经验回放机制,使得在进行网络更新时输入的数据符合独立同分布,打破了数据间的相关性。

本文还增加了动态探索概率,也就是随着模型的训练,我们有必要减少探索的概率

DQN的算法流程如下:

在这里插入图片描述

  • 首先,算法开始前随机选择一个初始状态,然后基于这个状态选择执行动作,这里需要进行一个判断,即是通过Q-Network选择一个Q值最大对应的动作,还是在动作空间中随机选择一个动作。
  • 在程序编程中,由于刚开始时,Q-Network中的相关参数是随机的,所以在经验池存满之前,通常将设置的很小,即初期基本都是随机选择动作。
  • 在动作选择结束后,agent将会在环境(Environment)中执行这个动作,随后环境会返回下一状态(S_)和奖励(R),这时将四元组(S,A,R,S_)存入经验池。
  • 接下来将下一个状态(S_)视为当前状态(S),重复以上步骤,直至将经验池存满。
  • 当经验池存满之后,DQN中的网络开始更新。即开始从经验池中随机采样,将采样得到的奖励(R)和下一个状态(S_)送入目标网络计算下一Q值(y),并将y送入Q-Network计算loss值,开始更新Q-Network。往后就是agent与环境交互,产生经验(S,A,R,S_),并将经验放入经验池,然后从经验池中采样更新Q-Network,周而复始,直到Q-Network完成收敛。

在这里插入图片描述

  • DQN中目标网络的参数更新是硬更新,即主网络(Q-Network)参数更新一定步数后,将主网络更新后的参数全部复制给目标网络(Target
    Q-Network)。
  • 在程序编程中,通常将设置成随训练步数的增加而递增,即agent越来越信任Q-Network来指导动作。

代码实现

1、环境准备

我们选择openAIgym环境作为我们训练的环境

  env1 = gym.make("CartPole-v0")

在这里插入图片描述

2、编写经验池函数

经验池的主要内容就是,存数据和取数据

import random
import collections
from torch import FloatTensor

class ReplayBuffer(object):
    # 初始化
    def __init__(self, max_size, num_steps=1 ):
        """
        
        :param max_size: 经验吃大小
        :param num_steps: 每经过训练num_steps次后,函数就学习一次
        """
        self.buffer = collections.deque(maxlen=max_size)
        self.num_steps  = num_steps

    def append(self, exp):
        """
        想经验池添加数据
        :param exp: 
        :return: 
        """
        self.buffer.append(exp)

    def sample(self, batch_size):
        """
        向经验池中获取batch_size个(obs_batch,action_batch,reward_batch,next_obs_batch,done_batch)这样的数据
        :param batch_size: 
        :return: 
        """
        mini_batch = random.sample(self.buffer, batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*mini_batch)
        obs_batch = FloatTensor(obs_batch)
        action_batch = FloatTensor(action_batch)
        reward_batch = FloatTensor(reward_batch)
        next_obs_batch = FloatTensor(next_obs_batch)
        done_batch = FloatTensor(done_batch)
        return obs_batch,action_batch,reward_batch,next_obs_batch,done_batch

    def __len__(self):
        return len(self.buffer)

3、神经网络模型

我们简单地使用神经网络

import torch

class MLP(torch.nn.Module):

    def __init__(self, obs_size,n_act):
        super().__init__()
        self.mlp = self.__mlp(obs_size,n_act)

    def __mlp(self,obs_size,n_act):
        return torch.nn.Sequential(
            torch.nn.Linear(obs_size, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, n_act)
        )

    def forward(self, x):
        return self.mlp(x)

4、探索率衰减函数

随着训练过程,我们动态地减小探索率,因为训练到后面,模型会越来越收敛,没必要继续探索

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import numpy as np

class EpsilonGreedy():

    def __init__(self,n_act,e_greed,decay_rate):
        self.n_act = n_act
        self.epsilon = e_greed
        self.decay_rate = decay_rate


    def act(self,predict_func,obs):
        if np.random.uniform(0, 1) < self.epsilon:  # 探索
            action = np.random.choice(self.n_act)
        else:  # 利用
            action = predict_func(obs)
        self.epsilon = max(0.01,self.epsilon-self.decay_rate)   #是探索率最低为0.01
        return action

5、DQN算法

import copy

import numpy as np
import torch
from utils import torchUtils

# 添加探索值递减的策略
class DQNAgent(object):

    def __init__( self, q_func, optimizer, replay_buffer, batch_size, replay_start_size,update_target_steps, n_act,explorer, gamma=0.9):
        '''
        :param q_func: Q函数
        :param optimizer: 优化器
        :param replay_buffer: 经验回放器
        :param batch_size: 批次数量
        :param replay_start_size: 开始回放的次数
        :param update_target_steps: 经过多少步才会同步target网络
        :param n_act: 动作数量
        :param gamma: 收益衰减率
        :param e_greed: 探索与利用中的探索概率
        '''
        self.pred_func = q_func
        self.target_func = copy.deepcopy(q_func)
        self.update_target_steps = update_target_steps
        self.explorer = explorer

        self.global_step = 0  #全局

        self.rb = replay_buffer
        self.batch_size = batch_size
        self.replay_start_size = replay_start_size

        self.optimizer = optimizer
        self.criterion = torch.nn.MSELoss()

        self.n_act = n_act  # 动作数量
        self.gamma = gamma  # 收益衰减率

    # 根据经验得到action
    def predict(self, obs):
        obs = torch.FloatTensor(obs)
        Q_list = self.pred_func(obs)
        action = int(torch.argmax(Q_list).detach().numpy())
        return action

    # 根据探索与利用得到action
    def act(self, obs):
        return self.explorer.act(self.predict,obs)

    def learn_batch(self,batch_obs, batch_action, batch_reward, batch_next_obs, batch_done):

        # predict_Q
        pred_Vs = self.pred_func(batch_obs)
        action_onehot = torchUtils.one_hot(batch_action, self.n_act)
        predict_Q = (pred_Vs * action_onehot).sum(1)
        # target_Q
        next_pred_Vs = self.target_func(batch_next_obs)
        best_V = next_pred_Vs.max(1)[0]
        target_Q = batch_reward + (1 - batch_done) * self.gamma * best_V

        # 更新参数
        self.optimizer.zero_grad()
        loss = self.criterion(predict_Q, target_Q)
        loss.backward()
        self.optimizer.step()

    def learn(self, obs, action, reward, next_obs, done):
        self.global_step+=1
        self.rb.append((obs, action, reward, next_obs, done))
        #当经验池中到的数据足够多时,并且满足每训练num_steps轮就更新一次参数
        if len(self.rb) > self.replay_start_size and self.global_step%self.rb.num_steps==0:
            self.learn_batch(*self.rb.sample(self.batch_size))
        #我们每训练update_target_steps轮就同步目标网络
        if self.global_step%self.update_target_steps==0:
            self.sync_target()

    # 同步target网络
    def sync_target(self):
        for target_param,param in zip(self.target_func.parameters(),self.pred_func.parameters()):
            target_param.data.copy_(param.data)

6、训练代码


import dqn,modules,replay_buffers
import gym
import torch
from explorers import  EpsilonGreedy

class TrainManager():

    def __init__(self,
                 env,  #环境
                 episodes=1000,  #轮次数量
                 batch_size=32,  #每一批次的数量
                 num_steps=4,  #进行学习的频次
                 memory_size = 2000,  #经验回放池的容量
                 replay_start_size = 200,  #开始回放的次数
                 update_target_steps=200,  #经过训练update_target_steps次后将参数同步给target网络
                 lr=0.001,  #学习率
                 gamma=0.9,  #收益衰减率
                 e_greed=0.1,  #探索与利用中的探索概率
                 e_greed_decay=1e-6, #探索率衰减值
                 ):
        self.env = env
        self.episodes = episodes

        n_act = env.action_space.n
        n_obs = env.observation_space.shape[0]
        q_func = modules.MLP(n_obs, n_act)
        optimizer = torch.optim.AdamW(q_func.parameters(), lr=lr)
        rb = replay_buffers.ReplayBuffer(memory_size,num_steps)

        explorer = EpsilonGreedy(n_act,e_greed,e_greed_decay)

        self.agent = dqn.DQNAgent(
            q_func=q_func,
            optimizer=optimizer,
            replay_buffer = rb,
            batch_size=batch_size,
            update_target_steps=update_target_steps,
            replay_start_size = replay_start_size,
            n_act=n_act,
            explorer = explorer,
            gamma=gamma)

    # 训练一轮游戏
    def train_episode(self):
        total_reward = 0
        obs = self.env.reset()
        while True:
            action = self.agent.act(obs)
            next_obs, reward, done, _ = self.env.step(action)
            total_reward += reward
            self.agent.learn(obs, action, reward, next_obs, done)
            obs = next_obs
            if done: break
        print('e_greed=',self.agent.explorer.epsilon)
        return total_reward

    # 测试一轮游戏
    def test_episode(self):
        total_reward = 0
        obs = self.env.reset()
        while True:
            action = self.agent.predict(obs)
            next_obs, reward, done, _ = self.env.step(action)
            total_reward += reward
            obs = next_obs
            self.env.render()
            if done: break
        return total_reward

    def train(self):
        for e in range(self.episodes):
            ep_reward = self.train_episode()
            print('Episode %s: reward = %.1f' % (e, ep_reward))
            #每训练100轮我们就测试一轮
            if e % 100 == 0:
                test_reward = self.test_episode()
                print('test reward = %.1f' % (test_reward))


if __name__ == '__main__':
    env1 = gym.make("CartPole-v0")
    tm = TrainManager(env1)
    tm.train()

实现效果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1070320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构:链式二叉树

上一章讲了堆,堆是完全二叉树的顺序存储结构,本章将要全面讲解一下二叉树的链式存储结构即链式二叉树 我们已经学习了二叉树的概念和性质了,本章重点学习二叉树相关操作,可以更好的理解分治算法思想;也需要对递归有更深次的理解. 其实普通的链式二叉树的增删查改没有什么意义,…

COLLABORATIVE DESIGNER FOR SOLIDWORKS® 新功能

共享和标注 优点&#xff1a;收件人在浏览器中访问共享文 件&#xff0c;无需安装3DEXPERIENCE 平台应用程序。 • 与 SOLIDWORKS 中来自您组织内部或外部的任何人无缝 共享您的设计。 • 直接将评论和标注附加到您的设计作品中&#xff0c;便于立即获得 反馈。 支持 SOLIDWO…

深入理解强化学习——强化学习的基础知识

分类目录&#xff1a;《深入理解强化学习》总目录 在机器学习领域&#xff0c;有一类任务和人的选择很相似&#xff0c;即序贯决策&#xff08;Sequential Decision Making&#xff09;任务。决策和预测任务不同&#xff0c;决策往往会带来“后果”&#xff0c;因此决策者需要为…

Centos7安装MongoDB7.xxNoSQL数据库|设置开机启动(骨灰级+保姆级)

一: mongodb下载 MongoDB 社区免费下载版 MongoDB社区下载版 [rootwww tools]# wget https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel70-7.1.0-rc4.tgz 二: 解压到指定目录 [rootwww tools]# mkdir -p /usr/local/mongodb [rootwww tools]# tar -zxvf mongodb-…

Linux目录和文件查看命令

一、Linux 的目录结构 Linux 的目录结构是一个树状结构&#xff0c;以根目录&#xff08;/&#xff09;为起点&#xff0c;以下是常见的 Linux 目录结构的主要内容&#xff1a; / 根路径 ├── bin: 存放系统指令&#xff08;命令&#xff09;&#xff0c;如ls、cp、mv等&…

ARM-流水灯

.text .global _start _start: 1、设置GPIOE寄存器的时钟使能 RCC_MP_AHB$ENSETR[4]->1 0x50000a28LDR R0,0X50000A28 LDR R1,[R0] 从R0起始地址的4字节数据取出放在R1 ORR R1,R1,#(0X3<<4) 第4位设置为1 STR R1,[R0] 写回2、设置PE10、PE8、PF10管脚为输出模式 …

Observability:使用 OpenTelemetry 对 Node.js 应用程序进行自动检测

作者&#xff1a;Bahubali Shetti DevOps 和 SRE 团队正在改变软件开发的流程。 DevOps 工程师专注于高效的软件应用程序和服务交付&#xff0c;而 SRE 团队是确保可靠性、可扩展性和性能的关键。 这些团队必须依赖全栈可观察性解决方案&#xff0c;使他们能够管理和监控系统&a…

学习记忆——数学篇——案例——算术——记忆100内质数

文章目录 质数表歌诀记忆法100以内的质数歌谣质数口决一百以内质数口诀100以内素数歌 规律记忆法100以内6的倍数前、后位置上的两个数&#xff0c;只要不是5或7的倍数&#xff0c;就一定是质数个数没有用该数除以包括7在内的质数 分类记忆法数字编码法谐音记忆法 100以内的质数…

Matlab随机变量的数字特征

目录 1、均值&#xff08;数学期望&#xff09; 2、中位数 3、几何平均数 4、调和平均数 5、数据排序 6、众数 7、极差&#xff08;最大值和最小值之差&#xff09; 8、方差与均方差&#xff08;标准差&#xff09; 9、变异系数 10、常见分布的期望与方差的计算 11、协方…

ElasticSearch 学习8 :ik分词器的扩展,及java调用ik分词器的analyzer

1.前言&#xff1a; 上篇已经说过ik的集成&#xff0c;这篇说下ik的实际使用 2.2、IK分词器测试 IK提供了两个分词算法ik_smart 和 ik_max_word ik_smart&#xff1a;为最少切分ik_max_word&#xff1a;为最细粒度划分。 2.2.1、最小切分示例 #分词器测试ik_smart POST _…

互联网项目有哪些值得做的

互联网已经融入了我们生活的方方面面&#xff0c;从电商巨头到科技创新&#xff0c;互联网带来的变革和便利无处不在。而在这个信息广泛的时代&#xff0c;越来越多的人开始思考如何利用互联网去创造价值。现如今&#xff0c;互联网项目的形式多种多样&#xff0c;有些让我们的…

剑指offer——JZ79 判断是不是平衡二叉树 解题思路与具体代码【C++】

一、题目描述与要求 判断是不是平衡二叉树_牛客题霸_牛客网 (nowcoder.com) 题目描述 输入一棵节点数为 n 二叉树&#xff0c;判断该二叉树是否是平衡二叉树。 在这里&#xff0c;我们只需要考虑其平衡性&#xff0c;不需要考虑其是不是排序二叉树 平衡二叉树&#xff08;…

【Java 进阶篇】深入了解HTML表单标签

HTML&#xff08;Hypertext Markup Language&#xff09;表单标签是网页开发中的重要组成部分&#xff0c;用于创建各种交互式元素&#xff0c;允许用户输入、提交和处理数据。本文将深入探讨HTML表单标签&#xff0c;包括如何创建表单、各种输入元素、表单属性以及一些最佳实践…

C++学习day2

作业&#xff1a; 1> 思维导图 2>自己封装一个矩形类(Rect)&#xff0c;拥有私有属性:宽度(width)、高度(height)&#xff0c; 定义公有成员函数: 初始化函数:void init(int w, int h) 更改宽度的函数:set_w(int w) 更改高度的函数:set_h(int h) 输出该矩形的周长和…

jenkins工具系列 —— 插件 使用Changelog获取commit记录

文章目录 安装changelog插件重启jenkins配置 ChangelogExecute shell 使用 changelog邮件中html格式也可以使用构建测试&#xff08;查看构建项 -> 控制台输出&#xff09; 安装changelog插件 插件文件可通过 V 获取 点击 左侧的 Manage Jenkins —> Plugins ——> …

Docker安装——Ubuntu (Jammy 22.04)

一、为什么要用 Ubuntu&#xff1f;(centos和ubuntu有什么区别&#xff09; 使用lsb_release命令&#xff1a;lsb_release -a &#xff0c;即可查看ubantu的版本&#xff0c;但是为什么要使用ubantu 呢&#xff1f; 区别&#xff1a;1、centos基于EHEL开发&#xff0c;而ubunt…

2023年10月8日

三盏灯流水 .text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[5:4]->1 0x50000a28 LDR R0,0X50000A28 LDR R1,[R0] 从r0为起始地址的4字节数据取出放在R1 ORR R1,R1,#(0x3<<4) 第4位设置为1 STR R1,[R0] 写回2.设置PE10管脚为…

SketchyCOCO数据集进行前景图像、背景图像和全景图像的分类

SketchyCOCO数据集进行前景图像、背景图像和全景图像的分类 import os import shutildef CopyFile(src, dst, filename):if not os.path.exists(dst):os.makedirs(dst)print(create dir: dst)try:shutil.copy(src\\filename, dst\\filename)except Exception as e:print(cop…

经典算法-----农夫过河问题(深度优先搜索)

目录 前言 农夫过河问题 1.问题描述 2.解决思路 位置编码 获取位置 判断是否安全 深度优先遍历&#xff08;核心算法&#xff09; 3.完整代码 前言 今天我们来解决一个有意思的问题&#xff0c;也就是农夫过河问题&#xff0c;可能这个问题我们小时候上学就听说过类似的…

分布式缓存-Redis集群

单点Redis的问题 数据丢失问题 Redis是内存存储&#xff0c;服务重启可能会丢失数据 并发能力问题 单节点Redis并发能力虽然不错&#xff0c;但也无法满足如618这样的高并发场景 故障恢复问题 如果Redis宕机&#xff0c;则服务不可用&#xff0c;需要一种自动的故障恢复手段…