【深度学习】最强算法之:深度Q网络(DQN)

news2024/10/7 10:16:51

深度Q网络

  • 1、引言
  • 2、深度Q网络
    • 2.1 定义
    • 2.2 原理
    • 2.3 实现方式
    • 2.4 算法公式
    • 2.5 代码示例
  • 3、总结

1、引言

小屌丝:鱼哥, 马上清明小长假了, 你这准备去哪里玩啊?
小鱼:哪也不去,在家待着
小屌丝:在家? 待着? 干啥啊?
小鱼:啥也不干,床上躺着
小屌丝:床上… 躺着… 做啥啊?
小鱼:啥也不做,睡觉
小屌丝:睡觉?? 这大白天的,确定睡觉?
小鱼:我擦… 你这wc~
小屌丝:我很正经的好不好。
小鱼:… 我有点事,待会说
小屌丝: 待会,没时间了哦
小鱼:那就在多几个待会的
小屌丝:这火急火燎的, 肯定"有事"。
在这里插入图片描述

2、深度Q网络

2.1 定义

深度Q网络(DQN)是一种结合了深度学习和Q-learning的强化学习算法。它通过深度神经网络逼近值函数,并利用经验回放和目标网络等技术,使得Q-learning能够在高维连续状态空间中稳定学习。

2.2 原理

DQN的核心原理是利用深度神经网络来估计Q值函数。
在每个时刻,DQN根据当前状态s和所有可能的动作a计算出一组Q值,然后选择Q值最大的动作执行。
执行动作后,环境会给出新的状态s’和奖励r,DQN将这些信息存储到经验回放缓存中。

在训练过程中,DQN从经验回放缓存中随机采样一批历史数据,利用这些数据进行梯度下降更新神经网络参数。

此外,DQN还引入了目标网络来稳定学习过程,即每隔一定步数将当前网络参数复制给目标网络,用于计算目标Q值。

2.3 实现方式

实现DQN主要包括以下步骤:

  • 初始化深度神经网络(Q网络)和目标网络(目标Q网络)。
  • 初始化经验回放缓存。
  • 对于每个训练回合:
    • 初始化状态s。
    • 对于每个时间步t:
      • 使用ε-贪婪策略选择动作a。
      • 执行动作a,观察奖励r和新状态s’。
      • 将经验(s, a, r, s’)存储到经验回放缓存中。
      • 从经验回放缓存中采样一批数据,计算损失函数并更新Q网络参数。
      • 每隔一定步数更新目标网络参数。
    • 重复上述步骤直至满足终止条件。

2.4 算法公式

DQN的损失函数通常采用均方误差(MSE)形式,即:

L ( θ ) = 1 / N ∗ Σ [ ( r + γ ∗ m a x a ′ Q ( s ′ , a ′ ; θ − ) − Q ( s , a ; θ ) ) 2 ] L(θ) = 1/N * Σ[(r + γ * max_a' Q(s', a'; θ⁻) - Q(s, a; θ))^2] L(θ)=1/NΣ[(r+γmaxaQ(s,a;θ)Q(s,a;θ))2]

其中,

  • θ θ θ Q Q Q网络参数,
  • θ − θ⁻ θ是目标网络参数,
  • N N N是采样数据批量大小,
  • γ γ γ是折扣因子,
  • r r r是奖励,
  • s s s a a a分别是当前状态和动作,
  • s ′ s' s是下一状态,
  • a ′ a' a是下一状态的所有可能动作。

2.5 代码示例

# -*- coding:utf-8 -*-
# @Time   : 2024-04-01
# @Author : Carl_DJ

'''
实现功能:
    使用PyTorch框架的简单DQN(Deep Q-Network)实现示例

'''
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# 创建一个简单的神经网络,作为Q网络
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), action, reward, np.array(next_state), done

    def __len__(self):
        return len(self.buffer)

# DQN算法实现
class DQNAgent:
    def __init__(self, input_dim, output_dim):
        self.model = DQN(input_dim, output_dim)
        self.target_model = DQN(input_dim, output_dim)
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters())
        self.buffer = ReplayBuffer(10000)
        self.steps_done = 0
        self.epsilon_start = 1.0
        self.epsilon_final = 0.01
        self.epsilon_decay = 500
        self.batch_size = 32

    def act(self, state):
        epsilon = self.epsilon_final + (self.epsilon_start - self.epsilon_final) * \
                  np.exp(-1. * self.steps_done / self.epsilon_decay)
        self.steps_done += 1
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_value = self.model(state)
            action = q_value.max(1)[1].item()
        else:
            action = random.randrange(2)
        return action

    def update(self):
        if len(self.buffer) < self.batch_size:
            return
        state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
        state = torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state)
        next_q_values = self.target_model(next_state)

        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + 0.99 * next_q_value * (1 - done)

        loss = (q_value - expected_q_value.data).pow(2).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())

# 训练环境设置
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

# 训练循环
episodes = 100
for episode in range(episodes):
    state = env.reset()
    total_reward = 0
    done = False
    while not done:
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        agent.buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        agent.update()
    agent.update_target()
    print('Episode: {}, Total reward: {}'.format(episode, total_reward))


解析:

  • 首先定义了一个简单的神经网络DQN,
  • 然后定义了ReplayBuffer用于经验回放,
  • 接着定义了DQNAgent类封装了DQN的决策、学习和目标网络更新逻辑。
  • 最后,通过创建一个gym环境(这里使用的是CartPole-v1)并在该环境中运行DQNAgent来进行训练。
    在这里插入图片描述

3、总结

深度Q网络(DQN)通过将深度学习与强化学习相结合,解决了传统Q-learning在高维连续状态空间中的维度灾难问题。

DQN利用深度神经网络的强大表征能力来估计Q值函数,并通过经验回放和目标网络等技术来稳定学习过程。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)测评一、二等奖获得者

关注小鱼,学习【机器学习】&【深度学习】领域的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1577567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何让阿里云AI001号员工帮我写代码(含IDEA插件使用)

国内首个AI程序员入职阿里云&#xff1a;专属工号AI001&#xff0c;KPI是一人写完公司20%代码。 不管是真是假&#xff0c;AI 程序员发展的趋势是无法改变的&#xff0c;小米汽车发布会上&#xff0c;雷军说到小米汽车工厂的自动化率达到90%以上&#xff0c;有些车间甚至100%的…

基于javassmJSP的家用电器销售网站

开发语言&#xff1a;Java 框架&#xff1a;ssm 技术&#xff1a;JSP JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclip…

面试字节被挂了

分享一个面试字节的经历。 1、面试过程 一面&#xff1a;上来就直接"做个题吧"&#xff0c;做完之后&#xff0c;对着简历上一个项目聊&#xff0c;一直聊到最后&#xff0c;还算比较正常。 二面&#xff1a;做自我介绍&#xff0c;花几分钟聊了一个项目&#xff…

江南大学酒科技馆OLED透明屏项目方案

一、项目概述 本项目旨在为无锡江南大学酒科技馆提供OLED透明屏解决方案&#xff0c;通过安装2x2的OLED透明屏&#xff0c;为参观者带来全新的视觉体验&#xff0c;同时提升酒科技馆的展示效果与科技感。 二、产品选型 本项目选用OLED透明屏&#xff0c;其具有高透明度、高对比…

Windows/Jerry

Jerry Enumeration nmap 扫描系统发现对外开放了 8080 端口&#xff0c;再次使用 nmap 扫描端口详细信息&#xff0c;发现运行着 Apache Tomcat ┌──(kali㉿kali)-[~/vegetable/HTB/Jerry] └─$ nmap -sV -sC -p 8080 -oA nmap 10.10.10.95 -Pn Starting Nmap 7.93 ( htt…

【分治算法】Strassen矩阵乘法Python实现

文章目录 [toc]问题描述基础算法时间复杂性 Strassen算法时间复杂性 问题时间复杂性Python实现 个人主页&#xff1a;丷从心. 系列专栏&#xff1a;Python基础 学习指南&#xff1a;Python学习指南 问题描述 设 A A A和 B B B是两个 n n n \times n nn矩阵&#xff0c; A A…

东方博宜 1426. 年龄与疾病

东方博宜 1426. 年龄与疾病 思路&#xff1a;1 读取数组 2 遍历数组并进行比较 遇到的坑是百分号且保留两位的输出方式&#xff0c;以及两个整数求商的时候要记得转换成小数形式 #include<iostream> #include<cstdio> using namespace std; int main() {int n ;cin…

第十四届蓝桥杯岛屿个数

题目描述&#xff1a; 小蓝得到了一副大小为 MN 的格子地图&#xff0c;可以将其视作一个只包含字符 0&#xff08;代表海水&#xff09;和 1&#xff08;代表陆地&#xff09;的二维数组&#xff0c;地图之外可以视作全部是海水&#xff0c;每个岛屿由在上/下/左/右四个方向上…

晶核养号攻略:如何轻松搬砖?两大要点!

晶核游戏中&#xff0c;想通过搬砖来养号并不是一件难事。本攻略将为你介绍两种主要的金币获取方式&#xff0c;让你轻松提升游戏财富&#xff0c;实现更多游戏目标。 一、刷深渊&#xff1a;稳定金币收入 深渊地图在晶核游戏中是一个稳定的金币来源。这张地图从55级开始可刷&…

【Segment Anything Model】十三:Meta的最新工作EfficientSAM,微调到自己的数据集,代码。

&#x1f349; 博主微信 cvxiayixiao 还有其他专栏点击头像查询 &#x1f353; 【Segment Anything Model】计算机视觉检测分割任务专栏。 &#x1f351; 【公开数据集预处理】特别是医疗公开数据集的接受和预处理&#xff0c;提供代码讲解。 &#x1f348; 【opencv图像处理】…

N4433A安捷伦N4433A电子校准件

181/2461/8938产品概述&#xff1a; 300 kHz至20 GHz频率范围标准3.5毫米接口通过单一连接实现快速完整的3或4端口校准NIST可追溯的精确校准减少连接器磨损用于直接控制PNA和ENA系列网络分析仪的USB接口可靠的固态开关提供混合3.5毫米公/母连接器选项 安捷伦N4433A微波电子校准…

代码随想录|Day34|动态规划03|343.整数拆分、96.不同的二叉搜索树

343.整数拆分 动规五步&#xff1a; 确定 dp[i] 含义&#xff1a;拆分数字 i&#xff0c;可以获得的最大乘积为 dp[i]。递推公式&#xff1a;dp[i] max(j * (i - j), j * dp[i - j])。i 可以被拆解为两个数&#xff08;j 和 i - j&#xff09;或者多个数&#xff08;j 和 dp[i…

app上架-您的应用存在最近任务列表隐藏风险活动的行为,不符合华为应用市场审核标准。

上架提示 您的应用存在最近任务列表隐藏风险活动的行为&#xff0c;不符合华为应用市场审核标准。 修改建议&#xff1a;请参考测试结果进行修改。 请参考《审核指南》第2.19相关审核要求&#xff1a;https://developer.huawei.com/consumer/cn/doc/app/50104-02 造成原因 …

数字电路基础(Digital Circuit Basis )

目录 一、什么是数字电路&#xff1f; &#xff08;Digital Circuit &#xff09; 1.概念 2.分类 3.优点 4.数电与模电的区别 二、数制 (十进制&#xff1a;Decimal) 1.概述 2.进位制 3.基数 4.位权 5.二进制的算术运算 三、编码 (二进制&#xff1a;Binary ) 1.什…

2024/4/1—力扣—按摩师

代码实现&#xff1a; 思路&#xff1a;打家劫舍题 int massage(int *nums, int numsSize) {if (nums NULL || numsSize 0) {return 0;}if (numsSize 1) {return nums[0];}int dp[numsSize];memset(dp, 0, sizeof(dp));dp[0] nums[0];dp[1] (nums[0] < nums[1] ? nums…

【大功率汽车大灯升压方案】LED恒流驱动芯片FP7208升压车灯调光应用,PWM内部转模拟,调光深度1%,无频闪顾虑,低亮无抖动

宝马X5前中排座椅宽大舒适&#xff0c;车厢内储物空间丰富。操控性能极佳&#xff0c;底盘稳扎精良。原车为氙气灯&#xff0c;其实宝马的氙气大灯配的比其他车型要好&#xff0c;照明效果是没得说的。但是不管什么灯久了都会出现光衰的情况。下面这辆宝马X5车灯已老化严重。 宝…

【Linux】安装+基本指令

&#x1f308;个人主页&#xff1a;秦jh__https://blog.csdn.net/qinjh_?spm1010.2135.3001.5343&#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/qinjh_/category_12625432.html 目录 Linux系统的安装 登录 XShell 下的复制粘贴 指令 pwd指令 ls指令 cd 指令 …

哪个好人,2024年还在做push攻略科普啊!

当拥有适当工具的时候&#xff0c;增加用户留存率的艰巨任务也能轻松解决。推送通知&#xff08;Push&#xff09;就是这样的宝藏工具&#xff0c;不用客户主动浏览&#xff0c;就可以触达客户。说起推送的优点&#xff08;Push&#xff09;&#xff0c;用户不需要主动触发&…

图像版PDF文件OCR识别转换为文本的3款免费工具软件

图像版PDF文件里面都是图片&#xff0c;要先通过OCR技术识别出文本&#xff0c;然后才能进行进一步处理编辑。下面是3个免费的PDF文件OCR识别软件工具&#xff1a; ●简可信PDF批量识别工具 简可信PDF批量识别工具是一款专门用于将PDF文件进行批量OCR&#xff08;光学字符识别…

React 集成三方登录按钮样式的插件库

按钮不提供任何社交逻辑。 效果如下&#xff1a; 原地址&#xff1a;https://www.npmjs.com/package/react-social-login-buttons 时小记&#xff0c;终有成。