强化学习_06_pytorch-doubleDQN实践(Pendulum-v1)

news2025/1/18 2:08:15

环境描述

环境是倒立摆(Inverted Pendulum),该环境下有一个处于随机位置的倒立摆。环境的状态包括倒立摆角度的正弦值,余弦值,角速度;动作为对倒立摆施加的力矩(action = Box(-2.0, 2.0, (1,), float32))。每一步都会根据当前倒立摆的状态的好坏给予智能体不同的奖励,该环境的奖励函数为,倒立摆向上保持直立不动时奖励为 0,倒立摆在其他位置时奖励为负数。环境本身没有终止状态,所以训练的时候需要设置终止条件(笔者在本文设置了260)。

一、构建智能体

构建智能体:
policy是和之前一样的。探索和利用, 就是利用的时候基于nn模型的预测
主要核心:

  • QNet:
    • 就是一个多层的NN
    • update就是用MSELoss进行梯度下降
  • DQN: 支持DQN和doubleDQN
    • update:
      • 经验回放池R中的数据足够, 从R中采样N个数据 { (si, ai, ri, si+1) }+i=1,...,N
      • 对于DQN: 对每个数据计算 y i = r i + γ Q w − ( s ′ , a r g m a x ( Q w − ( s ′ , a ′ ) ) y_i = r_i + \gamma Q_{w^-}(s',arg max(Q_{w^-}(s', a')) yi=ri+γQw(s,argmax(Qw(s,a)) ,动作的选取依靠目标网络( Q w − Q_{w-} Qw)
      • 对于doubleDQN: 对每个数据计算 y i = r i + γ Q w − ( s ′ , a r g m a x ( Q w ( s ′ , a ′ ) ) y_i = r_i + \gamma Q_{w^-}(s',arg max(Q_{w}(s', a')) yi=ri+γQw(s,argmax(Qw(s,a)) ,动作的选取依靠训练网络( Q w Q_w Qw)
        • f m a x f_{max} fmax就是用 Q N e t ( s i + 1 ) QNet(s_{i+1}) QNet(si+1)计算出每个action对应的值,取最大值的index, 然后根据这个index去 T a g e t Q N e t ( s i + 1 ) TagetQNet(s_{i+1}) TagetQNet(si+1)中取最大值
          下面的代码可以看出差异
        # 下个状态的最大Q值
        if self.dqn_type == 'DoubleDQN': # DQN与Double DQN的区别
            max_action = self.q(next_states).max(1)[1].view(-1, 1)
            max_next_q_values = n_actions_q.gather(1, max_action)
        else: # DQN的情况
            max_next_q_values = n_actions_q.max(1)[0].view(-1, 1)

        q_targets = reward + self.gamma * max_next_q_values * (1 - done) 
  • 最小化目标函数 L = 1 N ∑ ( y i − Q N e t ( s i , a i ) ) 2 L=\frac{1}{N}\sum (yi - QNet(s_i, a_i))^2 L=N1(yiQNet(si,ai))2
    • y i y_i yiq(states).gather(1, action) 计算损失并更新参数

代码实现

class QNet(nn.Module):
    def __init__(self, state_dim: int, hidden_layers_dim: typ.List, action_dim: int):
        super(QNet, self).__init__()
        self.features = nn.ModuleList()
        for idx, h in enumerate(hidden_layers_dim):
            self.features.append(
                nn.ModuleDict({
                    'linear': nn.Linear(state_dim if not idx else hidden_layers_dim[idx-1], h),
                    'linear_active': nn.ReLU(inplace=True)
                    
                })
            )
        self.header = nn.Linear(hidden_layers_dim[-1], action_dim)
    
    def forward(self, x):
        for layer in self.features:
            x = layer['linear_active'](layer['linear'](x))
        return self.header(x)

    def model_compelet(self, learning_rate):
        self.cost_func = nn.MSELoss()
        self.opt = torch.optim.Adam(self.parameters(), lr=learning_rate)

    def update(self, pred, target):
        self.opt.zero_grad()
        loss = self.cost_func(pred, target)
        loss.backward()
        self.opt.step()


class DQN:
    def __init__(self, 
                state_dim: int, 
                hidden_layers_dim, 
                action_dim: int, 
                learning_rate: float,
                gamma: float,
                epsilon: float=0.05,
                traget_update_freq: int=1,
                device: typ.AnyStr='cpu',
                dqn_type: typ.AnyStr='DQN'
                ):
        self.action_dim = action_dim
        # QNet & targetQNet
        self.q = QNet(state_dim, hidden_layers_dim, action_dim)
        self.target_q = copy.deepcopy(self.q)
        self.q.to(device)
        self.q.model_compelet(learning_rate)
        self.target_q.to(device)

        # iteration params
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.epsilon = epsilon
        
        # target update freq
        self.traget_update_freq = traget_update_freq
        self.count = 0
        self.device = device
        
        # dqn类型
        self.dqn_type = dqn_type
        
    def policy(self, state):
        if np.random.random() < self.epsilon:
            return np.random.randint(self.action_dim)
        
        action = self.target_q(torch.FloatTensor(state))
        return np.argmax(action.detach().numpy())
    
    def update(self, samples: deque):
        """
        Q<s, a, t> = R<s, a, t> + gamma * Q<s+1, a_max, t+1>
        """
        self.count += 1
        state, action, reward, next_state, done = zip(*samples)
        
        states = torch.FloatTensor(state).to(self.device)
        action = torch.tensor(action).view(-1, 1).to(self.device)
        reward = torch.tensor(reward).view(-1, 1).to(self.device)
        next_states = torch.FloatTensor(next_state)
        done = torch.FloatTensor(done).view(-1, 1).to(self.device)
        
        actions_q = self.q(states)
        n_actions_q = self.target_q(next_states)
        q_values = actions_q.gather(1, action)
        # 下个状态的最大Q值
        if self.dqn_type == 'DoubleDQN': # DQN与Double DQN的区别
            max_action = self.q(next_states).max(1)[1].view(-1, 1)
            max_next_q_values = n_actions_q.gather(1, max_action)
        else: # DQN的情况
            max_next_q_values = n_actions_q.max(1)[0].view(-1, 1)

        q_targets = reward + self.gamma * max_next_q_values * (1 - done) 
        # MSELoss update
        self.q.update(q_values.float(), q_targets.float())
        if self.count % self.traget_update_freq == 0:
            self.target_q.load_state_dict(
                self.q.state_dict()
            )

二、智能体训练

2.1 注意点

在训练的时候:

  • 所有的参数都在Config中进行配置,便于调参
class Config:
    num_episode  = 300
    state_dim = None
    hidden_layers_dim = [10, 10]
    action_dim = 20
    learning_rate = 2e-3
    gamma = 0.95
    epsilon = 0.01
    traget_update_freq = 3
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    buffer_size = 2048
    minimal_size = 1024
    batch_size = 128
    render = False
    save_path =  r'D:\TMP\model.ckpt' 
    dqn_type = 'DoubleDQN'
    # 回合停止控制
    max_episode_rewards = 260
    max_episode_steps = 260

    def __init__(self, env):
        self.state_dim = env.observation_space.shape[0]
        try:
            self.action_dim = env.action_space.n
        except Exception as e:
            pass
        print(f'device = {self.device} | env={str(env)}')
  • 由于这次环境的action空间是连续的,我们需要有一个函数进行action的离散和连续的转换
def Pendulum_dis_to_con(discrete_action, env, action_dim):  # 离散动作转回连续的函数
    action_lowbound = env.action_space.low[0]  # 连续动作的最小值
    action_upbound = env.action_space.high[0]  # 连续动作的最大值
    action_range = action_upbound - action_lowbound
    return action_lowbound + (discrete_action / (action_dim - 1)) * action_range

2.2 训练

需要注意的是笔者的gym版本是0.26.2

def train_dqn(env, cfg, action_contiguous=False):
    buffer = replayBuffer(cfg.buffer_size)
    dqn = DQN(
        state_dim=cfg.state_dim,
        hidden_layers_dim=cfg.hidden_layers_dim, 
        action_dim=cfg.action_dim,
        learning_rate=cfg.learning_rate,
        gamma=cfg.gamma,
        epsilon=cfg.epsilon,
        traget_update_freq=cfg.traget_update_freq,
        device=cfg.device,
        dqn_type=cfg.dqn_type
    )
    tq_bar = tqdm(range(cfg.num_episode))
    rewards_list = []
    now_reward = 0
    bf_reward = -np.inf
    for i in tq_bar:
        tq_bar.set_description(f'Episode [ {i+1} / {cfg.num_episode} ]')
        s, _ = env.reset()
        done = False
        episode_rewards = 0
        steps = 0
        while not done:
            a = dqn.policy(s)
            # [Any, float, bool, bool, dict]
            if action_contiguous:
                c_a = Pendulum_dis_to_con(a, env, cfg.action_dim)
                n_s, r, done, _, _ = env.step([c_a])
            else:
                n_s, r, done, _, _ = env.step(a)
            buffer.add(s, a, r, n_s, done)
            s = n_s
            episode_rewards += r
            steps += 1
            # buffer update
            if len(buffer) > cfg.minimal_size:
                samples = buffer.sample(cfg.batch_size)
                dqn.update(samples)
            if (episode_rewards >= cfg.max_episode_rewards) or (steps >= cfg.max_episode_steps):
                break

        rewards_list.append(episode_rewards)
        now_reward = np.mean(rewards_list[-10:])
        if bf_reward < now_reward:
            torch.save(dqn.target_q.state_dict(), cfg.save_path)

        bf_reward = max(bf_reward, now_reward)
        tq_bar.set_postfix({'lastMeanRewards': f'{now_reward:.2f}', 'BEST': f'{bf_reward:.2f}'})
    env.close()
    return dqn


if __name__ == '__main__':
    print('=='*35)
    print('Training Pendulum-v1')
	p_env = gym.make('Pendulum-v1')
	p_cfg = Config(p_env)
	p_dqn = train_dqn(p_env, p_cfg, True)

三、训练出的智能体观测

最后将训练的最好的网络拿出来进行观察

p_dqn.target_q.load_state_dict(torch.load(p_cfg.save_path))
play(gym.make('Pendulum-v1', render_mode="human"), p_dqn, p_cfg, episode_count=2, action_contiguous=True)

从下图中我们可以看出,本次的训练成功还是可以的。
在这里插入图片描述

完整脚本查看笔者github: Doubledqn_lr.py 记得点Star

笔者后续会更深入的学习强化学习并对gym各个环境逐一进行训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/85126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows11安装cuda+cudnn

安装Nvidia显卡驱动 如需安装显卡驱动&#xff0c;在官方驱动下载网站找到自己的显卡型号对应的驱动下载并安装:官方驱动 | NVIDIA 安装CUDA 前言 windows10 版本安装 CUDA &#xff0c;首先需要下载两个安装包 CUDA toolkit&#xff08;toolkit就是指工具包&#xff09;cu…

Qt扫盲-QLineEdit理论总结

QLineEdit理论总结1. 简述2. 输入模式3. 输入限制4. 文本操作槽函数3. 信号4. 快捷键5. 外观1. 简述 QLineEdit 是一个有用的编辑功能类&#xff0c;主要是处理输入和编辑单行纯文本 &#xff0c;主要是单行哦&#xff0c;就用来输入简单&#xff0c;短小的字符串。内部其实已…

极客时间Kafka - 09 Kafka Java Consumer 多线程开发实例

文章目录1. Kafka Java Consumer 设计原理2. 多线程方案3. 代码实现4. 问题思考目前&#xff0c;计算机的硬件条件已经大大改善&#xff0c;即使是在普通的笔记本电脑上&#xff0c;多核都已经是标配了&#xff0c;更不用说专业的服务器了。如果跑在强劲服务器机器上的应用程序…

JSP ssh科研管理系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

一、源码特点 JSP ssh科研管理系统是一套完善的web设计系统&#xff08;系统采用ssh框架进行设计开发&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myec…

Core Scheduling

Core Scheduling要解决什么问题&#xff1f; core scheduling是v5.14中新增的功能&#xff0c;下图是内核数据结构为该功能所添加的字段。 为什么有core scheduling呢&#xff1f;因为当开启超线程(HyperThreading)时&#xff0c;一个物理核就变成了两个逻辑核&#xff0c;但&…

postgres 源码解析43 元组的插入流程详解 heap_insert

本文讲解postgres中元组的插入流程&#xff0c;深入了解其实现原理。同时此过程涉及元组xmin/xmax与标识位的设置细节&#xff0c;与事务的可见性部分密切相关相关&#xff0c;借此复习一下。 heappage结构 执行流程框架图 heap_prepare_insert 该函数执行内容较为简单&#…

课设项目之——教学辅助系统(学生考试监考系统)

在考试场中为学生监考十分枯燥&#xff0c;因此&#xff0c;建立一个可靠的作弊检测系统来识别学生是否存在作弊行为。 使用一个名为 Yolo3 的训练模型和一个名为 coco 的数据集&#xff0c;我们测试了考场中学生的书籍和手机&#xff0c;并将他们标记为作弊者。 使用haarcasc…

如何将dxf或dwg等CAD文件与卫星影像地图叠加进行绘图设计?

引言&#xff1a; 在测绘、电力、水利、规划或道路设计等GIS相关行业中&#xff0c;通常会用AutoCAD进行矢量地图数据的绘制&#xff0c;而这些地图数据通常又是建立在投影平面坐标的基础上进行绘制的。 为了确保地图数据的准确性与精度的要求&#xff0c;这些地图数据经常会…

将一个乱序数组变为有序数组的最少交换次数

给定一个包含1-n的数列&#xff0c;通过交换任意两个元素给数列重新排序。求最少需要多少次交换&#xff0c;能把数组排成按1-n递增的顺序 总之就是将这个位置应该出现的元素和这个位置现在的元素交换位置 代码实现&#xff1a; 核心&#xff1a;记住一点&#xff0c;hashmap用…

【debug】时序预测的结果都是一个趋势

时序预测的结果都是一个趋势现象原因solutionother solutions现象 预测的是一个序列。 在测试集中随机取20个来看&#xff0c;所有的预测序列都是一个趋势&#xff0c;但是大小有所区别。 举例图片 原因 目前来看是数据的问题&#xff0c;应该是样本不均衡&#xff0c;某一…

简单个人网页制作 个人介绍网页模板 静态HTML留言表单页面网站模板 大学生个人主页网页

&#x1f389;精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业…

[ Linux ] 一篇带你理解Linux下线程概念

目录 1.Linux线程的概念 1.1什么是线程 1.1.1如何验证一个进程内有多个线程&#xff1f; 1.2线程的优点 1.3线程的缺点 1.4 线程异常 1.5 线程用途 2.Linux进程与线程 2.1进程和线程 2.2 进程和线程的关系 2.3如何看待之前学习的单进程&#xff1f; 1.Linux线程的概…

迪杰斯特拉算法求图的最短路径(java)

迪杰斯特拉算法 图的最短路径的解法 单源最短路径 从一个点开始&#xff0c;可以找到其中任意一个点的最短路径。 多源最短路径 从任何一个点开始&#xff0c;可以找到其中任何一个点的最短路径。 解题过程 给定一个带权有向图G(G, V), 另外&#xff0c;还给定 V 中的一…

力扣(LeetCode)1832. 判断句子是否为全字母句(C++)

哈希集合1 哈希集合记录 262626 个字母是否出现&#xff0c;一次遍历字符串&#xff0c;维护哈希集合&#xff0c;同时维护答案。遍历完成&#xff0c;仅当答案等于 262626 &#xff0c;句子是全字母句。 class Solution { public:bool checkIfPangram(string sentence) {boo…

轻松提高性能和并发度,springboot简单几步集成缓存

目录 1、缘由 2、技术介绍 2.1、技术调研 2.2、spring支持的cache 2.3、cache的核心注解 2.3.1 EnableCaching 2.3.2 Cacheable 2.3.3 CachePut 2.3.4 CacheEvict 2.4 cache的架构 2.5 cachemanager的实现类 3、搞个例子 3.1 为什么使用redis 作为缓存 3.2 代码走起…

【虚幻引擎】UE4/UE5数字孪生与前端Web页面匹配

一、数字孪生 数字孪生是一种多维动态的数字映射&#xff0c;可大幅提高效能。数字孪生是充分利用物理模型、传感器更新、运行历史等数据&#xff0c;集成多学科、多物理量、多尺度、多概率的仿真过程&#xff0c;在虚拟空间中完成对现实体的复制和映射&#xff0c;从而反映物理…

MySQL常用窗口函数

1、窗口函数概念 窗口的概念非常重要&#xff0c;它可以理解为记录集合&#xff0c;窗口函数也就是在满足某种条件的记录集合上执行的特殊函数对于每条记录都要在此窗口内执行函数&#xff0c;有的函数随着记录不同&#xff0c;窗口大小都是固定的&#xff0c;这种属于静态窗口…

c语言:枚举类型—enum

枚举类型一.常见形式二.枚举和宏定义三.枚举的意义四.插个小知识一.常见形式 这里举一个例子&#xff0c;我想要枚举颜色 注意一下细节&#xff0c;所有成员间用逗号隔开&#xff0c;最后一个成员后不加标点符号 这里看上去和定义结构体和联合体的样式一样&#xff0c;但其实前…

minio安装部署和minIO-Client的使用

minio安装部署和minIO-Client的使用 一、服务器安装minio 1.进行下载 下载地址&#xff1a; GNU/Linux https://dl.min.io/server/minio/release/linux-amd64/minio2.新建minio安装目录&#xff0c;执行如下命令 mkdir -p /home/minio/data把二进制文件上传到安装目录后&a…

【PAT甲级 - C++题解】1128 N Queens Puzzle

✍个人博客&#xff1a;https://blog.csdn.net/Newin2020?spm1011.2415.3001.5343 &#x1f4da;专栏地址&#xff1a;PAT题解集合 &#x1f4dd;原题地址&#xff1a; 题目详情 - 1128 N Queens Puzzle (pintia.cn) &#x1f511;中文翻译&#xff1a;皇后问题 &#x1f4e3;…