人工智能学习:倒立摆强化学习控制-Policy Gradient(11)

news2024/11/23 0:28:10

相对于DQN输出采取动作的Q值,Policy Gradient网络输出采取动作的概率,根据概率来判断需要采取的动作,并在训练过程不断修正网络,使输出的概率更好的符合最优的采取动作的策略。关于Policy Gradient方法的详细原理,可以参考

https://blog.csdn.net/ygp12345/article/details/109009311

应用到倒立摆控制,可以通过构建一个前向网络和一个学习策略来实现。

1 载入模块
载入需要的模块,代码如下

import gym
import numpy as np
import math

import torch
import torch.nn as nn

import matplotlib.pyplot as plt
from matplotlib import animation

animation模块用于生成倒立摆控制的gif动图。

2 定义前向网络
代码如下

# prediction model
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 2)
        self.fc1.weight.data.normal_(0,0.1)
        self.fc2.weight.data.normal_(0,0.1)
        
    def forward(self, state):
        x = self.fc1(state)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        output = nn.functional.softmax(x)
        
        return output

这里采用两层全连接层,中间通过relu函数激活,采用softmax函数输出采取动作(0和1)的概率。

3 定义Policy Gradient策略
代码如下

# define Policy Gradient
class PolicyGradient(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = Net()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.01)
        
        self.history_log_probs = []
        self.history_rewards = []
        
        self.gamma = 0.99

    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = self.net(state)
        ctgr = torch.distributions.Categorical(probs)
        action = ctgr.sample()
        
        self.history_log_probs.append(ctgr.log_prob(action))
        
        return action.item()
    
    def choose_best_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = self.net(state)
        action = int(torch.argmax(probs))
        
        return action
    
    def get_reward(self, state):
        pos, vel, ang, avel = state
        
        pos1 = 2.0
        ang1 = math.pi/6
        
        r1 = 5-10*abs(pos/pos1)
        r2 = 5-10*abs(ang/ang1)

        r1 = max(r1, -5)
        r2 = max(r2, -5)
            
        return r1+r2
    
    def gg(self, state):
        pos, vel, ang, avel = state

        bad = abs(pos) > 2.0 or abs(ang) > math.pi/4
        
        return bad
    
    def store_transition(self, reward):
        self.history_rewards.append(reward)
        
    def learn(self):
        # backward calculate rewards
        R = 0
        
        rewards = []
        for r in self.history_rewards[::-1]:
            R = r + self.gamma*R
            rewards.insert(0,R)
        rewards = torch.tensor(rewards)
        rewards = (rewards-rewards.mean())/rewards.std()
        
        loss = 0
        for i in range(len(rewards)):
            loss += -self.history_log_probs[i]*rewards[i]
            
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.history_log_probs.clear()
        self.history_rewards.clear()

# define some functions
def print_red(string):
    print('\033[0;31m', end='')
    print(string, end='')
    print('\033[0m')

def save_gif(frames, filename):
    figure = plt.imshow(frames[0])
    plt.axis('off')
    
    # callback function
    def animate(i):
        figure.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
    anim.save(filename, writer='pillow', fps=30)

其中包含用于动作决策的前向网络,以及动作决策函数(choose_action),奖励记录函数(store_transition),学习函数(learn)等。在Policy Gradient方法中,动作的决策由网络输出的概率来实行,概率高的动作具有较高的概率被执行。每次执行过程,奖励被记录下来用于后面的学习。学习过程通过对概率的对数log(prob)和奖励(reward)的乘积和求导,进行梯度下降学习,使奖励高的动作采取概率增加,奖励低的动作采取概率减小。

4 仿真训练
仿真训练通过CartPole对象模拟来实现

# create cartpole model
env = gym.make('CartPole-v1', render_mode='human')

# reset state of env
state, _ = env.reset()

# crate Policy Gradient model
model = PolicyGradient()

# step of learning
learn_step = 0

# flag of train ok
train_ok = False
episode = 0

# play and train
while not train_ok:
    state, _ = env.reset()

    play_step = 0
    total_rewards = 0
    
    episode += 1
    print(f'\nEpisode {episode} ...')

    while True:
        env.render()
        
        action = model.choose_action(state)
    
        state, reward, done, _, info = env.step(action)
        pos, vel, a, a_vel = state # position, velocity, angle, angular velocity
    
        reward = model.get_reward(state)
        if model.gg(state):
            reward += -10

        model.store_transition(reward)
        
        total_rewards += reward
        play_step += 1
        
        if play_step%1000 == 0 or model.gg(state):
            model.learn()
            learn_step += 1
            print(f'play step {play_step} rewards {total_rewards:.2f} learn {learn_step}')
    
        if model.gg(state):
            break
            
        if play_step >= 20000:
            train_ok = True
            break

# train ok, save model
save_file = 'policy_gradient.ptl'
torch.save(model, save_file)
print_red(f'\nmodel trained ok, saved to {save_file}')

# close env
env.close()

程序在循环中,不断的根据网络的决策对倒立摆进行控制,每次倒立摆控制失败,进行下一次尝试控制和学习。一直到倒立摆控制步数能够大于一定数值(10000)训练完成,表示达到了稳定控制倒立摆的能力。然后对控制模型进行保存。其中倒立摆对象的奖励和结束采用这里采用自己定义的函数。

5 进行验证
从保存的模型中载入数据,对一个新的对象进行控制

# create game model
env = gym.make('CartPole-v1', render_mode='rgb_array')

# load trained model
model = torch.load('policy_gradient.ptl')

# frames to store game play
frames = []

state, _ = env.reset()

# play a period of time
for i in range(400):
    frames.append(env.render())
    action = model.choose_best_action(state)
    state, reward, done, _, info = env.step(action)
    
    if model.gg(state):
        break

#save frames to gif file
save_gif(frames, 'cart_pole_policy_gradient.gif')
    
env.close()

如上,这里采用choose_best_action函数来选择采取的动作,和choose_action的区别在于choose_action按照概率来选择采取的动作,概率高的动作有更高的概率被选择,概率小的动作有较小的概率被选择。choose_best_action函数则直接选择概率高的动作,表示是在所在情况下最好的选择。控制过程记录倒立摆响应的画面,并写到gif文件。

最后效果如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/90044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]计算机毕业设计的汽车租赁系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis MavenVue等等组成,B/S模式…

配置虚拟主机

配置虚拟主机 虚拟主机在一台Web 服务器上,可以为多个独立的IP地址、域名或端口号提供不同的Web 站点。对于访问量不大的站点来说,这样做可以降低单个站点的运营成本。 子任务1 配置基于IP地址的虚拟主机 基于IP地址的虚拟主机的配置需要在服务器上…

详解STM32启动文件

本文对STM32启动文件startup_stm32f10x_hd.s的代码进行讲解,此文件的代码在任何一个STM32F10x工程中都可以找到。 启动文件使用的ARM汇编指令汇总 Stack——栈 Stack_Size EQU 0x00000400 AREA STACK, NOINIT, READWRITE, ALIGNStack_Mem SPACE Stack_Size__initi…

nodejs安装及环境配置

node.js下载 地址:https://nodejs.org/en/download/ 如果要下载指定的版本,可以点击下面的链接。 开始安装 双击msi,开始安装node.js。 点击【Next】按钮 勾选复选框,点击【Next】按钮 修改好目录后,点击【Nex…

大三学生HTML期末作业,网页制作作业——HTML+CSS+JavaScript饮品饮料茶(7页)

🎀 精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

vuex配置项(核心概念),vuex的触发流程

1.vuex中有哪些配置项(核心概念)作用是什么 state作用: 负责存储数据 getters作用:state计算属性(有缓存) mutaions作用:负责同步更新state数据 mutaions是唯一可以修改state数据的方式 actions作用:负责…

当AI学会创作,是否应该感到担忧?

当AI学会创作,是否应该感到担忧?0. 前言1. 人工智能与 AIGC1.1 人工智能简介1.2 人工智能与 AIGC1.3 步入 AIGC 时代2. 文本生成模型3. 代码生成模型4. 图像生成模型小结与展望0. 前言 近来,随着 Open AI 发布的新一代 AI 聊天机器人 ChatGP…

【Redis】Redis 持久化(RDB和AOF)

文章目录概述RDB触发机制如何恢复 rdb 文件?RDB 优点:RDB 缺点:AOFAOF 优点AOF 缺点概述 Redis 是内存数据库,如果不将内存中的数据库状态保存到磁盘,那么服务器一旦进程退出,服务器中的数据库状态也会消失…

软件供应链安全状况报告

根据 ReversingLabs 于 2022 年 12 月 5 日发布的报告《软件供应链安全状况》(文末提供报告阅读地址) ,在 2020 年至 2022 年初供应链攻击呈指数级增长之后,企业在整个2022 年出现了缓慢但稳定的增长。 ReversingLabs的研究基于上…

深证L2接口是如何运营的?

深证L2接口是现在大数据时代常用的一种程序接口,它集程序与协议于一体,为用户提供他们所需要的数据,这大大提高了用户的交易效率。 深证L2接口是如何运营的? 首先就是获取股票数据运用得最频繁的领域,股票交易市场领…

看完这篇还不懂 MySQL 的 MVCC 机制算我输

前言 MySQL中大名鼎鼎的MVCC机制想必大家都有所耳闻吧,虽然在平时MySQL使用过程中基本上用不到,但是面试中出场率十分高,而且作为架构师的你也是需要知道它的工作机制。那么你对MVCC机制了解多少呢?MVCC机制是用来干嘛的呢&#…

判断链表是否有环,如果有返回环的入口,即链表有环证明,和找到环的入口证明(非常清晰的证明过程)

有环链表 判断一个链表有环,如果有环,返回起点 使用快慢指针的方式,两个指针同时指向头节点,慢指针low一次走一步,快指针fast一次走两步,只要low和fast相遇即说明链表有环 只要快指针和慢指针有相差的步…

Maven Javafx javafx-maven-plugin打包项目,添加dll文件

在pom.xml的 javafx-maven-plugin内添加 additionalAppResources&#xff0c;并进行重新加载 <plugin><groupId>com.zenjava</groupId><artifactId>javafx-maven-plugin</artifactId><version>8.8.3</version><configuration&g…

2022,软件测试真的有说的这么容易吗?

前言 大家都说软件测试入门容易&#xff0c;似乎软件测试成了跳进互联网生态圈的最佳途径。但是不少小伙伴在入门软件测试后&#xff0c;却变的相当的迷茫&#xff0c;不知道自己应该做什么&#xff0c;似乎点点点就成了工作中唯一的事情了。趁现在负能量还没有缠身的时候&…

服务优化实践

性能分析常用方法 1. top top指令默认用来监控cpu使用情况&#xff0c;根据cpu使用情况&#xff0c;分析整个系统运作情况&#xff08;大多数系统cpu密集型&#xff09;top指令查询的进程&#xff0c;将会根据cpu使用率大小进行排序&#xff0c;使用的比较多的排在前面&#x…

重生强化【Reincarnating RL】论文梳理

重生强化【Reincarnating RL】论文梳理 文章目录重生强化【Reincarnating RL】论文梳理前言&#xff1a;文章链接&#xff1a;作者团队介绍&#xff1a;沈向洋老师的论文十问&#xff1a;联系方式&#xff1a;前言&#xff1a; 好久没写文章速读了&#xff0c;最近群友推荐了两…

html5期末大作业:自适应网站开发——公司网站7页 ,响应式页面

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

这10张图拿去,别再说学不会RecyclerView的缓存复用机制了

ViewPager2是在RecyclerView的基础上构建而成的&#xff0c;意味着其可以复用RecyclerView对象的绝大部分特性&#xff0c;比如缓存复用机制等。 作为ViewPager2系列的第一篇&#xff0c;本篇的主要目的是快速普及必要的前置知识&#xff0c;而内容的核心&#xff0c;正是前面…

SQL概述以及MySQL常用语句总结

目录数据库概述SQL与NoSQL对比关系型数据库管理系统的常用实例MySQL介绍安装数据库的连接SQLDDLDMLDQL单表查询多表查询多表关系连接查询连接分类内连接 JOIN外连接左外连接 LEFT JOIN右外连接 RIGHT JOIN自连接 JOIN联合查询 UNION子查询标量子查询列子查询行子查询表子查询DC…

大二Web课程设计——家乡主题网页设计(web前端网页制作课作业) 四川旅游网页设计制作

家乡旅游景点网页作业制作 网页代码运用了DIV盒子的使用方法&#xff0c;如盒子的嵌套、浮动、margin、border、background等属性的使用&#xff0c;外部大盒子设定居中&#xff0c;内部左中右布局&#xff0c;下方横向浮动排列&#xff0c;大学学习的前端知识点和布局方式都有…