优先经验回放(prioritized experience replay)

news2025/1/18 4:44:40

prioritized experience replay 思路

优先经验回放出自ICLR 2016的论文《prioritized experience replay》。

prioritized experience replay的作者们认为,按照一定的优先级来对经验回放池中的样本采样,相比于随机均匀的从经验回放池中采样的效率更高,可以让模型更快的收敛。其基本思想是RL agent在一些转移样本上可以更有效的学习,也可以解释成“更多地训练会让你意外的数据”。

那优先级如何定义呢?作者们使用的是样本的TD error δ \delta δ 的幅值。对于新生成的样本,TD error未知时,将样本赋值为最大优先级,以保证样本至少将会被采样一次。每个采样样本的概率被定义为
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=kpkαpiα
上式中的 p i > 0 p_i >0 pi>0是回放池中的第i个样本的优先级, α \alpha α则强调有多重视该优先级,如果 α = 0 \alpha=0 α=0,采样就退化成和基础DQN一样的均匀采样了。

p i p_i pi如何取值,论文中提供了如下两种方法,两种方法都是关于TD error δ \delta δ 单调的:

  • 基于比例的优先级: p i = ∣ δ i ∣ + ϵ p_i = |\delta_i| + \epsilon pi=δi+ϵ ϵ \epsilon ϵ是一个很小的正数常量,防止当TD error为0时样本就不会被访问到的情形。(目前大部分实现都是使用的这个形式的优先级)
  • 基于排序的优先级: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi=rank(i)1, 式中的 r a n k ( i ) rank(i) rank(i)是样本根据 ∣ δ i ∣ |\delta_i| δi 在经验回放池中的排序号,此时P就变成了带有指数 α \alpha α的幂率分布了。

作者们定义的概率调整了样本的优先级,因此也就在数据分布中引入了偏差,为了弥补偏差,使用了重要性采样权重(importance-sampling (IS) weights):
w i = ( 1 N ⋅ 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi=(N1P(i)1)β
上式权重中,当 β = 1 \beta=1 β=1时就完全补偿了非均匀概率采样引入的偏差,作者们提到为了收敛性考虑,最后让 β \beta β从0到1中的某个值开始,并逐渐增加到1。在Q-learning更新时使用这些权重乘以TD error,也就是使用 w i δ i w_i \delta_i wiδi而不是原来的 δ i \delta_i δi。此外,为了使训练更稳定,总是对权重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxiwi进行归一化。

以Double DQN为例,使用优先经验回放的算法(论文算法1)如下图:

在这里插入图片描述

prioritized experience replay 实现

直接实现优先经验回放池如下代码(修改自代码 )

class PrioReplayBufferNaive:
    def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):
        self.prob_alpha = prob_alpha
        self.capacity = buf_size
        self.pos = 0
        self.buffer = []
        self.priorities = np.zeros((buf_size, ), dtype=np.float32)
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.epsilon = epsilon

    def __len__(self):
        return len(self.buffer)

    def size(self):  # 目前buffer中数据的数量
        return len(self.buffer)

    def add(self, sample):
        # 新加入的数据使用最大的优先级,保证数据尽可能的被采样到
        max_prio = self.priorities.max() if self.buffer else 1.0
        if len(self.buffer) < self.capacity:
            self.buffer.append(sample)
        else:
            self.buffer[self.pos] = sample
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        probs = np.array(prios, dtype=np.float32) ** self.prob_alpha

        probs /= probs.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)
        samples = [self.buffer[idx] for idx in indices]
        total = len(self.buffer)
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
        weights = (total * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        return samples, indices, np.array(weights, dtype=np.float32)

    def update_priorities(self, batch_indices, batch_priorities):
        '''
        更新样本的优先级'''
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio + self.epsilon

直接实现的优先经验回放,在样本数很大时的采样效率不够高,作者们通过定义了sumtree的数据结构来存储样本优先级,该数据结构的每一个节点的值为其子节点之和,而样本优先级被放在树的叶子节点上,树的根节点的值为所有优先级之和 p t o t a l p_{total} ptotal,更新和采样时的效率为 O ( l o g N ) O(logN) O(logN)。在采样时,设采样批次大小为k,将 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal]均分为k等份,然后在每一个区间均匀的采样一个值,再通过该值从树中提取到对应的样本。python 实现如下(代码来源)

class SumTree:
    """
    父节点的值是其子节点值之和的二叉树数据结构
    """
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])


class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTree
    epsilon = 0.01
    alpha = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity

    def _get_priority(self, error):
        return (np.abs(error) + self.epsilon) ** self.alpha

    def add(self, error, sample):
        p = self._get_priority(error)
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        idxs = []
        segment = self.tree.total() / n
        priorities = []

        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])

        for i in range(n):
            a = segment * i
            b = segment * (i + 1)

            s = random.uniform(a, b)
            (idx, p, data) = self.tree.get(s)
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)

        sampling_probabilities = priorities / self.tree.total()
        is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
        is_weight /= is_weight.max()

        return batch, idxs, is_weight

    def update(self, idx, error):
      '''
      这里是一次更新一个样本,所以在调用时,写for循环依次更次样本的优先级
      '''
        p = self._get_priority(error)
        self.tree.update(idx, p)

参考资料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的实现代码

  3. 相关blog: 1 (对应的代码), 2, 3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1239382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第一百七十六回 如何创建渐变色边角

文章目录 1. 概念介绍2. 实现方法3. 代码与细节3.1 示例代码3.2 代码细节 4. 内容总结 我们在上一章回中介绍了"如何创建放射形状渐变背景"相关的内容&#xff0c;本章回中将介绍"如何创建渐变色边角".闲话休提&#xff0c;让我们一起Talk Flutter吧。 1.…

2023年【广东省安全员B证第四批(项目负责人)】报名考试及广东省安全员B证第四批(项目负责人)复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 广东省安全员B证第四批&#xff08;项目负责人&#xff09;报名考试是安全生产模拟考试一点通总题库中生成的一套广东省安全员B证第四批&#xff08;项目负责人&#xff09;复审考试&#xff0c;安全生产模拟考试一点…

智能安全帽作业记录仪赋能智慧工地人脸识别劳务实名制

需求背景 建筑工地是一个安全事故多发的场所。目前&#xff0c;工程建设规模不断扩大&#xff0c;工艺流程纷繁复杂&#xff0c;如何完善现场施工现场管理&#xff0c;控制事故发生频率&#xff0c;保障文明施工一直是施工企业、政府管理部门关注的焦点。尤其随着社会的不断进…

18.天气小案例

1►新增带Layout组件的页面 直接在views文件夹下面新增weather.vue。然后随便写一个123&#xff0c;现在先让我们页面能跳过去先。 让页面能跳过去&#xff0c;有好几种方法&#xff1a; 1、在菜单管理自己添加一个菜单&#xff0c;然后把菜单分配给某个角色&#xff0c;再把…

【Python进阶】近200页md文档14大体系第4篇:Python进程使用详解(图文演示)

本文从14大模块展示了python高级用的应用。分别有Linux命令&#xff0c;多任务编程、网络编程、Http协议和静态Web编程、htmlcss、JavaScript、jQuery、MySql数据库的各种用法、python的闭包和装饰器、mini-web框架、正则表达式等相关文章的详细讲述。 Python全套笔记直接地址…

【iOS】实现评论区展开效果

文章目录 前言实现行高自适应实现评论展开效果解决cell中的buttom的复用问题 前言 在知乎日报的评论区中&#xff0c;用到了Masonry行高自适应来实现评论的展开&#xff0c;这里设计许多控件的约束问题&#xff0c;当时困扰了笔者许久&#xff0c;特此撰写博客记录 实现行高自…

GB28181学习(十七)——基于jrtplib实现tcp被动和主动发流

前言 GB/T28181-2022实时流的传输方式介绍&#xff1a;https://blog.csdn.net/www_dong/article/details/134255185 基于jrtplib实现tcp被动和主动收流介绍&#xff1a;https://blog.csdn.net/www_dong/article/details/134451387 本文主要介绍下级平台或设备发流功能&#…

微信小游戏上线流程

微信小游戏上线是一个需要经过一系列步骤的过程。以下是一个一般性的微信小游戏上线流程&#xff0c;请注意&#xff0c;上述步骤可能会有微信平台的政策和规定的变化&#xff0c;因此建议在开发过程中及时查阅微信小游戏的官方文档和最新政策。北京木奇移动技术有限公司&#…

DB2—03(DB2中常见基础操作)

DB2—03&#xff08;DB2中常见基础操作&#xff09; 1. 前言1.1 oracle和mysql相关 2. db2中的"dual"2.1 SYSIBM.SYSDUMMY12.2 使用VALUES2.3 SYSIBM.SYSDUMMY1 "变" dual 3. db2中常用函数3.1 nvl()、value()、COALESCE()3.2 NULLIF() 函数3.3 LISTAGG() …

含分布式电源的配电网可靠性评估matlab程序

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 参考文献&#xff1a; 基于仿射最小路法的含分布式电源配电网可靠性分析——熊小萍 主要内容&#xff1a; 通过概率模型和时序模型分别进行建模&#xff0c;实现基于概率模型最小路法的含分布式电源配电网…

Python BDD之Behave测试报告

behave 本身的测试报告 behave 本身提供了四种报告格式&#xff1a; pretty&#xff1a;这是默认的报告格式&#xff0c;提供颜色化的文本输出&#xff0c;每个测试步骤的结果都会详细列出。plain&#xff1a;这也是一种文本格式的报告&#xff0c;但没有颜色&#xff0c;并且…

【C++】泛型编程 ⑫ ( 类模板 static 关键字 | 类模板 static 静态成员 | 类模板使用流程 )

文章目录 一、类模板使用流程1、类模板 定义流程2、类模板 使用3、类模板 函数 外部实现 二、类模板 static 关键字1、类模板 static 静态成员2、类模板 static 关键字 用法3、完整代码示例 将 类模板 函数声明 与 函数实现 分开进行编码 , 有 三种 方式 : 类模板 的 函数声明…

超详细!新手必看!STM32-通用定时器简介与知识点概括

一、通用定时器的功能 在基本定时器功能的基础上新增功能&#xff1a; 通用定时器有4个独立通道&#xff0c;且每个通道都可以用于下面功能。 &#xff08;1&#xff09;输入捕获&#xff1a;测量输入信号的周期和占空比等。 &#xff08;2&#xff09;输出比较&#xff1a;产…

电大搜题——让学习变得轻松高效

作为一名现代学者&#xff0c;您一定时刻关注着教育领域的进展和创新。今天&#xff0c;我将向大家介绍一个名为“电大搜题”的神奇工具&#xff0c;它将为您的学习之路带来一场完美的革命。 在快节奏的现代社会中&#xff0c;学习已经成为每个人追求成功的必经之路。然而&…

Linux超简单部署个人博客

1 安装halo 1.1 切换到超级用户 sudo -i 1.2 新建halo文件夹 mkdir ~/halo && cd ~/halo 1.3 编辑docker-compose.yml文件 vim ~/halo/docker-compose.yml 英文输入法下&#xff0c;按 i version: "3"services:halo:image: halohub/halo:2.10container_…

2023年【起重机械指挥】考试题及起重机械指挥找解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 起重机械指挥考试题考前必练&#xff01;安全生产模拟考试一点通每个月更新起重机械指挥找解析题目及答案&#xff01;多做几遍&#xff0c;其实通过起重机械指挥作业考试题库很简单。 1、【多选题】按照事故造成的人…

新手必看!!附源码!!STM32通用定时器输出PWM

一、什么是PWM? PWM&#xff08;脉冲宽度调制&#xff09;是一种用于控制电子设备的技术。它通过调整信号的脉冲宽度来控制电压的平均值。PWM常用于调节电机速度、控制LED亮度、产生模拟信号等应用。 二、PWM的原理 PWM的基本原理是通过以一定频率产生的脉冲信号&#xff0…

Java 编码

编码: 加密: 通过加密算法和密钥进行 也可通过码表进行加密 对称加密: 缺点:可被截获 元数据---加密算法密钥密文 ----> 解密算法密钥元数据 算法:DES(短 56位),AES(长 128位)破解时间加长 非对称加密: 元数据-加密算法加密密钥 密文 --->加密算法解密密钥元数据 …

亚马逊买家号用邮箱怎么注册

想要用邮箱注册亚马逊买家号&#xff0c;那么准备好能接受验证码的邮箱后打开相应的亚马逊官网即可。打开官网后点击注册——输入昵称——输入邮箱——输入密码——接受邮箱验证码并输入&#xff0c;如果遇到需要手机号验证就输入手机号&#xff0c;如果不需要验证&#xff0c;…

Mac M1 安装Docker打包arm64的python项目的镜像包

1、首先安装Docker&#xff0c;到官网下载&#xff0c;选择apple chip版 Docker中文网 官网 2、双击下载的dmg文件&#xff0c;在弹出框中之间拖拽到右边 3、打开docker&#xff0c;修改国内镜像源&#xff0c;位置在配置-DockerEngine "registry-mirrors": ["…