Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

news2024/11/15 23:43:19

目录

  • 1 案例介绍
  • 2 构造深度Q网络
  • 3 经验回放与目标网络
  • 4 训练流程
  • 5 实验分析

1 案例介绍

Flappy Bird是一款由来自越南的独立游戏开发者Dong Nguyen所开发的作品,于2013年5月24日上线。

Flappy Bird中,玩家只需要用一根手指来操控:点击一次屏幕,小鸟就会往上飞一次,不断地点击就会使小鸟不断往高处飞。放松手指,小鸟则会快速下降。所以玩家要控制小鸟一直向前飞行,然后注意躲避途中高低不平的管子。小鸟每安全穿过一个水管得1分,若撞上水管则游戏失败。

如图所示是用强化学习模型DQN训练AI完成Flappy Bird游戏的案例,接下来具体分析如何实现这个案例

在这里插入图片描述

2 构造深度Q网络

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池
  • 目标网络

拟合高维状态空间,是Q-Learning算法的深度学习版本。具体理论参考Pytorch深度强化学习(八):基于价值的强化学习——DQN算法

具体到Flappy Bird游戏,结构如图所示:设置网络输入为游戏的连续四帧图片,使用卷积神经网络提取状态特征,最后输出为一个布尔值,即小鸟选择的动作——向上飞或下降。

在这里插入图片描述
实现如下

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 2)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output

3 经验回放与目标网络

考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性。经验回放池的设置、存储与采样如下

replay_memory = []

# 将<s, a, r, s'>添加到经验回放池
replay_memory.append([state, action, reward, next_state, terminal])
if len(replay_memory) > opt["replay_memory_size"]:
    del replay_memory[0]

# 采样一个batch的数据
batch = sample(replay_memory, min(len(replay_memory), opt["batch_size"]))
state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)

考虑到若目标价值与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建结构相同但慢于更新的独立目标网络来评估目标价值,使模型更稳定

# 采用的网络
self.model = DQN(env.observation_space.shape, env.action_space.n).to(self.device)
self.target_model = DQN(env.observation_space.shape, env.action_space.n).to(self.device)
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
    target_param.data.copy_(param)

# 更新target网络
for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
    target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

4 训练流程

除了与环境的交互采样强化学习思想,其余步骤与深度学习训练相同

# 实例化DQN模型
model = DeepQNetwork()

# 设置优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=opt["lr"])
criterion = nn.MSELoss()

# 初始化环境
game_state = FlappyBird()
image, reward, terminal = game_state.step(0)
image = preProcessing(image[:game_state.screen_width, :int(game_state.base_y)], opt["image_size"], opt["image_size"])
image = torch.from_numpy(image)

# 获得状态, 将图片化为batch x in_channel x h x w
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
replay_memory = []

# 开始迭代
with tqdm(range(opt["num_iters"])) as bar:
    for i in bar:
        prediction = model(state)[0]
        
        # 动态调整贪心概率并执行贪心算法
        epsilon = opt["final_epsilon"] + (
                (opt["num_iters"] - i) * (opt["initial_epsilon"] - opt["final_epsilon"]) / opt["num_iters"])
        action = randint(0, 1) if random() <= epsilon else torch.argmax(prediction)

        # 获取下一个状态(时序差分)
        next_image, reward, terminal = game_state.step(action)
        next_image = preProcessing(next_image[:game_state.screen_width, :int(game_state.base_y)],
                        opt["image_size"], opt["image_size"])
        next_image = torch.from_numpy(next_image)
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]

        # 将<s, a, r, s'>添加到经验回放池
        ...
        
        # 采样一个batch的数据
        ...

        # 目标网络为训练样本添加标注信息,并与当前值网络做损失
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)
        y_batch = torch.cat(
            tuple(reward if terminal else reward + opt["gamma"] * torch.max(prediction) for reward, terminal, prediction in
                zip(reward_batch, terminal_batch, next_prediction_batch)))
        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)

        # 梯度优化
        optimizer.zero_grad()
        # y_batch = y_batch.detach()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state

5 实验分析

训练2000代的奖赏曲线如图所示,左侧是验证集曲线,右侧是训练集曲线,可见随着训练过程进行,模型得到的奖励在不断上升

在这里插入图片描述
刚开始训练时的效果可视化

在这里插入图片描述

模型收敛后的效果可视化(200万次迭代),AI已经可以很好地掌握这款游戏了

在这里插入图片描述

本文完整工程代码请联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《路径规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/181476.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

P49 BFC 块级格式化上下文 块级格式化上下文 BFC渲染区域: 创建BFC的元素,它的自动高度需要计算浮动元素. 高度塌陷例子

目录块级格式化上下文BFC渲染区域:创建BFC的元素&#xff0c;它的自动高度需要计算浮动元素.高度塌陷例子&#xff1a;第一种方法 clearfix::after第二种解决办法 :绝对定位第三种解决办法&#xff1a;overflow: scroll;第四种方法&#xff1a;clearfix hidden创建BFC的元素&am…

21版本FL Studio水果音乐制作软件下载

因为对音乐有一些了解&#xff0c;所以周边有不少朋友会问我很多关于音乐的问题&#xff0c;其中比较多是学习音乐到底用哪款软件比较好。每次遇到这样的问题&#xff0c;我都会告诉他们&#xff0c;就是我一直在用的音乐制作软件FL Studio。音乐制作软件FL Studio&#xff0c;…

【JavaGuide面试总结】Java集合篇·中

【JavaGuide面试总结】Java集合篇中1.Collection 子接口之 SetComparable 和 Comparator 的区别比较 HashSet、LinkedHashSet 和 TreeSet 三者的异同2.Collection 子接口之 QueueQueue 与 Deque 的区别ArrayDeque 与 LinkedList 的区别说一说 PriorityQueue3.Map 接口HashMap 的…

机器学习(八):深度学习简介

文章目录 深度学习简介 一、神经网络简介 二、深度学习各层负责内容 深度学习简介 一、神经网络简介 深度学习&#xff08;Deep Learning&#xff09;&#xff08;也称为深度结构学习【Deep Structured Learning】、层次学习【Hierarchical Learning】或者是深度机器学习【…

React中commit阶段发生了什么

对于commit阶段的主要工作是循环effectList链表去将有更新的fiber节点应用到页面上是commit的主要工作。 EffectList 什么是副作用&#xff1f; 函数在执行过程中对外部造成的影响可以称之为副作用&#xff0c;副作用包含的类型很多&#xff0c;比如说标记值为Placement时&a…

客快物流大数据项目(一百零九):Spring Boot概述

文章目录 Spring Boot概述 一、什么是SpringBoot 二、​​​​​​​为什么要学习Spring Boot

PHP转Go实践:xjson解析神器「开源工具集」

前言 近期会更新一系列开源项目的文章&#xff0c;新的一年会和大家做更多的开源项目&#xff0c;也欢迎大家加入进来。 xutil 今天分享的文章源自于开源项目jinzaigo/xutil的封装。 在封装过程中&#xff0c;劲仔将实现原理以及相关实践思考&#xff0c;写成文章分享出来&am…

Python3学习——条件控制、循环语句与迭代器

目录 一、编程第一步——斐波那契数列 二、条件控制 (一)if/else语句 判断狗狗的年龄&#xff1a; (二)多层if/else嵌套 判断数字能否被2或3整除&#xff1a; (三)match...case匹配——python3中新增 根据数字判断星期&#xff1a; 三、循环语句 (一)while循环 1.循环…

Java:Idea创建项目和Spring工程基本使用

一、创建项目 1、创建新的空的项目&#xff1a; Empty Project–next 2、定义项目的名称&#xff0c;并指定位置 3、对项目进行设置&#xff0c;JDK版本、编译版本 4、添加模块信息 5、修改maven路径 6、项目目录结构 二、搭建Spring的框架 1、在核心配置文件中添加Spring的j…

C++11 并发指南五(stdcondition_variable 详解)

C11 并发指南五(std::condition_variable 详解) 文章目录C11 并发指南五(std::condition_variable 详解)std::condition_variable 类介绍std::condition_variable_any 介绍std::cv_status 枚举类型介绍std::notify_all_at_thread_exit前面三讲《 C11 并发指南二(std::thread 详…

二叉树简单解析(1)

&#x1f340;本人简介&#xff1a; 吉师大一最爱逃课的混子、 华为云享专家、阿里云专家博主、腾讯云自媒体分享计划博主、 华为MindSpore优秀开发者、迷雾安全团队核心成员&#xff0c;CSDN2022年运维与安全领域第15名 &#x1f341;本人制作小程序以及资源分享地址&#x…

英语学习打卡day7

2023.1.27 1.ironically adv.具有讽刺意味的是;反讽地&#xff0c;讽刺地 Ironically, his cold got better on the last day of his holiday. 2.bequeath vt.遗赠;把…遗赠给;把… .传给 (比give更正式) bequeath sb sth bequeath sth to sb Don’t bequeath the problem …

JDK17 || JDK 8 完美 卸载 教程 (Windows版)

文章目录一、卸载jdk程序1 . 找到控制面板2. 卸载程序3. 找到JDK 相关的程序4. 右键 选择卸载程序5. 下一步 选择 是6.下一步 选择 是二、安装 新版 JDK三、如果不想再使用jdk环境结语一、卸载jdk程序 1 . 找到控制面板 2. 卸载程序 3. 找到JDK 相关的程序 4. 右键 选择卸载程…

IDEA界面和控制台的滚动条颜色不明显?赶快换一个吧!

前言 不知道大家是否和我一样有这么一个烦恼&#xff1a; IDEA自带的滚动条颜色很暗&#xff0c;配上一些主题颜色搭配很难发现。 所以今天就想着怎么可以修改滚动条颜色&#xff0c;首先去网上搜了搜都是什么鼠标滚轮加shift滚动&#xff0c;一点也不实用 偶然看到了个不错的…

【青训营】Go的BenchMark的使用

本文内容总结于 字节跳动青年训练营 第五届后端组 Go自带了一些性能测试工具&#xff0c;其中BenchMark是较为重要的一个。 我们以计算斐波那契数列的示例来展示BenchMark的使用 package Benchmarkimport "testing"func Fib(n int) int {if n < 2 {return n}ret…

OpenCV-PyQT项目实战(1)安装与环境配置

本系列从零开始实战解说基于 PyQt5 的 OpenCV 项目开发。 欢迎关注『OpenCV-PyQT项目实战 Youcans』系列&#xff0c;持续更新中 OpenCV-PyQT项目实战&#xff08;1&#xff09;安装与环境配置 OpenCV-PyQT项目实战&#xff08;2&#xff09;OpenCV导入图像 文章目录1. PyQt5 …

初识图像分类——K近邻法(cs231n assignment)

作者&#xff1a;非妃是公主 专栏&#xff1a;《计算机视觉》 个性签&#xff1a;顺境不惰&#xff0c;逆境不馁&#xff0c;以心制境&#xff0c;万事可成。——曾国藩 专栏系列文章 Cannot find reference ‘imread‘ in ‘init.py‘ error: (-209:Sizes of input arguments…

ppt神器islide 第1节 初步接触强大的工具-资源篇

ppt神器islide 第1节 初步接触强大的工具1 PPT大神的课程总结1.1 骨架篇1.2 色彩篇1.3 对齐篇1.4 对比篇1.5 修饰篇1.6 字体篇1.7 素材篇1.8 线条篇1.8.1 可以随意画线条&#xff0c;填充空白1.8.2 在字体上画线条&#xff0c;做成艺术字1.8.3 做对称线条&#xff0c;比如递进三…

[Vulnhub] DC-3

下载链接&#xff1a;DC-3 DC-3需要 把IDE里面的改成IDE 0:0 不然无法打开 知识点&#xff1a; Joomla cms 3.7.0 sql注入漏洞cmseek工具探测cms指纹john解密proc_popen反弹shellpython -m http.server开启http服务 & wget远程下载ubuntu16.0.4 Linux 4.4.0-21 系统漏…

使用OpenAI的Whisper 模型进行语音识别

语音识别是人工智能中的一个领域&#xff0c;它允许计算机理解人类语音并将其转换为文本。该技术用于 Alexa 和各种聊天机器人应用程序等设备。而我们最常见的就是语音转录&#xff0c;语音转录可以语音转换为文字记录或字幕。 wav2vec2、Conformer 和 Hubert 等最先进模型的最…