强化学习_06_pytorch-PPO实践(Hopper-v4)

news2025/1/11 21:47:42

一、PPO优化

PPO的简介和实践可以看笔者之前的文章 强化学习_06_pytorch-PPO实践(Pendulum-v1)
针对之前的PPO做了主要以下优化:

  1. batch_normalize: 在mini_batch 函数中进行adv的normalize, 加速模型对adv的学习
  2. policyNet采用beta分布(0~1): 同时增加MaxMinScale 将beta分布产出值转换到action的分布空间
  3. 收集多个episode的数据,依次计算adv,后合并到一个dataloader中进行遍历:加速模型收敛

1.1 PPO2 代码

详细可见 Github: PPO2.py

class PPO2:
    """
    PPO2算法, 采用截断方式
    """
    def __init__(self,
                state_dim: int,
                actor_hidden_layers_dim: typ.List,
                critic_hidden_layers_dim: typ.List,
                action_dim: int,
                actor_lr: float,
                critic_lr: float,
                gamma: float,
                PPO_kwargs: typ.Dict,
                device: torch.device,
                reward_func: typ.Optional[typ.Callable]=None
                ):
        dist_type = PPO_kwargs.get('dist_type', 'beta')
        self.dist_type = dist_type
        self.actor = policyNet(state_dim, actor_hidden_layers_dim, action_dim, dist_type=dist_type).to(device)
        self.critic = valueNet(state_dim, critic_hidden_layers_dim).to(device)
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = gamma
        self.lmbda = PPO_kwargs['lmbda']
        self.k_epochs = PPO_kwargs['k_epochs'] # 一条序列的数据用来训练的轮次
        self.eps = PPO_kwargs['eps'] # PPO中截断范围的参数
        self.sgd_batch_size = PPO_kwargs.get('sgd_batch_size', 512)
        self.minibatch_size = PPO_kwargs.get('minibatch_size', 128)
        self.action_bound = PPO_kwargs.get('action_bound', 1.0)
        self.action_low = -1 * self.action_bound 
        self.action_high = self.action_bound
        if 'action_space' in PPO_kwargs:
            self.action_low = self.action_space.low
            self.action_high = self.action_space.high
        
        self.count = 0 
        self.device = device
        self.reward_func = reward_func
        self.min_batch_collate_func = partial(mini_batch, mini_batch_size=self.minibatch_size)
    
    def _action_fix(self, act):
        if self.dist_type == 'beta':
            # beta 0-1 -> low ~ high
            return act * (self.action_high - self.action_low) + self.action_low
        return act 
    
    def _action_return(self, act):
        if self.dist_type == 'beta':
            # low ~ high -> 0-1 
            act_out = (act - self.action_low) / (self.action_high - self.action_low)
            return act_out * 1 + 0
        return act 

    def policy(self, state):
        state = torch.FloatTensor(np.array([state])).to(self.device)
        action_dist = self.actor.get_dist(state, self.action_bound)
        action = action_dist.sample()
        action = self._action_fix(action)
        return action.cpu().detach().numpy()[0]
    
    def _one_deque_pp(self, samples: deque):
        state, action, reward, next_state, done = zip(*samples)
        state = torch.FloatTensor(np.stack(state)).to(self.device)
        action = torch.FloatTensor(np.stack(action)).to(self.device)
        reward = torch.tensor(np.stack(reward)).view(-1, 1).to(self.device)
        if self.reward_func is not None:
            reward = self.reward_func(reward)

        next_state = torch.FloatTensor(np.stack(next_state)).to(self.device)
        done = torch.FloatTensor(np.stack(done)).view(-1, 1).to(self.device)
        
        old_v = self.critic(state)
        td_target = reward + self.gamma * self.critic(next_state) * (1 - done)
        td_delta = td_target - old_v
        advantage = compute_advantage(self.gamma, self.lmbda, td_delta, done).to(self.device)
        # recompute
        td_target = advantage + old_v
        action_dists = self.actor.get_dist(state, self.action_bound)
        old_log_probs = action_dists.log_prob(self._action_return(action))
        return state, action, old_log_probs, advantage, td_target
        
    def data_prepare(self, samples_list: List[deque]):
        state_pt_list = []
        action_pt_list = []
        old_log_probs_pt_list = []
        advantage_pt_list = []
        td_target_pt_list = []
        for sample in samples_list:
            state_i, action_i, old_log_probs_i, advantage_i, td_target_i = self._one_deque_pp(sample)
            state_pt_list.append(state_i)
            action_pt_list.append(action_i)
            old_log_probs_pt_list.append(old_log_probs_i)
            advantage_pt_list.append(advantage_i)
            td_target_pt_list.append(td_target_i)
            
        state = torch.concat(state_pt_list) 
        action = torch.concat(action_pt_list) 
        old_log_probs = torch.concat(old_log_probs_pt_list) 
        advantage = torch.concat(advantage_pt_list) 
        td_target = torch.concat(td_target_pt_list)
        return state, action, old_log_probs, advantage, td_target
        
    def update(self, samples_list: List[deque]):
        state, action, old_log_probs, advantage, td_target = self.data_prepare(samples_list)
        if len(old_log_probs.shape) == 2:
            old_log_probs = old_log_probs.sum(dim=1)
        d_set = memDataset(state, action, old_log_probs, advantage, td_target)
        train_loader = DataLoader(
            d_set,
            batch_size=self.sgd_batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=self.min_batch_collate_func
        )
        
        for _ in range(self.k_epochs):
            for state_, action_, old_log_prob, adv, td_v in train_loader:
                action_dists = self.actor.get_dist(state_, self.action_bound)
                log_prob = action_dists.log_prob(self._action_return(action_))
                if len(log_prob.shape) == 2:
                    log_prob = log_prob.sum(dim=1)
                # e(log(a/b))
                ratio = torch.exp(log_prob - old_log_prob.detach())
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * adv

                actor_loss = torch.mean(-torch.min(surr1, surr2)).float()
                critic_loss = torch.mean(
                    F.mse_loss(self.critic(state_).float(), td_v.detach().float())
                ).float()

                self.actor_opt.zero_grad()
                self.critic_opt.zero_grad()
                actor_loss.backward()
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) 
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) 
                self.actor_opt.step()
                self.critic_opt.step()

        return True

    def save_model(self, file_path):
        if not os.path.exists(file_path):
            os.makedirs(file_path)

        act_f = os.path.join(file_path, 'PPO_actor.ckpt')
        critic_f = os.path.join(file_path, 'PPO_critic.ckpt')
        torch.save(self.actor.state_dict(), act_f)
        torch.save(self.critic.state_dict(), critic_f)

    def load_model(self, file_path):
        act_f = os.path.join(file_path, 'PPO_actor.ckpt')
        critic_f = os.path.join(file_path, 'PPO_critic.ckpt')
        self.actor.load_state_dict(torch.load(act_f, map_location='cpu'))
        self.critic.load_state_dict(torch.load(critic_f, map_location='cpu'))
        self.actor.to(self.device)
        self.critic.to(self.device)
        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

    def train(self):
        self.training = True
        self.actor.train()
        self.critic.train()

    def eval(self):
        self.training = False
        self.actor.eval()
        self.critic.eval()

二、 Pytorch实践

2.1 智能体构建与训练

PPO2主要是收集多轮的结果序列进行训练,增加训练轮数,适当降低学习率,稍微增Actor和Critic的网络深度
详细可见 Github: test_ppo.Hopper_v4_ppo2_test

import os
from os.path import dirname
import sys
import gymnasium as gym
import torch
# 笔者的github-RL库
from RLAlgo.PPO import PPO
from RLAlgo.PPO2 import PPO2
from RLUtils import train_on_policy, random_play, play, Config, gym_env_desc

env_name = 'Hopper-v4'
gym_env_desc(env_name)
print("gym.__version__ = ", gym.__version__ )
path_ = os.path.dirname(__file__) 
env = gym.make(
    env_name, 
    exclude_current_positions_from_observation=True,
    # healthy_reward=0
)
cfg = Config(
    env, 
    # 环境参数
    save_path=os.path.join(path_, "test_models" ,'PPO_Hopper-v4_test2'), 
    seed=42,
    # 网络参数
    actor_hidden_layers_dim=[256, 256, 256],
    critic_hidden_layers_dim=[256, 256, 256],
    # agent参数
    actor_lr=1.5e-4,
    critic_lr=5.5e-4,
    gamma=0.99,
    # 训练参数
    num_episode=12500,
    off_buffer_size=512,
    off_minimal_size=510,
    max_episode_steps=500,
    PPO_kwargs={
        'lmbda': 0.9,
        'eps': 0.25,
        'k_epochs': 4, 
        'sgd_batch_size': 128,
        'minibatch_size': 12, 
        'actor_bound': 1,
        'dist_type': 'beta'
    }
)
agent = PPO2(
    state_dim=cfg.state_dim,
    actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,
    critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,
    action_dim=cfg.action_dim,
    actor_lr=cfg.actor_lr,
    critic_lr=cfg.critic_lr,
    gamma=cfg.gamma,
    PPO_kwargs=cfg.PPO_kwargs,
    device=cfg.device,
    reward_func=None
)
agent.train()
train_on_policy(env, agent, cfg, wandb_flag=False, train_without_seed=True, test_ep_freq=1000, 
                 online_collect_nums=cfg.off_buffer_size,
                 test_episode_count=5)

2.2 训练出的智能体观测

最后将训练的最好的网络拿出来进行观察

agent.load_model(cfg.save_path)
agent.eval()
env_ = gym.make(env_name, 
                exclude_current_positions_from_observation=True,
                render_mode='human'
                ) # , render_mode='human'
play(env_, agent, cfg, episode_count=3, play_without_seed=True, render=True)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1476991.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【中国善网ESG周报】企业ESG报告分析之“华大”

引言: ESG(环境、社会和治理)是一个越来越受到关注的话题,它涉及到企业在经营过程中如何平衡利润、人们和地球的利益。随着全球气候变化加剧、社会不平等问题日益突出,以及公司治理和道德标准的日益重要,E…

ISO_IEC_18598-2016自动化基础设施管理(AIM)系统国际标准解读(一)

██ ISO_IEC_18598-2016是什么标准? ISO/IEC 18598国际标准是由ISO(国际标准化组织)/IEC(国际电工委员会)联合技术委员会1-信息技术的第25分委员会-信息技术设备互连小组制定的关于信息基础设施自动化管理的国际标准&…

微信小程序固定头部-CSS实现

效果图 代码逻辑:设置头部的高度,浮动固定后,再加个这个高度的大小的外边距 .weui-navigation-bar {position: fixed;top: 0px;left: 0px;right: 0px;height:90px; } .weui-navigation-bar_bottom{height:90px; }

Chapter 8 - 19. Congestion Management in TCP Storage Networks

Queue Depth Monitoring and Microburst Detection Queue depth monitoring and microburst detection capture the events that may cause congestion at a lower granularity but are unnoticed by other means due to long polling intervals. 队列深度监控和微爆检测可捕捉…

人工智能到深度学习:药物发现的机器智能方法(综述学习)

Artificial intelligence to deep learning: machine intelligence approach for drug discovery - PubMed (nih.gov) 人工神经网络、深度神经网络、支持向量机、分类和回归、生成对抗网络、符号学习和元学习是应用于药物设计和发现过程的算法的例子。人工智能已应用于药物设计…

Tkinter.Text控件中,文本存在某个关键字的将被高亮显示(标记颜色+字体加粗)

在Tkinter的Text控件中,要标记某个关键字并改变其颜色,你可以使用tag_add方法来给包含关键字的文本添加标签,然后使用tag_config方法来配置该标签的显示样式,包括前景色(字体颜色)和背景色等。以下是一个完…

深度测试:指定DoC ID对ES写入性能的影响

在[[使用python批量写入ES索引数据]]中已经介绍了如何批量写入ES数据。基于该流程实际测试一下指定文档ID对ES性能的影响有多大。 一句话版 指定ID比不指定ID的性能下降了63%,且加剧趋势。 以下是测评验证的细节。 百万数据量 索引默认使用1分片和1副本。 指定…

Phoncent博客:探索AI写作与编程的无限可能

Phoncent博客,一个名为Phoncent的创新AIGC博客网站,于2023年诞生。它的创始人是庄泽峰,一个自媒体人和个人站长,他在网络营销推广领域有着丰富的经验。庄泽峰深知人工智能技术在内容创作和编程领域的潜力和创造力,因此…

SpringCloud 基本概念

开篇 学习springcloud的前提我已经认为你已经具备: 微服务的基本概念具备springboot的基本用法 eurake server:注册中心,对标zookeeper eurake client:服务,对标dubbo ribbon:负载均衡,对标nginx feign:与ribbon类似,目前项目没有使用,暂时就不写 hystrix:断路…

面试笔记系列四之SpringBoot+SpringCloud+计算机网络基础知识点整理及常见面试题

什么是 Spring Boot? Spring Boot 是 Spring 开源组织下的子项目,是 Spring 组件一站式解决方案,主要是简化了使用 Spring 的难度,简省了繁重的配置,提供了各种启动器,开发者能快速上手。 Spring Boot 有哪…

阿里云2核4G服务器租用价格85元一年,30元3个月

阿里云2核4G服务器多少钱一年?2核4G服务器1个月费用多少?2核4G服务器30元3个月、85元一年,轻量应用服务器2核4G4M带宽165元一年,本文阿里云服务器网整理的2核4G参加活动的主机是ECS经济型e实例和u1云服务器,阿里云服务…

REVERSE-COMPETITION-VNCTF-2024

REVERSE-COMPETITION-VNCTF-2024 前言TBXObaby_c2yunobfuseko 前言 ko的随机数算法没看出来,可惜~ 这里给自己打个广告:东南网安研二在读,求实习,求内推,求老板们多看看我QAQ TBXO 通过字符串定位到main函数汇编视…

K8S之使用Deployment实现滚动更新

滚动更新 滚动更新简介使用Deployment实现滚动更新相关字段介绍测试滚动更新观察滚动更新查看历史版本 自定义滚动更新策略自定义配置建议实践自定义策略通过 RollingUpdateStrategy 字段来设置滚动更新策略使用Recreate更新策略 滚动更新简介 滚动更新是一种自动化程度较高的…

C++入门06 数据的共享与保护

图源:文心一言 听课笔记简单整理,供小伙伴们参考,内容包含“🐋5.2 变量的生存期与可见性、🐋5.5 静态成员与静态函数、🐋5.6 友元函数与友元类、🐋5.7 共享数据的保护 / const关键字、&#x1…

使用Fragments(片段)提升你的Vue.js开发体验

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

JS进阶——深入对象

版权声明 本文章来源于B站上的某马课程,由本人整理,仅供学习交流使用。如涉及侵权问题,请立即与本人联系,本人将积极配合删除相关内容。感谢理解和支持,本人致力于维护原创作品的权益,共同营造一个尊重知识…

docker安装单机版canal和使用

说明:我安装的组件架构如下: 1、准备一台虚拟机,192.168.2.223,我安装的时候,docker只支持canal1.1.6版本,1.1.7无法使用docker安装.还有一点要补充,就是1.1.6好像不支持es8.0以上版本&#x…

基于springboot实现的牙科诊所系统

一、系统架构 前端:html | layui | js | css 后端:springboot | mybatis 环境:jdk1.8 | mysql | maven 二、 代码及数据库 三、功能介绍 01. web端-首页 02. web端-医生介绍 03. web端-新闻资讯 04. web端-关于我们 05. web…

mini-spring|关于Bean对象作用域以及FactoryBean的实现和使用

需求 FactoryBean 直接配置FactoryBean 获取FactoryBean中的Bean对象 FactoryBean的getObject方法通过反射获取Bean对象 由此省去对实体Dao类的定义 解决方法 对外提供一个可以二次从 FactoryBean 的 getObject 方法中获取对象的功能即可 整体架构 整个的实现过程包括了两部…

Spark Bloom Filter Join

1 综述 1.1 目的 Bloom Filter Join,或者说Row-level Runtime Filtering(还额外有一条Semi-Join分支),是Spark 3.3对运行时过滤的一个最新补充   之前运行时过滤主要有两个:动态分区裁剪DPP(开源实现&am…