PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)

news2025/3/17 23:44:32

在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念,并使用 PyTorch 实现一个经典的深度 Q 网络(DQN)来解决强化学习中的经典问题——CartPole。

一、强化学习基础

强化学习(Reinforcement Learning, RL)是机器学习的一个重要分支,它通过智能体(Agent)与环境(Environment)的交互来学习策略,以最大化累积奖励。强化学习的核心思想是通过试错来学习,智能体在环境中采取行动,观察结果,并根据奖励信号调整策略。

1. 强化学习的基本要素

  • 智能体(Agent):学习并做出决策的主体。

  • 环境(Environment):智能体交互的外部世界。

  • 状态(State):环境在某一时刻的描述。

  • 动作(Action):智能体在某一状态下采取的行动。

  • 奖励(Reward):智能体采取动作后,环境返回的反馈信号。

  • 策略(Policy):智能体在给定状态下选择动作的规则。

  • 价值函数(Value Function):评估在某一状态下采取某一动作的长期回报。

2. Q-Learning 与深度 Q 网络(DQN)

Q-Learning 是一种经典的强化学习算法,它通过学习一个 Q 函数来评估在某一状态下采取某一动作的长期回报。Q 函数的更新公式为:

深度 Q 网络(DQN)将 Q-Learning 与深度学习结合,使用神经网络来近似 Q 函数。DQN 通过经验回放(Experience Replay)和目标网络(Target Network)来稳定训练过程。

二、CartPole 问题实战

CartPole 是强化学习中的经典问题,目标是控制一个小车(Cart)使其上的杆子(Pole)保持直立。我们将使用 PyTorch 实现一个 DQN 来解决这个问题。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。

  2. 定义 DQN 模型。

  3. 定义经验回放缓冲区。

  4. 定义 DQN 训练过程。

  5. 测试模型并评估性能。

3. 代码实现

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
​
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
​
# 1. 安装并导入必要的库
env = gym.make('CartPole-v1')
​
# 2. 定义 DQN 模型
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)
​
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
​
# 3. 定义经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
​
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
​
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)
​
    def __len__(self):
        return len(self.buffer)
​
# 4. 定义 DQN 训练过程
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
target_model = DQN(state_size, action_size)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)
​
def train(batch_size, gamma=0.99):
    if len(buffer) < batch_size:
        return
    state, action, reward, next_state, done = buffer.sample(batch_size)
    state = torch.FloatTensor(state)
    next_state = torch.FloatTensor(next_state)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    done = torch.FloatTensor(done)
​
    q_values = model(state)
    next_q_values = target_model(next_state)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)
​
    loss = nn.MSELoss()(q_value, expected_q_value.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
​
# 5. 测试模型并评估性能
def test(env, model, episodes=10):
    total_reward = 0
    for _ in range(episodes):
        state = env.reset()
        done = False
        while not done:
            state = torch.FloatTensor(state).unsqueeze(0)
            action = model(state).max(1)[1].item()
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            state = next_state
    return total_reward / episodes
​
# 训练过程
episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
rewards = []
​
for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0
​
    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = model(state_tensor).max(1)[1].item()
​
        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
​
        train(batch_size, gamma)
​
    epsilon = max(epsilon_min, epsilon * epsilon_decay)
    rewards.append(total_reward)
​
    if (episode + 1) % 50 == 0:
        avg_reward = test(env, model)
        print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")
​
# 6. 可视化训练结果
plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("DQN 训练过程")
plt.show()

三、代码解析

  1. 环境与模型定义

    • 使用 gym 创建 CartPole 环境。

    • 定义 DQN 模型,包含三个全连接层。

  2. 经验回放缓冲区

    • 使用 deque 实现经验回放缓冲区,存储状态、动作、奖励等信息。

  3. 训练过程

    • 使用 epsilon-greedy 策略进行探索与利用。

    • 通过经验回放缓冲区采样数据进行训练,更新模型参数。

  4. 测试过程

    • 在测试环境中评估模型性能,计算平均奖励。

  5. 可视化

    • 绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。

  • 训练结束后,绘制训练过程中的总奖励曲线。

五、总结

本文介绍了强化学习的基本概念,并使用 PyTorch 实现了一个深度 Q 网络(DQN)来解决 CartPole 问题。通过这个例子,我们学习了如何定义 DQN 模型、使用经验回放缓冲区、训练模型以及评估性能。

在下一篇文章中,我们将探讨更复杂的强化学习算法,如 Actor-Critic 和 Proximal Policy Optimization (PPO)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解强化学习的基础知识!如果有任何问题,欢迎在评论区留言讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2316873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python学习第十九天

Django-分页 后端分页 Django提供了Paginator类来实现后端分页。Paginator类可以将一个查询集&#xff08;QuerySet&#xff09;分成多个页面&#xff0c;每个页面包含指定数量的对象。 from django.shortcuts import render, redirect, get_object_or_404 from .models impo…

Windows环境下安装部署dzzoffice+onlyoffice的私有网盘和在线协同系统

安装前需要准备好Docker Desktop环境&#xff0c;可查看我的另一份亲测安装文章 https://blog.csdn.net/qq_43003203/article/details/146283915?spm1001.2014.3001.5501 1、安装配置onlyoffice 1、Docker 拉取onlyoffice容器镜像 管理员身份运行Windows PowerShell&#x…

ChatPromptTemplate的使用

ChatPromptTemplate 是 LangChain 中专门用于管理多角色对话结构的提示词模板工具。它的核心价值在于&#xff0c;开发者可以预先定义不同类型的对话角色消息&#xff08;如系统指令、用户提问、AI历史回复&#xff09;&#xff0c;并通过数据绑定动态生成完整对话上下文。 1.…

Blender插件NodeWrangler导入贴图报错解决方法

Blender用NodeWrangler插件 CtrlShiftT 导入贴图 直接报错 解决方法: 用CtrlshiftT打开需要导入的材质文件夹时&#xff0c;右边有一个默认勾选的相对路径&#xff0c;取消勾选就可以了。 开启node wrangler插件&#xff0c;然后在导入贴图是取消勾选"相对路径"&am…

java项目之基于ssm的药店药品信息管理系统(源码+文档)

项目简介 药店药品信息管理系统实现了以下功能&#xff1a; 个人信息管理 负责管理个人用户的信息。 员工管理 负责管理药店或药品管理机构的员工信息。 药品管理 负责管理药品的详细信息&#xff0c;可能包括药品名称、成分、剂量、价格、库存等。 进货管理 负责管理药品…

论文分享 | HE-Nav: 一种适用于复杂环境中空地机器人的高性能高效导航系统

阿木实验室始终致力于通过开源项目和智能无人机产品&#xff0c;为全球无人机开发者提供强有力的技术支持&#xff0c;并推出了开源项目校园赞助活动&#xff0c;助力高校学子在学术研究与技术创新中取得更大突破。近日&#xff0c;香港大学王俊铭同学&#xff0c;基于阿木实验…

ubuntu 24 安装 python3.x 教程

目录 注意事项 一、安装不同 Python 版本 1. 安装依赖 2. 下载 Python 源码 3. 解压并编译安装 二、管理多个 Python 版本 1. 查看已安装的 Python 版本 2. 配置环境变量 3. 使用 update-alternatives​ 管理 Python 版本 三、使用虚拟环境为项目指定特定 Python 版本…

【sql靶场】第13、14、17关-post提交报错注入保姆级教程

目录 【sql靶场】第13、14、17关-post提交报错注入保姆级教程 1.知识回顾 1.报错注入深解 2.报错注入格式 3.使用的函数 4.URL 5.核心组成部分 6.数据编码规范 7.请求方法 2.第十三关 1.测试闭合 2.列数测试 3.测试回显 4.爆出数据库名 5.爆出表名 6.爆出字段 …

93.HarmonyOS NEXT窗口管理基础教程:深入理解WindowSizeManager

温馨提示&#xff1a;本篇博客的详细代码已发布到 git : https://gitcode.com/nutpi/HarmonyosNext 可以下载运行哦&#xff01; HarmonyOS NEXT窗口管理基础教程&#xff1a;深入理解WindowSizeManager 文章目录 HarmonyOS NEXT窗口管理基础教程&#xff1a;深入理解WindowSiz…

Python----数据分析(Pandas一:pandas库介绍,pandas操作文件读取和保存)

一、Pandas库 1.1、概念 Pandas是一个开源的、用于数据处理和分析的Python库&#xff0c;特别适合处理表格类数 据。它建立在NumPy数组之上&#xff0c;提供了高效的数据结构和数据分析工具&#xff0c;使得数据操作变得更加简单、便捷和高效。 Pandas 的目标是成为 Python 数据…

基于WebRTC技术的EasyRTC嵌入式音视频SDK:多平台兼容与性能优化

在当今数字化、智能化的时代背景下&#xff0c;实时音视频通信技术已成为众多领域不可或缺的关键技术。基于WebRTC技术的EasyRTC嵌入式音视频SDK&#xff0c;凭借其在ARM、Linux、Windows、安卓、iOS等多平台上的兼容性&#xff0c;为开发者提供了强大的工具&#xff0c;推动了…

【快速入门】MyBatis

一.基础操作 1.准备工作 1&#xff09;引入依赖 一个是mysql驱动包&#xff0c;一个是mybatis的依赖包&#xff1a; <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><vers…

交互式可视化进阶(Plotly Dash构建疫情仪表盘)

这里写目录标题 交互式可视化进阶(Plotly Dash构建疫情仪表盘)1. 引言2. 项目背景与意义3. 数据集生成与介绍4. GPU加速在数据处理中的应用5. 交互式仪表盘构建与Plotly Dash6. PyQt GUI集成与美化7. 工程整体架构8. 部分代码实现9. 代码自查与BUG排查10. 总结与展望交互式可…

如何选择适合您智能家居解决方案的通信协议?

如何选择适合您智能家居解决方案的通信协议&#xff1f; 在开发智能家居产品时&#xff0c;选择合适的通信协议对于设备的高效运行及其在智能家居系统中的互操作性至关重要。市面上协议众多&#xff0c;了解它们的特性并在做决定前考虑各种因素是非常必要的。以下是一些帮助您…

RabbitMQ可靠性进制

文章目录 1.生产者可靠性生产者重连生产者确认小结 2. MQ的可靠性数据持久化LazyQueue小结 3. 消费者的可靠性消费者确认机制消费者失败处理方案业务幂等性唯一消息ID业务判断 兜底方案业务判断 兜底方案 1.生产者可靠性 生产者重连 在某些场景下由于网络波动&#xff0c;可能…

版本控制器Git(5)

文章目录 前言一、理解标签二、创建标签三、操作标签四、多人协作场景一五、多人协作场景二总结 前言 本篇是最后一篇&#xff0c;主要介绍标签管理有关的内容 一、理解标签 标签定义&#xff1a;在Git中&#xff0c;标签&#xff08;tag&#xff09;是对某次提交&#xff08;c…

【数据分析】读取文件

3. 读取指定列 针对只需要读取数据中的某一列或多列的情况&#xff0c;pd.read_csv()函数提供了一个参数&#xff1a;usecols&#xff0c;将包含对应的columns的列表传入该参数即可。 上面&#xff0c;我们学习了读取 "payment" 和 "items_count" 这…

Dify使用部署与应用实践

最近在研究AI Agent&#xff0c;发现大家都在用Dify&#xff0c;但Dify部署起来总是面临各种问题&#xff0c;而且我在部署和应用测试过程中也都遇到了&#xff0c;因此记录如下&#xff0c;供大家参考。Dify总体来说比较灵活&#xff0c;扩展性比较强&#xff0c;适合基于它做…

Java 大视界 -- 基于 Java 的大数据机器学习模型的迁移学习应用与实践(129)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…

1.Windows+vscode+cline+MCP配置

文章目录 1.简介与资源2.在windows中安装vscode及Cline插件1. 安装vscode2. 安装Cline插件3. 配置大语言模型3. 配置MCP步骤(windows) 1.简介与资源 MCP官方开源仓库 MCP合集网站 参考视频 2.在windows中安装vscode及Cline插件 1. 安装vscode 2. 安装Cline插件 Cline插件…