Actor-Critic 跑 CartPole-v1

news2024/10/5 17:26:58

gym-0.26.1
CartPole-v1
Actor-Critic

这里采用 时序差分残差
ψ t = r t + γ V π θ ( s t + 1 ) − V π θ ( s t ) \psi_t = r_t + \gamma V_{\pi _ \theta} (s_{t+1}) - V_{\pi _ \theta}({s_t}) ψt=rt+γVπθ(st+1)Vπθ(st)
详细请参考 动手学强化学习
简单来说就是 reforce 是采用蒙特卡洛搜索方法来估计Q(s,a) ,然后这里先是把状态价值函数V作为基线, 然后利用Q = r + gamma * V 得到上式。

代码如下

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
from tqdm import tqdm

class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, X):
        X = F.relu(self.fc1(X))
        return F.softmax(self.fc2(X), dim=1)
        
 class ValueNet(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        
    def forward(self, X):
        X = F.relu(self.fc1(X))
        return self.fc2(X)
        
 class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        # 价值网络
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        # 策略网络优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr = actor_lr)
        #价值网络优化器
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr = critic_lr)
        self.gamma = gamma
        self.device = device
        
    def take_action(self, state):
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards']).reshape(-1,1).to(device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        # 时分差分目标
        td_target = rewards + self.gamma * self.critic(next_states) * (1- dones)
        # 时分差序目标
        td_delta = td_target - self.critic(states)
        log_probs = torch.log(self.actor(states).gather(1, actions))
        actor_loss = torch.mean(-log_probs * td_delta.detach())
        # 均方误差
        critic_loss= torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        # 计算策略网络的梯度
        actor_loss.backward()
        # 计算价值网络的梯度
        critic_loss.backward()
        # 更新策略网络梯度
        self.actor_optimizer.step()
        # 跟新价值网络梯度
        self.critic_optimizer.step()
 
def train(env, agent, num_episodes):
    return_list = []
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
                state = env.reset()[0]
                done ,truncated = False, False
                while not done and not truncated:
                    action = agent.take_action(state)
                    next_state, reward, done, truncated, info = env.step(action)
                    transition_dict['states'].append(state)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(next_state)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    state = next_state
                    episode_return += reward
                return_list.append(episode_return)
                agent.update(transition_dict)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    return return_list
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()

env_name = 'CartPole-v1'
env = gym.make(env_name)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device)

return_list = train(env, agent, num_episodes)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Return')
plt.title(f'Actor-Critic on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Return')
plt.title(f'Actor-Critic on {env_name}')
plt.show()

jupyter运行结果如下


reforce学习更加稳定,而且总体return也要高一些。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1314587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言--clock()时间函数【详细介绍】

一.clock()时间函数介绍 在 C/C 中,clock() 函数通常用于处理和测量程序运行时间(时钟时间)。它是一种数据类型,表示 CPU 执行指定任务所耗费的“时钟计数”数量,单位为“时钟周期”。 这个函数通常包含在 time.h 头文…

后缀数组模板

详细理解后缀数组求sa数组的函数,该函数可以看为主要分为三个部分,第一个部分是预处理;第二个部分是进行基数排序,首先根据第二关键词排序,然后根据第一关键字排序;第三个部分是根据排序后的结果重新为每个…

Bytebase 2.12.0 - 改进自动补全和布局导航

🚀 新功能 支持 MySQL 高级自动补全。支持从 UI 上导入分类分级配置。 🔔 重大变更 作废已有企业版试用证书。之后可以通过提交申请获取新的试用证书。 🎄 改进 改进整体布局和导航。 支持在 SQL 编辑器里显示以及查询 PostgreSQL 数据…

DDOS 攻击是什么?有哪些常见的DDOS攻击?

DDOS简介 DDOS又称为分布式拒绝服务,全称是Distributed Denial of Service。DDOS本是利用合理的请求造成资源过载,导致服务不可用,从而造成服务器拒绝正常流量服务。就如酒店里的房间是有固定的数量的,比如一个酒店有50个房间&am…

继续看回溯问题

关卡名 继续看回溯问题 我会了✔️ 内容 1.复习递归和N叉树,理解相关代码是如何实现的 ✔️ 2.理解回溯到底怎么回事 ✔️ 3.掌握如何使用回溯来解决二叉树的路径问题 ✔️ 1 复原IP地址 这也是一个经典的分割类型的回溯问题。LeetCode93.有效IP地址正好由四…

生产环境_Spark处理轨迹中跨越本初子午线的经度列

使用spark处理数据集,解决gis轨迹点在地图上跨本初子午线的问题,这个问题很复杂,先补充一版我写的 import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.func…

t-io 程序执行后,jvm不退出的原因

基于t-io 1.7.3 版本分析源码 1、设定当前时间,每10毫秒执行一次 (非守护线程) 2、对应线程池的核心线程在AioServer启动时全部激活,并且添加空任务到阻塞队列,让核心线程(非守护线程)一直存活

ArcGIS Pro SDK文件选择对话框

文件保存对话框 // 获取默认数据库var gdbPath Project.Current.DefaultGeodatabasePath;//设置文件的保存路径SaveItemDialog saveLayerFileDialog new SaveItemDialog(){Title "Save Layer File",OverwritePrompt true,//获取或设置当同名文件已存在时是否出现…

七. 使用ts写一个贪吃蛇小游戏

之前学习了几篇的ts基础,今天我们就使用ts来完成一个贪吃蛇的小游戏。 游戏拆解 我们将我们的任务进行简单拆解分析。 首先我们应该有一个窗口,我们叫做屏幕。让蛇在里面移动,所有我们应该想到要设计一个大盒子当作地图。考虑到食物以及蛇…

【Java代码审计】文件上传篇

【Java代码审计】文件上传篇 1.Java常见文件上传方式2.文件上传漏洞修复 1.Java常见文件上传方式 1、通过文件流的方式上传 public static void uploadFile(String targetURL, String filePath) throws IOException {File file new File(filePath);FileInputStream fileInpu…

【单调栈】【区间合并】LeetCode85:最大矩形

作者推荐 【动态规划】【广度优先搜索】LeetCode:2617 网格图中最少访问的格子数 本文涉及的知识点 单调栈 区间合并 题目 给定一个仅包含 0 和 1 、大小为 rows x cols 的二维二进制矩阵,找出只包含 1 的最大矩形,并返回其面积。 示例 1&#xff1…

遥感图像分割系统:融合空间金字塔池化(FocalModulation)改进YOLOv8

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义 遥感图像分割是遥感技术领域中的一个重要研究方向,它的目标是将遥感图像中的不同地物或地物类别进行有效的分割和识别。随着遥感技术的不断发展和遥感…

iOS_给View的部分区域截图 snapshot for view

文章目录 1.将整个view截图返回image:2.截取view的部分区域,返回image:3.旧方法:4.Tips参考: 1.将整个view截图返回image: 这些 api 已被废弃,所以需要判断 iOS 版本 写两套代码: R…

【Java】5分钟读懂Java虚拟机架构

5分钟读懂Java虚拟机架构 Java虚拟机(JVM)架构JVM是如何工作的?1. 类加载器子系统2. 运行时数据区3. 执行引擎 相关资料 本文阐述了JVM的构成和组件。每个Java开发人员都知道字节码经由JRE(Java运行时环境)执行。但他们…

php入门、安装wampserver教程

php声称是全世界最好的语言,今天这篇文章就带大家入门学习php,php和python、javasript一样,是一种弱类型的脚本语言。 一、php开发环境搭建 作为初学者,学习php建议安装wampserver,wampserver是包含了apache、php和mys…

oracle 锁表解决办法

相关表介绍 V$LOCKED_OBJECT(记录锁信息的表)v$session(记录会话信息的表)v$sql(记录 sql 执行的表)dba_objects(用来管理对象,表、库等等) 查询锁表的 SID select b.…

网络入门---可变参数原理和日志模拟实现

目录标题 前言有关函数的几个性质介绍可变参数的用法介绍可变参数的一个注意事项可变参数的底层原理va_listva_endva_startva_arg_INTSIZEOF 可变参数的注意事项日志的实现日志的测试 前言 在上一篇文章中我们介绍了TCP协议有关的函数,大致就是服务端先通过listen函…

Android多国语言翻译 国际化

语言目录详细对应关系 Arabic, Egypt (ar-rEG) —————————–阿拉伯语,埃及 Arabic, Israel (ar-rIL) ——————————-阿拉伯语,以色列 Bulgarian, Bulgaria (bg-rBG) ———————保加利亚语,保加利亚 Catalan, Spain (ca-r…

函数栈帧的创建和销毁(编程底层原理)

本篇的内容格外的难写,里面包含了许多的专业术语名和汇编指令等晦涩难懂的东西,既不利于讲解,也不利于读者的理解。但我会尽力去讲述出里面的底层逻辑,帮助大家去理解里面的过程,理解编程的底层原理可以为我们后续更为…

YOLOv8 | 代码逐行解析(一) | 项目目录构造分析

一、本文介绍 Hello,大家好这次给大家带来的不是改进,是整个YOLOv8项目的分析,整个系列大概会更新7-10篇左右的文章,从项目的目录到每一个功能代码的都会进行详细的讲解,同时YOLOv8改进系列也突破了三十篇文章&#x…