强化学习 PPO算法和代码

news2025/1/13 8:09:54

PPO 效果

在这里插入图片描述

字体找不到

ubuntu python findfont: Font family ‘Alibaba PuHuiTi 3.0’ not found.

shell 清除缓存:

rm ~/.cache/matplotlib -rf

到这里下载 阿里巴巴普惠体3.0 https://fonts.alibabagroup.com/
然后安装字体
在这里插入图片描述

PPO

import  matplotlib
from 	matplotlib import pyplot as plt
matplotlib.rcParams['font.size'] = 18
matplotlib.rcParams['figure.titlesize'] = 18
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['Alibaba PuHuiTi 3.0']
matplotlib.rcParams['axes.unicode_minus']=False

plt.figure()

import  gym,os
import  numpy as np
import  tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import layers,optimizers,losses
from    collections import namedtuple
# from    torch.utils.data import SubsetRandomSampler,BatchSampler

env = gym.make('CartPole-v0')  # 创建游戏环境
env.seed(2222)
tf.random.set_seed(2222)
np.random.seed(2222)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')



gamma = 0.98 # 激励衰减因子
epsilon = 0.2 # PPO误差超参数0.8~1.2
batch_size = 32 # batch size


# 创建游戏环境
env = gym.make('CartPole-v1').unwrapped
Transition = namedtuple('Transition', ['state', 'action', 'a_log_prob', 'reward', 'next_state'])


class Actor(keras.Model):
    """ 策略网络 """
    def __init__(self):
        super(Actor, self).__init__()
        # 策略网络,也叫Actor网络,输出为概率分布pi(a|s)
        self.fc1 = layers.Dense(100, kernel_initializer='he_normal')
        self.fc2 = layers.Dense(2, kernel_initializer='he_normal')

    def call(self, inputs):
        x = tf.nn.relu(self.fc1(inputs))
        x = self.fc2(x)
        x = tf.nn.softmax(x, axis=1) # 转换成概率
        return x

class Critic(keras.Model):
    """ 奖励预测网络 """
    def __init__(self):
        super(Critic, self).__init__()
        # 偏置b的估值网络,也叫Critic网络,输出为v(s)
        self.fc1 = layers.Dense(100, kernel_initializer='he_normal')
        self.fc2 = layers.Dense(1, kernel_initializer='he_normal')

    def call(self, inputs):
        x = tf.nn.relu(self.fc1(inputs))
        x = self.fc2(x)
        return x




class PPO():
    # PPO算法主体
    def __init__(self):
        super(PPO, self).__init__()
        self.actor = Actor() # 创建Actor网络
        self.critic = Critic() # 创建Critic网络
        self.buffer = [] # 数据缓冲池
        self.actor_optimizer = optimizers.Adam(1e-3) # Actor优化器
        self.critic_optimizer = optimizers.Adam(3e-3) # Critic优化器

    def select_action(self, s):
        # 送入状态向量,获取策略: [4]
        s = tf.constant(s, dtype=tf.float32)
        # s: [4] => [1,4]
        s = tf.expand_dims(s, axis=0)
        # 获取策略分布: [1, 2]
        prob = self.actor(s)
        # 从类别分布中采样1个动作, shape: [1]
        a = tf.random.categorical(tf.math.log(prob), 1)[0]
        a = int(a)  # Tensor转数字
        return a, float(prob[0][a]) # 返回动作及其概率

    def get_value(self, s):
        # 送入状态向量,获取策略: [4]
        s = tf.constant(s, dtype=tf.float32)
        # s: [4] => [1,4]
        s = tf.expand_dims(s, axis=0)
        # 获取策略分布: [1, 2]
        v = self.critic(s)[0]
        return float(v) # 返回v(s)

    def store_transition(self, transition):
        # 存储采样数据
        self.buffer.append(transition)

    def optimize(self):
        # 优化网络主函数
        # 从缓存中取出样本数据,转换成Tensor
        state = tf.constant([t.state for t in self.buffer], dtype=tf.float32)
        action = tf.constant([t.action for t in self.buffer], dtype=tf.int32)
        action = tf.reshape(action,[-1,1])
        reward = [t.reward for t in self.buffer]
        old_action_log_prob = tf.constant([t.a_log_prob for t in self.buffer], dtype=tf.float32)
        old_action_log_prob = tf.reshape(old_action_log_prob, [-1,1])
        # 通过MC方法循环计算R(st)
        R = 0
        Rs = []
        for r in reward[::-1]: #逆序计算每一时刻t步骤的回报
            R = r + gamma * R
            Rs.insert(0, R)
        Rs = tf.constant(Rs, dtype=tf.float32)
        # 对缓冲池数据大致迭代10遍
        for _ in range(round(10*len(self.buffer)/batch_size)):
            # 随机从缓冲池采样batch_size大小样本
            index = np.random.choice(np.arange(len(self.buffer)), batch_size, replace=False)
            # 构建梯度跟踪环境
            with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
                # 取出R(st)采样真实值,[b,1] ,shape=(batch_size,1)
                v_target = tf.expand_dims(tf.gather(Rs, index, axis=0), axis=1)
                # 计算v(s)预测值,也就是偏置b,我们后面会介绍为什么写成v,shape=(batch_size,1)
                v = self.critic(tf.gather(state, index, axis=0))
                delta = v_target - v # 计算优势值,采样真实值与预测值的误差,shape=(batch_size,1)
                advantage = tf.stop_gradient(delta) # 断开梯度连接 ,shape=(batch_size,1)

                # 由于TF的gather_nd与pytorch的gather功能不一样,需要构造
                # gather_nd需要的坐标参数,indices:[b, 2]
                # pi_a = pi.gather(1, a) # pytorch只需要一行即可实现
                
                a = tf.gather(action, index, axis=0) # 取出batch的动作at , shape=(batch_size,1)
                pi = self.actor(tf.gather(state, index, axis=0)) # batch的状态st预测动作a分布pi(a|st) shape=(batch_size,2)
                indices = tf.expand_dims(tf.range(a.shape[0]), axis=1) #shape=(batch_size,1)
                indices = tf.concat([indices, a], axis=1) #shape=(batch_size,2),第1列是索引indices,第2列是a,a∈[0,1]也相当于索引
                pi_a = tf.gather_nd(pi, indices)  # pi预测中按照indices索引取出动作的概率值pi(at|st), shape=(batch_size,)
                pi_a = tf.expand_dims(pi_a, axis=1) # shape=(batch_size,1)
                # 重要性采样 = 预测概率/采样真实概率
                ratio = (pi_a / tf.gather(old_action_log_prob, index, axis=0)) # shape=(batch_size,1)
                surr1 = ratio * advantage  # shape=(batch_size,1)
                surr2 = tf.clip_by_value(ratio, 1 - epsilon, 1 + epsilon) * advantage  # shape=(batch_size,1)
                # PPO误差函数
                policy_loss = -tf.reduce_mean(tf.minimum(surr1, surr2)) # shape=()
                # 对于偏置v来说,希望与MC估计的R(st)越接近越好
                value_loss = losses.MSE(v_target, v) # shape=(batch_size,)
            # 优化策略网络
            grads = tape1.gradient(policy_loss, self.actor.trainable_variables)
            self.actor_optimizer.apply_gradients(zip(grads, self.actor.trainable_variables))
            # 优化偏置值网络
            grads = tape2.gradient(value_loss, self.critic.trainable_variables)
            self.critic_optimizer.apply_gradients(zip(grads, self.critic.trainable_variables))

        self.buffer = []  # 清空已训练数据


def main():
    agent = PPO()
    returns = [] # 统计总回报
    total = 0 # 一段时间内平均回报
    for i_epoch in range(500): # 训练回合数
        state = env.reset() # 复位环境
        for t in range(500): # 最多考虑500步
            # 通过最新策略与环境交互
            action, action_prob = agent.select_action(state)
            next_state, reward, done, _ , _= env.step(action)
            # 构建样本并存储 
            trans = Transition(state, action, action_prob, reward, next_state)
            agent.store_transition(trans)
            state = next_state # 刷新状态
            total += reward # 累积激励
            if done: # 合适的时间点训练网络
                if len(agent.buffer) >= batch_size:
                    agent.optimize() # 训练网络
                break

        if i_epoch % 20 == 0: # 每20个回合统计一次平均回报
            returns.append(total/20)
            total = 0
            print(i_epoch, returns[-1])

    print(np.array(returns))
    plt.figure()
    plt.plot(np.arange(len(returns))*20, np.array(returns))
    plt.plot(np.arange(len(returns))*20, np.array(returns), 's')
    plt.xlabel('回合数')
    plt.ylabel('总回报')
    plt.savefig('ppo-tf-cartpole.svg')


if __name__ == '__main__':
    main()
    print("end")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/877470.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

​​C++多态​​

目录 1. 多态的概念 2. 多态的定义及实现 多态的构成条件 虚函数 虚函数的重写 特例 override 和 final 1. final:修饰虚函数,表示该虚函数不能再被重写 2.override: 检查派生类虚函数是否重写了基类某个虚函数,如果没有重写编译报错…

【数据结构】二叉树篇|超清晰图解和详解:二叉树的最近公共祖先

博主简介:努力学习的22级计算机科学与技术本科生一枚🌸博主主页: 是瑶瑶子啦每日一言🌼: 你不能要求一片海洋,没有风暴,那不是海洋,是泥塘——毕淑敏 目录 一、题目二、题解三、代码 一、题目 …

Stable Diffusion +EbSynth应用实践和经验分享

Ebsynth应用 1.安装ffmpeg 2.安装pip install transparent-background,下载模型https://www.mediafire.com/file/gjvux7ys4to9b4v/latest.pth/file 放到C:\Users\自己的用户名.transparent-background\加一个ckpt_base.pth文件 3.秋叶安装ebsynth插件,重启webui 填写项目基本…

线段树-模板-区间查询-区间修改

【模板】线段树 2 传送门:https://www.luogu.com.cn/problem/P3373 题单:https://www.luogu.com.cn/training/16376#problems 题目描述 如题,已知一个数列,你需要进行下面三种操作: 将某区间每一个数乘上 x x x&a…

FPGA学习——驱动WS2812光源并进行动态显示

文章目录 一、WS2812手册分析1.1 WS2812灯源特性及概述1.2 手册重点内容分析1.2.1 产品概述1.2.2 码型及24bit数据设计 二、系统设计2.1 模块设计2.2 模块分析2.2.1 驱动模块2.2.1 数据控制模块 三、IP核设置及项目源码3.1 MIF文件设计3.2 ROM IP核调用3.3 FIFO IP核调用3.4 项…

机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?

一、引言 在实际应用中,特征选择作为机器学习和数据挖掘领域的重要环节,对于提高模型性能和减少计算开销具有关键影响。特征选择是从原始特征集中选择最相关和最具区分力的特征子集,以提高模型的泛化能力和可解释性。 特征选择在实践中具有以…

算法笔试 java 输入输出练习

在线编程题刷题训练 所有答案 scancer函数的用法 输入输出总结top!!!! java如何调用函数(方法) java刷acm的各种输入输出 vscode配置java环境 子函数的调用,直接定义一个static子函数调用就…

人工智能任务1-【NLP系列】句子嵌入的应用与多模型实现方式

大家好,我是微学AI,今天给大家介绍一下人工智能任务1-【NLP系列】句子嵌入的应用与多模型实现方式。句子嵌入是将句子映射到一个固定维度的向量表示形式,它在自然语言处理(NLP)中有着广泛的应用。通过将句子转化为向量…

ARM(汇编指令)

.global _start _start:/*mov r0,#0x5mov r1,#0x6 bl LoopLoop:cmp r0,r1beq stopsubhi r0,r0,r1subcc r1,r1,r0mov pc,lr*/ mov r0,#0x1mov r1,#0x0mov r2,#0x64bl Loop Loop:cmp r0,r2bhi stopadd r1,r1,r0add r0,r0,#0x01mov pc,lr stop:B stop.end

Android Ble蓝牙App(五)数据操作

Ble蓝牙App(五)数据操作 前言正文一、操作内容处理二、读取数据① 概念② 实操 三、写入数据① 概念② 实操 四、打开通知一、概念二、实操三、收到数据 五、源码 前言 关于低功耗蓝牙的服务、特性、属性、描述符都已经讲清楚了,而下面就是使…

vue自定义指令动态绑定

在企业微信侧边栏应用中,给dialog添加了拖拽功能,但是因为dialog高度超过了页面高度,所以高度100%时拖拽有个bug--自动贴到窗口顶部而且企业侧边栏宽高都有限制,拖拽效果并不理想,所以就想缩小dialog再进行拖拽。 拖拽…

(一)掌握最基本的Linux服务器用法——了解Linux服务器基本的使用方法、常用命令。

1、掌握最基本的Linux服务器用法 1、了解Linux服务器基本的使用方法、常用命令。 1、Linux系统简介 略 2、服务器连接方法 1、SSH远程终端,Windows可以使用xshell软件。 2、PuTTY主要用来远程连接服务器,缺点是功能单一,只是一个客户端&…

Tomcat+Http+Servlet

文章目录 1.HTTP1.1 请求和响应HTTP请求:请求行请求头请求体HTTP响应:响应行(状态行)响应头响应体 2. Apache Tomcat2.1 基本使用2.2 IDEA中创建 Maven Web项目2.3 IDEA中使用Tomcat 3. Servlet3.1 Servlet快速入门3.2 Servlet执行…

Python 3 使用Hadoop 3之MapReduce总结

MapReduce 运行原理 MapReduce简介 MapReduce是一种分布式计算模型,由Google提出,主要用于搜索领域,解决海量数据的计算问题。 MapReduce分成两个部分:Map(映射)和Reduce(归纳)。…

yolov8训练进阶:从配置文件读入配置参数

yolov8官方教程提供了2种训练方式,一种是通过命令行启动训练,一种是通过写代码启动。 命令行的方式启动方便,通过传入参数可以方便的调整训练参数,但这种方式不方便记录训练参数和调试训练代码。 自行写训练代码的方式更灵活&am…

【Vue框架】用户和请求

前言 在上一篇 【Vue框架】Vuex状态管理 针对Vuex状态管理以getters.js进行说明,没有对其中state引入的对象进行详细介绍,因为整体都比较简单,也就不对全部做详细介绍了;但其中的user.js涉及到获取用户的信息、前后端请求的token…

今天来给大家聊一聊什么是Hierarchical-CTC模型

随着人工智能领域的不断发展,语音识别技术在日常生活和工业应用中扮演着越来越重要的角色。为了提高识别准确性和效率,研究人员不断探索新的模型和算法。在这个领域中,Hierarchical-CTC模型引起了广泛的关注和兴趣。本文将介绍什么是Hierarch…

JavaFx基础学习【二】:Stage

一、介绍 窗口Stage为图中标绿部分: 实际为如下部分: 不同的操作系统表现的样式不同,以下都是以Windows操作系统为例,为了使大家更清楚Stage是那部分,直接看以下图,可能更清楚: 有点潦草&…

MachineLearningWu_15/P70-P71_AdamAndConv

x.1 算法参数更新 我们使用梯度下降算法来自动更新参数,但是由于学习率的不好选择性,我们有时候会下降地很快,有时候下降地很慢,我们期望有一种方式能够自动调整学习率的变化,这里引入Adaptive Moment Estimation/Ada…

City Walk带动茶饮品牌售1200万,媒介盒子带你探究奥秘

年轻人生活趋势又出现了一个新鲜词——City Walk,简单来说,City Walk就是没有目的地,没有目标,只是出行,填充自己的生活。 其实这个词源于gap year,而这个说法一直是国外的一些毕业生,大多会在…