一文实践强化学习训练游戏ai--doom枪战游戏实践

news2024/9/20 14:44:00

一文实践强化学习训练游戏ai–doom枪战游戏实践
上次文章写道下载doom的环境并尝试了简单的操作,这次让我们来进行对象化和训练、验证,如果你有基础,可以直接阅读本文,不然请你先阅读Doom基础知识,其中包含了下载、动作等等的基础知识。
本次与之前的马里奥训练不同,马里奥是已经有做好的step等函数的,而这个doom没有,但也因此我们可以更好的一窥训练的过程。
完整代码在最后,可以复制执行。

文章目录

  • 一、训练模型
    • 1、vizdoom_train类
      • 1)__init__
      • 2)step
      • 3)togray
      • 4)其他函数
    • 2、保存模型函数
    • 3、训练模型
  • 二、成果验收
  • 完整代码
    • 训练代码
    • 测试代码

一、训练模型

1、vizdoom_train类

这是我们的训练基类,由于我们想用openai gym环境,因此我们需要手写此环境必须的__init__、step等函数

1)init

在这个函数中,我们要定义训练的观察空间(即游戏图像)、动作空间(即ai可以执行的操作)和一些基础设置。

        VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"
        self.game = vizdoom.DoomGame()
        self.game.load_config(VizDoom_basic_cfg)

        if render == False: 
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        self.game.init()

此处,我们先指定游戏文件的位置,然后加一个判断,决定是否允许游戏窗口显示出来,最后初始化游戏。
初始化后,我们就要规定观察空间和动作空间了

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) 
        self.action_space = Discrete(3)

观察空间是我们游戏的界面,这是一个灰化后的图像,所以维度为1,大小为100*160
动作空间为离散空间3,即可以选取动作0、1、2,我们只需要在后续的代码中指定数字指代的动作就可以了。

2)step

step函数非常关键,这是ai执行动作的时候会调用的函数。
首先,我们定义我们的动作

        actions = np.identity(3,dtype=np.uint8)
        reward = self.game.make_action(actions[action], 4) 

这部分的详细解释在Doom基础知识中,事实上就是定义一个矩阵,调用游戏文件中的左移、右移和射击。
接下来,我们要获取当前的状态,比如得分,游戏图像等,这样我们才可以训练。

        try:
            state = self.game.get_state()
            img = state.screen_buffer
            img = self.togray(img)
            info = state.game_variables[0]
        except:
            img = np.zeros(self.observation_space.shape)
            info = 0 
        finally:
            info = {"info":info}
            done = self.game.is_episode_finished()
        #img_show(img)
        return img,reward,done,info

使用try,是因为gameover时有些内容获取不到,为了防止程序因此暂停,用try。最后,我们要把info变成字典形式,这是因为openai gym环境时这么要求的。为了方便理解,这里可以调用imgshow,查看目前的图像,其实现如下。

def img_show(img):
    plt.imshow(img)
    plt.show()
    time.sleep(5)

3)togray

灰度化图像,我们知道,彩色图像时由rgb三个颜色矩阵组成,但这么大的数据量给我们的训练增添的很多负担,于是我们采用灰度图。同时,我们缩小图像,这样可以训练的更快。

    def togray(self,observation):
        gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)
        state = np.reshape(resize, (100,160,1))
        return state

observation 是一个 NumPy 数组,通常表示图像数据。
np.moveaxis 是 NumPy 库中的一个函数,用于重新排列数组的轴。
参数 0 表示将第0轴(通常是颜色通道)移动到新位置的最后一个轴位置(即 -1)。
例如,如果 observation 的形状是 (C, H, W)(即颜色通道在第一个维度),经过 np.moveaxis 后,形状将变为 (H, W, C),这对于 OpenCV 处理图像更为常见,因为 OpenCV 期望颜色通道是图像的最后一个维度。

cv2.cvtColor 是 OpenCV 中用于颜色空间转换的函数。
第一个参数是输入图像(在这里是经过 np.moveaxis 处理后的图像)。
第二个参数 cv2.COLOR_BGR2GRAY 指定了将图像从 BGR 颜色空间转换为灰度图像。

4)其他函数

我们要定义一个关闭游戏的函数close

   def close(self):
        self.game.close()

以及一个reset函数,用于在结束一个游戏后,重置状态,继续下一轮训练

    def reset(self):
        state = self.game.new_episode()
        state = self.game.get_state()
        return self.togray(state.screen_buffer)

至此,我们已经成功地把这个独立游戏包装成了可以使用openai gym的环境的游戏。

2、保存模型函数

class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

我们使用这段代码来存档数据,这部分复制即可

3、训练模型

首先,我们指定训练结果保存的路径

    CHECKPOINT_DIR = './train/train_basic'
    LOG_DIR = './logs/log_basic'
    callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    ```

然后我们调用训练函数进行训练

    env = vizdoom_train(render=False)
    model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)
    model.learn(total_timesteps=100000, callback=callback)

这里我们使用PPO这个强化学习算法。

经过几十分钟的等待,就得到训练好的模型了

二、成果验收

训练好模型后,我们要使用模型,看看效果。
首先,我们载入训练好的模型

model = PPO.load('./train/train_basic/best_model_100000')

然后测试以下模型的平均得分

mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)

这样还不够直观,我们让ai玩给我们看


for episode in range(100): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        time.sleep(0.20)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(total_reward, episode))
    time.sleep(2)

我们首先让ai模型预测,然后将预测的动作输入step函数,然后展示页面。

如图所示,平均得分非常高,基本能快速索敌,然后一枪秒杀。
至此,我们完成了ai的训练。

完整代码

训练代码

from gym import Env
from gym.spaces import Discrete,Box
import cv2
from vizdoom import *
import vizdoom
import random
import time
import numpy as np
#DIS离散空间 作用类似于random
#box用来装游戏图像
from matplotlib import pyplot as plt
class vizdoom_train(Env):
    def __init__(self, render=True):
        super().__init__()
        VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"
        self.game = vizdoom.DoomGame()
        self.game.load_config(VizDoom_basic_cfg)

        if render == False: 
            self.game.set_window_visible(False)
        else:
            self.game.set_window_visible(True)
        self.game.init()

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) 
        self.action_space = Discrete(3)
    def step(self,action):
        actions = np.identity(3,dtype=np.uint8)
        reward = self.game.make_action(actions[action], 4) 
        try:
            state = self.game.get_state()
            img = state.screen_buffer
            img = self.togray(img)
            info = state.game_variables[0]
        except:
            img = np.zeros(self.observation_space.shape)
            info = 0 
        finally:
            info = {"info":info}
            done = self.game.is_episode_finished()
        #img_show(img)
        return img,reward,done,info
    
    def close(self):
        self.game.close()

    def reset(self):
        state = self.game.new_episode()
        state = self.game.get_state()
        return self.togray(state.screen_buffer)
    
    def togray(self,observation):
        gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)
        resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)
        state = np.reshape(resize, (100,160,1))
        return state
    
def img_show(img):
    plt.imshow(img)
    plt.show()
    time.sleep(5)


# Import os for file nav
import os 
# Import callback class from sb3
from stable_baselines3.common.callbacks import BaseCallback
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True
    
if __name__ == "__main__":

    CHECKPOINT_DIR = './train/train_basic'
    LOG_DIR = './logs/log_basic'
    callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    

    train = vizdoom_train()

    from stable_baselines3 import PPO

    env = vizdoom_train(render=False)
    model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)
    model.learn(total_timesteps=100000, callback=callback)

测试代码

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
import time
model = PPO.load('./train/train_basic/best_model_100000')

from second_gym import vizdoom_train
env = vizdoom_train(render=True)

mean_reward, _ ,_= evaluate_policy(model, env, n_eval_episodes=5)

print(mean_reward)

for episode in range(100): 
    obs = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        time.sleep(0.20)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(total_reward, episode))
    time.sleep(2)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1911790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++中的多重继承和虚继承:横向继承、纵向继承和联合继承;虚继承

多重继承 A.横向多重继承: B.纵向多重继承: C.联合多重继承: 因为 single 和 waiter 都继承了一个 worker 组件,因此 SingingWaiter 将包含两个 worker 组件,那么将派生类对象的地址赋给基类指针将出现二义性 那么如何…

AdaBoost集成学习算法理论解读以及公式为什么这么设计?

本文致力于阐述AdaBoost基本步骤涉及的每一个公式和公式为什么这么设计。 AdaBoost集成学习算法基本上遵从Boosting集成学习思想,通过不断迭代更新训练样本集的样本权重分布获得一组性能互补的弱学习器,然后通过加权投票等方式将这些弱学习器集成起来得到…

代码随想录——合并区间(Leecode LCR74)

题目链接 贪心 排序 class Solution {public int[][] merge(int[][] intervals) {ArrayList<int[]> res new ArrayList<>();// 先将数组按照左区间排序Arrays.sort(intervals, new Comparator<int[]>() {public int compare(int[] intervals1, int[] in…

CentOS 7:停止更新后如何下载软件?

引言 CentOS 7 是一个广受欢迎的 Linux 发行版&#xff0c;它为企业和开发者提供了一个稳定、安全、且免费的操作系统环境。然而&#xff0c;随着时间的推移&#xff0c;CentOS 7 的官方支持已经进入了维护阶段&#xff0c;这意味着它将不再收到常规的更新和新功能&#xff0c;…

第241题| 确定极限中参数问题 | 武忠祥老师每日一题

解题思路&#xff1a;确定极限中的参数的方法是求这个极限&#xff1b;求极限根据类型选方法。 形可以用到三种方法&#xff1a;洛必达&#xff0c;等价&#xff0c;泰勒。 先观察题目&#xff0c;将看成一个整体&#xff0c;同时,并令,整理之后如下&#xff1a; 这里也要想办…

MySQL架构优化及SQL优化

变更项目的整体架构是性能收益最大的方式。主要涉及两方面&#xff0c;一方面是从整个项目角度&#xff0c;引入一些中间件优化整体性能&#xff0c;另一方面是调整MySQL的部署架构&#xff0c;确保能承载更大的流量访问&#xff0c;提高数据层的整体吞吐。 1. 引入缓存中间件…

使用F1C200S从零制作掌机之USB游戏手柄

一、USB手柄 COIORVIS PC游戏手柄电脑USB FC模拟器经典游戏手柄 安卓手机有线连接单打格斗对打拳皇 经典有线手柄【黄色】 https://item.jd.com/10046453175183.html 插入USB即可自动识别。 # [ 1425.447643] usb 1-1: USB disconnect, device number 7 [ 1427.072155] usb …

方法引用 异常 file

目录 一.方法引用 1.方法引用概述 2.引用静态方法 3.引用成员方法 i.引用其他成员方法 ii.引用本类成员方法 iii.引用父类成员方法 4.引用构造方法 5.其他调用方式 i.使用类名引用成员方法 ii.引用数组的构造方法 二、异常 1.异常的作用 2.异常的处理方式 i.JVM…

Windows7彻底卸载mysql

1.控制面板卸载mysql 2.删除C:\Program Files\MySQL 3.删除C:\用户\Administrator\App Data\Roaming\MySQL”(App Data默认隐藏&#xff0c;需要在文件夹和搜索选项中勾选显示文件夹),为了删除的更彻底&#xff0c;可以直接在计算机全盘搜索MySQL关键字&#xff0c;将所有找到…

微信服务里底部的不常用功能如何优化的数据分析思路

图片.png 昨天下午茶时光&#xff0c;和闺蜜偶然聊起&#xff0c;其实在微信服务底部&#xff0c;有很多被我们忽略遗忘&#xff0c;很少点过用过的功能服务&#xff0c;往往进入服务只为了收付款或进入钱包&#xff0c;用完就走了&#xff0c;很少拉到底部&#xff0c;看到和用…

STM32的SPI接口详解

目录 1.SPI简介 2.SPI工作原理 3.SPI时序 3.1 CPOL&#xff08;Clock Polarity&#xff0c;时钟极性&#xff09;&#xff1a; 3.2 CPHA&#xff08;Clock Phase&#xff0c;时钟相位&#xff09;&#xff1a; 3.3 四种工作模式 4.相关代码 4.1使能片选信号 4.2使能通…

【WebGIS平台】传统聚落建筑科普数字化建模平台

基于上述概括出建筑单体的特征部件&#xff0c;本文利用互联网、三维建模和地理信息等技术设计了基于浏览器/服务器&#xff08;B/S&#xff09;的传统聚落建筑科普数字化平台。该平台不仅实现了对传统聚落建筑风貌从基础到复杂的数字化再现&#xff0c;允许用户轻松在线构建从…

谷粒商城-个人笔记(集群部署篇三)

前言 ​学习视频&#xff1a;​Java项目《谷粒商城》架构师级Java项目实战&#xff0c;对标阿里P6-P7&#xff0c;全网最强​学习文档&#xff1a; 谷粒商城-个人笔记(基础篇一)谷粒商城-个人笔记(基础篇二)谷粒商城-个人笔记(基础篇三)谷粒商城-个人笔记(高级篇一)谷粒商城-个…

Spark实现电商消费者画像案例

作者/朱季谦 故事得从这一张图开始说起—— 可怜的打工人准备下班时&#xff0c;突然收到领导发来的一份电商消费者样本数据&#xff0c;数据内容是这样的—— 消费者姓名&#xff5c;年龄&#xff5c;性别&#xff5c;薪资&#xff5c;消费偏好&#xff5c;消费领域&#x…

Sentinel-1 Level 1数据处理的详细算法定义(二)

《Sentinel-1 Level 1数据处理的详细算法定义》文档定义和描述了Sentinel-1实现的Level 1处理算法和方程&#xff0c;以便生成Level 1产品。这些算法适用于Sentinel-1的Stripmap、Interferometric Wide-swath (IW)、Extra-wide-swath (EW)和Wave模式。 今天介绍的内容如下&…

【1.3】动态规划-解码方法

一、题目 一条包含字母A-Z的消息通过以下映射进行了编码&#xff1a; A -> 1 B -> 2 ... Z -> 26 要解码已编码的消息&#xff0c;所有数字必须基于上述映射的方法&#xff0c;反向映射回字母&…

Nacos2.X 配置中心源码分析:客户端如何拉取配置、服务端配置发布客户端监听机制

文章目录 Nacos配置中心源码总流程图NacosClient源码分析获取配置注册监听器 NacosServer源码分析配置dump配置发布 Nacos配置中心源码 总流程图 Nacos2.1.0源码分析在线流程图 源码的版本为2.1.0 &#xff0c;并在配置了下面两个启动参数&#xff0c;一个表示单机启动&#…

C++初探究(2)

引用 对于一个常量&#xff0c;想要将其进行引用&#xff0c;则使用普通的引用相当于权限扩大&#xff08;常量为只读&#xff0c;但此处的引用参数为可读可写&#xff09;&#xff0c;C编译器会报错. 例如&#xff1a; const int a 10;int& ra a;//权限放大&#xff0…

使用Mplayer实现MP3功能

核心功能 1. 界面设计 项目首先定义了一个clearscreen函数&#xff0c;用于清空屏幕&#xff0c;为用户界面的更新提供了便利。yemian函数负责显示主菜单界面&#xff0c;提供了包括查看播放列表、播放控制、播放模式选择等在内的9个选项。 2. 文件格式支持 is_supported_f…

详解TCP和UDP通信协议

目录 OSI的七层模型的主要功能 tcp是什么 TCP三次握手 为什么需要三次握手&#xff0c;两次握手不行吗 TCP四次挥手 挥手会什么需要四次 什么是TCP粘包问题&#xff1f;发生的原因 原因 解决方案 UDP是什么 TCP和UDP的区别 网络层常见协议 利用socket进行tcp传输代…