使用Deep Q-Network学习如何玩《飞行的小鸟》游戏

news2024/11/26 11:58:36

目录

  • 概述
  • 效果
  • 需要的依赖
  • 如何运行
  • 算法原理
  • 实验
    • 输入处理
    • 网络结构
    • 训练
  • 代码

概述

使用DQN实现《飞行的小鸟》游戏,代码可修改扩展为其他游戏,适合学习研究用。

效果

在这里插入图片描述

需要的依赖

Python 2.7 or 3
TensorFlow 0.7
pygame
OpenCV-Python

如何运行

运行主函数 deep_q_network.py即可。

算法原理

输入输出关系:深度强化学习是q学习的一个变种,其输入是原始像素,其输出是一个价值函数估计未来的回报。
Deep Q-Network Algorithm伪代码如下:

Initialize replay memory D to size N
Initialize action-value function Q with random weights
for episode = 1, M do
    Initialize state s_1
    for t = 1, T do
        With probability ϵ select random action a_t
        otherwise select a_t=max_a  Q(s_t,a; θ_i)
        Execute action a_t in emulator and observe r_t and s_(t+1)
        Store transition (s_t,a_t,r_t,s_(t+1)) in D
        Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
        Set y_j:=
            r_j for terminal s_(j+1)
            r_j+γ*max_(a^' )  Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
        Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
    end for
end for

实验

输入处理

由于深层 Q 网络是在每个时间步骤观察到的游戏屏幕的原始像素值上进行训练的,发现,删除原始游戏中出现的背景可以使它收敛得更快。这个过程可以用下图表示:
在这里插入图片描述

网络结构

先对游戏输入图像进行了以下预处理步骤:

  1. 将图像转换为灰度图
  2. 将图像调整为 80x80 大小
  3. 将最后 4 个帧堆叠在一起,为网络生成 80x80x4 输入数组

网络的架构如下图所示。第一层使用 8x8x4x32 内核在 4 的步幅大小上卷积输入图像。输出然后通过 2x2 最大池层。第二层在 2 的步幅上使用 4x4x32x64 内核进行卷积。然后我们再次进行最大池。第三层在 1 的步幅上使用 3x3x64x64 内核进行卷积。然后我们再次最大池。最后一个隐藏层由 256 个完全连接的 ReLU 节点组成。
在这里插入图片描述
最终输出层的维度与游戏中可以执行的有效操作数量相同,其中 0 索引总是对应于什么也不做。这个输出层的值代表给定输入状态的 Q 函数的每个有效操作。在每个时间步,网络使用ϵ贪心策略执行与最高 Q 值对应的操作。

训练

首先,我使用标准差为 0.01 的正态分布随机初始化所有权重矩阵,然后将重放记忆大小设置为 500,000 次实验。

我开始训练,在最初的 10,000 个时间步中随机均匀选择操作,而不更新网络权重。这使得系统在训练开始前填充重放记忆。

注意,在接下来的 3000,000 个帧中将ϵ从 0.1 线性退火到 0.0001。我这样设置的原因是,在我们的游戏中,代理可以在每 0.03 秒内选择一个操作(FPS=30),高ϵ将使它抖动太多,从而使它保持在游戏屏幕的顶部,最终以笨拙的方式撞到管道。这种情况将使 Q 函数相对较慢地收敛,因为它只有在ϵ较低时才开始看到其他条件。但是,在其他游戏中,将ϵ初始化为 1 更合理。

在训练时间内,在每个时间步,网络从重放记忆中采样大小为 32 的小批量进行训练,并使用学习率为 0.000001 的 Adam 优化算法在上述损失函数上执行梯度步。在退火完成后,网络继续无限期地训练,ϵ固定为 0.001。

代码

主函数

import tensorflow as tf
import cv2
import random
import numpy as np
from collections import deque

#Game的定义类,此处Game是什么不重要,只要提供执行Action的方法,获取当前游戏区域像素的方法即可
class Game(object):
    def __init__(self):  #Game初始化
    # action是MOVE_STAY、MOVE_LEFT、MOVE_RIGHT
    # ai控制棒子左右移动;返回游戏界面像素数和对应的奖励。(像素->奖励->强化棒子往奖励高的方向移动)
        pass
    def step(self, action):
        pass
# learning_rate
GAMMA = 0.99
# 跟新梯度
INITIAL_EPSILON = 1.0
FINAL_EPSILON = 0.05
# 测试观测次数
EXPLORE = 500000
OBSERVE = 500
# 记忆经验大小
REPLAY_MEMORY = 500000
# 每次训练取出的记录数
BATCH = 100
# 输出层神经元数。代表3种操作-MOVE_STAY:[1, 0, 0]  MOVE_LEFT:[0, 1, 0]  MOVE_RIGHT:[0, 0, 1]
output = 3
MOVE_STAY =[1, 0, 0]
MOVE_LEFT =[0, 1, 0]
MOVE_RIGHT=[0, 0, 1]
input_image = tf.placeholder("float", [None, 80, 100, 4])  # 游戏像素
action = tf.placeholder("float", [None, output])           # 操作

#定义CNN-卷积神经网络
def convolutional_neural_network(input_image):
    weights = {'w_conv1':tf.Variable(tf.zeros([8, 8, 4, 32])),
               'w_conv2':tf.Variable(tf.zeros([4, 4, 32, 64])),
               'w_conv3':tf.Variable(tf.zeros([3, 3, 64, 64])),
               'w_fc4':tf.Variable(tf.zeros([3456, 784])),
               'w_out':tf.Variable(tf.zeros([784, output]))}

    biases = {'b_conv1':tf.Variable(tf.zeros([32])),
              'b_conv2':tf.Variable(tf.zeros([64])),
              'b_conv3':tf.Variable(tf.zeros([64])),
              'b_fc4':tf.Variable(tf.zeros([784])),
              'b_out':tf.Variable(tf.zeros([output]))}

    conv1 = tf.nn.relu(tf.nn.conv2d(input_image, weights['w_conv1'], strides = [1, 4, 4, 1], padding = "VALID") + biases['b_conv1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, weights['w_conv2'], strides = [1, 2, 2, 1], padding = "VALID") + biases['b_conv2'])
    conv3 = tf.nn.relu(tf.nn.conv2d(conv2, weights['w_conv3'], strides = [1, 1, 1, 1], padding = "VALID") + biases['b_conv3'])
    conv3_flat = tf.reshape(conv3, [-1, 3456])
    fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights['w_fc4']) + biases['b_fc4'])

    output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out']
    return output_layer

#训练神经网络
def train_neural_network(input_image):
    argmax = tf.placeholder("float", [None, output])
    gt = tf.placeholder("float", [None])

    #损失函数
    predict_action = convolutional_neural_network(input_image)
    action = tf.reduce_sum(tf.mul(predict_action, argmax), reduction_indices = 1) #max(Q(S,:))
    cost = tf.reduce_mean(tf.square(action - gt))
    optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost)

    #游戏开始
    game = Game()
    D = deque()
    _, image = game.step(MOVE_STAY)
    image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
    ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
    input_image_data = np.stack((image, image, image, image), axis = 2)

    with tf.Session() as sess:
        #初始化神经网络各种参数
        sess.run(tf.initialize_all_variables())
        #保存神经网络参数的模块
        saver = tf.train.Saver()

        #总的运行次数
        n = 0
        epsilon = INITIAL_EPSILON
        while True:

            #神经网络输出的是Q(S,:)值
            action_t = predict_action.eval(feed_dict = {input_image : [input_image_data]})[0]
            argmax_t = np.zeros([output], dtype=np.int)

            #贪心选取动作
            if(random.random() <= INITIAL_EPSILON):
                maxIndex = random.randrange(output)
            else:
                maxIndex = np.argmax(action_t)

            #将action对应的Q(S,a)最大值提取出来
            argmax_t[maxIndex] = 1

            #贪婪的部分开始不断的增加
            if epsilon > FINAL_EPSILON:
                epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE

            #将选取的动作带入到环境,观察环境状态S'与回报reward
            reward, image = game.step(list(argmax_t))

            #将得到的图形进行变换用于神经网络的输出
            image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
            ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
            image = np.reshape(image, (80, 100, 1))
            input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis = 2)

            #将S,a,r,S'记录的大脑中
            D.append((input_image_data, argmax_t, reward, input_image_data1))

            #大脑的记忆是有一定的限度的
            if len(D) > REPLAY_MEMORY:
                D.popleft()

            #如果达到观察期就要进行神经网络训练
            if n > OBSERVE:

                #随机的选取一定记忆的数据进行训练
                minibatch = random.sample(D, BATCH)
                #将里面的每一个记忆的S提取出来
                input_image_data_batch = [d[0] for d in minibatch]
                #将里面的每一个记忆的a提取出来
                argmax_batch = [d[1] for d in minibatch]
                #将里面的每一个记忆回报提取出来
                reward_batch = [d[2] for d in minibatch]
                #将里面的每一个记忆的下一步转台提取出来
                input_image_data1_batch = [d[3] for d in minibatch]

                gt_batch = []
                #利用已经有的求解Q(S',:)
                out_batch = predict_action.eval(feed_dict = {input_image : input_image_data1_batch})

                #利用bellman优化得到长期的回报r + γmax(Q(s',:))
                for i in range(0, len(minibatch)):
                    gt_batch.append(reward_batch[i] + GAMMA * np.max(out_batch[i]))

                #利用事先定义的优化函数进行优化神经网络参数
                print("gt_batch:", gt_batch, "argmax:", argmax_batch)
                optimizer.run(feed_dict = {gt : gt_batch, argmax : argmax_batch, input_image : input_image_data_batch})

            input_image_data = input_image_data1
            n = n+1
            print(n, "epsilon:", epsilon, " ", "action:", maxIndex, " ", "_reward:", reward)

train_neural_network(input_image)

注意:完整项目代码链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/153857.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

目标追踪综述

目标追踪综述 - 知乎目标跟踪是计算机视觉领域的一个重要问题&#xff0c;目前广泛应用在体育赛事转播、安防监控和无人机、无人车、机器人等领域。下面是一些应用的例子。 体育赛事转播 无人车 目标跟踪任务分类了解了目标跟踪的用途&#xff0c;我们接下…https://zhuanlan.z…

Java(SpringBoot)项目打包(构建)成`Docker`镜像的几种方式

前置说明 最为原始的打包方式spring-boot-maven-plugin插件jib-maven-plugin插件dockerfle-maven-plugin插件 最为原始的方式 也就是使用Docker的打包命令去打包&#xff0c;麻烦&#xff0c;我这里不多说。 spring-boot-maven-plugin插件打包 SpringBoot自己内置了一个Docker镜…

有了这些软件测试面试话术,offer想不拿到都难

软件测试是一个复杂且重要的技术岗位&#xff0c;因此&#xff0c;大多数互联网企业在面试时&#xff0c;都会严谨对待每一个面试者。而&#xff0c;作为即将去进行面试测试人来说&#xff0c;想要在面试中&#xff0c;沉着稳定地回答好面试官们提出的问题&#xff0c;前期的软…

P5 内积 -- 通讯原理

目录内积内积和傅里叶变换正交能量帕瑟瓦尔定理互能量一 内积定义&#xff1a;任意信号 内积定义为&#xff1a;如果都是实信号例&#xff1a;二 内积和傅里叶变换的关系傅里叶变换 和逆变换 本质上就是求两个函数的内积傅里叶变换傅里叶逆变换时域的内积等于频域的内积假设 则…

再获殊荣!维视智造斩获2022年度光能杯最具影响力“智造”企业奖

近日&#xff0c;由光伏行业权威媒体和机构——索比光伏网、索比咨询联合主办的2022年度“光能杯”影响力大奖榜单发布&#xff0c;维视智造凭借硬件与AI算法能力、凭借在光伏行业具有创新性的智能制造产品方案与落地的标杆案例&#xff0c;斩获“2022年最具影响力“智造”企业…

Windows下Canal.deployer-1.1.6安装部署

canal [kənl]&#xff0c;译意为水道/管道/沟渠&#xff0c;主要用途是基于 MySQL 数据库增量日志解析&#xff0c;提供增量数据订阅和消费 早期阿里巴巴因为杭州和美国双机房部署&#xff0c;存在跨机房同步的业务需求&#xff0c;实现方式主要是基于业务 trigger 获取增量变…

多线程之线程控制与互斥

1.线程的缺点有哪些&#xff1f; 第一点 健壮性低------ 一个线程挂了容易影响另外的线程 第二点 缺乏访问控制----- 不像进程是独立的&#xff0c;可以写时拷贝&#xff0c;线程随进随出有点危险哦 第三点 编写难度上升----- 编写一个多线程的代码和调试可比单线程难多了 ——…

strlen 的三种模拟方法

欢迎来到 Claffic 的博客 &#x1f49e;&#x1f49e;&#x1f49e; 前言&#xff1a; 在C/C 中&#xff0c;strlen函数是一种计算字符串长度的库函数&#xff0c;要模拟此函数有多种方法&#xff0c;这里总结三种模拟方法。 1. strlen 函数介绍 cplusplus - strlen strlen 函数…

正点原子-Linux嵌入式开发学习-第二期06

第十四讲&#xff1a;主频和时钟配置 分析一个芯片的时钟&#xff0c;肯定先知道它的时钟来源&#xff0c;一般来源于外部晶振&#xff0c;内部晶振很少使用 时钟来源分析 RTC的时钟并不是其他外设的晶振来源 24MHz 晶振是 I.MX6U 内核和其它外设的时钟源&#xff0c;也是我…

K8s入门

K8s入门K8s入门k8s介绍k8s功能概述k8s架构k8s核心概念服务器配置要求部署方式使用kubeadm搭建一个k8s集群所有节点安装 Docker/kubeadm/kubeletK8s入门 你好&#xff01; 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这…

84.【Vue--初刷】

vue.js(一)、vue.js简介1.简介(1).MVVM模式的实现(2).为什么要使用Vue.js(3).为什么要使用MVVC2.应用场景3.JavaScipt框架(1).JQuery :(2).Angular(3).React(4).Vue(5).Axios4.UI框架【可视化】5.JavaScript 构建工具6.三端开发(1).混合开发(Hybrid App)(2).微信小程序7.后端技…

LeetCode题解 回溯(一):77 组合;216 组合总和III

回溯 从今天开始进入回溯&#xff0c;其实此前也接触过几道使用了该思想的题目 回溯的思想是“倒退到上一个状态”&#xff0c;通常结合递归&#xff0c;解决的问题多是“从众多组合中找出符合条件的组合”的问题&#xff0c;随想录中给出了题目大纲&#xff1a; 回溯算法解决…

Linux学习笔记——ZooKeeper集群安装部署

5.8、ZooKeeper集群安装部署 5.8.1、简介 Zookeeper是一个分布式的、开放源码的分布式应用程序协调服务&#xff0c;是Hadoop和HBase的重要组件。它是一个为分布式应用提供一致性服务的软件&#xff0c;提供的功能包括&#xff1a;配置维护、域名服务、分布式同步、组服务等。…

CHAPTER 2 Docker镜像

docker镜像2.1 docker image 获取2.1.1 命令格式&#xff08;pull&#xff09;2.1.2 层(layer)2.1.3 镜像重名2.2 查看镜像信息&#xff08;ls&#xff0c;tag&#xff0c;inspect&#xff0c;history&#xff09;2.2.1 使用images命令列出镜像&#xff08;ls&#xff09;2.2.2…

uni-app:小程序开发总结

内容持续更新中~~~&#x1f618;uniapp项目起步:工具下载在Dcloud 官网上下载 HBuilderX 开发工具,以及微信开发者工具.(同时你要在微信开发者文档进行小程序注册,拿到 ID, HBuilderX 和 微信开发者工具 你都要进行注册登录)项目创建我们可以通过HBuilderX 来进行基础版的项目创…

【阶段三】Python机器学习12篇:机器学习项目实战:朴素贝叶斯模型的算法原理与朴素贝叶斯分类模型

本篇的思维导图: 朴素贝叶斯模型的算法原理 朴素贝叶斯是贝叶斯模型当中最简单的一种,其算法核心为如下所示的贝叶斯公式: 其中P(A)为事件A发生的概率,P(B)为事件B发生的概率,P(A|B)表示在事件B发生的条件下事件A发生的概率,同理P(B|A)则表示在事件A发…

2023-01-10 clickhouse-聚合函数的源码再梳理

https://cloud.tencent.com/developer/article/1815441 1.IAggregateFunction接口梳理 话不多说&#xff0c;直接上代码&#xff0c;笔者这里会将所有聚合函数的核心接口代码全部列出&#xff0c;一一梳理各个部分&#xff1a; 构造函数 IAggregateFunction(const DataTypes …

Android设置本地字体文件ttf

目录 前言 ①使用typeface 方式 一、创建加载字体实例 二、使用步骤 1.在Application中加载字体 2.在xml中使用 ②使用fontFamily 方式 1、在res/font下导入ttf文件 2、在xml中使用 总结 前言 产品告诉UI设计设计图时要使用炫酷字体。因为Android不像网页项目可以使用…

如何使用Jasper导出用户列表数据?

场景说明在使用JasperjaspersoftStudio导出用户列表数据导出(如下图)是比较简单的&#xff0c;就是把用户列表数据&#xff0c;一个List集合放到 JRBeanCollectionDataSource中即可。但是如果有多个List集合需要导出呢&#xff0c;这个应该怎么办?比如&#xff1a;一个用户的集…

用Python发邮件(附完整源代码)

目录 一、背景 1.1、前言 1.2、说明 二、SMTP协议 2.1、SMTP协议作用 2.2、SSL作用 三、步骤 3.1、开启QQ邮箱SMTP 四、代码 4.1、完整源代码 五、结果 5.1、代码运行结果 六、总结 6.1、总结 一、背景 1.1、前言 写了一个简陋的2023年12306自动化购票程序&…