Reinforcement Learning with Code【Code 5. Policy Gradient Methods】

news2025/2/8 20:10:17

Reinforcement Learning with Code【Code 5. Policy Gradient Methods】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning, .

文章目录

  • Reinforcement Learning with Code【Code 5. Policy Gradient Methods】
    • 1. Policy Gradient 回顾
    • 2. Policy Gradient Code
    • Reference

1. Policy Gradient 回顾

之前介绍的 Q-learning、DQN 及 DQN 改进算法都是基于价值(value-based)的方法,其中 Q-learning 是处理有限状态的算法,而 DQN 可以用来解决连续状态的问题。在强化学习中,除了基于值函数的方法,还有一支非常经典的方法,那就是基于策略(policy-based)的方法。对比两者,基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。策略梯度是基于策略的方法的基础,本章从策略梯度算法说起。

由之前的学习参考Reinforcement Learning with Code 【Chapter 9. Policy Gradient Methods】,可以策略梯度有三个metric可以使用,分别是平均状态值(average state value),平均奖励(average reward)和从特定状态出发的平均状态值(state value of a specific starting state)。其中使用最多就是从特定状态出发的平均状态值,记为 v π ( s 0 ) v_\pi(s_0) vπ(s0),其中 s 0 s_0 s0表示初始状态。所有当我们使用从特点状态出发的平均状态值(state value of a specific starting state)作为优化目标函数的时候,我们的待优化函数可以写作
max ⁡ θ J ( θ ) = E [ v π θ ( s 0 ) ] \max_\theta J(\theta) = \mathbb{E}[v_{\pi_\theta}(s_0)] θmaxJ(θ)=E[vπθ(s0)]
再根据策略梯度定理,则有证明略(可以参考Hands on RL)
∇ θ J ( θ ) = E [ ∇ θ ln ⁡ π ( A ∣ S ; θ ) q π ( S , A ) ] \nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \ln \pi(A|S;\theta)q_\pi(S,A)] θJ(θ)=E[θlnπ(AS;θ)qπ(S,A)]
这一梯度更新法则是不能使用的,这是因为 q π ( S , A ) q_\pi(S,A) qπ(S,A)是真值,我们不能获得,我们可以借助Monte-Carlo的思想来对此进行更新,用一个episode的return来代替这个Q-value,即
q π ( s t , a t ) = ∑ k = t + 1 T γ k − t − 1 r k q_\pi(s_t,a_t) = \sum^T_{k=t+1}\gamma^{k-t-1} r_k qπ(st,at)=k=t+1Tγkt1rk
那我们获得的梯度更新法则为
∇ θ J ( θ ) = E [ ∇ θ ln ⁡ π ( A ∣ S ; θ ) × ∑ k = t + 1 T γ k − t − 1 r k ) ] \nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \ln \pi(A|S;\theta) \times \sum^T_{k=t+1}\gamma^{k-t-1} r_k)] θJ(θ)=E[θlnπ(AS;θ)×k=t+1Tγkt1rk)]
还原出待优化的目标函数即为
max ⁡ θ J ( θ ) = E [ ln ⁡ π ( A ∣ S ; θ ) × ∑ k = t + 1 T γ k − t − 1 r k ) ] \max_\theta J(\theta) = \mathbb{E}[\ln \pi(A|S;\theta) \times \sum^T_{k=t+1}\gamma^{k-t-1} r_k)] θmaxJ(θ)=E[lnπ(AS;θ)×k=t+1Tγkt1rk)]
这个应用了Monte-Carlo思想的算法又被称为REINFORCE。

2. Policy Gradient Code

智能体的交互环境采用的是gymCartPole-v1环境,已经在中Reinforcement Learning with Code 【Code 4. Vanilla DQN】进行过介绍,此处不再赘述。

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Define the policy network
class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)

    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        x = F.softmax(self.fc2(x), dim=1)
        return x


# Implement REINFORCE algorithm
class REINFORCE():
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
        self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.device = device
    
    def choose_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.policy_net(state)
        action_probs_dist = torch.distributions.Categorical(probs)    # generate prob distribution according to probs
        action = action_probs_dist.sample().item()
        return action
    
    def learn(self, transition_dict):
        reward_list = transition_dict['rewards']
        state_list = transition_dict['states']
        action_list = transition_dict['actions']

        G = 0
        self.optimizer.zero_grad()
        for i in reversed(range(len(reward_list))):
            reward = reward_list[i]
            state = torch.tensor([state_list[i]], dtype=torch.float).to(self.device)
            action = torch.tensor([action_list[i]]).view(-1,1).to(self.device)
            log_prob = torch.log(self.policy_net(state).gather(dim=1,index=action))
            G = self.gamma * G + reward
            loss = -log_prob * G    # 计算每一步的损失函数,有负号是因为我们需要max这个loss
            loss.backward()     # 反向传播累计梯度
        self.optimizer.step()   # after one episode 梯度更新


def train_policy_net_agent(env, agent, num_episodes, seed):
    return_list = []
    for i in range(10):
        with tqdm(total = int(num_episodes/10), desc="Iteration %d"%(i+1)) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {
                    'states': [],
                    'actions': [],
                    'next_states': [],
                    'rewards': [],
                    'dones': []
                }
                observation, _ = env.reset(seed=seed)
                done = False
                while not done:
                    if render:
                        env.render()
                    action = agent.choose_action(observation)
                    observation_, reward, terminated, truncated, _ = env.step(action)
                    done = terminated or truncated
                    # save one episode experience into a dict
                    transition_dict['states'].append(observation)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(observation_)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    # swap state
                    observation = observation_
                    # compute one episode return
                    episode_return += reward
                return_list.append(episode_return)
                agent.learn(transition_dict)
                if((i_episode + 1) % 10 == 0):
                    pbar.set_postfix({
                        'episode': '%d'%(num_episodes / 10 * i + i_episode + 1),
                        'return': '%.3f'%(np.mean(return_list[-10:]))
                    })
                pbar.update(1)
    env.close()
    return return_list

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

def plot_curve(return_list, mv_return, algorithm_name, env_name):
    episodes_list = list(range(len(return_list)))
    plt.plot(episodes_list, return_list, c='gray', alpha=0.6)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('{} on {}'.format(algorithm_name, env_name))
    plt.show()

if __name__=="__main__":
    learning_rate = 1e-3    # learning rate
    num_episodes = 1000     # episodes length
    hidden_dim = 128        # hidden layers dimension
    gamma = 0.98            # discounted rate
    device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')

    env_name = 'CartPole-v1'    # gym env name  
    render = False              # render or not

    # reproducible
    seed_number = 0
    torch.manual_seed(seed_number)
    np.random.seed(seed_number)


    if render:
        env = gym.make(id=env_name, render_mode='human')
    else:
        env = gym.make(id=env_name, render_mode=None)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = REINFORCE(state_dim, hidden_dim, action_dim, learning_rate, gamma, device)
    return_list = train_policy_net_agent(env, agent, num_episodes, seed_number)

    mv_return = moving_average(return_list, 9)
    plot_curve(return_list, mv_return, 'REINFORCE', env_name)

REINFORCE的效果如下图所示

Image

Reference

赵世钰老师的课程
Reinforcement Learning with Code 【Chapter 9. Policy Gradient Methods】
Hands on RL
Reinforcement Learning with Code 【Code 4. Vanilla DQN】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/864760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

想要延长Macbook寿命?这六个保养技巧你必须get!

Mac作为我们工作生活的伙伴,重要性不需要多说。但在使用的过程中,我们总会因不当操作导致Mac出现各种问题。 要想它长久的陪伴,平时的维护与保养自然不能少,Mac的保养很重要的两点就是硬件保养和电脑系统保养,硬件保养…

企业服务器中了Locked勒索病毒怎么办,勒索病毒解密有哪些步骤

随着网络技术的不断发展,勒索病毒攻击成为了企业面临的一种风险。近期,我们收到某医药公司的求助,企业的服务器数据库遭到了locked勒索病毒的攻击,导致企业服务器内的许多重要数据被加密无法正常读取,不仅影响到了企业…

智能合约 -- 常规漏洞分析 + 实例

1.重入攻击 漏洞分析 攻击者利用合约漏洞,通过fallback()或者receive()函数进行函数递归进行无限取钱。 刚才试了一下可以递归10次,貌似就结束了。 直接看代码: 银行合约:有存钱、取钱、查看账户余额等函数。攻击合约: 攻击、以及合约接…

ECG和PPG信号用于PTT、HRV和PRV研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

微信朋友圈置顶功能已大范围上线!

微信是目前全球最受欢迎的社交媒体应用之一,拥有数十亿的用户。作为一款持续发展和改进的应用,微信不断推出新的功能来提升用户体验。 近日,iOS微信8.0.41内测版迎来了更新,本次更新距离上个正式版间隔了大概10天的时间。 微信朋友…

BI技巧丨利用Index计算半累计

在实际的业务场景中,特别是财务模块和库存管理模块,经常需要我们针对每个月的期初期末进行相关指标计算,这也是我们之前曾经提到的Calculate基础应用——半累计计算。 现在我们也可以通过微软新推出的Index开窗函数来解决这一问题。 INDEX函…

健启星|医学营养的市场先行者

随着《“健康中国2030”规划纲要》、《国民营养计划(2017-2030年)》等政策的陆续发布,标志着以传统药物治疗为中心的医疗模式时代正式转型到以预防和康复为中心的新的医学营养时代。在此背景下,符合时代需求的特医食品成为“医学营…

HbuildX生成安卓签名证书

HbuildX生成安卓签名证书 安装和配置JRE环境 根据此链接安装和配置JRE环境 生成签名证书 keytool -genkey -alias testalias -keyalg RSA -keysize 2048 -validity 36500 -keystore test.keystoretestalias是证书别名,可修改为自己想设置的字符,建议…

闻道网络:2023宠物消费网络营销洞察数据报告(附下载)

关于报告的所有内容,公众【营销人星球】获取下载查看 核心观点 行业持续升级,增速放缓,正朝着多元化和专业化的方向发展;自公共事件以来,因,“猫不用遛”,养猫人士增速迅猛反超犬主人&#xf…

Qt在mac安装

先在app store下载好Xcode 打开Xcode 随便建个文件 给它取个名字 找个地方放 提醒没建立git link,不用理他 打开终端, 输入/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 开始安装啦 继续在终端…

MySQL中的用户管理

系列文章目录 MySQL常见的几种约束 MySQL中的函数 MySQL中的事务 MySQL中的视图 MySQL中的索引 文章目录 系列文章目录前言一、用户管理1、用户管理入门2、用户管理操作及示例 二、权限管理1.权限管理语法2.权限操作示例 三、角色管理1、角色管理入门2、角色操作示例 总结…

Chisel 语言 - 小记

文章目录 Chisel 一种硬件描述语言,类似 verilog 本质是 Scala编程语言的一个包,类似于 numpy 是 Python 的一个包。 官网 : https://www.chisel-lang.orggithub: https://github.com/chipsalliance/chisel 同名的还有个 Facebook…

【数学建模】--灰色关联分析

系统分析: 一般的抽象系统,如社会系统,经济系统,农业系统,生态系统,教育系统等都包含有许多种因素,多种因素共同作用的结果决定了该系统的发展态势。人们常常希望知道在众多的因素中,哪些是主要…

每天一道leetcode:516. 最长回文子序列(动态规划中等)

今日份题目: 给你一个字符串 s ,找出其中最长的回文子序列,并返回该序列的长度。 子序列定义为:不改变剩余字符顺序的情况下,删除某些字符或者不删除任何字符形成的一个序列。 示例1 输入:s "bbb…

Nginx环境搭建以及Docker环境部署

目录 Nginx环境搭建 1.首先创建Nginx的目录并进入 2.下载Nginx的安装包 可以通过FTP工具上传离线环境包,也可通过wget命令在线获取安装包 没有wget命令的可通过yum命令安装 3.解压Nginx的压缩包 4.下载并安装Nginx所需的依赖库和包 安装方式一 安装方式二 --- 也…

hello world, this is my time

case1 2023-08-11 00:19:12 其实我这个人吧, 没事做也会刷点b站和抖音, 而且我经常看罗翔老师讲, 什么是爱, 他说爱是责任, 爱是不离不弃, 爱是有耐心, 爱是安慰, 爱也是陪伴, 爱同时也是一种共生的关系, 两个人彼此之间共生, 互相都希望彼此可以好好的生活下去, 看见对方活的比…

Git全栈体系(六)

第十章 自建代码托管平台-GitLab 一、GitLab 简介 GitLab 是由 GitLabInc.开发,使用 MIT 许可证的基于网络的 Git 仓库管理工具,且具有 wiki 和 issue 跟踪功能。使用 Git 作为代码管理工具,并在此基础上搭建起来的 web 服务。GitLab 由乌克…

红帽停止公开Linux操作系统(RHEL)源代码,甲骨文等企业成立协会

根据报道,红帽(Red Hat)在8月11日宣布停止公开企业级Linux操作系统(RHEL)的源代码后,甲骨文、SUSE和CIQ昨日联合发布了一份声明。声明宣布成立了Open Enterprise Linux Association(OpenELA&…

安全测试中常见的业务安全问题

“在测试过程中,特殊的操作往往容易触发异常场景,而这些异常场景也很容易引起安全问题!” 常见的安全漏洞就不多说了,这里主要介绍常见的业务安全问题及修复建议。 01 刷短信 问题描述: 当发送短信的请求接口只需要…

用最少数量的箭引爆气球——力扣452

文章目录 题目描述解法一题目描述 解法一 int findMinArrowShots(vector<vector<int>>& nums){if(num