策略梯度 (Policy Gradient):直接优化策略的强化学习方法

news2025/1/31 5:51:20

策略梯度 (Policy Gradient) 是强化学习中的一种方法,用于优化智能体的策略,使其在给定环境中表现得更好。与值函数方法(如 Q-learning)不同,策略梯度方法直接对策略进行优化,而不是通过学习一个值函数来间接估计最优策略。

核心思想:

在策略梯度方法中,智能体的策略是一个参数化的函数(通常是神经网络),通过梯度上升法来优化该策略的参数,使得智能体在与环境互动时获得最大的预期奖励。该方法通过计算策略相对于策略参数的梯度来更新策略参数,从而改善智能体的行为。

实现方式:

  1. 收集经验: 智能体与环境互动,收集状态-动作对以及相应的奖励。
  2. 计算梯度: 基于当前策略和收集到的经验,计算梯度。
  3. 更新策略: 使用计算出的梯度更新策略参数。

优点:

  • 可以直接优化策略,适用于连续动作空间。
  • 不依赖于环境的价值函数,适用于部分可观测或高维的状态空间。

缺点:

  • 策略梯度的估计通常具有较高的方差,需要更多的样本来获得稳定的结果。
  • 收敛速度较慢,可能需要更多的计算资源。

简单例子:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 1D迷宫环境,目标是从位置0移动到位置10
class SimpleMazeEnv:
    def __init__(self):
        self.state = 0  # 初始位置
        self.target = 10  # 目标位置
        self.max_steps = 20  # 最大步数
        
    def reset(self):
        self.state = 0
        return self.state
    
    def step(self, action):
        if action == 0:  # 向左移动
            self.state = max(0, self.state - 1)
        elif action == 1:  # 向右移动
            self.state = min(self.target, self.state + 1)
        
        # 计算奖励,靠近目标位置时奖励更高
        reward = -abs(self.state - self.target)  # 离目标越远奖励越低
        done = (self.state == self.target)  # 到达目标时结束
        return self.state, reward, done

# 策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return self.softmax(x)

# 策略梯度算法(REINFORCE)
def reinforce(env, policy, optimizer, episodes=1000, gamma=0.99):
    episode_rewards = []
    best_reward = -float('inf')
    best_path = []
    
    for episode in range(episodes):
        state = env.reset()
        state = torch.tensor([state], dtype=torch.float32)
        done = False
        rewards = []
        log_probs = []
        path = []  # 记录当前回合的路径
        
        while not done:
            # 选择动作
            action_probs = policy(state)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            
            # 执行动作并观察结果
            next_state, reward, done = env.step(action.item())
            next_state = torch.tensor([next_state], dtype=torch.float32)
            
            # 保存奖励和动作的log概率
            rewards.append(reward)
            log_probs.append(dist.log_prob(action))
            path.append(state.item())  # 记录当前位置
            
            state = next_state
        
        # 计算回报
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        
        # 计算损失并更新模型
        returns = torch.tensor(returns, dtype=torch.float32)
        log_probs = torch.stack(log_probs)
        loss = -torch.sum(log_probs * returns)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_reward = sum(rewards)
        episode_rewards.append(total_reward)
        
        if total_reward > best_reward:
            best_reward = total_reward
            best_path = path
        
        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}, Total Reward: {total_reward}, Best Reward: {best_reward}")
    
    return episode_rewards, best_path

# 初始化环境和模型
env = SimpleMazeEnv()
input_dim = 1  # 状态是一个标量
output_dim = 2  # 动作是向左或向右
policy = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.001)

# 训练模型
episode_rewards, best_path = reinforce(env, policy, optimizer, episodes=1000)

# 可视化训练结果
plt.figure(figsize=(12, 6))

# 绘制奖励曲线
plt.subplot(1, 2, 1)
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training Progress')

# 绘制最优路径图
plt.subplot(1, 2, 2)
plt.plot(best_path, marker='o', markersize=5, label="Best Path")
for i, coord in enumerate(best_path):
    plt.text(i, coord, f"({i}, {coord})", fontsize=8)  # 显示坐标
plt.xlabel('Steps')
plt.ylabel('State')
plt.title('Best Path Taken')
plt.legend()

plt.tight_layout()
plt.show()
  1. 环境SimpleMazeEnv是一个非常简单的1D迷宫环境,智能体的目标是从位置0移动到目标位置10。每步,智能体可以选择向左或向右移动。
  2. 策略网络PolicyNetwork是一个简单的神经网络,输出的是两个动作的概率(向左和向右)。
  3. 训练过程:采用策略梯度算法(REINFORCE),在每一轮训练中,智能体根据当前策略选择动作,通过累积奖励(回报)来更新策略网络。
  4. 奖励:智能体的奖励是与目标位置的距离成反比,离目标越近奖励越高。

预期效果:

  • 训练过程:每个回合的奖励会逐渐增加,智能体会逐步学习到正确的动作。
  • 可视化:我们会看到训练过程中每个回合的奖励曲线,以及最优路径(即智能体最终到达目标位置时的移动轨迹)。

运行后的图:

  • 左图:训练过程中的奖励变化。
  • 右图:最优路径的轨迹图,标记了每一步的位置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2286775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

练习(复习)

大家好,今天我们来做几道简单的选择题目来巩固一下最近学习的知识,以便我们接下来更好的学习。 这道题比较简单,我们前面学过,在Java中,一个类只能继承一个父类,但是一个父类可以有多个子类,一个…

【原创改进】SCI级改进算法,一种多策略改进Alpha进化算法(IAE)

目录 1.前言2.CEC2017指标3.效果展示4.探索开发比5.定性分析6.附件材料7.代码获取 1.前言 本期推出一期原创改进——一种多策略改进Alpha进化算法(IAE)~ 选择CEC2017测试集低维(30dim)和高维(100dim)进行测…

56. 协议及端口号

协议及端口号 在计算机网络中,协议和端口号是两个重要的概念。它们共同确保了不同计算机和网络设备之间可以正确、有效地进行通信。 协议(Protocol) 协议是网络通信的一组规则或标准,它定义了如何在计算机网络中发送、接收和解释…

短链接项目02---依赖的添加和postman测试

文章目录 1.声明2.对于依赖的引入和处理2.1原有的内容说明2.2添加公共信息2.3dependencies和management区别说明2.4添加spring-boot依赖2.5数据库的相关依赖2.6hutool工具类的依赖添加2.7测试test 的依赖添加 3.core文件的代码3.1目录层级结构3.2启动类3.3testcontroller测试类…

ADC 精度 第二部分:总的未调整误差解析

在关于ADC精度的第一篇文章中,我们阐述了模拟-数字转换器(ADC)的分辨率和精度之间的区别。现在,我们可以深入探讨影响ADC总精度的因素,这通常被称为总未调整误差(TUE)。 你是否曾好奇ADC数据表…

密码强度验证代码解析:C语言实现与细节剖析

在日常的应用开发中,密码强度验证是保障用户账户安全的重要环节。今天,我们就来深入分析一段用C语言编写的密码强度验证代码,看看它是如何实现对密码强度的多维度检测的。 代码整体结构 这段C语言代码主要实现了对输入密码的一系列规则验证&a…

Vue - pinia

Pinia 是 Vue 3 的官方状态管理库,旨在替代 Vuex,提供更简单的 API 和更好的 TypeScript 支持。Pinia 的设计遵循了组合式 API 的理念,能够很好地与 Vue 3 的功能结合使用。 Pinia 的基本概念 Store: Pinia 中的核心概念,类似于…

JxBrowser 7.41.7 版本发布啦!

JxBrowser 7.41.7 版本发布啦! • 已更新 #Chromium 至更新版本 • 实施了多项质量改进 🔗 点击此处了解更多详情。 🆓 获取 30 天免费试用。

亚博microros小车-原生ubuntu支持系列:17 gmapping

前置依赖 先看下亚博官网的介绍 Gmapping简介 gmapping只适用于单帧二维激光点数小于1440的点,如果单帧激光点数大于1440,那么就会出【[mapping-4] process has died】 这样的问题。 Gmapping是基于滤波SLAM框架的常用开源SLAM算法。 Gmapping基于RBp…

Python 变量和简单数据类型思维导图_2025-01-30

变量和简单数据类型思维导图 下载链接腾讯云盘: https://share.weiyun.com/15A8hrTs

小麦重测序-文献精读107

Whole-genome sequencing of diverse wheat accessions uncovers genetic changes during modern breeding in China and the United States 中国和美国现代育种过程中小麦不同种质的全基因组测序揭示遗传变化 大豆重测序-文献精读53_gmsw17-CSDN博客 大豆重测序二&#xff…

Django基础之ORM

一.前言 上一节简单的讲了一下orm,主要还是做个了解,这一节将和大家介绍更加细致的orm,以及他们的用法,到最后再和大家说一下cookie和session,就结束了全部的django基础部分 二.orm的基本操作 1.settings.py&#x…

大模型知识蒸馏技术(2)——蒸馏技术发展简史

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl2006年模型压缩研究 知识蒸馏的早期思想可以追溯到2006年,当时Geoffrey Hinton等人在模型压缩领域进行了开创性研究。尽管当时深度学习尚未像今天这样广泛普及,但Hinton的研究已经为知识迁移和模…

android获取EditText内容,TextWatcher按条件触发

android获取EditText内容,TextWatcher按条件触发 背景:解决方案:效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中涉及到EditText组件内容的动态校验,初步实现功能后&#…

毕业设计--具有车流量检测功能的智能交通灯设计

摘要: 随着21世纪机动车保有量的持续增加,城市交通拥堵已成为一个日益严重的问题。传统的固定绿灯时长方案导致了大量的时间浪费和交通拥堵。为解决这一问题,本文设计了一款智能交通灯系统,利用车流量检测功能和先进的算法实现了…

[权限提升] 操作系统权限介绍

关注这个专栏的其他相关笔记:[内网安全] 内网渗透 - 学习手册-CSDN博客 权限提升简称提权,顾名思义就是提升自己在目标系统中的权限。现在的操作系统都是多用户操作系统,用户之间都有权限控制,我们通过 Web 漏洞拿到的 Web 进程的…

Qt Designer and Python: Build Your GUI

1.install pyside6 2.pyside6-designer.exe 发送到桌面快捷方式 在Python安装的所在 Scripts 文件夹下找到此文件。如C:\Program Files\Python312\Scripts 3. 打开pyside6-designer 设计UI 4.保存为simple.ui 文件,再转成py文件 用代码执行 pyside6-uic.exe simpl…

数据结构与算法之栈: LeetCode LCR 152. 验证二叉搜索树的后序遍历序列 (Ts版)

验证二叉搜索树的后序遍历序列 https://leetcode.cn/problems/er-cha-sou-suo-shu-de-hou-xu-bian-li-xu-lie-lcof/description/ 描述 请实现一个函数来判断整数数组 postorder 是否为二叉搜索树的后序遍历结果 示例 1 输入: postorder [4,9,6,5,8] 输出: false解释&#…

[STM32 - 野火] - - - 固件库学习笔记 - - -十三.高级定时器

一、高级定时器简介 高级定时器的简介在前面一章已经介绍过,可以点击下面链接了解,在这里进行一些补充。 [STM32 - 野火] - - - 固件库学习笔记 - - -十二.基本定时器 1.1 功能简介 1、高级定时器可以向上/向下/两边计数,还独有一个重复计…

IPhone13 Pro Max设备详情

目录 产品宣传图内部图——后设备详细信息 产品宣传图 内部图——后 设备详细信息 信息收集于HubWeb.cn