深入理解策略梯度算法

news2025/1/12 4:10:45

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。

策略梯度算法的基本概念

在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 $\pi_\theta(a|s)$,其中$\theta$ 是策略的参数,$s$ 是状态,$a$ 是行动。

目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数$J(\theta)$,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。

策略梯度的数学推导

假设我们的目标函数 $J(\theta)$ 定义为:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]

其中$\tau$ 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),$R(\tau)$ 是该轨迹的总回报。根据策略的定义,我们有:

\pi_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)

因此,目标函数可以重写为:

J(\theta) = \sum_{\tau} \pi_\theta(\tau) R(\tau)

为了最大化$J(\theta)$,我们需要计算其梯度 $\nabla_\theta J(\theta)$

\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} \pi_\theta(\tau) R(\tau) = \sum_{\tau} \nabla_\theta \pi_\theta(\tau) R(\tau)

使用概率分布的梯度性质,我们有:

\nabla_\theta \pi_\theta(\tau) = \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau)

因此,梯度可以表示为:

\nabla_\theta J(\theta) = \sum_{\tau} \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]

这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 $\pi_\theta$ 中采样多个轨迹 $\tau$,然后计算平均值。

策略梯度算法的实现

我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略$\pi_\theta(a|s)$

步骤 1:环境设置

首先,设置环境和参数:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

env = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义

定义一个简单的策略网络:

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, n_actions):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.softmax(x, dim=-1)

policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹

编写函数来从策略中采样轨迹:

def sample_trajectory(env, policy, max_steps=1000):
    state = env.reset()
    states, actions, rewards = [], [], []
    for _ in range(max_steps):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = policy(state)
        action = np.random.choice(n_actions, p=probs.detach().numpy()[0])
        next_state, reward, done, _ = env.step(action)
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        if done:
            break
        state = next_state
    return states, actions, rewards
步骤 4:计算回报和梯度

计算每个状态的回报,并使用策略梯度定理更新策略:

def compute_returns(rewards, gamma=0.99):
    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    return returns

def update_policy(policy, optimizer, states, actions, returns):
    returns = torch.FloatTensor(returns)
    loss = 0
    for state, action, G in zip(states, actions, returns):
        state = state.squeeze(0)
        probs = policy(state)
        log_prob = torch.log(probs[action])
        loss += -log_prob * G
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
步骤 5:训练策略

将上述步骤组合在一起,训练策略网络:

num_episodes = 1000
for episode in range(num_episodes):
    states, actions, rewards = sample_trajectory(env, policy)
    returns = compute_returns(rewards)
    update_policy(policy, optimizer, states, actions, returns)
    if episode % 100 == 0:
        print(f"Episode {episode}, total reward: {sum(rewards)}")
总结

通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1888965.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ajax异步请求 axios

Ajax异步请求 axios 1 axios介绍 原生ajax请求的代码编写太过繁琐,我们可以使用axios这个库来简化操作! 在后续学习的Vue(前端框架)中发送异步请求,使用的就是axios. 需要注意的是axios不是vue的插件,它可以独立使用. axios说明网站:(https://www.kancl…

猫头虎分享[可灵AI」官方推荐的驯服指南-V1.0

猫头虎分享[可灵AI」官方推荐的驯服指南-V1.0 猫头虎是谁? 大家好,我是 猫头虎,别名猫头虎博主,擅长的技术领域包括云原生、前端、后端、运维和AI。我的博客主要分享技术教程、bug解决思路、开发工具教程、前沿科技资讯、产品评…

云服务器安装部署LAMP网站Web环境教程

搭建网站如何安装LAMP环境,以腾讯云轻量应用服务器为例,应用模板直接选择“LAMP”镜像即可,打开腾讯云轻量应用服务器页面,在应用模板中选择LAMP即可,如下图: 轻量服务器“LAMP”镜像 腾讯云的LAMP应用镜像…

XLSX + LuckySheet + LuckyExcel实现前端的excel预览

文章目录 功能简介简单代码实现效果参考 功能简介 通过LuckyExcel的transformExcelToLucky方法, 我们可以把一个文件直接转成LuckySheet需要的json字符串, 之后我们就可以用LuckySheet预览excelLuckyExcel只能解析xlsx格式的excel文件,因此对…

海南云亿商务咨询有限公司抖音电商新引擎

在当今这个数字化高速发展的时代,电商已经成为推动经济增长的重要引擎。而在众多电商平台中,抖音以其独特的短视频形式和庞大的用户群体,迅速崛起为电商领域的新星。海南云亿商务咨询有限公司,作为一家专注于抖音电商服务的公司&a…

JavaScript实现注册页面的校验

完成注册页面的校验 1.目标 要求: 1、 用户名必须是6-10位的字母或者数字 2、 密码长度必须6位任意字符 3、 两次输入密码要一致 说明:只要有一个输入项不满足要求则阻止表单提交。都满足才可以提交表单。 2.实现 1.知识点 1.1js事件 【1】鼠标离…

网上最靠谱的改名大师颜廷利:世界顶级哲学家,东方伟大的思想家教育家

颜廷利教授,21世纪的东方哲学巨匠、科学探索者,以及中国教育界的领军人物,他在《升命学说》中阐述了一种深刻的生活哲学。这本书包含了四个主要理念:净化论、和合法则、唯悟主义与镜正理念,每一个都为我们如何理解生活…

我在手提电脑上将大模型训练成了语文老师

(图片由大模型生成,如有侵权,立删) 记得一年多以前,和不少商家交流大模型解决方案时,他们谈到内部有很多的资料,可以对大模型进行训练,让大模型变得更有智慧,从而为客户…

【项目日记(三)】搜索引擎-搜索模块

❣博主主页: 33的博客❣ ▶️文章专栏分类:项目日记◀️ 🚚我的代码仓库: 33的代码仓库🚚 🫵🫵🫵关注我带你了解更多项目内容 目录 1.前言2.项目回顾3.搜索流程3.1分词3.2触发3.3去重3.4排序3.5包装 4.总结 1.前言 在前…

全网新鲜出炉的Stable Diffusion 人物发型提示词大全,中英文列表!

前言 简介: 使用发型提示词能更精确描述所需图像的发型特征,如卷发、短发、颜色和风格。结合正负提示词,确保生成图片符合预期。尝试使用工具如PromptChoose来创建个性化图像描述,包含多种发型选项,如刘海、马尾、波浪…

6.5、函数的常见形式

代码 #include <iostream> using namespace std; #include <string>//函数的的常见延时 //1、无参无反 void test01() {cout << "this is test01" << endl; } //2、有参无反 void test02(int a) {cout << "this is test02 a &q…

QT学习积累——方法参数加const和不加const的区别

目录 引出方法参数加const和不加const的区别方法加static和不加static的区别Qt遍历list提高效率显示函数的调用使用&与不使用&除法的一个坑 总结自定义信号和槽1.自定义信号2.自定义槽3.建立连接4.进行触发 自定义信号重载带参数的按钮触发信号触发信号拓展 lambda表达…

【探索Linux】P.36(传输层 —— TCP协议段格式)

阅读导航 引言一、TCP段的基本格式二、控制位详细介绍三、16位接收窗口大小⭕窗口大小的作用⭕窗口大小的限制⭕窗口缩放选项⭕窗口大小的更新⭕窗口大小与拥塞控制 四、紧急指针温馨提示 引言 在上一篇文章中&#xff0c;我们深入探讨了一种无连接的UDP协议&#xff0c;它以其…

【数据结构】04.双向链表

一、双向链表的结构 注意&#xff1a;这里的“带头”跟前面我们说的“头节点”是两个概念&#xff0c;带头链表里的头节点&#xff0c;实际为“哨兵位”&#xff0c;哨兵位节点不存储任何有效元素&#xff0c;只是站在这里“放哨的”。 “哨兵位”存在的意义&#xff1a;遍历循…

以太坊DApp交易量激增83%的背后原因解析

引言 最近&#xff0c;以太坊网络上的去中心化应用程序&#xff08;DApp&#xff09;交易量激增83%&#xff0c;引发了广泛关注和讨论。尽管交易费用高达2.4美元&#xff0c;但以太坊仍在DApp交易量方面遥遥领先于其他区块链网络。本文将深入探讨导致这一现象的主要原因&#…

Redux 使用及基本原理

什么是Redux Redux 是用于js应用的状态管理库&#xff0c;通常和React一起用。帮助开发者管理应用中各个组件之间的状态&#xff0c;使得状态的变化变得更加可预测和易于调试。 Redu也可以不和React组合使用。&#xff08;通常一起使用&#xff09; Redux 三大原则 单一数据源…

图书馆书籍管理系统

项目名称与项目简介 图书馆书籍管理系统 本项目是一个计算机管理系统&#xff0c;也就是将传统手工的管理方式转变为智能化、标准化、规范化的管理管理模式&#xff0c;对图书馆中所有的图书、文献资料、音像资料、报刊、期刊等各种类型的资料实现采编、收集图书信息、检索、归…

【笔记】强化学习,gym的命令行图形化界面适配

搞了一大堆还是搞不出来放弃了 最后用matplotlib画出来看 import gym import matplotlib.pyplot as plt from IPython import display import numpy as np %matplotlib inlineenv gym.make(CartPole-v1, render_mode"rgb_array") observation env.reset() a 0 f…

JWT入门

JWT与TOKEN JWT&#xff08;JSON Web Token&#xff09;是一种基于 JSON 格式的轻量级安全令牌&#xff0c;通常用于在网络应用间安全地传递信息。而“token”一词则是一个更广泛的术语&#xff0c;用来指代任何形式的令牌&#xff0c;用于在计算机系统中进行身份验证或授权。J…

EIOT能源物联网平台在连锁门店的应用

在当今快节奏的商业环境中&#xff0c;连锁门店的管理和运营变得越来越具有挑战性。能源数据是连锁门店的管理中重要组成部分&#xff0c;为了提高门店的能源利用效率和管理水平&#xff0c;需要依赖先进的集团能源管理系统&#xff0c;进而实现节能减排&#xff0c;优化运营成…