【强化学习】常用算法之一 “PPO”

news2025/1/11 5:13:21

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        强化学习(Reinforcement Learning)作为一种机器学习的分支,旨在让智能体通过与环境的交互来学习最优的行为策略。近年来,强化学习在各个领域取得了重要的突破,其中Proximal Policy Optimization(PPO)算法是一种重要的策略优化算法。

本文将详细讲解强化学习常用算法之一“PPO”


目录

一、简介

二、发展史

三、算法公式讲解

        1. 目标函数

        2. Surrogate目标函数

        3. 更新步骤

四、算法原理

五、算法功能

六、示例代码

七、总结


一、简介

        强化学习是一种通过智能体与环境的互动来学习最优行为策略的机器学习方法。相较于监督学习和无监督学习,强化学习的特点在于具有延迟奖赏和试错机制。在强化学习中,智能体通过选择动作来影响环境,并且从环境中获得奖励作为反馈。强化学习的目标是通过与环境的交互,使得智能体能够学会最优的行为策略。

        PPO算法属于策略优化(Policy Optimization)算法家族,是由OpenAI在2017年提出的。与其他策略优化算法相比,PPO算法具有较高的样本利用率和较好的收敛性能。该算法在分布式训练和大规模模型上都表现出了较好的性能,因此被广泛应用于各个领域,如机器人控制、自动驾驶、游戏等。

二、发展史

        在介绍PPO算法之前,需要先了解一些相关的算法。PPO算法是基于TRPO(Trust Region Policy Optimization)算法的改进。TRPO算法最初由Schulman等人于2015年提出,通过引入约束条件来保证每次更新的策略改变不会太大,从而确保策略的稳定性。然而,TRPO算法的计算复杂度较高,限制了其应用范围。

        为了解决TRPO算法的计算复杂度问题,Schulman等人在2017年提出了PPO算法。PPO算法通过引入一个修剪概率比率的约束,取代了TRPO算法中的相对熵约束。这样一来,PPO算法的计算复杂度大大降低,使得其在实际应用中更加高效。

三、算法公式讲解

        1. 目标函数

        PPO算法的目标是最大化预期回报函数。设状态为s,行动为a,策略函数为π(a|s),价值函数为V(s),回报函数为R。目标是最大化状态转换的总回报函数G。根据策略梯度定理,可以得到以下目标函数:

J(θ)=E[R(θ)] =E[∑t=0∞γt rt]

        其中,θ表示策略参数,γ表示折扣因子。

        2. Surrogate目标函数

        由于直接优化目标函数需要进行复杂的概率计算,PPO采用了一种近似的优化目标函数。引入一个由策略生成的新旧策略比率,即π(θ)/π(θ_old)。于是目标函数可以转化为:

J_surrogate(θ)=E[min(ratio(θ)A(θ), clip(ratio(θ), 1-ε, 1+ε)A(θ))]

        其中,A(θ)=Q(s,a)-V(s)表示优势函数,ratio(θ)=π(a|s)/π_old(a|s)表示比率,ε表示剪切范围。

        3. 更新步骤

        PPO算法通过交替地进行策略评估和策略改进来训练智能体。在每次迭代中,首先使用当前策略收集一批经验数据,然后使用这些数据来计算并更新策略。具体的更新步骤如下:

  • 收集经验数据;
  • 计算梯度并优化策略函数;
  • 更新价值函数。

四、算法原理

        PPO算法的核心原理是使用近端策略优化,即在每一次迭代中,通过利用大量采样数据来不断优化策略,同时限制策略的变化范围,避免过大的策略更新。

        PPO算法主要包括两个步骤:采样和优化。在采样阶段,算法通过与环境的交互来收集训练数据。在优化阶段,算法利用收集到的数据来更新策略参数,并根据目标函数的梯度信息来更新网络参数。

        PPO算法的基本思路是使用一个重要度采样比率来控制策略更新的范围。在每一次更新中,算法会计算新策略和旧策略之间的重要度采样比率,并利用该比率来限制策略更新的范围。通过引入一个剪切项来限制策略更迭过大,PPO算法可以有效地提高训练的稳定性和效率。

五、算法功能

        PPO算法具有以下几个功能:

  1. 基于策略的优化:PPO算法通过优化策略来提高智能体在环境中的性能,从而实现优化决策和行为。
  2. 高效稳定:PPO算法通过限制策略更新的范围,避免过大的更新,从而提高训练的稳定性和效率。
  3. 广泛适用性:PPO算法适用于解决连续动作空间和高维状态空间问题,可以应用于多个领域,如机器人控制、游戏智能等。

六、示例代码

        下面是一个简单的PPO算法示例代码,用于解决CartPole强化学习任务。

        首先,安装必要的依赖库:

pip install tensorflow
pip install gym

 

        接下来,编写PPO算法的代码: 

# -*- coding: utf-8 -*-
import tensorflow as tf
import gym
import numpy as np

env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
hidden_dim = 32
lr = 0.001

actor_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(hidden_dim, activation='relu', input_shape=(state_dim,)),
    tf.keras.layers.Dense(hidden_dim, activation='relu'),
    tf.keras.layers.Dense(action_dim, activation='softmax')
])

critic_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(hidden_dim, activation='relu', input_shape=(state_dim,)),
    tf.keras.layers.Dense(hidden_dim, activation='relu'),
    tf.keras.layers.Dense(1)
])

actor_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
critic_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

def choose_action(state):
    logits = actor_model.predict(state[np.newaxis, :])[0]
    action = np.random.choice(range(action_dim), p=logits)
    return action

def compute_return(rewards, gamma):
    returns = np.zeros_like(rewards)
    G = 0
    for t in reversed(range(len(rewards))):
        G = rewards[t] + gamma * G
        returns[t] = G
    return returns

def compute_advantage(states, rewards, values, gamma, lamda):
    returns = compute_return(rewards, gamma)
    values = np.append(values, 0)
    deltas = rewards + gamma * values[1:] - values[:-1]
    advantages = np.zeros_like(rewards)
    A = 0
    for t in reversed(range(len(rewards))):
        A = deltas[t] + gamma * lamda * A
        advantages[t] = A
    return returns, advantages

def train_actor(states, actions, advantages, old_probs, eps):
    with tf.GradientTape() as tape:
        logits_new = actor_model(states, training=True)
        probabilities_new = tf.reduce_sum(tf.one_hot(actions, action_dim) * logits_new, axis=1)
        ratios = tf.exp(tf.math.log(probabilities_new) - tf.math.log(old_probs))
        surrogate_obj1 = ratios * advantages
        surrogate_obj2 = tf.clip_by_value(ratios, 1-eps, 1+eps) * advantages
        surrogate_obj = tf.minimum(surrogate_obj1, surrogate_obj2)
        loss = -tf.reduce_mean(surrogate_obj)
    grads = tape.gradient(loss, actor_model.trainable_variables)
    actor_optimizer.apply_gradients(zip(grads, actor_model.trainable_variables))

def train_critic(states, returns):
    with tf.GradientTape() as tape:
        values = critic_model(states, training=True)
        mse = tf.keras.losses.MeanSquaredError()
        loss = mse(returns, tf.squeeze(values))
    grads = tape.gradient(loss, critic_model.trainable_variables)
    critic_optimizer.apply_gradients(zip(grads, critic_model.trainable_variables))

gamma = 0.99
lamda = 0.95
eps = 0.2
max_episodes = 200
max_steps_per_episode = 1000

for episode in range(max_episodes):
    state = env.reset()
    done = False
    episode_reward = 0
    states, actions, rewards, values, old_probs = [], [], [], [], []

    for step in range(max_steps_per_episode):
        action = choose_action(state)
        next_state, reward, done, _ = env.step(action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        values.append(critic_model.predict(state[np.newaxis, :])[0])
        old_probs.append(actor_model.predict(state[np.newaxis, :])[0][action])

        episode_reward += reward
        state = next_state

        if done:
            break

    states = np.array(states)
    actions = np.array(actions)
    rewards = np.array(rewards)
    values = np.array(values)
    old_probs = np.array(old_probs)

    returns, advantages = compute_advantage(states, rewards, values, gamma, lamda)
    returns = returns.astype('float32')
    advantages = advantages.astype('float32')

    train_actor(states, actions, advantages, old_probs, eps)
    train_critic(states, returns)

    print(f"Episode {episode+1}: Reward = {episode_reward}")

env.close()

        运行结果: 

Episode 1: Reward = 14.0
Episode 2: Reward = 13.0
Episode 3: Reward = 9.0
...
Episode 198: Reward = 500.0
Episode 199: Reward = 500.0
Episode 200: Reward = 500.0
 

        这个示例代码使用PPO算法来训练一个Actor模型和Critic模型,通过与环境交互收集训练数据并更新模型参数。最终,在CartPole任务中可以观察到奖励逐渐增加,达到最大奖励500的稳定水平。 

七、总结

        本文详细介绍了强化学习中的PPO算法,包括其简介、发展史、算法公式、算法原理、算法功能、示例代码和运行结果以及如何使用。PPO算法是一种基于策略的优化算法,通过最大化目标函数来优化策略,具有高效稳定和广泛适用性的特点。通过示例代码的讲解,读者可以了解PPO算法的具体实现和使用方法。希望本文对读者能够加深对PPO算法的理解,并能够运用到实际问题中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 操作系统日历完成提醒功能 附带开关闹钟 适配高版本安卓

Android 操作系统日历完成提醒功能 附带开关闹钟 如果想要一个稳定且不用担心生命周期的提醒方式,可以试试利用系统日历去完成任务的提醒或某个活动的预约。 项目仓库地址在文末 环境 Java 11 Android sdk 30 Gredle 7.1 minSdkVersion 23 targetSdkVersion 30测…

js 纯前端实现 重新部署 通知用户刷新网页

需求:有时候上完线,用户还停留在老的页面,用户不知道网页重新部署了,跳转页面的时候有时候js连接hash变了导致报错跳不过去,并且用户体验不到新功能,需要进行优化,每当打包发版后客户进入系统就…

F#奇妙游(1):F#浅尝

F#奇妙游(1):F#浅尝 是什么 F#是.NET平台的OCaml。 这句话很欠打,.NET和OCaml前者知道的人有一些,后者就很少了。.NET平台是一个开源的软件平台,早期由微软主导,目前已经开源,由.…

如何使用CSS Grid 居中 div

本文翻译自 How to Center a Div Using CSS Grid,作者:Fimber Elemuwa, Ralph Mason。 略有删改 在本文中,我们将介绍使用CSS Grid在水平和垂直方向上居中div的五种方法,当然这些技术可用于任何类型的元素。 初始化 我们首先创建…

ASP.Net Core Web API项目发布到IIS(二)

目录 一.启动并配置IIS环境 1.启用或关闭window功能 2.设置万维网服务 3.点击确定等待配置更改 二.创建新的Web网站并进行设置 1.打开IIS管理 2.配置默认的网站 3.创建新的网站 4.测试 三.可能出现的问题 1.404错误 前一篇已经记录了如何创建项目并发布到文件夹&#x…

配置管理数据库(CMDB)

什么是CMDB 配置管理数据库(Configuration Management Database,简称CMDB)是组织IT基础结构中配置项(Configuration Item)及其关系的数据库。CI指示了任何需要管理的、以确保成功交付服务的项目。CI可以是一个具体的实体,如服务器、交换机,也…

软件测试的自动化工具

在软件开发过程中,测试是必不可少的一个环节。而在测试中,测试人员需要花费大量的时间和精力进行手动测试,这不仅费时费力,而且效率较低。因此,自动化测试工具的出现为测试人员提供了更加便捷高效的测试方法。本文将介…

认识CSS

hi,大家好,今天我们来简单认识一下前端三剑客之一的CSS 目录 🐷CSS是什么🐷基本语法规范🐷CSS引入方式🥝内部样式🥝外部样式🥝内联样式 🐷认识选择器🍉标签选择器🍉类选…

最优化--坐标下降法--凸优化问题与凸集

目录 坐标下降法 概念 坐标下降法的步骤 案例演示 数值优化算法面临的问题 凸优化问题与凸集 凸优化问题 性质 优点 凸集 性质 坐标下降法 概念 坐标下降法是一种非梯度优化算法。算法在每次迭代中,在当前点处沿一个坐标方向 进行一维搜索以求得一个函…

Shell、Xshell以及两者的关系

编程语言分为编译型语言(需要使用编译器生成可执行的文件)和解释型语言(需要解释器,不需要编译器)。shell语言是一种解释型语言所使用的解释器有bash解释器或者sh解释器等。我们通过shell命令使之和操作系统交互&#…

漏洞复现-网康(奇安信)NGFW下一代防火墙远程命令执行

漏洞描述 网康下一代防火墙(NGFW)是网康科技推出的一款可全面应对网络威胁的高性能应用层防火墙。该NGFW存在远程命令执行漏洞,攻击者可通过构造特殊请求执行系统命令。凭借超强的应用识别能力,下一代防火墙可深入洞察网络流量中…

vscode python 自定义函数无法跳转到定义处,且定义处无法展示所有调用该函数的位置

问题描述 在vscode中编写python代码,在自定义类的forward函数中调用该类的成员函数,但在调用处无法通过ctrl鼠标左键直接跳转到该成员函数的定义中,系统显示找不到函数声明。同时,在该函数的定义处无法通过ctrl鼠标左键展示项目中…

React小项目-题解列表

1. 项目初始化 首先创建一个新项目 solution-app: npx create-react-app solution-app cd solution-app npm start先将 src 目录中除了 index.css 与 index.js 之外的文件删除,然后创建一个 components 目录,在该目录中创建一个 solution.j…

浅析舆情监测系统

舆情及内容简述 大家对于“舆情”应该有一个简单地概念,尤其是在现在微博、微信、知乎、抖音等平台普及化的今天,舆情的力量日渐凸显。比如最近萧敬腾的求婚、《消失的她》的热议、ikun的翻车等等,舆情既可以让明星塌房,也会让一…

Android Compose UI实战练手----Google Bloom登录页

目录 1.概述2.页面展示1.1 亮色主题1.2暗色主题 3.登录页面拆分以及编码实现3.1 登录页面拆分3.2 编码实现3.2.1 LoginPage3.2.2 LoginTitle3.2.3 LoginInoutBox3.2.4 LoginHintWithUnderLine3.2.5 LoginButton 4.源码地址 1.概述 在之前的章节中我们已经介绍了如何实现Google…

每个前端开发需要了解的10个强大的CSS属性

微信搜索 【大迁世界】, 我会第一时间和你分享前端行业趋势,学习途径等等。 本文 GitHub https://github.com/qq449245884/xiaozhi 已收录,有一线大厂面试完整考点、资料以及我的系列文章。 快来免费体验ChatGpt plus版本的,我们出的钱 体验地…

vue 启动项目报错:TypeError: Cannot set property ‘parent‘ of undefined异常解决

场景:从git上面拉下来一个项目 npm i 下载完依赖以后 npm run serve 去运行项目的时候 报错TypeError: Cannot set property ‘parent’ of undefined 如图所示 原因:首先排查发现判断得出是less解析失败导致 但是经过长时间的查询解决方案发现是因为v…

【Redis一】Redis简介及安装部署

Redis简介及安装部署 1.关系数据库 VS 非关系型数据库1.1 关系型数据库1.2 非关系型数据库1.3 关系型数据库和非关系型数据库区别1.4 非关系型数据库产生背景1.5 关系型数据库与非关系型数据库总结 2.Redis简介2.1 Redis概述2.2 Redis的优点2.3 Redis使用场景2.4 关于Redis的高…

nginx配置vue项目添加访问前缀

文章目录 前言实现需求Nginx配置访问前端正确配置注意点alias的含义举个栗子静态文件及js等404错误 前言 最近,在搞一个SASS系统,将原有的单服务,每次卖出一套啥软件就需要部署一套环境,使得运维人员有些捉襟见肘。产品调整为SAS…

链表理论基础

链表是一种通过指针串联在一起的线性结构,每一个节点由两部分组成,一个是数据域,一个是指针域(存放指向下一节点的指针)。 链表的类型 单链表 每一个节点由两部分组成,一个是数据域一个是指针域&#xf…