强化学习从基础到进阶-案例与实践[4.2]:深度Q网络DQN-Cart pole游戏展示

news2024/11/24 3:47:15

在这里插入图片描述
【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现

在这里插入图片描述
专栏详细介绍:【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现

对于深度强化学习这块规划为:

  • 基础单智能算法教学(gym环境为主)
  • 主流多智能算法教学(gym环境为主)
    • 主流算法:DDPG、DQN、TD3、SAC、PPO、RainbowDQN、QLearning、A2C等算法项目实战
  • 一些趣味项目(超级玛丽、下五子棋、斗地主、各种游戏上应用)
  • 单智能多智能题实战(论文复现偏业务如:无人机优化调度、电力资源调度等项目应用)

本专栏主要方便入门同学快速掌握强化学习单智能体|多智能体算法原理+项目实战。后续会持续把深度学习涉及知识原理分析给大家,让大家在项目实操的同时也能知识储备,知其然、知其所以然、知何由以知其所以然。

声明:部分项目为网络经典项目方便大家快速学习,后续会不断增添实战环节(比赛、论文、现实应用等)

  • 专栏订阅(个性化选择):

    • 强化学习原理+项目专栏大合集-《推荐订阅☆☆☆☆☆》

    • 强化学习单智能体算法原理+项目实战《推荐订阅☆☆☆☆》

    • 强化学习多智能体原理+项目实战《推荐订阅☆☆☆☆☆》

    • 强化学习相关技巧(调参、画图等《推荐订阅☆☆☆》)

    • tensorflow_gym-强化学习:免费《推荐订阅☆☆☆☆》

    • 强化学习从基础到进阶-案例与实践:免费《推荐订阅☆☆☆☆☆》

强化学习从基础到进阶-案例与实践[4.2]:深度Q网络DQN-Cart pole游戏展示

  • 强化学习(Reinforcement learning,简称RL)是机器学习中的一个领域,区别与监督学习和无监督学习,强调如何基于环境而行动,以取得最大化的预期利益。
  • 基本操作步骤:智能体agent在环境environment中学习,根据环境的状态state(或观测到的observation),执行动作action,并根据环境的反馈reward(奖励)来指导更好的动作。

比如本项目的Cart pole小游戏中,agent就是动图中的杆子,杆子有向左向右两种action

## 安装依赖
!pip install pygame
!pip install gym
!pip install atari_py
!pip install parl
import gym
import os
import random
import collections

import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F

1.经验回放部分

经验回放主要做的事情是:把结果存入经验池,然后经验池中随机取出一条结果进行训练。

这样做有两个好处:

  1. 减少样本之间的关联性
  2. 提高样本的利用率

之所以加入experience replay是因为样本是从游戏中的连续帧获得的,这与简单的reinforcement learning问题相比,样本的关联性大了很多,如果没有experience replay,算法在连续一段时间内基本朝着同一个方向做gradient descent,那么同样的步长下这样直接计算gradient就有可能不收敛。因此experience replay是从一个memory pool中随机选取了一些expeirence,然后再求梯度,从而避免了这个问题。

class ReplayMemory(object):
    def __init__(self, max_size):
        self.buffer = collections.deque(maxlen=max_size)

    # 增加一条经验到经验池中
    def append(self, exp):
        self.buffer.append(exp)

    # 从经验池中选取N条经验出来
    def sample(self, batch_size):
        mini_batch = random.sample(self.buffer, batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []

        for experience in mini_batch:
            s, a, r, s_p, done = experience
            obs_batch.append(s)
            action_batch.append(a)
            reward_batch.append(r)
            next_obs_batch.append(s_p)
            done_batch.append(done)

        return np.array(obs_batch).astype('float32'), np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'), np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')

    def __len__(self):
        return len(self.buffer)

2.DQN

DQN算法较普通算法在经验回放和固定Q目标有了较大的改进,主要原因:

  • 经验回放:他充分利用了off-colicp的优势,通过训练把结果(成绩)存入Q表格,然后随机从表格中取出一条结果进行优化。这样子一方面可以:减少样本之间的关联性另一方面:提高样本的利用率 注:训练结果会存进Q表格,当Q表格满了以后,存进来的数据会把最早存进去的数据“挤出去”(弹出)
  • 固定Q目标他解决了算法更新不平稳的问题。 和监督学习做比较,监督学习的最终值要逼近实际结果,这个结果是固定的,但是我们的DQN却不是,他的目标值是经过神经网络以后的一个值,那么这个值是变动的不好拟合,怎么办,DQN团队想到了一个很好的办法,让这个值在一定时间里面保持不变,这样子这个目标就可以确定了,然后目标值更新以后更加接近实际结果,可以更好的进行训练。

3.模型Model

这里的模型可以根据自己的需求选择不同的神经网络组建。

DQN用来定义前向(Forward)网络,可以自由的定制自己的网络结构。

class DQN(nn.Layer):
    def __init__(self, outputs):
        super(DQN, self).__init__()
        self.linear1 = nn.Linear(in_features=4, out_features=128)
        self.linear2 = nn.Linear(in_features=128, out_features=24)
        self.linear3 = nn.Linear(in_features=24, out_features=outputs)

    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

4.智能体Agent的学习函数

这里包括模型探索与模型训练两个部分

Agent负责算法与环境的交互,在交互过程中把生成的数据提供给Algorithm来更新模型(Model),数据的预处理流程也一般定义在这里。

def sample(obs, MODEL):
    global E_GREED
    global ACTION_DIM
    global E_GREED_DECREMENT
    sample = np.random.rand()  # 产生0~1之间的小数
    if sample < E_GREED:
        act = np.random.randint(ACTION_DIM)  # 探索:每个动作都有概率被选择
    else:
        obs = np.expand_dims(obs, axis=0)
        obs = paddle.to_tensor(obs, dtype='float32')
        act = MODEL(obs)
        act = np.argmax(act.numpy())  # 选择最优动作
    E_GREED = max(0.01, E_GREED - E_GREED_DECREMENT)  # 随着训练逐步收敛,探索的程度慢慢降低
    return act


def learn(obs, act, reward, next_obs, terminal, TARGET_MODEL, MODEL):
    global global_step
    # 每隔200个training steps同步一次model和target_model的参数
    if global_step % 50 == 0:
        TARGET_MODEL.load_dict(MODEL.state_dict())
    global_step += 1

    obs = np.array(obs).astype('float32')
    next_obs = np.array(next_obs).astype('float32')
    # act = np.expand_dims(act, -1)
    cost = optimize_model(obs, act, reward, next_obs,
                          terminal, TARGET_MODEL, MODEL)  # 训练一次网络
    return cost


def optimize_model(obs, action, reward, next_obs, terminal, TARGET_MODEL, MODEL):
    """
    使用DQN算法更新self.model的value网络
    """
    # 从target_model中获取 max Q' 的值,用于计算target_Q
    global E_GREED
    global ACTION_DIM
    global E_GREED_DECREMENT
    global GAMMA
    global LEARNING_RATE
    global opt

    opt = paddle.optimizer.Adam(learning_rate=LEARNING_RATE,
                                parameters=MODEL.parameters())  # 优化器(动态图)

    obs = paddle.to_tensor(obs)
    next_obs = paddle.to_tensor(next_obs)

    next_pred_value = TARGET_MODEL(next_obs).detach()
    best_v = paddle.max(next_pred_value, axis=1)
    target = reward + (1.0 - terminal) * GAMMA * best_v.numpy()
    target = paddle.to_tensor(target)
    pred_value = MODEL(obs)  # 获取Q预测值
    # 将action转onehot向量,比如:3 => [0,0,0,1,0]
    action = paddle.to_tensor(action.astype('int32'))
    action_onehot = F.one_hot(action, ACTION_DIM)
    action_onehot = paddle.cast(action_onehot, dtype='float32')
    # 下面一行是逐元素相乘,拿到action对应的 Q(s,a)
    pred_action_value = paddle.sum(paddle.multiply(action_onehot, pred_value), axis=1)
    # 计算 Q(s,a) 与 target_Q的均方差,得到loss
    cost = F.square_error_cost(pred_action_value, target)
    cost = paddle.mean(cost)
    avg_cost = cost
    cost.backward()
    opt.step()
    opt.clear_grad()

    return avg_cost.numpy()

5.模型梯度更新算法

def run_train(env, rpm, TARGET_MODEL, MODEL):
    MODEL.train()
    TARGET_MODEL.train()
    total_reward = 0
    obs = env.reset()

    global global_step
    while True:
        global_step += 1
        # 获取随机动作和执行游戏
        action = sample(obs, MODEL)

        next_obs, reward, isOver, info = env.step(action)

        # 记录数据
        rpm.append((obs, action, reward, next_obs, isOver))

        # 在预热完成之后,每隔LEARN_FREQ步数就训练一次
        if (len(rpm) > MEMORY_WARMUP_SIZE) and (global_step % LEARN_FREQ == 0):
            (batch_obs, batch_action, batch_reward, batch_next_obs, batch_isOver) = rpm.sample(BATCH_SIZE)
            train_loss = learn(batch_obs, batch_action, batch_reward,
                               batch_next_obs, batch_isOver, TARGET_MODEL, MODEL)

        total_reward += reward
        obs = next_obs.astype('float32')

        # 结束游戏
        if isOver:
            break
    return total_reward


def evaluate(model, env, render=False):
    model.eval()
    eval_reward = []
    for i in range(5):
        obs = env.reset()
        episode_reward = 0
        while True:
            obs = np.expand_dims(obs, axis=0)
            obs = paddle.to_tensor(obs, dtype='float32')
            action = model(obs)
            action = np.argmax(action.numpy())
            obs, reward, done, _ = env.step(action)
            episode_reward += reward
            if render:
                env.render()
            if done:
                break
        eval_reward.append(episode_reward)
    return np.mean(eval_reward)

6.训练函数与验证函数

设置超参数

LEARN_FREQ = 5  # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率
MEMORY_SIZE = 20000  # replay memory的大小,越大越占用内存
MEMORY_WARMUP_SIZE = 200  # replay_memory 里需要预存一些经验数据,再开启训练
BATCH_SIZE = 32  # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来
LEARNING_RATE = 0.001  # 学习率大小
GAMMA = 0.99  # reward 的衰减因子,一般取 0.9 到 0.999 不等

E_GREED = 0.1  # 探索初始概率
E_GREED_DECREMENT = 1e-6  # 在训练过程中,降低探索的概率
MAX_EPISODE = 20000  # 训练次数
SAVE_MODEL_PATH = "models/save"  # 保存模型路径
OBS_DIM = None
ACTION_DIM = None
global_step = 0
def main():
    global OBS_DIM
    global ACTION_DIM

    train_step_list = []
    train_reward_list = []
    evaluate_step_list = []
    evaluate_reward_list = []

    # 初始化游戏
    env = gym.make('CartPole-v0')
    # 图像输入形状和动作维度
    action_dim = env.action_space.n
    obs_dim = env.observation_space.shape
    OBS_DIM = obs_dim
    ACTION_DIM = action_dim
    max_score = -int(1e4)

    # 创建存储执行游戏的内存
    rpm = ReplayMemory(MEMORY_SIZE)
    MODEL = DQN(ACTION_DIM)
    TARGET_MODEL = DQN(ACTION_DIM)
    # if os.path.exists(os.path.dirname(SAVE_MODEL_PATH)):
    #     MODEL_DICT = paddle.load(SAVE_MODEL_PATH+'.pdparams')
    #     MODEL.load_dict(MODEL_DICT)  # 加载模型参数
    print("filling memory...")
    while len(rpm) < MEMORY_WARMUP_SIZE:
        run_train(env, rpm, TARGET_MODEL, MODEL)
    print("filling memory done")

    # 开始训练
    episode = 0

    print("start training...")
    # 训练max_episode个回合,test部分不计算入episode数量
    while episode < MAX_EPISODE:
        # train part
        for i in range(0, int(50)):
            # First we need a state
            total_reward = run_train(env, rpm, TARGET_MODEL, MODEL)
            episode += 1
        
        # print("episode:{}    reward:{}".format(episode, str(total_reward)))

        # test part
        # print("start evaluation...")
        eval_reward = evaluate(TARGET_MODEL, env)
        print('episode:{}    e_greed:{}   test_reward:{}'.format(episode, E_GREED, eval_reward))

        evaluate_step_list.append(episode)
        evaluate_reward_list.append(eval_reward)

        # if eval_reward > max_score or not os.path.exists(os.path.dirname(SAVE_MODEL_PATH)):
        #     max_score = eval_reward
        #     paddle.save(TARGET_MODEL.state_dict(), SAVE_MODEL_PATH+'.pdparams')  # 保存模型


if __name__ == '__main__':
    main()

filling memory…
filling memory done
start training…
episode:50 e_greed:0.0992949999999993 test_reward:9.0
episode:100 e_greed:0.0987909999999988 test_reward:9.8
episode:150 e_greed:0.09827199999999828 test_reward:10.0
episode:200 e_greed:0.09777599999999778 test_reward:8.8
episode:250 e_greed:0.09726999999999728 test_reward:9.0
episode:300 e_greed:0.09676199999999677 test_reward:10.0
episode:350 e_greed:0.0961919999999962 test_reward:14.8

项目链接fork一下即可运行

https://www.heywhale.com/mw/project/649e7d3f70567260f8f11d2b

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于PaddleDetection fairmot目标跟踪 C++ 部署

1 源码下载 PaddleDetection 2 工程编译 参考&#xff1a;paddle 目标检测C部署流程 3 导出模型 python tools/export_model.py -c configs/mot/fairmot/fairmot_dla34_30e_576x320.yml --output_dir ./inference -o weightshttps://paddledet.bj.bcebos.com/models/mot/…

AutoSAR系列讲解(入门篇)4.7-BSW的Diagnostics功能

一、架构与术语解释 首先简单介绍以下诊断&#xff08;Diagnostics&#xff09;&#xff0c;由于百度百科中就有很好的解释&#xff0c;这里直接引用一下&#xff1a; 汽车诊断技术是凭借仪器设备对汽车进行性能测试和故障检查的方法和手段&#xff0c;它能够测试出汽车各项工作…

CMU 15-445 -- 存储篇 - 02

CMU 15-445 -- 存储篇 - 02 引言Database StorageDisk Manager 简介计算机存储体系为什么不使用 OS 自带的磁盘管理模块磁盘管理模块的核心问题在文件中表示数据库File StorageDatabase PagesHeap File OrganizationPage LayoutData LayoutTuple LayoutTuple Storage Data Repr…

Windows Update当前无法检查更新怎么办?

当进行Windows更新或升级时&#xff0c;可能会提示“Windows Update当前无法检查更新&#xff0c;因为未运行服务。您可能需要重新启动计算机”。而当重启也无法解决问题时&#xff0c;我们该怎么办呢&#xff1f;下面我们就来了解一下。 1、删除Software Distribution文件夹中…

Hyperledger Fabric交易流程分析

1、交易流程 客户端利用受支持的SDK(Golang、Java、Node、Python)提供的API构建交易提案请求&#xff0c;将交易事务提案打包成为一个正确的格式。交易提案包含如下要素&#xff1a; ①channelID:通道信息。 ②chaincodelD:要调用的链码信息。 ③timestamp:时间戳。 ④sign:客户…

百度智能车竞赛丝绸之路2——手柄控制

百度智能车竞赛丝绸之路1——智能车设计与编程实现控制 百度智能车竞赛丝绸之路2——手柄控制 一、机器人设计 二、实现原理 本教程使用Python的Serial库和Struct二进制数据解析库去实现Xbox手柄百度大脑学习开发板&#xff08;上位机&#xff09;和机器人控制器&#xff08;…

9小时通关 黑马新教程发布,含重磅项目~

随着测试行业的蓬勃发展&#xff0c;对从业者的要求越来越高&#xff0c;自动化测试已经成为软件测试中一个重要组成部分&#xff0c;广泛应用于各行各业。甚至&#xff0c;在圈子中还流传着这样一句话&#xff1a;学好测试自动化&#xff0c;年薪30万不在话下&#xff01; 今…

Qt读写文件

一、界面 项目文件结构 样例文件 中芯国际近期做出了两个重要改变&#xff1a;第一个是调整财报披露方式&#xff0c;不再公布芯片制程的营收占比&#xff0c;而只公布晶圆尺寸的营收占比&#xff1b;第二个是撤消14nm工艺的官方展示&#xff0c;只有28nm、40nm及以上的芯片工…

LeNet基础

目录 1.LeNet简介 1.1基本介绍 1.2网络结构 2.LetNet在pytorch中的使用 2.1首先定义模型 2.2初始化数据集&#xff0c;初始化模型&#xff0c;同时训练数据。 2.3 训练结果​编辑 2.4绘制曲线 1.LeNet简介 1.1基本介绍 LeNet&#xff08;LeNet-5&#xff09;是历史上第…

磁盘阵列(RAID)

什么是磁盘阵列 磁盘阵列&#xff08;RAID&#xff09;是一种将多个物理硬盘组合成一个逻辑存储单元的技术。这种技术可以提高数据存储的可靠性、性能或容量&#xff0c;并且可以在某些情况下提供备份和灾难恢复功能。 RAID技术可以通过在多个硬盘之间分配数据来提高性能。例…

事务处理相关

目录 步骤1.创建一个数据表 步骤2:创建项目导入jar包 步骤3:根据表创建模型类 步骤5:创建Service接口和实现类 步骤6:添加jdbc.properties文件 步骤7:创建JdbcConfig配置类 步骤8:创建MybatisConfig配置类 步骤9:创建SpringConfig配置类 步骤10:编写测试类 开启事务 1…

电磁阀原理精髓

一、引用 电磁阀在液/气路系统中&#xff0c;用来实现液路的通断或液流方向的改变&#xff0c;它一般具有一个可以在线圈电磁力驱动下滑动的阀芯&#xff0c;阀芯在不同的位置时&#xff0c;电磁阀的通路也就不同。 阀芯在线圈不通电时处在甲位置&#xff0c;在线圈通电时处在…

算法与数据结构-链表

文章目录 链表和数组的区别常见的链表类型单链表循环链表双向链表 总结 链表和数组的区别 相比数组&#xff0c;链表是一种稍微复杂一点的数据结构。对于初学者来说&#xff0c;掌握起来也要比数组稍难一些。这两个非常基础、非常常用的数据结构&#xff0c;我们常常会放到一块…

Python基础 - global nonlocal

global global作为全局变量的标识符&#xff0c;修饰变量后标识该变量是全局变量 global关键字可以用在任何地方&#xff0c;包括最上层函数中和嵌套函数中 实例1&#xff1a;如下代码&#xff0c;定义了两个x&#xff0c;并且赋值不同 直接调用print(x) 打印的是全局变量x的…

号外!MyEclipse 2023.1.1已发布,更好支持Vue框架

MyEclipse 2023.1.1是之前发布的2023.1.0的一个小错误修复版本&#xff0c;如果您已经安装了MyEclipse 2023&#xff0c;只需检查产品中的更新 (Help > Check for Updates…) 就可以选择这个新版本。或者&#xff0c;下载我们更新的离线安装程序来安装2023.1.1。 MyEclipse…

C# WPF应用使用visual studio的安装程序类的一些坑

重写installer实现自定义安装程序时&#xff0c;项目类型要选择 类库(.NET Framework) 否则会出现命名空间System.Configuration不存在Install的报错 有些可能想实现安装完自动启动应用的功能&#xff0c;就需要获取installer安装路径 var s Context.Parameters["assem…

【Java】网络编程与Socket套接字、UDP编程和TCP编程实现客户端和服务端通信

网络编程客户端和服务器Socket套接字流套接字TCP数据报套接字UDP对比TCP与UDP UDP编程DatagramSocket构造方法:普通方法&#xff1a; DatagramPacket构造方法:普通方法&#xff1a; 实现 TCP编程ServerSocket构造方法普通方法 Socket构造方法普通方法 实现 网络编程 为什么需要…

MyBatis-Plus 实现PostgreSQL数据库jsonb类型的保存

文章目录 在 handle 包下新建Jsonb处理类方式一方式二 PostgreSQL jsonb类型示例新建数据库表含有jsonb类型创建实体类创建Control 发起请求 在 handle 包下新建Jsonb处理类 方式一 import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.serializer.SerializerFea…

低代码开发平台到底省掉了哪些成本?可能大家一直错了

低代码到底是否真正可以降低研发成本&#xff1f;是否每个团队都适合&#xff1f;如果能降低&#xff0c;到底是降低的什么成本&#xff1f;其实我觉得这个是我们每个技术交付团队应该在使用任何产品之前都要考虑的问题。 在我们考虑低代码是否能降低成本的问题前&#xff0c;…

【Python】一文带你学会数据结构中的字典、集合

作者主页&#xff1a;爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?typeblog个…