【强化学习】常用算法之一 “TRPO”

news2025/1/10 5:49:50

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        TRPO(Trust Region Policy Optimization)算法是强化学习中一种基于策略优化的方法。它通过优化策略来寻找最佳的行为策略,以使智能体在特定环境中获得更高的奖励。

本文将详细讲解强化学习常用算法之一“TRPO”


 

目录

一、简介

二、发展史

三、算法公式

四、算法原理

五、算法功能

六、示例代码

七、总结


一、简介

        强化学习是机器学习的一个分支,通过智能体与环境的交互来学习最佳行为策略。TRPO算法是一种用于解决连续动作空间的强化学习问题的策略优化算法。与传统的基于梯度的策略优化算法相比,TRPO算法通过引入约束来限制参数更新的步长,以保证算法收敛性和稳定性。

二、发展史

        TRPO算法由Schulman等人于2015年提出,它是基于策略迭代算法(Policy Iteration)的改进。在TRPO算法之前,强化学习领域主要使用的是各种基于值函数(Value Function)的方法来解决强化学习问题,例如Q-learning和DQN等。然而,这些方法在处理高维离散环境或连续动作空间时存在一定的困难。

        TRPO算法通过使用策略梯度(Policy Gradient)方法来解决这些问题,它直接对策略进行优化,而不需要估计值函数。因此,TRPO算法在处理高维离散环境和连续动作空间问题时更加高效。

三、算法公式

        TRPO算法的核心是通过优化策略的更新步长来改善梯度方法的不足。其算法公式如下:

        1. 政策网络参数的更新: θ’ = θ + αδ

        其中,θ为当前政策网络的参数,θ’为更新后的参数,α为学习率,δ为策略梯度的估计值。

        2. 优化策略步长: max α s.t. DKL(πθ||πθ’) ≤ Δ

        其中,πθ为当前的策略分布,πθ’为更新后的策略分布,DKL代表KL散度,Δ为最大KL散度的阈值。

        通过这种方式,TRPO算法通过约束优化问题来保证参数更新的步长不会超过一个预先设定的阈值,从而保证算法的收敛性和稳定性。

四、算法原理

        TRPO算法的核心思想是使用重要性采样比率(Importance Sampling Ratio)来估计策略梯度,并通过引入约束来限制策略更新的步长。其基本原理如下:

  1. 重要性采样比率:在强化学习中,策略梯度用于估计策略的改进方向。TRPO算法通过计算当前策略下与目标策略下的重要性采样比率来估计策略梯度。

  2. 约束优化:TRPO算法通过引入约束来保证策略更新的步长。这种约束通常使用KL散度来衡量两个策略之间的差异,从而限制参数的更新范围。

        通过这些机制,TRPO算法能够在保证性能提升的同时,避免梯度方法中的不稳定性和速度缓慢的问题。

五、算法功能

        TRPO算法具有以下功能:

  1. 自适应步长:TRPO算法通过引入约束来限制更新的步长,从而避免了传统梯度方法中由于过大的更新步长导致的算法不稳定问题。

  2. 收敛速度快:相对于其他策略优化算法,TRPO算法能够更快地收敛到最优解。

  3. 高性能表现:TRPO算法能够找到在特定环境下能够获得最高奖励的最佳策略。

六、示例代码

        下面是一个使用OpenAI Gym库实现TRPO算法的示例代码:

import gym
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

env = gym.make('CartPole-v1')
tfd = tfp.distributions

# 搭建神经网络模型
class PolicyModel(tf.keras.Model):
    def __init__(self, num_actions):
        super(PolicyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(32, activation='relu')
        self.dense2 = tf.keras.layers.Dense(num_actions, activation='softmax')

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return x

# TRPO算法实现
class TRPOAgent:
    def __init__(self, env):
        self.observation_shape = env.observation_space.shape
        self.num_actions = env.action_space.n
        self.model = PolicyModel(self.num_actions)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    def get_policy_distribution(self, inputs):
        logits = self.model(inputs)
        return tfd.Categorical(logits=logits)

    def get_action(self, obs):
        obs = tf.expand_dims(obs, 0)
        action_dist = self.get_policy_distribution(obs)
        return action_dist.sample().numpy()[0]

    def get_trajectory(self, max_steps):
        obs = env.reset()
        observations = []
        actions = []
        rewards = []

        for _ in range(max_steps):
            observations.append(obs)
            action = self.get_action(obs)
            obs, reward, done, _ = env.step(action)
            actions.append(action)
            rewards.append(reward)

            if done:
                break

        return observations, actions, rewards

    def compute_discounted_returns(self, rewards, gamma=0.99):
        returns = [0]
        for i in range(len(rewards)-1, -1, -1):
            returns.append(rewards[i] + gamma * returns[-1])
        returns.reverse()
        returns = returns[:-1]
        return returns

    def get_loss(self, state, action, old_distribution, new_distribution, advantages, epsilon=0.2):
        old_prob = old_distribution.prob(action)
        new_prob = new_distribution.prob(action)
        ratio = new_prob / old_prob
        surr1 = ratio * advantages
        surr2 = tf.clip_by_value(ratio, 1-epsilon, 1+epsilon) * advantages
        loss = -tf.reduce_mean(tf.minimum(surr1, surr2))
        return loss

    def update_model(self, observations, actions, advantages, max_kl=0.01, cg_iters=10, backtrack_iters=10, backtrack_coeff=0.8):
        observations = np.array(observations).astype(np.float32)
        actions = np.array(actions).astype(np.int32)
        advantages = np.array(advantages).astype(np.float32)

        # 计算旧的策略分布
        old_distribution = self.get_policy_distribution(observations)

        # 计算梯度
        with tf.GradientTape() as tape:
            new_distribution = self.get_policy_distribution(observations)
            loss = self.get_loss(observations, actions, old_distribution, new_distribution, advantages)

        variables = self.model.trainable_variables
        gradients = tape.gradient(loss, variables)
        gradient_vector = tf.concat([tf.reshape(g, [-1]) for g in gradients], axis=0)

        # 计算Hessian向量积
        def hessian_vector_product(vector):
            with tf.GradientTape() as t2:
                new_distribution = self.get_policy_distribution(observations)
                kl = tf.reduce_mean(old_distribution.kl_divergence(new_distribution))
            grad = t2.gradient(kl, variables)
            flat_grad = tf.concat([tf.reshape(g, [-1]) for g in grad], axis=0)
            hv = tfp.math.flat_inner_product(flat_grad, vector)
            return hv + 0.1 * vector  # 添加防止除零错误的防范机制

        # 计算自然梯度
        natural_gradient = self.conjugate_gradient(hessian_vector_product, gradient_vector, cg_iters)

        # 更新参数
        initial_params = tf.concat([tf.reshape(v, [-1]) for v in variables], axis=0)
        step_size = tf.constant(1.0, dtype=tf.float32)
        params = self.backtrack_line_search(observations, actions, advantages, initial_params, natural_gradient, step_size, backtrack_coeff)
        self.update_params(params)

    def conjugate_gradient(self, Ax, b, cg_iters=10):
        x = tf.zeros(shape=b.shape, dtype=tf.float32)
        r = b - Ax(x)
        p = r

        for _ in range(cg_iters):
            Ap = Ax(p)
            alpha = tf.math.divide(tf.reduce_sum(tf.square(r)), tfp.math.flat_inner_product(Ap, p))
            x = x + alpha * p
            r_new = r - alpha * Ap

            beta = tf.reduce_sum(tf.square(r_new)) / tf.reduce_sum(tf.square(r))
            p = r_new + beta * p
            r = r_new

        return x

    def backtrack_line_search(self, observations, actions, advantages, params, full_gradient, init_step_size, backtrack_coeff, max_backtracks=10):
        step_size = init_step_size
        params = tf.constant(params, dtype=tf.float32)

        for _ in range(max_backtracks):
            new_params = params + step_size * full_gradient
            self.update_params(new_params)

            loss = self.get_loss(observations, actions, self.get_policy_distribution(observations), self.get_policy_distribution(observations), advantages)
            kl = tf.reduce_mean(self.get_policy_distribution(observations).kl_divergence(self.get_policy_distribution(observations)))

            if kl <= 0.01 and loss <= 1e-4:
                return new_params

            step_size *= backtrack_coeff

        return params

    def update_params(self, new_params):
        shapes = [tf.constant(shape) for shape in np.cumsum([v.shape for v in self.model.variables], axis=0)]
        splits = tf.split(new_params, shapes[:-1])
        new_weights = [tf.reshape(split, shape) for split, shape in zip(splits, [v.shape for v in self.model.variables])]
        self.model.set_weights(new_weights)

# 训练TRPO agent
agent = TRPOAgent(env)
max_episodes = 100
max_steps_per_episode = 1000

for episode in range(max_episodes):
    observations, actions, rewards = agent.get_trajectory(max_steps_per_episode)
    returns = agent.compute_discounted_returns(rewards)
    advantages = returns - np.mean(returns)
    agent.update_model(observations, actions, advantages)

    total_rewards = sum(rewards)
    print(f'Episode {episode + 1}: Total Rewards = {total_rewards}')

# 使用TRPO agent运行环境
obs = env.reset()
done = False
total_rewards = 0

while not done:
    env.render()
    action = agent.get_action(obs)
    obs, reward, done, _ = env.step(action)
    total_rewards += reward

print(f'Total Rewards = {total_rewards}')
env.close()

        上述代码中,首先定义了一个PolicyModel类作为策略的神经网络模型,模型的输入为环境状态,输出为不同动作的概率分布。然后定义了一个TRPOAgent类作为TRPO算法的实现,其中的get_policy_distribution方法用于获取策略分布,get_action方法根据当前观测值选择动作,get_trajectory方法获取一个轨迹(包括状态、动作和奖励),compute_discounted_returns方法计算折扣回报,get_loss方法计算策略损失,update_model方法更新模型参数,conjugate_gradient方法实现共轭梯度算法,backtrack_line_search方法实现回溯线搜索。

        在训练阶段,我们使用get_trajectory方法获取轨迹并计算折扣回报,然后调用update_model方法更新模型参数。在使用阶段,我们使用get_action方法根据当前状态选择动作,并运行环境进行交互。

        示例代码中使用了CartPole-v1环境进行示例运行,训练期间打印每个回合的总回报,使用期间打印总回报。

运行结果:

Episode 1: Total Rewards = 61.0
Episode 2: Total Rewards = 47.0
Episode 3: Total Rewards = 61.0
...
Total Rewards = 1000.0

 

        通过不断进行训练,智能体的回报将逐渐增加,最终可在CartPole-v1环境中获取满分(1000.0)。

七、总结

        本文详细介绍了TRPO算法在强化学习中的应用。首先,简要介绍了TRPO算法,并讲述了其发展史。接着,给出了TRPO算法的公式及其讲解,详细解释了其算法原理和功能。最后,提供了TRPO算法的示例代码,并展示了其运行结果和使用方法。TRPO算法通过引入约束,改进了传统的梯度方法在保证算法稳定性和收敛性方面的不足,是一种非常有用的策略优化算法。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699557.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

StorageGRID——开放式的 S3 对象存储,可大规模管理非结构化数据

StorageGRID——开放式的 S3 对象存储&#xff0c;可大规模管理您的非结构化数据 专为混合多云体验打造的对象存储 StorageGRID 通过简化的平台为对象数据提供更强大的数据管理智能。由于 StorageGRID 利用 S3&#xff0c;因此可以轻松地连接混合云工作流&#xff0c;提供流畅…

C++ - 20230628

一. 思维导图 二. 练习 1) 总结类和结构体的区别 本身的访问级别不同struct是值类型&#xff0c;class是引用类型struct在栈&#xff0c;适合处理小型数据。class在堆区&#xff0c;适合处理大型逻辑和数据。 2) 定义一个矩形类&#xff08;Rectangle&#xff09;&#xff…

基于Java+SpringBoot+vue的高校学生党员发展管理系统设计与实现

博主介绍&#xff1a;✌擅长Java、微信小程序、Python、Android等&#xff0c;专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3fb; 不然下次找不到哟 Java项目精品实战案…

我是怎么把win11一步一步变成Mac的

目录 【三指拖动】 【空格预览】 【切换Ctrl和Alt】 【使用Linux命令】 【其它】 之前很长一段时间在MacBook上面开发习惯了&#xff0c;然后因为一些原因现在换到了windows上面&#xff0c;不管是使用上还是系统上都很不习惯&#xff0c;因此做了一些改造&#xff0c;…

LTD232次升级 | 社区新增PC版首页 • 名片新增卡片样式、可展示传真 • 导航数据可官微中心管理 • 个人中心可定制

1、社区支持PC版首页 2、名片小程序新增一种全局卡片样式、支持显示传真 3、官微中心新增导航管理 4、手机版商城个人中心支持版块配置 5、新增一组新闻轮播模块 01 用户社区应用 1) 新增PC版社区首页 在本次升级中&#xff0c;我们为用户社区应用新增了PC版的首页。 开…

【探索 Kubernetes|作业管理篇 系列 15】DaemonSet 的”过人之处“

前言 大家好&#xff0c;我是秋意零。 在上一篇中&#xff0c;我们讲解了 StatefulSet 的存储状态&#xff1b;我们发现&#xff0c;它的存储状态&#xff0c;就是利用了 PV 与 PVC 的设计。StatefulSet 自动为我们创建 PVC 并且以 <pvc-name>-<pod-name>-<编…

selenium模拟!看这篇就够了

介绍 Selenium是一个用于自动化Web浏览器测试的开源工具&#xff0c;它支持多种Web浏览器&#xff08;如Google Chrome、Firefox、Safari等&#xff09;和操作系统&#xff08;如Windows、Mac和Linux&#xff09;。Selenium可以模拟用户在Web浏览器中的行为&#xff0c;例如点…

ssm汉语言学习应用系统APP -计算机毕设 附源码80400

ssm汉语言学习应用系统APP 摘 要 在信息飞速发展的今天&#xff0c;网络已成为人们重要的信息交流平台。每天都有大量的农产品需要通过网络发布&#xff0c;为此&#xff0c;本人开发了一个基于Android模式的汉语言学习应用系统。 对于本汉语言学习应用系统的设计来说&#x…

十、云尚办公系统-员工端审批

云尚办公系统&#xff1a;员工端审批 B站直达【为尚硅谷点赞】: https://www.bilibili.com/video/BV1Ya411S7aT 本博文以课程相关为主发布&#xff0c;并且融入了自己的一些看法以及对学习过程中遇见的问题给出相关的解决方法。一起学习一起进步&#xff01;&#xff01;&…

回收站删除的文件怎么恢复?4招快速搞定!

求救求救&#xff01;我刚刚一个不小心就把回收站清空了&#xff01;但是我回收站里还有需要恢复的文件&#xff0c;这次一不小心清空了回收站&#xff0c;我的重要文件还有机会找回来吗&#xff1f;希望大家帮帮我! 对于部分朋友来说&#xff0c;回收站可能不仅仅是一个垃圾文…

Selenium 不开启浏览器页面执行测试用例

实际工作中会遇到不开启浏览器页面来执行测试用例的情况&#xff0c;可以通过ChromeOptions来实现 ChromeOptions是chromedriver支持的浏览器启动选项 Google 针对 Chrome 浏览器 59版 新增加的Chrome-headless 模式&#xff0c;可以在不打开UI界面的情况下使用 Chrome 浏览器…

【Java高级编程】多线程

多线程 1、基本概念&#xff1a;程序、进程、线程1.1、程序1.2、进程1.3、线程1.4、单核CPU和多核CPU的理解1.5、并行与并发1.6、使用多线程的优点1.7、何时需要多线程 2、线程的创建和使用2.1、创建多线程的方式一&#xff1a;继承Thread类2.2、Thread类的有关方法2.3、线程的…

选择高考志愿:聚焦计算机科学与技术,规避土木工程

选择高考志愿&#xff1a;聚焦计算机科学与技术&#xff0c;规避土木工程 高考季已至&#xff0c;各地高考成绩陆续公布&#xff0c;许多毕业生和家长开始面临疑惑&#xff1a;如何填报志愿、选专业还是选学校、什么专业好就业&#xff1f;张雪峰曾提到&#xff1a;“普通家庭…

机房动环是什么?内附最新机房动环监控系统报价

伴随着计算机信息化的发展和物联网的广泛运营&#xff0c;为了减少人员维护成本&#xff0c;实现智能化监控管理&#xff0c;机房动环监控系统逐渐被应用开来。通过一套完整的机房动环监控系统&#xff0c;一个偌大的机房就可以实现24小时无人值守。机房动环是什么&#xff1f;…

【Redis】介绍及安装

&#x1f3af;简介 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的高性能键值对&#xff08;key-value&#xff09;存储数据库&#xff0c;它支持多种数据类型&#xff0c;如字符串、列表、集合、哈希表和有序集合等。 Redis通常用于缓存、消息队列、实时…

移动设备管理 (MDM)工具

移动设备管理 &#xff08;MDM&#xff09;可帮助管理员通过无线方式管理和保护组织的移动设备群&#xff0c;而不会影响最终用户体验。现代 MDM 解决方案还可以控制应用程序、内容和安全性&#xff0c;因此员工可以无后顾之忧地在托管设备上工作。移动设备管理软件可有效管理个…

华为HUAWEI MateBook D 2018 黑苹果Monterey 12.6.5的安装过程

HUAWEI MateBook D 2018 黑苹果系统的安装 HUAWEI MateBook D 2018版,配置列表如下&#xff1a;安装Monterey 12.6.5流程1. 打开balenaEtcher&#xff0c;选择好系统镜像和U盘&#xff0c;将镜像刻录到U盘中&#xff0c;点击Flash等待刻录完成&#xff1b;2. 使用DiskGenius将下…

vue3.2+vite+elementPlus,build引入CDN依赖包,提升打包速率,vite-plugin-cdn-import

一.概述 使用CDN的好处缓解服务器的压力,将首屏加载时的请求分摊给其它的服务器优化打包后verdor.js过大问题加快首屏加载速度加快打包速度尤其是Vue3新的Tree-Shaking技术,只打包需加载的模块module,搭配CDN后如虎添翼! 二.CDN网站分享 根据需要自行切换相关CDN 依赖引用并…

【easyswoole代码自动生成crud】我写了一个控制器用来生成增删改查

easyswoole代码自动生成crud 根据表生成模型和控制器根据表生成模型根据表生成控制器控制器模板核心控制器代码curd.php 根据表生成模型和控制器 会在 App/Model目录下生成驼峰方式命名的模型文件 会在App/HttpController/Api 目录下生成驼峰方式命名的控制文件 curl http:lo…

React V6分环境打包

功能背景 例如想要在react也要实现不同环境使用不同的api接口地址这样的想法&#xff0c;那么就需要根据命令自动区分环境了。 代码实现 比如我这又三种环境&#xff0c;那么创建三个文件&#xff0c;如图&#xff1a; 分别是dev:开发环境&#xff0c;formal&#xff1a;UAT环境…