强化学习之Double DQN算法与DQN算法对比学习——以倒立摆(Inverted Pendulum)环境为例

news2024/11/23 8:54:07

0.简介

DQN算法敲开深度强化学习大门,但是其存在着一些问题,有进一步改进的空间。因此在DQN后出现大量改进算法。在此介绍DQN算法改进算法之一Double DQN,其在DQN算法基础上稍加修改实现在一定程度上改善DQN效果。

 普通DQN算法会导致对Q值的过高估计,即overestimation,传统DQN优化的TD误差目标为

TD_{error}=r+\gamma \underset{​{a}'}{max}(Q_{\omega ^{-} }{({s}',{a}'}))=r+\gamma Q_{\omega ^{-} }({s}',{arg\; } \underset{​{a}'}{max}Q_{\omega^{-} }{({s}',{a}'}))

而Double DQN的优化目标为:

TD_{error}=r+\gamma Q_{\omega ^{-} }({s}',{arg\; } \underset{​{a}'}{max}Q_{\omega }{({s}',{a}'}))

显然DQN 与Double DQN差别在于计算状态s'下的Q值如何选择动作。

DQN优化目标可以写成r+\gamma Q_{\omega ^{-} }({s}',{arg\; } \underset{​{a}'}{max}Q_{\omega^{-} }{({s}',{a}'})),动作的选取依靠目标网络Q_{\omega^{-}}

Double DQN优化目标可以写成TD_{error}=r+\gamma Q_{\omega ^{-} }({s}',{arg\; } \underset{​{a}'}{max}Q_{\omega }{({s}',{a}'})),动作选取依靠训练网络Q_{\omega}

1.必要库导入

import random
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
import collections
import rl_utils
from tqdm import tqdm

2.Q网络和经验回放池实现

class ReplayBuffer:
    '''经验回放池'''
    def __init__(self,capacity):
        self.buffer=collections.deque(maxlen=capacity)#队列,先进先出
    def add(self,state,action,reward,nextstate,done):
        self.buffer.append((state,action,reward,nextstate,done))
    def sample(self,batch_size):
        transitions=random.sample(self.buffer,batch_size)
        state,action,reward,nextstate,done=zip(*transitions)
        return np.array(state),action,reward,np.array(nextstate),done
    def size(self):
        return len(self.buffer)
class Qnet(torch.nn.Module):
    """ 只有一层隐藏层的Q网络 """
    def __init__(self,state_dim,hidden_dim,action_dim):
        super(Qnet,self).__init__()
        self.fc1=torch.nn.Linear(state_dim,hidden_dim)
        self.fc2=torch.nn.Linear(hidden_dim,action_dim)
    def forward(self,x):
        x=torch.nn.functional.relu(self.fc1(x))
        return self.fc2(x)

3.Double DQN算法实现

class DQN:
    """ DQN算法,包括Double DQN """
    def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device,dqntype='VanillaDQN'):
        self.action_dim=action_dim
        self.gamma=gamma
        self.epsilon=epsilon
        self.target_update=target_update
        self.device=device
        self.dqntype=dqntype
        self.count=0
        self.qnet=Qnet(state_dim,hidden_dim,self.action_dim).to(self.device)
        self.targetqnet=Qnet(state_dim,hidden_dim,self.action_dim).to(self.device)
        self.optimizer=torch.optim.Adam(self.qnet.parameters(),lr=learning_rate)
    def takeaction(self,state):
        if np.random.random()<self.epsilon:
            action=np.random.randint(self.action_dim)
        else:
            state=torch.tensor([state],dtype=torch.float).to(self.device)
            action=self.qnet(state).argmax().item()
        return action
    def maxqvalue(self,state):#计算最大Q值
        state=torch.tensor([state],dtype=torch.float).to(self.device)
        return self.qnet(state).max().item()
    def update(self,transition_dict):#更新Q值
        states=torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
        actions=torch.tensor(transition_dict['actions']).view(-1,1).to(self.device)
        rewards=torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1,1).to(self.device)
        nextstates=torch.tensor(transition_dict['nextstates'],dtype=torch.float).to(self.device)
        dones=torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1,1).to(self.device)
        qvalues=self.qnet(states).gather(1,actions)#在 gather(1, actions) 中:第一个参数 1 表示沿着维度 1 (列维度)进行收集操作。actions 是一个张量,作为收集的索引。
        if self.dqntype=='DoubleDQN':
            maxaction=self.qnet(nextstates).max(1)[1].view(-1,1)#max(1)表示沿着维度1求最大值,返回两个值,第一个值为每行最大值,第二个值为每行最大值所在的索引。view(-1,1)将其转化为列向量
            maxnextqvalues=self.targetqnet(nextstates).gather(1,maxaction)
        else:#DQN
            maxnextqvalues=self.targetqnet(nextstates).max(1)[0].view(-1,1)
        qtargets=rewards+self.gamma*maxnextqvalues*(1-dones)
        dqnloss=torch.mean(torch.nn.functional.mse_loss(qvalues,qtargets))
        self.optimizer.zero_grad()
        dqnloss.backward()
        self.optimizer.step()#优化器根据计算得到的梯度进行参数更新
        if self.count%self.target_update==0:
            self.targetqnet.load_state_dict(self.qnet.state_dict())
        self.count+=1

4.设置相应的超参数

lr=1e-2
episodesnum=200
hidden_dim=128
gamma=0.98
epsilon=0.01
target_update=50
buffersize=5000
minimalsize=1000
batch_size=64
pbarnum=10
printreturnnum=10

5.实现倒立摆环境中连续动作转化为离散动作函数

device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
env=gym.make('Pendulum-v1')
env.reset(seed=10)
# env.render()
state_dim=env.observation_space.shape[0]#print(env.observation_space):Box([-1. -1. -8.], [1. 1. 8.], (3,), float32);print(env.observation_space.shape):(3,)
action_dim=11#将连续动作分成11个离散动作
def dis_to_con(actionid,env,action_dim):
    action_lowbound=env.action_space.low[0]#连续动作的最小值
    action_upbound=env.action_space.high[0]#连续动作的最大值
    return action_lowbound+(action_upbound-action_lowbound)/(action_dim-1)*actionid

6.定义一个DQN算法训练过程的函数,以便后续重复多次调用。训练过程中会记录每个状态下的最大Q值,训练后可结果可进行可视化,观测这些Q值存在的过高估计的情况,以此对比DQN与Double DQN差异。

def train_DQN(agent,env,num_episodes,pbarnum,printreturnnum,replay_buffer,minimal_size,batch_size):
    returnlist=[]
    maxqvaluelist=[]
    maxqvalue=0
    for i in range(pbarnum):
        with tqdm(total=int(num_episodes/pbarnum),desc='Iteration %d'%i) as pbar:
            for episode in range(int(num_episodes/pbarnum)):
                episodereturn=0
                state,info=env.reset(seed=10)
                done=False
                while not done:
                    action=agent.takeaction(state)
                    # env.render()
                    maxqvalue=agent.maxqvalue(state)*0.005+maxqvalue*0.995#平滑处理
                    maxqvaluelist.append(maxqvalue)
                    action_continuous=dis_to_con(action,env,agent.action_dim)
                    nextstate,reward,done,truncated,_=env.step([action_continuous])
                    done=done or truncated
                    replay_buffer.add(state,action,reward,nextstate,done)
                    state=nextstate
                    episodereturn+=reward
                    if replay_buffer.size()>minimal_size:
                        bs,ba,br,bns,bd=replay_buffer.sample(batch_size)
                        transition_dict={'states':bs,'actions':ba,'rewards':br,'nextstates':bns,'dones':bd}
                        agent.update(transition_dict)
                returnlist.append(episodereturn)
                if(episode+1)%printreturnnum==0:
                    pbar.set_postfix({'episode':'%d'%(num_episodes/pbarnum*i+episode+1),'return':'%.3f'%np.mean(returnlist[-printreturnnum:])})
                pbar.update(1)
    return returnlist,maxqvaluelist

7.训练DQN并打印打印其学习过程中最大Q值的情况

random.seed(10)
np.random.seed(10)
torch.manual_seed(10)
replaybuffer = ReplayBuffer(buffersize)
agent = DQN(state_dim=state_dim, hidden_dim=hidden_dim, action_dim=action_dim, learning_rate=lr, gamma=gamma,
             epsilon=epsilon, target_update=target_update, device=device, dqntype='VanillaDQN')
returnlist, maxqvaluelist= train_DQN(agent=agent, env=env, num_episodes=episodesnum, pbarnum=pbarnum,
                                      printreturnnum=printreturnnum, replay_buffer=replaybuffer,
                                      minimal_size=minimalsize, batch_size=batch_size)
episodelist = list(range(len(returnlist)))
mvreturn = moving_average(returnlist, 5)
plt.plot(episodelist, mvreturn)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env.spec.name))
plt.show()

frameslist = list(range(len(maxqvaluelist)))
plt.plot(frameslist, maxqvaluelist)
plt.axhline(y=0, c='orange', ls='--')
plt.axhline(y=10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('DQN on {}'.format(env.spec.name))
plt.show()
env.close()

下面则是一个平滑移动计算回报的函数实现

def moving_average(a, window_size):
    cumulative_sum = np.cumsum(np.insert(a, 0, 0)) 
    middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
    r = np.arange(1, window_size-1, 2)
    begin = np.cumsum(a[:window_size-1])[::2] / r
    end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
    return np.concatenate((begin, middle, end))

8.DQN训练结果可视化以及结论

Iteration 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.61it/s, episode=20, return=-1195.945]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.91it/s, episode=40, return=-870.678] 
Iteration 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.90it/s, episode=60, return=-442.819] 
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.17it/s, episode=80, return=-371.271] 
Iteration 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.99it/s, episode=100, return=-419.501] 
Iteration 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.06it/s, episode=120, return=-388.722] 
Iteration 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00,  2.47it/s, episode=140, return=-947.871] 
Iteration 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00,  2.40it/s, episode=160, return=-287.627] 
Iteration 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.03it/s, episode=180, return=-747.758] 
Iteration 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.86it/s, episode=200, return=-577.417]

 

我们发现DQN算法在倒立摆环境中能取得不错的回报,最后的期望回报在-200左右,但是不少Q值超过了0,有些还超过10,这一现象表明了DQN算法会对Q值过高估计。 

9.训练Double DQN 并打印学习过程中最大Q值情况

random.seed(10)
np.random.seed(10)
torch.manual_seed(10)
replaybuffer = ReplayBuffer(buffersize)
agent = DQN(state_dim=state_dim, hidden_dim=hidden_dim, action_dim=action_dim, learning_rate=lr, gamma=gamma,
             epsilon=epsilon, target_update=target_update, device=device, dqntype='DoubleDQN')
returnlist, maxqvaluelist= train_DQN(agent=agent, env=env, num_episodes=episodesnum, pbarnum=pbarnum,
                                      printreturnnum=printreturnnum, replay_buffer=replaybuffer,
                                      minimal_size=minimalsize, batch_size=batch_size)
episodelist = list(range(len(returnlist)))
mvreturn = moving_average(returnlist, 5)
plt.plot(episodelist, mvreturn)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(' Double DQN on {}'.format(env.spec.name))
plt.show()

frameslist = list(range(len(maxqvaluelist)))
plt.plot(frameslist, maxqvaluelist)
plt.axhline(y=0, c='orange', ls='--')
plt.axhline(y=10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title(' Double DQN on {}'.format(env.spec.name))
plt.show()
env.close()

10.Double DQN训练结果可视化以及结论

Iteration 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.64it/s, episode=20, return=-1113.004]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.90it/s, episode=40, return=-731.259] 
Iteration 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00,  1.89it/s, episode=60, return=-487.520] 
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.22it/s, episode=80, return=-346.293] 
Iteration 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.11it/s, episode=100, return=-353.339] 
Iteration 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.20it/s, episode=120, return=-313.014] 
Iteration 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.20it/s, episode=140, return=-433.954] 
Iteration 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.18it/s, episode=160, return=-356.992] 
Iteration 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.08it/s, episode=180, return=-320.773] 
Iteration 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.18it/s, episode=200, return=-386.794]

 

发现与普通DQN相比,Double DQN比较少出现Q值大于0的情况,说明Q值过高估计的问题得到了很大缓解。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1987127.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode75-5 反转字符串的元音字母

本质上来说就是反转字符串 一部分需要反转 一部分不动 思路: 1.用String字符串倒序拼接 就是过滤掉不是元音字符 然后把所有的字符&#xff08;非元音的直接复制过来 元音字母直接从反转的字符串里边复制即可&#xff09; 2.看了题解发现自己写的啰嗦了 就是一个双指针问题用…

螺旋矩阵

螺旋矩阵 思路&#xff1a; 这题是一个模拟的题目。 可以观察出一些性质&#xff1a;每次需要换方向的时候都是到达了边界&#xff08;长度和宽度的边界&#xff09;。 不知道怎么转化为代码&#xff01; 哭了 看看题解吧&#xff1a;真不会 看到一个太妙的方法了&#x…

《变形金刚》战斗力排名分析

Top1 天火擎天柱 作为博派的领袖&#xff0c;擎天柱本身实力不凡。然而&#xff0c;胜败乃兵家常事。在《变形金刚2》中&#xff0c;他虽然成功击败了破坏者、碾碎器和红蜘蛛&#xff0c;却不幸被威震天一炮穿心&#xff0c;阵亡。 不过&#xff0c;擎天柱是《变形金刚》系列…

Zero123 论文学习

论文链接&#xff1a;https://arxiv.org/abs/2303.11328 代码链接&#xff1a;https://github.com/cvlab-columbia/zero123 解决了什么问题&#xff1f; 人类通常能够仅凭一个相机视角来想象物体的三维形状和外观。这种能力对于日常任务非常重要&#xff0c;例如物体操纵和在…

快速掌握Vue:基础命令详解

目录 1. Vue概述 2. 快速入门 3. Vue指令 3.1 v-bind 3.2 v-model 3.3 v-on 3.4 v-if 3.5 v-show 3.6 v-for 3.7 案例 4. 生命周期 1. Vue概述 Vue.js&#xff08;读音 /vjuː/, 类似于 「view」&#xff09; 是一套构建用户界面的 「渐进式框架」。与其他重量级框…

【EI会议征稿通知】第六届光电科学与材料国际学术会议 (ICOSM 2024)

会议主要围绕“光电技术应用”“光电科学材料”“光电信号处理”“低温等离子体技术与应用” “激光技术与应用”“材料科学”等研究领域展开讨论。旨在为光电学、电子工程学等专家学者、工程技术人员、技术研发人员提供一个交流平台。拓展国内外光电科学与材料技术方面的研究范…

科普文:微服务之全文检索ElasticSearch忝删改查详细操作说明

一、Restful简介 RESTFul&#xff1a;Representational State Transfer&#xff0c;中文意思&#xff1a;表现层状态转化。变现层指的是资源的表现层&#xff0c;这里的资源是指网络上的信息&#xff0c;比如一张图片&#xff0c;一段文本&#xff0c;一步电影&#xff0c;那么…

数据结构(学习)2024.8.6

今天开始学习数据结构的相关知识&#xff0c;大概分为了解数据结构、算法&#xff1b;学习线性表&#xff1a;顺序表、链表、栈、队列的相关知识和树&#xff1a;二叉树、遍历、创建&#xff0c;查询方法、排序方式等。 目录 一、数据结构 数据 逻辑结构 1.线性结构 2.树…

JavaEE: wait(等待) / notify (通知)

文章目录 wait(等待) / notify (通知)总结 wait(等待) / notify (通知) 线程在操作系统上的调度是随机的~ 那么我们想要控制线程之间执行某个逻辑的先后顺序,那该咋办呢? 可以让后执行的逻辑,使用wait, 先执行的线程,在完成某些逻辑之后,通过notify来唤醒对应的wait. 另外,通…

谈谈如何优雅地封装 el-table

效果 像这样的表格我们可以这样划分一下区域&#xff1a; 1区域的渲染是通过取反插槽的条件 2区域的渲染是写在 slot 插槽的内部的&#xff0c;直接显示行数据3区域的渲染是通过具名插槽 bind 渲染 直接上代码&#xff1a; 子组件&#xff1a; <template><el-tabl…

为什么要用分布式锁

单应用中,如果要确保多线程修改同一个资源的安全性 加synchronized就可以了 但是性能不高 而mybatis-plus的乐观锁就可以很好的解决这类问题 但是这样的锁机制,只在单应用中有效 试想,在分布式下,有没有可能出现多个应用中的线程同时去修改同一个数据资源的并发问题 例如A …

Golang | Leetcode Golang题解之第328题奇偶链表

题目&#xff1a; 题解&#xff1a; func oddEvenList(head *ListNode) *ListNode {if head nil {return head}evenHead : head.Nextodd : headeven : evenHeadfor even ! nil && even.Next ! nil {odd.Next even.Nextodd odd.Nexteven.Next odd.Nexteven even.N…

65 生成器函数设计要点

包含 yield 语句的函数可以用来创建生成器对象&#xff0c;这样的函数也称为生成器函数。yield 语句与 return 语句的作用相似&#xff0c;都是用来从函数中返回值。return 语句一旦执行会立刻结束函数的运行&#xff0c;而每次执行到 yield 语句返回一个值之后会暂停或挂起后面…

无人机无线电监测设备技术分析

随着无人机技术的飞速发展&#xff0c;其在民用、军事、科研及娱乐等领域的广泛应用&#xff0c;对无线电频谱资源的有效管理和监测提出了更高要求。无人机无线电监测设备作为保障空域安全、维护无线电秩序的重要工具&#xff0c;集成了高精度定位、频谱扫描、信号分析、数据处…

stm32应用、项目

主要记录实际使用中的一些注意点。 1.LCD1602 电路图&#xff1a; 看手册&#xff1a;电源和背光可以使用5v或者3.3v&#xff0c;数据和控制引脚直接和单片机引脚连接即可。 单片机型号&#xff1a;stm32c031c6t6 可以直接使用推完输出连接D0--D7,RS,EN,RW引脚&#xff0c;3…

大数据面试SQL(二):每天最高峰同时直播人数

文章目录 每天最高峰同时直播人数 一、题目 二、分析 三、SQL实战 四、样例数据参考 每天最高峰同时直播人数 一、题目 有如下数据记录直播平台主播上播及下播时间&#xff0c;根据该数据计算出平台当天最高峰同时直播人数。 这里用主播名称做统计&#xff0c;前提是主…

Flask+LayUI开发手记(一):LayUI表格的前端数据分页展现

用数据表格table展示系统数据&#xff0c;是LayUI的基本功能&#xff0c;编码十分简单&#xff0c;就是通过table.render()渲染&#xff0c;把属性配置好就OK了&#xff0c;十分方便&#xff0c;功能也十分强大。 不过&#xff0c;在实现时&#xff0c;把table的有个功能却理解…

WPF MVVM实现TreeView层级显示

最近在写一个小工具的时候&#xff0c;遇到TreeView的层级显示&#xff0c;刚好我又用了MVVM模式&#xff0c;所以这里做个总结。 以前我是直接绑定XML数据到TreeView的&#xff0c;使用的XmlDataProvider&#xff0c;这次的数据是直接来自数据库的。 用到的都是Hierarchical…

Element学习(入门)(1)

1、Element官网&#xff1a;https://element.eleme.cn/#/zh-CN 2、来源与用处 3、Element的快速入门 &#xff08;1&#xff09; &#xff08;2&#xff09;在入口文件&#xff08;main.js&#xff09;中引入 &#xff08;3&#xff09; 4、快捷键ctrlc&#xff0c;在当前的项目…

【SpringBoot】自定义注解 I18n <约定式>国际化 (源码分享直接Copy)

0. 已做全新升级版 链接&#xff1a;【SpringBoot】自定义注解终极升级版&#xff1c;i18n国际化&#xff1e;方案源码Copy 链接&#xff1a;【SpringBoot】自定义注解终极升级版&#xff1c;i18n国际化&#xff1e;方案源码Copy 链接&#xff1a;【SpringBoot】自定义注解终…