Pytorch深度强化学习2-1:基于价值的强化学习——DQN算法

news2025/3/17 11:31:08

目录

  • 0 专栏介绍
  • 1 基于价值的强化学习
  • 2 深度Q网络与Q-learning
  • 3 DQN原理分析
  • 4 DQN训练实例

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 基于价值的强化学习

根据不动点定理,最优策略和最优价值函数是唯一的(对该经典理论不熟悉的请看Pytorch深度强化学习1-4:策略改进定理与贝尔曼最优方程详细推导),通过优化价值函数间接计算最优策略的方法称为基于价值的强化学习(value-based)框架。设状态空间为 n n n维欧式空间 S = R n S=\mathbb{R} ^n S=Rn,每个维度代表状态的一个特征。此时状态-动作值函数记为

Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)

其中 s \boldsymbol{s} s是状态向量, a \boldsymbol{a} a是动作空间中的动作向量, θ \boldsymbol{\theta } θ是神经网络的参数向量。深度学习完成了从输入状态到输出状态-动作价值的映射

s → Q ( s , a ; θ ) [ Q ( s , a 1 ) Q ( s , a 2 ) ⋯ Q ( s , a m ) ] T    ( a 1 , a 2 , ⋯   , a m ∈ A ) \boldsymbol{s}\xrightarrow{Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}\left[ \begin{matrix} Q\left( \boldsymbol{s},a_1 \right)& Q\left( \boldsymbol{s},a_2 \right)& \cdots& Q\left( \boldsymbol{s},a_m \right)\\\end{matrix} \right] ^T\,\, \left( a_1,a_2,\cdots ,a_m\in A \right) sQ(s,a;θ) [Q(s,a1)Q(s,a2)Q(s,am)]T(a1,a2,,amA)

相当于对无穷维Q-Table的一次隐式查表,对经典Q-learing算法不熟悉的请看Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)、Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫。设目标价值函数为 Q ∗ Q^* Q,若采用最小二乘误差,可得损失函数为

J ( θ ) = E [ 1 2 ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) 2 ] J\left( \boldsymbol{\theta } \right) =\mathbb{E} \left[ \frac{1}{2}\left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) ^2 \right] J(θ)=E[21(Q(s,a)Q(s,a;θ))2]

采用梯度下降得到参数更新公式为

θ ← θ + α ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) ∂ Q ( s , a ; θ ) ∂ θ \boldsymbol{\theta }\gets \boldsymbol{\theta }+\alpha \left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) \frac{\partial Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}{\partial \boldsymbol{\theta }} θθ+α(Q(s,a)Q(s,a;θ))θQ(s,a;θ)

随着迭代进行, Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)将不断逼近 Q ∗ Q^* Q,由 Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)进行的策略评估和策略改进也将迭代至最优。

2 深度Q网络与Q-learning

Q-learning和深度Q学习(Deep Q-learning, DQN)是强化学习领域中两种重要的算法,它们在解决智能体与环境之间的决策问题方面具有相似之处,但也存在一些显著的异同。这里进行简要阐述以加深对二者的理解。

  • Q-learning是一种基于值函数的强化学习算法。它通过使用Q-Table来表示每个状态和动作对的预期回报。Q值函数用于指导智能体在每个时间步选择最优动作。通过不断更新Q值函数来使其逼近最优的Q值函数
  • DQN是对Q-learning的深度网络版本,它将神经网络引入Q-learning中,以处理具有高维状态空间的问题。通过使用深度神经网络作为函数逼近器,DQN可以学习从原始输入数据(如像素值)直接预测每个动作的Q值

在这里插入图片描述

3 DQN原理分析

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池(Experience Replay):考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性;
  • 目标网络(Target Network):考虑到若目标价值 与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建与 结构相同但慢于 更新的独立目标网络来评估目标价值,使模型更稳定。

拟合了高维状态空间,是Q-Learning算法的深度学习版本,算法流程如表所示

在这里插入图片描述

4 DQN训练实例

最简单的例子是使用全连接网络来构造DQN

class DQN(nn.Module):
	def __init__(self, input_dim, output_dim):
	    super(DQN, self).__init__()
	    self.input_dim = input_dim
	    self.output_dim = output_dim
	    
	    self.fc = nn.Sequential(
	        nn.Linear(self.input_dim[0], 128),
	        nn.ReLU(),
	        nn.Linear(128, 256),
	        nn.ReLU(),
	        nn.Linear(256, self.output_dim)
	    )
	
	def __str__(self) -> str:
	    return "Fully Connected Deep Q-Value Network, DQN"
	
	def forward(self, state):
	    qvals = self.fc(state)
	    return qvals

基于贝尔曼最优原理的损失计算如下

def computeLoss(self, batch):
    states, actions, rewards, next_states, dones = batch
    states = torch.FloatTensor(states).to(self.device)
    actions = torch.LongTensor(actions).to(self.device)
    rewards = torch.FloatTensor(rewards).to(self.device)
    next_states = torch.FloatTensor(next_states).to(self.device)
    dones = (1 - torch.FloatTensor(dones)).to(self.device)

    # 根据实际动作提取Q(s,a)值
    curr_Q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_Q = self.target_model(next_states)
    max_next_Q = torch.max(next_Q, 1)[0]
    expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q * dones

    loss = self.criterion(curr_Q, expected_Q.detach())
    return loss

基于经验回放池和目标网络的参数更新如下

def update(self, batch_size):
	batch = self.replay_buffer.sample(batch_size)
	loss = self.computeLoss(batch)
	self.optimizer.zero_grad()
	loss.backward()
	self.optimizer.step()
	
	# 更新target网络
	for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
	    target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
	
	# 退火
	self.epsilon = self.epsilon + self.epsilon_delta \
	    if self.epsilon < self.epsilon_max else self.epsilon_max

基于DQN可以实现最基本的智能体,下面给出一些具体案例

  • Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1336343.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

swing快速入门(二十九)播放器工具条

注释很详细&#xff0c;直接上代码 上一篇 新增内容 1.工具条按钮添加响应及图标 2.为控件添加滚动条&#xff08;通用&#xff09; 3.在工具按钮之间添加自动间隔 4.设置工具条的可否移动状态&#xff08;默认可移动&#xff09; package swing21_30;import javax.swing…

Spring系列学习二、Spring框架的环境配置

Spring框架的环境配置 一、Java环境配置二、 Spring框架的安装与配置三、Maven与Gradle环境的配置四、IDE环境配置&#xff08;Eclipse与IntelliJ IDEA&#xff09;五、结语 一、Java环境配置 所有编程旅程总是得从基础开始&#xff0c;如同乐高积木大作的基座&#xff0c;首先…

Ubuntu20.04-查看GPU的使用情况及输出详解

1. 查看GPU的使用情况 1.1 nvidia-smi # 直接在终端得到显卡的使用情况 # 不会自动刷新 nvidia-smi# 重定向到文件中 nvidia-smi > nvidia_smi_output.txt# 如果输出的内容部分是以省略号表示的&#xff0c;可以-q nvidia-smi -q 1.2 nvidia-smi -l # 会自动刷新&#x…

Python高级用法:enumerate(枚举)

enumerate&#xff08;枚举&#xff09; 在编写代码时&#xff0c;为了遍历列表并获取每个元素在列表中的索引&#xff0c;我们可以使用Python中的enumerate函数。下面是一个简单的例子&#xff0c;演示了如何使用enumerate函数实现相同的功能。 原始代码片段&#xff1a; i…

keepalived高可用 | 部署Ceph分布式存储

keepalived高可用 | 部署Ceph分布式存储 keepalived高可用1. 配置第二台haproxy代理服务器部署HAProxy 2.为两台代理服务器配置keepalived配置第一台代理服务器proxy (192.168.4.5)配置第二台代理服务器proxy (192.168.4.6)修改DNS服务器 部署ceph分布式存储准备硬件实验环境准…

10个值得收藏的机器视觉标注工具

推荐&#xff1a;用 NSDT编辑器快速搭建可编程3D场景 我们知道寻找良好的图像标记和注释工具对于创建准确且有用的数据集的重要性。 随着图像注释空间的增长&#xff0c;我们看到开源工具的可用性激增&#xff0c;这些工具使任何人都可以免费标记他们的图像并从强大的功能中受益…

Java代理设计模式--静态代理和动态代理

文章目录 代理设计模式概念代理模式的定义与特点代理模式的结构与实现代理模式的应用场景静态代理实例代理模式的扩展动态代理实现方式JDK动态代理与实例Cglib动态代理JDK动态代理与CGLIB对比 代理设计模式 概念 在有些情况下&#xff0c;一个客户不能或者不想直接访问另一个…

【中小型企业网络实战案例 二】配置网络互连互通

​【中小型企业网络实战案例 一】规划、需求和基本配置-CSDN博客 热门IT技术视频教程&#xff1a;https://xmws-it.blog.csdn.net/article/details/134398330?spm1001.2014.3001.5502 配置接入层交换机 1.以接入交换机ACC1为例&#xff0c;创建ACC1的业务VLAN 10和20。 <…

NFC物联网一次性口令认证解决方案

物联网是由无线传感器网络、射频识别(RadioFrequency Identificalion&#xff0c;RFID)网络、互联网等构成的一种复合型网络&#xff0c;具有部分终端设备体积小、存储和计算处理能力弱的特点。顾名思义&#xff0c;物联网就是“物物相连的互联网”&#xff0c;也就是说,物联网…

Visual Studio2022配置ReSharper C++ 常用设置

如需安装免费的可以在下面留言&#xff0c;看到即回复 文章目录 Visual Studio2022配置ReSharper C 常用设置配置Visual Studio2022&#xff0c;使其能够按回车进行补全配置ReSharper C 设置自动弹出配置ReSharper C 的快捷键ReSharper C 去掉注释拼写使用中文注释 如何关闭新版…

OAuth2.0 四种授权方式讲解

一、OAuth2.0 的理解 OAuth2是一个开放的授权标准&#xff0c;允许第三方应用程序以安全可控的方式访问受保护的资源&#xff0c;而无需用户将用户名和密码信息与第三方应用程序共享。OAuth2被广泛应用于现代Web和移动应用程序开发中&#xff0c;可以简化应用程序与资源服务器之…

在国内如何在速卖通上买东西(在速卖通aliexpress上付款)??

一、速卖通aliexpress上购物流程 1. 登录速卖通aliexpress网站&#xff0c;点击“注册”按钮。 2. 输入您的邮箱地址&#xff0c;然后单击“验证/联系”按钮&#xff1b; 3. 使用您的信用卡支付订单金额&#xff0c;点击获取信用卡 4. 在“我的订单管理器”中查看订单信息。 …

学习笔记14——Springboot以及SSMP项目

SpringBoot Springboot项目 IDEA2023只能创建jdk17和21的springboot项目解决 - 嘿嘿- - 博客园 (cnblogs.com)解决IntelliJ IDEA2022.03创建包时&#xff0c;包结构不自动分级显示的问题_idea建包不分级-CSDN博客IDEA调出maven项目窗口_idea maven窗口-CSDN博客 相比于spring的…

【2023下算法课设】Gray码的分治构造算法

Gray码是一个长度为2ⁿ的序列&#xff0c;序列中无相同元素&#xff0c;且每个元素都是长度为n位的二进制位串&#xff0c;相邻元素恰好只有1位不同。例如长度为2的格雷码为&#xff08;000,001,011,010,110,111,101,100&#xff09;&#xff0c;设计分治算法对任意的n值构造相…

如何使用设计模式来解决类与类之间调用过深的问题。

我们将使用责任链模式和装饰者模式的组合。 考虑一个简化的餐厅订单处理系统&#xff0c;其中包括服务员&#xff08;Waiter&#xff09;、厨师&#xff08;Chef&#xff09;和收银员&#xff08;Cashier&#xff09;。订单从服务员开始&#xff0c;然后传递给厨师&#xff0c…

python区块链简单模拟【05】

新增内容&#xff1a;构建去中心化网络 import socket #套接字&#xff0c;利用三元组【ip地址&#xff0c;协议&#xff0c;端口】可以进行网络间通信 import threading #线程 import pickle# 定义一个全局列表保存所有节点 NODE_LIST []class Node(threading.Thread…

目标检测-Two Stage-RCNN

文章目录 前言一、R-CNN的网络结构及步骤二、RCNN的创新点候选区域法特征提取-CNN网络 总结 前言 在前文&#xff1a;目标检测之序章-类别、必读论文和算法对比&#xff08;实时更新&#xff09;已经提到传统的目标检测算法的基本流程&#xff1a; 图像预处理 > 寻找候选区…

手术麻醉临床信息系统源码,客户端可以接入监护仪、麻醉机、呼吸机

一、手术麻醉临床信息管理系统介绍 1、手术麻醉临床信息管理系统是数字化手段应用于手术过程中的重要组成部分&#xff0c;用数字形式获取并存储手术相关信息&#xff0c;既便捷又高效。既然是管理系统&#xff0c;那就是一整套流程&#xff0c;管理患者手术、麻醉的申请、审批…

NVIDIA Jetson Nano 2GB 系列文章(9):调节 CSI 图像质量

NVIDIA英伟达中国 ​在本系列上一篇文章中&#xff0c;我们为大家展示了如何执行常见机器视觉应用。在本篇文章中&#xff0c;我们将带领大家调节 CSI 图像质量。 前面两篇文章在 Jetson Nano 2GB 上使用 CSI 摄像头做了几个实验&#xff0c;效果很不错&#xff0c;并且很容易…

分布式系统架构设计之分布式通信机制

二、分布式通信机制&#xff1a;保障系统正常运行基石 在分布式系统中&#xff0c;各个组件之间的通信是保障系统正常运行的基石&#xff0c;直接影响到系统的性能、可扩展性以及整体的可维护性。接下来我们就一起看看通信在分布式系统中的重要性&#xff0c;以及一些常用的技…