目录
实机演示
代码实现
实机演示
用深度强化学习来玩Chrome小恐龙快跑
代码实现
import os
import cv2
from pygame import RLEACCEL
from pygame.image import load
from pygame.sprite import Sprite, Group, collide_mask
from pygame import Rect, init, time, display, mixer, transform, Surface
from pygame.surfarray import array3d
import torch
from random import randrange, choice
import numpy as np
mixer.pre_init(44100, -16, 2, 2048)
init()
scr_size = (width, height) = (600, 150)
FPS = 60
gravity = 0.6
black = (0, 0, 0)
white = (255, 255, 255)
background_col = (235, 235, 235)
high_score = 0
screen = display.set_mode(scr_size)
clock = time.Clock()
display.set_caption("T-Rex Rush")
def load_image(
name,
sizex=-1,
sizey=-1,
colorkey=None,
):
fullname = os.path.join("assets/sprites", name)
image = load(fullname)
image = image.convert()
if colorkey is not None:
if colorkey is -1:
colorkey = image.get_at((0, 0))
image.set_colorkey(colorkey, RLEACCEL)
if sizex != -1 or sizey != -1:
image = transform.scale(image, (sizex, sizey))
return (image, image.get_rect())
def load_sprite_sheet(
sheetname,
nx,
ny,
scalex=-1,
scaley=-1,
colorkey=None,
):
fullname = os.path.join("assets/sprites", sheetname)
sheet = load(fullname)
sheet = sheet.convert()
sheet_rect = sheet.get_rect()
sprites = []
sizey = sheet_rect.height / ny
if isinstance(nx, int):
sizex = sheet_rect.width / nx
for i in range(0, ny):
for j in range(0, nx):
rect = Rect((j * sizex, i * sizey, sizex, sizey))
image = Surface(rect.size)
image = image.convert()
image.blit(sheet, (0, 0), rect)
if colorkey is not None:
if colorkey is -1:
colorkey = image.get_at((0, 0))
image.set_colorkey(colorkey, RLEACCEL)
if scalex != -1 or scaley != -1:
image = transform.scale(image, (scalex, scaley))
sprites.append(image)
else: #list
sizex_ls = [sheet_rect.width / i_nx for i_nx in nx]
for i in range(0, ny):
for i_nx, sizex, i_scalex in zip(nx, sizex_ls, scalex):
for j in range(0, i_nx):
rect = Rect((j * sizex, i * sizey, sizex, sizey))
image = Surface(rect.size)
image = image.convert()
image.blit(sheet, (0, 0), rect)
if colorkey is not None:
if colorkey is -1:
colorkey = image.get_at((0, 0))
image.set_colorkey(colorkey, RLEACCEL)
if i_scalex != -1 or scaley != -1:
image = transform.scale(image, (i_scalex, scaley))
sprites.append(image)
sprite_rect = sprites[0].get_rect()
return sprites, sprite_rect
def extractDigits(number):
if number > -1:
digits = []
i = 0
while (number / 10 != 0):
digits.append(number % 10)
number = int(number / 10)
digits.append(number % 10)
for i in range(len(digits), 5):
digits.append(0)
digits.reverse()
return digits
def pre_processing(image, w=84, h=84):
image = image[:300, :, :]
# cv2.imwrite("ori.jpg", image)
image = cv2.cvtColor(cv2.resize(image, (w, h)), cv2.COLOR_BGR2GRAY)
# cv2.imwrite("color.jpg", image)
_, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
# cv2.imwrite("bw.jpg", image)
return image[None, :, :].astype(np.float32)
class Dino():
def __init__(self, sizex=-1, sizey=-1):
self.images, self.rect = load_sprite_sheet("dino.png", 5, 1, sizex, sizey, -1)
self.images1, self.rect1 = load_sprite_sheet("dino_ducking.png", 2, 1, 59, sizey, -1)
self.rect.bottom = int(0.98 * height)
self.rect.left = width / 15
self.image = self.images[0]
self.index = 0
self.counter = 0
self.score = 0
self.isJumping = False
self.isDead = False
self.isDucking = False
self.isBlinking = False
self.movement = [0, 0]
self.jumpSpeed = 11.5
self.stand_pos_width = self.rect.width
self.duck_pos_width = self.rect1.width
def draw(self):
screen.blit(self.image, self.rect)
def checkbounds(self):
if self.rect.bottom > int(0.98 * height):
self.rect.bottom = int(0.98 * height)
self.isJumping = False
def update(self):
if self.isJumping:
self.movement[1] = self.movement[1] + gravity
if self.isJumping:
self.index = 0
elif self.isBlinking:
if self.index == 0:
if self.counter % 400 == 399:
self.index = (self.index + 1) % 2
else:
if self.counter % 20 == 19:
self.index = (self.index + 1) % 2
elif self.isDucking:
if self.counter % 5 == 0:
self.index = (self.index + 1) % 2
else:
if self.counter % 5 == 0:
self.index = (self.index + 1) % 2 + 2
if self.isDead:
self.index = 4
if not self.isDucking:
self.image = self.images[self.index]
self.rect.width = self.stand_pos_width
else:
self.image = self.images1[(self.index) % 2]
self.rect.width = self.duck_pos_width
self.rect = self.rect.move(self.movement)
self.checkbounds()
if not self.isDead and self.counter % 7 == 6 and self.isBlinking == False:
self.score += 1
self.counter = (self.counter + 1)
class Cactus(Sprite):
def __init__(self, speed=5, sizex=-1, sizey=-1):
Sprite.__init__(self, self.containers)
self.images, self.rect = load_sprite_sheet("cacti-small.png", [2, 3, 6], 1, sizex, sizey, -1)
self.rect.bottom = int(0.98 * height)
self.rect.left = width + self.rect.width
self.image = self.images[randrange(0, 11)]
self.movement = [-1 * speed, 0]
def draw(self):
screen.blit(self.image, self.rect)
def update(self):
self.rect = self.rect.move(self.movement)
if self.rect.right < 0:
self.kill()
class Ptera(Sprite):
def __init__(self, speed=5, sizex=-1, sizey=-1):
Sprite.__init__(self, self.containers)
self.images, self.rect = load_sprite_sheet("ptera.png", 2, 1, sizex, sizey, -1)
self.ptera_height = [height * 0.82, height * 0.75, height * 0.60, height * 0.48]
self.rect.centery = self.ptera_height[randrange(0, 4)]
self.rect.left = width + self.rect.width
self.image = self.images[0]
self.movement = [-1 * speed, 0]
self.index = 0
self.counter = 0
def draw(self):
screen.blit(self.image, self.rect)
def update(self):
if self.counter % 10 == 0:
self.index = (self.index + 1) % 2
self.image = self.images[self.index]
self.rect = self.rect.move(self.movement)
self.counter = (self.counter + 1)
if self.rect.right < 0:
self.kill()
class Ground():
def __init__(self, speed=-5):
self.image, self.rect = load_image("ground.png", -1, -1, -1)
self.image1, self.rect1 = load_image("ground.png", -1, -1, -1)
self.rect.bottom = height
self.rect1.bottom = height
self.rect1.left = self.rect.right
self.speed = speed
def draw(self):
screen.blit(self.image, self.rect)
screen.blit(self.image1, self.rect1)
def update(self):
self.rect.left += self.speed
self.rect1.left += self.speed
if self.rect.right < 0:
self.rect.left = self.rect1.right
if self.rect1.right < 0:
self.rect1.left = self.rect.right
class Cloud(Sprite):
def __init__(self, x, y):
Sprite.__init__(self, self.containers)
self.image, self.rect = load_image("cloud.png", int(90 * 30 / 42), 30, -1)
self.speed = 1
self.rect.left = x
self.rect.top = y
self.movement = [-1 * self.speed, 0]
def draw(self):
screen.blit(self.image, self.rect)
def update(self):
self.rect = self.rect.move(self.movement)
if self.rect.right < 0:
self.kill()
class Scoreboard():
def __init__(self, x=-1, y=-1):
self.score = 0
self.tempimages, self.temprect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
self.image = Surface((55, int(11 * 6 / 5)))
self.rect = self.image.get_rect()
if x == -1:
self.rect.left = width * 0.89
else:
self.rect.left = x
if y == -1:
self.rect.top = height * 0.1
else:
self.rect.top = y
def draw(self):
screen.blit(self.image, self.rect)
def update(self, score):
score_digits = extractDigits(score)
self.image.fill(background_col)
if len(score_digits) == 6:
score_digits = score_digits[1:]
for s in score_digits:
self.image.blit(self.tempimages[s], self.temprect)
self.temprect.left += self.temprect.width
self.temprect.left = 0
class ChromeDino(object):
def __init__(self):
self.gamespeed = 5
self.gameOver = False
self.gameQuit = False
self.playerDino = Dino(44, 47)
self.new_ground = Ground(-1 * self.gamespeed)
self.scb = Scoreboard()
self.highsc = Scoreboard(width * 0.78)
self.counter = 0
self.cacti = Group()
self.pteras = Group()
self.clouds = Group()
self.last_obstacle = Group()
Cactus.containers = self.cacti
Ptera.containers = self.pteras
Cloud.containers = self.clouds
self.retbutton_image, self.retbutton_rect = load_image("replay_button.png", 35, 31, -1)
self.gameover_image, self.gameover_rect = load_image("game_over.png", 190, 11, -1)
self.temp_images, self.temp_rect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
self.HI_image = Surface((22, int(11 * 6 / 5)))
self.HI_rect = self.HI_image.get_rect()
self.HI_image.fill(background_col)
self.HI_image.blit(self.temp_images[10], self.temp_rect)
self.temp_rect.left += self.temp_rect.width
self.HI_image.blit(self.temp_images[11], self.temp_rect)
self.HI_rect.top = height * 0.1
self.HI_rect.left = width * 0.73
def step(self, action, record=False): # 0: Do nothing. 1: Jump. 2: Duck
reward = 0.1
if action == 0:
reward += 0.01
self.playerDino.isDucking = False
elif action == 1:
self.playerDino.isDucking = False
if self.playerDino.rect.bottom == int(0.98 * height):
self.playerDino.isJumping = True
self.playerDino.movement[1] = -1 * self.playerDino.jumpSpeed
elif action == 2:
if not (self.playerDino.isJumping and self.playerDino.isDead) and self.playerDino.rect.bottom == int(
0.98 * height):
self.playerDino.isDucking = True
for c in self.cacti:
c.movement[0] = -1 * self.gamespeed
if collide_mask(self.playerDino, c):
self.playerDino.isDead = True
reward = -1
break
else:
if c.rect.right < self.playerDino.rect.left < c.rect.right + self.gamespeed + 1:
reward = 1
break
for p in self.pteras:
p.movement[0] = -1 * self.gamespeed
if collide_mask(self.playerDino, p):
self.playerDino.isDead = True
reward = -1
break
else:
if p.rect.right < self.playerDino.rect.left < p.rect.right + self.gamespeed + 1:
reward = 1
break
if len(self.cacti) < 2:
if len(self.cacti) == 0 and len(self.pteras) == 0:
self.last_obstacle.empty()
self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
else:
for l in self.last_obstacle:
if l.rect.right < width * 0.7 and randrange(0, 50) == 10:
self.last_obstacle.empty()
self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
# if len(self.pteras) == 0 and randrange(0, 200) == 10 and self.counter > 500:
if len(self.pteras) == 0 and len(self.cacti) < 2 and randrange(0, 50) == 10 and self.counter > 500:
for l in self.last_obstacle:
if l.rect.right < width * 0.8:
self.last_obstacle.empty()
self.last_obstacle.add(Ptera(self.gamespeed, 46, 40))
if len(self.clouds) < 5 and randrange(0, 300) == 10:
Cloud(width, randrange(height / 5, height / 2))
self.playerDino.update()
self.cacti.update()
self.pteras.update()
self.clouds.update()
self.new_ground.update()
self.scb.update(self.playerDino.score)
state = display.get_surface()
screen.fill(background_col)
self.new_ground.draw()
self.clouds.draw(screen)
self.scb.draw()
self.cacti.draw(screen)
self.pteras.draw(screen)
self.playerDino.draw()
display.update()
clock.tick(FPS)
if self.playerDino.isDead:
self.gameOver = True
self.counter = (self.counter + 1)
if self.gameOver:
self.__init__()
state = array3d(state)
if record:
return torch.from_numpy(pre_processing(state)), np.transpose(
cv2.cvtColor(state, cv2.COLOR_RGB2BGR), (1, 0, 2)), reward, not (reward > 0)
else:
return torch.from_numpy(pre_processing(state)), reward, not (reward > 0)
import torch.nn as nn
class DeepQNetwork(nn.Module):
def __init__(self):
super(DeepQNetwork, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))
self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
self.fc2 = nn.Linear(512, 3)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.uniform_(m.weight, -0.01, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, input):
output = self.conv1(input)
output = self.conv2(output)
output = self.conv3(output)
output = output.view(output.size(0), -1)
output = self.fc1(output)
output = self.fc2(output)
return output
import argparse
import torch
from src.model import DeepQNetwork
from src.env import ChromeDino
import cv2
def get_args():
parser = argparse.ArgumentParser(
"""Implementation of Deep Q Network to play Chrome Dino""")
parser.add_argument("--saved_path", type=str, default="trained_models")
parser.add_argument("--fps", type=int, default=60, help="frames per second")
parser.add_argument("--output", type=str, default="output/chrome_dino.mp4", help="the path to output video")
args = parser.parse_args()
return args
def q_test(opt):
if torch.cuda.is_available():
torch.cuda.manual_seed(123)
else:
torch.manual_seed(123)
model = DeepQNetwork()
checkpoint_path = "{}/chrome_dino.pth".format(opt.saved_path)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
env = ChromeDino()
state, raw_state, _, _ = env.step(0, True)
state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
if torch.cuda.is_available():
model.cuda()
state = state.cuda()
out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps, (600, 150))
done = False
while not done:
prediction = model(state)[0]
action = torch.argmax(prediction).item()
next_state, raw_next_state, reward, done = env.step(action, True)
out.write(raw_next_state)
if torch.cuda.is_available():
next_state = next_state.cuda()
next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
state = next_state
if __name__ == "__main__":
opt = get_args()
q_test(opt)
import argparse
import os
from random import random, randint, sample
import pickle
import numpy as np
import torch
import torch.nn as nn
from src.model import DeepQNetwork
from src.env import ChromeDino
def get_args():
parser = argparse.ArgumentParser(
"""Implementation of Deep Q Network to play Chrome Dino""")
parser.add_argument("--batch_size", type=int, default=64, help="The number of images per batch")
parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--initial_epsilon", type=float, default=0.1)
parser.add_argument("--final_epsilon", type=float, default=1e-4)
parser.add_argument("--num_decay_iters", type=float, default=2000000)
parser.add_argument("--num_iters", type=int, default=2000000)
parser.add_argument("--replay_memory_size", type=int, default=50000,
help="Number of epoches between testing phases")
parser.add_argument("--saved_folder", type=str, default="trained_models")
args = parser.parse_args()
return args
def train(opt):
if torch.cuda.is_available():
torch.cuda.manual_seed(123)
else:
torch.manual_seed(123)
model = DeepQNetwork()
if torch.cuda.is_available():
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
if not os.path.isdir(opt.saved_folder):
os.makedirs(opt.saved_folder)
checkpoint_path = os.path.join(opt.saved_folder, "chrome_dino.pth")
memory_path = os.path.join(opt.saved_folder, "replay_memory.pkl")
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
iter = checkpoint["iter"] + 1
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
print("Load trained model from iteration {}".format(iter))
else:
iter = 0
if os.path.isfile(memory_path):
with open(memory_path, "rb") as f:
replay_memory = pickle.load(f)
print("Load replay memory")
else:
replay_memory = []
criterion = nn.MSELoss()
env = ChromeDino()
state, _, _ = env.step(0)
state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
while iter < opt.num_iters:
if torch.cuda.is_available():
prediction = model(state.cuda())[0]
else:
prediction = model(state)[0]
# Exploration or exploitation
epsilon = opt.final_epsilon + (
max(opt.num_decay_iters - iter, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_iters)
u = random()
random_action = u <= epsilon
if random_action:
action = randint(0, 2)
else:
action = torch.argmax(prediction).item()
next_state, reward, done = env.step(action)
next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
replay_memory.append([state, action, reward, next_state, done])
if len(replay_memory) > opt.replay_memory_size:
del replay_memory[0]
batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)
state_batch = torch.cat(tuple(state for state in state_batch))
action_batch = torch.from_numpy(
np.array([[1, 0, 0] if action == 0 else [0, 1, 0] if action == 1 else [0, 0, 1] for action in
action_batch], dtype=np.float32))
reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
next_state_batch = torch.cat(tuple(state for state in next_state_batch))
if torch.cuda.is_available():
state_batch = state_batch.cuda()
action_batch = action_batch.cuda()
reward_batch = reward_batch.cuda()
next_state_batch = next_state_batch.cuda()
current_prediction_batch = model(state_batch)
next_prediction_batch = model(next_state_batch)
y_batch = torch.cat(
tuple(reward if done else reward + opt.gamma * torch.max(prediction) for reward, done, prediction in
zip(reward_batch, done_batch, next_prediction_batch)))
q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
optimizer.zero_grad()
loss = criterion(q_value, y_batch)
loss.backward()
optimizer.step()
state = next_state
iter += 1
print("Iteration: {}/{}, Loss: {:.5f}, Epsilon {:.5f}, Reward: {}".format(
iter + 1,
opt.num_iters,
loss,
epsilon, reward))
if (iter + 1) % 50000 == 0:
checkpoint = {"iter": iter,
"model_state_dict": model.state_dict(),
"optimizer": optimizer.state_dict()}
torch.save(checkpoint, checkpoint_path)
with open(memory_path, "wb") as f:
pickle.dump(replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
opt = get_args()
train(opt)