Reinforcement Learning with Code【Code 6. Advantage Actor-Critic(A2C)】

news2025/1/13 17:33:25

Reinforcement Learning with Code【Code 6. Advantage Actor-Critic(A2C)】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu’s Mathematical Foundation of Reinforcement Learning.

文章目录

  • Reinforcement Learning with Code【Code 6. Advantage Actor-Critic(A2C)】
    • 1. Actor-Criti's Various Forms
    • 2. Review Advantage Actor-Critic (A2C)
    • 3. A2C Code
    • Reference

1. Actor-Criti’s Various Forms

Image

2. Review Advantage Actor-Critic (A2C)

首先先了解传统的AC算法(详见 Reinforcement Learning with Code 【Chapter 10. Actor Critic】),A2C算法就是在传统的AC算法上增加了baseline减小了拟合的方差,正好这个增加的baseline是价值函 v π ( s ) v_\pi(s) vπ(s),那么就得到了优势函数的定义

δ π ( S , A ) = q π ( S , A ) − v π ( S ) \textcolor{red}{\delta_\pi(S,A) = q_\pi(S,A) - v_\pi(S)} δπ(S,A)=qπ(S,A)vπ(S)
描述的是,在当前状态选择的动作,相比于平均状态值的优劣程度。完整的A2C算法可以参考Reinforcement Learning with Code 【Chapter 10. Actor Critic】

Image
Image

3. A2C Code

在实现A2C时,仍然采用gym中的CartPole-v1环境。

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


# Policy Network
class PolicyNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        return F.softmax(self.fc2(x), dim=1)

# State Value Network
class ValueNet(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
    
    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        return self.fc2(x)

# # Q Value Network
# class QValueNet(nn.Module):
#     def __init__(self, state_dim, hidden_dim, action_dim):
#         super(QValueNet,self).__init__()
#         self.fc1 = nn.Linear(state_dim, hidden_dim)
#         self.fc2 = nn.Linear(hidden_dim, action_dim)
    
#     def forward(self, observation):
#         x = F.relu(self.fc1(observation))
#         return self.fc2(x)


# QAC & A2C
class ActorCritic():
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, ac_type, device):

        self.ac_type = ac_type
        if ac_type == "A2C":
            self.critic = ValueNet(state_dim, hidden_dim).to(device)
        elif ac_type == "QAC":
            self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)

        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.device = device
    
    def choose_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample().item()
        return action
    
    def learn(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)
        actions = torch.tensor(transition_dict['actions'], dtype=torch.int64).view(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)

        if self.ac_type == 'A2C':
            td_target = rewards + self.gamma * self.critic(next_states) * (1-dones)
            td_delta = td_target - self.critic(states)
            log_probs = torch.log(self.actor(states).gather(dim=1, index=actions))
            actor_loss = torch.mean(- log_probs * td_delta.detach())
            critic_loss = torch.mean(F.mse_loss(td_target.detach(), self.critic(states)))
        # elif self.ac_type == 'QAC':
        #     td_target = rewards + self.gamma * self.critic(next_states).gather(dim=1, index=actions) * (1-dones)
        #     td_delta = self.critic(states).gather(dim=1, index=actions)
        #     log_probs = torch.log(self.actor(states).gather(dim=1, index=actions))
        #     actor_loss = torch.mean(- log_probs * td_delta.detach())
        #     critic_loss = torch.mean(F.mse_loss(td_target, td_delta))

        # clear gradient cumulation
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        # calculate gradient
        actor_loss.backward()
        critic_loss.backward()
        # update parameters
        self.actor_optimizer.step()
        self.critic_optimizer.step()


def train_on_policy_agent(env, agent, num_episodes, seed):
    return_list = []
    for i in range(10):
        with tqdm(total = int(num_episodes/10), desc="Iteration %d"%(i+1)) as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                transition_dict = {
                    'states': [],
                    'actions': [],
                    'next_states': [],
                    'rewards': [],
                    'dones': []
                }
                observation, _ = env.reset(seed=seed)
                done = False
                while not done:
                    if render:
                        env.render()
                    action = agent.choose_action(observation)
                    observation_, reward, terminated, truncated, _ = env.step(action)
                    done = terminated or truncated
                    # save one episode experience into a dict
                    transition_dict['states'].append(observation)
                    transition_dict['actions'].append(action)
                    transition_dict['next_states'].append(observation_)
                    transition_dict['rewards'].append(reward)
                    transition_dict['dones'].append(done)
                    # swap state
                    observation = observation_
                    # compute one episode return
                    episode_return += reward
                return_list.append(episode_return)
                agent.learn(transition_dict)
                if((i_episode + 1) % 10 == 0):
                    pbar.set_postfix({
                        'episode': '%d'%(num_episodes / 10 * i + i_episode + 1),
                        'return': '%.3f'%(np.mean(return_list[-10:]))
                    })
                pbar.update(1)
    env.close()
    return return_list

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

def plot_curve(return_list, mv_return, algorithm_name, env_name):
    episodes_list = list(range(len(return_list)))
    plt.plot(episodes_list, return_list, c='gray', alpha=0.6)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('{} on {}'.format(algorithm_name, env_name))
    plt.show()


if __name__ == "__main__":

    # reproducible
    seed_number = 0
    np.random.seed(seed_number)
    torch.manual_seed(seed_number)

    num_episodes = 1000     # episodes length
    hidden_dim = 256        # hidden layers dimension
    gamma = 0.98            # discounted rate
    device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')
    env_name = 'CartPole-v1'
    ac_type = 'A2C'  # Actor-Critic Type: QAC or A2C

    # Attention Learning Rate Is Important 
    actor_lr = 1e-3         # learning rate of actor
    if ac_type == 'A2C':
        critic_lr = 1e-2        # learning rate of critic
    # elif ac_type == 'QAC':
    #     critic_lr = 1e-3

    render = False
    if render:
        env = gym.make(id=env_name, render_mode='human')
    else:
        env = gym.make(id=env_name)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, ac_type, device)

    return_list = train_on_policy_agent(env, agent, num_episodes, seed_number)
    
    mv_return = moving_average(return_list, 9)
    plot_curve(return_list, mv_return, ac_type, env_name)

最终的学习曲线如图所示

Image

Reference

赵世钰老师的课程
Hands on RL
Reinforcement Learning with Code 【Chapter 10. Actor Critic】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/872115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你不得不懂的IT知识-《敏捷项目管理》

国林哥在IBM时,几乎每天都会收到关于“敏捷”相关的邮件,公司鼓励我们去学习邮件里的知识,参加敏捷相关的认证和培训。刚开始我和大多数同事一样不管不顾,后来随着PBC里要求加上成长目标,比如要获得一个认证&#xff0…

为什么要试用CRM系统?有什么优点?

对于那些正在进行CRM选型的企业来说,想要了解一款CRM系统是否好用,亲自试用无疑是最好的方法。那么,有没有可以免费试用的在线CRM系统 CRM系统免费试用的好处 体验产品功能:您可以亲自操作和测试CRM系统的各项功能,如…

从一个GPU到多个GPU

在多GPU运行应用程序时,需要正确设计GPU之间的通信,GPU间数据传输的效率取决于GPU是如何连接在一个节点上并跨集群的 在多GPU系统里有两种连接方式 多GPU通过单个节点连接到PCIe总线上 多GPU连接到集群中的网络交换机上 /* * 本示例演示了如何使用 Open…

【Vue-Router】路由模式

1. WebHashHistory index.ts import { createRouter, createWebHistory, RouteRecordRaw, createWebHashHistory } from "vue-router";// 路由模式 //vue2 mode history -> vue3 createWebHistory //vue2 mode hash -> vue3 createWebHashHistory //vue2 m…

电池的正极是带正电?

首先说明结论:电池正极带正电,负极带负电。 一个错误的实例: 如果说电流是从电池正极流动到电池负极,那么电子就是从负极流动到正极,那么正极就是带负电。----这个说法是错误的。这是因为,根据那么很出名…

verilog学习笔记5——进制和码制、原码/反码/补码

文章目录 前言一、进制转换1、十进制转二进制2、二进制转十进制3、二进制乘除法 二、原码、反码、补码1、由补码计算十进制数2、计算某个负数的补码 前言 2023.8.13 天气晴 一、进制转换 1、十进制转二进制 整数:除以2,余数倒着写 小数:乘…

难解的bug

android.app.RemoteServiceException: Context.startForegroundService() did not then call Service.startForeground(): ServiceRecord 【Android TimeCat】 解决 context.startforegroundservice() did not then call service.startforeground() | XiChens Blog http://www…

【Linux从入门到精通】文件I/O操作(C语言vs系统调用)

文章目录 一、C语言的文件IO相关函数操作 1、1 fopen与fclose 1、2 fwrite 1、3 fprintf与fscanf 1、4 fgets与fputs 二、系统调用相关接口 2、1 open与close 2、2 write和read 三、简易模拟实现cat指令 四、总结 🙋‍♂️ 作者:Ggggggtm 🙋‍…

JAVA多线程和并发基础面试问答(翻译)

JAVA多线程和并发基础面试问答(翻译) java多线程面试问题 1. 进程和线程之间有什么不同? 一个进程是一个独立(self contained)的运行环境,它可以被看作一个程序或者一个应用。而线程是在进程中执行的一个任务。Java运行环境是一个包含了不同的类和程序…

Shell编程之条件测试、if语句、case语句

条件语句 一、条件测试1.1 测试命令1.1 文件测试1.2 整数比较1.3 字符串比较1.4 逻辑测试1.4.1 逻辑与 &&1.4.2 逻辑或 || 1.4.3 组合应用1.5 多个命令组合执行 ( ) { } 二、if语句2.1单分支结构2.2 多分支结构2.4 if语句练习2.4.1 单分支2.4.2 简单的交互式分数反馈 三…

Shell编程之正则表达式(非常详细)

正则表达式 1.通配符和正则表达式的区别2.基本正则表达式2.1 元字符 (字符匹配)2.2 表示匹配次数2.4 位置锚定2.5 分组 和 或者 3.扩展正则表达式4.部分文本处理工具4.1 tr 命令4.2 cut命令4.3 sort命令4.4 uniq命令 1.通配符和正则表达式的区别 通配符一般用于文件…

部署Springboot项目注意事项

步骤 步骤 1:将数据库内容在云服务器上的数据库部署一份 我使用mariadb;会出现一些不兼容现象;我们需要把默认值删掉 2:配置文件你得修改地方 a:linux是磁盘区分(像我自己项目用来储存验证码的文件我们得换这个配置;…

DoIP诊断入门

简介 DoIP(Diagnosis over Internet Protocol)是一种用于车辆诊断的网络通信协议。它基于现代互联网技术,允许通过以太网或IP网络进行车辆诊断和通信。 DoIP的背景是现代车辆中使用的电子控制单元(ECU)数量不断增加&…

利用OpenSSL实现私有 CA 搭建和证书颁发

利用OpenSSL实现私有 CA 搭建和证书颁发 一、私有 CA 搭建1. 安装openssl2. 配置 openssl3. 生成 CA 自己的私钥4. 生成 CA 自己的自签证书5. 验证自签证书 二、向私有CA申请证书流程1. 生成应用私钥文件2. 根据应用私钥生成证书申请文件3. 向CA请求颁发证书4. 验证应用证书5. …

PS/LR2024专用智能磨皮插件Portraiture提高P图效率

Portraiture 4智能磨皮插件支持Photoshop和Lightroom!Portraiture是一款智能磨皮插件,为Photoshop和Lightroom添加一键磨皮美化功能,快速对照片中皮肤、头发、眉毛等部位进行美化,无需手动调整,大大提高P图效率。全新4…

Wlan——无线服务集和AP的基本概念以及AP的配置

目录 WLAN服务集的基本概念 AP的基本概念 AP的分类 AP模式的切换 胖(FAT)AP介绍 胖AP的工作模式 接入模式和路由模式的区别 胖AP的组网方式 瘦(FIT)AP介绍 瘦AP的工作模式 瘦AP的组网方式 胖AP和瘦AP的区别 AP的配置…

(leecode)错误的集合

最近听到的,还可以,试试吧~ 题目: 示例: 提示: 题解: 思路: 将数字大小的位置,然后遍历每个位置,大小为0的是缺失数字,大小为2的是重复数字 int* findErro…

2022年12月 C/C++(一级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题:加一 输入一个整数x,输出这个整数加1后的值,即x1的值。 时间限制:1000 内存限制:65536 输入 一个整数x(0 ≤ x ≤ 1000)。 输出 按题目要求输出一个整数。 样例输入 9 样例输出 10 以下是使用C语言编写的解决方案…

湘大 XTU OJ:1406 String Game、1098 素数个数 题解(非常详细)

1406 String Game 一、链接 1406 String Game 二、题目 题目描述 Alice和Bob正在玩一个基于字符串的游戏,一开始,Alice和Bob分别拥有一个等长的字符串S1和S2,且这两个字符串只包含小写字母。 在每个回合中,Alice和Bob必须分…

【Vue-Router】路由入门

路由(Routing)是指确定网站或应用程序中特定页面的方式。在Web开发中,路由用于根据URL的不同部分来确定应用程序中应该显示哪个内容。 构建前端项目 npm init vuelatest //或者 npm init vitelatest安装依赖和路由 npm install npm instal…