基于强化学习算法玩CartPole游戏

news2025/1/24 2:17:50

什么事CartPole游戏

CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆)上安装在小车(或称为平衡车)上的水平移动,使杆子保持竖直直立的状态。

有两个动作(action):

左移(0)

右移(1)

四个状态(state): 1. 小车在轨道上的位置 2. 杆子与竖直方向的夹角 3. 小车速度 4. 角度变化率

神经网络设计

1、强化学习的训练网络cartpole_train.py

import  gym
import pygame
import time
import random
import torch
from torch.distributions import Categorical

from torch import nn, optim
import torch.nn.functional as F

def compute_policy_loss(n, log_p):
    r = list()
    #构造奖励r列表
    for i in range(n, 0 ,-1):
        r.append(i *1.0)
    r = torch.tensor(r)
    r = (r - r.mean()) / r.std() #进行标准化处理
    loss = 0
    #计算损失函数
    for pi, ri in zip(log_p, r):
        loss += -pi * ri
    return  loss

class CartPolePolicy(nn.Module):
    def __init__(self):
        super(CartPolePolicy, self).__init__()
        self.fc1 = nn.Linear(in_features = 4, out_features = 128)
        self.fc2 = nn.Linear(128, 2) #输出为神经元个数为2表示,向左和向向右
        self.drop = nn.Dropout(p=0.6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.drop(x)
        x = F.relu(x)
        x = self.fc2(x)
        #使用softmax决策最终的行动,是向左还是右
        return F.softmax(x, dim=1)


if __name__ == '__main__':
    env = gym.make("CartPole-v1") #启动环境
    env.reset(seed= 543)
    torch.manual_seed(543)
    policy = CartPolePolicy() #定义模型
    optimizer = optim.Adam(policy.parameters(), lr = 0.01) #优化器

    #我们一共最多训练1000个回合
    #每个回合最多行动10000次
    #当某一回合的游戏步数超过5000时,就认为完成训练
    max_episod = 1000 #最大游戏回合数
    max_action = 10000 #每回合最大行动数
    max_steps = 5000 #完成训练的步数
    for episod in range(1, max_episod + 1):
        # 对于每一轮循环,都要重新启动一次游戏环境
        state, _ = env.reset()
        step = 0
        log_p = list()
        for step in range(1, max_action + 1):
            state = torch.from_numpy(state).float().unsqueeze(0)
            probs = policy(state) #计算神经网络给出的行动概率
            # 基于网络给出的概率分布,随机选择行动
            m = Categorical(probs)
            # 这里并不是直接使用概率较大的行动,而是通过概率分布生成action, 这样可以进一步探索低概率行动
            action = m.sample()
            state, _, done, _, _ = env.step(action.item())
            if done:
                break #表示跳出该for循环
            log_p.append(m.log_prob(action)) #保存每次行动对应的概率分布
        if step > max_steps: #当step大于最大步数时
            print(f"Done! last episode {episod} Run steps {step}")
            break #跳出循序,结束训练

        #每一回合游戏,都会做一次梯度下降算法
        optimizer.zero_grad()
        loss = compute_policy_loss(step, log_p)
        loss.backward()
        optimizer.step()
        if episod % 10 ==0:
            print(f"Episode {episod} Run step {step}")
    #保存模型
    torch.save(policy.state_dict(), f"cartpole_policy.pth")

2、验证:cartpole_eval.py

import  gym
import pygame
import torch.nn as nn
import torch.nn.functional as F
import time
import torch
class CartPolePolicy(nn.Module):
    def __init__(self):
        super(CartPolePolicy, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.drop = nn.Dropout(p=0.6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.drop(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.softmax(x, dim=1)


if __name__ == '__main__':
    pygame.init() #初始化pygame
    #使用gym, 创建一个artPole游戏的运行环境,这个环境是提供给人类玩家使用的
    env = gym.make('CartPole-v1', render_mode = "human")
    state, _ =env.reset()
    #使用env.reset重置环境后,会得到CartPole游戏中关键参数state
    cart_position = state[0] #小车位置
    cart_speed = state[1] #小车速度
    pole_angle = state[2] #杆的角度
    pole_speed = state[3] #杆的尖端速度

    #加载网络
    policy = CartPolePolicy()
    policy.load_state_dict(torch.load("cartpole_policy.pth"))
    policy.eval()

    start_time =time.time()
    max_action =1000 #设置游戏最大执行次数
    #最多执行1000次方向键,游戏就可以通关结束
    step = 0
    fail = False
    for step in range(1, max_action + 1):
        #首先使用time.sleep,使游戏暂停0.3s,用于人的反应,觉得自己反应慢可以设置更长时间
        # time.sleep(0.3)
        #小车的控制方式,通过神经网络,来决定小车的运动方向
        #将环境参数state转为张量
        state = torch.from_numpy(state).float().unsqueeze(0)
        #输入至网络模型,计算行动概率probs
        probs = policy(state)
        #选取行动概率最大的行动
        action =torch.argmax(probs, dim = 1).item()
        state, _, done, _, _ = env.step(action) #done为True,表示杆倒了
        if done:
            fail = True
            break
        print(f"step = {step} action = {action} angle = {state[2]:.2f}  position = {state[0]:.2f}")

    end_time = time.time()
    game_time = end_time - start_time
    if fail:
        print(f"Game over ,you play {game_time:.2f} seconds, {step} steps.")
    else:
        print(f"Congratulations! you play  {game_time:.2f} seconds, {step} steps.")
    env.close()

视频讲解:

什么是reinforce强化学习算法,基于强化学习玩CartPole游戏_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1978000.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cesium初探-坐标转换

Cesium的坐标系分三种:屏幕坐标、笛卡尔空间直角坐标、地理坐标。 屏幕坐标 屏幕坐标系是一个是平面直角坐标系,即二维笛卡尔坐标系,屏幕左上角为原点(0,0),单位为像素值,屏幕水平方向为X轴&a…

Python | SyntaxError: invalid syntax 深度解析

Python | SyntaxError: invalid syntax 深度解析 在Python编程中,SyntaxError: invalid syntax是一个常见的错误,它表明Python解释器在尝试解析代码时遇到了语法问题。这个错误通常是由于代码中存在拼写错误、缺少符号(如括号、冒号或逗号&a…

Java中的Map(如果想知道Java中有关Map的知识点,那么只看这一篇就足够了!)

前言:在Java编程语言中,集合框架(Collection Framework)提供了一系列用于存储和操作数据的接口和类。其中,Map和Set是两个非常重要的接口,分别用于存储键值对和无重复元素的集合。 ✨✨✨这里是秋刀鱼不做梦…

Nerd Fonts

文章目录 关于 Nerd Fonts重要告示TL;DR字体的各种下载选项 特点 Glyph Setsshell中的图标名称 修补字体Variations 字体安装Option 1: Release Archive DownloadOption 2: Homebrew FontsOption 3: Unofficial Chocolatey or Scoop RepositoriesOption 4: Arch Extra Reposito…

AI在医学领域:医学成像中针对深度神经网络(DNN)的对抗性攻击及其防御策略

关键词:对抗性攻击、医学图像、深度神经网络、模型安全、鲁棒性 机器学习(ML)是医学领域快速发展的一个分支,它利用计算机科学和统计学的方法来解决医学问题。众所周知,攻击者可能通过故意为机器学习分类器创建输入来…

C++11 包装器

1.function包装器 1.1 概念介绍 ret func(x); 上面 func 是什么呢?那么 func 可能是函数名,函数指针,函数对象 ( 仿函数对象 ), 也可能是lamber 表达式对象,这些都是可调用的类型。 函数包装器,也称为函…

comfyui老照片修复工作流,直接复制到comfyui中即可使用

ComfyUI是一个基于web的图形用户界面,用于直观地构建和运行AI模型流程。它特别适合于使用Stable Diffusion等模型进行图像生成任务。然而,ComfyUI本身并不直接提供老照片修复的功能,但你可以通过组合不同的节点来实现这一目标。 老照片修复通常涉及到几个关键步骤: 图像去…

人像修复-插件磨皮

破锤和DR5插件磨皮 破锤插件(更快磨皮)DR5(更好保留皮肤纹理) 破锤插件(更快磨皮) 打开方式:滤镜->Imagenomic->Portraiture 磨皮阈值一般控制在10-20之间若环境与肤色接近,容…

PYTHON专题-(3)你应该知道python内置函数

abs() 函数返回数字的绝对值。dict() 函数用于创建一个字典。help() 函数用于查看函数或模块用途的详细说明。min() 方法返回给定参数的最小值,参数可以为序列。max() 方法返回给定参数的最大值,参数可以为序列。round() 方法返回浮点数 x 的四舍五入值&…

【独家原创】基于APO-Transformer多变量回归预测【24年新算法】 (多输入单输出)Matlab代码

【独家原创】基于APO-Transformer多变量回归预测【24年新算法】 (多输入单输出)Matlab代码 目录 【独家原创】基于APO-Transformer多变量回归预测【24年新算法】 (多输入单输出)Matlab代码效果一览基本介绍程序设计参考资料 效果一…

中国数字孪生进入爆发期,平台级产品决定市场高度

MIT 教授 Geoffrey Parker在《平台革命》中认为,平台正在吞噬整个世界,平台赋予开放的参与式架构,设定合理的参与规则,通过创新的产品、服务为所有参与者创造价值。 与现实世界类似,在数字孪生世界中,数字…

分享5款.NET开源免费的Redis客户端组件库

前言 今天大姚给大家分享5款.NET开源、免费的Redis客户端组件库,希望可以帮助到有需要的同学。 StackExchange.Redis StackExchange.Redis是一个基于.NET的高性能Redis客户端,提供了完整的Redis数据库功能支持,并且具有多节点支持、异步编…

JavaScript基础——Date日期对象常见的用法

Date日期对象 查看Date日期对象的数据类型 创建Date日期对象的实例 获取Date日期对象的属性 设置Date日期对象的属性 日期和时间的比较 获取时间戳 比较时间戳 Date日期对象 JavaScript中的Date类型,提供了一种处理日期和时间的方法,用于创建表示…

OD C卷 - 多线段数据压缩

多段 线 数据压缩 (200) 如图中每个方格为一个像素(i,j),线的走向只能水平、垂直、倾斜45度;图中线段表示为(2, 8)、(3,7)、(3, 6)、&#xff08…

tcp westwood 比 reno,cubic 好在哪

今天说说 tcp 韦斯特伍德,和昨天 dctcp 的路子一样,主要还是一个观点,信息带来性能收益。 reno,cubic 仅做孤立 aimd,没有将 rtt 用到极致,信息相当于浪费掉了,而 westwood 却充分利用 ack 和 …

Python数值计算(21)——非扭结点三次样条曲线

前面介绍到紧固和自然三次样条曲线,这次介绍一下非扭结点三次样条曲线。所谓的非扭结点,是指由于最开始的两个子区间使用插值多项式相同,最后两个子区间所使用的插值多项式也相同,这就会导致在这段多项式上起不到扭结点的效果&…

E26.【C语言】练习:打印整数二进制的奇数位和偶数位

获取一个整数二进制序列中所有的偶数位和奇数位,分别打印出二进制序列 要会打印奇或偶序列,先学会打印二进制序列 下面我的这篇文章的代码稍作修改即可 E24.【C语言】练习:求一个整数存储在内存中的二进制中1的个数(两种方法&a…

一键体验Detectron2框架中的所有预训练模型

Detectron2是由Facebook AI Research (FAIR)推出的基于PyTorch的模块化物体检测库,发布于2019年10月10日。该平台原是2018年推出的Detectron的第二代版本,它完全重写于maskrcnn-benchmark,并采用了PyTorch语言实现。与原版相比,De…

(五)activiti-modeler 编辑器初步优化

最终效果: 1..首先去掉顶部的logo,没什么用,还占用空间。 修改modeler.html文件,添加样式: <style type="text/css"> #main-header{display: none; } #main{padding: 0px; } </style> 2.左边组件选择区域太宽了,一般用不到那么宽。 修改editor…

Linux驱动入门实验班day03-GPIO子系统概述

3.通用框架1——最简单方式1&#xff1a;执行命令cat /sys/kernel/debug/gpio查看串口信息 gpio4对应的下列 方式2&#xff1a; 对于按键GPIO4_14:对应第四组第14个引脚 gpiochip3 ,从96开始&#xff0c; 9614110&#xff1b;