强化学习与ChatGPT:快速让AI学会玩贪食蛇游戏!

news2024/11/28 12:33:11

大家好,我是千寻哥,现在自动驾驶很火热,其实自动驾驶是一个很大的概念,主要涉及的领域包括强化学习以及计算机视觉。

今天给各位讲讲强化学习的入门知识,并且手把手和大家一起做一个强化学习的Demo。

一、 浅谈强化学习入门

说到强化学习,你可能会有一些陌生,但是说到Alpha Go的围棋对决,你可能一下子就明白了。是的,这就是强化学习的能力。

为了让大家更加直观的了解强化学习的能力以及效果,千寻自己开发了一个强化学习玩贪吃蛇的游戏!

alt

怎么样是不是十分的神奇!千寻今天和大家介绍一下,如何利用强化学习算法和ChatGPT让AI快速学会玩贪食蛇游戏。

我们将从理论基础出发,解释强化学习和深度强化学习的概念,并详细介绍使用本项目中所使用的DQN算法来训练AI玩贪食蛇的过程。

同时,我们将展示如何将ChatGPT与强化学习结合,以提供对游戏环境的实时解释和指导。

二、强化学习原理简介

强化学习是一种通过与环境交互学习最优行为策略的机器学习方法。在强化学习中,智能体通过观察环境的状态,并根据选择的动作获得奖励或惩罚来学习如何最大化累积奖励。

深度强化学习是将深度学习和强化学习相结合的方法,使用神经网络来近似值函数或策略函数,以解决高维状态空间和动作空间的问题。

在训练贪吃蛇的过程中使用的是PPO强化学习模型,以下是关于PPO算法的原理简介。

三、PPO算法训练智能体原理

定义状态表示

首先,需要定义贪吃蛇游戏的状态表示。状态可以包括蛇头位置、蛇身位置、食物位置等信息。这些信息将作为输入提供给PPO算法。

初始化PPO网络

使用神经网络作为策略函数,在PPO算法中,通常使用多层感知机(MLP)作为策略网络。策略网络的输入是状态表示,输出是在给定状态下选择每个可能动作的概率。

与环境交互

智能体与贪吃蛇环境进行交互。在每个时间步骤中,智能体观察当前状态,并根据策略网络选择一个动作。

收集经验

在与环境交互的过程中,记录智能体的状态、动作、奖励等信息,构成经验轨迹。一般通过多次游戏回合进行经验收集。

计算优势函数

根据收集到的经验,计算优势函数。优势函数用于估计每个动作相对于平均水平的优势程度,即衡量每个动作相对于当前策略的好坏程度。

更新策略网络

使用PPO算法的核心思想,对策略网络进行更新。更新的目标是最大化优势函数,同时限制更新幅度以保证策略的稳定性。

计算策略损失函数

根据收集到的经验和优势函数,计算策略损失函数。一般使用似然比优势函数作为损失函数,它包括策略网络输出动作的对数概率和优势函数的乘积。

计算策略更新

通过最小化策略损失函数来更新策略网络的参数。在PPO算法中,通常使用梯度下降方法,如Adam优化器,来最小化损失函数。

控制策略更新幅度

为了保证策略的稳定性,PPO算法使用了一种重要的技术,即通过限制策略更新幅度来防止太大的策略变化。这可以通过剪切或者概率比例等方式实现。

重复步骤3至步骤6

智能体与环境交互,收集经验,并更新策略网络。这个过程会进行多个迭代,直到达到预定的训练轮次或者策略收敛。

以下为部分训练深度强化学习的训练代码:

#!/usr/bin/python
# -*- coding: utf-8 -*-
from Agent import AgentDiscretePPO
from core import ReplayBuffer
from draw import Painter
from env4Snake import Snake
import random
import pygame
import numpy as np
import torch
import matplotlib.pyplot as plt

if __name__ == "__main__":

    #初始化超参数
    env = Snake()
    test_env = Snake()
    act_dim = 4
    obs_dim = 6
    agent = AgentDiscretePPO()
    agent.init(512, obs_dim, act_dim, if_use_gae=True)
    agent.state = env.reset()
    buffer = ReplayBuffer(2**12, obs_dim, act_dim, True)

    #设定训练迭代的轮数以及数据的批量大小
    MAX_EPISODE = 200
    batch_size = 64

    rewardList = []
    maxReward = -np.inf

    episodeList = []  # 存储训练轮数
    rewardArray = []  # 存储rewards得分

    for episode in range(MAX_EPISODE):

        # 进行强化学习模型的训练
        with torch.no_grad():
            trajectory_list = agent.explore_env(env, 2**12, 1, 0.99)

        # 反馈数据存入buffer缓存中
        buffer.extend_buffer_from_list(trajectory_list)

        # 根据缓存中的反馈数据更新网络结构
        agent.update_net(buffer, batch_size, 1, 2**-8)

        # 测试模型的代理获得贪吃蛇的得分
        ep_reward = testAgent(test_env, agent, episode)

        # 打印训练过程的信息
        print('Episode:', episode, 'Reward:%f' % ep_reward)

        rewardList.append(ep_reward)
        episodeList.append(episode)
        rewardArray.append(ep_reward)

        if episode > MAX_EPISODE / 3 and ep_reward > maxReward:
            maxReward = ep_reward
            print('保存模型!')
            torch.save(agent.act.state_dict(), 'model_weights/act_weight.pkl')

    pygame.quit()

代码的每一部分的功能,我已经在代码文件中进行了详细的注释。终端输出训练信息如下:

alt

为了进一步的对强化学习的模型训练过程,我们对训练过程的信息进行可视化。

添加如下代码:

    # 绘制训练轮数与rewards得分曲线
    plt.plot(episodeList, rewardArray, label='Actual Rewards')
    plt.plot(episodeList, fitted_rewards, label='Fitted Rewards')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.legend()

我们希望观察迭代训练的次数episode与最终强化学习的模型得分reward之间的关系,如下图所示:

alt

在曲线图像中的Actual Rewards标签为当前迭代轮数下的“实际得分数值”,Fittered Rewards标签为“拟合得分数值”。

通过以上的曲线,我们可以看出在约100轮左右,模型已经进入收敛状态,表示模型性能已经训练完成

四、ChatGPT与强化学习训练的结合

为了进一步优化强化学习的模型性能,将模型训练融入ChatGPT,ChatGPT是一种基于GPT-3.5架构的大型语言模型,具有强大的自然语言处理和生成能力。

那么ChatGPT语言生成模型与强化学习结合可以做什么呢?

必然是引入AI算法从而提供实时的游戏环境解释和指导。包括以下几点:

(1)游戏环境交互:通过ChatGPT与玩贪食蛇游戏的AI进行实时对话,AI可以向ChatGPT提问关于当前状态和最佳行动的问题。

(2)状态解释:AI可以将当前状态描述发送给ChatGPT,并从ChatGPT获得对状态的解释和建议。ChatGPT可以帮助AI理解游戏中的复杂状态和策略。

(3)行动建议:AI可以向ChatGPT询问最佳行动,并根据ChatGPT的建议选择下一步动作。ChatGPT可以基于其语言模型和先前的训练经验提供合理的建议。

(4)策略优化:AI可以根据ChatGPT提供的建议进行策略优化。AI在每个时间步骤中选择动作后,可以将结果反馈给ChatGPT,以便进行进一步的讨论和改进。

五、训练模型代理的验证

经过了ChatGPT的生成模型与PPO算法强化学习模型训练的AI玩贪吃蛇游戏,我们可以编写一个AI自动玩贪吃蛇游戏的推理代码:

定义Snake类的属性

class Snake:
    def __init__(self):
        self.snake_speed = 100 # 贪吃蛇的速度
        self.windows_width = 600
        self.windows_height = 600  # 游戏窗口的大小
        self.cell_size = 50  # 贪吃蛇身体方块大小,注意身体大小必须能被窗口长宽整除
        self.map_width = int(self.windows_width / self.cell_size)
        self.map_height = int(self.windows_height / self.cell_size)
        self.white = (255, 255, 255)
        self.black = (0, 0, 0)
        self.gray = (230, 230, 230)
        self.dark_gray = (40, 40, 40)
        self.DARKGreen = (0, 155, 0)
        self.Green = (0, 255, 0)
        self.Red = (255, 0, 0)
        self.blue = (0, 0, 255)
        self.dark_blue = (0, 0, 139)
        self.BG_COLOR = self.white  # 游戏背景颜色
        # 定义方向
        self.UP = 0
        self.DOWN = 1
        self.LEFT = 2
        self.RIGHT = 3

        self.HEAD = 0  # 贪吃蛇头部下标

        pygame.init()  # 模块初始化
        self.snake_speed_clock = pygame.time.Clock()  # 创建Pygame时钟对象

        [self.snake_coords,self.direction,self.food,self.state] = [None,None,None,None]

设置奖励条件与游戏终止条件

    # 判断蛇死了没
    def snake_is_alive(self,snake_coords):
        tag = True
        if snake_coords[self.HEAD]['x'] == -1 or snake_coords[self.HEAD]['x'] == self.map_width or snake_coords[self.HEAD]['y'] == -1 or \
                snake_coords[self.HEAD]['y'] == self.map_height:
            tag = False  # 蛇碰壁啦
        for snake_body in snake_coords[1:]:
            if snake_body['x'] == snake_coords[self.HEAD]['x'] and snake_body['y'] == snake_coords[self.HEAD]['y']:
                tag = False  # 蛇碰到自己身体啦
        return tag

    # 判断贪吃蛇是否吃到食物
    def snake_is_eat_food(self,snake_coords, food):  # 如果是列表或字典,那么函数内修改参数内容,就会影响到函数体外的对象。
        flag = False
        if snake_coords[self.HEAD]['x'] == food['x'] and snake_coords[self.HEAD]['y'] == food['y']:
            while True:
                food['x'] = random.randint(0, self.map_width - 1)
                food['y'] = random.randint(0, self.map_height - 1)  # 实物位置重新设置
                tag = 0
                for coord in snake_coords:
                    if [coord['x'],coord['y']] == [food['x'],food['y']]:
                        tag = 1
                        break
                if tag == 1: continue
                break
            flag = True
        else:
            del snake_coords[-1]  # 如果没有吃到实物, 就向前移动, 那么尾部一格删掉
        return flag

自动玩贪吃蛇游戏部署,设置贪吃蛇的复活命数为10次

if __name__ == "__main__":
    random.seed(100)
    env = Snake()
    env.snake_speed = 10
    agent = AgentDiscretePPO()
    agent.init(512,6,4)

    # 加载强化学习的训练模型
    agent.act.load_state_dict(torch.load('model_weights/act_weight.pkl'))

    # 设置贪吃蛇复活次数
    lifes = 10
    for _ in range(lifes):
        o = env.reset()
        while 1:
            env.render()
            for event in pygame.event.get():
                pass
            a,_ = agent.select_action(o)
            o2,r,d,_ = env.step(a)
            o = o2
            if d: break

最终在PyCharm环境中贪吃蛇的运行效果如图:

alt

实际的动图检验效果:

alt

通过ChatGPT指导的AI自动玩贪吃蛇,怎么样,效果不错吧,是不是很想尝试一下,没事,马上安排上!以下github地址链接,本地拷贝一下,就可以玩耍了!

Github:https://github.com/qianyuqianxun-DeepLearning/SnakeGameAI.git

我是千与千寻,一个只讲干货的码农,我们下期见~

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/576823.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SSM的酒店客房管理系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 酒店管理系统是一款高…

Java 集合 - 集合框架概述

文章目录 1.集合框架体系结构2.Collection 接口2.1 Iterator2.1.1 使用迭代器遍历集合2.1.2 使用迭代器删除集合元素2.1.3 Iterator 迭代器的 fail-fast 机制 2.2 Iterable2.3 List 集合2.4 Set 集合2.5 Queue 3.Map 集合 Java 集合框架(Java Collections Framework…

Java 集合 - Set 接口

文章目录 1.概述2.HashSet3.LinkedHashSet4.TreeSet5.选择合适的 Set 实现6.总结 1.概述 Set 接口的定义非常简单。它本质上是一个 Collection,但是要求该集合不能有重复的元素。换句话说,如果尝试将一个元素添加到 Set 中,而该元素已经存在…

FPGA实现ESP8266驱动且进行数据包收发

一. 简介 本次将使用正点原子的ESP8266 WIFI模块,来实现PC与FPGA之间的TCP通讯,其中ESP8266与FPGA之间的接口是UART。 二. 正点原子的ESP8266 WIFI模块介绍 模块实物图如下,到手就可以使用了,RST和IO_0两个IO口不接或者接高电平…

C++布隆过滤器和哈西切分

文章目录 一、布隆过滤器的提出二、布隆过滤器的概念三、布隆过滤器的实现布隆过滤器的插入布隆过滤器的判断在不在布隆过滤器的删除布隆过滤器的优点布隆过滤器的缺点 四、布隆过滤器的应用场景五、布隆过滤器的扩展[面试题]六、哈西切分 一、布隆过滤器的提出 我们在使用新闻…

GO 语言核心编程-全文版

第 1 章 1.1Golang的学习方向 Go语言,我们可以简单的写成Golang. Golang开山篇 1.2Golang的应用领域 1.2.1区块链的应用开发 1.2.2后台的服务应用 1.2.3云计算/云服务后台应用 1.3学习方法的介绍 1.4讲课的方式的说明 努力做到通俗易懂注重Go语言体系&#xff…

K8s之零故障升级Pod健康探测详解

文章目录 一、Pod健康探测介绍1、三种容器探测方法2、常用三种探测探针3、探针相关属性说明 二、探测案例1、Pod启动探测案例-startupProbe2、Pod存活探测案例-livenessProbe3、Pod就绪探测案例-readinessProbe4、启动、存活、就绪探测混合使用案例 三、总结 一、Pod健康探测介…

【MySQL新手到通关】第五章 多表查询

文章目录 1. 笛卡尔积1.1 避免笛卡尔积1.2 笛卡尔积(或交叉连接)的理解1.3 案例分析与问题解决笛卡尔积的错误会在下面条件下产生: 2. 多表查询分类讲解2.1 多表联查分类方式1:2.2 多表联查分类方式2:2.3 多表联查分类…

Eclipse教程 Ⅴ

Eclipse 创建 Java 类 打开新建 Java 类向导 你可以使用新建 Java 类向导来创建 Java 类,可以通过以下途径打开 Java 类向导: 点击 "File" 菜单并选择 New > Class在 Package Explorer 窗口中右击鼠标并选择 New > Class点击类的下拉…

c++输入输出文件操作stream

系列文章目录 C IO库 文章目录 系列文章目录前言一、文件IO概述coutcin其他istream类方法 文件输入和输出内核格式化总结 前言 一、文件IO 概述 c程序把输入和输出看作字节流。输入时,程序从输入流中抽取字节:输出时,程序将字节流插入到输…

springboot+ssm+java校园二手物品交易系统vxkyj

样需要经过市场调研,需求分析,概要设计,详细设计,编码,测试这些步骤,基于Java语言、Jsp技术设计并实现了校园二手物品交易系统。系统主要包括个人中心、商家管理、用户管理、商品分类管理、商品信息管理、商…

中间件SOME/IP简述

SOME/IP SOME/IP 不是广义上的中间件,严格的来讲它是一种通信协议,但中间件这个概念太模糊了,所以我们也一般称 SOME/IP 为通信中间件。 SOME/IP 全称是 Scalable service-Oriented MiddlewarE over IP。也就是基于 IP 协议的面向服务的可扩…

调用华为API实现身份证识别

调用华为API实现身份证识别 1、作者介绍2、调用华为API实现身份证识别2.1 算法介绍2.1.1OCR简介2.1.2身份证识别原理2.1.3身份证识别应用场景 2.2 调用华为API流程 3、代码实现3.1安装相关的包3.2代码复现3.3实验结果 1、作者介绍 雷千龙,男,西安工程大…

Spring Boot如何实现配置文件的自动加载和刷新?

Spring Boot如何实现配置文件的自动加载和刷新? 在使用Spring Boot开发应用程序时,配置文件是非常重要的组成部分。在不同的环境中,我们可能需要使用不同的配置文件,例如在开发、测试和生产环境中使用不同的配置文件。而且&#…

功能测试转到自动化测试,我的测试之路“狂飙”~20k...

前言 Python自动化测试:Python自动化测试,7天练完这60个实战项目,年薪过35w。 手动功能测试人员应该权衡测试自动化相对于功能测试的好处,并且即可开始行动。现在随着测试行业的发展,自动化测试已经成为每个测试人的标…

nodejs+vue大学生招聘网站应聘系统设计与实现5b14b

目前,伴随着Internet技术的日益成熟,互联网需要提供更多的服务,发达国家已形成以信息技术为核心,招聘网站支撑的现代化招聘公司技术格局。这便是今天为大家所熟悉的管理信息系统,网络发展为招聘网站实现信息化、自动化、智能化和集…

牛客小白月赛73DE

问题很好转化,但是对区间的处理没把握好,一直在纠结怎么o(n) 一开始想到二分了,但是没细想,结果看了讲解发现,其实就是一个前缀数组上对区间的查询的操作,以后再遇到此类问题直接向…

Git提交提交代码报错 Push failed unable to access

目录 场景 环境: Git配置 场景 Push failed unable to access https://github.com/1790753131/remotRepository3.git/: Failed to connect to github.com port 443 after 21114 ms: Couldnt connect to server Push failed unable to ac…

计算节点与存储设备是如何连接的?

本文是《数据存储通识课》合集的一部分,本合集希望通过一系列文章科普数据存储相关技术内容。同时,本系列文章不仅仅是科普,还会进行有深度解析,理论结合实现,从代码实现层面进行剖析。欢迎关注“数据存储张”,老张是深耕存储十几载,就业于存储No1公司的资深工程师。 无…

Keil 5 MDK 发律师函警告了,如何用STCubeIDE开发标准库的程序(STM32F103C8T6为例)

用STCubeIDE进行标准库开发 1、CubeIDE介绍 https://www.stmcu.com.cn/ecosystem/Cube/STM32CubeIDE 2、CubeIDE下载 点击上面的链接,登录即可下载 3、搭建Demo工程 新建一个工作空间 创建一个工程 选择芯片-STM32F103C8T6 填写工程信息 添加标准库到工程 标…