【强化学习】策略梯度(Policy Gradient,PG)算法

news2025/1/21 21:52:03

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】- 【单智能体强化学习】(5)---《策略梯度(Policy Gradient,PG)算法》

策略梯度(Policy Gradient,PG)算法

目录

一、概述

二、核心概念

三、PG算法的基本思想

四、策略梯度公式推导

[Python] Policy Gradient算法实现

参数解析部分

环境初始化与随机种子设置

Policy网络类定义

Policy对象和优化器初始化

 选择动作的函数

结束当前回合的函数

主要训练循环

程序入口

[Results] 运行结果

[Content]主要内容:

[Notice]  注意事项:

五、优缺点

六、总结


一、概述

        在强化学习中,Policy Gradient(策略梯度)算法是一类通过优化策略函数直接来求解最优策略的方法。与基于值函数(例如Q学习和SARSA)的方法不同,策略梯度方法直接对策略函数进行建模,目标是通过梯度下降的方法来最大化预期的累积奖励(即期望回报)。这些算法主要适用于连续的动作空间或高维问题,能够在复杂的环境中取得较好的性能。


二、核心概念

  1. 策略(Policy):策略是一个从状态空间到动作空间的映射。我们可以表示策略为 \pi(a | s) ,表示在状态 s 下采取动作 a的概率。

  2. 回报(Return):从某一时刻起,智能体在环境中继续执行策略时所能获得的累计奖励,通常我们用 G_t 来表示从时间 t开始的回报。

  3. 梯度(Gradient):在PG算法中,我们通过计算策略函数的梯度,来调整策略,使得在某个状态下选取最优的动作。


三、PG算法的基本思想

        PG算法的核心思想是通过策略梯度来优化策略函数。其目标是最大化累积奖励的期望,即:

J(\theta) = \mathbb{E}{\pi{\theta}} \left[ G_t \right]

其中,( J(\theta) )是目标函数,表示参数化策略的期望回报,( \pi_{\theta}(a | s) )是由参数( \theta ) 定义的策略。为了优化这个目标函数,我们需要通过梯度上升法来调整参数( \theta )


四、策略梯度公式推导

        为了最大化目标函数 ( J(\theta) ),我们首先需要计算目标函数对策略参数( \theta )的梯度( \nabla_{\theta} J(\theta) )。根据期望的定义,我们有:

J(\theta) = \mathbb{E}{\pi{\theta}} [G_t] = \sum_{s} p(s) \sum_{a} \pi_{\theta}(a|s) Q_{\pi}(s, a)

其中,( Q_{\pi}(s, a) )是在策略( \pi ) 下,智能体在状态 ( s )选择动作( a )后的状态-动作值函数。为了计算梯度,我们需要用到交换微分与期望的操作:

\nabla_{\theta} J(\theta) = \sum_{s} p(s) \sum_{a} \nabla_{\theta} \left( \pi_{\theta}(a|s) Q_{\pi}(s, a) \right)

我们可以分开对策略函数和价值函数求梯度:

\nabla_{\theta} J(\theta) = \sum_{s} p(s) \sum_{a} Q_{\pi}(s, a) \nabla_{\theta} \pi_{\theta}(a|s)

通过这个公式,我们可以得到策略参数 ( \theta ) 的更新方向,即梯度。然后,我们使用梯度上升法来调整策略参数:

\theta_{t+1} = \theta_t + \alpha \nabla_{\theta} J(\theta)

其中,( \alpha )是学习率。


[Python] Policy Gradient算法实现

        实现了一个基于 PyTorch 的强化学习算法 Policy Gradient算法,主要用于训练一个在 CartPole-v1 环境中平衡杆的智能体。

        🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。

"""《 Policy Gradient算法实现》
    时间:2024.12
    环境:CartPole-v1
    作者:不去幼儿园
"""
import argparse  # 导入命令行参数解析库
import gym  # 导入OpenAI Gym库,用于创建强化学习环境
import numpy as np  # 导入numpy库,用于处理数值计算
from itertools import count  # 导入count函数,用于生成整数序列

import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch的神经网络模块
import torch.nn.functional as F  # 导入PyTorch的神经网络功能函数模块
import torch.optim as optim  # 导入PyTorch的优化器模块
from torch.distributions import Categorical  # 导入PyTorch中类别分布的模块,用于离散动作选择

参数解析部分

parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')  # 创建一个ArgumentParser对象,描述REINFORCE算法示例
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',  # 添加一个命令行参数:gamma,表示折扣因子(默认0.99)
                    help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',  # 添加一个命令行参数:seed,用于设置随机种子(默认543)
                    help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',  # 添加一个命令行参数:render,用于决定是否渲染环境
                    help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',  # 添加一个命令行参数:log-interval,表示训练状态日志的输出间隔(默认10)
                    help='interval between training status logs (default: 10)')
args = parser.parse_args()  # 解析命令行参数并将结果保存到args对象中

环境初始化与随机种子设置

env = gym.make('CartPole-v1')  # 创建一个CartPole-v1环境,用于训练
torch.manual_seed(args.seed)  # 设置PyTorch的随机种子,以确保可复现性

Policy网络类定义

class Policy(nn.Module):  # 定义一个名为Policy的类,继承自nn.Module,这是一个神经网络模型
    def __init__(self):  # 初始化方法
        super(Policy, self).__init__()  # 调用父类构造方法
        self.affine1 = nn.Linear(4, 128)  # 第一层全连接层,将输入维度4映射到128维
        self.affine2 = nn.Linear(128, 2)  # 第二层全连接层,将输入维度128映射到2维(表示动作空间的大小)

        self.saved_log_probs = []  # 用于保存每一时刻的动作概率的对数
        self.rewards = []  # 用于保存每一时刻的奖励

    def forward(self, x):  # 定义前向传播方法
        x = F.relu(self.affine1(x))  # 经过第一层后使用ReLU激活函数
        action_scores = self.affine2(x)  # 通过第二层得到每个动作的得分
        return F.softmax(action_scores, dim=1)  # 使用softmax函数计算动作的概率分布

Policy对象和优化器初始化

policy = Policy()  # 创建Policy类的实例
optimizer = optim.Adam(policy.parameters(), lr=1e-2)  # 使用Adam优化器优化Policy模型,学习率为0.01
eps = np.finfo(np.float32).eps.item()  # 获取float32类型的最小正数,用于避免除零错误

  选择动作的函数

def select_action(state):  # 定义选择动作的函数
    state = torch.from_numpy(state).float().unsqueeze(0)  # 将输入的状态从numpy数组转换为PyTorch张量,并增加一个维度
    probs = policy(state)  # 通过Policy网络计算每个动作的概率分布
    m = Categorical(probs)  # 用Categorical分布定义一个概率分布对象
    action = m.sample()  # 从该分布中采样一个动作
    policy.saved_log_probs.append(m.log_prob(action))  # 保存该动作的对数概率
    return action.item()  # 返回动作的值

结束当前回合的函数

def finish_episode():  # 定义结束一个回合的函数
    R = 0  # 初始化回报R为0
    policy_loss = []  # 初始化用于保存每个动作损失的列表
    rewards = []  # 初始化保存所有回报的列表
    for r in policy.rewards[::-1]:  # 从后往前遍历奖励列表
        R = r + args.gamma * R  # 计算当前时刻的回报(折扣奖励)
        rewards.insert(0, R)  # 将回报插入到列表的开头
    rewards = torch.tensor(rewards)  # 将奖励转换为PyTorch张量
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)  # 对奖励进行标准化
    for log_prob, reward in zip(policy.saved_log_probs, rewards):  # 遍历每个动作的对数概率和奖励
        policy_loss.append(-log_prob * reward)  # 计算每个动作的损失(负对数概率与标准化奖励的乘积)
    optimizer.zero_grad()  # 清除梯度
    policy_loss = torch.cat(policy_loss).sum()  # 计算所有动作的总损失
    policy_loss.backward()  # 反向传播计算梯度
    optimizer.step()  # 执行一步优化
    del policy.rewards[:]  # 清空保存的奖励列表
    del policy.saved_log_probs[:]  # 清空保存的对数概率列表

主要训练循环

def main():  # 定义主函数
    running_reward = 10  # 初始化运行奖励
    for i_episode in count(1):  # 从1开始无限循环
        state, _ = env.reset()  # 重置环境并获取初始状态
        for t in range(10000):  # 限制每个回合的最大步数为10000
            action = select_action(state)  # 选择动作
            state, reward, done, _, _ = env.step(action)  # 执行动作并获取下一个状态、奖励等信息
            if args.render:  # 如果设置了渲染选项
                env.render()  # 渲染环境
            policy.rewards.append(reward)  # 保存奖励
            if done:  # 如果回合结束
                break  # 跳出循环

        running_reward = running_reward * 0.99 + t * 0.01  # 更新运行奖励(使用指数加权移动平均)
        finish_episode()  # 结束当前回合并进行学习
        if i_episode % args.log_interval == 0:  # 每log_interval步输出一次日志
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))  # 打印当前回合、回合长度和平均长度
        if running_reward > env.spec.reward_threshold:  # 如果运行奖励超过环境的奖励阈值
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))  # 打印成功信息
            break  # 结束训练

程序入口

if __name__ == '__main__':  # 如果是直接运行该文件(而不是导入)
    main()  # 调用main函数,开始训练

[Results] 运行结果


[Content]主要内容:

  1. 环境初始化:使用 OpenAI Gym 库中的 CartPole-v1 环境。智能体的目标是通过选择合适的动作来保持杆子平衡。

  2. 策略网络

    • 采用一个简单的两层全连接神经网络 (Policy 类) 来表示策略。
    • 输入是环境的状态(4 维),输出是两个动作的概率分布(2 维)。
    • 使用 ReLU 激活函数对第一层的输出进行非线性转换,并用 Softmax 计算每个动作的概率。
  3. 选择动作:通过神经网络预测的动作概率,使用 Categorical 分布 来采样动作,并将该动作的对数概率保存下来用于后续更新。

  4. 回报和奖励

    • 在每个回合结束时,通过遍历奖励列表来计算每个时间步的累积回报(折扣奖励)。
    • 将奖励进行标准化,以帮助训练过程中梯度的稳定性。
  5. 优化器和更新:使用 Adam 优化器来更新策略网络的权重。每次回合结束时,计算每个时间步的损失,并使用反向传播来优化模型。

  6. 训练循环

    • 在每一回合中,智能体从环境中获得当前状态,选择动作,执行动作并获取奖励,直到回合结束。
    • 每隔一定的训练回合(log-interval),打印出当前回合的奖励情况。
    • 一旦平均奖励超过设定的阈值,训练结束并报告完成。

[Notice]  注意事项:

  1. 随机性和种子

    在环境初始化时,设置了随机种子(args.seed)。这确保了训练过程是可复现的。不同的种子可能会导致不同的训练结果,因此在实验中使用相同的种子可以帮助进行对比分析。
  2. 奖励标准化

    对奖励进行了标准化 (rewards - rewards.mean()) / (rewards.std() + eps),这是为了避免奖励的尺度差异对学习过程造成影响。eps 防止除零错误。
  3. 学习率选择

    学习率 (lr=1e-2) 设置为0.01,这可能需要根据训练的表现进行调整。过大的学习率可能导致训练不稳定,过小则可能导致训练进展缓慢。
  4. 回报折扣因子

    折扣因子 (gamma=0.99) 用来权衡当前奖励与未来奖励的关系。值越大,智能体更倾向于关注长期回报;值越小,则更注重即时奖励。
  5. 环境渲染

    通过 --render 命令行参数可以开启环境渲染,这对于观察训练过程很有帮助。但是开启渲染会导致训练速度下降,因此一般在调试或演示时才开启。
  6. 训练中的收敛

    由于REINFORCE算法是基于蒙特卡罗方法的,它依赖于完整的回合数据来估计梯度,因此收敛速度可能较慢。为了提高收敛速度,可以考虑使用 基准(Baseline) 或 优势函数(Advantage) 进行改进,或者使用 PPO 等更高效的强化学习算法。
  7. 性能问题

    由于每个回合都需要存储动作的对数概率以及奖励,这可能导致内存消耗较大。可以通过限制回合的最大步数或者使用更加高效的数据存储策略来减小内存负担。
  8. 终止条件

    训练会在达到环境的奖励阈值(env.spec.reward_threshold)时终止,这是判断智能体是否学会平衡杆的一个标准。环境的 reward_threshold 根据具体环境和任务而不同。
  9. 代码健壮性

    代码中使用了 env.reset() 返回的状态,实际上在 Gym v0.26 及之后版本中,reset() 方法返回一个元组 (state, info),但代码没有处理 info。为了兼容不同版本的 Gym,可以适配这个变化。
# 环境配置
Python                  3.11.5
torch                   2.1.0
torchvision             0.16.0
gym                     0.26.2

五、优缺点

优点:

  1. 直接优化策略:PG算法直接对策略进行建模,不需要显式地计算值函数。
  2. 适用于连续动作空间:与Q-learning等离散动作空间的算法相比,PG算法适合处理连续动作空间的强化学习问题。
  3. 可扩展性强:PG算法能够处理高维状态空间和复杂问题。

缺点:

  1. 高方差:REINFORCE等算法中,回报计算的高方差可能导致训练不稳定。
  2. 计算开销大:需要对每一条轨迹进行采样计算,并多次更新策略参数。
  3. 局部最优:策略梯度方法容易陷入局部最优解。

六、总结

        策略梯度方法通过优化策略函数来提高累积回报,其核心思想是直接优化策略而非间接估计值函数。REINFORCE算法是最基础的策略梯度方法,但由于其高方差问题,通常需要结合一些技术(如基线、经验回放)来改进性能。虽然策略梯度方法具有很好的理论支持,并且在很多复杂的强化学习任务中能够取得较好的结果,但也有一定的计算开销和稳定性问题。

 更多强化学习文章,请前往:【强化学习(RL)】专栏


        博客都是给自己看的笔记,如有误导深表抱歉。文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2280021.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Apache SeaTunnel 2.3.9 正式发布:多项新特性与优化全面提升数据集成能力

近日,Apache SeaTunnel 社区正式发布了最新版本 2.3.9。本次更新新增了Helm 集群部署、Transform 支持多表、Zeta新API、表结构转换、任务提交队列、分库分表合并、列转多行 等多个功能更新! 作为一款开源、分布式的数据集成平台,本次版本通过…

4 AXI USER IP

前言 使用AXI Interface封装IP,并使用AXI Interface实现对IP内部寄存器进行读写实现控制LED的demo,这个demo是非常必要的,因为在前面的笔记中基本都需哟PS端与PL端就行通信互相交互,在PL端可以通过中断的形式来告知PS端一些事情&…

B站评论系统的多级存储架构

以下文章来源于哔哩哔哩技术 ,作者业务 哔哩哔哩技术. 提供B站相关技术的介绍和讲解 1. 背景 评论是 B站生态的重要组成部分,涵盖了 UP 主与用户的互动、平台内容的推荐与优化、社区文化建设以及用户情感满足。B站的评论区不仅是用户互动的核心场所&…

电子科大2024秋《大数据分析与智能计算》真题回忆

考试日期:2025-01-08 课程:成电信软学院-大数据分析与智能计算 形式:开卷 考试回忆版 简答题(4*15) 1. 简述大数据的四个特征。分析每个特征所带来的问题和可能的解决方案 2. HDFS的架构的主要组件有哪些&#xff0…

多选multiple下拉框el-select回显问题(只显示后端返回id)

首先保证v-model的值对应options数据源里面的id <el-form-item prop"subclass" label"分类" ><el-select v-model"formData.subclass" multiple placeholder"请选择" clearable :disabled"!!formData.id"><e…

JavaWeb开发(十五)实战-生鲜后台管理系统(二)注册、登录、记住密码

1. 生鲜后台管理系统-注册功能 1.1. 注册功能 &#xff08;1&#xff09;创建注册RegisterServlet&#xff0c;接收form表单中的参数。   &#xff08;2&#xff09;service创建一个userService处理业务逻辑。   &#xff08;3&#xff09;RegisterServlet将参数传递给ser…

【MySQL系列文章】Linux环境下安装部署MySQL

前言 本次安装部署主要针对Linux环境进行安装部署操作,系统位数64 getconf LONG_BIT 64MySQL版本&#xff1a;v5.7.38 一、下载MySQL MySQL下载地址&#xff1a;MySQL :: Download MySQL Community Server (Archived Versions) 二、上传MySQL压缩包到Linuxx环境&#xff0c…

嵌入式硬件篇---基本组合逻辑电路

文章目录 前言基本逻辑门电路1.与门&#xff08;AND Gate&#xff09;2.或门&#xff08;OR Gate&#xff09;3.非门&#xff08;NOT Gate&#xff09;4.与非门&#xff08;NAND Gate&#xff09;5.或非门&#xff08;NOR Gate&#xff09;6.异或门&#xff08;XOR Gate&#x…

基于微信小程序的手机银行系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

C++ List 容器:实现原理深度解析

1.基本结构 1.1list底层 list底层是一个双向链表&#xff0c;每个节点包含三个主要部分&#xff1a;存储的数据&#xff0c;指向前一个节点和后一个节点的指针。我们首先定义一个 list_node 结构体来描述链表节点。 template <class T> struct list_node {T _data;lis…

在Spring Boot中使用SeeEmitter类实现EventStream流式编程将实时事件推送至客户端

&#x1f604; 19年之后由于某些原因断更了三年&#xff0c;23年重新扬帆起航&#xff0c;推出更多优质博文&#xff0c;希望大家多多支持&#xff5e; &#x1f337; 古之立大事者&#xff0c;不惟有超世之才&#xff0c;亦必有坚忍不拔之志 &#x1f390; 个人CSND主页——Mi…

Maven下载配置

目录 Win下载配置maven的环境变量 Mac下载安装配置环境变量 MavenSetting.xml文件配置 Win 下载 https://maven.apache.org/ 在主页面点击Download 点击archives 最好不要下载使用新版本&#xff0c;我使用的是maven-3.6.3&#xff0c;我们点击页面下方的archives&#xff0…

小程序获取微信运动步数

1、用户点击按钮&#xff0c;在小程序中触发getuserinfo方法&#xff0c;获取用户信息 <scroll-view class"scrollarea" scroll-y type"list"><view class"container"><button bind:tap"getLogin">获取</button&…

OSCP - Proving Grounds - BullyBox

主要知识点 如果发现有域名&#xff0c;则可以加入/etc/hosts后重新执行nmap,nikto等扫描dirsearch的时候可以使用完整一些的字典文件&#xff0c;避免漏掉信息.git dump 具体步骤 执行nmap 扫描&#xff0c;发现 80和22端口开放,访问后发现被重定向到 bullybox.local Star…

MIAOYUN信创云原生项目亮相西部“中试”生态对接活动

近日&#xff0c;以“构建‘中试’生态&#xff0c;赋能科技成果转化”为主题的“科创天府智汇蓉城”西部“中试”生态对接活动在成都高新区菁蓉汇隆重开幕。活动分为成果展览、“中试”生态主场以及成果路演洽谈对接三大板块。在成果展览环节&#xff0c;成都元来云志科技有限…

计算机网络 (47)应用进程跨越网络的通信

前言 计算机网络应用进程跨越网络的通信是一个复杂而关键的过程&#xff0c;它涉及多个层面和组件的协同工作。 一、通信概述 计算机网络中的通信&#xff0c;本质上是不同主机中的应用进程之间的数据交换。为了实现这种通信&#xff0c;需要借助网络协议栈中的各层协议&#x…

封装svg图片展示及操作组件——svgComponent——js技能提升

template部分 <template><div class"canvas-wrapper" ref"canvasWrapper"><svg:viewBox"computedViewBox"ref"svgCanvas"xmlns"http://www.w3.org/2000/svg"xmlns:xlink"http://www.w3.org/1999/xlink…

大数据,Hadoop,HDFS的简单介绍

大数据 海量数据&#xff0c;具有高增长率、数据类型多样化、一定时间内无法使用常规软件工具进行捕捉、管理和处理的数据集 合 大数据的特征: 4V Volume : 巨大的数据量 Variety : 数据类型多样化 结构化的数据 : 即具有固定格式和有限长度的数据 半结构化的数据 : 是…

EAMM: 通过基于音频的情感感知运动模型实现的一次性情感对话人脸合成

EAMM: 通过基于音频的情感感知运动模型实现的一次性情感对话人脸合成 1所有的材料都可以在EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model网站上找到。 摘要 尽管音频驱动的对话人脸生成技术已取得显著进展&#xff0c;但现有方法要么忽…

基于STM32的智能门锁安防系统(开源)

目录 项目演示 项目概述 硬件组成&#xff1a; 功能实现 1. 开锁模式 1.1 按键密码开锁 1.2 门禁卡开锁 1.3 指纹开锁 2. 功能备注 3. 硬件模块工作流程 3.1 步进电机控制 3.2 蜂鸣器提示 3.3 OLED显示 3.4 指纹与卡片管理 项目源代码分析 1. 主程序流程 (main…