强化学习玩flappy_bird

news2024/12/23 22:32:32

强化学习玩flappy_bird(代码解析)

游戏地址:https://flappybird.io/

该游戏的规则是:

  • 点击屏幕则小鸟立即获得向上速度。

  • 不点击屏幕则小鸟受重力加速度影响逐渐掉落。

  • 小鸟碰到地面会死亡,碰到水管会死亡。(碰到天花板不会死亡)

  • 小鸟通过水管会得分。

    img

    具体的网络结构如图所示,网络架构是拿到游戏状态(每个样本维度是 80 * 80 * 4),然后卷积(输出维度 20 * 20 * 32)、池化(输出 10 * 10 * 32)、卷积(输出 5 * 5 * 64)、卷积(输出 5 * 5 * 64)、reshape(1600)、全连接层(512)、输出层(2)

一、flappy_bird_utils.py

"""
游戏素材加载
"""
import pygame
import sys
import os

assets_dir = os.path.dirname(__file__)

def load():
    # 小鸟挥动翅膀的3个造型
    PLAYER_PATH = (
        assets_dir + '/assets/sprites/redbird-upflap.png',
        assets_dir + '/assets/sprites/redbird-midflap.png',
        assets_dir + '/assets/sprites/redbird-downflap.png'
    )

    # 游戏背景图,纯黑色是为了训练降低干扰
    BACKGROUND_PATH = assets_dir + '/assets/sprites/background-black.png'

    # 水管图片
    PIPE_PATH = assets_dir + '/assets/sprites/pipe-green.png'

    IMAGES, SOUNDS, HITMASKS = {}, {}, {}
    #初始化了三个空字典:IMAGES用于存储加载的图片资源,SOUNDS用于存储加载的声音资源,HITMASKS用于存储碰撞掩码(用于检测游戏中的碰撞)

    # 加载数字0~9的图片,类型是Surface图像
    #使用convert_alpha()方法将图片转换为带有透明度的格式
    IMAGES['numbers'] = (
        pygame.image.load(assets_dir + '/assets/sprites/0.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/1.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/2.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/3.png').convert_alpha(),    # convert/conver_alpha是为了将图片转成绘制用的像素格式,提高绘制效率
        pygame.image.load(assets_dir + '/assets/sprites/4.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/5.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/6.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/7.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/8.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/9.png').convert_alpha()
    )

    # 地面图片
    IMAGES['base'] = pygame.image.load(assets_dir + '/assets/sprites/base.png').convert_alpha()


    #根据操作系统类型,设置声音文件的扩展名。Windows系统使用.wav,其他系统使用.ogg
    if 'win' in sys.platform:
        soundExt = '.wav'
    else:
        soundExt = '.ogg'

    # 各种Sound对象
    #加载各种游戏音效,并将它们存储在SOUNDS字典中
    SOUNDS['die']    = pygame.mixer.Sound(assets_dir + '/assets/audio/die' + soundExt)
    SOUNDS['hit']    = pygame.mixer.Sound(assets_dir + '/assets/audio/hit' + soundExt)
    SOUNDS['point']  = pygame.mixer.Sound(assets_dir + '/assets/audio/point' + soundExt)
    SOUNDS['swoosh'] = pygame.mixer.Sound(assets_dir + '/assets/audio/swoosh' + soundExt)
    SOUNDS['wing']   = pygame.mixer.Sound(assets_dir + '/assets/audio/wing' + soundExt)

    # 加载背景图片
    IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()

    # 加载小鸟的3个姿态
    IMAGES['player'] = (
        pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
    )

    # 加载水管图片,并使用rotate()方法将其旋转180度以创建上方的水管图片,然后将这两个图片存储在IMAGES字典中
    IMAGES['pipe'] = (
        pygame.transform.rotate(
            pygame.image.load(PIPE_PATH).convert_alpha(), 180),
        pygame.image.load(PIPE_PATH).convert_alpha(),
    )

    # 计算水管图片的bool掩码
    #为水管图片生成碰撞掩码,并将它们存储在HITMASKS字典中。
    HITMASKS['pipe'] = (
        getHitmask(IMAGES['pipe'][0]),
        getHitmask(IMAGES['pipe'][1]),
    )

    # 生成小鸟图片的bool掩码
    HITMASKS['player'] = (
        getHitmask(IMAGES['player'][0]),
        getHitmask(IMAGES['player'][1]),
        getHitmask(IMAGES['player'][2]),
    )

    return IMAGES, SOUNDS, HITMASKS

# 生成图片的bool掩码矩阵,true表示对应像素位置不是透明的部分
def getHitmask(image):
    """returns a hitmask using an image's alpha."""
    mask = []
    for x in range(image.get_width()):#遍历所有的像素点
        mask.append([])
        for y in range(image.get_height()):
            mask[x].append(bool(image.get_at((x,y))[3]))    # 像素点是RGBA,例如:(83, 56, 70, 255),最后是透明度(0是透明,255是不透明)
    #对于图像中的每一个像素点,使用 image.get_at((x,y)) 获取该点的颜色值。颜色值通常以 RGBA(红色、绿色、蓝色、透明度)格式存储,其中 A 代表 Alpha 通道,即透明度。image.get_at((x,y))[3] 就是获取该像素点的 Alpha 值。
    return mask

这里面需要解释的碰撞掩码是什么?

碰撞掩码(Collision Mask)是一种在计算机图形学和游戏开发中用于检测物体间碰撞的技术。它通常由一个布尔矩阵表示,其中每个像素点的值表示该点是否是物体的一部分。在处理碰撞检测时,通过比较两个物体的碰撞掩码可以判断它们是否重叠,从而确定是否发生了碰撞。

以下是碰撞掩码的一些关键点:

  1. 透明度判断:在许多游戏中,碰撞掩码是通过检查图像的透明度(Alpha通道)来生成的。如果图像的某个像素点是不透明的(例如,Alpha值为255),那么在碰撞掩码中对应的位置会被标记为True或实体部分;如果是透明的(Alpha值为0),则被标记为False或非实体部分。

  2. 简化碰撞检测:使用碰撞掩码可以避免直接对图像的每个像素点进行碰撞检测,这样可以显著提高碰撞检测的效率,尤其是在处理复杂图形或大规模场景时。

  3. 灵活性:碰撞掩码可以根据需要设计成不同的形状和大小,从而实现精确的碰撞检测。例如,一个角色的碰撞掩码可以是其轮廓的形状,而不仅仅是一个矩形或正方形。

  4. 性能优化:在游戏开发中,碰撞检测通常是一个计算密集型的过程。通过使用碰撞掩码,可以减少不必要的像素点比较,从而提高游戏性能。

  5. 应用场景:碰撞掩码不仅用于检测角色与障碍物之间的碰撞,还可以用于检测子弹与目标的碰撞、角色间的交互等。

getHitmask函数就是用来生成碰撞掩码的。它通过遍历图像的每个像素点,并检查其透明度来创建一个布尔矩阵。这个矩阵随后可以用于游戏中的碰撞检测逻辑,以判断小鸟是否与水管或其他物体发生了碰撞。

二、wrapped_flappy_bird.py

import numpy as np
import sys
import random
import pygame
from . import flappy_bird_utils
import pygame.surfarray as surfarray#用于将pygame的Surface对象转换为NumPy数组
from pygame.locals import *
from itertools import cycle#用于创建一个可循环的对象

# 屏幕宽*高
FPS = 30
SCREENWIDTH  = 288
SCREENHEIGHT = 512

# 初始化游戏,创建一个时钟对象来控制帧率,设置游戏窗口的尺寸和标题
pygame.init()
FPSCLOCK = pygame.time.Clock()  # FPS限速器
SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))   # 宽*高
pygame.display.set_caption('Flappy Bird')   # 标题

# 加载素材
IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()

PIPEGAPSIZE = 100 # 上下水管之间的距离是固定的100像素
BASEY = SCREENHEIGHT * 0.79 # 地面图片的y坐标

'''
地面图片在游戏窗口中的垂直位置。SCREENHEIGHT是游戏窗口的高度,
乘以0.79后得到一个值,这个值就是地面图片距离窗口顶部的像素距离。
因此,BASEY变量代表了地面图片在Y轴(垂直轴)上的位置,
它被设置在屏幕高度的79%的位置,这样地面图片会显示在屏幕的下半部分。
'''

# 小鸟图片的宽*高
PLAYER_WIDTH = IMAGES['player'][0].get_width()
PLAYER_HEIGHT = IMAGES['player'][0].get_height()
# 水管图片的宽*高
PIPE_WIDTH = IMAGES['pipe'][0].get_width()
PIPE_HEIGHT = IMAGES['pipe'][0].get_height()

# 背景图片的宽
BACKGROUND_WIDTH = IMAGES['background'].get_width()

# 创建一个循环对象,小鸟图片动画播放顺序
PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])

'''
0 表示第一张图片(翅膀上挥)
1 表示第二张图片(翅膀中挥)
2 表示第三张图片(翅膀下挥)
序列最后再次包含1,以实现翅膀的自然循环
'''

# Flappy bird游戏类
class GameState:
    def __init__(self):
        self.score = 0#初始化玩家的得分为 0
        self.playerIndex = 0#初始化玩家小鸟的当前动画索引为 0,这将决定小鸟显示哪一张动画图片
        self.loopIter = 0#初始化一个循环计数器,可能用于跟踪动画或游戏循环的次数

        # 玩家初始坐标
        self.playerx = int(SCREENWIDTH * 0.2)#设置玩家小鸟的初始 x 坐标,位于屏幕宽度的 20% 位置
        self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)#计算并设置玩家小鸟的初始 y 坐标,使得小鸟位于屏幕垂直居中的位置

        # 地面图片需要跑马灯效果,它比屏幕宽一点,每帧向左移动,当要耗尽时重新回到右边,如此往复
        self.basex = 0         # 地面图片的x坐标
        self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH  # 地面图片比屏幕宽度长多少像素,就是它可以移动的距离

        newPipe1 = getRandomPipe()  # 生成一对上下管子
        newPipe2 = getRandomPipe()  # 再生成一对上下管子

        # 上面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.upperPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
        ]
        # 下面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.lowerPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
        ]

        # 水管的水平移动速度,每次x-4实现向左移动
        self.pipeVelX = -4

        # 小鸟Y方向速度
        self.playerVelY    =  0
        # 小鸟Y方向重力加速度,每帧作用域playerVelY,令其Y速度向下加大
        self.playerAccY    =   1
        # 点击后,小鸟Y方向速度重置为-9,也就是开始向上移动
        self.playerFlapAcc =  -9

        # 小鸟Y方向速度限制
        self.playerMaxVelY =  10   # Y向下最大速度10

    # 执行一次操作,返回操作后的画面、本次操作的奖励(活着+0.1,死了-1,飞过水管+1)、游戏是否结束
    def frame_step(self, input_actions):
        # 给pygame对积累的事件做一下默认处理
        pygame.event.pump()

        # 活着就奖励0.1分
        reward = 0.01
        # 是否死了
        terminal = False

        # 必须传有效的action,[1,0]表示不点击,[0,1]表示点击,全传0是不对的
        if sum(input_actions) != 1:#检查 input_actions 确保只有一个动作被执行
            raise ValueError('Multiple input actions!')

        # 每3帧换一次小鸟造型图片,loopIter统计经过了多少帧
        if (self.loopIter + 1) % 3 == 0:
            self.playerIndex = next(PLAYER_INDEX_GEN)
        self.loopIter += 1

        # 让地面向左移动,游戏开始的时候地面x=0,逐步减小x
        if self.basex + self.pipeVelX <= -self.baseShift:
            self.basex = 0
        else: # 图片即将滚动耗尽,重置x坐标
            self.basex += self.pipeVelX

        # 点击了屏幕
        if input_actions[1] == 1:
            self.playerVelY = self.playerFlapAcc # 将小鸟y方向速度重置为-9,也就是向上移动
            #SOUNDS['wing'].play()   # 播放扇翅膀的声音
        elif self.playerVelY < self.playerMaxVelY:  # 没点击屏幕并且没达到最大掉落速度,继续施加重力加速度
            self.playerVelY += self.playerAccY

        # 将速度施加到小鸟的y坐标上
        self.playery += self.playerVelY
        if self.playery < 0:    # 撞到上边缘不算死
            self.playery = 0 # 限制它别飞出去
        elif self.playery + PLAYER_HEIGHT >= BASEY: # 小鸟碰到地面
            self.playery = BASEY - PLAYER_HEIGHT # 限制它别穿地

        # 让上下水管都向左移动一次
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            uPipe['x'] += self.pipeVelX
            lPipe['x'] += self.pipeVelX

        # 判断小鸟是否穿过了一排水管,因为上下水管x一样,只需要用上排水管判断
        playerMidPos = self.playerx + PLAYER_WIDTH / 2  # 小鸟中心的x坐标(这个是固定值,小鸟实际不会动,是水管在动)
        for pipe in self.upperPipes:    # 检查与上排水管的关系
            pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 # 水管中心的x坐标
            if pipeMidPos <= playerMidPos < pipeMidPos + abs(self.pipeVelX): # 小鸟x坐标刚刚飞过了水管x中心(4是水管的移动速度)
                self.score += 1 # 游戏得分+1
                #SOUNDS['point'].play()
                reward = 100  # 产生强化学习的动作奖励10分

        # 最左侧水管马上离开屏幕,生成新水管
        if 0 < self.upperPipes[0]['x'] < 5:
            newPipe = getRandomPipe()
            self.upperPipes.append(newPipe[0])
            self.lowerPipes.append(newPipe[1])

        # 最左侧水管彻底离开屏幕,删除它的上下2根水管
        if self.upperPipes[0]['x'] < -PIPE_WIDTH:
            self.upperPipes.pop(0)
            self.lowerPipes.pop(0)

        # 检查小鸟是否碰到水管
        isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 'index': self.playerIndex}, self.upperPipes, self.lowerPipes)
        if isCrash:  # 死掉了
            #SOUNDS['hit'].play()
            #SOUNDS['die'].play()
            reward = -10 # 负向激励分
            terminal = True # 本次操作导致游戏结束了

        ##### 进入重绘 #######

        # 贴背景图
        SCREEN.blit(IMAGES['background'], (0,0))
        # 画水管
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
            SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
        # 画地面
        SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
        # 画得分(训练时候别打开,造成干扰了)
        #showScore(self.score)
        # 画小鸟
        SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery))
        # 重绘
        pygame.display.update()
        # 留存游戏画面(截图是列优先存储的,需要转行行优先存储)
        # https://stackoverflow.com/questions/34673424/how-to-get-numpy-array-of-rgb-colors-from-pygame-surface
        image_data = pygame.surfarray.array3d(pygame.display.get_surface()).swapaxes(0,1)
        # 死亡则重置游戏状态
        if terminal:
            self.__init__()
        # 控制FPS
        FPSCLOCK.tick(FPS)
        return image_data, reward, terminal

# 生成一对水管,放到屏幕外面
def getRandomPipe():
    gapY = random.randint(70, 140)#生成一个介于 70 到 140 之间的随机整数,并将其赋值给变量 gapY。这个随机数决定了水管之间缝隙的上边缘的 y 坐标

    # 注:每一对水管的缝隙高度都是一样的PIPEGAPSIZE,gayY决定的是缝隙的上边缘y坐标
    pipeX = SCREENWIDTH + 10    # 水管出现在屏幕右侧之外

    return [
        {'x': pipeX, 'y': gapY - PIPE_HEIGHT},  # 计算上面水管图片的y坐标,就是缝隙上边缘y减去水管本身高度
        {'x': pipeX, 'y': gapY + PIPEGAPSIZE},  # 计算下面水管图片的y坐标,就是缝隙上边缘y加上缝隙本身高度
    ]

# 检查小鸟是否碰到水管或者地面(天花板不算)
def checkCrash(player, upperPipes, lowerPipes):
    pi = player['index']    # 小鸟的第几张图片

    # 图片的宽*高
    player['w'] = IMAGES['player'][pi].get_width()
    player['h'] = IMAGES['player'][pi].get_height()

    # 小鸟碰到了地面
    if player['y'] + player['h'] >= BASEY - 1:
        return True
    else: # 小鸟与水管进行碰撞检测
        # 小鸟图片的矩形区域
        playerRect = pygame.Rect(player['x'], player['y'], player['w'], player['h'])

        # 每一对水管
        for uPipe, lPipe in zip(upperPipes, lowerPipes):
            # 上面水管的矩形
            uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
            # 下面水管的矩形
            lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)

            # 小鸟图片的非透明像素掩码
            pHitMask = HITMASKS['player'][pi]
            # 上水管的非透明像素掩码
            uHitmask = HITMASKS['pipe'][0]
            # 下水管的非透明像素掩码
            lHitmask = HITMASKS['pipe'][1]

            # 检测小鸟与上面水管的碰撞
            uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
            # 检测小鸟与下面水管的碰撞
            lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)

            if uCollide or lCollide:
                return True
    return False


# 2个矩形区域的碰撞检测
def pixelCollision(rect1, rect2, hitmask1, hitmask2):
    '''
    rect1 和 rect2 是参与碰撞检测的两个矩形区域,通常是游戏中对象的位置和大小
    hitmask1 和 hitmask2 是与这两个矩形关联的碰撞掩码,它们是布尔数组,表示相应对象的哪些部分是实体(非透明)
    '''
    # 计算两个矩形的交集,即它们重叠的区域。如果没有重叠(即两个矩形没有碰撞),则 clip 方法返回一个宽度或高度为 0 的矩形
    rect = rect1.clip(rect2)

    # 相交面积为0
    if rect.width == 0 or rect.height == 0:
        return False

    # 相交矩形x,y相对于2个矩形左上角的距离
    x1, y1 = rect.x - rect1.x, rect.y - rect1.y
    #计算交集区域相对于 rect1 的相对位置
    x2, y2 = rect.x - rect2.x, rect.y - rect2.y#同理

    # 检查相交矩形内的每个点,是否在2个矩形内同时是非透明点,那么就碰撞了
    for x in range(rect.width):
        for y in range(rect.height):
            if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
                return True
    return False

# 展示得分,传入一个整数得分
def showScore(score):
    # 转成单个数字的列表
    scoreDigits = [int(x) for x in list(str(score))]
    #将得分 score 转换成字符串,然后将其每个字符(即每个单独的数字)转换成整数,并存储在列表 scoreDigits 中。这样,得分就被分解成了单个数字的列表

    # 计算展示所有数字要占多少像素宽度
    totalWidth = 0
    for digit in scoreDigits:
        totalWidth += IMAGES['numbers'][digit].get_width()
        '''
        遍历 scoreDigits 列表中的每个数字
        将每个数字图像的宽度累加到 totalWidth
        '''

    # 计算绘制起始x坐标
    Xoffset = (SCREENWIDTH - totalWidth) / 2

    # 逐个数字绘制
    for digit in scoreDigits:
        SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, 20))    # y坐标贴近屏幕上边缘
        Xoffset += IMAGES['numbers'][digit].get_width() # 移动绘制x坐标

三、q_game.py

"""
强化学习q learning flappy bird
"""
from game.wrapped_flappy_bird import GameState
import time
import numpy as np 
import skimage.color
import skimage.transform
import skimage.exposure
import tensorflow as tf 
import random 
import argparse

# 命令行参数
parser = argparse.ArgumentParser()#创建一个 ArgumentParser 对象,用于定义命令行参数
parser.add_argument("--model-only", help="加载已有模型,不随机探索,仍旧训练", action='store_true')
args = parser.parse_args()#解析命令行输入的参数,并将它们存储在 args 变量中

# 测试用代码
def _test_save_img(img):
    # 把每一帧图片存储到文件里,调试用
    from PIL import Image
    im = Image.fromarray((img*255).astype(np.uint8), mode='L') # 图片已经被处理为0~1之间的亮度值,所以*255取整数变灰度展示
    im.save('./img.jpg')

# 构建卷积神经网络
def build_model():
    # 卷积神经网络:https://blog.csdn.net/FontThrone/article/details/76652753
    model = tf.keras.models.Sequential([#创建一个 Sequential 模型,它是 tf.keras 中用于线性堆叠网络层的模型类
        tf.keras.layers.Input(shape=(80,80,4)),
        tf.keras.layers.Conv2D(filters=32, kernel_size=(8, 8), padding='same',strides=4, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(4, 4), padding='same',strides=2, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same',strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Flatten(),#将三维的卷积层输出展平为一维,以便传入到全连接层
        tf.keras.layers.Dense(256, activation='relu'),#定义一个具有 256 个单元的全连接层,并使用 ReLU 激活函数
        tf.keras.layers.Dense(2), # 对应2个action未来总回报预期
    ])
    model.compile(loss='mse', optimizer='adam')#编译模型,指定均方误差(MSE)作为损失函数,使用 Adam 优化器

    # 尝试加载之前保存的模型参数
    try:
        model.load_weights('./weights.h5')
        print('加载模型成功...................')
    except:
        pass
    return model

# 创建游戏
game = GameState()
# 卷积模型
model = build_model()

# 执行1帧游戏
def run_one_frame(action):
    global game 
    # image_data:执行动作后的图像(288*512*3的RGB三维数组)
    # reward:本次动作的奖励
    # terminal:游戏是否失败
    img, reward, terminal = game.frame_step(action)
    # RGB转灰度图
    img = skimage.color.rgb2gray(img)
    # 压缩到80*80的图片(根据RGB算出来的亮度,其数值很小)
    img = skimage.transform.resize(img, (80,80))
    # 把亮度标准化到0~1之间,用作模型输入
    img = skimage.exposure.rescale_intensity(img, out_range=(0,1))
    return img,reward,terminal

# 强化学习初始化状态
def reset_stat():
    # 执行第一帧,不点击
    img_t,_,_ =  run_one_frame([1,0])
    '''
    使用 numpy.stack 函数将首帧图像 img_t 重复四次,沿着第三个维度堆叠,
    形成初始状态 stat_t。这是因为卷积神经网络需要连续几帧的图像作为输入
    '''
    stat_t = np.stack([img_t] * 4, axis=2)
    return stat_t 

# 初始状态
stat_t = reset_stat()
# 训练样本
transitions = []#用于存储训练过程中的状态转换样本

# 时刻
t = 0

# 随机探索的概率控制,定义了随机探索概率的初始值、最终值和每次更新的步长。
INIT_EPSILON = 0.1
FINAL_EPSILON = 0.005
EPSLION_DELTA = 1e-6
# 最大留存样本个数
TRANS_CAP =  20000
# 至少有多少样本才训练
TRANS_SIZE_FIT = 10000
# 训练集大小
BATCH_SIZE = 32
# 未来激励折扣
GAMMA = 0.99

# 随机探索概率
if args.model_only: # 不随机探索(极低概率)
    epsilon = FINAL_EPSILON
else:
    epsilon = INIT_EPSILON

# 打印一些进度信息
rand_flap =0    # 随机点击次数
rand_noflap = 0 # 随机不点击次数
model_flap=0    # 模型点击次数
model_noflap=0  # 模型不点击次数
model_train_times = 0   # 模型训练次数

# 游戏启动
while True:    
    # 动作
    action_t = [0,0]

    action_type = '随机'#设置动作类型默认为 '随机',这将在选择动作时用于判断动作是随机选择的还是基于模型经验选择的。

    # 随着学习,降低随机探索的概率,让模型趋于稳定
    if (t <= TRANS_SIZE_FIT and not args.model_only) or random.random() <= epsilon:
        '''判断是否应该进行随机探索。如果在观察期内(t <= TRANS_SIZE_FIT)或者随机数小于或等于 epsilon,则执行随机探索'''
        n = random.random()
        if n <= 0.95:
            action_index = 0
            rand_noflap+=1
        else:
            action_index = 1
            rand_flap+=1
        #print('[随机探索] t时刻进行随机动作探索...')
    else: # 模型预测2个操作的未来累计回报
        action_type = '经验'
        Q_t = model.predict(np.expand_dims(stat_t, axis=0))[0]
        #使用当前的模型和状态 stat_t 来预测两个动作的未来总回报
        action_index = np.argmax(Q_t)   # 回报最大的action下标
        if action_index==0:
            model_noflap+=1
        else:
            model_flap+=1
        #print('[已有经验] 预测t时刻2个动作的未来总回报 -- 不点击:{} 点击:{}'.format(Q_t[0], Q_t[1]))

    action_t[action_index] = 1
    #print('时刻t将执行的动作为{}'.format(action_t))

    # 执行当前动作,返回操作后的图片、本次激励、游戏是否结束
    img_t1, reward, terminal = run_one_frame(action_t)
    _test_save_img(img_t1)
    img_t1 = img_t1.reshape((80,80,1)) # 增加通道维度,因为我们要最近4帧作为4通道图片,用作卷积模型输入
    stat_t1 = np.append(stat_t[:,:,1:], img_t1, axis=2) # 80*80*4,淘汰当前的第0通道,添加最新t1时刻到第3通道

    # 收集训练样本(保留有限的)
    transitions.append({
        'stat_t': stat_t,   # t时刻状态
        'stat_t1': stat_t1, # t1时刻状态
        'reward': reward,   # 本次动作的激励得分
        'terminal': terminal,   # 执行动作后游戏是否结束(ps: 结束意味着没有未来激励了)
        'action_index': action_index,   # 执行了什么动作(0:不点击,1:点击)
    })
    if len(transitions) > TRANS_CAP:
        transitions.pop(0)
    
    # 游戏结束则重置stat_t
    if terminal:
        stat_t = reset_stat()
        #print('死了!!!!!!! 状态t重置为初始帧...')
    else:   # 否则切为新的状态
        stat_t = stat_t1
        #print('没死~~~ 状态t切换为状态t1...')

    # 过了观察期,开始训练
    if t >= TRANS_SIZE_FIT and t % 10 == 0:
        minibatch = random.sample(transitions, BATCH_SIZE)
        # 模型训练的输入:t时刻的状态(最近4帧图片)
        inputs_t = np.concatenate([tran['stat_t'].reshape((1,80,80,4)) for tran in minibatch])
        #print('inputs_t shape', inputs_t.shape)
        ######################################################
        # 模型训练的输出:t时刻的未来总激励(Q_t = reward+gamma*Q_t1)
        # 1,让模型预测t时刻2种action的未来总激励
        Q_t = model.predict(inputs_t, batch_size=len(minibatch))
        # 2,让模型预测t1时刻2种action的未来总激励
        input_t1 = np.concatenate([tran['stat_t1'].reshape((1,80,80,4)) for tran in minibatch])
        Q_t1 = model.predict(input_t1, batch_size=len(minibatch))
        # 3,保留t1时刻2个action中最大的未来总激励
        Q_t1_max = [max(q) for q in Q_t1]
        # 4,t时刻进行action_index动作得到真实激励
        reward_t = [tran['reward'] for tran in minibatch]
        # 5,t时刻进行了什么action
        action_index_t = [tran['action_index'] for tran in minibatch]
        # 6,t1时刻是否死掉了
        terminal = [tran['terminal'] for tran in minibatch]
        # 7,修正训练的目标Q_t=reward+gamma*Q_t1
        # (t时刻action_index的未来总激励=action_index真实激励+t1时刻预测的最大未来总激励)
        for i in range(len(minibatch)):
            if terminal[i]:
                Q_t[i][action_index_t[i]] = reward_t[i] # 因为t1时刻已经死了,所以没有t1之后的累计激励
            else:
                Q_t[i][action_index_t[i]] = reward_t[i] + GAMMA*Q_t1_max[i] # Q_t=reward+Q_t1
        # print('Q_t shape', Q_t.shape)
        # 训练一波
        #print(inputs_t)
        #print(Q_t)
        model.fit(inputs_t, Q_t, batch_size=len(minibatch))
        model_train_times += 1
        # 训练1次则降低些许的随机探索概率
        if epsilon > FINAL_EPSILON:
            epsilon -= EPSLION_DELTA
        
        # 每5000次batch保存一次模型权重(不适用saved_model,后续加载只会加载权重,模型结构还是程序构造,因为这样可以保持keras model的api)
        if model_train_times % 5000 == 0:
            model.save_weights('./weights.h5')

        ######################################################
    if t % 100 == 0:
        print('总帧数:{} 剩余探索概率:{}% 累计训练次数:{} 累计随机点:{} 累计随机不点:{} 累计模型点:{} 累计模型不点:{} 训练集:{} '.format(
            t, round(epsilon * 100, 4), model_train_times, rand_flap, rand_noflap, model_flap, model_noflap,
            len(transitions)))
    t = t + 1
    #time.sleep(1)

四、text_game.py

"""
演示pygame制作的flappy bird如何逐帧调用执行
"""
from game.wrapped_flappy_bird import GameState
from random import random
import time

# 创建游戏
game = GameState()

# 游戏启动
while True:
    r = random()
    if r <= 0.92:  # 92%的概率不点击屏幕
        game.frame_step([1,0]) # 动作:[1,0] 表示不点击
    else: # 8%的概率点击屏幕
        game.frame_step([0,1]) # 动作:[0,1] 表示点击

五、训练结果

请添加图片描述

代码源自:强化学习Deep Q-Network自动玩flappy bird | 鱼儿的博客 (yuerblog.cc)

仅想具体看一下工作原理和代码,仅供学习使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1646271.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vscode连接服务器的docker步骤

进入容器之后&#xff0c;操作方式与本地windows系统操作逻辑一样&#xff1b;容器内部结构都能任意查看和使用&#xff0c;创建文件及编写python脚本都可以直接使用vs code编辑器进行编辑和调试&#xff0c;从而避免使用命令行及vim编辑文件&#xff0c;非常直观且方便~

4. RedHat认证-进程管理

4. RedHat认证-进程管理 1.进程概念 进程就是正在运行中的程序或者命令 每一个进程都是运行的实体&#xff0c;都有自己的地址空间&#xff0c;并占有一定的资源空间 程序消耗的是磁盘资源、进程消耗的是内存和CPU资源 进程会占用四类资源&#xff08;CPU 、内存、磁盘、网…

python爬虫(一)之 抓取极氪网站汽车文章

极氪汽车文章爬虫 闲来没事&#xff0c;将极氪网站的汽车文章吃干抹尽&#xff0c;全部抓取到本地&#xff0c;还是有点小小的难度。不能抓取太快&#xff0c;太快容易被封禁IP&#xff0c;不过就算被封了问题也不大&#xff0c;大不了重启路由器&#xff0c;然后你的IP里面又…

i.MX 6ULL 裸机 IAR 环境安装

一. IAR 的安装请自行搜索 二. 使用最新版本的 IAR&#xff0c;需要修改 SDK 1. 在 SDK 的 core_ca7.h 加上 #include "intrinsics.h" /* IAR Intrinsics */ 2. debug 时需要修改每个工程下的 ddr_init.jlinkscript&#xff0c;参考链接 Solved: How to conn…

使用C语言实现杨氏矩阵并找出数字

前言 过了五一假期&#xff0c;咋们经过了一个假期的休息&#xff0c;要继续学习了&#xff0c;不能偷懒哦&#xff01;&#xff01; 今天让我们来看看如何在一个杨氏矩阵中找出自己想找到的数字。 首先&#xff0c;我们要了解一下杨氏矩阵到底是什么&#xff0c;如果一个矩阵中…

[redis] redis为什么快

1. Redis与Memcached的区别 两者都是非关系型内存键值数据库&#xff0c;现在公司一般都是用 Redis 来实现缓存&#xff0c;而且 Redis 自身也越来越强大了&#xff01;Redis 与 Memcached 主要有以下不同&#xff1a; (1) memcached所有的值均是简单的字符串&#xff0c;red…

ACPWorkbench_for_BP10

一、菜单 文件菜单包含导入导出所有参数&#xff0c;导出flashbin文件和退出操作。文件菜单显示如下&#xff1a; Import Audio Settings&#xff1a;从音频配置文件中导入音频参数。 Export Audio Settings&#xff1a;将音频设置导出为音频配置文件。 Export Flash Binary Fi…

OpenNJet:下一代云原生应用引擎

OpenNJet&#xff1a;下一代云原生应用引擎 前言一、技术架构二、新增特性1. 透明流量劫持2. 熔断机制3. 遥测与故障注入 三、Ubuntu 发行版安装 OpentNJet1. 添加gpg 文件2. 添加APT 源3. 安装及启动4. 验证 总结 前言 OpenNJet&#xff0c;是一款基于强大的 NGINX 技术栈构建…

设置定位坐标+请按任意键继续

设置定位坐标 目的 在编程和游戏开发中&#xff0c;设置定位坐标的目的是为了确定对象在屏幕或游戏世界中的具体位置。坐标通常由一对数值表示&#xff0c;例如 (x, y)&#xff0c;其中 x 表示水平位置&#xff0c;y 表示垂直位置。设置定位坐标的目的包括&#xff1a; 1. **精…

【JavaScript】数据类型转换

JavaScript 中的数据类型转换主要包括两种&#xff1a;隐式类型转换&#xff08;Implicit Type Conversion&#xff09;和显式类型转换&#xff08;Explicit Type Conversion&#xff09;。 1. 隐式类型转换&#xff08;自动转换&#xff09;&#xff1a; js 是动态语言&…

CNN笔记详解

CNN(卷积神经网络) 计算机视觉&#xff0c;当你们听到这一概念的是否好奇计算机到底是怎样知道这个图片是什么的呢&#xff1f;为此提出了卷积神经网络&#xff0c;通过卷积神经网络&#xff0c;计算机就可以识别出图片中的特征&#xff0c;从而识别出图片中的物体。看到这里充…

XYCTF2024 RE ez unity 复现

dll依然有加壳 但是这次global-metadata.dat也加密了&#xff0c;原工具没办法用了&#xff0c;不过依然是可以修复的 a. 法一&#xff1a;frida-il2cpp-bridge 可以用frida-il2cpp-bridge GitHub - vfsfitvnm/frida-il2cpp-bridge: A Frida module to dump, trace or hijac…

深度剖析muduo网络库1.1---面试提问(阻塞、非阻塞、同步、异步)

在面试过程中&#xff0c;如果被问到关于IO的阻塞、非阻塞、同步、异步时&#xff0c;我们应该如何回答呢&#xff1f; 结合最近学习的课程&#xff0c;我作出了以下的总结&#xff0c;希望能与大家共同探讨&#xff01; 先给出 陈硕大神原话&#xff1a;在处理IO的时候&…

存储故障后oracle报—ORA-01122/ORA-01207故障处理---惜分飞

客户存储异常,通过硬件恢复解决存储故障之后,oracle数据库无法正常启动(存储cache丢失),尝试recover数据库报ORA-00283 ORA-01122 ORA-01110 ORA-01207错误 以前处理过比较类似的存储故障case:又一起存储故障导致ORA-00333 ORA-00312恢复存储故障,强制拉库报ORA-600 kcbzib_kcr…

计算机毕设

随着社会和国家的重视&#xff0c;大学对于大学生毕业设计越来越重视。 做软件设计设计方面&#xff0c;前后端分离是必不可少的&#xff0c;代码管理工具&#xff0c;前后端接口测试是项目中必须要用到的工具。做大数据设计方面&#xff0c;主要是要用到爬虫进行数据爬取&…

(二)JSP教程——taglib指令

创建标签文件 首先创建一个Web项目&#xff0c;在webapp/WEB-INF目录下创建一个tags文件夹 在tags文件夹中创建一个oddNumberSum.tag文件&#xff0c;Tag文件时扩展名为.tag的文本文件&#xff0c;其结构和JSP文件非常相似&#xff0c;该文件的目录结构如图所示 创建Tag文件的…

Altium Designer——检查原理图库正确性并生成报告

一、方法&#xff1a; 1.打开原理图库&#xff1a; 2.点击菜单栏的报告选项&#xff1a; 3.选择器件规则检查&#xff1a; 根据需求勾选&#xff0c;一般都是全部勾选&#xff1a; 二、问题&#xff1a; 1.缺少封装会导致什么问题&#xff1a; 1.首先&#xff1a; 封装是…

方法的入栈和出栈

一.作用域问题 1.全局作用域 在全局都能进行访问的变量 var a 10;function fn() {var b 20;return a b;}console.log(fn()); 2.局部的作用域 只能在限定的范围内进行访问 function fn() {var b 20;}console.log(b); b is not defined 打印的结果是b这个变量没用定义 3…

9.Admin后台系统

9. Admin后台系统 Admin后台系统也称为网站后台管理系统, 主要对网站的信息进行管理, 如文字, 图片, 影音和其他日常使用的文件的发布, 更新, 删除等操作, 也包括功能信息的统计和管理, 如用户信息, 订单信息和访客信息等. 简单来说, 它是对网站数据库和文件进行快速操作和管…

Xinlinx FPGA如何降低Block RAM的功耗

FPGA中降低Block RAM的功耗有两种方式&#xff0c;分别是选择合适的写操作模式以及Block RAM的实现算法及综合设置。我们知道对于采用IP核生成对应的RAM时&#xff0c;会有最小面积算法、低功耗算法以及固定原语&#xff0c;但是采用最小功耗算法有时由于级联长度导致无法实现&…