【强化学习】常用算法之一 “DQN”

news2025/1/22 21:41:18

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        强化学习是机器学习中的一大分支,通过与环境的交互,通过不断试错来学习如何做出最优决策。Deep Q-Network(DQN)算法是强化学习中的经典算法之一,它结合了深度学习和Q-learning算法,可以在宽泛的任务领域中学习和解决问题。

本文将详细讲解强化学习常用算法之一“DQN”


  

目录

一、简介

二、发展史

三、算法公式

        1. Q-learning算法公式:

        2. 深度神经网络:

        3. DQN算法公式:

四、算法原理

五、算法功能

六、示例代码

七、总结


一、简介

        DQN算法是深度学习领域首次广泛应用于强化学习的算法模型之一。它于2013年由DeepMind公司的研究团队提出,通过将深度神经网络与经典的强化学习算法Q-learning结合,实现了对高维、连续状态空间的处理,具备了学习与规划的能力。

二、发展史

        在DQN算法提出之前,强化学习中的经典算法主要是基于表格的Q学习算法。这些算法在处理简单的低维问题时表现出色,但随着状态和动作空间的增加,表格表示的存储和计算复杂度呈指数级增长。为了解决这个问题,研究人员开始探索使用函数逼近的方法,即使用参数化的函数代替表格。

        之后,逐步发展出了一系列将深度学习应用于强化学习的算法。DQN算法是其中的一种。它是由Alex Krizhevsky等人在2013年提出的,是首个将深度学习与强化学习相结合的算法。DQN算法引入了经验回放和固定Q目标网络等技术,极大地提升了深度神经网络在强化学习中的性能。随后,DQN算法在Atari游戏中取得了比人类玩家更好的成绩,引起了广泛的关注和研究。

  1. Q-learning:Q-learning是强化学习中的经典算法,由Watkins等人在1989年提出。它使用一个Q表格来存储状态和动作的价值,通过不断更新和探索来学习最优策略。然而,Q-learning算法在面对大规模状态空间时,无法扩展。

  2. Deep Q-Network(DQN):DQN算法在2013年由DeepMind团队提出,通过使用深度神经网络来逼近Q函数的值,解决了状态空间规模大的问题。该算法采用了两个关键技术:经验回放和固定Q目标网络。

  3. 经验回放:经验回放是DQN算法的核心思想之一,它的基本原理是将智能体的经验存储在一个回放记忆库中,然后随机从中抽样,利用这些经验进行模型更新。这样做的好处是避免了样本间的相关性,提高了模型的稳定性和收敛速度。

  4. 固定Q目标网络:DQN算法使用两个神经网络,一个是主网络(online network),用于选择动作,并进行模型更新;另一个是目标网络(target network),用于计算目标Q值。目标网络的参数固定一段时间,这样可以减少目标的波动,提高模型的稳定性。

三、算法公式

        DQN算法的核心是Q-learning算法和深度神经网络的结合。

        1. Q-learning算法公式:

        Q-learning算法通过不断更新Q值来学习最优策略,其更新公式如下:

        其中,s_t表示当前状态,a_t表示选择的动作,r_t表示立即回报,s_t+1表示下一个状态,α是学习率,γ是折扣因子。 

        2. 深度神经网络:

        DQN算法使用深度神经网络拟合Q函数的值。输入是状态s,输出是不同动作的Q值。常用的神经网络结构是多层感知机(MLP)或卷积神经网络(CNN),通过训练来优化网络参数。网络的输出大小与动作空间的维度相同。

        3. DQN算法公式:

        DQN算法通过最小化Q函数的均方差损失来进行模型更新。其更新公式如下:

        其中,θ是网络参数,Q(s_t, a, θ-)表示目标网络的输出。 

四、算法原理

        DQN算法的原理是通过利用深度神经网络逼近Q函数的值,实现对高维、连续状态空间的处理。其核心思想是通过不断更新神经网络的参数,使其的输出Q值逼近真实的Q值,从而学习最优策略。

DQN算法的工作原理如下:

  1. 初始化:初始化主网络和目标网络的参数。

  2. 选择动作:根据当前状态s,使用ε-greedy策略选择动作a。

  3. 执行动作并观察回报:采取动作a,与环境交互,观察下一个状态s’和立即回报r。

  4. 存储经验:将(s, a, r, s’)存储到经验回放记忆库中。

  5. 从经验回放记忆库中随机抽样:从记忆库中随机抽样一批经验。

  6. 计算目标Q值:使用目标网络计算目标Q值,即max(Q(s’, a, θ-))。

  7. 更新主网络:根据损失函数L(θ)进行模型参数更新。

  8. 更新目标网络:定期更新目标网络的参数。

  9. 重复步骤2-8,直到达到终止条件。

五、算法功能

        DQN算法具有以下功能:

  1. 处理高维、连续状态空间:通过深度神经网络的逼近能力,可以处理高维、连续状态空间的问题。

  2. 学习和规划能力:通过与环境的交互和不断试错,DQN算法可以学习到最优策略,并具备一定的规划能力。

  3. 稳定性和收敛速度高:DQN算法通过经验回放和固定Q目标网络等技术,提高了模型的稳定性和收敛速度。

六、示例代码

        以下是一个使用DQN算法解决经典的CartPole问题的示例代码:

# -*- coding: utf-8 -*-
import gym
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

env = gym.make('CartPole-v0')
n_actions = env.action_space.n
n_states = env.observation_space.shape[0]

def create_dqn_model():
    model = Sequential()
    model.add(Dense(32, input_shape=(n_states,), activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(n_actions, activation='linear'))
    model.compile(loss='mse', optimizer=Adam(lr=0.001))
    return model

def choose_action(state, epsilon):
    if np.random.rand() < epsilon:
        return np.random.choice(n_actions)
    else:
        q_values = model.predict(state)
        return np.argmax(q_values[0])

def train_dqn():
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    batch_size = 32
    replay_memory = []
    for episode in range(500):
        state = env.reset()
        state = np.reshape(state, [1, n_states])
        done = False
        steps = 0

        while not done:
            env.render()
            action = choose_action(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, n_states])
            replay_memory.append((state, action, reward, next_state, done))
            state = next_state
            steps += 1

            if done:
                print("Episode: %d, Steps: %d" % (episode, steps))
                break
            if len(replay_memory) > batch_size:
                minibatch = np.random.choice(replay_memory, batch_size, replace=False)
                states_mb = np.concatenate([mb[0] for mb in minibatch])
                actions_mb = np.array([mb[1] for mb in minibatch])
                rewards_mb = np.array([mb[2] for mb in minibatch])
                next_states_mb = np.concatenate([mb[3] for mb in minibatch])
                dones_mb = np.array([mb[4] for mb in minibatch])

                targets = rewards_mb + 0.99 * (np.amax(model.predict_on_batch(next_states_mb), axis=1)) * (1 - dones_mb)
                targets_full = model.predict_on_batch(states_mb)
                ind = np.array([i for i in range(batch_size)])
                targets_full[[ind], [actions_mb]] = targets

                model.fit(states_mb, targets_full, epochs=1, verbose=0)

            if epsilon > epsilon_min:
                epsilon *= epsilon_decay

    env.close()

if __name__ == '__main__':

    model = create_dqn_model()

    train_dqn()

        运行结果: 

Episode: 0, Steps: 14
Episode: 1, Steps: 26
Episode: 2, Steps: 16
Episode: 3, Steps: 12
Episode: 4, Steps: 12
...
Episode: 498, Steps: 160
Episode: 499, Steps: 200

 

        通过运行以上代码,可以看到DQN算法在解决CartPole问题上的表现。在经过多个episode的训练后,算法可以在游戏中坚持更久的时间,最终取得较高的得分。 

七、总结

        本文对DQN算法进行了详细的讲解,包括发展史、算法公式和原理、功能、示例代码以及如何使用。DQN算法通过结合深度学习和Q-learning算法,实现了对高维、连续状态空间的处理,具备了学习和规划的能力。通过示例代码的运行结果,我们可以看到DQN算法在解决CartPole问题上取得了较好的效果。然而,DQN算法也存在一些局限性,比如训练不稳定、样本相关性等问题。未来的研究可以进一步改进算法,并将其应用于更广泛的任务领域。

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703510.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5、多层感知机:过拟合解决方法:权重衰退、丢弃法

1、权重衰退 1. 基础概念 实际上&#xff0c;限制特征的数量是缓解过拟合的一种常用技术。然而&#xff0c;简单地丢弃特征对这项工作来说可能过于生硬。我们继续思考多项式回归的例子&#xff0c;考虑高维输入可能发生的情况。多项式对多变量数据的自然扩展称为单项式&#…

微信小程序学习记录3 案例分享<蓝牙小车UI>

效果 页面1 一键连接蓝牙 页面2 控制页面 思路 页面1 旋转动画一键连接蓝牙(写死了device id和uuid) 页面2 轮播图按键 按键绑定不同事件即可

多模态大语言模型综述来啦!一文带你理清多模态关键技术

夕小瑶科技说 原创 作者 | 智商掉了一地、Python 随着 ChatGPT 在各领域展现出非凡能力&#xff0c;多模态大型语言模型&#xff08;MLLM&#xff09;近来也成为了研究的热点&#xff0c;它利用强大的大型语言模型&#xff08;LLM&#xff09;作为“大脑”&#xff0c;可以执…

Linux卸载OpenJDK

1、JDK版本 java -version2、查看当前系统OPENJDK rpm -qa | grep jdk3、卸载 sudo rpm -e --nodeps java-1.8.0-openjdk sudo rpm -e --nodeps java-1.8.0-openjdk-headless sudo rpm -e --nodeps java-1.8.0-openjdk-devel sudo rpm -e --nodeps copy-jdk-configs4、其他 …

爬虫入门指南(3):Python网络请求及常见反爬虫策略应对方法

文章目录 引言HTTP协议与请求方法HTTP协议请求方法 使用Python进行网络请求安装Requests库发送GET请求发送POST请求 反爬虫与应对策略IP限制使用代理IP&#xff1a; 用户代理检测设置User-Agent头部&#xff1a; 验证码参考方案 动态页面请求频率限制未完待续.... 引言 在当今…

Python 基本数据类型(五)

文章目录 每日一句正能量List&#xff08;列表&#xff09;结语 每日一句正能量 营造良好的工作和学习氛围&#xff0c;时刻牢记宗旨&#xff0c;坚定信念&#xff0c;胸怀全局&#xff0c;埋头苦干&#xff0c;对同事尊重信任谅解&#xff0c;发扬团体协作精神&#xff0c;积极…

安装tomcat的步骤与部署服务详解

目录 一安装tomcat步骤 1.首先关闭防火墙跟安全机制 2.安装jdk把jdk包与tomcat的包下载到/opt目录 3.然后用命令查看安装好的版本号 4.安装好jdk后需要把他设置到环境变量中去以免系统找不到 5.下来装tomcat 6.首先把你的tomcat安装包解压 然后包名太长简写成tomcat 7.进…

【单片机】STM32单片机频率计程序,外部脉冲计数程序,基于脉冲计数的频率计程序,STM32F103

文章目录 定时器外部脉冲计数功能程序实现TIM1 定时一秒钟中断TIM2 外部脉冲计数配置TIM3 PWM产生总程序 定时器外部脉冲计数功能 两种方法用于在单片机中实现频率计的功能。 第一种方法是通过定时器来衡量信号的周期&#xff0c;然后将周期转换为频率。在这种方法中&#xf…

异步秒杀逻辑前后端

前端 Sckill.vue <template><div><h2>go语言从入门到放弃</h2><el-button type"danger" click"handleSckill">秒杀</el-button></div> </template><script> export default {name: "Sckill&q…

华为认证 | HCIP-Datacom考试费多少?

华为认证之前版本的是RS体系的&#xff0c;也是我们所称的路由交换宣布结束。 Datacom代替了RS&#xff0c;考试内容和形式也发生了变化&#xff0c;今天就让我们来详细了解一下。 01 HCIP-Datacom考什么 HCIP-Datacom系列认证包含1门核心认证Core Technology和6门子认证。 …

MySQL数据库的优化技术二

纵论 对mysql优化时一个综合性的技术&#xff0c;主要包括 表的设计合理化(符合3NF)添加适当索引(index) [ 四种: 普通索引、主键索引、唯一索引unique、全文索引 ]分表技术( 水平分割、垂直分割 ) 水平分割根据一个标准重复定义几个字段值相同&#xff0c;表名称不同的表&…

七.图像处理与光学之镜头LSC

七.图像处理与光学之镜头LSC Lens Shading视为镜头阴影/镜头暗影,此外,还有称Lens Shading为亮度均匀性的。 7.1 LSC(Lens Shading Correction)现象 具体现象如下: 如图所示: 拍摄纯灰色卡(正常所有像素值一样)时shading的具体现象 上侧称为Luma shading,下侧称为c…

Camtasia Studio 2023永久激活版免费下载

电脑屏幕录制工具Camtasia Studio2023是一款功能非常强大的电脑屏幕录像软件&#xff0c;这款软件目前在国内非常受欢迎。我们可以通过这款软件来录制各种软件使用教程和游戏攻略教程。这样你就可以将自己在电脑上的每一步操作全部录制下来&#xff0c;从而分享给其他人欣赏。 …

logback日志的分片压缩

logback-spring.xml <?xml version"1.0" encoding"UTF-8"?> <configuration debug"true"><springProperty name"LOG_PATH" source"shands.log.logPath" defaultValue"/var/delonix/logs/local"…

在 Jetpack Compose 中创建 AppBar

Jetpack Compose 是 Android 的现代 UI 工具库&#xff0c;使用声明性编程简化了 UI 的开发过程。在本文中&#xff0c;我们将学习如何使用 Jetpack Compose 创建 AppBar。 什么是 AppBar&#xff1f; AppBar&#xff0c;也就是我们常说的顶部应用栏&#xff0c;是用户界面的一…

基于YOLOv5系列【n/s/m/l】模型开发构建人体手势目标检测识别分析系统

人体手势检测识别是指通过计算机视觉和深度学习技术&#xff0c;自动地识别和理解人体的手势动作。这项技术可以应用于各种领域&#xff0c;如人机交互、虚拟现实、智能监控等。 下面是一般的人体手势检测识别流程&#xff1a; 数据采集&#xff1a;首先需要收集包含手势动作的…

AI自动生成代码,是时候冷静下来思考如何保障代码安全了

HDC期间可参与华为开发者大会Check新人抽奖活动&#xff0c;活动链接在文末。 华为开发者大会2023将于7月7日与各位开发者进行见面&#xff0c;本次大会的主题演讲内容为&#xff1a;AI重塑千行百业。 自从AI聊天被推出之后&#xff0c;其热度就一直是高居不下。身边的小伙伴们…

【python】—— 基础语法(二)

序言&#xff1a; 在上期&#xff0c;我们已经对python进行了初步的学习和了解。本期&#xff0c;我将继续带领大家学习关于python的基本知识&#xff01;&#xff01; 目录 &#xff08;一&#xff09;顺序语句 &#xff08;二&#xff09;条件语句 1、什么是条件语句 2、…

基于springboot实现的健身房系统(免费)

1.1 项目概述 开发语言&#xff1a;Java8 数据库&#xff1a;MySQL5.7以上版本 前端技术&#xff1a;template模板引擎 后端技术&#xff1a;Spring SpringMVC MyBaties shiro 数据库连接池&#xff1a;Druid 服务器&#xff1a;Tomcat 开发工具&#xff1a;idea na…

【C语言】十大经典排序算法-动图演示

目录 0、算法概述 0.1 算法分类 0.2 算法复杂度 1、冒泡排序&#xff08;Bubble Sort&#xff09; 1.1 算法描述 1.3 代码实现 2、选择排序&#xff08;Selection Sort&#xff09; 2.1 算法描述 2.2 动图演示 2.3 代码实现 2.4 算法分析 3、插入排序&#xff08;I…