【强化学习】常用算法之一 “A3C”

news2024/10/4 10:26:06

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        强化学习是一种机器学习的方法,旨在通过与环境进行交互学习来最大化累积奖励。强化学习研究的核心问题是“智能体(agent)在不断与环境交互的过程中如何选择行为以最大化奖励”。其中,A3C算法(Asynchronous Advantage Actor-Critic)是一种基于策略梯度的强化学习方法,通过多个智能体的异步训练来实现快速而稳定的学习效果。

本文将详细讲解强化学习常用算法之一“A3C”


目录

一、A3C算法的简介

二、A3C算法的发展历程

三、A3C算法的公式和原理讲解

        1. A3C算法的公式

        2. A3C算法的原理

四、A3C算法的功能

五、A3C算法的示例代码

        分解代码

        完整代码

六、总结


一、A3C算法的简介

        A3C(Asynchronous Advantage Actor-Critic)算法是一种在强化学习领域中应用广泛的算法,它结合了策略梯度方法和价值函数的学习,用于近似解决马尔可夫决策过程(Markov Decision Process)问题。A3C算法在近年来备受关注,因为它在处理大规模连续动作空间和高维状态空间方面具有出色的性能。 

二、A3C算法的发展历程

        A3C算法是对DQN(Deep Q Network)算法在强化学习领域的一个重要延伸和改进。DQN算法在2013年被DeepMind团队首次提出,并在很多任务上取得了令人瞩目的效果。然而,DQN算法在处理连续动作空间、高维状态空间等复杂问题上面临着困难。为了解决这些问题,研究人员开始关注基于策略梯度的方法,并提出了A3C算法。

三、A3C算法的公式和原理讲解

        1. A3C算法的公式

        A3C算法的目标是最大化累积奖励,将这一目标表示为优化问题,可以用如下的公式表示:

L(θ) = -E[logπ(a|s;θ)A(s,a)]

        其中,L(θ)表示损失函数,θ表示模型参数,π(a|s;θ)表示在状态s下选择动作a的概率,A(s,a)表示在状态s选择动作a相对于平均回报的优势函数。A3C算法的优化目标是最小化损失函数L(θ)

        2. A3C算法的原理

        A3C算法采用Actor-Critic结构,由Actor和Critic两个网络组成。Actor网络的目标是学习策略函数,即在给定状态下选择动作的概率分布。Critic网络的目标是学习状态值函数或者状态-动作值函数,用于评估不同状态或状态-动作对的价值。

        A3C算法的训练过程可以分为以下几个步骤:

  • 初始化神经网络参数。
  • 创建多个并行的训练线程,每个线程独立运行一个智能体与环境交互,并使用Actor和Critic网络实现策略和价值的近似。
  • 每个线程根据当前的策略网络选择动作,并观测到新的状态和奖励,将这些信息存储在经验回放缓冲区中。
  • 当一个线程达到一定的时间步数或者轨迹结束时,该线程将经验回放缓冲区中的数据抽样出来,并通过计算优势函数进行梯度更新。
  • 每个线程进行一定次数的梯度更新后,将更新的参数传递给主线程进行整体参数更新。
  • 重复上述步骤直到达到预定的训练轮次或者达到终止条件为止。

        A3C算法采用了Asynchronous(异步)的训练方式,每个线程独立地与环境交互,并通过参数共享来实现梯度更新。这种异步训练的方式可以提高训练的效率和稳定性,并且能够学习到更好的策略和价值函数。

四、A3C算法的功能

        A3C算法具有以下功能和特点:

  • 支持连续动作空间和高维状态空间的强化学习;
  • 通过多个并行的智能体实现快速而稳定的训练;
  • 利用Actor和Critic两个网络分别学习策略和价值函数,具有更好的学习效果和收敛性;
  • 通过异步训练的方式提高了训练的效率和稳定性。

五、A3C算法的示例代码

        下面是一个简单的A3C算法的示例代码

        分解代码

        首先,导入需要的库和模块:

import gym
import torch
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F

        然后定义A3C算法中的Actor和Critic网络: 

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.actor = nn.Linear(128, 2)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

        定义ActorCritic类来实现Actor和Critic网络的结构。

        然后定义训练函数,该函数将使用A3C算法进行智能体的训练:

def train(global_model, rank):
    env = gym.make('CartPole-v1')
    model = ActorCritic()
    model.load_state_dict(global_model.state_dict())
    optimizer = optim.Adam(global_model.parameters(), lr=0.01)
    torch.manual_seed(123 + rank)
    max_episode_length = 1000

    for episode in range(2000):
        state = env.reset()
        done = False
        episode_length = 0
        while not done and episode_length < max_episode_length:
            episode_length += 1
            state = torch.from_numpy(state).float()
            action_probs, _ = model(state)
            dist = Categorical(action_probs)
            action = dist.sample()
            next_state, reward, done, _ = env.step(action.item())

            if done:
                reward = -1

            next_state = torch.from_numpy(next_state).float()
            next_state_value = model(next_state)[-1]
            model_value = model(state)[-1]

            delta = reward + (0.99 * next_state_value * (1 - int(done))) - model_value

            actor_loss = -dist.log_prob(action) * delta.detach()
            critic_loss = delta.pow(2)

            loss = actor_loss + critic_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            state = next_state.numpy()

        global_model.load_state_dict(model.state_dict())

        在训练函数中,每个进程都会拷贝全局模型,并在自己的进程中进行模型的训练。训练过程中,智能体使用Actor模型根据当前状态选择动作,然后与环境进行交互,得到下一状态和奖励。根据奖励和下一状态的估值更新网络参数,得到一个损失函数。然后使用反向传播算法更新网络参数。

        最后,运行训练函数:

if __name__ == '__main__':
    num_processes = 4
    global_model = ActorCritic()
    global_model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(global_model, rank,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

        完整代码

# -*- coding: utf-8 -*-
import gym
import torch
import torch.optim as optim
from torch.distributions import Categorical
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda")

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.actor = nn.Linear(128, 2)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

def train(global_model, rank):
    env = gym.make('CartPole-v1')
    model = ActorCritic().to(device)
    model.load_state_dict(global_model.state_dict())
    optimizer = optim.Adam(global_model.parameters(), lr=0.01)
    torch.manual_seed(123 + rank)
    max_episode_length = 1000

    for episode in range(2000):
        state = env.reset()
        done = False
        episode_length = 0
        while not done and episode_length < max_episode_length:
            episode_length += 1
            state = torch.from_numpy(state).float().to(device)
            action_probs, _ = model(state)
            dist = Categorical(action_probs)
            action = dist.sample().to(device)
            next_state, reward, done, _ = env.step(action.item())

            if done:
                reward = -1

            next_state = torch.from_numpy(next_state).float().to(device)
            next_state_value = model(next_state)[-1].cpu()
            model_value = model(state)[-1].cpu()

            delta = reward + (0.99 * next_state_value * (1 - int(done))) - model_value

            actor_loss = -dist.log_prob(action) * delta.detach().to(device)
            critic_loss = delta.pow(2).to(device)

            loss = actor_loss + critic_loss

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            state = next_state.cpu().numpy()

        global_model.load_state_dict(model.state_dict())

if __name__ == '__main__':
    num_processes = 4
    global_model = ActorCritic()
    global_model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(global_model, rank,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

        因为分解代码是cpu运行...的,速度很慢,所以我在完整代码里加了cuda去运行,但是我的电脑...唉,自己看吧。。。

        上述代码使用了OpenAI Gym中的’CartPole-v1’环境作为示例,通过A3C算法训练智能体在该环境中尽可能长时间地保持杆的平衡。 

六、总结

        A3C算法是一种基于策略梯度的强化学习算法,通过多个并行的智能体异步地与环境交互,并利用Actor和Critic网络实现策略和价值的近似,从而实现快速而稳定的强化学习训练。A3C算法具有良好的学习效果和收敛性,并且适用于处理连续动作空间和高维状态空间的问题。本文讲解了A3C算法的介绍、详细讲解其发展历程、算法公式、原理、功能和示例代码,希望能让读者对A3C算法有了更加深入和全面的理解。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699653.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端JS一维数组转树状数组并获取当前节点的所有父级名称或id

注意一维数组里面必须要有属性跟父级关联 test(){const list2 [{id: 1,pid: 0,name: 湖南},{id: 2,pid: 1,name: 长沙},{id: 3,pid: 2,name: 雨花区},];// 参数一:需要转树状数组的数组// 参数二:父id// 参数三:当前递归的父级节点name// 参数三:当前递归的父级节点idconst ar…

移远通信再推模组新品,全新5G智能模组SG530C-CN智创全景智慧生活

6月28日&#xff0c;在2023 MWC上海展会首日&#xff0c;移远通信再次宣布推出模组新品。 此次推出的全新5G智能模组SG530C-CN在连接能力、算力、多媒体性能与成本效益等层面都呈现较高水平。该模组将在智慧零售、车载后装、娱乐/直播、手持终端、工业AI等行业与应用场景上大有…

大数据需要一场硬件革命

光子盒研究院 计算领域的进步往往集中在软件上&#xff1a;华丽的应用程序和软件可以跟踪人和生态系统的健康状况、分析大数据&#xff0c;并在智力竞赛中击败人类冠军。与此同时&#xff0c;对支撑所有这些创新的硬件进行全面改革的努力相对来说&#xff0c;略显小众。 自2020…

如何实现MySQL的读写分离?

其实很简单&#xff0c;就是基于主从复制架构&#xff0c;简单来说&#xff0c;就搞一个主库&#xff0c;挂多个从库&#xff0c;然后我们就单 单只是写主库&#xff0c;然后主库会自动把数据给同步到从库上去。 MySQL主从复制原理的是啥&#xff1f; 主库将变更写入 binlog …

架构图的实现过程

项目需求架构图 实现代码 index.vue <template><!-- 外层div --><div class"topu-container" :style"{ minWidth: ${functionDomainList.length * 330}px }"><!-- 头部显示 --><div class"topu-heard"><!-- …

vue3高德地图点击标点

1.首先如果没有key的话需要在高德开发平台申请key。 2.安装 npm i amap/amap-jsapi-loader --save cnpm i amap/amap-jsapi-loader --save3.容器&#xff1a; <template><div><div class"info"><h4>获取地图级别与中心点坐标</h4>&l…

git常用命令之Cherry-pick

8. Cherry-pick 8.1 基本用法 命令作用延展阅读git cherry-pick 125a1d将提交125a1d应用于当前分支. 在当前分支会产生一个新的提交.链接git cherry-pick bugfix将分支bugfix应用于当前分支. 在当前分支会产生一个新的提交. 场景1&#xff1a;提交125a1d应用到master分支 命…

玖章算术与百度智能云达成合作,「NineData SQL 开发」成为百度智能云主推的数据库工具

2023 年 6 月 19 日&#xff0c;玖章算术&#xff08;浙江&#xff09;科技有限公司旗下的多云数据管理平台 NineData 正式入驻百度智能云市场&#xff0c;双方的深度技术融合将为客户提供智能高效、安全可靠的数据库开发服务。通过适配百度智能云数据库&#xff0c;NineData 为…

计算机网络中的安全

计算机网络中的安全 1 什么是网络安全2 加密的方式——机密性2.1 对称密钥加密2.2 公开密钥加密 3 报文鉴别码——报文完整性4 数字签名——报文完整性、端点鉴别4.1 数字签名技术的基础4.2 公钥认证 5 案例——设计安全电子邮件系统 《计算机网络—自顶向下方法》&#xff08;…

Postman中读取外部文件

目录 前言&#xff1a; 一、postman中读取外部文件的格式 二、Postman中如何导入文件 三、在Postman读取导入的数据文件 前言&#xff1a; 在Postman中&#xff0c;您可以使用"数据文件"功能来读取外部文件&#xff0c;如CSV、JSON或Excel文件。这使得在测试中使用…

如何应用Nginx Rewrit实现网页跳转

目录 一、Nginx Rewrite 二、Rewrite功能 Rewrite跳转场景 Rewrite跳转实现 Nginx 跳转 pcre支持 重写模块 Rewrite实际场景 Rewrite命令/语法格式 flag标记说明 location分类 location优先级 rewrite和location相比 三、跳转案例 实现域名跳转 第一步 修改指…

一文详解gRPC框架

目录 RPC框架简介 简介 各种序列化协议优缺点 gRPC调用模式 gRPC跟ProtocolBuffers的关系 ProtocolBuffers协议 gRPC桩代码生成 gRPC线程模型 gRPC分层 gRPC开发经验 官网及快速开始 常见状态码 适用场景 适用 不适用 手写简易RPC框架 Dubbo学习笔记 一文详解…

【python】数据表转csv

文章目录 1 基本结构1.1 数据1.2 数据结构 2 代码3 tip 1 基本结构 1.1 数据 1.2 数据结构 2 代码 代码&#xff1a; import mysql.connector import csvdef getPerson():# 数据库初始化cnx mysql.connector.connect(userroot, passwordroot, databasetest)cursor cnx.cur…

IDEA启动tomcat控制台中文乱码问题

IntelliJ IDEA是很多程序员必备且在业界被公认为最好的Java开发工具&#xff0c;有很多小伙伴在安装完IDEA并且tomcat之后&#xff0c;启动tomcat会出现控制台中文乱码问题&#xff0c;如下图所示&#xff1a; 具体解决步骤&#xff1a; 一、修改当前 Web 项目 Tomcat Server…

SAP ALV批量修改列的数据

导语&#xff1a;最近在给ALV增加批量修改列的功能&#xff0c;需求是修改多列&#xff0c;以前经常自己画屏幕来实现&#xff0c;研究了一下&#xff0c;SAP有标准的函数&#xff0c;可以自动带出选择列的字段属性&#xff0c;搜索帮助等等&#xff0c;大大提高了便捷性。 函…

本地同步远程yum源,并保存到本地

1.修改本地/etc/yum.repos.d/内容为远程yum repo配置&#xff1b; # 1&#xff09;.备份原yum配置 mkdir -p /home/yum-bak && mv /etc/yum.repos.d/* /home/yum-bak/* # 2&#xff09;.修改目标yum配置 2.执行缓存&#xff0c;查看相关repoid是否正确 yum clean all …

QT学习笔记2--对象树

对象树 可以看到QWidet这几个类的父亲是QObject&#xff0c;在析构的时候是从下往上析构。 实例 创建类 验证的话&#xff0c;要先创建一个类&#xff0c;命名为pushbotton。 点击choose创建&#xff0c;类。 编写相关函数 构造函数 pushbotton::pushbotton(QWidget *pare…

华为云专家出品《深入理解边缘计算》电子书上线

华为开发者大会PaaS生态电子书推荐&#xff0c;助你成为了不起的开发者&#xff01; 什么是边缘计算&#xff1f;边缘计算的应用场景有哪些&#xff1f; 华为云出品《深入理解边缘计算》电子书上线 带你系统理解云、边、端协同的相关原理 了解开源项目的源码分析流程 学成能…

【Python】字符串格式化前世今生

▒ 目录 ▒ &#x1f6eb; 问题描述环境 1️⃣ 《%》方式格式化语法%后面的参数说明 2️⃣ str.format优点指定位置&#xff1a;参数可以不按顺序关键字参数列表索引对象数字格式化 3️⃣ f-string 语法语法示例格式化一个表达式转义符号格式化 datetime 对象 &#x1f6ec; 结…

C#传Bitmap到C++dll出现灰色图片的问题

如果直接将内存中的Bitmap 传给C,原图会失去颜色&#xff0c;如下&#xff1a; 代码如下&#xff1a; ImageCodecInfo jpgEncoder GetEncoder(ImageFormat.Jpeg);System.Drawing.Imaging.Encoder myEncoder System.Drawing.Imaging.Encoder.Quality;EncoderParameters myEncod…