强化学习A3C算法

news2025/2/21 19:55:14

强化学习A3C算法

效果:
在这里插入图片描述

a3c.py

import  matplotlib
from    matplotlib import pyplot as plt
matplotlib.rcParams['font.size'] = 18
matplotlib.rcParams['figure.titlesize'] = 18
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['KaiTi']
matplotlib.rcParams['axes.unicode_minus']=False

plt.figure()

import os
import  threading
import  gym
import  multiprocessing
import  numpy as np
from    queue import Queue

import  tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import layers,optimizers,losses


# os.environ["CUDA_VISIBLE_DEVICES"] = "0" #使用GPU
# 按需占用GPU显存
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # 设置 GPU 显存占用为按需分配,增长式
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e :
        # 异常处理
        print(e)



SEED_NUM = 1234
tf.random.set_seed(SEED_NUM)
np.random.seed(SEED_NUM)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 互斥锁,用于线程同步数据
g_mutex = threading.Lock()


class ActorCritic(keras.Model):
    """ Actor-Critic模型 """
    def __init__(self, state_size, action_size):
        super(ActorCritic, self).__init__()
        self.state_size = state_size # 状态向量长度
        self.action_size = action_size # 动作数量
        # 策略网络Actor
        self.dense1 = layers.Dense(128, activation='relu')
        self.policy_logits = layers.Dense(action_size)
        # V网络Critic
        self.dense2 = layers.Dense(128, activation='relu')
        self.values = layers.Dense(1)

    def call(self, inputs):
        # 获得策略分布Pi(a|s)
        x = self.dense1(inputs)
        logits = self.policy_logits(x)
        # 获得v(s)
        v = self.dense2(inputs)
        values = self.values(v)
        return logits, values


def record(episode,
           episode_reward,
           worker_idx,
           global_ep_reward,
           result_queue,
           total_loss,
           num_steps):
    """ 统计工具函数  """
    if global_ep_reward == 0:
        global_ep_reward = episode_reward
    else:
        global_ep_reward = global_ep_reward * 0.99 + episode_reward * 0.01
    print(
        f"{episode} | "
        f"Average Reward: {int(global_ep_reward)} | "
        f"Episode Reward: {int(episode_reward)} | "
        f"Loss: {int(total_loss / float(num_steps) * 1000) / 1000} | "
        f"Steps: {num_steps} | "
        f"Worker: {worker_idx}"
    )
    result_queue.put(global_ep_reward) # 保存回报,传给主线程
    return global_ep_reward

class Memory:
    """ 数据 """
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []

    def store(self, state, action, reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []

class Agent:
    """ 智能体,包含了中央参数网络server """
    def __init__(self):
        # 服务模型优化器,client不需要,直接从server拉取参数
        self.opt = optimizers.Adam(1e-3)
        # 服务模型(状态向量,动作数量)
        self.server = ActorCritic(4, 2) 
        self.server(tf.random.normal((2, 4)))

    def train(self):
        # 共享队列,线程安全,不需要加锁同步
        res_queue = Queue() 
        # 根据cpu线程数量创建多线程Worker
        workers = [Worker(self.server, self.opt, res_queue, i)
                   for i in range(10)] #multiprocessing.cpu_count()
        # 启动多线程Worker
        for i, worker in enumerate(workers):
            print("Starting worker {}".format(i))
            worker.start()

        # 统计并绘制总回报曲线
        returns = []
        while True:
            reward = res_queue.get()
            if reward is not None:
                returns.append(reward)
            else: # 结束标志
                break
            
        # 等待线程退出 
        [w.join() for w in workers] 

        print(returns)

        plt.figure()
        plt.plot(np.arange(len(returns)), returns)
        # plt.plot(np.arange(len(moving_average_rewards)), np.array(moving_average_rewards), 's')
        plt.xlabel('回合数')
        plt.ylabel('总回报')
        plt.savefig('a3c-tf-cartpole.svg')


class Worker(threading.Thread): 
    def __init__(self,  server, opt, result_queue, idx):
        super(Worker, self).__init__()
        self.result_queue = result_queue # 共享队列
        self.server = server # 服务模型
        self.opt = opt # 服务优化器
        self.client = ActorCritic(4, 2) # 线程私有网络
        self.worker_idx = idx # 线程id
        self.env = gym.make('CartPole-v1').unwrapped #私有环境
        self.ep_loss = 0.0

    def run(self): 
        
        # 每个worker自己维护一个memory
        mem = Memory() 
        # 1回合最大500步
        for epi_counter in range(500): 
            # 复位client游戏状态
            current_state,info = self.env.reset(seed=SEED_NUM) 
            mem.clear()
            ep_reward = 0.0
            ep_steps = 0  
            done = False
            while not done:
                # 输入AC网络状态获得Pi(a|s),未经softmax
                logits, _ = self.client(tf.constant(current_state[None, :],dtype=tf.float32))
                # 归一化概率
                probs = tf.nn.softmax(logits)
                # 随机采样动作
                action = np.random.choice(2, p=probs.numpy()[0])
                # 交互 
                new_state, reward, done, truncated, info = self.env.step(action) 
                # 累加奖励
                ep_reward += reward 
                # 记录
                mem.store(current_state, action, reward) 
                # 计算回合步数
                ep_steps += 1
                # 刷新状态 
                current_state = new_state 

                # 最长500步或者规则结束,回合结束
                if ep_steps >= 500 or done: 
                    # 计算当前client上的误差
                    with tf.GradientTape() as tape:
                        total_loss = self.compute_loss(done, new_state, mem) 
                    # 计算梯度
                    grads = tape.gradient(total_loss, self.client.trainable_weights)
                    # 梯度提交到server,在server上更新梯度
                    global g_mutex
                    g_mutex.acquire()
                    self.opt.apply_gradients(zip(grads,self.server.trainable_weights))
                    g_mutex.release()
                    # 从server拉取最新的梯度
                    g_mutex.acquire()
                    self.client.set_weights(self.server.get_weights())
                    g_mutex.release()
                    # 清空Memory 
                    mem.clear() 
                    # 统计此回合回报
                    self.result_queue.put(ep_reward)
                    print(f"thread worker_idx : {self.worker_idx}, episode reward : {ep_reward}")
                    break
        # 线程结束
        self.result_queue.put(None) 

    def compute_loss(self,
                     done,
                     new_state,
                     memory,
                     gamma=0.99):
        if done:
            reward_sum = 0. # 终止状态的v(终止)=0
        else:
            # 私有网络根据新状态计算回报
            reward_sum = self.client(tf.constant(new_state[None, :],dtype=tf.float32))[-1].numpy()[0]
        # 统计折扣回报
        discounted_rewards = []
        for reward in memory.rewards[::-1]:  # reverse buffer r
            reward_sum = reward + gamma * reward_sum
            discounted_rewards.append(reward_sum)
        discounted_rewards.reverse()
        # 输入AC网络环境状态获取 Pi(a|s) v(s) 预测值
        logits, values = self.client(tf.constant(np.vstack(memory.states), dtype=tf.float32))
        # 计算advantage = R() - v(s) = 真实值 - 预测值
        advantage = tf.constant(np.array(discounted_rewards)[:, None], dtype=tf.float32) - values
        # Critic网络损失
        value_loss = advantage ** 2
        # 归一化概率预测值Pi(a|s)
        policy = tf.nn.softmax(logits)
        # 真实动作a 概率预测值Pi(a|s) 交叉熵
        policy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=memory.actions, logits=logits)
        # 计算策略网络损失时,并不会计算V网络
        policy_loss = policy_loss * tf.stop_gradient(advantage)
        # 动作概率测值Pi(a|s) 熵
        entropy = tf.nn.softmax_cross_entropy_with_logits(labels=policy, logits=logits)
        policy_loss = policy_loss - 0.01 * entropy
        # 聚合各个误差
        total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))
        return total_loss


if __name__ == '__main__':
    master = Agent()
    master.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/888634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谓语动词(动词不定式(短语)、动名词、分词(短语))作定语

动词不定式(短语)作定语 动名词作定语 分词(短语)作定语 着重记忆及物动词

剑指offer-2.2字符串

字符串 C/C中每个字符串都以字符"\0"作为结尾,这样我们就能很方便地找到字符串的最后尾部。但由于这个特点,每个字符串中都有一个额外字符的开销,稍不留神就会造成字符串的越界。比如下面的代码: char str [10]; strc…

UWB现场安装通常涉及以下步骤

UWB现场安装通常涉及以下步骤: 1.确定区域需求:首先,确定需要进行UWB定位的区域和目标。这可能是一个室内环境、仓库、工厂或其他特定的工作场所。 2.设计系统布局:根据区域的特点和目标定位需求,设计系统的布局和基…

关于Firmae缺失binwalk模块

问题 david707:~/FirmAE$ sudo ./run.sh -c weyow ./WAM_9900-20.06.03V.trx [*] ./WAM_9900-20.06.03V.trx emulation start!!! Traceback (most recent call last):File "./sources/extractor/extractor.py", line 19, in <module>import binwalk ModuleNot…

红帽8.2版本CSA题库:第十题配置用户帐户

红帽8.2版本CSA题库&#xff1a;第十题配置用户帐户 useradd -u 3533 manalo #传创建用户指定uid为3533 echo flectrag | passwd --stdin manalo #设置密码 tail -1 /etc/passwd #查看

globals()与locals()函数

在Python中&#xff0c;globals()和locals()是两个内置函数&#xff0c;用于获取当前作用域内的全局和局部命名空间中的变量和对象。 一、globals() :这个函数返回一个包含当前全局作用域中所有变量和对象的字典。在函数内部调用globals()将返回全局命名空间中的变量&#xf…

密码湘军,融合创新!麒麟信安参展2023商用密码大会,铸牢数据安全坚固堡垒

2023年8月9日至11日&#xff0c;商用密码大会在郑州国际会展中心正式开幕。本次大会由国家密码管理局指导&#xff0c;中国密码学会支持&#xff0c;郑州市人民政府、河南省密码管理局主办&#xff0c;以“密码赋能美好发展”为主题&#xff0c;旨在推进商用密码创新驱动、前沿…

2023年第四届全国人工智能大赛初赛晋级复赛名单公示

由深圳市科技创新委员会、鹏城实验室共同主办&#xff0c;新一代人工智能产业技术创新战略联盟&#xff08;AITISA&#xff09;承办&#xff0c;华为技术有限公司、中国工商银行股份有限公司深圳市分行、中国农业银行股份有限公司深圳市分行、中国建设银行股份有限公司深圳市分…

无线测温产品在半导体制造项目的应用

摘 要&#xff1a;半导体被誉为“制造业的大脑”&#xff0c;在关系国家安全和国民经济命脉的主要行业和关键领域占据支配地位&#xff0c;是国民经济的重要支柱。 随着数字技术的发展和数字经济在国民经济中所占比重越来越高&#xff0c;半导体产业的重要性还会进一步提升。安…

【制作npm包4】api-extractor 学习

制作npm包目录 本文是系列文章&#xff0c; 作者一个橙子pro&#xff0c;本系列文章大纲如下。转载或者商业修改必须注明文章出处 一、申请npm账号、个人包和组织包区别 二、了解 package.json 相关配置 三、 了解 tsconfig.json 相关配置 四、 api-extractor 学习 五、npm包…

高等数学教材啃书汇总重难点(三)微分中值定理与导数的应用

本章节包含多个知识点&#xff0c;一些列微分中值定理是考研证明题的重头戏&#xff0c;而洛必达和泰勒展开则是方法论的天花板难度&#xff0c;虽然对于小题的考察难度较低&#xff0c;整体上仍需重点复习 1.费马引理 2.罗尔定理 3.拉格朗日定理 4.柯西中值定理 5.洛必达法则 …

算法竞赛入门【码蹄集新手村600题】(MT1160-1180)C语言

算法竞赛入门【码蹄集新手村600题】(MT1160-1180&#xff09;C语言 目录MT1161 N的零MT1162 数组最大公约数MT1163 孪生质数MT1164 最大数字MT1165 卡罗尔数MT1166 自守数MT1167自守数IIMT1168 阶乘数MT1169 平衡数MT1170 四叶玫瑰数MT1171 幻数MT1172 完美数字MT1173 魔数MT11…

ODB++资料解析

ODB文件是由VALOR提出的一种ASCII码&#xff0c;双向传输文件。奥宝公司和康代公司的设备都是用的ODB格式进行PCB的生产和检测。 对ODB文件进行解析把数据栅格化很重要&#xff0c;查了网上找不到一个成熟能用的ODB文件解析代码。自己上手写了一个。 当前解析一些载板&#x…

SNMP简单介绍

SNMP SNMP是广泛应用于TCP/IP网络的网络管理标准协议&#xff0c;该协议能够支持网络管理系统&#xff0c;用以监测连接到网络上的设备是否有任何引起管理上关注的情况。SNMP采用轮询机制&#xff0c;提供最基本的功能集&#xff0c;适合小型、快速、低价格的环境使用&#xf…

灵异事件!程序里发现了新Bug但是它正常运行啦!

人生处处有Bug&#xff0c;有些令人困惑&#xff0c;有些令人崩溃&#xff0c;而有些则会让你觉得发现了一件奇奇怪怪的事情。今天&#xff0c;我就来分享一个我在程序中发现的令人惊奇的Bug。 这个Bug出现在我负责维护的一个大型软件系统中。这个系统是用来管理一个电商平台的…

VS2022如何显示Class View窗口

点击菜单栏的“视图”选项 > “类视图”&#xff0c;即可打开Class View。

Stable Diffusion核心算法DDPM解析

DDPM&#xff1a;Denoising Diffusion Probabilistic Model&#xff0c;去噪扩散概率模型 本文参考&#xff1a;一个视频看懂扩散模型DDPM原理推导|AI绘画底层模型_哔哩哔哩_bilibili 1、大概原理 从右往左为正向加噪过程&#xff0c;从左往右为逆向降噪过程。 在正向过程中不…

安装软件包

安装软件包 创建一个名为 /home/curtis/ansible/packages.yml 的 playbook : 将 php 和 mariadb 软件包安装到 dev、test 和 prod 主机组中的主机上 将 RPM Development Tools 软件包组安装到 dev 主机组中的主机上 将 dev 主机组中主机上的所有软件包更新为最新版本 vim packa…

低代码实操演示 | 如何快速构建企微、钉钉、飞书消息推送服务

8月15日&#xff0c;万应低代码培训总监胡杰为大家带来了一场低代码实操直播&#xff0c;这场直播同时在抖音和微信视频号两个平台进行&#xff0c;吸引了众多关注者的参与。 为了更好地帮助大家快速上手&#xff0c;我们将直播的主题内容做了文字梳理&#xff0c;感兴趣的小伙…

【git clone error:no matching key exchange method found】

拉起项目代码报错 git clone ssh://uidxxxgerrit-xxxxxxxx Cloning into ‘xxxxx’… Unable to negotiate with xxx.xx.xxx.ip port xxxxx: no matching key exchange method found. Their offer: diffie-hellman-group14-sha1,diffie-hellman-group1-sha1 fatal: Could not …