强化学习------Policy Gradient算法

news2025/3/15 0:20:24

目录

    • 简介
    • PG算法原理
    • 效果:
    • 参考

简介

之前的QLearning DQN Sarsa都是通过计算动作得分来决策的,我们是在确定了价值函数的基础上采用某种策略,即Value-Based,通过先算出价值函数,再去做决策。而Policy Gradient算法是一种直接的方法,我们直接去评估策略的好坏,然后进行选择。即Policy-Base

智能体通过与环境的交互获得特定时刻的状态信息,并直接给出下一步要采取各种动作的概率,然后根据该状态动作的策略分布采取下一步的行动,所以每种动作都有可能被选中,只是选中的概率性不同。智能体直接学习状态动作的策略分布,在强化学习的训练中,用神经网络来表示状态动作分布,给一个状态,就会输出该状态下的动作分布。强化学习算法直接对策略进行优化,使指定的策略能够获得最大的奖励。

PG算法原理

在这里插入图片描述

强化学习主要目标强是最大化智能体在与环境交互的过程中获得的累积奖励的期望值
考虑一个随机参数化的策略

  • π(θ) 是一个参数化的策略函数,它接受当前的观测作为输入,输出一个动作的概率分布。
  • E 表示期望值,表示对所有可能的轨迹 τ 进行加权平均。
  • τ 表示一个轨迹,它包含了智能体在环境中与环境进行交互的一系列状态、动作和奖励。
  • R(τ) 表示轨迹 τ 的累积奖励,表示智能体在轨迹中获得的所有奖励的总和。

通过梯度上升法优化策略即有
在这里插入图片描述

  • θ_k 表示当前的参数值,即第 k 次迭代时的参数值。
  • θ_{k+1} 表示下一次迭代的参数值,即第 k+1 次迭代时的参数值。
  • α 表示学习率,它决定了每次迭代时参数更新的幅度。
  • ∇_θ J(π_θ) 表示性能指标 J(π_θ) 对参数 θ 的梯度,它表示了在当前参数值下,性能指标 J(π_θ) 变化最快的方向。
  • ∣ θ_k 表示在参数值为 θ_k 的情况下计算梯度。

在这里插入图片描述

先实现一个episode然后从后往前计算回报,损失函数是负的回报乘于log的该状态下采取该动作的概率。每个状态动作对对应算一次loss,然后反向传播计算梯度。最后整个episode完之后进行梯度下降。

本代码核心在于更新参数的部分
更新参数的步骤如下:

  • Step 1: 计算每一步的状态价值 首先,将每个时间步的奖励值进行折扣,得到折扣后的奖励值。这里使用的折扣因子为GAMMA。
    然后,从最后一个时间步开始,计算每个时间步的状态价值。状态价值等于当前时间步的奖励值加上下一个时间步的折扣后的状态价值。这个过程通过一个循环实现。
    最后,对计算得到的状态价值进行标准化处理,减去均值并除以标准差。
        for t in reversed(range(0, len(self.ep_rs))):
            running_add = running_add * GAMMA + self.ep_rs[t]
            discounted_ep_rs[t] = running_add

        # 标准化处理
        discounted_ep_rs -= np.mean(discounted_ep_rs)  # 减均值
        discounted_ep_rs /= np.std(discounted_ep_rs)  # 除以标准差
        discounted_ep_rs = torch.FloatTensor(discounted_ep_rs).to(device)
  • Step 2: 前向传播 将状态观测值作为输入,通过神经网络进行前向传播,得到每个动作的概率分布。
    这里使用了softmax函数将输出转化为概率分布。
# Step 2: 前向传播
softmax_input = self.network.forward(torch.FloatTensor(self.ep_obs).to(device))
# all_act_prob = F.softmax(softmax_input, dim=0).detach().numpy()
neg_log_prob = F.cross_entropy(input=softmax_input, target=torch.LongTensor(self.ep_as).to(device),
                                  reduction='none')
  • Step 3: 反向传播 计算负对数似然损失函数,该损失函数用于最大化选择正确动作的概率。
    然后,将负对数似然损失函数与折扣后的状态价值相乘,得到最终的损失函数。 最后,计算损失函数的均值,用于更新神经网络的参数。
# Step 3: 反向传播
loss = torch.mean(neg_log_prob * discounted_ep_rs)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

代码如下:

import os

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import time
from collections import deque

# Hyper Parameters for PG Network
GAMMA = 0.95  # discount factor
LR = 0.01  # learning rate

# Use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# torch.backends.cudnn.enabled = False  # 非确定性算法


class PGNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PGNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 20)
        self.fc2 = nn.Linear(20, action_dim)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return out


class PG(object):
    # dqn Agent
    def __init__(self, env):  # 初始化
        # 状态空间和动作空间的维度
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n

        # init N Monte Carlo transitions in one game
        self.ep_obs, self.ep_as, self.ep_rs = [], [], []

        # init network parameters
        self.network = PGNetwork(state_dim=self.state_dim, action_dim=self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=LR)
        # 加载以前保存的模型(如果有的话)
        if os.path.exists("./model/model.pkl"):
            self.network.load_state_dict(torch.load("./model/model.pkl"))
            # 加载以前保存的优化器
            self.optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))


        # init some parameters
        self.time_step = 0

    def choose_action(self, observation):
        observation = torch.FloatTensor(observation).to(device)
        network_output = self.network.forward(observation)

        prob_weights = F.softmax(network_output, dim=0).detach().numpy()
        action = np.random.choice(range(prob_weights.shape[0]),
                                  p=prob_weights)  # select action w.r.t the actions prob
        return action

    # 将状态,动作,奖励这一个transition保存到三个列表中
    def store_transition(self, s, a, r):
        self.ep_obs.append(s)
        self.ep_as.append(a)
        self.ep_rs.append(r)

    def learn(self):
        self.time_step += 1

        # Step 1: 计算每一步的状态价值
        discounted_ep_rs = np.zeros_like(self.ep_rs)
        running_add = 0
        # 注意这里是从后往前算的,所以式子还不太一样。算出每一步的状态价值
        # 前面的价值的计算可以利用后面的价值作为中间结果,简化计算;从前往后也可以
        for t in reversed(range(0, len(self.ep_rs))):
            running_add = running_add * GAMMA + self.ep_rs[t]
            discounted_ep_rs[t] = running_add

        discounted_ep_rs -= np.mean(discounted_ep_rs)  # 减均值
        discounted_ep_rs /= np.std(discounted_ep_rs)  # 除以标准差
        discounted_ep_rs = torch.FloatTensor(discounted_ep_rs).to(device)

        # Step 2: 前向传播
        softmax_input = self.network.forward(torch.FloatTensor(self.ep_obs).to(device))
        # all_act_prob = F.softmax(softmax_input, dim=0).detach().numpy()
        neg_log_prob = F.cross_entropy(input=softmax_input, target=torch.LongTensor(self.ep_as).to(device),
                                       reduction='none')

        # Step 3: 反向传播
        loss = torch.mean(neg_log_prob * discounted_ep_rs)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # 每次学习完后清空数组
        self.ep_obs, self.ep_as, self.ep_rs = [], [], []


# ---------------------------------------------------------
# Hyper Parameters
ENV_NAME = 'CartPole-v0'
EPISODE = 3000  # Episode limitation
STEP = 300  # Step limitation in an episode
TEST = 10  # The number of experiment test every 100 episode


def main():
    # initialize OpenAI Gym env and dqn agent
    env = gym.make(ENV_NAME)
    agent = PG(env)

    for episode in range(EPISODE):
        # initialize task
        state = env.reset()
        # Train
        # 只采一盘?N个完整序列
        for step in range(STEP):
            action = agent.choose_action(state)  # softmax概率选择action
            next_state, reward, done, _ = env.step(action)
            agent.store_transition(state, action, reward)  # 新函数 存取这个transition
            state = next_state
            if done:
                # print("stick for ",step, " steps")
                agent.learn()  # 更新策略网络
                break

        # Test every 100 episodes
        if episode % 100 == 0:
            # 保存模型
            torch.save(agent.network.state_dict(), "./model/model.pkl")
            torch.save(agent.optimizer.state_dict(), "./model/optimizer.pkl")
            print("save model to /model/model.pkl")
            total_reward = 0
            for i in range(TEST):
                state = env.reset()
                for j in range(STEP):
                    env.render()
                    action = agent.choose_action(state)  # direct action for test
                    state, reward, done, _ = env.step(action)
                    total_reward += reward
                    if done:
                        break
            ave_reward = total_reward / TEST
            print('episode: ', episode, 'Evaluation Average Reward:', ave_reward)


if __name__ == '__main__':
    time_start = time.time()
    main()
    time_end = time.time()
    print('The total time is ', time_end - time_start)


效果:

在这里插入图片描述

参考

策略梯度算法
policy gradient详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1103816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云计算:shell脚本

shell脚本,会极大减少重复性工作,缩短很大时间。 脚本每个人都可以不一样,只要实现就可以。 注意:要多思考,把思路锻炼好。以后就可以写各种程序。 shell语言 学完shell之后,对Linux理解更深刻&#xff…

在调试器下看微信[如何耗电]

在今天这样干什么都离不开手机的时代里,手机的待机时间太重要了。特别是对于我这个不喜欢带充电宝出门的人来说,一旦看到手机电量低于20%,立刻就精神紧张了,因为一切信息都在手机里,如果手机没电,那么就失联…

[SQL | MyBatis] MyBatis 简介

目录 一、MyBatis 简介 1、MyBatis 简介 2、工作流程 二、入门案例 1、准备工作 2、示例 三、Mapper 代理开发 1、问题简介 2、工作流程 3、注意事项 4、测试 四、核心配置文件 mybatis-config.xml 1、environment 2、typeAilases 五、基于 xml 的查询操作 1、…

通过stream对list集合中对象的多个字段进行去重

记录下通过stream流对list集合中对象的多个字段进行去重! 举个栗子,对象book,我们要通过姓名和价格这两个字段的值进行去重,该这么做呢? distinct()返回由该流的不同元素组成的流。distinct&am…

第五届芜湖机器人展,正运动助力智能装备“更快更准”更智能!

■展会名称: 第十一届中国(芜湖)科普产品博览交易会-第五届机器人展 ■展会日期 2023年10月21日-23日 ■展馆地点 中国ㆍ芜湖宜居国际博览中心B馆 ■展位号 B029 正运动技术,作为国内领先的运动控制企业,将于2023年10月21日参加芜湖机…

查看双翌视觉软件版本号

查看双翌视觉软件版本号 MasterAlign视觉对位软件 MasterAlign视觉对位软件的版本号在软件界面的右下角,如下图所示: 进入界面查看右下角编号尾号为O的代表旧协议版本 而编号尾号为N的则为新协议版本。 WiseAlign视觉对位软件 打开WiseAlign视觉对位软…

靶机 Chill_Hack

Chill_Hack 信息搜集 存活检测 arp-scan -l 详细扫描 扫描结果 显示允许 ftp 匿名链接 FTP 匿名登录 匿名登陆 ftp 下载文件并查看 anonymous10.4.7.139下载命令 get note.txt查看文件 译 Anurodh告诉我,在命令 Apaar 中有一些字符串过滤后台扫描 扫描结果…

【算法挨揍日记】day16——525. 连续数组、1314. 矩阵区域和

525. 连续数组 525. 连续数组 题目描述: 给定一个二进制数组 nums , 找到含有相同数量的 0 和 1 的最长连续子数组,并返回该子数组的长度。 解题思路: 本题的元素只有0和1,根据题目意思,我们可以把题目看成找一段最…

Educational Codeforces Round 156 (Rated for Div. 2)

C. Decreasing String 分析&#xff1a;暴力做法是很容易想到的&#xff0c;但时间复杂度为O(n2) 这是我打cf以来看到的最好的题解。 #include<cstdio> #include<set> #include<list> #include<queue> #include<math.h> #include<stdlib.h&g…

5.DApp-前端网页怎么连接MetaMask

题记 在前端网页连接metamask&#xff0c;以下是全部操作流程和代码。 编写index.html文件 index.html文件如下&#xff1a; <!DOCTYPE html> <html> <head> <title>My DApp</title> <!--导入用于检测Metamask提供者的JavaScript库--> &l…

嵌入式开发学习之STM32F407定时器中断配置(四)

嵌入式开发学习之STM32F407定时器中断配置&#xff08;四&#xff09; 此次实现目的开发涉及工具一、TIM参数配置和中断配置二、TIM的中断服务函数 此次实现目的 1.配置一个TIM进行计时&#xff0c;让一颗LED以点亮500ms&#xff0c;熄灭500ms的方式闪烁&#xff1b; 有工程实…

【JVM】对象内存布局

对象内存布局 文章目录 对象内存布局1. 对象的内存布局2. 对象标记(Mark Word)3. 类元信息(类型指针)4. 实例数据和对象填充 1. 对象的内存布局 在Hotspot虚拟机里&#xff0c;对象在堆内存中的存储布局可以划分为三个部分&#xff1a;对象头(Header)、实例数据(Instance Data…

华为云云耀云服务器L实例评测|使用Benchmark工具对云耀云服务器Elasticsearch的性能测试

目录 引言 1 在centos上安装Elasticsearch 1.1在服务器上安装 Docker 1.2 查找Elasticsearch镜像 1.3 安装并运行 Elasticsearch 容器 2 性能测试 Elasticsearch 2.1 安装 Apache Benchmark 工具 2.2 使用Benchmark进行性能测试 3 性能分析 3.1 性能测试结果 3.2 性能…

堆/二叉堆详解[C/C++]

前言 堆是计算机科学中-类特殊的数据结构的统称。实现有很多,例如:大顶堆,小顶堆&#xff0c;斐波那契堆&#xff0c;左偏堆&#xff0c;斜堆等等。从子结点个数上可以分为二汊堆&#xff0c;N叉堆等等。本文将介绍的是二叉堆。 二叉堆的概念 1、引例 我们小时候&#xff0c;基…

网络安全常见问题隐患及其应对措施

随着数字化时代的到来&#xff0c;网络安全已经成为组织和个人面临的严重挑战之一。网络攻击日益普及&#xff0c;黑客和不法分子不断寻找机会侵入系统、窃取敏感信息、破坏服务和网络基础设施。在这种情况下&#xff0c;了解网络安全的常见问题隐患以及如何应对它们至关重要。…

Android 13 - Media框架(11)- MediaCodec(一)

MediaCodec 是 Android 平台上音视频编解码的标准接口&#xff0c;无论是使用软解还是硬解都要通过调用 MediaCodec来完成&#xff0c;是学习 Android 音视频不可跳过的重要部分。MediaCodec 部分的代码有几千行&#xff0c;光是头文件就有几百行&#xff0c;对于我这样的新手来…

OpenCV Series : TI - DSP - CCS

Code Composer Studio V5.5 https://www.ti.com/tool/download/CCSTUDIO https://www.ti.com/tool/download/CCSTUDIO/5.5.0.00077

vue中引入jquery解决跨域问题

1、vue 工程文件 package.json 中 引入 “dependencies”: { “jquery”:“^2.2.4” }, 2、控制台执行命令&#xff0c;当前工程文件夹下 cnpm install 3、修改的vue文件中 加入 import $ from ‘jquery’ 4、调用 ajax请求 $.ajax({url:http://192.168.0.10:9099/strutsJspA…

黑马JVM总结(三十六)

&#xff08;1&#xff09;CAS-概述 cas是配合volatile使用的技术 &#xff0c;对共享变量的安全性要使用synachonized加锁&#xff0c;但是CAS不加锁&#xff0c;它是使用where&#xff08;true&#xff09;的死循环&#xff0c;里面compareAndSwap尝试把结果赋值给共享变量&…

Leetcode 02.07 链表相交(链表)

Leetcode 02.07 链表相交&#xff08;链表&#xff09; 解法1 尾部对齐解法2&#xff1a;太厉害了&#xff0c;数学归纳推导的方法 很巧妙&#xff0c;这就是将链表的尾端对齐后再一起遍历&#xff0c;这样能满足题目的要求。因为相交之后两个链表到结束的所有节点都一样了&…