reinforce 跑 CartPole-v1

news2024/11/26 8:45:30

gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。

然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。

代码

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了

class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, X):
        X = F.relu(self.fc1(X))
        return F.softmax(self.fc2(X),dim=1)

class REINFORCE:
    def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):
        self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)
        self.gamma = gamma # 折扣因子
        self.device = device
    
    def take_action(self, state): # 根据动作概率分布随机采样
        state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)
        probs = self.policy_net(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    
    def update(self, transition_dict):  # 公式用的是简化推导
        reward_list = transition_dict['rewards']
        state_list = transition_dict['states']
        action_list = transition_dict['actions']
        
        G = 0
        self.optimizer.zero_grad()
        for i in reversed(range(len(reward_list))):  # 从最后一步算起
            reward = reward_list[i]
            state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)
            action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)
            log_prob = torch.log(self.policy_net(state).gather(1, action))
            G = self.gamma * G + reward 
            loss = -log_prob * G  # 因为梯度更新是减的,所以取个负号
            loss.backward()
        self.optimizer.step()
  
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()

env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

agent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :  # 主要是这部分和原始的有点不同
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                transition_dict['states'].append(state)
                transition_dict['actions'].append(action)
                transition_dict['next_states'].append(next_state)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(done)
                state = next_state
                episode_return += reward
            return_list.append(episode_return)
            agent.update(transition_dict)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,结果如下所示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1301113.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二十一章网络通信

计算机网络实现了堕胎计算机间的互联,使得它们彼此之间能够进行数据交流。网络应用程序就是再已连接的不同计算机上运行的程序,这些程序借助于网络协议,相互之间可以交换数据,编写网络应用程序前,首先必须明确网络协议…

用23种设计模式打造一个cocos creator的游戏框架----(八)适配器模式

1、模式标准 模式名称:适配器模式 模式分类:结构型 模式意图:适配器模式的意图是将一个类的接口转换成客户端期望的另一个接口。适配器模式使原本接口不兼容的类可以一起工作。 结构图: 适用于: 系统需要使用现有的…

力扣题:数字与字符串间转换-12.11

力扣题-12.11 [力扣刷题攻略] Re:从零开始的力扣刷题生活 力扣题1:506. 相对名次 解题思想:进行排序即可 class Solution(object):def findRelativeRanks(self, score):""":type score: List[int]:rtype: List[str]"&…

Java学习笔记——instanceof关键字

instanceof关键字: 作用:保证对象向下转型的安全性在对象向下转型前判断某一对象实例是否属于某个类 判断时,如果对象是null,则 instanceof 判断结果为 false

搜狗输入法v模式 | 爱莉希雅皮肤

搜狗输入法v模式 | 爱莉希雅皮肤 前言爱莉希雅皮肤v模式 前言 搜狗输入法有v模式,v模式是一个转换和计算的功能组合。拥有数字转换、日期转换、算式计算、函数计算等功能。本文介绍如何使用v模式,并附赠一个爱莉希雅的皮肤,可通过百度网盘下…

订单系统的设计与海量数据处理实战

概述 订单系统可以说是整个电商系统中最重要的一个子系统,因此订单数据可以算作电商企业最重要的数据资产。订单系统从代码上来说可分为两部分:订单程序和历史订单处理程序。数据存储进行分库分表。 订单系统业务分析 对于一个合格的订单系统&#xf…

人力资源服务大赛活动方案

为加快人力资源服务业发展,进一步提升人力资源服务从业人员的专业化、职业化水平,推动云南省人力资源服务产业创新发展,促进行业服务规范,为云南省经济社会发展提供人力服务保障。云南省人力资源和社会保障厅计划组织开展“云南省…

vite脚手架,配置动态生成路由,添加不同的layout以及meta配置

实现效果,配置了layout和对应的路由的meta 我想每个模块添加对应的layout,下边演示一层layout及对应的路由 约束规则: 每个模块下,添加对应的 layout.vue 文件 每个文件夹下的 index.vue 是要渲染的页面路由 每个渲染的页面路由对…

关于python一些惯用写法、语法

关于python一些惯用代码处理 程序入口(python从脚本第一行开始运行,没有统一的入口)if __name__ __main__: filename.split("/")[-1][:-4][split("/")\[-1\] 和 split("/",-1)的区别](https://blog.csdn.net/jialibang/a…

分层网络模型(OSI、TCP/IP)及对应的网络协议

OSI七层网络模型 OSI(Open System Interconnect),即开放式系统互连参考模型, 一般都叫OSI参考模型,是ISO组织于1985年研究的网络互连模型。OSI是分层的体系结构,每一层是一个模块,用于完成某种功…

TailwindCSS 如何处理RTL布局模式

背景 TikTok作为目前全世界最受欢迎的APP,需要考虑兼容全世界各个地区的本地化语言和阅读习惯。其中对于阿拉伯语、波斯语等语言的阅读书写习惯是从右向左的,在前端有一个专有名字RTL模式,即Right-to-Left。 其中以阿拉伯语作为第一语言的人…

Oracle(2-14)User-Managed Incomplete Recovery

文章目录 一、基础知识1、Incomplete Recovery Overview 不完全恢复概述2、Situations Requiring IR 需要不完全恢复的情况3、Types of IR 不完全恢复的类型4、IR Guidelines 不完全恢复指南5、User-Managed Procedures 用户管理程序6、RECOVER Command Overview 恢复命令概述7…

深度学习:自注意力机制(Self-Attention)

1 自注意力概述 1.1 定义 自注意力机制(Self-Attention),有时也称为内部注意力机制,是一种在深度学习模型中应用的机制,尤其在处理序列数据时显得非常有效。它允许输入序列的每个元素都与序列中的其他元素进行比较&a…

应用开发平台文件处理设计与实现系列之1——文件处理需求、方案、整体设计

需求 对于应用系统而言,数据主要分为两大类,结构化数据和非结构化数据。 结构化数据通常是指可以明确定义其数据结构及属性的对象,如组织机构、用户、合同、订单等,通常都会使用关系型数据库来存储,通过SQL来读写。 非…

区间合并|LeetCode100136:统计好分割方案的数目

作者推荐 【动态规划】【广度优先】LeetCode2258:逃离火灾 本文涉及的基础知识点 区间合并 题目 给你一个下标从 0 开始、由 正整数 组成的数组 nums。 将数组分割成一个或多个 连续 子数组,如果不存在包含了相同数字的两个子数组,则认为是一种 好分…

layui分页laypage结合Flask+Jinja2实现流程

Layui2.0普通用法<!DOCTYPE html> <html> <head><meta charset"utf-8"><meta name"viewport" content"widthdevice-width, initial-scale1"><title>Demo</title><!-- 请勿在项目正式环境中引用该 …

使用eXtplorer本地搭建文件管理器并内网穿透远程访问本地数据

文章目录 1. 前言2. eXtplorer网站搭建2.1 eXtplorer下载和安装2.2 eXtplorer网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1.Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1. 前言 通过互联网传输文件&#xff0c;是互联网最重要的应用之一&#xff0c;无论是…

ThingWorx 9.2 Windows安装

参考官方文档安装配置 1 PostgreSQL 13.X 2 Java, Apache Tomcat, and ThingWorx PTC Help Center 参考这里安装 数据库 C:\ThingworxPostgresqlStorage 设置为任何人可以full control 数据库初始化 pgadmin4 创建用户twadmin并记录口令password Admin Userpostgres Thin…

I.MX6ULL_Linux_驱动篇(46)linux LCD驱动

LCD 是很常用的一个外设&#xff0c;在Linux 下LCD 的使用更加广泛&#xff0c;在搭配 QT 这样的 GUI 库下可以制作出非常精美的 UI 界面。本章我们就来学习一下如何在 Linux 下驱动 LCD 屏幕。 Linux 下 LCD 驱动简析 Framebuffer 设备 先来回顾一下裸机的时候 LCD 驱动是怎…

Leetcode每日一题

https://leetcode.cn/problems/binary-tree-preorder-traversal/ 这道题目需要我们自行进行创建一个数组&#xff0c;题目也给出我们需要自己malloc一个数组来存放&#xff0c;这样能达到我们遍历的效果&#xff0c;我们来看看他的接口函数给的是什么。 可以看到的是这个接口函…