强化学习从基础到进阶-案例与实践[5.1]:Policy Gradient-Cart pole游戏展示

news2024/11/18 6:20:16

强化学习从基础到进阶-案例与实践[5.1]:Policy Gradient-Cart pole游戏展示

  • 强化学习(Reinforcement learning,简称RL)是机器学习中的一个领域,区别与监督学习和无监督学习,强调如何基于环境而行动,以取得最大化的预期利益。
  • 基本操作步骤:智能体agent在环境environment中学习,根据环境的状态state(或观测到的observation),执行动作action,并根据环境的反馈reward(奖励)来指导更好的动作。

比如本项目的Cart pole小游戏中,agent就是动图中的杆子,杆子有向左向右两种action

1.Policy Gradient简介

  • 在强化学习中,有两大类方法,一种基于值(Value-based),一种基于策略(Policy-based

    • Value-based的算法的典型代表为Q-learningSARSA,将Q函数优化到最优,再根据Q函数取最优策略。
    • Policy-based的算法的典型代表为Policy Gradient,直接优化策略函数。
  • 采用神经网络拟合策略函数,需计算策略梯度用于优化策略网络。

    • 优化的目标是在策略π(s,a)的期望回报:所有的轨迹获得的回报R与对应的轨迹发生概率p的加权和,当N足够大时,可通过采样N个Episode求平均的方式近似表达。

    • 优化目标对参数θ求导后得到策略梯度:

## 安装依赖
!pip install pygame
!pip install gym
!pip install atari_py
!pip install parl
import gym
import os
import random
import collections

import paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F

2.模型Model

这里的模型可以根据自己的需求选择不同的神经网络组建。

PolicyGradient用来定义前向(Forward)网络,可以自由的定制自己的网络结构。

class PolicyGradient(nn.Layer):
    def __init__(self, act_dim):
        super(PolicyGradient, self).__init__()
        act_dim = act_dim
        hid1_size = act_dim * 10

        self.linear1 = nn.Linear(in_features=4, out_features=hid1_size)
        self.linear2 = nn.Linear(in_features=hid1_size, out_features=act_dim)

    def forward(self, obs):
        out = self.linear1(obs)
        out = paddle.tanh(out)
        out = self.linear2(out)
        out = F.softmax(out)
        return out

3.智能体Agent的学习函数

这里包括模型探索与模型训练两个部分

Agent负责算法与环境的交互,在交互过程中把生成的数据提供给Algorithm来更新模型(Model),数据的预处理流程也一般定义在这里。

def sample(obs, MODEL):
    global ACTION_DIM
    obs = np.expand_dims(obs, axis=0)
    obs = paddle.to_tensor(obs, dtype='float32')
    act = MODEL(obs)
    act_prob = np.squeeze(act, axis=0)
    act = np.random.choice(range(ACTION_DIM), p=act_prob.numpy())
    return act


def learn(obs, action, reward, MODEL):
    obs = np.array(obs).astype('float32')
    obs = paddle.to_tensor(obs)
    act_prob = MODEL(obs)
    action = paddle.to_tensor(action.astype('int32'))
    log_prob = paddle.sum(-1.0 * paddle.log(act_prob) * F.one_hot(action, act_prob.shape[1]), axis=1)
    reward = paddle.to_tensor(reward.astype('float32'))
    cost = log_prob * reward
    cost = paddle.sum(cost)

    opt = paddle.optimizer.Adam(learning_rate=LEARNING_RATE,
                                parameters=MODEL.parameters())  # 优化器(动态图)
    cost.backward()
    opt.step()
    opt.clear_grad()
    return cost.numpy()

4.模型梯度更新算法

def run_train(env, MODEL):
    MODEL.train()
    obs_list, action_list, total_reward = [], [], []
    obs = env.reset()

    while True:
        # 获取随机动作和执行游戏
        obs_list.append(obs)
        action = sample(obs, MODEL) # 采样动作
        action_list.append(action)
        
        obs, reward, isOver, info = env.step(action)
        total_reward.append(reward)
        
        # 结束游戏
        if isOver:
            break
    return obs_list, action_list, total_reward


def evaluate(model, env, render=False):
    model.eval()
    eval_reward = []
    for i in range(5):
        obs = env.reset()
        episode_reward = 0
        while True:
            obs = np.expand_dims(obs, axis=0)
            obs = paddle.to_tensor(obs, dtype='float32')
            action = model(obs)
            action = np.argmax(action.numpy())
            obs, reward, done, _ = env.step(action)
            episode_reward += reward
            if render:
                env.render()
            if done:
                break
        eval_reward.append(episode_reward)
    return np.mean(eval_reward)

5.训练函数与验证函数

设置超参数

LEARNING_RATE = 0.001  # 学习率大小

OBS_DIM = None
ACTION_DIM = None

# 根据一个episode的每个step的reward列表,计算每一个Step的Gt
def calc_reward_to_go(reward_list, gamma=1.0):
    for i in range(len(reward_list) - 2, -1, -1):
        # G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1
        reward_list[i] += gamma * reward_list[i + 1]  # Gt
    return np.array(reward_list)
    
def main():
    global OBS_DIM
    global ACTION_DIM

    train_step_list = []
    train_reward_list = []
    evaluate_step_list = []
    evaluate_reward_list = []

    # 初始化游戏
    env = gym.make('CartPole-v0')
    # 图像输入形状和动作维度
    action_dim = env.action_space.n
    obs_dim = env.observation_space.shape[0]
    OBS_DIM = obs_dim
    ACTION_DIM = action_dim
    max_score = -int(1e4)

    # 创建存储执行游戏的内存
    MODEL = PolicyGradient(ACTION_DIM)
    TARGET_MODEL = PolicyGradient(ACTION_DIM)

    # 开始训练
    print("start training...")
    # 训练max_episode个回合,test部分不计算入episode数量
    for i in range(1000):
        obs_list, action_list, reward_list = run_train(env, MODEL)
        if i % 10 == 0:
            print("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))

        batch_obs = np.array(obs_list)
        batch_action = np.array(action_list)
        batch_reward = calc_reward_to_go(reward_list)
        cost = learn(batch_obs, batch_action, batch_reward, MODEL)

        if (i + 1) % 100 == 0:
            total_reward = evaluate(MODEL, env, render=False) # render=True 查看渲染效果,需要在本地运行,AIStudio无法显示
            print("Test reward: {}".format(total_reward))


if __name__ == '__main__':
    main()
W0630 11:26:18.969960   322 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0630 11:26:18.974581   322 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


start training...
Episode 0, Reward Sum 37.0.
Episode 10, Reward Sum 27.0.
Episode 20, Reward Sum 32.0.
Episode 30, Reward Sum 20.0.
Episode 40, Reward Sum 18.0.
Episode 50, Reward Sum 38.0.
Episode 60, Reward Sum 52.0.
Episode 70, Reward Sum 19.0.
Episode 80, Reward Sum 27.0.
Episode 90, Reward Sum 13.0.
Test reward: 42.8
Episode 100, Reward Sum 28.0.
Episode 110, Reward Sum 44.0.
Episode 120, Reward Sum 30.0.
Episode 130, Reward Sum 28.0.
Episode 140, Reward Sum 27.0.
Episode 150, Reward Sum 47.0.
Episode 160, Reward Sum 55.0.
Episode 170, Reward Sum 26.0.
Episode 180, Reward Sum 47.0.
Episode 190, Reward Sum 17.0.
Test reward: 42.8
Episode 200, Reward Sum 23.0.
Episode 210, Reward Sum 19.0.
Episode 220, Reward Sum 15.0.
Episode 230, Reward Sum 59.0.
Episode 240, Reward Sum 59.0.
Episode 250, Reward Sum 32.0.
Episode 260, Reward Sum 58.0.
Episode 270, Reward Sum 18.0.
Episode 280, Reward Sum 24.0.
Episode 290, Reward Sum 64.0.
Test reward: 116.8
Episode 300, Reward Sum 54.0.
Episode 310, Reward Sum 28.0.
Episode 320, Reward Sum 44.0.
Episode 330, Reward Sum 18.0.
Episode 340, Reward Sum 89.0.
Episode 350, Reward Sum 26.0.
Episode 360, Reward Sum 57.0.
Episode 370, Reward Sum 54.0.
Episode 380, Reward Sum 105.0.
Episode 390, Reward Sum 56.0.
Test reward: 94.0
Episode 400, Reward Sum 70.0.
Episode 410, Reward Sum 35.0.
Episode 420, Reward Sum 45.0.
Episode 430, Reward Sum 117.0.
Episode 440, Reward Sum 50.0.
Episode 450, Reward Sum 35.0.
Episode 460, Reward Sum 41.0.
Episode 470, Reward Sum 43.0.
Episode 480, Reward Sum 75.0.
Episode 490, Reward Sum 37.0.
Test reward: 57.6
Episode 500, Reward Sum 40.0.
Episode 510, Reward Sum 85.0.
Episode 520, Reward Sum 86.0.
Episode 530, Reward Sum 30.0.
Episode 540, Reward Sum 68.0.
Episode 550, Reward Sum 25.0.
Episode 560, Reward Sum 82.0.
Episode 570, Reward Sum 54.0.
Episode 580, Reward Sum 53.0.
Episode 590, Reward Sum 58.0.
Test reward: 147.2
Episode 600, Reward Sum 24.0.
Episode 610, Reward Sum 78.0.
Episode 620, Reward Sum 62.0.
Episode 630, Reward Sum 58.0.
Episode 640, Reward Sum 50.0.
Episode 650, Reward Sum 67.0.
Episode 660, Reward Sum 68.0.
Episode 670, Reward Sum 51.0.
Episode 680, Reward Sum 36.0.
Episode 690, Reward Sum 69.0.
Test reward: 84.2
Episode 700, Reward Sum 34.0.
Episode 710, Reward Sum 59.0.
Episode 720, Reward Sum 56.0.
Episode 730, Reward Sum 72.0.
Episode 740, Reward Sum 28.0.
Episode 750, Reward Sum 35.0.
Episode 760, Reward Sum 54.0.
Episode 770, Reward Sum 61.0.
Episode 780, Reward Sum 32.0.
Episode 790, Reward Sum 147.0.
Test reward: 123.0
Episode 800, Reward Sum 129.0.
Episode 810, Reward Sum 65.0.
Episode 820, Reward Sum 73.0.
Episode 830, Reward Sum 54.0.
Episode 840, Reward Sum 60.0.
Episode 850, Reward Sum 71.0.
Episode 860, Reward Sum 54.0.
Episode 870, Reward Sum 74.0.
Episode 880, Reward Sum 34.0.
Episode 890, Reward Sum 55.0.
Test reward: 104.8
Episode 900, Reward Sum 41.0.
Episode 910, Reward Sum 111.0.
Episode 920, Reward Sum 33.0.
Episode 930, Reward Sum 49.0.
Episode 940, Reward Sum 62.0.
Episode 950, Reward Sum 114.0.
Episode 960, Reward Sum 52.0.
Episode 970, Reward Sum 64.0.
Episode 980, Reward Sum 94.0.
Episode 990, Reward Sum 90.0.
Test reward: 72.2

项目链接fork一下即可运行

https://www.heywhale.com/mw/project/649e7dc170567260f8f12d54

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/705445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于u(x,t)=f(x)*g(t)形式证明的思考

突然想起来,二维高斯函数是可以拆分成两个一维高斯函数相乘的: 原来在学概率论的时候,证明过,这只能说高斯函数可以,这是一个思路。 一维波动函数应该也是这个套路。 那么还有没有其他函数可以如此,有如此…

javascript和css实现瀑布流排列

Grid 布局 实现瀑布流 html <div class"gridDiv"><divv-for"(item,index) in 20":style"{grid-row: auto / span ${heightArray[index]}}"><div class"gridItemConten"><div class"gridText">{{ite…

VS2022编译运行VS2015的项目

最近新装了VisualStudio2022&#xff0c;有一些VS2015老的项目需要运行&#xff0c;但不想再安装VS2015&#xff0c;就想能否直接在VS2022编译运行&#xff0c;研究一下发现可行&#xff0c;记录一下。 1. 直接升级VS2015项目到2022使用windows sdk 10.0 发现老代码里的一些语…

#10044 「一本通 2.2 例 2」Power Strings(KMP)(内附封面)

题目描述 原题来自&#xff1a;POJ 2406 给定若干个长度 \le 10^6 的字符串&#xff0c;询问每个字符串最多是由多少个相同的子字符串重复连接而成的。如&#xff1a;ababab 则最多有 3 个 ab 连接而成。 输入格式 输入若干行&#xff0c;每行有一个字符串。特别的&#xf…

第一章:R-CNN网络详解(丰富特征层次用于准确的目标检测和语义分割技术报告(v5))

(目标检测篇&#xff09;系列文章目录 第一章:R-CNN网络详解 第二章:Fast R-CNN网络详解 第三章:Faster R-CNN网络详解 第四章:YOLO v1网络详解 第五章:YOLO v2网络详解 第六章:YOLO v3网络详解 文章目录 系列文章目录技术干货集锦前言一、摘要二、正文分析 1.引入库2.读入…

迅为RK3568/RK3588开发板视频教程 | RKNPU2 从入门到实践一套搞定!

迅为电子嵌入式视频教程更新了&#xff01;——「AI深度学习推理加速器--RKNPU2 从入门到实践」&#xff08;基于RK3588和RK3568&#xff09; 课程内容分为三个阶段&#xff1a;认识RKNPU、RKNPU开发学习以及项目实战。 首先&#xff0c;我们将从认识RKNPU阶段开始&#xff0…

Redis实战——短信登录(一)

项目搭建 前期准备 导入SQL CREATE TABLE tb_user (id bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 主键,phone varchar(11) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL COMMENT 手机号码,password varchar(128) CHARACTER SET utf8mb4 COLLATE utf8mb4…

【Redis】Redis主从复制哨兵模式集群

文章目录 一、Redis 持久化1. 主从复制2. 哨兵模式3. 集群 二、 Redis 主从复制1. 概述2. 主从复制的作用3. 主从复制流程4. 搭建 Redis 主从复制4.1 环境准备4.2 安装 Redis4.3 修改 Master 节点配置文件4.4 修改Slave节点配置文件&#xff08;Slave1和Slave2配置相同&#xf…

关于Docker中 docker build 时no such file or directory报错

ERROR: failed to solve: failed to read dockerfile: open /var/lib/docker/tmp/buildkit-mount545066663/Dockerfile: no such file or directory 主要原因是命令行没有在文件夹下执行docker build&#xff0c;cd到指定文件夹下执行即可

windows-x86使用qemu打开x86和arm虚拟机

1、下载qemu软件 下载固件&#xff08;UEFI固件镜像文件&#xff0c;BIOS的替代方案&#xff09;&#xff09; 2、配置qemu环境变量 使用cmd执行qemu命令&#xff0c;配置好环境变量比较方便 3、准备镜像 准备好一个x86的镜像或者arm的镜像&#xff0c;格式可以为qcow2 4、打…

STM32 时钟 寄存器 异常和中断

时钟: 51单片机中有时钟和时钟树的概念&#xff0c;外设只有GPIO、定时器、和一个串口&#xff0c;使用的都是11.0592MHZ的频率&#xff0c;除了定时器外&#xff0c;其他外设只要上电就可以使用。 stm32不同外设对应的时钟频率不同&#xff0c;故有时钟树的概念 PLL&#xf…

提升半导体制造效率,了解半导体CMS系统的关键作用

随着半导体制造业的不断发展&#xff0c;提高生产效率成为企业追求的核心目标。在这一背景下&#xff0c;CMS系统&#xff08;中央设备状态监控系统&#xff09;的关键作用愈发凸显。本文将深入探讨CMS系统在提升半导体制造效率方面的关键作用&#xff0c;帮助读者全面了解该系…

Android Studio Could not reserve enough space for 2097152KB object heap

Android Studio Could not reserve enough space for 2097152KB object heap android studio 编译的项目的时候&#xff0c;出现的内存不足问题&#xff0c;实际上android studio会有引导设置内存大小&#xff0c;可能都不太在意在哪个地方&#xff0c;设置完就完事了&#xff…

linux上搭建samba服务

Samba是在Linux和UNIX系统上实现SMB协议的一个免费软件&#xff0c;由服务器及客户端程序构成。SMB&#xff08;Server Messages Block&#xff0c;信息服务块&#xff09;是一种在局域网上共享文件和打印机的一种通信协议&#xff0c;它为局域网内的不同计算机之间提供文件及打…

【NOSQL数据库】Rdeis持久化

目录 一、Redis高可用1.2Redis高可用技术 二、Redis持久化2.1Redis提供的两种持久话方式 三、RDB持久化3.1触发条件3.1.1手动触发3.1.2自动触发3.1.3其他自动触发机制 3.2执行流程3.3启动时加载 四、AOF持久化4.1 开启AOF4.2执行流程4.2.1命令追加(append)4.2.2文件写入(write)…

【LeetCode热题100】打卡第30天:从前序遍历与中序遍历序列构造二叉树二叉树展开为链表

文章目录 【LeetCode热题100】打卡第30天&#xff1a;从前序遍历与中序遍历序列构造二叉树&二叉树展开为链表⛅前言 从前序与中序遍历构造二叉树&#x1f512;题目&#x1f511;题解 从中序与后序遍历构造二叉树&#x1f512;题目&#x1f511;题解 二叉树展开为链表&#…

使用 Maya Mari 设计 3D 波斯风格道具(p1)

今天瑞云渲染小编给大家带来了Simin Farrokh Ahmadi 分享的Persian Afternoon 项目过程&#xff0c;解释了 Maya 和 Mari 中的建模、纹理和照明过程。 介绍 我的名字是西敏-法罗赫-艾哈迈迪&#xff0c;人们都叫我辛巴 在我十几岁的时候&#xff0c;我就意识到我喜欢艺术和创造…

python最佳开发环境组合(pycharm+anaconda)

一、pycharmanaconda是python 最佳开发环境组合 1.pycharm与vscode对比 pycharm社区版与pycharm pro pycharm pro 与vscode 二、anaconda Anaconda Python 集成包 工具箱。 所以没有必要下载传统Python (cPython)个人十分不推荐使用传统python做科学计算&#xff0c; 一来…

【王道·操作系统】第五章 输入输出管理【未完】

一、I/O设备 1.1 I/O设备的基本概念 I/O&#xff0c;Input/Output&#xff1a;输入/输出I/O 设备&#xff1a;将数据输入到计算机&#xff0c;或者可以接收计算机输出数据的外部设备&#xff0c;属于计算机中的硬件部件UNIX系统将外部设备抽象为一种特殊的文件&#xff0c;用户…

C语言无类型指针 void* 学习

int * 类型的指针变量&#xff0c;只能保存int型的数据的地址&#xff1b; double * 类型的指针变量&#xff0c;只能保存double型的数据的地址&#xff1b; void 指针可以指向任意类型的数据&#xff0c;可以用任意类型的指针对 void 指针赋值&#xff1b; void 在英文中作为…