double DQN 跑 Pendulum-v1

news2024/11/13 23:47:00

gym-0.26.1
Pendulum-v1 环境详细信息

double DQN

实验环境 是为了体现 double DQN对高估的缓解, 因为 Pendulum-v1 reward最大是为0,可以有明显的对比。

相关论文 Deep Reinforcement Learning with Double Q-Learning

对动手深度强化学习里的代码做了一些修改。
代码如下:

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils

class ReplayBuffer:
    """经验回放池"""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
    
    def add(self, state, action, reward, next_state, done): # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
        transition = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transition)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self): # 目前buffer中数据的数量
        return len(self.buffer)
        
class Q(nn.Module):
    """只有一层隐藏层的Q网络"""
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, X):
        X = F.relu(self.fc1(X)) # 隐藏层之后使用ReLU激活函数
        return self.fc2(X)

class DQN:
    """DQN算法,包括double DQN"""
    def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device, dqn_type='DQN'):
        self.action_dim = action_dim
        self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
        self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
        self.target_q.load_state_dict(self.q.state_dict())  # 加载参数
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update # 目标网络更新频率
        self.count = 0 # 计数器,记录更新次数
        self.device = device
        self.dqn_type = dqn_type
    
    def take_action(self, state): # epsilon-贪婪策略
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
            action = self.q(state).argmax().item()
        return action
    
    def max_q_value(self, state):
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        return self.q(state).max().item()
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        q_values = self.q(states).gather(1, actions) # Q值
        if self.dqn_type == "Double DQN":
            max_action = self.q(next_states).max(1)[1].reshape(-1, 1)
            max_next_q_values = self.target_q(next_states).gather(1, max_action)
        else:
            max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
        q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
        loss = F.mse_loss(q_values, q_targets) # 均方误差
        self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
        loss.mean().backward() # 反向传播
        self.optimizer.step() # 更新梯度
        
        if self.count % self.target_update == 0:
            self.target_q.load_state_dict(self.q.state_dict())
        self.count += 1

def dis_to_con(discrete_action, env, action_dim): # 将离散动作转回连续的函数
    action_lowbound= env.action_space.low[0] # 连续动作的最小值
    action_upbound = env.action_space.high[0] # 连续动作的最大值
    return action_lowbound + (discrete_action / (action_dim - 1) * (action_upbound - action_lowbound))

def train_DQN(agent, env, num_episodes, reply_buffer, miniaml_size, batch_size):
    return_list = []
    max_q_value_list = []
    max_q_value = 0
    for i in range(10):
        with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
            for i_episode in range(int(num_episodes/10)):
                episode_return = 0
                state = env.reset()[0]
                done, truncated= False, False
                while not done and not truncated :
                    action = agent.take_action(state)
                    max_q_value = agent.max_q_value(state) * 0.005 + max_q_value * 0.995 # 平滑处理
                    max_q_value_list.append(max_q_value)  # 保存每个状态的最大Q值
                    action_continuous = dis_to_con(action, env, agent.action_dim)
                    next_state, reward, done, truncated, info = env.step([action_continuous])
                    replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    episode_return += reward
                    # 当buffer数据的数量超过一定值后,才进行Q网络训练
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                        transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode+1) % 10 == 0:
                    pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                      'return': '%.3f' % np.mean(return_list[-10:])})
                pbar.update(1)
    
    episodes_list = list(range(len(return_list)))
    mv_return = rl_utils.moving_average(return_list, 5)
    plt.plot(episodes_list, mv_return)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title(f'{agent.dqn_type} on {env_name}')
    plt.show()
    
    frames_list = list(range(len(max_q_value_list)))
    plt.plot(frames_list, max_q_value_list)
    plt.axhline(0, c='orange', ls='--')
    plt.axhline(10, c='red', ls='--')
    plt.xlabel('Frames')
    plt.ylabel('Q value')
    plt.title(f'{agent.dqn_type} on {env_name}')
    plt.show()
lr = 1e-2
num_episodes = 200
hidden_dim = 128
gamma = 0.98
epsilon = 0.01
target_update = 50
buffer_size = 5000
minimal_size = 1000
batch_size = 64
device = d2l.try_gpu()
print(device)

env_name = "Pendulum-v1"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = 11 # 把连续动作分解成11个离散动作

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device)
train_DQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size)

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon, target_update, device,"Double DQN")
train_DQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size)

jupyter中运行,结果如下:



可以看到从double DQN 确实缓解了高估的问题,q-value整体下移了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java连接数据库实现用户登录和注册功能

目录 需求内容如下 示例代码 数据库studb Java代码 效果图 需求内容如下 1&#xff0c;创建数据库studb 2&#xff0c;库中添加用户表userinfo,包含如下字段 用户id ,用户名&#xff0c;用户密码&#xff0c;用户权限 &#xff08;数据类型和约束自己定义&#xff09…

SpringBoot2—开发实用篇2

目录 数据层解决方案 SQL NoSQL SpringBoot整合Redis SpringBoot整合MongoDB SpringBoot整合ES 数据层解决方案 SQL 回忆一下之前做SSMP整合的时候数据层解决方案涉及到了哪些技术&#xff1f;MySQL数据库与MyBatisPlus框架&#xff0c;后面又学了Druid数据源的配置&am…

TrustGeo代码理解(一)main.py

代码链接:https://github.com/ICDM-UESTC/TrustGeo 一、导入各种模块和数据库 # -*- coding: utf-8 -*- import torch.nnfrom lib.utils import * import argparse, os import numpy as np import random from lib.model import * import copy from thop import profile imp…

文本聚类——文本相似度(聚类算法基本概念)

一、文本相似度 1. 度量指标&#xff1a; 两个文本对象之间的相似度两个文本集合之间的相似度文本对象与集合之间的相似度 2. 样本间的相似度 基于距离的度量&#xff1a; 欧氏距离 曼哈顿距离 切比雪夫距离 闵可夫斯基距离 马氏距离 杰卡德距离 基于夹角余弦的度量 公式…

【从零开始学习JVM | 第七篇】深入了解 堆回收

前言&#xff1a; Java堆作为内存管理中最核心的一部分&#xff0c;承担着对象实例的存储和管理任务。堆内存的高效使用对于保障程序的性能和稳定性至关重要。因此&#xff0c;深入理解Java堆回收的原理、机制和优化策略&#xff0c;对于Java开发人员具有重要的意义。 本文旨在…

竞赛保研 python图像检索系统设计与实现

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; python图像检索系统设计与实现 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 该项目较为新颖&#xff0c…

Python基础04-数据容器

零、文章目录 Python基础04-数据容器 1、了解字符串 &#xff08;1&#xff09;字符串的定义 字符串是 Python 中最常用的数据类型。我们一般使用引号来创建字符串。创建字符串很简单&#xff0c;只要为变量分配一个值即可。<class ‘str’>即为字符串类型。一对引号…

计算机毕业设计 基于SpringBoot的二手物品交易管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

【一文带你掌握Java中方法定义、调用和重载的技巧】

方法的定义和调用 方法的定义 方法&#xff08;method&#xff09;是一段用于实现特定功能的代码块&#xff0c;类似于其他编程语言中的函数&#xff08;function&#xff09;。方法被用来定义类或类的实例的行为特征和功能实现。方法是类和对象行为特征的抽象表示。方法与面向…

『PyTorch』张量和函数之gather()函数

文章目录 PyTorch中的选择函数gather()函数 参考文献 PyTorch中的选择函数 gather()函数 import torch a torch.arange(1, 16).reshape(5, 3) """ result: a [[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12],[13, 14, 15]] """# 定义两个index…

圆通速递查询,圆通速递单号查询,一键复制查询好的物流信息

批量查询圆通速递单号的物流信息&#xff0c;并将查询好的物流信息一键复制出来。 所需工具&#xff1a; 一个【快递批量查询高手】软件 圆通速递单号若干 操作步骤&#xff1a; 步骤1&#xff1a;运行【快递批量查询高手】软件&#xff0c;第一次使用的朋友记得先注册&…

002.Java实现两数相加

题意 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示两数之和的新链表。 示例 输入&#xff1a;l1[2,4,3],l2[5,6,4] 输出…

ChatGPT4 Excel 高级复杂函数案例实践

案例需求: 需求中需要判断多个条件进行操作。 可以让ChatGPT来实现这样的操作。 Prompt:有一个表格B2单元格为入职日期,C2单元格为员工等级(A,B,C),D2单元格为满意度分数(1,2,3,4,5)请给入职一年以上,员工等级为A级并且满意度在3分以上的人发4000元奖金,给入…

普冉(PUYA)单片机开发笔记(10): I2C通信-配置从机

概述 I2C 常用在某些型号的传感器和 MCU 的连接&#xff0c;速率要求不高&#xff0c;距离很短&#xff0c;使用简便。 I2C的通信基础知识请参见《基础通信协议之 IIC详细讲解 - 知乎》。 PY32F003 可以复用出一个 I2C 接口&#xff08;PA3&#xff1a;SCL&#xff0c;PA2&a…

计算机组成原理-函数调用的汇编表示(call和ret指令 访问栈帧 切换栈帧 传递参数和返回值)

文章目录 call指令和ret指令高级语言的函数调用x86汇编语言的函数调用call ret指令小结其他问题 如何访问栈帧函数调用栈在内存中的位置标记栈帧范围&#xff1a;EBP ESP寄存器访问栈帧数据&#xff1a;push pop指令访问栈帧数据&#xff1a;mov指令小结 如何切换栈帧函数返回时…

深度学习记录--矩阵维数

如何识别矩阵的维数 如下图 矩阵的行列数容易在前向和后向传播过程中弄错&#xff0c;故写这篇文章来提醒易错点 顺便起到日后查表改错的作用 本文仅作本人查询参考(摘自吴恩达深度学习笔记)

Flink系列之:SQL提示

Flink系列之&#xff1a;SQL提示 一、动态表选项二、语法三、例子四、查询提示五、句法六、加入提示七、播送八、随机散列九、随机合并十、嵌套循环十一、LOOKUP十二、进一步说明十三、故障排除十四、连接提示中的冲突案例十五、什么是查询块 SQL 提示可以与 SQL 语句一起使用来…

MyBatis Plus 大数据量查询优化

大数据量操作的场景大致如下&#xff1a; 数据迁移 数据导出 批量处理数据 在实际工作中当指定查询数据过大时&#xff0c;我们一般使用分页查询的方式一页一页的将数据放到内存处理。但有些情况不需要分页的方式查询数据或分很大一页查询数据时&#xff0c;如果一下子将数…

【单元测试】Junit 4--junit4 内置Rule

1.0 Rules ​ Rules允许非常灵活地添加或重新定义一个测试类中每个测试方法的行为。测试人员可以重复使用或扩展下面提供的Rules之一&#xff0c;或编写自己的Rules。 1.1 TestName ​ TestName Rule使当前的测试名称在测试方法中可用。用于在测试执行过程中获取测试方法名称…

MATLAB 计算两片点云间的最小距离(2种方法) (39)

MATLAB 计算两片点云间的最小距离 (39) 一、算法介绍二、算法实现1.常规计算方法2.基于KD树的快速计算一、算法介绍 假设我们现在有两片点云 1 和 2 ,需要计算二者之间的最小距离,这里提供两种计算方法,分别是常规计算和基于KD树近邻搜索的快速计算方法,使用的测试数据如…