深度Q网络(DQN)算法技术博客

news2024/10/7 9:16:39

深度Q网络(DQN)是一种将深度学习与强化学习相结合的算法,用于解决高维状态空间的强化学习问题。本文将详细介绍DQN算法的基本原理,关键公式以及具体的代码实现。

一、DQN算法的基本原理

DQN算法是Q学习的一种扩展,利用神经网络来逼近Q值函数。其核心思想是通过不断地与环境交互,从而学习到一个策略,使得在每个状态下的累积奖励最大化。Q值函数的定义如下:

Q(s, a) = \mathbb{E}[r_t + \gamma \max_{a'} Q(s_{t+1}, a') | s_t = s, a_t = a]

其中:

  • s 是状态
  • a 是动作
  • r 是奖励
  • \gammaγ是折扣因子(0 <= \gamma < 1)

DQN通过使用两个神经网络来稳定训练过程:

  1. 策略网络(Policy Network):用来生成动作的Q值。
  2. 目标网络(Target Network):用来生成目标Q值,其参数定期从策略网络复制。

二、DQN算法的关键步骤

  1. 经验回放(Experience Replay):为了打破数据之间的相关性,DQN使用了经验回放技术,将经验存储在一个记忆库中,并从中随机采样一批用于训练。

  2. 目标Q值的计算y_i = r_i + \gamma \max_{a'} Q'(s_{i+1}, a'; \theta^{-})其中 \theta^{-} 是目标网络的参数, \theta 是策略网络的参数。

  3. 损失函数的定义L(\theta) = \mathbb{E}_{(s, a, r, s') \sim D} [(y_i - Q(s, a; \theta))^2]通过最小化上述损失函数,来更新策略网络的参数。

三、DQN算法的代码实现

以下是一个简单的DQN算法在OpenAI Gym的CartPole环境中的实现。

import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# DQN算法类
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95  # 折扣因子
        self.epsilon = 1.0  # 探索率
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = QNetwork(state_size, action_size)
        self.target_model = QNetwork(state_size, action_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.update_target_model()
    
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state)
        act_values = self.model(state)
        return np.argmax(act_values.detach().numpy())
    
    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = self.model(torch.FloatTensor(state)).detach().numpy()
            if done:
                target[action] = reward
            else:
                t = self.target_model(torch.FloatTensor(next_state)).detach().numpy()
                target[action] = reward + self.gamma * np.amax(t)
            target_f = self.model(torch.FloatTensor(state))
            target_f[action] = torch.FloatTensor([target[action]])
            self.model.zero_grad()
            loss = nn.MSELoss()(target_f, torch.FloatTensor(target))
            loss.backward()
            self.optimizer.step()
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# 训练DQN模型
if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    episodes = 1000
    batch_size = 32
    
    for e in range(episodes):
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        for time in range(500):
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            reward = reward if not done else -10
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                agent.update_target_model()
                print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")
                break
            if len(agent.memory) > batch_size:
                agent.replay(batch_size)

四、总结

DQN算法通过结合Q学习与深度神经网络,解决了高维状态空间下的强化学习问题。本文详细介绍了DQN的基本原理、关键步骤和具体的代码实现,希望能够帮助读者更好地理解和应用这一算法。如果在实际应用中遇到问题,可以参考相关文献和开源代码库,进一步优化和改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1893330.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第7天|保存与加载

Python语言 AI框架&#xff1a;Mindspore 1.模型构建 class Network(nn.Cell):def __init__(self):super().__init__()self.flatten nn.Flatten()self.dense_relu_sequential nn.SequentialCell(nn.Dense(28*28, 512, weight_init"normal", bias_init"zeros…

确认下单:购物车页面点击 去结算 按钮发起两个请求trade(显示购物车的商品信息和计算商品的总金额)findUserAddressList

文章目录 1、确认下单&#xff1a;购物车页面点击去结算1.1、在OrderController类中创建 trade 方法1.2、在CartController类中创建 checkedCartInfos1.3、CartServiceImpl 实现 checkedCartInfos的业务功能1.4、在service-cart-client模块下定义远程openFeign接口1.5、在SpzxO…

vue-org-tree搜索到对应项高亮展开

效果图&#xff1a; 代码&#xff1a; <template><div class"AllTree"><el-form :inline"true" :model"formInline" class"demo-form-inline"><el-form-item><el-input v-model"formInline.user&quo…

2024年【天津市安全员B证】最新解析及天津市安全员B证实操考试视频

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年【天津市安全员B证】最新解析及天津市安全员B证实操考试视频&#xff0c;包含天津市安全员B证最新解析答案和解析及天津市安全员B证实操考试视频练习。安全生产模拟考试一点通结合国家天津市安全员B证考试最新大…

C++必修:深入理解继承与虚继承

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C学习 贝蒂的主页&#xff1a;Betty’s blog 1. 继承的概念与定义 1.1. 继承的概念 继承(inheritance)机制是面向对象程序设计…

【12321骚扰电话举报受理中心-短信验证安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…

【HTML入门】第一课 - 网页标签框架

这一节&#xff0c;我们说一下学习前端开发的话&#xff0c;最入门的也是非常重要的一门可成&#xff0c;也就是HTML。HTML标签&#xff0c;是网页的重要组成部分&#xff0c;可以说&#xff0c;你看到网页上的内容&#xff0c;都是基于HTML标签呈现出来的。 这一小节呢&#…

Windows系统安装SSH服务结合内网穿透配置公网地址远程ssh连接

前言 在当今的数字化转型时代&#xff0c;远程连接和管理计算机已成为日常工作中不可或缺的一部分。对于 Windows 用户而言&#xff0c;SSH&#xff08;Secure Shell&#xff09;协议提供了一种安全、高效的远程访问和命令执行方式。SSH 不仅提供了加密的通信通道&#xff0c;…

SalesForce集成案例-获取联系人信息

SalesForce本身比较复杂&#xff0c;涉及的东西比较多&#xff0c;下面以使用REST API接口为例&#xff0c;介绍与SalesForce集成的过程&#xff0c;集成案例&#xff1a;获取联系人信息。 首先需要注册一个免费的开发者帐号&#xff0c;具有完全操作SalesForce的权限。 1、注…

【vmbox centos7 网络配置】【centos7 glances 安装】【centos7 安装MySQL5.7】

文章目录 vmbox centos7 网络配置centos7 修改镜像地址centos7 安装 glancesCentOS 7 上安装 MySQL 5.7 并进行基本的安全配置使用 firewalld 开放 3306 端口 可以远程连接mysql vmbox centos7 网络配置 目前 能组建集群 虚拟机网络互通&#xff0c;虚拟机能访问外网 创建一个…

OpenCV库Windows端编译方法

编译前提 &#xff08;1&#xff09;下载好所需版本的OpenCV源码&#xff0c;点击进入下载地址&#xff0c;此处以OpenCV-2.4.13.6为例&#xff0c;下载页面截图如下图所示&#xff1a; 解压后如下图所示&#xff1a; &#xff08;2&#xff09;安装好CMake软件&#xff0c;点…

数据结构(一)C语言补

数据结构 内存空间划分 一个进程启动后&#xff0c;会生成4G的内存空间 0~3G是用户空间(应用层) 3~4G是内核空间(底层) 0~3G 3~4G 所有的进程都会共享3G~4G的内核空间&#xff0c; 但是每个进程会独立拥有0~3G的用户空间。 栈区 存放数据特点 栈区存放数据的申请空间的先后…

Ecology Letters | 植物多样性-生产力关系的正反馈机制:基于BEF-China的7年大规模实验数据

本文首发于“生态学者”微信公众号&#xff01; 森林提供了丰富的生态系统功能和服务&#xff0c;如生物质生产、碳固存、气候调节、水过滤和防止土壤侵蚀。森林生物多样性丧失的空前速度可能会严重损害世界森林提供基本生态系统功能和服务的能力。因此&#xff0c;了解物种丧失…

【Linux进阶】Linux目录配置,FHS

在了解了每个文件的相关种类与属性&#xff0c;以及了解了如何修改文件属性与权限的相关信息后&#xff0c;再来要了解的就是&#xff0c;为什么每个Linux发行版它们的配置文件、执行文件、每个目录内放置的东西&#xff0c;其实都差不多&#xff1f;原来是有一套标准依据&…

CentralCache中心缓存

目录 一.CentralCache基本结构 1.CentralCache任务 2.基本结构 二.函数调用层次结构/.h文件 三.Span和SpanList的封装 Span:大块内存跨度 PAGE_ID _pageId size_t _objSize _useCount SpanList:管理Span的双链表(桶锁) 四.获取大块内存GetOneSpan 五.FetchRangeObj输…

C语言作业笔记

1. 要找俩个数使其相加等于一个数&#xff0c;那么俩个数从头尾出发&#xff0c;先动一边&#xff0c;假设是尾先动&#xff0c;一开始俩个数相加大于sum&#xff08;小于的话就动头&#xff09;&#xff0c;那么总有一时刻俩数相加小于sum&#xff0c;则就在那一刻停下来&…

MySQL高可用(MHA高可用)

什么是 MHA MHA&#xff08;MasterHigh Availability&#xff09;是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。 MHA 的出现就是解决MySQL 单点的问题。 MySQL故障切换过程中&#xff0c;MHA能做到0-30秒内自动完成故障切换操作。 MHA能在故障切换的过程中最大…

机器学习与模式识别_清华大学出版社

contents 前言第1章 绪论1.1 引言1.2 基本术语1.3 假设空间1.4 归纳偏好1.5 发展历程1.6 应用现状 第2章 模型评估与选择2.1 经验误差与过拟合2.2 评估方法2.3 性能度量2.3.1 回归任务2.3.2 分类任务 2.4 比较检验2.5 偏差与方差2.5.1 偏差-方差分解2.5.2 偏差-方差窘境 第3章 …

In Ictu Oculi: Exposing AI Created Fake Videos by Detecting Eye Blinking

文章目录 In Ictu Oculi: Exposing AI Created Fake Videos by Detecting Eye Blinking背景关键点内容预处理Long-Term Recurrent CNNsLSTM-RNN模型训练实验data启示In Ictu Oculi: Exposing AI Created Fake Videos by Detecting Eye Blinking 会议:2018 IEEE International…

用Vue3和Rough.js绘制一个交互式3D图

本文由ScriptEcho平台提供技术支持 项目地址&#xff1a;传送门 基于Rough.js和GSAP创建交互式SVG图形卡片 应用场景 本代码适用于需要创建动态交互式SVG图形卡片的场景&#xff0c;例如网页设计、数据可视化和交互式艺术作品。 基本功能 该代码利用Rough.js和GSAP库&…