强化学习11——DQN算法

news2025/1/11 19:48:12

DQN算法的全称为,Deep Q-Network,即在Q-learning算法的基础上引用深度神经网络来近似动作函数 Q ( s , a ) Q(s,a) Q(s,a) 。对于传统的Q-learning,当状态或动作数量特别大的时候,如处理一张图片,假设为 210 × 160 × 3 210×160×3 210×160×3,共有 25 6 ( 210 × 60 × 3 ) 256^{(210×60×3)} 256(210×60×3)种状态,难以存储,但可以使用参数化的函数 Q θ Q_{\theta} Qθ 来拟合这些数据,即DQN算法。同时DQN还引用了经验回放和目标网络,接下来将以此介绍。

CartPole 环境

image.png

在车杆环境中,通过移动小车,让小车上的杆保持垂直,如果杆的倾斜度数过大或者车子偏离初始位置的距离过大,或者坚持了一定的时间,则结束本轮训练。该智能体的状态是四维向量,每个状态是连续的,但其动作是离散的,动作的工作空间是2。

维度意义最小值最大值
0车的位置-2.42.4
1车的速度-InfInf
2杆的角度~ -41.8°~ 41.8°
3杆尖端的速度-InfInf
标号动作
0向左移动小车
1向右移动小车

深度网络

我们通过神经网络将输入向量 x x x映射到输出向量 y y y,通过下式表示:
y = f θ ( x ) y=f_{\theta}(x) y=fθ(x)
神经网络可以理解为是一个函数,输入输出都是向量,并且拥有可以学习的参数 θ \theta θ ,通过梯度下降等方法,使得神经网络能够逼近任意函数,当然可以用来近似动作价值函数:
y ⃗ = Q θ ( s ⃗ , a ⃗ ) \vec{y}=Q_{\theta}(\vec{s},\vec{a}) y =Qθ(s ,a )
在本环境种,由于状态的每一维度的值都是连续的,无法使用表格记录,因此可以使用一个神经网络表示函数Q。当动作是连续(无限)时,神经网络的输入是状态s和动作a,输出一个标量,表示在状态s下采取动作a能获得的价值。若动作是离散(有限)的,除了采取动作连续情况下的做法,还可以只将状态s输入到神经忘了,输出每一个动作的Q值。

假设使用神经网络拟合w,则每一个状态s下所有可能动作a的Q值为 Q w ( s , a ) Q_w(s,a) Qw(s,a),我们称为Q网络:

image.png

我们在Q-learning种使用下面的方式更新:
Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a'\in\mathcal{A}}Q(s',a')-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]
即让 Q ( s , a ) Q(s,a) Q(s,a) r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal{A}}Q(s',a') r+γmaxaAQ(s,a)靠近,那么Q网络的损失函数为均方误差的形式:
ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_{\omega}\frac{1}{2N}\sum_{i=1}^{N}\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a'}Q_\omega\left(s_i',a'\right)\right)\right]^2 ω=argωmin2N1i=1N[Qω(si,ai)(ri+γamaxQω(si,a))]2

经验回访

将Q-learning过程中,每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,之后在训练Q网络时,再从回访缓冲区中,随机采样若干数据进行训练。

image.png

在一般的监督学习中,都是假定训练数据是独立同分布的,而在强化学习中,连续的采样、交互所得到的数据有很强的相关性,这一时刻的状态和上一时刻的状态有关,不满足独立假设。通过在回访缓冲区采样,可以打破样本之间的相关性。另外每一个样本可以使用多次,也适合深度学习。

目标网络

构建两个网络,一个是目标网络,一个是当前网络,二者结构相同,都用于近似Q值。在实践中每隔若干步才把每步更新的当前网络参数复制给目标网络,这样做的好处是保证训练的稳定,当训练的结果不好时,可以不同步当前网络的值,避免Q值的估计发散。

image.png

在计算期望时,使用目标网络来计算:
Q 期望 = [ r t + γ max ⁡ a ′ Q ω ˉ ( s ′ , a ′ ) ] Q_\text{期望}=[r_t+\gamma\max_{a^{\prime}}Q_{\bar{\omega}}(s^{\prime},a^{\prime})] Q期望=[rt+γamaxQωˉ(s,a)]
具体流程如下所示:

  • 使用随机的网络参数 ω \omega ω初始化初始化当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a)
  • 复制相同的参数初始化目标网络 ω ˉ ← ω \bar{\omega}\gets \omega ωˉω
  • 初始化经验回访池R
  • for 序列 e = 1 → E e=1\to E e=1E do
    • 获取环境初始状态 s 1 s_1 s1
    • for 时间步 t = 1 → T 时间步t=1\to T 时间步t=1T do
      • 根据当前网络 Q ω ( s , a ) Q_{\omega}(s,a) Qω(s,a) ϵ − g r e e d y \epsilon -greedy ϵgreedy策略选择动作 a t a_t at
      • 执行动作 a t a_t at,获得回报 r t r_t rt,环境状态变为 s t + 1 s_{t+1} st+1
      • ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)存储进回池R
      • 若R中数据足够,则从R中采样N个数据 { ( s i , a i , r i , s i + 1 ) } i = 1 , … , N \{(s_i,a_i,r_i,s_{i+1})\}_{i=1,\ldots,N} {(si,ai,ri,si+1)}i=1,,N
      • 对每个数据,用目标网络计算 y = r i + γ max ⁡ a Q ω ˉ ( s i + 1 , a ) y=r_i+\gamma\max_aQ_{\bar{\omega}}(s_{i+1},a) y=ri+γmaxaQωˉ(si+1,a)
      • 最小化目标损失 L = 1 N ∑ i ( y i − Q ω ( s i , a i ) ) 2 L=\frac{1}{N}\sum_{i}(y_{i}-Q_{\omega}(s_{i},a_{i}))^{2} L=N1i(yiQω(si,ai))2,以更新当前网络 Q ω Q_{\omega} Qω
      • 更新目标网络
    • end for
  • end for
import random
from typing import Any
import gymnasium as gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import rl_utils

# 首先定义经验回收池的类,包括加入数据、采样数据
class ReplayBuffer:
    def __init__(self, capacity):
        # 创建一个队列,先进先出
        self.buffer=collections.deque(maxlen=capacity)
        
    def add(self,state,action,reward,next_state,done):
        # 加入数据
        self.buffer.append((state,action,reward,next_state,done))
        
    def sample(self,batch_size):
        # 随机采样数据
        mini_batch=random.sample(self.buffer,batch_size)
        # zip(*)取mini_batch中的每个元素(即取列),并返回一个元组
        state,action,reward,next_state,done=zip(*mini_batch)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self):
        return len(self.buffer)
    
# 定义一个只有一层隐藏层的Q网络
class Qnet(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Qnet,self).__init__()
        # 定义一个全连接层,输入为state_dim维向量,输出为hidden_dim维向量
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        # 定义一个全连接层,输入为hidden_dim维向量,输出为action_dim维向量
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)
        
    def forward(self,state):
        x = F.relu(self.fc1(state))
        return self.fc2(x)
    
class DQN:
    def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device):
        self.action_dim=action_dim
        self.q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)
        # 目标网络
        self.target_q_net=Qnet(state_dim,hidden_dim,action_dim).to(device)
        # 使用Adam优化器
        self.optimizer=torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)
        # 折扣因子
        self.gamma=gamma
        # 贪婪策略
        self.epsilon=epsilon
        # 目标网络更新频率
        self.target_update=target_update
        # 计数器
        self.count=0
        self.device=device
        
    def take_action(self,state):
        # 判断是否需要贪婪策略
        if np.random.random()<self.epsilon:
            action=np.random.randint(self.action_dim)
        else:
            state=torch.tensor([state],dtype=torch.float).to(self.device)
            action=self.q_net(state).argmax().item()
        return action

    def update(self,transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        # Q值
        q_values=self.q_net(states).gather(1,actions)
        # 下一个状态的最大Q值
        max_next_q_values=self.target_q_net(next_states).max(1)[0].view(-1, 1)
        q_targets=rewards+self.gamma*max_next_q_values*(1-dones)
        # 反向传播更新参数
        dqn_loss=torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
        
        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(
                self.q_net.state_dict())  # 更新目标网络
        self.count += 1
        
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,
            target_update, device)

return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
        for i_episode in range(int(num_episodes / 10)):
            episode_return = 0
            state = env.reset()[0]
            aa=state[0]
            print(state)
            done = False
            while not done:
                action = agent.take_action(state)
                next_state, reward, done,info, _ = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {
                        'states': b_s,
                        'actions': b_a,
                        'next_states': b_ns,
                        'rewards': b_r,
                        'dones': b_d
                    }
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode + 1) % 10 == 0:
                pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1389623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android性能优化 | DEX 布局优化和启动配置文件

Android性能优化 | DEX 布局优化和启动配置文件 引言 使用DEX布局优化和启动配置文件是优化Android应用性能的有效途径。DEX布局优化可以通过优化应用程序中的DEX文件布局&#xff0c;从而加快Android应用的启动速度和执行速度。启动配置文件则提供了一种灵活的方式来控制应用…

SDRAM小项目——命令解析模块

简单介绍&#xff1a; 在FPGA中实现命令解析模块&#xff0c;命令解析模块的用来把pc端传入FPGA中的数据分解为所需要的数据和触发命令&#xff0c;虽然代码不多&#xff0c;但是却十分重要。 SDRAM的整体结构如下&#xff0c;可以看出&#xff0c;命令解析模块cmd_decode负责…

知存科技助力AI应用落地:WTMDK2101-ZT1评估板实地评测与性能揭秘

文章目录 一、前言二、深入了解存算一体技术2.1 什么是存算一体2.2 存算一体技术发展历程2.3 基于不同存储介质的存内计算芯片性能比较 三、国产存算一体&#xff0c;重大进展3.1 知存科技&#xff1a;我国存算一体领域的研发领导者 四、知存科技新型 WTM2101 SOC 评估板使用评…

【RTOS】快速体验FreeRTOS所有常用API(9)中断管理

目录 九、中断管理9.1 基本概念9.2 两套API9.3 中断服务程序 九、中断管理 该部分暂无代码 9.1 基本概念 FreeRTOS中的中断管理主要涉及&#xff1a;两套API、中断服务编写 不能阻塞、不能处理耗时任务ISR的优先级高于任务&#xff1a;即使是优先级最低的中断&#xff0c;它的…

【小黑嵌入式系统第十五课】μC/OS-III程序设计基础(四)——消息队列(工作方式数据通信生产者消费者模型)、动态内存管理、定时器管理

上一课&#xff1a; 【小黑嵌入式系统第十四课】μC/OS-III程序设计基础&#xff08;三&#xff09;——信号量&#xff08;任务同步&资源同步&#xff09;、事件标记组&#xff08;与&或&多个任务&#xff09; 前些天发现了一个巨牛的人工智能学习网站&#xff0c…

美媒:OpenAI删除了禁止其技术被用于军事用途的条款

据美国调查新闻网站“拦截者”&#xff08;The Intercept&#xff09;1月12日报道&#xff0c;美国知名人工智能企业、ChatGPT母公司OpenAI近日悄悄修改了其产品的使用条款&#xff0c;删除了禁止将OpenAI技术用于军事用途的条文。 报道称&#xff0c;在今年1月10日之前&#…

Springboot+vue的智能无人仓库管理(有报告),Javaee项目,springboot vue前后端分离项目

演示视频&#xff1a; Springbootvue的智能无人仓库管理&#xff08;有报告&#xff09;&#xff0c;Javaee项目&#xff0c;springboot vue前后端分离项目 项目介绍&#xff1a; 本文设计了一个基于Springbootvue的前后端分离的智能无人仓库管理&#xff0c;采用M&#xff08…

基于Springboot的私人健身与教练预约管理系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的私人健身与教练预约管理系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三…

java基础:求数组的和以及平均数案例分析

/* * * 解题思路&#xff1a;首先定义一个包含数字的数组hens&#xff0c; * 然后使用循环遍历数组中的每个元素&#xff0c;并将其累加到sum变量中。 * 最后&#xff0c;将sum除以数组长度得到平均值avg。最终将结果打印输出到控制台。*/ 代码如下&#xff1a; package idea;…

机器学习算法实战案例:GRU 实现多变量多步光伏预测

文章目录 1 数据处理1.1 数据集简介1.2 导入库文件1.3 数据集处理1.4 训练数据构造 2 模型训练与预测2.1 模型训练2.2 模型多步预测2.3 预测可视化 答疑&技术交流机器学习算法实战案例系列 1 数据处理 1.1 数据集简介 实验数据集采用数据集7&#xff1a;常州普利司通光伏…

[linux]同步缓冲区数据到flash

一、需求 由于linux自身策略问题&#xff0c;在某些情况下需主动同步数据到flash。 二、方案 同步数据的两种方式&#xff1a;sync和fsync 2.1sync 将整个缓冲区同步至flash。性能较差。 2.2fsync -d [file] 将某一文件的数据同步至flash。 三、应用实例 3.1sync 3.2f…

虚拟服务器的监控和管理

IT 网络中虚拟环境的出现给 IT 管理员带来了一些挑战&#xff0c;虚拟环境降低了管理硬件和软件的成本和复杂性&#xff0c;同时&#xff0c;他们通常需要 IT 管理员管理更多的空间&#xff0c;以确保完全可见和快速解决问题。 虚拟服务器在现代 IT 基础架构中越来越普遍&…

NXP采用RS RTS测试系统,验证28纳米RFCMOS雷达单芯片 |百能云芯

Rohde & Schwarz的雷达目标模拟器R&S RTS&#xff0c;作为汽车雷达的颠覆性解决方案&#xff0c;尤其是其能够电子模拟非常近距离物体的能力&#xff0c;已被用于验证NXP半导体的下一代雷达传感器参考设计的性能。 这一合作使汽车行业在汽车雷达的发展上迈出了一步&…

.net core 6 集成nacos的服务注册和配置中心

1、安装nuget包 2、加上配置文件 "nacos": {"ServerAddresses": [ "http://127.0.0.1:8848" ],"GroupName": "DEFAULT_GROUP","ClusterName": "DEFAULT","ServiceName": "webapi"…

【Linux实用篇】项目部署 基于Shell脚本自动部署

目录 1. 项目部署 1.1 手动部署项目 1.2 基于Shell脚本自动部署 1.2.1 介绍 1.2.2 推送代码到远程 1.2.3 Git操作 1.2.4 Maven安装 1.2.5 Shell脚本准备 1.2.6 Linux权限 1.2.7 授权并执行脚本 1.2.8 设置静态IP 1. 项目部署 之前我们讲解Linux操作系统时&#xff0…

基于拓扑图与领导跟随法的编队控制算法

matlab2020可运行 https://download.csdn.net/download/ljjjjjjjjjjj/88750436

导航与定位技术已成为移动机器人的核心技术之一

随着移动机器人技术的不断发展和应用领域的扩大&#xff0c;导航与定位技术已成为移动机器人的核心技术之一。本文将介绍移动机器人导航与定位技术的发展现状、技术前沿和面临的挑战。 ​ 一、导航与定位技术的发展现状 移动机器人的导航与定位技术是实现自主移动的关键。目前…

Vscode——运行java项目

1、安装JDK&#xff08;已安装请忽略&#xff09; 教程&#xff1a;https://blog.csdn.net/qq812457115/article/details/117451011 验证是否已安装&#xff0c;命令符输入 java -version2、安装Maven&#xff08;已安装请忽略&#xff09; 教程&#xff1a;https://blog.csdn…

LeetCode刷题---反转链表II

LeetCode官方给出的解题思路 在需要反转的区间里&#xff0c;每遍历到一个节点&#xff0c;让这个新节点来到反转部分的起始位置。 使用了三指针的思想。 定义三个节点: curr&#xff1a;指向待反转区域的第一个节点 left&#xff1b; next&#xff1a;永远指向 curr 的下一个节…

怎样实现安全便捷的网间数据安全交换?

数据安全交换是指在数据传输过程中采取一系列措施来保护数据的完整性、机密性和可用性。网间数据安全交换&#xff0c;则是需要进行跨网络、跨网段甚至跨组织地进行数据交互&#xff0c;对于数据的传输要求会更高。 大部分企业都是通过网闸、DMZ区、VLAN、双网云桌面等方式实现…