Count-based exploration with neural density models论文笔记

news2025/1/10 20:32:36

Count-based exploration with neural density models[J]. International Conference on Machine Learning,International Conference on Machine Learning, 2017.

基于计数的神经密度模型探索

0、问题

这篇文章的关键在于弄懂pseudo-count的概念,以及是如何运用pseudo-count去进行探索的。pseudo-count主要用于生成探索奖励,即可以理解为生成内在奖励。

但是仍然保留一个疑问为在使用PixelCNN得到状态st的概率密度和状态st+1的概率密度后,为何不适用
N ^ n ( x ) = ρ n ( x ) ( 1 − ρ n ′ ( x ) ) ρ n ′ ( x ) − ρ n ( x ) \hat{\mathrm{N}}_n(x)=\frac{\rho_n(x)(1-\rho_n'(x))}{\rho_n'(x)-\rho_n(x)} N^n(x)=ρn(x)ρn(x)ρn(x)(1ρn(x))
这一公式来直接计算pseudo-count,这里的计算好像还用的是等号?而使用
N ^ n ( x ) ≈ ( e P G n ( x ) − 1 ) − 1 \hat{\mathrm{N}}_n(x)\approx\left(e^{\mathrm{PG}_n(x)}-1\right)^{-1} N^n(x)(ePGn(x)1)1
这一公式来近似计算pseudo-count?

1、Motivation

本文主要是针对强化学习中的智能体探索方面,提出了一种基于计数的探索方式。

在强化学习中,dynamics(动态模型)指的是对环境的模拟或建模。它描述了智能体与环境互动的方式,包括智能体采取行动后环境如何变化以及智能体所观察到的状态转换。

“pseudo-count”(伪计数)是一个在统计学和机器学习中常用的概念。它指的是一种人为引入的计数,用于对现有数据的不确定性进行建模。

密度模型(Density Model)是一种用于建模概率密度函数的数学模型,它可以用来描述或预测随机变量的分布。密度模型在统计学、概率论、信息论、机器学习等领域中得到了广泛应用。

2、Background

状态的概率密度模型p
ρ ( x ) = P ( X n + 1 = x ∣ X 1 … X n = x 1 : n ) = N ^ ( x ) n ^ \rho(x)=P(X_{n+1}=x|X_1\ldots X_n=x_{1:n})=\frac{\hat{N}(x)}{\hat{n}} ρ(x)=P(Xn+1=xX1Xn=x1:n)=n^N^(x)
prediction gain of ρ:
P G n ( x ) = log ⁡ ρ n ′ ( x ) − log ⁡ ρ n ( x ) \mathrm{PG}_n(x)=\log\rho_n^{\prime}(x)-\log\rho_n(x) PGn(x)=logρn(x)logρn(x)
pseudo-count:

  1. N ^ n ( x ) = ρ n ( x ) ( 1 − ρ n ′ ( x ) ) ρ n ′ ( x ) − ρ n ( x ) \hat{\mathrm{N}}_n(x)=\frac{\rho_n(x)(1-\rho_n'(x))}{\rho_n'(x)-\rho_n(x)} N^n(x)=ρn(x)ρn(x)ρn(x)(1ρn(x))

pseudo-count可以用PG来近似:
N ^ n ( x ) ≈ ( e P G n ( x ) − 1 ) − 1 \hat{\mathrm{N}}_n(x)\approx\left(e^{\mathrm{PG}_n(x)}-1\right)^{-1} N^n(x)(ePGn(x)1)1
tips:以上公式的具体推导引用了白辰甲老师的知乎回答强化学习中的探索与利用(count-based) - 知乎 (zhihu.com)

由此可以得到这篇文章中提出的内在奖励公式:
r + ( x ) : = ( N ^ n ( x ) ) − 1 / 2 r^+(x):=(\hat{\mathrm{N}}_n(x))^{-1/2} r+(x):=(N^n(x))1/2
本文估计期望回报采用mixed Monte-Carlo update (MMC)算法:
Q ( x , a ) ← Q ( x , a ) + α [ ( 1 − β ) δ ( x , a ) + β δ M C ( x , a ) ] Q(x,a)\leftarrow Q(x,a)+\alpha\left[(1-\beta)\delta(x,a)+\beta\delta_{\mathsf{MC}}(x,a)\right] Q(x,a)Q(x,a)+α[(1β)δ(x,a)+βδMC(x,a)]

其中:
δ ( x , a ) = r ( x , a ) + γ max ⁡ a ′ Q ( x ′ , a ′ ) − Q ( x , a ) \delta\left(x,a\right)=r(x,a)+\gamma\operatorname*{max}_{a^{\prime}}Q(x^{\prime},a^{\prime})-Q(x,a) δ(x,a)=r(x,a)+γamaxQ(x,a)Q(x,a)

δ MC ( x , a ) = ∑ t = 0 ∞ γ t r ( x t , a t ) − Q ( x , a ) \delta_{\text{MC}}( x , a )=\begin{aligned}\sum_{t=0}^{\infty}\gamma^{t}r(x_{t},a_{t})-Q(x,a)\end{aligned} δMC(x,a)=t=0γtr(xt,at)Q(x,a)

前者为TD算法中的目标值与实际值之差,后者为蒙特卡洛算法中实际回报与实际动作状态价值之差。

3、一些估计Return的算法

各种算法估计Return的利弊:

  1. TD(λ) with important sampling :可以保证收敛,但是重要性采样的系数引入了极大的方差,导致算法的收敛过程不稳定。
  2. Q(λ) :忽略重要性采样系数,直接乘以λ,能保证方差小,但是只有在采样策略和目标策略接近时才可以保证收敛,不安全。
  3. Retrace算法:低方差(控制了重要性采样系数的大小)、安全性高(总是能“安全”地利用各种行为策略采样得到的样本,当behavior policy和target policy差很多的时候,依然能保障收敛性)、样本效率高(对reward的压缩性没有那么高),但是Retrace(λ)算法在学习时过于谨慎,可能无法充分利用探索奖励,因为在计算重要性采样比率时采样的数据会被截断,只有那些足够接近当前策略的状态-行为轨迹才会被保留

估计Return的通用算子:
R Q ( x , a ) : = Q ( x , a ) + E μ [ ∑ t ≥ 0 γ t ( ∏ s = 1 t c s ) ( r t + γ E π Q ( x t + 1 , ⋅ ) − Q ( x t , a t ) ) ] \mathcal{R}Q(x,a):=Q(x,a)+\mathbb{E}_\mu\left[\sum_{t\geq0}\gamma^t(\prod_{s=1}^tc_s)(r_t+\gamma\mathbb{E}_\pi Q(x_{t+1},\cdot)-Q(x_t,a_t))\right] RQ(x,a):=Q(x,a)+Eμ[t0γt(s=1tcs)(rt+γEπQ(xt+1,)Q(xt,at))]
将TD(λ)、Q(λ)、Retrace等算法的不同归结为c_{s}的不同:

  1. TD(λ) with import sampling:
    c s = λ ⋅ π ( a s ∣ x s ) μ ( a s ∣ x s ) c_s=\lambda\cdot\frac{\pi(a_s|x_s)}{\mu(a_s|x_s)} cs=λμ(asxs)π(asxs)

  2. Q(λ):
    c s = λ   c_s=\lambda\ cs=λ 

  3. Retrace(λ):
    c s = λ ⋅ m i n ( 1 , π ( a s ∣ x s ) μ ( a s ∣ x s ) ) c_s=\lambda\cdot min{\left(1,\frac{\pi(a_s|x_s)}{\mu(a_s|x_s)}\right)} cs=λmin(1,μ(asxs)π(asxs))

由于Retrace(λ)使用了在1处截断的Importance Sampling,方差得到了降低。同时,因为
min ⁡ ( 1 , π ( a s ∣ x s ) μ ( a s ∣ x s ) ) ≥ π ( a s ∣ x s ) \min\left(1,\frac{\pi(a_s|x_s)}{\mu(a_s|x_s)}\right)\geq\pi(a_s|x_s) min(1,μ(asxs)π(asxs))π(asxs)
所以Retrace(λ)对回报的压缩幅度更弱(尤其是在两个policy接近时),从而提高了return的利用效率。

tips:对于Retrace(λ)算子的详细推导过程见【Typical RL 19】Retrace - 知乎 (zhihu.com)

4、方法过程

使用PixelCNN,将当前状态作为输入,输出对应状态的概率密度估计,通过对状态概率密度进行计数,计算出每个状态的探索奖励,即越少访问过的状态获得的奖励越高。这样,在选择下一个动作时,在智能体的策略中加入了探索奖励的权重,以鼓励更多地探索未知的状态,从而提高学习效率和收敛速度。

为了确保pseudo-counts与真实计数近似线性增长,PG应该以n^-1的速率衰减。于是将PG_{n}替换为c_{n}*PG_{n},其中c_{n}为:
c n = c n c_n=\frac{c}{\sqrt{n}} cn=n c
文章中通过实验确定c=0.1时结果最好。

由于当神经网络模型的优化器超过局部损失的最小值时,会出现负PG,因此需要给PG设定一个阈值为0,得到最终的伪计数公式为:
N ^ n ( x ) = ( exp ⁡ ( c ⋅ n − 1 / 2 ⋅ ( PG ⁡ n ( x ) ) + ) − 1 ) − 1 \begin{aligned}\hat{\mathrm{N}}_n(x)&=\left(\exp\left(c\cdot n^{-1/2}\cdot(\operatorname{PG}_n(x))_+\right)-1\right)^{-1}\end{aligned} N^n(x)=(exp(cn1/2(PGn(x))+)1)1
因此最终得到的组合探索奖励为:
r t = r ( x , a ) + ( N ^ n ( x ) ) − 1 / 2 r_t=\begin{aligned}r(x,a)+(\hat{\text{N}}_{n}(x))^{-1/2}\end{aligned} rt=r(x,a)+(N^n(x))1/2

总而言之,引入PixelCNN是为了计算PG,进而计算状态的伪计数,将伪计数转化为智能体的内在奖励

5、实验

1、文章通过实验表明探索奖励(exploration bonus)对智能体性能的影响比较均匀,可以在很多游戏中提高智能体的表现。特别是在Reactor-PixelCNN这个环境下,使用探索奖励的效果要比没有探索奖励的Reactor更好,表现为更高的样本利用效率。

2、文章还通过实验说明了在极度难以探索的游戏中,MMC和PixelCNN探索奖励的组合效果最好,两者相辅相成,加快了训练进展并促使智能体达到高性能水平。

3、文章说明了PixelCNN模型在估计探索奖励方面的有效性,这里与CTS模型进行了对比。

4、文章通过实验发现在一定范围内增加 PG scale 可以加快算法的探索速度,并在多次试验中获得记录峰值分数。但是增加 PG scale 也会导致一些问题。因为探索奖励是一个固定的值,如果过度注重探索奖励,可能会导致算法稳定性下降,从而影响长期性能。

6、结论

虽然目前的伪计数理论对密度模型提出了严格的要求,但文章证明PixelCNN可以在更简单和更一般的设置中使用,并且可以完全在线训练。它还被证明与值函数和基于策略的RL算法广泛兼容。

PixelCNN提高了基础RL算法的学习速度和稳定性。

6、算法伪代码(以DQN为base)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1188936.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Leetcode】202. 两数之和

给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以按任意顺序返回…

Java 身份证号校验,根据身份证号识别出生地

Java 身份证号校验: import org.apache.commons.lang.StringUtils;import java.util.Calendar; import java.util.Collections; import java.util.HashMap; import java.util.Map;/*** desc 身份证工具类* auth llp* date 2022/7/7 16:13*/ public class IdCardNum…

Java算法(三): 判断两个数组是否为相等 → (要求:长度、顺序、元素)相等

Java算法(三) 需求: 1. 定义一个方法,用于比较两个数组是否相同2. 需求:长度,内容,顺序完全相同package com.liujintao.compare;public class SameArray {public static void main (String[] a…

JAVA微信端医院3D智能导诊系统源码

医院智能导诊系统利用高科技的信息化手段,优化就医流程。让广大患者有序、轻松就医,提升医疗服务水平。 随着人工智能技术的快速发展,语音识别与自然语言理解技术的成熟应用,基于人工智能的智能导诊导医逐渐出现在患者的生活视角中…

小红书达人投放比例是多少合适?品牌方必看

品牌做小红书种草推广想要产生更好的效果,是需要素人和达人按照一定比例去进行投放的,素人铺量可以让产品产生迅速曝光的效果,少量达人投放可以让产品产生更好的转化效果。 小红书达人投放具有较高的互动性和口碑传播效果。达人通过自身的影…

打开pr提示找不到vcomp100.dll无法继续执行代码怎么办?5种dll问题解决方案全解析

vcomp100.dll是一个由Microsoft开发的动态链接库(DLL)文件,它对于许多基于图形的应用程序(如Photoshop)和多个游戏(如《巫师3》)至关重要。以下是关于vcomp100.dll的属性介绍以及找不到vcomp100…

小程序如何部署SSL证书

微信小程序开发前提必须拥有一本SSL证书,办理SSL证书之前确保好指定的微信小程序开发接口使用的域名,如果没有域名的提前申请好,并且到国内服务器提供商去办理备案。 了解微信小程序使用SSL证书的作用,包括以下三个方面&#xff1…

[C语言基础]文件读取模式简析

文件操作 打开方式介绍r / rb模式w / wb模式 打开方式介绍 函数fopen可打开一个文件,返回值是文件指针FILE * 第一个参数是文件路径,第二个参数是打开方式mode 参数可为以下几种: r/w/a/r/w/a/rb/wb/ab/rb/wb/ab 其中, r 为只读&…

求臻医学MRD产品喜获北京市新技术新产品(服务)证书

近日,北京市科学技术委员会、中关村科技园区管理委员会、北京市发展和改革委员会等五大部门联合公示了2023年度第一批(总第十八批)北京市新技术新产品(服务)名单。凭借领先的技术能力、产品创新能力及质量可靠性等优势…

大数据毕业设计选题推荐-河长制大数据监测平台-Hadoop-Spark-Hive

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

VM虚拟机安装

想编译一个 c 代码 windows 转成 linux 安装一个vm 准备一个虚拟机安装包,双击,开始安装 下一步 缸盖安装位置路径,添加PATH,下一步 下一步 添加到桌面,加入开始菜单,下一步 打开桌面的软件图标&#…

Panorama SCADA平台的警报通知功能配置详解

1. 前言 SCADA系统的主要目标是采集与监控工业过程数据,以确保工业生产正常运行。通过实时警报通知功能,操作人员可以立即获取有关潜在问题的信息,从而能够快速采取行动解决问题,防止进一步的损害或生产中断。因此,及…

小程序版本审核未通过,需在开发者后台「版本管理—提交审核——小程序订单中心path」设置订单中心页path,请设置后再提交代码审核

小程序版本审核未通过,需在开发者后台「版本管理—提交审核——小程序订单中心path」设置订单中心页path,请设置后再提交代码审核 因小程序尚未发布,订单中心不能正常打开查看,请先发布小程序后再提交订单中心PATH申请 初次提交…

03【远程协作开发、TortoiseGit、IDEA绑定Git插件的使用】

上一篇:02【Git分支的使用、Git回退、还原】 下一篇:【已完结】 目录:【Git系列教程-目录大纲】 文章目录 一、远程协作开发1.1 远程仓库简介1.1.1 Github1.1.2 Gitee1.1.3 其他托管平台 1.2 发布远程仓库1.2.1 创建项目1) 新…

deeplog中输出某个 event 的概率

1 实现之后效果 # import DeepLog and Preprocessor import numpy as np from deeplog import DeepLog import torch# Create DeepLog object deeplog DeepLog(input_size 10, # Number of different events to expecthidden_size 64 , # Hidden dimension, we suggest 64…

K8s----资源管理

目录 一、Secret 1、创建 Secret 1.1 用kubectl create secret命令创建Secret 1.2 内容用 base64 编码,创建Secret 2、使用方式 2.1 将 Secret 挂载到 Volume 中,以 Volume 的形式挂载到 Pod 的某个目录下 2.2 将 Secret 导出到环境变量中 二、Co…

大数据之LibrA数据库系统告警处理(ALM-12033 慢盘故障)

告警解释 系统每一秒执行一次iostat命令,监控磁盘I/O的系统指标,如果在60s内,svctm大于100ms的周期数大于30次则认为磁盘有问题,产生该告警。 更换磁盘后,告警自动恢复。 告警属性 告警ID 告警级别 可自动清除 1…

99% 用户都不知道的 Power BI / Power Query 隐藏功能

Power Query 有一个被糟糕的翻译耽误了的宝藏功能,我估计绝大多数的用户都没发现。 在 Power Query —— 视图 —— 数据预览 下,有几个奇怪的选项 “列分发”、“列配置文件”、“列质量”,从名字根本看不出来是做什么的! 看英文…

sm2加密算法

sm2是一种非对称加密算法。在非对称加密中,加密和解密使用的是不同的密钥对,分别是公钥和私钥。SM2算法是由中国国家密码管理局制定的一种椭圆曲线非对称加密算法,用于数字签名、密钥协商等安全通信场景。 这里使用hutool工具类 Hutool 支持对…

element-ui中el-table数据合并行和列,应该怎么解决

最近接到一个任务,要实现一个数据报表,涉及到很多合并问题,一开始想着原生会简单点,实际上很麻烦,最后还是用elemen-ui中table自带的合并方法. 最终的效果是要做成这种:1.数据处理,后端返回来的数据是,一个大对象,包含三个数组,既然合并,肯定是要处理成一个数组,并且要把相同的…