A2C原理和代码实现

news2024/11/15 3:52:32

参考王树森《深度强化学习》课程和书籍


1、A2C原理:

在这里插入图片描述


Observe a transition: ( s t , a t , r t , s t + 1 ) (s_t,{a_t},r_t,s_{t+1}) (st,at,rt,st+1)

TD target:
y t = r t + γ ⋅ v ( s t + 1 ; w ) . y_{t} = r_{t}+\gamma\cdot v(s_{t+1};\mathbf{w}). yt=rt+γv(st+1;w).
TD error:
δ t = v ( s t ; w ) − y t . \quad\delta_t = v(s_t;\mathbf{w})-y_t. δt=v(st;w)yt.
Update the policy network (actor) by:
θ ← θ − β ⋅ δ t ⋅ ∂ ln ⁡ π ( a t ∣ s t ; θ ) ∂ θ . \mathbf{\theta}\leftarrow\mathbf{\theta}-\beta\cdot\delta_{t}\cdot\frac{\partial\ln\pi(a_{t}\mid s_{t};\mathbf{\theta})}{\partial \mathbf{\theta}}. θθβδtθlnπ(atst;θ).


def compute_value_loss(self, bs, blogp_a, br, bd, bns):
    # 目标价值。
    with torch.no_grad():
        target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()
        # torch.logical_not 对输入张量取逻辑非

    # 计算value loss。
    value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)
    return value_loss

Update the value network (critic) by:
w ← w − α ⋅ δ t ⋅ ∂ v ( s t ; w ) ∂ w . \mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_{t}\cdot{\frac{\partial{v(s_{t}};\mathbf{w})}{\partial\mathbf{w}}}. wwαδtwv(st;w).


def compute_policy_loss(self, bs, blogp_a, br, bd, bns):
    # 建议对比08_a2c.py,比较二者的差异。
    with torch.no_grad():
        value = self.V(bs).squeeze()

    policy_loss = 0
    for i, logp_a in enumerate(blogp_a):
        policy_loss += -logp_a * value[i]
    policy_loss = policy_loss.mean()
    return policy_loss

2、A2C完整代码实现:

参考后修改注释:最初的代码在https://github.com/wangshusen/DRL

"""8.3节A2C算法实现。"""
import argparse
import os
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class ValueNet(nn.Module):
    def __init__(self, dim_state):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PolicyNet(nn.Module):
    def __init__(self, dim_state, num_action):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_action)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        prob = F.softmax(x, dim=-1)
        return prob


class A2C:
    def __init__(self, args):
        self.args = args
        self.V = ValueNet(args.dim_state)
        self.V_target = ValueNet(args.dim_state)
        self.pi = PolicyNet(args.dim_state, args.num_action)
        self.V_target.load_state_dict(self.V.state_dict())

    def get_action(self, state):
        probs = self.pi(state)
        m = Categorical(probs)
        action = m.sample()
        logp_action = m.log_prob(action)
        return action, logp_action

    def compute_value_loss(self, bs, blogp_a, br, bd, bns):
        # 目标价值。
        with torch.no_grad():
            target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()

        # 计算value loss。
        value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)
        return value_loss

    def compute_policy_loss(self, bs, blogp_a, br, bd, bns):
        # 目标价值。
        with torch.no_grad():
            target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()

        # 计算policy loss。
        with torch.no_grad():
            advantage = target_value - self.V(bs).squeeze()
        policy_loss = 0
        for i, logp_a in enumerate(blogp_a):
            policy_loss += -logp_a * advantage[i]
        policy_loss = policy_loss.mean()
        return policy_loss

    def soft_update(self, tau=0.01):
        def soft_update_(target, source, tau_=0.01):
            for target_param, param in zip(target.parameters(), source.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - tau_) + param.data * tau_)

        soft_update_(self.V_target, self.V, tau)


class Rollout:
    def __init__(self):
        self.state_lst = []
        self.action_lst = []
        self.logp_action_lst = []
        self.reward_lst = []
        self.done_lst = []
        self.next_state_lst = []

    def put(self, state, action, logp_action, reward, done, next_state):
        self.state_lst.append(state)
        self.action_lst.append(action)
        self.logp_action_lst.append(logp_action)
        self.reward_lst.append(reward)
        self.done_lst.append(done)
        self.next_state_lst.append(next_state)

    def tensor(self):
        bs = torch.as_tensor(self.state_lst).float()
        ba = torch.as_tensor(self.action_lst).float()
        blogp_a = self.logp_action_lst
        br = torch.as_tensor(self.reward_lst).float()
        bd = torch.as_tensor(self.done_lst)
        bns = torch.as_tensor(self.next_state_lst).float()
        return bs, ba, blogp_a, br, bd, bns


class INFO:
    def __init__(self):
        self.log = defaultdict(list)
        self.episode_length = 0
        self.episode_reward = 0
        self.max_episode_reward = -float("inf")

    def put(self, done, reward):
        if done is True:
            self.episode_length += 1
            self.episode_reward += reward
            self.log["episode_length"].append(self.episode_length)
            self.log["episode_reward"].append(self.episode_reward)

            if self.episode_reward > self.max_episode_reward:
                self.max_episode_reward = self.episode_reward

            self.episode_length = 0
            self.episode_reward = 0

        else:
            self.episode_length += 1
            self.episode_reward += reward


def train(args, env, agent: A2C):
    V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)
    pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)
    info = INFO()

    rollout = Rollout()
    state, _ = env.reset()
    for step in range(args.max_steps):
        action, logp_action = agent.get_action(torch.tensor(state).float())
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        info.put(done, reward)

        rollout.put(
            state,
            action,
            logp_action,
            reward,
            done,
            next_state,
        )
        state = next_state

        if done is True:
            # 模型训练。
            bs, ba, blogp_a, br, bd, bns = rollout.tensor()

            value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)
            V_optimizer.zero_grad()
            value_loss.backward(retain_graph=True)
            V_optimizer.step()

            policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)
            pi_optimizer.zero_grad()
            policy_loss.backward()
            pi_optimizer.step()

            agent.soft_update()

            # 打印信息。
            info.log["value_loss"].append(value_loss.item())
            info.log["policy_loss"].append(policy_loss.item())

            episode_reward = info.log["episode_reward"][-1]
            episode_length = info.log["episode_length"][-1]
            value_loss = info.log["value_loss"][-1]
            print(f"step={step}, reward={episode_reward:.0f}, length={episode_length}, max_reward={info.max_episode_reward}, value_loss={value_loss:.1e}")

            # 重置环境。
            state, _ = env.reset()
            rollout = Rollout()

            # 保存模型。
            if episode_reward == info.max_episode_reward:
                save_path = os.path.join(args.output_dir, "model.bin")
                torch.save(agent.pi.state_dict(), save_path)

        if step % 10000 == 0:
            plt.plot(info.log["value_loss"], label="value loss")
            plt.legend()
            plt.savefig(f"{args.output_dir}/value_loss.png", bbox_inches="tight")
            plt.close()

            plt.plot(info.log["episode_reward"])
            plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")
            plt.close()


def eval(args, env, agent):
    agent = A2C(args)
    model_path = os.path.join(args.output_dir, "model.bin")
    agent.pi.load_state_dict(torch.load(model_path))

    episode_length = 0
    episode_reward = 0
    state, _ = env.reset()
    for i in range(5000):
        episode_length += 1
        action, _ = agent.get_action(torch.from_numpy(state))
        next_state, reward, terminated, truncated, info = env.step(action.item())
        done = terminated or truncated
        episode_reward += reward

        state = next_state
        if done is True:
            print(f"episode reward={episode_reward}, length={episode_length}")
            state, _ = env.reset()
            episode_length = 0
            episode_reward = 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
    parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
    parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
    parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")

    parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")
    parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
    parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")

    parser.add_argument("--do_train", action="store_true", help="Train policy.")
    parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
    args = parser.parse_args()

    env = gym.make(args.env)
    agent = A2C(args)

    if args.do_train:
        train(args, env, agent)

    if args.do_eval:
        eval(args, env, agent)


3、torch.distributions.Categorical()

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs) # 用probs构造一个分布
action = m.sample() # 按照probs进行采样
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()

Probability distributions - torch.distributions — PyTorch 2.0 documentation

next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()


[Probability distributions - torch.distributions — PyTorch 2.0 documentation](https://pytorch.org/docs/stable/distributions.html)

[【PyTorch】关于 log_prob(action) - 简书 (jianshu.com)](https://www.jianshu.com/p/06a5c47ee7c2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/855245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rocketmq 5.0 任意时间定时消息(RIP-43) 原理详解 源码解析

1. 背景 1.1 概念和应用场景 延迟消息(定时消息)即消息到达消息队列服务端后不会马上投递,而是到达某个时间才投递给消费者。它在在当前的互联网环境中有非常大的需求。 例如电商/网约车等业务中都会出现的订单场景,客户下单后…

用C语言构建一个数字识别深度神经网络

接上一篇: 用C语言构建一个数字识别卷积神经网络 1. 深度神经网络 按照深度学习的理论,随着神经网络层数的增加,网络拟合复杂问题的能力也会增强,对事物特征的挖掘也会更加深入.这里尝试构建一个5层深度的神经网络&am…

靶形数独

题目描述 小城和小华都是热爱数学的好学生,最近,他们不约而同地迷上了数独游戏,好胜的他们想用数独来一比高低。但普通的数独对他们来说都过于简单了,于是他们向 Z 博士请教,Z 博士拿出了他最近发明的“靶形数独”&am…

使用AI工具Lama Cleaner一键去除水印、人物、背景等图片里的内容

使用AI工具Lama Cleaner一键去除水印、人物、背景等图片里的内容 前言前提条件相关介绍Lama Cleaner环境要求安装Lama Cleaner启动Lama CleanerCPU方式启动GPU方式启动 使用Lama Cleaner测试结果NO.1 检测框NO.2 水印NO.3 广州塔NO.4 人物背景 参考 前言 由于本人水平有限&…

springcloud3 bus+springconfig 实现配置文件的动态刷新(了解)

一 springcloud Bus的作用 1.1 springcloud的作用 spring cloud bus是用来将分布式系统的节点与轻量级消息系统链接起来的框架。 它整合了java的事件处理机制和消息中间件的功能。其中目前支持RabbitMQ和kafka 简介: bus实现多个服务的配置文件动态刷新。 1.2 …

【算法|数组】快慢指针

算法|数组——快慢指针 引入 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你…

QT QLCDNumber 使用详解

本文详细的介绍了QLCDNumber控件的各种操作,例如:新建界面、源文件、设置显示位数、设置进制、设置外观、设置小数点、设置溢出、显示事件、其它文章等等操作。 实际开发中,一个界面上可能包含十几个控件,手动调整它们的位置既费时…

一些日常问题的简单总结

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 maven生命周期二方包maven的配置 Java内存管理堆jvm内存调优GC流程G1垃圾回收器 CPU负载及使用率docker二进制安装及配置nexusnginx做反向代理 k8spod生命周期探针l…

linux remoteproc驱动中elf解析函数实现分析

linux remoteproc驱动中elf解析函数实现分析 1 ELF文件组织结构2 ELF_GEN_FIELD_GET_SET3 elf 各种header解析接口以及其实现3.1 elf header3.1.1 elf header解析接口3.1.2 elf header各个解析函数为:3.1.2.1 ELF_GEN_FIELD_GET_SET(hdr, e_entry, u64)3.1.2.2 ELF_…

[OnWork.Tools]系列 07-Web浏览器

简介 简易的web浏览器,适合临时使用 组件安装 第一次使用时可能需要安装相关组件 点击确定 会打开官方地址 WebView2 - Microsoft Edge Developer 点击立即下载 跳转到新的地址 WebView2 - Microsoft Edge Developer 有外网的选择第一个,无网络的在有网络的电脑打开后选择…

SpringCloud实用篇4——MQ RabbitMQ SpringAMQP

目录 1 初识MQ1.1 同步和异步通讯1.1.1 同步通讯1.1.2 异步通讯 1.2 技术对比 2.快速入门2.1 安装RabbitMQ2.1.1 单机部署2.1.2集群部署 2.2 RabbitMQ消息模型2.3.导入Demo工程2.4 入门案例2.4.1 publisher实现2.4.2 consumer实现 3 SpringAMQP3.1 Basic Queue 简单队列模型3.1…

【源码分析】Nacos如何是现在CP模式下基于Raft协议的节点注册逻辑

而对于持久节点,有一个Raft协议的实现 我们知道Raft算法作为一个CP协议,它通过的是Leader节点来向各个节点进行数据的同步。 所以会先判断当前节点是否是Leader节点,如果不是则将请求转发到Leader节点进行处理。 而如果就是Leader节点&am…

二、Linux中权限、shell命令及运行原理

shell命令及运行原理 我们使用Linux时,并不是直接访问操作系统,为什么不是直接访问操作系统呢? 如果用户直接访问操作系统,不仅使用难度大,而且不安全,容易把系统文件损坏。 那么我们通常是如何访问操作系统…

全网最牛,接口自动化测试实现详细总结,23年测试进阶之路...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 为什么要做接口自…

Java String类【超详细】

文章目录 1. 字符串构造2. String对象的比较2.1 比较是否引用同一个对象2. 2boolean equals(Object anObject) 方法:按照字典序比较2.3 int compareTo(String s) 方法: 按照字典序进行比较2.4 int compareToIgnoreCase(String str) 方法:与compareTo方式…

图像 处理 - 开源算法集合

图像 处理 - 开源算法集合 1. 图像 检测 - MMDetection 简介2. 图像 分割 - MMSegmentation 简介3. 图像 其他 - MMPreTrain 以下介绍的每个 开源算法集合 均包含多种 开源算法 1. 图像 检测 - MMDetection 简介 简介:MMDetection 是一个基于 PyTorch 的目标检测开…

【C++进阶之路】map与set的基本使用

文章目录 一、set系列1.set①insert②find③erase④lower_bound与upper_bound 2.multiset①count②equal_range 二、map系列1.map①insert1.插入pair的四种方式2.常用两种方式 ②[]2.multimap①count②equal_range 一、set系列 1.set ①insert 函数分析(C98&…

解决Windows:Call to undefined function exif_imagetype()

很明显,是php安装时没有打开某些扩展,以致不能执行exif_imagetype()这个方法,因此需要打开。 网上很多人说需要打开下面这两个扩展: extension=php_exif.dll extension=php_mbstring.dll 但只说对了一半,我一开始也按照网上文章说的打开这两个扩展,但是还是同样错误。…

2. 软件需求 面向对象分析

目录 1. 软件需求 1.1 需求分类 1.2 需求获取 1.3 需求分析 2. 面向对象分析(OOA) 2.1 统一建模语言 UML 2.2 用例模型 2.2.1 用例图的元素 2.2.2 识别参与者 2.2.3 合并需求获得用例 2.2.4 细化用例描述 2.3 分析模型 2.3.1 定义概念类 …

3.1线程之间共享数据的问题

线程之间共享数据的问题 从整体上来看,所有线程之间共享数据的问题,都是修改数据导致的。如果所有的共享数据都是只读的,就没有问题,因为一个线程所读取的数据不受另一个线程是否正在读取相同的数据而影响。然而,如果…