DDPG算法详解

news2025/1/20 1:55:49

DQN算法详解

一.概述

概括来说,RL要解决的问题是:让agent学习在一个环境中的如何行为动作(act), 从而获得最大的奖励值总和(total reward)。
这个奖励值一般与agent定义的任务目标关联。
agent需要的主要学习内容:第一是行为策略(action policy), 第二是规划(planning)。
其中,行为策略的学习目标是最优策略, 也就是使用这样的策略,可以让agent在特定环境中的行为获得最大的奖励值,从而实现其任务目标。

行为(action)可以简单分为:

  • 连续的:如赛 车游戏中的方向盘角度、油门、刹车控制信号,机器人的关节伺服电机控制信号。
  • 离散的:如围棋、贪吃蛇游戏。 Alpha Go就是一个典型的离散行为agent。

DDPG是针对连续行为的策略学习方法。

二.DDPG的定义和应用场景

在RL领域,DDPG主要从:PG -> DPG -> DDPG 发展而来。

基本概念:

PG

R.Sutton 在2000年提出的Policy Gradient 方法,是RL中,学习连续的行为控制策略的经典方法,其提出的解决方案是:
通过一个概率分布函数 , 来表示每一步的最优策略, 在每一步根据该概率分布进行action采样,获得当前的最佳action取值;即:
在这里插入图片描述
生成action的过程,本质上是一个随机过程;最后学习到的策略,也是一个随机策略(stochastic policy).

DPG

Deepmind的D.Silver等在2014年提出DPG: Deterministic Policy Gradient, 即确定性的行为策略,每一步的行为通过函数μ 直接获得确定的值:
在这里插入图片描述
这个函数μ即最优行为策略,不再是一个需要采样的随机策略。

为何需要确定性的策略?简单来说,PG方法有以下缺陷:

  1. 即使通过PG学习得到了随机策略之后,在每一步行为时,我们还需要对得到的最优策略概率分布进行采样,才能获得action的具体值;而action通常是高维的向量,比如25维、50维,在高维的action空间的频繁采样,无疑是很耗费计算能力的;
  2. 在PG的学习过程中,每一步计算policy gradient都需要在整个action space进行积分:
    在这里插入图片描述

DDPG

Deepmind在2016年提出DDPG,全称是:Deep Deterministic Policy Gradient,是将深度学习神经网络融合进DPG的策略学习方法。
相对于DPG的核心改进是: 采用卷积神经网络作为策略函数μ 和Q 函数的模拟,即策略网络和Q网络;然后使用深度学习的方法来训练上述神经网络。

Q函数的实现和训练方法,采用了Deepmind 2015年发表的DQN方法 ,即 Alpha Go使用的Q函数方法。

1.DDQN解决问题方法

  1. 使用卷积神经网络来模拟策略函数和Q函数,并用深度学习的方法来训练,证明了在RL方法中,非线性模拟函数的准确性和高性能、可收敛;
    而DPG中,可以看成使用线性回归的机器学习方法:使用带参数的线性函数来模拟策略函数和Q函数,然后使用线性回归的方法进行训练。
  2. experience replay memory的使用:actor同环境交互时,产生的transition数据序列是在时间上高度关联(correlated)的,如果这些数据序列直接用于训练,会导致神经网络的overfit,不易收敛。
    DDPG的actor将transition数据先存入experience replay buffer, 然后在训练时,从experience replay buffer中随机采样mini-batch数据,这样采样得到的数据可以认为是无关联的。
  3. target 网络和online 网络的使用, 使的学习过程更加稳定,收敛更有保障。
    在这里插入图片描述

四、 代码详解

import gym
import paddle
import paddle.nn as nn
from itertools import count
from paddle.distribution import Normal
import numpy as np
from collections import deque
import random
import paddle.nn.functional as F
# 定义评论家网络结构
# DDPG这种方法与Q学习紧密相关,可以看作是连续动作空间的深度Q学习。
class Critic(nn.Layer):
    def __init__(self):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(3, 256)
        self.fc2 = nn.Linear(256 + 1, 128)
        self.fc3 = nn.Linear(128, 1)
        self.relu = nn.ReLU()

    def forward(self, x, a):
        x = self.relu(self.fc1(x))
        x = paddle.concat((x, a), axis=1)
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 定义演员网络结构
# 为了使DDPG策略更好地进行探索,在训练时对其行为增加了干扰。 原始DDPG论文的作者建议使用时间相关的 OU噪声 ,
# 但最近的结果表明,不相关的均值零高斯噪声效果很好。 由于后者更简单,因此是首选。
class Actor(nn.Layer):
    def __init__(self, is_train=True):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.noisy = Normal(0, 0.2)
        self.is_train = is_train

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.tanh(self.fc3(x))
        return x

    def select_action(self, epsilon, state):
        state = paddle.to_tensor(state,dtype="float32").unsqueeze(0)
        with paddle.no_grad():
            action = self.forward(state).squeeze() + self.is_train * epsilon * self.noisy.sample([1]).squeeze(0)
        return 2 * paddle.clip(action, -1, 1).numpy()

# 重播缓冲区:这是智能体以前的经验, 为了使算法具有稳定的行为,重播缓冲区应该足够大以包含广泛的体验。
# 如果仅使用最新数据,则可能会过分拟合,如果使用过多的经验,则可能会减慢模型的学习速度。 这可能需要一些调整才能正确。
class Memory(object):
    def __init__(self, memory_size: int) -> None:
        self.memory_size = memory_size
        self.buffer = deque(maxlen=self.memory_size)

    def add(self, experience) -> None:
        self.buffer.append(experience)

    def size(self):
        return len(self.buffer)

    def sample(self, batch_size: int, continuous: bool = True):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        if continuous:
            rand = random.randint(0, len(self.buffer) - batch_size)
            return [self.buffer[i] for i in range(rand, rand + batch_size)]
        else:
            indexes = np.random.choice(np.arange(len(self.buffer)), size=batch_size, replace=False)
            return [self.buffer[i] for i in indexes]

    def clear(self):
        self.buffer.clear()


# 定义软更新的函数
def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.set_value(target_param * (1.0 - tau) + param * tau)

    # 定义环境、实例化模型


env = gym.make('Pendulum-v1')
actor = Actor()
critic = Critic()
actor_target = Actor()
critic_target = Critic()

# 定义优化器
critic_optim = paddle.optimizer.Adam(parameters=critic.parameters(), learning_rate=3e-5)
actor_optim = paddle.optimizer.Adam(parameters=actor.parameters(), learning_rate=1e-5)

# 定义超参数
explore = 50000
epsilon = 1
gamma = 0.99
tau = 0.001

memory_replay = Memory(50000)
begin_train = False
batch_size = 32

learn_steps = 0

#writer = LogWriter('logs')

# 训练循环
for epoch in count():
    state = env.reset()
    episode_reward = 0
    for time_step in range(200):
        action = actor.select_action(epsilon, state)
        next_state, reward, done, _ = env.step([action])
        episode_reward += reward
        reward = (reward + 8.1) / 8.1
        memory_replay.add((state, next_state, action, reward))
        if memory_replay.size() > 1280:

            learn_steps += 1
            if not begin_train:
                print('train begin!')
                begin_train = True
            experiences = memory_replay.sample(batch_size, False)
            batch_state, batch_next_state, batch_action, batch_reward = zip(*experiences)

            batch_state = paddle.to_tensor(batch_state, dtype="float32")
            batch_next_state = paddle.to_tensor(batch_next_state, dtype="float32")
            batch_action = paddle.to_tensor(batch_action, dtype="float32").unsqueeze(1)
            batch_reward = paddle.to_tensor(batch_reward, dtype="float32").unsqueeze(1)

            # 均方误差 y - Q(s, a) , y是目标网络所看到的预期收益, 而 Q(s, a)是Critic网络预测的操作值。
            # y是一个移动的目标,评论者模型试图实现的目标;这个目标通过缓慢的更新目标模型来保持稳定。
            with paddle.no_grad():
                Q_next = critic_target(batch_next_state, actor_target(batch_next_state))
                Q_target = batch_reward + gamma * Q_next

            critic_loss = F.mse_loss(critic(batch_state, batch_action), Q_target)

            critic_optim.clear_grad()
            critic_loss.backward()
            critic_optim.step()

            #writer.add_scalar('critic loss', critic_loss.numpy(), learn_steps)
            # 使用Critic网络给定值的平均值来评价Actor网络采取的行动。 我们力求使这一数值最大化。
            # 因此,我们更新了Actor网络,对于一个给定状态,它产生的动作尽量让Critic网络给出高的评分。
            critic.eval()
            actor_loss = - critic(batch_state, actor(batch_state))
            # print(actor_loss.shape)
            actor_loss = actor_loss.mean()
            actor_optim.clear_grad()
            actor_loss.backward()
            actor_optim.step()
            critic.train()
            #writer.add_scalar('actor loss', actor_loss.numpy(), learn_steps)

            soft_update(actor_target, actor, tau)
            soft_update(critic_target, critic, tau)
            env.render()
        if epsilon > 0:
            epsilon -= 1 / explore
        state = next_state

    #writer.add_scalar('episode reward', episode_reward, epoch)
    if epoch % 10 == 0:
        print('Epoch:{}, episode reward is {}'.format(epoch, episode_reward))

    if epoch % 200 == 0:
        paddle.save(actor.state_dict(), 'model/ddpg-actor' + str(epoch) + '.para')
        paddle.save(critic.state_dict(), 'model/ddpg-critic' + str(epoch) + '.para')
        print('model saved!')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/434765.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode刷题(5)

各位朋友们,大家好,今天是我leedcode刷题的第五篇,我们一起来看看吧。 文章目录 栈的压入,弹出序列题目要求用例输入提示做题思路代码实现C语言代码实现Java代码实现 最小栈题目要求用例输入提示做题思路代码实现Java代码实现 栈的…

QML自定义模块及qmldir的使用

前言 在开发QtQuick项目中,当项目文件很多的情况下,可能会分成多级文件夹来进行分类,还有一些通用类型文件,如公共组件,通用配置等等,需要在各个不同的文件中进行调用,这种情况下,一…

04、Cadence使用记录之器件连接的连线、网络、总线、差分(OrCAD Capture CIS)

04、Cadence使用记录之器件连接的连线、页内网络、总线、跨页网络、差分、电源(OrCAD Capture CIS原理图) 前置教程: 01、Cadence使用记录之新建工程与基础操作(原理图绘制:OrCAD Capture CIS) 02、Cadenc…

操作系统原理 —— 操作系统的四个特征:并发、共享、虚拟、异步 (二)

本章我们来聊一下操作系统的四个特征 在我们的操作系统中有四个特征:并发、共享、虚拟、异步,我们结合每一个特征来进行讲解,我们先来看并发。 并发 这里所说的并发,最好不联想到并发编程。咱们就简简单单理解一下,…

浙工商机器学习课程论文+代码分享(含数据集)

文章目录 一、论文总览二、摘要 & 目录三、数据集的展示四、部分代码4.1 降低内存4.2 部分特征生成4.3 热力图分析4.4 变量分布图4.5 聚类算法4.6 聚类结果的展示(部分)4.7 聚类后的特征图 完整版的论文代码数据集地址: https://mbd.pub…

leetcode刷题(7)二叉树(1)

哈喽大家好,这是我leetcode刷题的第七篇,这两天我将更新leetcode上关于二叉树方面的题目,如果大家对这方面感兴趣的话,欢迎大家持续关注,谢谢大家。 那么我们就进入今天的主题。 文章目录 1.二叉树的前序遍历题目要求示…

优先级队列

目录 前言: 1、PriorityQueue的特性 .2 PriorityQueue常用接口介绍 Ⅰ、PriorityQueue常见的构造方法 Ⅱ、常用的方法 Ⅲ、PriorityQueue的扩容方式: 3、应用 前言: 普通的队列是一种 先进先出的数据结构,元素在队列尾追加&am…

RC专题:无源滤波电路和有源滤波电路

什么是无源滤波电路和有源滤波电路 仅由无源器件(电阻、电容、电感)构成的滤波电路 称为无源滤波电路。如下图所示。 由无源器件和有源器件(双极型管,单极型管,集成运放)构成的滤波电路 称为有源滤波电路。…

什么是爬虫?

网络爬虫(又被称为网页蜘蛛,网络机器人,在FOAF社区中间,更经常的称为网页追逐者),是一种按照一定的规则,自动地抓取万维网信息的程序或者脚本。另外一些不常使用的名字还有蚂蚁、自动索引、模拟…

2023第十三届MathorCup高校数学建模挑战赛C题解析

2023第十三届MathorCup高校数学建模挑战赛C题解析 题目解析前言题目一题目二题目三题目四 题目 C 题 电商物流网络包裹应急调运与结构优化问题 电商物流网络由物流场地(接货仓、分拣中心、营业部等)和物流场地之间的运输线路组成,如图 1 所示…

LeetCode:1. 两数之和——哈希表~

🍎道阻且长,行则将至。🍓 🌻算法,不如说它是一种思考方式🍀 算法专栏: 👉🏻123 一、🌱1. 两数之和 题目描述:给定一个整数数组nums 和一个整数目…

QT 插件通信接口调用 CTK开发(四)

CTK 为支持生物医学图像计算的公共开发包,其全称为 Common Toolkit。为医学成像提供一组统一的基本功能;促进代码和数据的交互及结合;避免重复开发;在工具包(医学成像)范围内不断扩展到新任务,而不会增加现有任务的负担;整合并适应成功的解决方案。 本专栏文章较为全面…

leetcode python刷题记录(十)(91~100)

leetcode python刷题记录(十)(91~100) 91. 解码方法 class Solution:def numDecodings(self, s: str) -> int:if not s or s[0]0:return 0nlen(s)dp[0]*(n1)dp[0]1dp[1]1for i in range(1,n):if s[i]0:if s[i-1]1 or s[i-1]2:…

【算法系列之二叉树I】leetcode226.翻转二叉树

非递归实现前序遍历 力扣题目链接 解决思路 前序遍历&#xff0c;中左右。先放右节点&#xff0c;后放左节点。 Java实现 class Solution {public List<Integer> preorderTraversal(TreeNode root) {//中左右Stack<TreeNode> stack new Stack<>();List…

蓝桥杯:人物相关性分析

蓝桥杯&#xff1a;人物相关性分析https://www.lanqiao.cn/problems/198/learning/ 目录 题目描述 输入描述 输出描述 输入输出样例 输入 输出 输入 输出 运行限制 题目分析:(滑动窗口) AC代码&#xff08;JAVA&#xff09; 题目描述 小明正在分析一本小说中…

【ChatGPT】无需魔法打开即用的 AI 工具集锦

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;蚂蚁集团高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《EffectiveJava》独家解析》专栏作者。 热门文章推荐…

原理+配置+实战,Canal一套带走

前几天在网上冲浪的时候发现了一个比较成熟的开源中间件——Canal。在了解了它的工作原理和使用场景后&#xff0c;顿时产生了浓厚的兴趣。今天&#xff0c;就让我们跟随阿Q的脚步&#xff0c;一起来揭开它神秘的面纱吧。 简介 canal 翻译为管道&#xff0c;主要用途是基于 M…

【设计】【Redis】分布式限流与算法实现

目录 前言实现application.propertiesconfig.RedisConfigMainApplicationcontroller.TrafficLimitControlleraop.AccessLimiterAspectaop.annotation.AccessLimiter 项目结构运行限流脚本计数器滑动窗口令牌桶漏桶 参考资料 前言 服务的某些场景可能会出现短时间内的巨大访问流…

【C语言进阶:动态内存管理】柔性数组

本节重点内容&#xff1a; 柔性数组的特点柔性数组的使用柔性数组的优势 ⚡柔性数组 也许你从来没有听说过柔性数组&#xff08;flexible array&#xff09;这个概念&#xff0c;但是它确实是存在的。C99 中&#xff0c;结构中的最后一个元素允许是未知大小的数组&#xff0c…

java+ssm 社区超市网上商城果蔬(水果蔬菜)管理系统

在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括超市果蔬管理系统的网络应用&#xff0c;在外国超市果蔬管理系统已经是很普遍的方式&#xff0c;不过国内的超市果蔬管理系统可能还处于起步阶段。超市果蔬管理系统具有果蔬管…