用深度强化学习来玩Flappy Bird

news2024/11/27 1:30:21

目录

演示视频

核心代码


演示视频

用深度强化学习来玩Flappy Bird

核心代码

import torch.nn as nn

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 2)
        self._create_weights()

    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -0.01, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output
from itertools import cycle
from numpy.random import randint
from pygame import Rect, init, time, display
from pygame.event import pump
from pygame.image import load
from pygame.surfarray import array3d, pixels_alpha
from pygame.transform import rotate
import numpy as np


class FlappyBird(object):
    init()
    fps_clock = time.Clock()
    screen_width = 288
    screen_height = 512
    screen = display.set_mode((screen_width, screen_height))
    display.set_caption('Deep Q-Network Flappy Bird')
    base_image = load('assets/sprites/base.png').convert_alpha()
    background_image = load('assets/sprites/background-black.png').convert()

    pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),
                   load('assets/sprites/pipe-green.png').convert_alpha()]
    bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),
                   load('assets/sprites/redbird-midflap.png').convert_alpha(),
                   load('assets/sprites/redbird-downflap.png').convert_alpha()]
    # number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]

    bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
    pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]

    fps = 30
    pipe_gap_size = 100
    pipe_velocity_x = -4

    # parameters for bird
    min_velocity_y = -8
    max_velocity_y = 10
    downward_speed = 1
    upward_speed = -9

    bird_index_generator = cycle([0, 1, 2, 1])

    def __init__(self):

        self.iter = self.bird_index = self.score = 0

        self.bird_width = self.bird_images[0].get_width()
        self.bird_height = self.bird_images[0].get_height()
        self.pipe_width = self.pipe_images[0].get_width()
        self.pipe_height = self.pipe_images[0].get_height()

        self.bird_x = int(self.screen_width / 5)
        self.bird_y = int((self.screen_height - self.bird_height) / 2)

        self.base_x = 0
        self.base_y = self.screen_height * 0.79
        self.base_shift = self.base_image.get_width() - self.background_image.get_width()

        pipes = [self.generate_pipe(), self.generate_pipe()]
        pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_width
        pipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5
        self.pipes = pipes

        self.current_velocity_y = 0
        self.is_flapped = False

    def generate_pipe(self):
        x = self.screen_width + 10
        gap_y = randint(2, 10) * 10 + int(self.base_y / 5)
        return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}

    def is_collided(self):
        # Check if the bird touch ground
        if self.bird_height + self.bird_y + 1 >= self.base_y:
            return True
        bird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)
        pipe_boxes = []
        for pipe in self.pipes:
            pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))
            pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))
            # Check if the bird's bounding box overlaps to the bounding box of any pipe
            if bird_bbox.collidelist(pipe_boxes) == -1:
                return False
            for i in range(2):
                cropped_bbox = bird_bbox.clip(pipe_boxes[i])
                min_x1 = cropped_bbox.x - bird_bbox.x
                min_y1 = cropped_bbox.y - bird_bbox.y
                min_x2 = cropped_bbox.x - pipe_boxes[i].x
                min_y2 = cropped_bbox.y - pipe_boxes[i].y
                if np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,
                       min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,
                                                              min_y2:min_y2 + cropped_bbox.height]):
                    return True
        return False

    def next_frame(self, action):
        pump()
        reward = 0.1
        terminal = False
        # Check input action
        if action == 1:
            self.current_velocity_y = self.upward_speed
            self.is_flapped = True

        # Update score
        bird_center_x = self.bird_x + self.bird_width / 2
        for pipe in self.pipes:
            pipe_center_x = pipe["x_upper"] + self.pipe_width / 2
            if pipe_center_x < bird_center_x < pipe_center_x + 5:
                self.score += 1
                reward = 1
                break

        # Update index and iteration
        if (self.iter + 1) % 3 == 0:
            self.bird_index = next(self.bird_index_generator)
            self.iter = 0
        self.base_x = -((-self.base_x + 100) % self.base_shift)

        # Update bird's position
        if self.current_velocity_y < self.max_velocity_y and not self.is_flapped:
            self.current_velocity_y += self.downward_speed
        if self.is_flapped:
            self.is_flapped = False
        self.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)
        if self.bird_y < 0:
            self.bird_y = 0

        # Update pipes' position
        for pipe in self.pipes:
            pipe["x_upper"] += self.pipe_velocity_x
            pipe["x_lower"] += self.pipe_velocity_x
        # Update pipes
        if 0 < self.pipes[0]["x_lower"] < 5:
            self.pipes.append(self.generate_pipe())
        if self.pipes[0]["x_lower"] < -self.pipe_width:
            del self.pipes[0]
        if self.is_collided():
            terminal = True
            reward = -1
            self.__init__()

        # Draw everything
        self.screen.blit(self.background_image, (0, 0))
        self.screen.blit(self.base_image, (self.base_x, self.base_y))
        self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))
        for pipe in self.pipes:
            self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
            self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))


        image = array3d(display.get_surface())
        display.update()
        self.fps_clock.tick(self.fps)
        return image, reward, terminal
import argparse
import torch

from src.deep_q_network import DeepQNetwork
from src.flappy_bird import FlappyBird
from src.utils import pre_processing


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Flappy Bird""")
    parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    parser.add_argument("--saved_path", type=str, default="trained_models")

    args = parser.parse_args()
    return args


def q_test(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    if torch.cuda.is_available():
        model = torch.load("{}/flappy_bird".format(opt.saved_path))
    else:
        model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)
    model.eval()
    game_state = FlappyBird()
    image, reward, terminal = game_state.next_frame(0)
    image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    image = torch.from_numpy(image)
    if torch.cuda.is_available():
        model.cuda()
        image = image.cuda()
    state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]

    while True:
        prediction = model(state)[0]
        action = torch.argmax(prediction)

        next_image, reward, terminal = game_state.next_frame(action)
        next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
                                    opt.image_size)
        next_image = torch.from_numpy(next_image)
        if torch.cuda.is_available():
            next_image = next_image.cuda()
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]

        state = next_state


if __name__ == "__main__":
    opt = get_args()
    q_test(opt)
def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Flappy Bird""")
    parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")
    parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")
    parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
    parser.add_argument("--lr", type=float, default=1e-6)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--initial_epsilon", type=float, default=0.1)
    parser.add_argument("--final_epsilon", type=float, default=1e-4)
    parser.add_argument("--num_iters", type=int, default=2000000)
    parser.add_argument("--replay_memory_size", type=int, default=50000,
                        help="Number of epoches between testing phases")
    parser.add_argument("--log_path", type=str, default="tensorboard")
    parser.add_argument("--saved_path", type=str, default="trained_models")

    args = parser.parse_args()
    return args


def train(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = DeepQNetwork()
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    writer = SummaryWriter(opt.log_path)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.MSELoss()
    game_state = FlappyBird()
    image, reward, terminal = game_state.next_frame(0)
    image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)
    image = torch.from_numpy(image)
    if torch.cuda.is_available():
        model.cuda()
        image = image.cuda()
    state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]

    replay_memory = []
    iter = 0
    while iter < opt.num_iters:
        prediction = model(state)[0]
        # Exploration or exploitation
        epsilon = opt.final_epsilon + (
                (opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
        u = random()
        random_action = u <= epsilon
        if random_action:
            print("Perform a random action")
            action = randint(0, 1)
        else:

            action = torch.argmax(prediction)

        next_image, reward, terminal = game_state.next_frame(action)
        next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
                                    opt.image_size)
        next_image = torch.from_numpy(next_image)
        if torch.cuda.is_available():
            next_image = next_image.cuda()
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
        replay_memory.append([state, action, reward, next_state, terminal])
        if len(replay_memory) > opt.replay_memory_size:
            del replay_memory[0]
        batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)

        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
        reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state for state in next_state_batch))

        if torch.cuda.is_available():
            state_batch = state_batch.cuda()
            action_batch = action_batch.cuda()
            reward_batch = reward_batch.cuda()
            next_state_batch = next_state_batch.cuda()
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction in
                  zip(reward_batch, terminal_batch, next_prediction_batch)))

        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        # y_batch = y_batch.detach()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state
        iter += 1
        print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
            iter + 1,
            opt.num_iters,
            action,
            loss,
            epsilon, reward, torch.max(prediction)))
        writer.add_scalar('Train/Loss', loss, iter)
        writer.add_scalar('Train/Epsilon', epsilon, iter)
        writer.add_scalar('Train/Reward', reward, iter)
        writer.add_scalar('Train/Q-value', torch.max(prediction), iter)
        if (iter+1) % 1000000 == 0:
            torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))
    torch.save(model, "{}/flappy_bird".format(opt.saved_path))


if __name__ == "__main__":
    opt = get_args()
    train(opt)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/952724.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java八股文面试[数据库]——MySql聚簇索引和非聚簇索引区别

聚集索引和非聚集索引 聚集索引和非聚集索引的根本区别是表记录的排列顺序和与索引的排列顺序是否一致。 1、聚集索引 聚集索引表记录的排列顺序和索引的排列顺序一致&#xff08;以InnoDB聚集索引的主键索引来说&#xff0c;叶子节点中存储的就是行数据&#xff0c;行数据在…

【Go 基础篇】Go语言结构体之间的转换与映射

在Go语言中&#xff0c;结构体是一种强大的数据类型&#xff0c;用于定义和组织不同类型的数据字段。当我们处理复杂的数据逻辑时&#xff0c;常常需要在不同的结构体之间进行转换和映射&#xff0c;以便实现数据的转移和处理。本文将深入探讨Go语言中结构体之间的转换和映射技…

Folx 5适用Mac的BT客户端下载器

Mac 上免费的网络下载管理器Folx Mac 下载器有一个支持 Retina 显示的现代界面。提供独特的系统排序、存储下载内容与预览下载文件。Folx 是具有真正 Mac 风格界面的 macOS 免费下载管理器。它提供了方便的下载管理,灵活的设置等。Folx 专业版是 Mac 上一个出色的种子下载器&am…

Solidity 小白教程:4. 函数输出 return

Solidity 小白教程&#xff1a;4. 函数输出 return 这一讲&#xff0c;我们将介绍Solidity函数输出&#xff0c;包括&#xff1a;返回多种变量&#xff0c;命名式返回&#xff0c;以及利用解构式赋值读取全部和部分返回值。 返回值 return 和 returns Solidity有两个关键字与…

1773_把vim的tab键设置为4个空格显示

全部学习汇总&#xff1a; GitHub - GreyZhang/editors_skills: Summary for some common editor skills I used. 有时候自己觉得自己很奇怪&#xff0c;看着Linux的命令窗口就觉得很顺眼。那些花花绿绿的字符以及繁多的方便命令工具&#xff0c;确实是比Windows强不少。不过&a…

电脑前置耳机没声音怎么办

有很多小伙伴反映在将自己的耳机连接到主机前面时没有声音&#xff0c;这是怎么回事呢&#xff0c;遇到这种情况应该怎么解决呢&#xff0c;下面小编就给大家详细介绍一下电脑前置耳机没声音的解决方法&#xff0c;有需要的小伙伴可以来看一看电脑前面耳机没声音。 解决方法&a…

SpringCloud(十)——ElasticSearch简单了解(三)数据聚合和自动补全

文章目录 1. 数据聚合1.1 聚合介绍1.2 Bucket 聚合1.3 Metrics 聚合1.4 使用 RestClient 进行聚合 2. 自动补全2.1 安装补全包2.2 自定义分词器2.3 自动补全查询2.4 拼音自动补全查询2.5 RestClient 实现自动补全2.5.1 建立索引2.5.2 修改数据定义2.5.3 补全查询2.5.4 解析结果…

Web安全——信息收集上篇

Web安全 一、信息收集简介二、信息收集的分类三、常见的方法四、在线whois查询在线网站备案查询 五、查询绿盟的whois信息六、收集子域名1、子域名作用2、常用方式3、域名的类型3.1 A (Address) 记录&#xff1a;3.2 别名(CNAME)记录&#xff1a;3.3 如何检测CNAME记录&#xf…

yolov5自定义模型训练二

前期准备好了用于训练识别是否有火灾的数据集后就可以开始修改yolo相关文件来进行训练 数据集放到yolov5目录里 在data目录下新建yaml文件设置数据集信息如下 在model文件夹下新增新的model文件 开始训练 训练出错 确认后是对训练数据集文件夹里的文件名字有要求&#xff0c;原…

YOLO目标检测——人脸数据集下载分享

目标检测人脸数据集在人脸识别、监控和安防、社交媒体、情感分析、医疗诊断等多个领域都具有广泛的应用潜力。 数据集点击下载&#xff1a;YOLO人脸数据集7000图片.rar

Shell 脚本入门

目录 一、Shell是什么 1.1 我们为什么要学习Shell和使用Shell&#xff1f; 1.2 Shell的分类有哪些&#xff1f; 二、Shell脚本入门知识 2.1 Shell文件命名规范 2.2 Shell解析器 2.3 用Shell 编写hello World 三、Shell的四种变量类型 3.1 系统预定义变量 3.2 自定义变…

移动端几种适配方式

移动端几种适配方式 第一种&#xff1a;rem <meta name"viewport" content"widthdevice-width, initial-scale1.0,,maximum-scale1,user-scalableno">设置窗口不能缩放 一般设备宽度的十分之一 如果这个值是动态计算的需要使用js去设置 根据设备…

【rar密码】使用WinRAR加密的三种方法

如何使用WinRAR加密压缩包&#xff1f;详细介绍WinRAR中的三种加密方法给大家。 方法一&#xff1a;加密 最简单的加密方法&#xff0c;就是在加密文件时输入想要设置的密码&#xff0c;完成加密和压缩了。 方法二&#xff1a;自动加密 普通的加密方式&#xff0c;需要我们加…

1.8.7 练习 冒泡排序 Bubble Sort(提取函数)

C自学精简教程 目录(必读) 1 前驱知识点 for循环语句 、 if语句、函数 、动态内存 2 排序 是将元素按照从小到大的顺序存放的方法。 一开始元素可能并不是按照从小到大的顺序存放的。 这时候我们需要找到需要调整的元素对&#xff0c;并交换这两个元素的值&#xff0c;不…

无涯教程-Android - Absolute Layout函数

Absolute Layout 可让您指定其子级的确切位置(x/y坐标)&#xff0c;绝对布局的灵活性较差且难以维护。 Absolute Layout - 属性 以下是AbsoluteLayout特有的重要属性- Sr.NoAttribute & 描述1 android:id 这是唯一标识布局的ID。 2 android:layout_x 这指定视图的x坐标…

什么是面向对象以及和面向过程的区别

概念 面向对象是一种编程范式&#xff0c;它将现实世界中的事物抽象为对象&#xff0c;并通过对象之间的交互来实现程序的设计和开发。在面向对象编程中&#xff0c;对象是程序的基本单元&#xff0c;具有状态&#xff08;属性&#xff09;和行为&#xff08;方法&#xff09;…

【python爬虫】3.爬虫初体验(BeautifulSoup解析)

文章目录 前言BeautifulSoup是什么BeautifulSoup怎么用解析数据提取数据 对象的变化过程总结 前言 上一关&#xff0c;我们学习了HTML基础知识&#xff0c;知道了HTML是一种用来描述网页的语言&#xff0c;又了解了HTML的基本结构。 认识了HTML中的常见标签和常见属性&#x…

A系统跳转到B系统URl传参与接收

处理url字符串拼接方法new URLSearchParams() A系统 let data { study_id: row.study_id, hospital_id: row.hospital_id, pageType: detail };let params new URLSearchParams(); for (let key in data) {params.append(key, data[key]); } window.open(url ? params.toS…

SpringCloud(十)——ElasticSearch简单了解(二)DSL查询语句及RestClient查询文档

文章目录 1. DSL查询文档1.1 DSL查询分类1.2 全文检索查询1.3 精确查询1.4 地理查询1.5 查询算分1.6 布尔查询1.7 结果排序1.8 分页查询1.9 高亮显示 2. RestClient查询文档2.1 查询全部2.2 其他查询语句2.3 排序和分页2.4 高亮显示 1. DSL查询文档 1.1 DSL查询分类 查询所有…

中央仓库更新失败,IDEA报错repository is non-nexus repo, or does not indexed

某个仓库未被识别为 Nexus 仓库&#xff0c;或者没有被正确地索引。导致引入依赖一直爆红&#xff0c;找不到。只有本地仓库的依赖没报错&#xff0c;因为下载过了&#xff0c;添加新的依赖就需要到远程仓库找就爆红。 解决 去阿里云Maven官网看了一下&#xff0c;发现阿里云…