AC的改进算法——TRPO、PPO

news2025/1/11 12:49:34

两类AC的改进算法

整理了动手学强化学习的学习内容

1. TRPO 算法(Trust Region Policy Optimization)

1.1. 前沿

策略梯度算法即沿着梯度方向迭代更新策略参数 。但是这种算法有一个明显的缺点:当策略网络沿着策略梯度更新参数,可能由于步长太长,策略突然显著变差,进而影响训练效果。

针对以上问题,考虑在更新时找到一块信任区域(trust region),在这个区域上更新策略时能够得到某种策略性能的安全性保证,这就是信任区域策略优化(trust region policy optimization,TRPO)算法的主要思想。

1.2. 一些推导

首先,最常规的动作价值函数,状态价值函数,优势函数定义如下:
在这里插入图片描述接着,一个策略的好坏可以期望折扣奖励 J ( π θ ) J(\pi_\theta) J(πθ)表示:
J ( π θ ) = E s 0 , a 0 , . . . [ ∑ t = 0 ∞ γ t r ( s t ) ] = E s 0 [ V π θ ( s 0 ) ] J(\pi_\theta)=E_{s_0,a_0,...}[\sum_{t=0}^{\infty}\gamma^tr(s_t)]=E_{s_0}[V^{\pi_\theta}(s_0)] J(πθ)=Es0,a0,...[t=0γtr(st)]=Es0[Vπθ(s0)]

其中, s 0 ∼ ρ 0 ( s 0 ) s_0 \sim \rho_0(s_0) s0ρ0(s0) a t ∼ π θ ( a t ∣ s t ) a_t \sim \pi_\theta(a_t|s_t) atπθ(atst) a t + 1 ∼ P ( s t + 1 ∣ s t , a t ) a_{t+1} \sim P(s_{t+1}|s_t,a_t) at+1P(st+1st,at)
由于初始状态 s 0 s_0 s0的分布 ρ 0 \rho_0 ρ0和策略无关,因此上述策略 π θ \pi_\theta πθ下的优化目标 J ( π θ ) J(\pi_\theta) J(πθ)可以写成在新策略 π θ ′ \pi_{\theta'} πθ的期望形式:
在这里插入图片描述从而,推导新旧策略的目标函数之间的差距:
deltaJ=A将时序差分残差定义为优势函数A:
在这里插入图片描述所以只要我们能找到一个新策略,使得 J ( θ ′ ) − J ( θ ) > = 0 J(\theta')-J(\theta)>=0 J(θ)J(θ)>=0,就能保证策略性能单调递增。

但是直接求解该式是非常困难的,因为 π θ ′ \pi_{\theta'} πθ是我们需要求解的策略,但我们又要用它来收集样本。把所有可能的新策略都拿来收集数据,然后判断哪个策略满足上述条件的做法显然是不现实的。

于是 TRPO 做了一步近似操作,对状态访问分布进行了相应处理。具体而言,忽略两个策略之间的状态访问分布变化,直接采用旧的策略的状态分布,定义如下替代优化目标:
在这里插入图片描述当新旧策略非常接近时,状态访问分布变化很小,这么近似是合理的。其中,动作仍然用新策略 π θ ′ \pi_{\theta'} πθ采样得到,我们可以用重要性采样对动作分布进行处理:
在这里插入图片描述为了保证新旧策略足够接近,TRPO 使用了KL散度来衡量策略之间的距离,并给出了整体的优化公式:
优化这里的不等式约束定义了策略空间中的一个 KL 球,被称为信任区域。在这个区域中,可以认为当前学习策略和环境交互的状态分布与上一轮策略最后采样的状态分布一致,进而可以基于一步行动的重要性采样方法使当前学习策略稳定提升。

1.3. 近似求解

直接求解上式带约束的优化问题比较麻烦,TRPO 在其具体实现中做了一步近似操作来快速求解。
对目标函数和约束在 θ k \theta_k θk进行泰勒展开,分别用 1 阶、2 阶进行近似:
在这里插入图片描述于是我们的优化目标变成了:
在这里插入图片描述此时,我们可以用KKT条件直接导出上述问题的解:
解

1.4. 共轭梯度

一般来说,用神经网络表示的策略函数的参数数量都是成千上万的,计算和存储黑塞矩阵的逆矩阵会耗费大量的内存资源和时间。

TRPO 通过共轭梯度法(conjugate gradient method)回避了这个问题,它的核心思想是直接计算 x = H − 1 g x=H^{-1}g x=H1g x x x即参数更新方向。假设满足 KL距离约束的参数更新时的最大步长为 β = θ ′ − θ \beta=\theta'-\theta β=θθ
于是,根据 KL 距离约束条件 1 2 ( θ ′ − θ k ) T H ( θ ′ − θ k ) < = δ \frac{1}{2}(\theta'-\theta_k)^TH(\theta'-\theta_k)<=\delta 21(θθk)TH(θθk)<=δ,有 1 2 ( β x ) T H ( β x ) = δ \frac{1}{2}(\beta x)^TH(\beta x)=\delta 21(βx)TH(βx)=δ。求解 β \beta β,得到 β = 2 δ x T H x \beta=\sqrt{\frac{2\delta}{x^THx}} β=xTHx2δ 。因此,此时参数更新方式为
θ k + 1 = θ k + 2 δ x T H x x \theta_{k+1}=\theta_k+\sqrt{\frac{2\delta}{x^THx}}x θk+1=θk+xTHx2δ x
因此,只要可以直接计算 x = H − 1 g x=H^{-1}g x=H1g,就可以根据该式更新参数,问题转化为解 H x = g Hx=g Hx=g。实际上 H H H为对称正定矩阵,所以我们可以使用共轭梯度法来求解。
共轭梯度法的具体流程如下:
在这里插入图片描述在共轭梯度运算过程中,直接计算 α k \alpha_k αk r k + 1 r_{k+1} rk+1需要计算和存储海森矩阵 H H H。为了避免这种大矩阵的出现,我们只计算 H x Hx Hx向量,而不直接计算和存储 H H H矩阵。这样做比较容易,因为对于任意的列向量 v v v,容易验证:
Hv即先用梯度和向量 v v v点乘后计算梯度。

    def hessian_matrix_vector_product(self, states, old_action_dists, vector):
        # 计算黑塞矩阵和一个向量的乘积
        new_action_dists = torch.distributions.Categorical(self.actor(states))
        kl = torch.mean(
            torch.distributions.kl.kl_divergence(old_action_dists,
                                                 new_action_dists))  # 计算平均KL距离
        kl_grad = torch.autograd.grad(kl,
                                      self.actor.parameters(),
                                      create_graph=True)
        kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])
        # KL距离的梯度先和向量进行点积运算
        kl_grad_vector_product = torch.dot(kl_grad_vector, vector)
        grad2 = torch.autograd.grad(kl_grad_vector_product,
                                    self.actor.parameters())
        grad2_vector = torch.cat([grad.view(-1) for grad in grad2])
        return grad2_vector

    def conjugate_gradient(self, grad, states, old_action_dists):  # 共轭梯度法求解方程
        x = torch.zeros_like(grad)
        r = grad.clone()
        p = grad.clone()
        rdotr = torch.dot(r, r)
        for i in range(10):  # 共轭梯度主循环
            Hp = self.hessian_matrix_vector_product(states, old_action_dists,
                                                    p)
            alpha = rdotr / torch.dot(p, Hp)
            x += alpha * p
            r -= alpha * Hp
            new_rdotr = torch.dot(r, r)
            if new_rdotr < 1e-10:
                break
            beta = new_rdotr / rdotr
            p = r + beta * p
            rdotr = new_rdotr
        return x

1.5. 线性搜索

由于 TRPO 算法用到了泰勒展开的 1 阶和 2 阶近似,这并非精准求解,因此, θ \theta θ可能未必比 θ k \theta_k θk好,或未必能满足 KL 散度限制。TRPO 在每次迭代的最后进行一次线性搜索,以确保找到满足条件。具体来说,就是找到一个最小的非负整数 i i i,使得按照
θ k + 1 = θ k + α i 2 δ x T H x x \theta_{k+1}=\theta_{k}+\alpha^i \sqrt{\frac{2\delta}{x^THx}}x θk+1=θk+αixTHx2δ x

求出的 θ k + 1 \theta_{k+1} θk+1依然满足最初的 KL 散度限制,并且确实能够提升目标函数,这KaTeX parse error: Undefined control sequence: \apha at position 1: \̲a̲p̲h̲a̲ ̲\in (0,1)其中是一个决定线性搜索长度的超参数。

1.6. 总结

至此,我们已经基本上清楚了 TRPO 算法的大致过程,它具体的算法流程如下:
在这里插入图片描述

2. PPO 算法(Trust Region Policy Optimization)

2.1. 前沿

PPO 算法作为TRPO算法的改进版,但是其算法实现更加简单。并且大量的实验结果表明,与TRPO相比,PPO能学习得一样好(甚至更快),这使得PPO成为非常流行的强化学习算法。如果我们想要尝试在一个新的环境中使用强化学习算法,那么 PPO 就属于可以首先尝试的算法。

PPO 的优化目标与 TRPO 相同,但 PPO用了一些相对简单的方法来求解(TRPO 使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解)。具体来说,PPO 有两种形式,一是 PPO-惩罚,二是 PPO-截断,接下来对这两种形式进行介绍。

2.2. PPO-惩罚

PPO-Penalty拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。即:
无约束的优化问题 d k = D K L π θ k ( π θ k , π θ ) d_k=D_{KL}^{\pi_{\theta_k}}(\pi_{\theta_k},\pi_{\theta}) dk=DKLπθk(πθk,πθ) β \beta β的更新规则如下:

  1. 如果 d k < δ / 1.5 d_k<\delta/1.5 dk<δ/1.5,那么 β k + 1 = β k / 2 \beta_{k+1}=\beta_k/2 βk+1=βk/2
  2. 如果 d k > δ × 1.5 d_k>\delta \times 1.5 dk>δ×1.5,那么 β k + 1 = β k × 2 \beta_{k+1}=\beta_k \times 2 βk+1=βk×2
  3. 否则 β k + 1 = β k \beta_{k+1}=\beta_k βk+1=βk

其中, δ \delta δ是事先设定的一个超参数,用于限制学习策略和之前一轮策略的差距。

2.3 PPO-截断

PPO的另一种形式 PPO-截断(PPO-Clip) 更加直接,它在目标函数中进行限制,以保证新的参数和旧的参数的差距不会太大,即:
在这里插入图片描述其中 c l i p ( x , l , r ) : = m a x ( m i n ( x , r ) , l ) clip(x,l,r):=max(min(x,r),l) clip(x,l,r):=max(min(x,r),l) ,即把 x x x限制在 [ l , r ] [l,r] [l,r]内。上式中 ϵ \epsilon ϵ是一个超参数,表示进行截断(clip)的范围。

如果 A π θ k ( s , a ) > 0 A^{\pi_{\theta_k}}(s,a)>0 Aπθk(s,a)>0,说明这个动作的价值高于平均,最大化这个式子会增大 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta (a|s)}{\pi_{\theta_k} (a|s)} πθk(as)πθ(as),但不会让其超过 1 + ϵ 1+\epsilon 1+ϵ。反之,如果 A π θ k ( s , a ) < 0 A^{\pi_{\theta_k}}(s,a)<0 Aπθk(s,a)<0,最大化这个式子会减小 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta (a|s)}{\pi_{\theta_k} (a|s)} πθk(as)πθ(as),但不会让其超过 1 − ϵ 1-\epsilon 1ϵ。如下图所示。
在这里插入图片描述

代码

最后,两个算法的代码可参考GitHub,Good Night!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/357342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(考研湖科大教书匠计算机网络)第五章传输层-第五节:TCP拥塞控制

获取pdf&#xff1a;密码7281专栏目录首页&#xff1a;【专栏必读】考研湖科大教书匠计算机网络笔记导航 文章目录一&#xff1a;拥塞控制概述二&#xff1a;拥塞控制四大算法&#xff08;1&#xff09;慢开始和拥塞避免A&#xff1a;慢启动&#xff08;slow start&#xff09;…

CTFer成长之路之举足轻重的信息搜集

举足轻重的信息搜集CTF 信息搜集 常见的搜集 题目描述: 一共3部分flag docker-compose.yml version: 3.2services:web:image: registry.cn-hangzhou.aliyuncs.com/n1book/web-information-backk:latestports:- 80:80启动方式 docker-compose up -d 题目Flag n1book{in…

设计模式-代理模式

控制和管理访问 玩过扮白脸&#xff0c;扮黑脸的游戏吗&#xff1f;你是一个白脸&#xff0c;提供很好且很友善的服务&#xff0c;但是你不希望每个人都叫你做事&#xff0c;所以找了黑脸控制对你的访问。这就是代理要做的&#xff1a;控制和管理对象。 监视器编码 需求&…

数据挖掘,计算机网络、操作系统刷题笔记49

数据挖掘&#xff0c;计算机网络、操作系统刷题笔记49 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;可能很多算法学生都得去找开发&#xff0c;测开 测开的话&#xff0c;你就得学数据库&#xff0c;sql&#xff0c;orac…

Spring Cloud Alibaba 微服务简介

微服务简介 1 什么是微服务 2014年&#xff0c;Martin Fowler&#xff08;马丁福勒 &#xff09; 提出了微服务的概念&#xff0c;定义了微服务是由以单一应用程序构成的小服务&#xff0c;自己拥有自己的进程与轻量化处理&#xff0c;服务依业务功能设计&#xff0c;以全自动…

将Nginx 核心知识点扒了个底朝天(四)

为什么 Nginx 不使用多线程&#xff1f; Apache: 创建多个进程或线程&#xff0c;而每个进程或线程都会为其分配 cpu 和内存&#xff08;线程要比进程小的多&#xff0c;所以 worker 支持比 perfork 高的并发&#xff09;&#xff0c;并发过大会榨干服务器资源。 Nginx: 采用…

程序员35岁中年危机不是坎,是一把程序员自己设计的自旋锁

有时候&#xff0c;我会思考35岁这个程序员的诅咒&#xff0c;确切来说是中国程序员的独有的诅咒。 优秀的程序员思维逻辑严谨&#xff0c;弄清楚需求的本质是每天重复的工作&#xff0c;也是对工作的态度&#xff0c;那弄清楚诅咒的来源&#xff0c;义不容辞。 被诅咒的35岁 …

【爬虫】自动获取showdoc指定项目中的所有文档

▒ 目录 ▒&#x1f6eb; 导读需求1️⃣ 格式分析官方下载文件内容prefix_info.json文件格式2️⃣ 封包分析/api/page/info/api/item/info3️⃣ 编码代码特点问题&#x1f4d6; 参考资料&#x1f6eb; 导读 需求 showdoc是一个API文档、技术文档工具网站&#xff0c;经常能搜到…

String intern方法理解

1、原理 参考学习视频&#xff1a; https://www.bilibili.com/video/BV1WK4y1M77t/?spm_id_from333.337.search-card.all.click&vd_source4dc3f886f5ce1d43363b603935f02bd1 String s1 “hello”; String s1 "hello"; 代码原理解释如下图String s1 new Str…

进程章节总结性实验

进程实验课笔记 本节需要有linux基础&#xff0c;懂基本的linux命令操作即可。 Ubuntu镜像下载 https://note.youdao.com/s/VxvU3eVC ubuntu安装 https://www.bilibili.com/video/BV1j44y1S7c2/?spm_id_from333.999.0.0 实验环境ubuntu22版本&#xff0c;那个linux环境都可以…

Linux-VMware常用设置(时间+网络)及网络连接激活失败解决方法-基础篇②

目录一、设置时间二、网络设置1. 激活网卡方法一&#xff1a;直接启动网卡&#xff08;仅限当此&#xff09;方法二&#xff1a;修改配置文件&#xff08;永久&#xff09;2. 将NAT模式改为桥接模式什么是是NAT模式&#xff1f;如何改为桥接模式&#xff1f;三、虚拟机网络连接…

20230219 质心和重心的区别和性质

质心&#xff1a;&#xff08;无需重力场的前提&#xff09;所有质点的位置关于它们的质量的加权平均数。 重心&#xff1a;&#xff08;需要重力场的前提&#xff09;重力对系统中每个质点关于重心的力矩之和为零。 质心&#xff1a; xˉ∑i1nmixi∑i1nmi,yˉ∑i1nmiyi∑i1nmi…

Fiddler的报文分析

目录 1.Statistics请求性能数据 2.检测器&#xff08;Inspectors&#xff09; 3.自定义响应&#xff08;AutoResponder&#xff09; 1.Statistics请求性能数据 报文分析&#xff1a; Request Count: 1 请求数&#xff0c;该session总共发的请求数 Bytes …

vue3.0 生命周期

目录前言&#xff1a;vue3.0生命周期图例1.beforeCreate2.created3.beforeMount/onBeforeMount4.mounted/onMounted5.beforeUpdate/onBeforeUpdate6.updated/onUpdated7.beforeUnmount/onBeforeUnmount8.unmounted/onUnmounted案例&#xff1a;总结前言&#xff1a; 每个Vue组…

智慧城市应急指挥中心数字化及城市驾驶舱建设方案

目 录 第一章 项目概述 1.1 项目背景 1.2 项目范围 第二章 建设内容 2.1 三维可视化平台 2.1.1 多源数据接入 2.1.2 可视化编排 2.1.3 三维可视化编辑 2.1.4 空间数据可视化 2.1.5 集成框架支持 2.2 可视化场景定制开发 2.2.1 城市驾驶总舱 2.2.2 城市安全分舱 2.…

PLT/PDF转CAD:scViewerX 8.1 Crack

scViewerX是一个功能强大的 ActiveX 控件&#xff0c;允许您查看、打印和转换 PLT、Adobe PDF、Autodesk DWF、CGM、Calcomp、HPGL/2、Gerber、TIF、CALS 和其他几种格式。 ScViewerX 可以将您的文件转换为多种不同的输出文件格式&#xff0c;包括 PDF、PDF/A、TIFF、DXF、DWF、…

【人工智能AI】三、NoSQL 实战《NoSQL 企业级基础入门与进阶实战》

帮我写一篇介绍NoSQL的技术文章&#xff0c;文章标题是《NoSQL 实战》&#xff0c;不少于3000字。这篇文章的目录是 3.NoSQL 实战 3.1 MongoDB 入门 3.1.1 MongoDB 基本概念 3.1.2 MongoDB 安装与配置 3.1.3 MongoDB 数据库操作 3.2 Redis 入门 3.2.1 Redis 基本概念 3.2.2 Red…

windows微软商店下载应用失败/下载故障的解决办法;如何在网页上下载微软商店的应用

一、问题背景 设置惠普打印机时&#xff0c;需要安装hp smart&#xff0c;但是官方只提供微软商店这一下载渠道。 点击安装HP Smart&#xff0c;确定进入微软商店下载。 完全加载不出来&#xff0c;可能是因为开了代理。 把代理关了&#xff0c;就能正常打开了。 但是点击“…

IsADirectoryError: [Errno 21] Is a directory【已解决】

问题描述 生成数据&#xff0c;存储时候报错。 IsADirectoryError: [Errno 21] Is a directory: /home/LIST_2080Ti/njh/CHB-MIT-DATA/epilepsy_eeg_classification/data_processing/chb28/520.csv 问题分析 按我的认知&#xff0c;python执行的时候&#xff0c;比如这句 d…

FLAT:Flat-LAttice Transformer

中文NLP的一个问题&#xff0c;就是中文的字除了句句之间有标点符号之外都是连在一起的&#xff0c;不像英文词语是单独分割的。中文NLP处理一般会有2种方式&#xff1a;基于字的&#xff0c;char-level。现在比较常用的方法&#xff0c;但会缺少词组的语义信息。基于词的&…