强化学习Agent系列(二)——PyGame虚拟环境创建与Python 贪吃蛇Agent制作实战教学

news2025/1/18 10:55:23

文章目录

  • 一、前言
  • 二、gymnasium 简单虚拟环境创建
    • 1、gymnasium介绍
    • 2、gymnasium 贪吃蛇简单示例
  • 三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent
    • 1、虚拟环境
    • 2、虚拟环境注册
    • 3、训练程序
    • 4、模型测试
  • 三、卷积虚拟环境
    • 1、卷积神经网络虚拟环境
    • 2、训练代码

一、前言

大家好,未来的开发者们请上座
随着人工智能的发展,强化学习基本会再次来到人们眼前,遂想制作一下相关的教程。强化学习第一步基本离不开虚拟环境的搭建,下面用大家耳熟能详的贪吃蛇游戏为基础,制作一个Agent,完成对这个游戏的绝杀。
万里长城第二步:用python开发贪吃蛇智能体****加粗样式

二、gymnasium 简单虚拟环境创建

1、gymnasium介绍

gymnasium(此前称为gym)是一个由 OpenAI 开发的 Python 库,用于开发和比较强化学习算法。它提供了一组丰富的环境,模拟了各种任务,包括但不限于经典的控制问题、像素级游戏、机器人模拟等。

以下是gymnasium库的一些主要特点:

  1. 环境多样性:gymnasium包含了一系列不同的环境,每个环境都有其独特的观察空间(输入)和动作空间(输出)。这些环境涵盖了从简单的文本控制任务到复杂的三维视觉任务的广泛范围。

  2. 标准化API:gymnasium库提供了一个简单且统一的API来与这些环境交互。这使得研究人员和开发人员可以轻松地用相同的代码测试和比较不同的强化学习算法。

  3. 扩展性:用户可以创建自定义环境并将其集成到gymnasium框架中,这使得库能够适应各种不同的研究需求和应用场景。

  4. 评估标准:gymnasium环境通常包括预定义的评估标准,如累积回报或任务完成时间,这有助于在不同算法间进行公平的比较。

  5. 社区支持:由于gymnasium是由OpenAI推出并得到了强化学习社区的广泛支持,因此有大量的教程、论坛讨论和第三方资源可供学习和参考。

  6. 可视化和监控:gymnasium提供了工具来可视化智能体的性能,并允许监控和记录实验过程,便于分析和调试。

使用gymnasium的基本步骤通常包括:

导入gymnasium库。
创建一个环境实例。
初始化环境。
在一个循环中,根据当前观察值选择动作,执行动作,并接收环境的反馈(新的观察值、奖励、完成状态等)。
结束实验并关闭环境。

2、gymnasium 贪吃蛇简单示例

下面是贪吃蛇虚拟环境的一个简单的示例
在本次示例中,暂未进行任何训练。一切行为主要是从状态空间中随机抽取一个动作并执行。
下面是gymnasium创建的虚拟环境的三个核心函数介绍

  • reset(): 这个函数用于重置环境到初始状态,并返回初始状态的观测值。在开始每个新的episode时,通常会调用这个函数来初始化环境。

  • step(action): 这个函数用于让Agent在环境中执行一个动作(action),并返回四个值:观测值(observation),奖励(reward),是否终止(done),以及额外信息(info)。Agent根据环境返回的信息来决定下一步的动作。

  • render(): 这个函数用于在屏幕上渲染当前环境的状态,通常用于可视化环境以便观察Agent的行为。不是所有的环境都支持渲染,具体取决于环境的实现。

下面是具体的示例代码

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
class SnakeEnv(gym.Env):
    def __init__(self):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH=240
        self.SCREEN_HEIGHT=240
        #蛇的方块大小
        self.snakeCell=10
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH,self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')
        self.action_space=gym.spaces.Discrete(4) #动作空间为4
        self.observation_space=gym.spaces.Box(low=0,high=7,shape=(self.SCREEN_WIDTH,self.SCREEN_HEIGHT),dtype=np.uint8)
    
    # 重启
    def reset(self):
        """
        重置蛇和食物的位置
        """
        # 蛇的初始位置
        self.snake_head=[100,50]
        self.snake_body=[[100,50],[100-self.snakeCell,50],[100-self.snakeCell*2,50]]
        self.len=3

        # 食物的初始位置
        self.food_pos=[random.randint(1,self.SCREEN_WIDTH//10-1)*10,random.randint(1,self.SCREEN_HEIGHT//10-1)*10]

        return self._get_observation()

    # 根据当前状态 和action 执行动作
    def step(self,action):
        # 定义动作到方向的映射
        directionDict={'LEFT':[1,0],'RIGHT':[-1,0],'UP':[0,-1],'DOWN':[0,1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }
        directionTarget=action_to_direction[action]
        nextPosDelay=np.array(directionDict[directionTarget])*self.snakeCell #加的位置
        self.snake_head=list(np.array(self.snake_body[0])+nextPosDelay)

        if self.snake_head in self.snake_body:
           return self._get_observation(), 0, True, False, {}
        self.snake_body.insert(0,self.snake_head)

        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_head:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10)) * 10]
            self.len+=1

        # 弹出
        while self.len<len(self.snake_body):
            self.snake_body.pop()
        # 奖励
        reward,done=self._get_reward()
        truncated = True

        return self._get_observation(), reward, truncated, done, {}

    # 渲染
    def render(self,mode="human"):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE= (255,255,255)
        GREEN = (0,255,0)
        RED = (255,0,0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物
        for pos in self.snake_body:
            pygame.draw.rect(screen,GREEN,pygame.Rect(pos[0],pos[1],self.snakeCell,self.snakeCell))
        pygame.draw.rect(screen,RED,pygame.Rect(self.food_pos[0],self.food_pos[1],self.snakeCell,self.snakeCell))

        pygame.display.update()

    # 获取奖励
    def _get_reward(self):

        # 计算奖励
        reward = 0
        done = False

        # 检查蛇是否吃到食物
        if self.snake_head:
            reward+=10

        # 检查蛇是否撞到墙壁或自身
        head=self.snake_head
        if head[0]<0 or head[0]>self.SCREEN_WIDTH-10 or head[1]<0 or head[1]>self.SCREEN_HEIGHT-10:
            reward = -10
            done = True

        return reward,done

    # 获取当前观察空间
    def _get_observation(self):
        # 获取窗口内容作为观察值
        observation = pygame.display.get_surface()
        # 将观察值调整为指定的宽度和高度
        # observation = pygame.transform.scale(observation, (self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        return observation

def main():
    snakeEnv=Snake()
    snakeEnv.reset()
    done=False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action= snakeEnv.action_space.sample()
        screen, reward, truncated, done,_=snakeEnv.step(action)
        snakeEnv.render()
        time.sleep(0.03)
if __name__=="__main__":
    main()

程序运行截图:
在这里插入图片描述

三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent

在上一步中,你已经创建出你要的虚拟环境了,现在让我们在这个创建好的环境中进行训练吧!

1、虚拟环境

SnakeEnv2.py

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optional


class SnakeEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }

    def __init__(self, render_mode="human"):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH = 100
        self.SCREEN_HEIGHT = 100
        # 蛇的方块大小
        self.snakeCell = 10
        # 游戏速度
        self.speed = 12
        self.clock = pygame.time.Clock()
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')

        self.render_mode = render_mode
        self.action_space = gym.spaces.Discrete(4)  # 动作空间为4
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT),
                                                dtype=np.float32)
        # 初始化蛇和食物的位置等属性
        # ...
        self.num_timesepts = 0

    # 重启
    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """
        重置蛇和食物的位置
        """
        super().reset(seed=seed)
        self.curStep = 0  # 步数
        # 蛇的初始位置
        self.snake_next = [60, 50]
        self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]
        self.len = 3

        # 食物的初始位置
        self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,
                         random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]
        info = {}

        #
        return self._get_observation(), info

    # 根据当前状态 和action 执行动作
    def step(self, action):
        self.num_timesepts += 1  # 步骤统计加1
        # 定义动作到方向的映射
        directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }

        directionTarget = action_to_direction[action]
        nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置
        # 输入action 获取到 snake_next 下一步
        self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)
        if self.snake_next == self.snake_body[1]:
            return self._get_observation(), -0.5, False, False, {}

        # 奖励
        reward, terminated = self._get_reward()
        
        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_next:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]
            self.len += 1




        truncated = False
        info = {}
        # if self.render_mode == "human":
        if self.render_mode == "human" and self.num_timesepts % 5000 > 4000 and self.num_timesepts > 10000:
            self.render()
            for event in pygame.event.get():
                if event == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
        
        if not terminated:
            self.snake_body.insert(0, self.snake_next)
            
        # 弹出
        while self.len < len(self.snake_body):
            self.snake_body.pop()
        return self._get_observation(), reward,terminated,  truncated,{}

    # 渲染
    def render(self):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE = (255, 255, 255)
        GREEN = (0, 255, 0)
        RED = (255, 0, 0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物

        snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)
        for i in range(len(self.snake_body)):
            pos = self.snake_body[len(self.snake_body) - i - 1]
            color = [int(round(component * snakecolor[i])) for component in GREEN]
            pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))
        pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))

        pygame.display.update()

        self.clock.tick(self.speed)
    def GetDic(self,p1,p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))
    # 获取奖励
    def _get_reward(self):
    
        # 计算奖励
        self.curStep += 1  # 步数
        reward = 0
        terminated = False
        flag=0
        # 正向激励
        # 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励
        if self.snake_next == self.food_pos:
            reward += 500 + pow(5, self.len)
            self.curStep = 0
            #print(reward)
        # 负向激励
        # 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励
        head = self.snake_next
        if head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[
            1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:
            reward -= 100 / self.len
            terminated = True
            self.curStep = 0
        
        # 摸鱼步数超过一定值就开始负向奖励
        if self.curStep > 100 * self.len:
            reward -=  1 / self.len

      
      
        # 中向激励
        if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):
            reward += 2 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_body[0],self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.
        else:
            reward -= 1 / self.len
        #print(reward * 0.3)
        if reward<0:
            #print(reward * 0.2)
            pass
     
        #print(reward* 0.2)
        return reward * 0.2, terminated

    # 获取当前观察空间
    def _get_observation(self):
        # 返回观察空间,也就是一个二维数组
        obs = np.zeros((self.SCREEN_WIDTH, self.SCREEN_HEIGHT), dtype=np.float32)
        obs[tuple(np.transpose(self.snake_body))] = np.linspace(0.8, 0.2, len(self.snake_body), dtype=np.float32)
        obs[tuple(self.snake_body[0])] = 1.0
        obs[tuple(self.food_pos)] = -1.0
        return obs


def main():
    snakeEnv = SnakeEnv()
    snakeEnv.reset()
    done = False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action = snakeEnv.action_space.sample()
        screen, reward, truncated, done, _ = snakeEnv.step(action)
        snakeEnv.render()


if __name__ == "__main__":
    main()

2、虚拟环境注册

打开当前项目的site-packages
在这里插入图片描述
找到gymnasium
将其SnakeEnv2.py放置如下,并在init.py中添加调用注册函数
在这里插入图片描述
到这里就注册完毕,可以进行训练了

3、训练程序

snake_train.py 具体代码如下

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler

# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnv-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-2, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

model = PPO("MlpPolicy", env, verbose=1, device="cuda",

        n_steps=2048,
        batch_size=512,
        n_epochs=4,
        gamma=0.94,
            learning_rate=lr_schedule,
            clip_range=clip_range_schedule,
            )
# 4、训练模型:

# Set the save directory
num=1
save_dir="trained_models_mlp"
while True:
    save_dir = "trained_models_mlp_{}".format(num)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        break
    else:
        num +=1

checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")

# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)


model.learn(
    total_timesteps=int(200000),
    callback=[checkpoint_callback]
)


# Restore stdout
sys.stdout = original_stdout

# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))

4、模型测试

对训练好的模型进行测试,可以用如下代码

import time
import random
from sb3_contrib import MaskablePPO
from stable_baselines3 import PPO
from snakecnn23 import SnakeEnv
import pygame

MODEL_PATH=r'H:\AILab\RL\Snaker2\trained_models_cnn\ppo_snake_final'
# Load the trained model
model = MaskablePPO.load(MODEL_PATH)

snakeEnv = SnakeEnv()
for i in range(10):
    obs,info=snakeEnv.reset()
    terminated = False

    while not terminated:
        # 获取事件
        for event in pygame.event.get():
            if event == pygame.QUIT:
                pygame.quit()
        # 从动作空间随机获取一个动作
        action ,_=  model.predict(obs, action_masks=snakeEnv.get_action_mask())

        prev_mask = snakeEnv.get_action_mask()
        action_value=int(action.item())

        obs, reward,  terminated, truncated, _ = snakeEnv.step(action_value)
        snakeEnv.render()

三、卷积虚拟环境

上面的是基于多层感知机,上限有限,可能效果不是很好,可以对其进行一点点改进
核心是修改
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)
和 _get_observation() 观察空间
改完这两个其他的基本不用变

1、卷积神经网络虚拟环境

import time

import pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optional


class SnakeEnv(gym.Env):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }

    def __init__(self, render_mode="human"):
        super().__init__()

        # 初始化Pygame
        pygame.init()
        # 屏幕宽高
        self.SCREEN_WIDTH = 84
        self.SCREEN_HEIGHT = 84
        # 蛇的方块大小
        self.snakeCell = 7
        # 游戏速度
        self.speed = 12
        self.clock = pygame.time.Clock()
        # 创建窗口
        self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))
        pygame.display.set_caption('Snake_Game')

        self.render_mode = render_mode
        self.action_space = gym.spaces.Discrete(4)  # 动作空间为4
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),
                                                dtype=np.uint8)
        # 初始化蛇和食物的位置等属性
        # ...
        self.num_timesepts = 0

    # 重启
    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        """
        重置蛇和食物的位置
        """
        super().reset(seed=seed)
        self.curStep = 0  # 步数
        # 蛇的初始位置
        self.snake_next = [60, 50]
        self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]
        self.len = 3

        # 食物的初始位置
        self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,
                         random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]
        info = {}

        #
        return self._get_observation(), info

    # 根据当前状态 和action 执行动作
    def step(self, action):
        self.num_timesepts += 1  # 步骤统计加1
        # 定义动作到方向的映射
        directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}

        action_to_direction = {
            0: "UP",
            1: "DOWN",
            2: "LEFT",
            3: "RIGHT"
        }

        directionTarget = action_to_direction[action]
        nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置
        # 输入action 获取到 snake_next 下一步
        self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)
        if self.snake_next == self.snake_body[1]:
            return self._get_observation(), -0.5, False, False, {}

        # 奖励
        reward, terminated = self._get_reward(action)
        
        # 如果是吃到食物,就重新刷新果子,同时长度 +1
        if self.food_pos == self.snake_next:
            self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,
                             random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]
            self.len += 1




        truncated = False
        info = {}
        # if self.render_mode == "human":
        if self.render_mode == "human" and self.num_timesepts % 5000 > 4500 and self.num_timesepts > 10000:
            self.render()
            for event in pygame.event.get():
                if event == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
        
        if not terminated:
            self.snake_body.insert(0, self.snake_next)
            
        # 弹出
        while self.len < len(self.snake_body):
            self.snake_body.pop()
        return self._get_observation(), reward,terminated,  truncated,{}

    # 渲染
    def render(self):
        # 实现可视化
        screen = self.screen
        # 颜色定义
        WHITE = (255, 255, 255)
        GREEN = (0, 255, 0)
        RED = (255, 0, 0)

        # 清空屏幕
        screen.fill(WHITE)

        # 画蛇和食物

        snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)
        for i in range(len(self.snake_body)):
            pos = self.snake_body[len(self.snake_body) - i - 1]
            color = [int(round(component * snakecolor[i])) for component in GREEN]
            pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))
        pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))

        pygame.display.update()

        self.clock.tick(self.speed)
    def GetDic(self,p1,p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))
    # 获取奖励
    def _get_reward(self,action):
    
        # 计算奖励
        self.curStep += 1  # 步数
        reward = 0
        terminated = False
        flag=0
        # 正向激励
        # 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励
        if self.snake_next == self.food_pos:
            reward += 400 + pow(5, self.len)
            self.curStep = 0
           # print(reward)
            #print(action)
           # print(self.snake_body,self.food_pos,self.snake_next)
        # 负向激励
        # 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励
        head = self.snake_next
        if head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[
            1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:
            reward -= 200 / self.len
            terminated = True
            self.curStep = 0
        
        # 摸鱼步数超过一定值就开始负向奖励
        if self.curStep > 250 * self.len:
            reward -=  1 / self.len

      
      
        # 中向激励
        if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):
            reward += 4 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_next,self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.
        elif self.curStep>50 and self.GetDic(self.snake_next,self.food_pos)>= self.GetDic(self.snake_body[0],self.food_pos):
            reward -= 2 / self.len * self.GetDic(self.snake_next,self.food_pos) /self.SCREEN_WIDTH
        #print(reward * 0.3)
        if reward<0:
            #print(reward * 0.2)
            pass
     
       # print(reward* 0.2)
        return reward * 0.2, terminated

    # 获取当前观察空间
    def _get_observation(self):
        obs = np.zeros((self.SCREEN_WIDTH//self.snakeCell, self.SCREEN_HEIGHT//self.snakeCell), dtype=np.uint8)

        # Set the snake body to gray with linearly decreasing intensity from head to tail.
        newsnake=np.array(self.snake_body)//7
        obs[tuple(np.transpose(newsnake))] = np.linspace(200, 50, len(newsnake), dtype=np.uint8)

        # Stack single layer into 3-channel-image.
        obs = np.stack((obs, obs, obs), axis=-1)

        # Set the snake head to green and the tail to blue
        obs[tuple(newsnake[0])] = [0, 255, 0]
        obs[tuple(newsnake[-1])] = [255, 0, 0]

        # Set the food to red
        obs[np.array(self.food_pos)//7] = [0, 0, 255]

        # Enlarge the observation to 84x84
        obs = np.repeat(np.repeat(obs, self.snakeCell, axis=0), self.snakeCell, axis=1)

        return obs
        


def main():
    snakeEnv = SnakeEnv()
    snakeEnv.reset()
    done = False
    while not done:
        # 获取事件
        for event in pygame.event.get():
            # 处理退出事件
            if event.type == pygame.QUIT:
                pygame.quit()
                done = True
        # 从动作空间随机获取一个动作
        action = snakeEnv.action_space.sample()
        screen, reward, truncated, done, _ = snakeEnv.step(action)
        snakeEnv.render()


if __name__ == "__main__":
    main()

2、训练代码

核心是修改算法名,之前用MlpPolicy,现在改为CnnPolicy
其余不变

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler

from stable_baselines3 import PPO
def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnvcnn-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-3, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)

model = PPO(  "CnnPolicy",
            env,
            device="cuda",
            verbose=1,
            n_steps=2048,
            batch_size=512,
            n_epochs=4,
            gamma=0.94,
            learning_rate=lr_schedule,
            clip_range=clip_range_schedule,
            tensorboard_log=LOG_DIR
            )
# 4、训练模型:

# Set the save directory
num=1
save_dir="trained_models_cnn"
while True:
    save_dir = "trained_models_cnn_{}".format(num)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        break
    else:
        num +=1

checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")

# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)


model.learn(
    total_timesteps=int(200000),
    callback=[checkpoint_callback]
)


# Restore stdout
sys.stdout = original_stdout

# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))


# 5、测试训练好的模型:
obs = env.reset()

for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    observation, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated:
        obs = env.reset()
env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1475756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java多线程】面试常考——锁策略、synchronized的锁升级优化过程以及CAS(Compare and swap)

目录 1、锁的策略 1.1、乐观锁和悲观锁 1.2、轻量级锁和重量级锁 1.3、自旋锁和挂起等待锁 1.4、普通互斥锁和读写锁 1.5、公平锁和非公平锁 1.6、可重入锁和不可重入锁 2、synchronized 内部的升级与优化过程 2.1、锁的升级/膨胀 2.1.1、偏向锁阶段 2.1.2、轻量级锁…

如何调用GLM-4 API实现智能问答

诸神缄默不语-个人CSDN博文目录 GLM系列大模型是智谱AI提供的系列语言模型&#xff0c;GLM-4没有开源&#xff0c;只提供了API。本文介绍如何用Python语言调用GLM-4 API实现智能问答。 智谱AI为所有用户提供了18元免费额度&#xff0c;可以试用。 文章目录 1. 获得API key2. …

探索AI视频模型的无限可能:OpenAI的Sora引领创新浪潮

文章目录 &#x1f4d1;前言一、技术解析二、应用场景三、未来展望四、伦理与创意五、用户体验与互动&#x1f324;️总结 &#x1f4d1;前言 随着人工智能技术的蓬勃发展&#xff0c;AI视频模型正逐渐成为科技领域的新宠。在这个变革的浪潮中&#xff0c;OpenAI推出的首个AI视…

嵌入式C语言(四)

零长度数组 零长度数组、变长数组都是GNU C编译器支持的数组类型。 什么是零长度数组&#xff1f; 首先肯定长度是为0的数组 ANSI C规定定义一个数组长度必须为一个常数&#xff0c;那么就是这个数组的长度在编译的时候就确定了。 int a[10];但是在C99标准中规定可以定义一…

华为数通方向HCIP-DataCom H12-821题库(单选题:501-520)

第501题 三台交换机运行RSTP协议,拓扑和配置情况如图所示。那么以下关于根桥的描述,正确的是哪一项? A、根桥是SWA B、根桥是SWB C、根桥是SWC D、根桥无法确定 参考答案:A 第502题 在华为设备中,以下哪一个命令可以实现BFD与静态默认路由联动? A、ip route-static 0.…

加盟户用光伏赚钱吗?

光伏发电是一个新兴的投资领域,其中所蕴含的财富商机是非常多的,并且大多可靠效益显著,让无数创业投资者看到了生财的希望。 一、什么是户用光伏&#xff1f; 户用光伏&#xff0c;顾名思义&#xff0c;是在居民区安装的光伏发电系统。它利用太阳能电池板将光能转化为电能&…

InnoDB锁介绍

本文主要介绍MySQL InnoDB引擎中的各种锁策略和锁类别&#xff0c;并针对记录锁做演示以便于理解。 以下内容适用于MySQL 8.0版本。 读写锁 处理并发读/写访问的系统通常实现一个由两种锁类型组成的锁系统。这两种锁通常被称为共享锁(shared lock)和排他锁(exclusive lock)&…

网络编程-编码与解码(Protobuf)

编码与解码 下面的文字都来自于极客时间 为什么要编解码呢&#xff1f;因为计算机数据传输的是二进制的字节数据 解码&#xff1a;字节数据 --> 字符串&#xff08;字符数据&#xff09; 编码&#xff1a;字符串&#xff08;字符数据&#xff09;–> 字节数据 我们在编…

车载测试-常用adb命令和使用场景

app安装&#xff1a; 最常用 adb install apk地址 应用安装&#xff0c;常用于直接在系统上安装新包用于测试验证bug 常用参数&#xff08;一般直接使用这三个参数一起&#xff09; -t 允许测试包 -r 替换已存在的应用程序&#xff0c;也就是说强制安装 -d 允许进行将见状…

【.NET Core】深入理解IO之File类

【.NET Core】深入理解IO之File类 文章目录 【.NET Core】深入理解IO之File类一、概述二、File类2.1 File.AppendAllLines方法2.2 File.AppendAllText方法2.3 File.Copy 方法2.4 File.Create 方法2.5 File.Decrypt(String) 方法2.6 File.Delete(String) 方法2.7 File.Move 方法…

基于Python网络爬虫的IT招聘就业岗位数据分析可视化推荐系统

文章目录 基于Python网络爬虫的IT招聘就业岗位数据分析可视化推荐系统项目概述招聘岗位数据爬虫分析系统展示用户注册登录系统首页IT招聘数据开发岗-javaIT招聘数据开发岗-PythonIT招聘数据开发岗-AndroidIT招聘数据开发岗-其它招聘岗位数据分析算法方面运维方面测试方面招聘岗…

C语言 变量

变量其实只不过是程序可操作的存储区的名称。C 中每个变量都有特定的类型&#xff0c;类型决定了变量存储的大小和布局&#xff0c;该范围内的值都可以存储在内存中&#xff0c;运算符可应用于变量上。 变量的名称可以由字母、数字和下划线字符组成。它必须以字母或下划线开头…

Python入门到精通(九)——Python数据可视化

Python数据可视化 一、JSON数据格式 1、定义 2、python数据和JSON数据转换 二、pyecharts 三、折线图 四、地图 五、动态柱状图 一、JSON数据格式 1、定义 JSON是一种轻量级的数据交互格式。可以按照JSON指定的格式去组织和封装数据JSON本质上是一个带有特定格式的字符…

python|闲谈2048小游戏和数组的旋转及翻转和转置

目录 2048 生成数组 n阶方阵 方阵旋转 顺时针旋转 逆时针旋转 mxn矩阵 矩阵旋转 测试代码 测试结果 翻转和转置 2048 《2048》是一款比较流行​的数字游戏​&#xff0c;最早于2014年3月20日发行。原版2048由Gabriele Cirulli首先在GitHub上发布&#xff0c;后被移…

LeetCode142. 环形链表 II刷题详解

今天力扣刷到了一个特别有意思的题目&#xff0c;于是就写了下面的题解来加深以下理解。 142. 环形链表 II - 力扣&#xff08;LeetCode&#xff09; 这个可以分为两大步去写&#xff0c;首先要判断链表是否有环&#xff0c;然后如果有环就去找到环的入口&#xff0c;没有环返…

11 PLL IP核

PLL IP 核简介 锁相环&#xff08;PLL&#xff09;作为一种反馈控制电路&#xff0c;其特点是利用外部输入的参考信号来控制环路内部震荡信号的频率和相位。因为锁相环可以实现输出信号频率对输入信号频率的自动跟踪&#xff0c;所以锁相环通常用于闭环跟踪电路。锁相环在工作…

一文读懂压敏电阻原理,参数,选型

大家好&#xff0c;我是砖一。 压敏电阻并不是一般的电阻&#xff0c;而是一种具有瞬态电压抑制功能的元件&#xff0c;效果同TVS。 这篇文章介绍压敏电阻的一些基本知识&#xff0c;包括参数、选型、应用等。 一&#xff0c;基础知识 压敏电阻用MY表示&#xff0c;MY后…

东方博宜 1078. 求恰好使s=1+1/2+1/3+…+1/n的值大于X时n的值

东方博宜 1078. 求恰好使s11/21/3…1/n的值大于X时n的值 #include<iostream> using namespace std; int main() {int x ;cin >> x ;double s 0 ;int i 1 ;bool m true ;while(m){s 1.0/i ;if(s>x)break;i ;}cout << i ;return 0 ; }

【C++那些事儿】深入理解C++类与对象:从概念到实践(上)| 揭开this指针的神秘面纱

&#x1f4f7; 江池俊&#xff1a; 个人主页 &#x1f525;个人专栏&#xff1a; ✅数据结构冒险记 ✅C那些事儿 &#x1f305; 有航道的人&#xff0c;再渺小也不会迷途。 文章目录 1. 面向过程和面向对象初步认识2.类的引入3.类的定义4.类的访问限定符及封装4.1 访问限定符…

SpringBoot案例(黑马学习笔记)

这个案例呢&#xff0c;就是Tlias智能学习辅助系统。 参考接口文档完成后端功能的开 发&#xff0c;然后结合前端工程进行联调测试即可。 完成后的成品效果展示&#xff1a; 准备工作 需求&环境搭建 需求说明 部门管理 部门管理功能开发包括&#xff1a; ● 查询部门列…