Noisy DQN 跑 CartPole-v1

news2025/1/12 12:26:25

gym 0.26.1
CartPole-v1
NoisyNet DQN

NoisyNet 就是把原来Linear里的w/b 换成 mu + sigma * epsilon, 这是一种非常简单的方法,但是可以显著提升DQN的表现。
和之前最原始的DQN相比就是改了两个地方,一个是Linear改成了NoisyLinear,另外一个是在agenttake_action的时候策略 由ε-greedy改成了直接取argmax。详细见下面的代码。

本文的实现参考王树森的深度强化学习。

引用书上的一段话, 噪声DQN本身就带有随机性,可以鼓励探索,起到与ε-greedy策略相同的作用,直接用a_t = argmax Q(s,a,epsilon; mu,sigma), 作为行为策略,效果比ε-greedy更好。

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
import math

class ReplayBuffer:
    """经验回放池"""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
    
    def add(self, state, action, reward, next_state, done): # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
        transition = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transition)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self): # 目前buffer中数据的数量
        return len(self.buffer)

class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
        super().__init__(in_features, out_features, bias)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))
            self.register_buffer("epsilon_bias", torch.zeros(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x, is_training=True):
        self.epsilon_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.epsilon_bias.normal_()
            bias = bias + self.sigma_bias * self.epsilon_bias.data
        if is_training:
            return F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)
        else:
            return F.linear(x, self.weight, bias)

class Q(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = NoisyLinear(state_dim, hidden_dim)
        self.fc2 = NoisyLinear(hidden_dim, action_dim)
    def forward(self, x, is_training=True):
        x = F.relu(self.fc1(x, is_training)) # 隐藏层之后使用ReLU激活函数
        return self.fc2(x, is_training)

class DQN:
    """DQN算法"""
    def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, target_update, device):
        self.action_dim = action_dim
        self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
        self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
        self.target_q.load_state_dict(self.q.state_dict())  # 加载参数
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma
        self.target_update = target_update # 目标网络更新频率
        self.count = 0 # 计数器,记录更新次数
        self.device = device
    
    def take_action(self, state): # 这个地方就不用epsilon-贪婪策略
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        action = self.q(state).argmax().item()
        return action
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        q_values = self.q(states).gather(1, actions) # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
        q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
        loss = F.mse_loss(q_values, q_targets) # 均方误差
        self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
        loss.mean().backward() # 反向传播
        self.optimizer.step() # 更新梯度
        
        if self.count % self.target_update == 0:
            self.target_q.load_state_dict(self.q.state_dict())
        self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)

env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, target_update, device)
return_list = []

for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

这次是在pycharm上运行jupyter file,结果如下:




效果对比之前的DQN 详细参考这篇 表现是显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1351375.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第二十七章 正则表达式

第二十七章 正则表达式 1.正则快速入门2.正则需求问题3.正则底层实现14.正则底层实现25.正则底层实现36.正则转义符7.正则字符匹配8.字符匹配案例19.字符匹配案例211.选择匹配符(|)12.正则限定符{n}{n,m}(1个或者多个)*(0个或者多…

创建x11vnc系统进程

为方便使用vnc,所以寻找到一个比较好用的vnc服务端那就是x11vnc,索性就创建了一个系统进程 一、环境 系统:银河麒麟v4-sp2-server 软件:x11vnc【linux下】、VNCviewer【win下】 二、安装x11vnc 1、挂载光盘源并修改apt源 mou…

生态系统服务构建生态安全格局中的实践技术应用

生态安全是指生态系统的健康和完整情况。生态安全的内涵可以归纳为:一,保持生态系统活力和内外部组分、结构的稳定与持续性;二,维持生态系统生态功能的完整性;三,面临外来不利因素时,生态系统具…

Linux用shell脚本执行乘法口诀表的两种方式

#!/bin/bash # *********************************************************# # # # * Author : 藻头男 # # * QQ邮箱 : 2322944912qq.com # …

【SpringBoot3】1.SpringBoot入门的第一个完整小项目(新手保姆版+教会打包)

目录 1 SpringBoot简单介绍1.1 SpringBoot是什么1.2 主要优点1.3 术语1.3.1 starter(场景启动器) 1.4 官方文档 2 环境说明3 实现代码3.1 新建工程与模块3.2 加入依赖3.3 主程序文件3.4 业务代码3.5 运行测试3.6 部署打包3.7 命令行运行 1 SpringBoot简单…

[足式机器人]Part2 Dr. CAN学习笔记-自动控制原理Ch1-8Lag Compensator滞后补偿器

本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记-自动控制原理Ch1-8Lag Compensator滞后补偿器 从稳态误差入手(steady state Error) 误差 Error : E ( s ) R ( s ) − X ( s ) R ( s ) − E ( s ) ⋅ K G …

再见2023,你好2024!

大家好,我是老三,本来今天晚上打算出去转一转,陆家嘴打车实在太艰难了,一公里多的路,司机走了四十分钟,还没到,再加上身体不适,咳嗽地比较厉害,所以还是宅在酒店里&#…

.NET Core SkiaSharp 替代 System.Drawing.Common 的一些用法

在.NET 6中,微软官方建议把 System.Drawing.Common 迁移到 SkiaSharp 库。因为System.Drawing.Common 被设计为 Window 技术的精简包装器,因此其跨平台实现欠佳。 SkiaSharp是一个基于谷歌的Skia图形库(Skia.org)的用于.NET平台的…

机器学习与深度学习——使用paddle实现随机梯度下降算法SGD对波士顿房价数据进行线性回归和预测

文章目录 机器学习与深度学习——使用paddle实现随机梯度下降算法SGD对波士顿房价数据进行线性回归和预测一、任务二、流程三、完整代码四、代码解析五、效果截图 机器学习与深度学习——使用paddle实现随机梯度下降算法SGD对波士顿房价数据进行线性回归和预测 随机梯度下降&a…

深度学习 Day23——J3DenseNet算法实战与解析

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 文章目录 前言1 我的环境2 pytorch实现DenseNet算法2.1 前期准备2.1.1 引入库2.1.2 设…

GitHub Copilot 最佳免费平替:阿里通义灵码

之前分享了不少关于 GitHub Copilot 的文章,不少粉丝都评论让我试试阿里的通义灵码,这让我对通义灵码有了不少的兴趣。 今天,阿七就带大家了解一下阿里的通义灵码,我们按照之前 GitHub Copilot 的顺序分享通义灵码在相同场景下的…

RabbitMQ基础知识

一.什么是RabbitMQ RabbitMQ是一个开源的、高性能的消息队列系统,用于在应用程序之间实现异步通信。它实现了AMQP(Advanced Message Queuing Protocol)协议,可以在分布式系统中传递和存储消息。 消息队列是一种将消息发送者和接收…

六、Redis 分布式系统 —— 超详细操作演示!

六、Redis 分布式系统 —— 超详细操作演示! 六、Redis 分布式系统6.1 数据分区算法6.1.1 顺序分区6.1.2 哈希分区 6.2 系统搭建与运行6.2.1 系统搭建6.2.2 系统启动与关闭 6.3 集群操作6.3.1 连接集群6.3.2 写入数据6.3.3 集群查询6.3.4 故障转移6.3.5 集群扩容6.3…

Android 12.0 禁用插入耳机时弹出的保护听力对话框

1.前言 在12.0的系统rom定制化开发中,在某些产品中会对耳机音量调节过高限制,在调高到最大音量的70%的时候,会弹出音量过高弹出警告,所以产品 开发的需要要求去掉这个音量弹窗警告功能,接下来具体实现相关功能 2.禁用插入耳机时弹出的保护听力对话框的核心类 frameworks\b…

JAVA对象、List、Map和JSON之间的相互转换

JAVA对象、List、Map和JSON之间的相互转换 1.Java中对象和json互转2.Java中list和json互转3.Java中map和json互转 1.Java中对象和json互转 Object obj new Object(); String objJson JSONObject.toJSONString(obj);//java对象转json Object newObj JSONObject.parseObject(…

Pytorch详细安装过程

1、安装anaconda 官网(https://www.anaconda.com/products/distribution#Downloads)下载,使用管理员身份运行(不使用似乎也没事) 这里选择Just me(至于为啥,咱也不是很清楚) 更改路…

玩转贝启科技BQ3588C开源鸿蒙系统开发板 —— 开发板详情与规格

本文主要参考: BQ3588C_开发板详情-开源鸿蒙技术交流-Bearkey-开源社区 BQ3588C_开发板规格-开源鸿蒙技术交流-Bearkey-开源社区 厦门贝启科技有限公司-Bearkey-官网 1. 开发板详情 RK3588 核心板是一款由贝启科技自主研发的基于瑞芯微 RK3588 AI 芯片的智能核心…

强大的隐藏应用 Hides 5中文 for mac

Hides 5是一款Mac上的应用程序,旨在帮助用户隐藏其他应用程序并专注于当前任务,从而提高工作效率。其主要功能包括对焦模式、隐藏所有打开的应用程序、隐藏除当前活动应用之外的所有打开的应用程序、支持全局热键、可定制性、支持多种显示方式等。 Hide…

vue3项目创建

安装node.js vue --version (4.5.0以上) npm install -g vue/cli vue create 项目名称 npm run dev 启动 npm run build 打包 ———————— vite 创建工程 npm create vuelatest npm i npm run dev 启动 npm run build 打包 项目结构…

WeNet语音识别+Qwen-72B-Chat Bot+Sambert-Hifigan语音合成

WeNet语音识别Qwen-72B-Chat Bot👾Sambert-Hifigan语音合成 简介 利用 WeNet 进行语音识别,使用户能够通过语音输入与系统进行交互。接着,Qwen-72B-Chat Bot作为聊天机器人接收用户的语音输入或文本输入,提供响应并与用户进行对话…