强化学习_06_pytorch-PPO实践(Pendulum-v1)

news2025/1/15 19:42:44

一、PPO简介

TRPO(Trust Range Policy Optimate)算法每一步更新都需要大量的运算,于是便有其改进版本PPO在2017年被提出。PPO 基于 TRPO 的思想,但是其算法实现更加简单。TRPO 使用泰勒展开近似、共轭梯度、线性搜索等方法直接求解。PPO 的优化目标与 TRPO 相同,但 PPO 用了一些相对简单的方法来求解。具体来说, PPO 有两种形式,一是PPO-惩罚,二是PPO-截断,我们接下来对这两种形式进行介绍。

二、PPO两种形式

2.1 PPO-Penalty

用拉格朗日乘数法直接将KL散度的限制放入目标函数,变成一个无约束的优化问题。同时还需要更新KL散度的系数。
a r g m a x θ E a − v π θ k E a − π θ k ( ⋅ ∣ s ) [ π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) − β D K L [ π θ k ( ⋅ ∣ s ) , π θ ( ⋅ ∣ s ) ] ] arg max_{\theta} E_{a- v^{\pi_{\theta_k}}}E_{a-\pi_{\theta_k}}( \cdot|s)[\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}A^{\pi_{\theta_k}}(s, a) - \beta D_{KL}[\pi_{\theta_k}(\cdot|s), \pi_{\theta}(\cdot|s)]] argmaxθEavπθkEaπθk(s)[πθk(as)πθ(as)Aπθk(s,a)βDKL[πθk(s),πθ(s)]]
d k = D K L v π θ k [ π θ k ( ⋅ ∣ s ) , π θ ( ⋅ ∣ s ) ] d_k=D^{v^{\pi_{\theta_k}}}_{KL}[\pi_{\theta_k}(\cdot|s), \pi_{\theta}(\cdot|s)] dk=DKLvπθk[πθk(s),πθ(s)]

  1. 如果 d k < δ / 1.5 d_k < \delta /1.5 dk<δ/1.5, 那么 β k + 1 = β k / 2 \beta_{k+1} = \beta_k/2 βk+1=βk/2
  2. 如果 d k > δ ∗ 1.5 d_k > \delta *1.5 dk>δ1.5, 那么 β k + 1 = β k ∗ 2 \beta_{k+1} = \beta_k * 2 βk+1=βk2
  3. 否则 β k + 1 = β k \beta_{k+1} = \beta_k βk+1=βk

相对PPO-Clip来说计算还是比较复杂,我们来看PPO-Clip的做法

2.2 PPO-Clip

ppo-Clip直接在目标函数中进行限制,保证新的参数和旧的参数的差距不会太大。
a r g m a x θ E a − v π θ k E a − π θ k ( ⋅ ∣ s ) [ m i n ( π θ ( a ∣ s ) π θ k ( a ∣ s ) A π θ k ( s , a ) , c l i p ( π θ ( a ∣ s ) π θ k ( a ∣ s ) , 1 − ϵ , 1 + ϵ ) A π θ k ( s , a ) ) ] arg max_{\theta} E_{a- v^{\pi_{\theta_k}}}E_{a-\pi_{\theta_k}}( \cdot|s)[min(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}A^{\pi_{\theta_k}}(s, a), clip(\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)}, 1-\epsilon, 1+\epsilon )A^{\pi_{\theta_k}}(s, a))] argmaxθEavπθkEaπθk(s)[min(πθk(as)πθ(as)Aπθk(s,a),clip(πθk(as)πθ(as),1ϵ,1+ϵ)Aπθk(s,a))]

就是将新旧动作的差异限定在 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1ϵ,1+ϵ]如果A > 0,说明这个动作的价值高于平均,最大化这个式子会增大 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as),但是不会让超过 1 + ϵ 1+\epsilon 1+ϵ。反之,A<0,最大化这个式子会减少 π θ ( a ∣ s ) π θ k ( a ∣ s ) \frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} πθk(as)πθ(as),但是不会让超过 1 − ϵ 1-\epsilon 1ϵ
可以简单绘制如下
在这里插入图片描述

绘图脚本


def plot_pip_clip(td_delta):
    gamma = 0.8
    lmbda = 0.95
    epsilon = 0.2
    A = []
    pi_pi = []
    adv = 0
    for delta in td_delta[::-1]:
        adv = gamma * lmbda * adv + delta
        # A > 0 
        A.append(adv)
        # Pi_pi > 1
        pi_pi.append((1+delta)/1)


    A = np.array(A)
    pi_pi = np.array(pi_pi)
    clip_ = np.clip(pi_pi, 1-epsilon, 1+epsilon) 
    L = np.where(pi_pi * A < clip_ * A, pi_pi * A,  clip_ * A)
    print(clip_)
    fig, axes = plt.subplots(figsize=(4, 4))
    axes2 = axes.twinx()
    axes.plot(pi_pi, clip_, label='clip_', color='darkred', alpha=0.7)
    axes2.plot(pi_pi, L, label='L', color='steelblue', linestyle='--')
    axes.set_title(f'A > 0 (gamma={gamma}, lmbda={lmbda}, epsilon={epsilon})')
    axes.set_xlim([1, 2])
    axes.set_ylim([min(clip_)-max(clip_), max(clip_)+min(clip_)])
    if pi_pi[-1] < pi_pi[0]:
        # axes.set_xlim([1, 0])
        axes.set_xlim([0, 1])
        axes.set_ylim([max(clip_)+min(clip_), min(clip_)-max(clip_)])
        axes.set_title(f'A < 0 (gamma={gamma}, lmbda={lmbda}, epsilon={epsilon})')


    axes.legend()
    axes2.legend(loc="upper left")
    plt.show()


td_delta = np.linspace(0, 2, 100)[::-1]
plot_pip_clip(td_delta)
td_delta = -np.linspace(0, 2, 100)[::-1]
plot_pip_clip(td_delta)

三、Pytorch实践

PPO-Clip更加简洁,同时大量的实验也表名PPO-Clip总是比PPO-Penalty 效果好。所以我们就用PPO-Clip去倒立钟摆中实践。
我们这次用Pendulum-v1action 也同样用连续变量。
这里我们需要做一个转化,一个连续的action力矩(一维的连续遍历)。

  • 将连续变量的每个dim都拟合为一个正态分布
  • 训练的时候训练action每个维度的均值和方差
  • 最终进行action选择的时候基于均值和方差所再拟合正态曲线进行抽样。

所以我们的策略网络按如下方法构造。

class policyNet(nn.Module):
    """
    continuity action:
    normal distribution (mean, std) 
    """
    def __init__(self, state_dim: int, hidden_layers_dim: typ.List, action_dim: int):
        super(policyNet, self).__init__()
        self.features = nn.ModuleList()
        for idx, h in enumerate(hidden_layers_dim):
            self.features.append(nn.ModuleDict({
                'linear': nn.Linear(hidden_layers_dim[idx-1] if idx else state_dim, h),
                'linear_action': nn.ReLU(inplace=True)
            }))

        self.fc_mu = nn.Linear(hidden_layers_dim[-1], action_dim)
        self.fc_std = nn.Linear(hidden_layers_dim[-1], action_dim)

    def forward(self, x):
        for layer in self.features:
            x = layer['linear_action'](layer['linear'](x))
        
        mean_ = 2.0 * torch.tanh(self.fc_mu(x))
        # np.log(1 + np.exp(2))
        std = F.softplus(self.fc_std(x))
        return mean_, std

3.1 构建智能体(PPO-Clip)

完整脚本可以参看笔者的github: PPO_lr.py

网络如下方法构造

def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().numpy()
    adv_list = []
    adv = 0
    for delta in td_delta[::-1]:
        adv = gamma * lmbda * adv + delta
        adv_list.append(adv)
    adv_list.reverse()
    return torch.FloatTensor(adv_list)


class PPO:
    """
    PPO算法, 采用截断方式
    """
    def __init__(self,
                state_dim: int,
                hidden_layers_dim: typ.List,
                action_dim: int,
                actor_lr: float,
                critic_lr: float,
                gamma: float,
                PPO_kwargs: typ.Dict,
                device: torch.device
                ):
        self.actor = policyNet(state_dim, hidden_layers_dim, action_dim).to(device)
        self.critic = valueNet(state_dim, hidden_layers_dim).to(device)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        
        self.gamma = gamma
        self.lmbda = PPO_kwargs['lmbda']
        self.ppo_epochs = PPO_kwargs['ppo_epochs'] # 一条序列的数据用来训练的轮次
        self.eps = PPO_kwargs['eps'] # PPO中截断范围的参数
        self.count = 0 
        self.device = device
    
    def policy(self, state):
        state = torch.FloatTensor([state]).to(self.device)
        mu, std = self.actor(state)
        action_dist = torch.distributions.Normal(mu, std)
        action = action_dist.sample()
        return [action.item()]
        
    def update(self, samples: deque):
        self.count += 1
        state, action, reward, next_state, done = zip(*samples)

        state = torch.FloatTensor(state).to(self.device)
        action = torch.tensor(action).view(-1, 1).to(self.device)
        reward = torch.tensor(reward).view(-1, 1).to(self.device)
        reward = (reward + 8.0) / 8.0  # 和TRPO一样,对奖励进行修改,方便训练
        next_state = torch.FloatTensor(next_state).to(self.device)
        done = torch.FloatTensor(done).view(-1, 1).to(self.device)
        
        td_target = reward + self.gamma * self.critic(next_state) * (1 - done)
        td_delta = td_target - self.critic(state)
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
                
        mu, std = self.actor(state)
        action_dists = torch.distributions.Normal(mu.detach(), std.detach())
        # 动作是正态分布
        old_log_probs = action_dists.log_prob(action)
        for _ in range(self.ppo_epochs):
            mu, std = self.actor(state)
            action_dists = torch.distributions.Normal(mu, std)
            log_prob = action_dists.log_prob(action)
            
            # e(log(a/b))
            ratio = torch.exp(log_prob - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage

            actor_loss = torch.mean(-torch.min(surr1, surr2)).float()
            critic_loss = torch.mean(
                F.mse_loss(self.critic(state).float(), td_target.detach().float())
            ).float()
            self.actor_opt.zero_grad()
            self.critic_opt.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_opt.step()
            self.critic_opt.step()

3.2 智能体训练



class Config:
    num_episode = 1200
    state_dim = None
    hidden_layers_dim = [ 128, 128 ]
    action_dim = 20
    actor_lr = 1e-4
    critic_lr = 5e-3
    PPO_kwargs = {
        'lmbda': 0.9,
        'eps': 0.2,
        'ppo_epochs': 10
    }
    gamma = 0.9
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    buffer_size = 20480
    minimal_size = 1024
    batch_size = 128
    save_path = r'D:\TMP\ac_model.ckpt'
    # 回合停止控制
    max_episode_rewards = 260
    max_episode_steps = 260
    
    
    def __init__(self, env):
        self.state_dim = env.observation_space.shape[0]
        try:
            self.action_dim = env.action_space.n
        except Exception as e:
            self.action_dim = env.action_space.shape[0]
        print(f'device={self.device} | env={str(env)}')



def train_agent(env, cfg):
    ac_agent = PPO(
        state_dim=cfg.state_dim,
        hidden_layers_dim=cfg.hidden_layers_dim,
        action_dim=cfg.action_dim,
        actor_lr=cfg.actor_lr,
        critic_lr=cfg.critic_lr,
        gamma=cfg.gamma,
        PPO_kwargs=cfg.PPO_kwargs,
        device=cfg.device
    )           
    tq_bar = tqdm(range(cfg.num_episode))
    rewards_list = []
    now_reward = 0
    bf_reward = -np.inf
    for i in tq_bar:
        buffer_ = replayBuffer(cfg.buffer_size)
        tq_bar.set_description(f'Episode [ {i+1} / {cfg.num_episode} ]')    
        s, _ = env.reset()
        done = False
        episode_rewards = 0
        steps = 0
        while not done:
            a = ac_agent.policy(s)
            n_s, r, done, _, _ = env.step(a)
            buffer_.add(s, a, r, n_s, done)
            s = n_s
            episode_rewards += r
            steps += 1
            if (episode_rewards >= cfg.max_episode_rewards) or (steps >= cfg.max_episode_steps):
                break

        ac_agent.update(buffer_.buffer)
        rewards_list.append(episode_rewards)
        now_reward = np.mean(rewards_list[-10:])
        if bf_reward < now_reward:
            torch.save(ac_agent.actor.state_dict(), cfg.save_path)
            bf_reward = now_reward
        
        tq_bar.set_postfix({'lastMeanRewards': f'{now_reward:.2f}', 'BEST': f'{bf_reward:.2f}'})
    env.close()
    return ac_agent



print('=='*35)
print('Training Pendulum-v1')
env = gym.make('Pendulum-v1')
cfg = Config(env)
ac_agent = train_agent(env, cfg)

四、训练出的智能体观测

最后将训练的最好的网络拿出来进行观察

ac_agent.actor.load_state_dict(torch.load(cfg.save_path))
play(gym.make('Pendulum-v1', render_mode="human"), ac_agent, cfg)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/142907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

可观测性--数据源

文章目录监控数据来源端上访问应用程序业务监控基础设施可观测性核心概念日志&#xff08;Logging&#xff09;统计指标&#xff08;Metrics&#xff09;链路追踪&#xff08;Tracing&#xff09;三者之间关系监控数据来源 我们一般讲的数据观测&#xff0c;其实观测的就是从发…

【Linux】计算机软硬件体系结构

文章目录冯诺依曼体系结构操作系统(Operator System)什么是操作系统为什么要有操作系统操作系统是怎么实现管理的系统调用接口和库函数总结冯诺依曼体系结构 谈到计算机的硬件结构&#xff0c;第一个想到的必然是经典的冯诺依曼体系结构&#xff1a; 我们常见的计算机&#xf…

在购买低代码产品时,源码是必需的吗?

编者按&#xff1a;企业在采购软件或者平台时&#xff0c;到底需不需要源码&#xff1f;本文分析了源码交付的对于不同规模和情况企业的意义&#xff0c;并介绍了源码交付的低代码平台。关键词&#xff1a;源码交付&#xff0c;可视化设计&#xff0c;私有化部署&#xff0c;多…

数据上线:首届6G智能无线通信系统大赛OPPO赛道评测正式开启

12月26日&#xff0c;首届6G智能无线通信系统大赛——面向小样本条件场景自适应及在线更新需求的无线AI设计赛题已经正式上线&#xff0c;数据集也已经在1月3日正式上线啦&#xff0c;评测同步开启&#xff0c;快来打擂冲榜&#xff01; 文末还将揭晓本赛题专属活动&#xff0…

图像锐化处理之一阶微分算子

图像锐化是通过增强图像的边缘和细节来提高图像的清晰度的操作。这种操作通常用于将模糊或不清晰的图像改进为更清晰的图像。由于微分是对函数局部变化率的一种描述&#xff0c;因此图像锐化算法的实现可基于空间微分。 一阶微分算子 对任意一阶微分的定义必须满足两点&#xf…

采用热电偶温度传感器实现超高精度温度跟踪控制的解决方案

摘要&#xff1a;针对温度跟踪控制中存在热电堆信号小致使控制器温度跟踪控制精度差&#xff0c;以及热电阻形式的温度跟踪控制中需要额外配置惠斯特电桥进行转换的问题&#xff0c;本文提出相应的解决方案。解决方案的核心是采用一个多功能的超高精度PID控制器&#xff0c;具有…

ubuntu18.04安装mysql5.7.32

目录一、下载mysql安装包二、下载依赖三、安装mysql四、导入sql一、下载mysql安装包 下载地址&#xff1a;https://downloads.mysql.com/archives/community/ 下载包 mysql-server_5.7.32-1ubuntu18.04_amd64.deb-bundle.tar 下载后解压&#xff0c;里面包含要安装的deb包 二、…

计算机编码

字符的表示原理 计算机内所有信息都是使用0和1进行表示的。 对于一个短路来说&#xff0c;0代表关&#xff0c;1代表开。那把这些电路组合起来就可以有长串0和1组成的二进制数字&#xff0c;我们对这些数字进行编码和解码&#xff0c;我们就能用它来表示我们想要表示的东西了…

蓝牙模块芯片串口透传的AT指令模式和波特率是什么意思

一、什么是蓝牙串口透传模块的模式 蓝牙串口模块&#xff0c;一般都会有两个模式&#xff0c;即AT指令模式&#xff0c;以及透传模式 1、有的模块&#xff0c;会通过一个GPIO口来选择当前是什么模式&#xff0c;比如将一个IO口拉低则进入透传模式&#xff0c;也就是不再识别A…

element-ui 表格el-table高度不是一个固定值时固定表头

elementui中为表格组件提供了height属性实现固定表头 height可以为数字或者字符串&#xff0c;当为一个数字时表示固定的高度&#xff0c;也可以为百分比等字符串。 当height不是一个固定值时&#xff0c;如期望表格可以填充完页面剩余空间&#xff0c;并且固定表头时&#x…

Mysql之常见可视化管理工具

mysql在日常开发中作为基础软件&#xff0c;对其数据的管理必不可少&#xff0c;除了系统自带的命令行管理工具之外&#xff0c;还有许多其他的图形化管理工具&#xff0c;下面介绍常见的mysql图形化管理工具。 1、Navicat Navicat 是一个桌面版 MySQL 数据库管理和开发工具。…

【Linux操作系统】程序的编译和动静态链接

文章目录一.编译写在前面1.预处理2.编译3.汇编二.(动静态)链接1.动态链接2.静态链接3.静态链接库的下载安装4.windows下动静态库的后缀一.编译 写在前面 编译这整个过程都只是在编译你自己写的代码,直到链接才让你的代码和库的代码关联起来,最终形成可执行程序 源程序到可执行…

靶机测试Os-ByteSec笔记

靶机测试Os-ByteSec笔记 靶机描述 Back to the Top Difficulty : Intermediate Flag : 2 Flag first user And second root Learning : exploit | SMB | Enumration | Stenography | Privilege Escalation Contact … https://www.linkedin.com/in/rahulgehlaut/ This w…

EXSi root密码忘记通过centos7镜像破解

1.安装软碟通UltraISO刻录U盘启动盘 下载阿里云centos7镜像&#xff0c;选择mini的链接&#xff1a;https://mirrors.aliyun.com/centos/7.9.2009/isos/x86_64/CentOS-7-x86_64-Minimal-2207-02.iso?spma2c6h.25603864.0.0.28f76aeapbXyYT 打开软碟通&#xff0c;打开下载的…

C语言宏定义立即数后缀U的含义

背景 在看开源的相关代码中&#xff0c;会有下面的宏定义用法 #define TEST_VALUE (0xFFFFFFFFU) 其和下面的宏定义区别是什么呢&#xff1f; #define TEST_VALUE (0xFFFFFFFF) 答疑 U表示 unsigned 无符号后缀&#xff0c;关于后缀的表述C99标准有如下定义&…

私有部署V3.8:自建内部应用库和预置应用

2022年12月27日&#xff0c;明道云私有部署V3.8正式发布。除了同步更新明道云SaaS版V7.8的功能以外&#xff0c;V3.8还将应用库功能下放至私有部署&#xff0c;私有部署用户可以自建企业内部应用库&#xff0c;并且给新创建的组织预置应用了。 注意&#xff1a;该功能仅面向明…

CSS知识点精学4-学成项目案例实现

根目录 先写好项目根目录 网站的首页,所有网站的首页都叫index.html,因为服务器找首页都是找index.html 一般网站页面配套的css文件与网站设置为相同的名字 比如index.html搭配index.css 准备工作 首先&#xff0c;我们发现没一个模块都是居中显示的&#xff0c;抓住一个边…

Python压缩模块gzip

文章目录初步认识压缩和解压缩函数初步认识 gzip是用于处理gzip格式的模块&#xff0c;相当于是zlib模块面向文件的一个应用&#xff0c;其最常用的函数为open。 有了open&#xff0c;那就得演示一下文件读写 import gzip with gzip.open(test.txt.gz, wb) as f:f.write(&qu…

Halcon笔记1

一、前言 最近来触碰一下halcon&#xff0c;一直以来作为ai算法工程师&#xff0c;虽然知道halcon&#xff0c;但是一直也没有用过 对于我们用户来说&#xff0c;halcon与opencv的差距主要在下面&#xff1a; &#xff08;1&#xff09;halcon是闭源的&#xff0c;商业的软件…

ATJ2158 LRADC的使用

LRADCLRADC对应引脚LRADC采样电压范围及位数使用LRADC涉及到的驱动文件如何使用不同的LRADC通道LRADC对应引脚 LRADC对应引脚备注LRACDC1WIO0/WIO1LRACDC2GPIO8/GPIO20LRACDC3GPIO9/GPIO21LRACDC4GPIO35LRACDC5GPIO5LRACDC6无没有找到相应的引脚LRACDC7GPIO63 LRADC采样电压范…