PyTorch深度学习实战——使用深度Q学习进行Pong游戏

news2024/12/28 2:50:06

PyTorch深度学习实战——使用深度Q学习进行Pong游戏

    • 0. 前言
    • 1. 结合固定目标网络的深度 Q 学习模型
      • 1.1 模型输入
      • 1.2 模型策略
    • 2. 实现深度 Q 学习进行 Pong 游戏
    • 相关链接

0. 前言

我们已经学习了如何利用深度 Q 学习来进行 Gym 中的 CartPole 游戏。在本节中,我们将研究更复杂的 Pong 游戏,并了解如何结合深度 Q 学习与固定目标网络进行此游戏,同时利用基于卷积神经网络 (Convolutional Neural Networks, CNN) 的模型替代普通神经网络。

1. 结合固定目标网络的深度 Q 学习模型

1.1 模型输入

在本节中,我们的目标是构建一个可以与计算机进行乒乓球对战并击败它的智能体,该智能体预计能够获得 21 分。我们采用以下策略训练智能体用于进行 Pong 游戏:
裁剪图像的无关部分,获取游戏当前帧(状态):

模型输入

在上示图像中,我们获取了原始图像,并裁剪原始图像的顶部和底部像素。

1.2 模型策略

为了构建具有固定目标网络的深度 Q 学习模型,使用以下策略:

  • 堆叠四个连续的帧——智能体需要状态序列了解球是否正在向它靠近
  • 智能体在初始阶段采取随机动作进行游戏,并将当前状态、未来状态、采取的动作和奖励存储在内存中,只保留最近的 10,000 个动作的信息,并清除超过 10,000 的历史记录
  • 构建预测网络,从内存中获取状态样本并预测可能动作的值
  • 定义作为预测网络的副本——目标网络
  • 预测网络每更新 1000 次就更新一次目标网络
  • 1000epoch 结束时目标网络的权重与预测网络的权重相同
  • 利用目标网络计算下一个状态下最佳动作的 Q 值
  • 对于预测网络提议的动作,我们期望它预测即时奖励和下一个状态中最佳动作的 Q 值的总和
  • 最小化预测网络的 MSE 损失
  • 令智能体持续游戏,直到奖励最大化

2. 实现深度 Q 学习进行 Pong 游戏

根据上一小节的策略,使用 PyTorch 实现智能体,用于最大化 Pong 游戏奖励。

(1) 导入相关库,搭建游戏环境:

import gym
import numpy as np
import cv2
from collections import deque
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple, deque
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make('PongDeterministic-v0')

(2) 获取状态空间和动作空间大小:

state_size = env.observation_space.shape[0]
action_size = env.action_space.n

(3) 定义预处理函数,以便删除无关的底部和顶部像素:

def preprocess_frame(frame): 
    bkg_color = np.array([144, 72, 17])
    img = np.mean(frame[34:-16:2,::2]-bkg_color, axis=-1)/255.
    resized_image = img
    return resized_image

(4) 定义用于堆叠四个连续游戏帧的函数。

函数将 stack_frames、当前状态 state 和标志 is_new_episode 作为输入:

def stack_frames(stacked_frames, state, is_new_episode):
    # Preprocess frame
    frame = preprocess_frame(state)
    stack_size = 4

如果为新一回合,则从初始帧开始:

    if is_new_episode:
        # Clear stacked_frames
        stacked_frames = deque([np.zeros((80,80), dtype=np.uint8) for i in range(stack_size)], maxlen=4)
        # Because we're in a new episode, copy the same frame 4x
        for i in range(stack_size):
            stacked_frames.append(frame) 
        # Stack the frames
        stacked_state = np.stack(stacked_frames, axis=2).transpose(2, 0, 1)

如果并非新一回合,则从 stacked_frames 中删除最旧的帧并添加最新的帧:

    else:
        # Append frame to deque, automatically removes the oldest frame
        stacked_frames.append(frame)
        # Build the stacked state (first dimension specifies different frames)
        stacked_state = np.stack(stacked_frames, axis=2).transpose(2, 0, 1) 
    return stacked_state, stacked_frames

(5) 定义网络架构 DQNetwork

class DQNetwork(nn.Module):
    def __init__(self, states, action_size):
        super(DQNetwork, self).__init__()
        
        self.conv1 = nn.Conv2d(4, 32, (8, 8), stride=4)
        self.conv2 = nn.Conv2d(32, 64, (4, 4), stride=2)
        self.conv3 = nn.Conv2d(64, 64, (3, 3), stride=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(2304, 512)
        self.fc2 = nn.Linear(512, action_size)
        
    def forward(self, state): 
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

(6) 定义 Agent 类。

定义 __init__ 方法:

class Agent():
    def __init__(self, state_size, action_size):
        
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(0)

        ## hyperparameters
        self.buffer_size = 10000
        self.batch_size = 32
        self.gamma = 0.99
        self.lr = 0.0001
        self.update_every = 4
        self.update_every_target = 1000 
        self.learn_every_target_counter = 0
        # Q-Network
        self.local = DQNetwork(state_size, action_size).to(device)
        self.target = DQNetwork(state_size, action_size).to(device)
        self.optimizer = optim.Adam(self.local.parameters(), lr=self.lr)

        # Replay memory
        self.memory = deque(maxlen=self.buffer_size) 
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        # Initialize time step (for updating every few steps)
        self.t_step = 0

__init__ 方法添加目标网络和目标网络的更新频率。

定义权重更新方法 step

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.append(self.experience(state[None], action, reward, next_state[None], done))
        
        # Learn every update_every time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
   # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.batch_size:
                experiences = self.sample_experiences()
                self.learn(experiences, self.gamma)

定义 act 方法,根据给定状态获取要执行的动作:

    def act(self, state, eps=0.):
        # Epsilon-greedy action selection
        if random.random() > eps:
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)
            self.local.eval()
            with torch.no_grad():
                action_values = self.local(state)
            self.local.train()
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

定义 learn 函数,训练预测模型:

    def learn(self, experiences, gamma):
        self.learn_every_target_counter+=1
        states, actions, rewards, next_states, dones = experiences
       # Get expected Q values from local model
        Q_expected = self.local(states).gather(1, actions)

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current state
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        
        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)

        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        if self.learn_every_target_counter%1000 ==0:
            self.target_update() 

在以上代码中,Q_targets_next 是使用目标模型预测的,而非预测模型,每经过 1,000 步后更新目标网络,其中 learn_every_target_counter 是用于确定是否应该更新目标模型的计数器。

定义函数 target_update 用于更新目标模型:

    def target_update(self):
        print('target updating')
        self.target.load_state_dict(self.local.state_dict())

定义函数 sample_experiences 用于从内存中采样经验:

    def sample_experiences(self):
        experiences = random.sample(self.memory, k=self.batch_size)        
        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)        
        return (states, actions, rewards, next_states, dones)

(7) 定义智能体对象:

agent = Agent(state_size, action_size)

(8) 定义用于训练智能体的参数:

n_episodes=5000
max_t=5000
eps_start=1.0
eps_end=0.02
eps_decay=0.995
scores = [] # list containing scores from each episode
scores_window = deque(maxlen=100) # last 100 scores
eps = eps_start
stack_size = 4
stacked_frames = deque([np.zeros((80,80), dtype=np.int) for i in range(stack_size)], maxlen=stack_size) 

(9) 训练智能体:

for i_episode in range(1, n_episodes+1):
    state = env.reset()
    state, frames = stack_frames(stacked_frames, state, True)
    score = 0
    for i in range(max_t):
        action = agent.act(state, eps)
        next_state, reward, done, _ = env.step(action)
        next_state, frames = stack_frames(frames, next_state, False)
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
        if done:
            break 
    scores_window.append(score) # save most recent score
    scores.append(score) # save most recent score
    eps = max(eps_end, eps_decay*eps) # decrease epsilon
    print('\rEpisode {}\tReward {} \tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode,score,np.mean(scores_window), eps), end="")
    if i_episode % 100 == 0:
        print('\rEpisode {}\tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode, np.mean(scores_window), eps))

随着训练回合的增加分数变化情况如下:

import matplotlib.pyplot as plt
plt.plot(scores)
plt.title('Scores over increasing episodes')
plt.show()

模型性能监测

从上图中可以看出,智能体逐渐学会了进行 Pong 游戏,能够获得较高奖励。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(45)——强化学习
PyTorch深度学习实战(46)——深度Q学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1940440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode 热题 HOT 100 (004/100)【宇宙最简单版】

【单调栈】No. 0739 每日温度 【中等】👉力扣对应题目指路 希望对你有帮助呀!!💜💜 如有更好理解的思路,欢迎大家留言补充 ~ 一起加油叭 💦 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#…

python-多任务编程

2. 多任务编程 2.1 多任务概述 多任务 即操作系统中可以同时运行多个任务。比如我们可以同时挂着qq,听音乐,同时上网浏览网页。这是我们看得到的任务,在系统中还有很多系统任务在执行,现在的操作系统基本都是多任务操作系统,具备…

Typora 1.5.8 版本安装下载教程 (轻量级 Markdown 编辑器),图文步骤详解,免费领取(软件可激活使用)

文章目录 软件介绍软件下载安装步骤激活步骤 软件介绍 Typora是一款基于Markdown语法的轻量级文本编辑器,它的主要目标是为用户提供一个简洁、高效的写作环境。以下是Typora的一些主要特点和功能: 实时预览:Typora支持实时预览功能&#xff0…

MySQL 查询 limit 100000000, 10 和 limit 10 速度一样快吗?

MySQL 查询 limit 100000000, 10 和 limit 10 速度一样快吗? MySQL内部分为server层和存储引擎层。一般情况下存储引擎都用innodb。 server层有很多模块,其中需要关注的是执行器是用于跟存储引擎打交道的组件。 执行器可以通过调用存储引擎提供的接口&…

order by 索引优化

根据 add_time 顺序扫描数据,然后根据 where 过滤数据,order_status4的数据很稀疏,会导致扫描很多数据 add_time 0 表达式使得无法使用add_time索引,则会先使用order_status进行过滤数据,然后对add_time进行排序&…

【帆软报表开发】图表设计入门示例

效果展示 我们希望报表样式最终展示如图下所示 操作步骤 新建模板 新建一张普通报表 新建数据库查询 输入sql语句select * from 销量,点击确定 插入图表 合并一片单元格 插入图表 选择柱形图,点击确定 设计柱形图 图表类型 如下图,选中…

怎样录制游戏视频?一站式教你解决

在当今数字化的时代,录制视频已经成为了众多游戏爱好者和职业选手必不可少的技能。录制视频不仅可以帮助玩家保存游戏的精彩瞬间,还能用于制作游戏教程、宣传视频等。因此,了解怎样录制游戏视频是非常重要的。本文将详细介绍两种流行的录制视…

【无人机】测绘行业新时代

【无人机】测绘行业新时代 无人机测绘主要指的是依托无人机系统为主要的信息接收平台,通过无人机机载遥感信息采集和处理设备,将最终所获取的遥感信息传输到测绘中心,经过数据技术处理,形成立体化的数字模型,以满足行…

逻辑回归损失函数

文章目录 1.基础简析交叉熵损失函数(Cross-Entropy Loss)对数似然损失函数(Log-Likelihood Loss) 2.关键步骤3.案例 1.基础简析 逻辑回归(Logistic Regression)是一种广泛应用于分类问题的统计模型&#x…

选粮必看!安全又美味的主食罐头来啦!江小傲、希喂、迈格仕真实测评

开猫咖3年啦,店里有加菲,美短,布偶,暹罗,都是我一手带大的。店铺开在高校附近,顾客以学生为主,也有很多养猫人士会到店里来,和我交流选粮经验。不少铲屎官都希望选到安全有美味的罐头…

Inconsistent Query Results Based on Output Fields Selection in Milvus Dashboard

题意:在Milvus仪表盘中基于输出字段选择的不一致查询结果 问题背景: Im experiencing an issue with the Milvus dashboard where the search results change based on the selected output fields. Im working on a RAG project using text data conv…

Facebook Dating:社交平台的约会新体验

随着社交媒体的普及和技术的发展,传统的社交方式正在经历革新,尤其是在约会这个领域。Facebook作为全球领先的社交平台,推出了Facebook Dating,旨在为用户提供一个全新的约会体验。本文将探讨Facebook Dating如何重新定义社交平台…

自托管社交媒体管理软件Mixpost

本文软件应网友 ilikeit 的要求而折腾; 什么是 Mixpost ? Mixpost 是一款强大且多功能的社交媒体管理软件,旨在简化社交媒体操作并增强内容营销策略。可以让您轻松地在一个地方创建、安排、发布和管理社交媒体内容,没有任何限制或…

C语言 | Leetcode C语言题解之第242题有效的字母异位词

题目&#xff1a; 题解&#xff1a; bool isAnagram(char* s, char* t) {int len_s strlen(s), len_t strlen(t);if (len_s ! len_t) {return false;}int table[26];memset(table, 0, sizeof(table));for (int i 0; i < len_s; i) {table[s[i] - a];}for (int i 0; i &…

Java开发之Java线程池

#来自ゾフィー&#xff08;佐菲&#xff09; 1 简介 在需要异步或者并发编程中&#xff0c;常常使用线程池&#xff0c;所谓线程池&#xff0c;就是事先创建好一堆线程&#xff0c;装到一个池子里面&#xff0c;需要用到时候取出来&#xff0c;这样带来了以下的好处&#xff1a…

用ComfyUI安装可图Kolors大模型做手机壁纸

一、Kolors简介 国内科技公司快手在人工智能领域取得了显著进展&#xff0c;特别推出了「可图 Kolors」这一开源模型&#xff0c;它在图像生成质量上超越了SD3&#xff0c;与Midjourney v6模型相媲美&#xff0c;并支持中文提示词识别与生成中文字符&#xff0c;成为国产AI绘画…

【STM32】理解时钟树(图示分析)

文章目录 时钟系统什么是时钟时钟树简化图示类比示例时钟树详解时钟源系统时钟配置各总线时钟外设时钟 时钟系统 什么是时钟 时钟在电子和计算机系统中指的是生成周期性信号的电路或设备&#xff0c;这种周期性信号用于同步系统内的各种操作。时钟信号通常是方波&#xff0c;…

YOLO 模型基础入门及官方示例演示

文章目录 Github官网简介模式数据集Python 环境Conda 环境Docker 环境部署 CPU 版本官方 CLI 示例官方 Python 示例 任务目标检测姿势估计图像分类 Ultralytics HUB视频流示例 Github https://github.com/ultralytics/ultralytics 官网 https://docs.ultralytics.com/zhhttp…

图像生成(Text-to-Image)发展脉络

这篇博客对 图像生成&#xff08;image generation&#xff09; 领域的经典工作发展进行了梳理&#xff0c;包括重要的一些改进&#xff0c;目的是帮助读者对此领域有一个整体的发展方向把握&#xff0c;并非是对每个工作的详细介绍。 脉络发展&#xff08;时间顺序&#xff0…

13.5.【C语言】二维数组

接第13篇&#xff08;http://t.csdnimg.cn/TioJH&#xff09; 把一维数组做为数组的元素&#xff0c;这时候就是二维数组&#xff0c;二维数组作为数组元素的数组被称为三维数组&#xff0c;二维数组以上的数组统称为多维数组。 01.创建 格式&#xff1a; 数据类型 数组名[…