基于强化学习的Deep-Qlearning网络玩cartpole游戏

news2024/9/30 3:34:37

1、环境准备,gym的版本为0.26.2

2、编写网络代码

# 导入必要的库
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random


# 定义DQN网络
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        # 定义三层全连接网络
        self.fc1 = nn.Linear(state_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


# 定义DQN智能体
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)  # 经验回放池
        self.gamma = 0.95  # 折扣因子
        self.epsilon = 1.0  # 探索率
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_size, action_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def remember(self, state, action, reward, next_state, done):
        # 将经验存储到经验回放池中
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        # ε-贪婪策略选择动作
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        act_values = self.model(state)
        return np.argmax(act_values.cpu().data.numpy())

    def replay(self, batch_size):
        # 从经验回放池中随机采样进行学习
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
                target = (reward + self.gamma * np.amax(self.model(next_state).cpu().data.numpy()))
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            target_f = self.model(state)
            target_f[0][action] = target
            self.optimizer.zero_grad()
            loss = nn.MSELoss()(self.model(state), target_f)
            loss.backward()
            self.optimizer.step()
        # 更新探索率
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        # 加载模型
        self.model.load_state_dict(torch.load(name))

    def save(self, name):
        # 保存模型
        torch.save(self.model.state_dict(), name)


# 训练函数
def train_dqn():
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    episodes = 1000
    batch_size = 32

    for e in range(episodes):
        state, _ = env.reset() #重置环境,返回初始观察值和初始奖励
        for time in range(500):
            action = agent.act(state)
            next_state, reward, done, _, _ = env.step(action) # 执行动作,返回5个数值
            reward = reward if not done else -10 # 如果游戏结束,给予负奖励
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print(f"episode: {e}/{episodes}, score: {time}, epsilon: {agent.epsilon:.2}")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)
        if e % 100 == 0:
            agent.save(f"cartpole-dqn-{e}.pth")  # 每100回合保存一次模型


# 使用训练好的模型玩游戏
def play_cartpole():
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    agent.load("cartpole-dqn-900.pth")  # 加载训练好的模型

    for e in range(10):  # 玩10局
        state, _ = env.reset()
        for time in range(500):
            env.render()
            action = agent.act(state)
            next_state, reward, done, _, _= env.step(action)
            state = next_state
            if done:
                print(f"episode: {e}, score: {time}")
                break
    env.close()

if __name__ == '__main__':
    # 如果要训练模型,取消下面这行的注释
    # train_dqn()
    # 如果要使用训练好的模型玩游戏,取消下面这行的注释
    play_cartpole()

更多解析请参考:https://zhuanlan.zhihu.com/p/29283993

 https://zhuanlan.zhihu.com/p/29213893

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1980125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《深入浅出WPF》学习笔记五.Mvvm设计模式

《深入浅出WPF》学习笔记五.Mvvm设计模式 背景 在通过视频学习wpf的过程中&#xff0c;讲师花了不少篇幅来讲Mvvm。特地在此用自己的语言总结一番,方便以后面试回答&#xff0c;如有理解不对&#xff0c;欢迎指正哈。 Mvvm结构 Mvvm指的是ModelViewViewModel 为什么要使用…

《网络安全自学教程》- MySQL匿名用户的原理分析与实战研究

《网络安全自学教程》 低版本的MySQL数据库在安装时会创建一个用户名和密码为空的账户&#xff0c;也就是匿名账户。即使升级到高版本&#xff0c;匿名账户仍然会存在。 MySQL匿名账户 1、检查是否存在匿名账户2、检查用户权限3、创建匿名账户4、使用匿名账户登录5、删除匿名账…

医院管理系统

医院管理系统 本文所涉及所有资源均在传知代码平台可获取 文章目录 医院管理系统概述使用技术核心功能1. 登录与注册2. 管理员系统3. 患者系统&#xff08;医院电子平台&#xff09;4. 医生系统&#xff08;坐诊系统&#xff09; 部署与启动适用场景 概述 本项目是一个专为大学…

读零信任网络:在不可信网络中构建安全系统09用户信任

1. 用户信任 1.1. 将设备身份和用户身份混为一谈会导致一些显而易见的问题 1.1.1. 特别是当用户拥有多台设备时&#xff0c;而这种情况很普遍 1.1.2. 应该针对不同类型的设备提供相匹配的凭证 1.1.3. 在存在共用终端设备的情况下&#xff0c;所有的这些问题将更加凸显 1.2…

打造未来交互新篇章:基于AI大模型的实时交互式流媒体数字人项目

在当今数字化浪潮中,人工智能(AI)正以前所未有的速度重塑我们的交互体验。本文将深入探讨一项前沿技术——基于AI大模型的实时交互式流媒体数字人项目,该项目不仅集成了多种先进数字人模型,还融合了声音克隆、音视频同步对话、自然打断机制及全身视频拼接等前沿功能,为用…

Python中使用正则表达式

摘要&#xff1a; 正则表达式&#xff0c;又称为规则表达式&#xff0c;它不是某种编程语言所特有的&#xff0c;而是计算机科学的一个概念&#xff0c;通常被用来检索和替换某些规则的文本。 一.正则表达式的语法 ①行定位符 行定位符就是用来描述字符串的边界。"^&qu…

第十三节、人物属性及伤害计算

一、碰撞器层级剔除 选中player和敌人&#xff0c;即可去除 若勾选触发器&#xff0c;则会取消掉碰撞效果&#xff0c;物体掉落 二、人数属性受伤计算 1、创建代码 将两个代码挂载到玩家和敌人身上 2、调用碰撞物体的方法 3、伤害值 开始&#xff1a;最大血量即为当前血量…

Arduino PID库 (6):初始化

Arduino PID库 &#xff08;6&#xff09;&#xff1a;初始化 参考&#xff1a;手把手教你看懂并理解Arduino PID控制库——初始化 Arduino PID库 &#xff08;5&#xff09;&#xff1a;开启或关闭 PID 控制的影响 问题 在上一节中&#xff0c;我们实现了关闭和打开PID的功…

最小二乘法求解线性回归问题

本文章记录通过矩阵最小二乘法&#xff0c;求解二元方程组的线性回归。 假设&#xff0c;二维平面中有三个坐标&#xff08;1&#xff0c;1&#xff09;、&#xff08;2&#xff0c;2&#xff09;、&#xff08;3&#xff0c;2&#xff09;&#xff0c;很显然该三个坐标点不是…

React(三):PDF文件在线预览(简易版)

效果 依赖下载 https://mozilla.github.io/pdf.js/getting_started/ 引入依赖 源码 注意&#xff1a;pdf文件的预览地址需要配置代理后才能显示出来 import ./index.scss;function PreviewPDF() {const PDF_VIEWER_URL new URL(./libs/pdfjs-4.5.136-dist/web/viewer.html, im…

12.SpringDataRedis

介绍 SpringData是Spring中数据操作的模块&#xff0c;包含对各种数据库的集成&#xff0c;其中redis的集成模块就叫做SpringDataRedis。 spring的思想从来都不是重新生产&#xff0c;而是整合其他技术。 SpringDataRedis的特点 1.提供了对不同redis客户端的整合&#xff08…

8.4 day bug

bug1 忘记给css变量加var 复制代码到通义千问&#xff0c;解决 bug2 这不是我的bug&#xff0c;是freecodecamp的bug 题目中“ 将 --building-color2 变量的颜色更改为 #000” “ 应改为” 将 #000 变量的颜色更改为 --building-color2 “ bug3 又忘记加var(–xxx) 还去问…

渗透小游戏,各个关卡的渗透实例

Less-1 首先&#xff0c;可以看见该界面&#xff0c;该关卡主要是SQL注入&#xff0c;由于对用户的输入没有做过滤&#xff0c;使查询语句进入到了数据库中&#xff0c;查询到了本不应该查询到的数据 首先&#xff0c;如果想要进入内部&#xff0c;就要绕过&#xff0c;首先是用…

C#中的TCP和UDP

TcpClient TCP客户端 UDP客户端 tcp和udp的区别 TCP&#xff08;传输控制协议&#xff09;和UDP&#xff08;用户数据报协议&#xff09;是两种在网络通信中常用的传输层协议&#xff0c;它们在C#或任何其他编程语言中都具有相似的特性。下面是TCP和UDP的主要区别&#xff1a;…

MySQL的基本使用

文章目录 MySQL的基本使用什么是SQLSQL学习目标SQL的SELECT语句SQL的INSERT INTO语句 SQL的UPDATE语句SQL的DELETE语句 SQL的WHERE子句可在WHERE子句中使用的运算符SQL的AND和OR运算符SQL的ORDER BY子句SQL的COUNT(*)函数 在项目中操作数据库的步骤安装mysql模块配置mysql模块测…

微服务设计原则——易维护

文章目录 1.充分必要2.单一职责3.内聚解耦4.开闭原则5.统一原则6.用户重试7.最小惊讶8.避免无效请求9.入参校验10.设计模式11.禁用 flag 标识12.分页宜小不宜大参考文献 1.充分必要 不是随便一个功能都需要开发个接口。 虽然一个接口应该只专注一件事&#xff0c;但并不是每个…

摩托罗拉刷机包和固件下载地址

发现了一个非常好的摩托罗拉刷机包和固件下载地址&#xff1a;https://firmware.center/ 里面包含了所有的摩托罗拉的刷机包和软件、电路图等等&#xff0c;非常多&#xff0c;我想镜像到本地网盘&#xff0c;但不知道怎么操作&#xff0c;有没有懂得朋友教我全部镜像到国内的…

Kafka生产者(二)

1、生产者消息发送流程 1.1 发送原理 在消息发送的过程中&#xff0c;涉及到了两个线程——main 线程和 Sender 线程。在 main 线程中创建了一个双端队列 RecordAccumulator。main 线程将消息发送给 RecordAccumulator&#xff0c;Sender 线程不断从 RecordAccumulator 中拉取…

Gamma AI:一键生成专业级PPT的智能工具

1. Gamma 简介 Gamma 是一个致力于通过非常简单的ai交互&#xff0c;制作好的视觉体验作品&#xff0c;它始终站在作者的视角新增功能&#xff0c;同时注重观众视角呈现作品。 突破了以往演示文档&#xff08;ppt、pdf、网站&#xff09;表现形式&#xff0c;能够借助ai的力量…

informer中的WorkQueue机制的实现分析与源码解读(1)

背景 client-go中的workqueue包里主要有三个队列&#xff0c;分别是普通队列Queue&#xff0c;延时队列DelayingQueue&#xff0c;限速队列RateLimitingQueue&#xff0c;后一个队列以前一个队列的实现为基础&#xff0c;层层添加新功能。 workqueue是整个client-go源码的重点…