0.简介
DQN算法敲开深度强化学习大门,但是其存在着一些问题,有进一步改进的空间。因此在DQN后出现大量改进算法。在此介绍DQN算法改进算法之一Double DQN,其在DQN算法基础上稍加修改实现在一定程度上改善DQN效果。
普通DQN算法会导致对Q值的过高估计,即overestimation,传统DQN优化的TD误差目标为
而Double DQN的优化目标为:
显然DQN 与Double DQN差别在于计算状态s'下的Q值如何选择动作。
DQN优化目标可以写成,动作的选取依靠目标网络。
Double DQN优化目标可以写成,动作选取依靠训练网络。
1.必要库导入
import random
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
import collections
import rl_utils
from tqdm import tqdm
2.Q网络和经验回放池实现
class ReplayBuffer:
'''经验回放池'''
def __init__(self,capacity):
self.buffer=collections.deque(maxlen=capacity)#队列,先进先出
def add(self,state,action,reward,nextstate,done):
self.buffer.append((state,action,reward,nextstate,done))
def sample(self,batch_size):
transitions=random.sample(self.buffer,batch_size)
state,action,reward,nextstate,done=zip(*transitions)
return np.array(state),action,reward,np.array(nextstate),done
def size(self):
return len(self.buffer)
class Qnet(torch.nn.Module):
""" 只有一层隐藏层的Q网络 """
def __init__(self,state_dim,hidden_dim,action_dim):
super(Qnet,self).__init__()
self.fc1=torch.nn.Linear(state_dim,hidden_dim)
self.fc2=torch.nn.Linear(hidden_dim,action_dim)
def forward(self,x):
x=torch.nn.functional.relu(self.fc1(x))
return self.fc2(x)
3.Double DQN算法实现
class DQN:
""" DQN算法,包括Double DQN """
def __init__(self,state_dim,hidden_dim,action_dim,learning_rate,gamma,epsilon,target_update,device,dqntype='VanillaDQN'):
self.action_dim=action_dim
self.gamma=gamma
self.epsilon=epsilon
self.target_update=target_update
self.device=device
self.dqntype=dqntype
self.count=0
self.qnet=Qnet(state_dim,hidden_dim,self.action_dim).to(self.device)
self.targetqnet=Qnet(state_dim,hidden_dim,self.action_dim).to(self.device)
self.optimizer=torch.optim.Adam(self.qnet.parameters(),lr=learning_rate)
def takeaction(self,state):
if np.random.random()<self.epsilon:
action=np.random.randint(self.action_dim)
else:
state=torch.tensor([state],dtype=torch.float).to(self.device)
action=self.qnet(state).argmax().item()
return action
def maxqvalue(self,state):#计算最大Q值
state=torch.tensor([state],dtype=torch.float).to(self.device)
return self.qnet(state).max().item()
def update(self,transition_dict):#更新Q值
states=torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
actions=torch.tensor(transition_dict['actions']).view(-1,1).to(self.device)
rewards=torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1,1).to(self.device)
nextstates=torch.tensor(transition_dict['nextstates'],dtype=torch.float).to(self.device)
dones=torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1,1).to(self.device)
qvalues=self.qnet(states).gather(1,actions)#在 gather(1, actions) 中:第一个参数 1 表示沿着维度 1 (列维度)进行收集操作。actions 是一个张量,作为收集的索引。
if self.dqntype=='DoubleDQN':
maxaction=self.qnet(nextstates).max(1)[1].view(-1,1)#max(1)表示沿着维度1求最大值,返回两个值,第一个值为每行最大值,第二个值为每行最大值所在的索引。view(-1,1)将其转化为列向量
maxnextqvalues=self.targetqnet(nextstates).gather(1,maxaction)
else:#DQN
maxnextqvalues=self.targetqnet(nextstates).max(1)[0].view(-1,1)
qtargets=rewards+self.gamma*maxnextqvalues*(1-dones)
dqnloss=torch.mean(torch.nn.functional.mse_loss(qvalues,qtargets))
self.optimizer.zero_grad()
dqnloss.backward()
self.optimizer.step()#优化器根据计算得到的梯度进行参数更新
if self.count%self.target_update==0:
self.targetqnet.load_state_dict(self.qnet.state_dict())
self.count+=1
4.设置相应的超参数
lr=1e-2
episodesnum=200
hidden_dim=128
gamma=0.98
epsilon=0.01
target_update=50
buffersize=5000
minimalsize=1000
batch_size=64
pbarnum=10
printreturnnum=10
5.实现倒立摆环境中连续动作转化为离散动作函数
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
env=gym.make('Pendulum-v1')
env.reset(seed=10)
# env.render()
state_dim=env.observation_space.shape[0]#print(env.observation_space):Box([-1. -1. -8.], [1. 1. 8.], (3,), float32);print(env.observation_space.shape):(3,)
action_dim=11#将连续动作分成11个离散动作
def dis_to_con(actionid,env,action_dim):
action_lowbound=env.action_space.low[0]#连续动作的最小值
action_upbound=env.action_space.high[0]#连续动作的最大值
return action_lowbound+(action_upbound-action_lowbound)/(action_dim-1)*actionid
6.定义一个DQN算法训练过程的函数,以便后续重复多次调用。训练过程中会记录每个状态下的最大Q值,训练后可结果可进行可视化,观测这些Q值存在的过高估计的情况,以此对比DQN与Double DQN差异。
def train_DQN(agent,env,num_episodes,pbarnum,printreturnnum,replay_buffer,minimal_size,batch_size):
returnlist=[]
maxqvaluelist=[]
maxqvalue=0
for i in range(pbarnum):
with tqdm(total=int(num_episodes/pbarnum),desc='Iteration %d'%i) as pbar:
for episode in range(int(num_episodes/pbarnum)):
episodereturn=0
state,info=env.reset(seed=10)
done=False
while not done:
action=agent.takeaction(state)
# env.render()
maxqvalue=agent.maxqvalue(state)*0.005+maxqvalue*0.995#平滑处理
maxqvaluelist.append(maxqvalue)
action_continuous=dis_to_con(action,env,agent.action_dim)
nextstate,reward,done,truncated,_=env.step([action_continuous])
done=done or truncated
replay_buffer.add(state,action,reward,nextstate,done)
state=nextstate
episodereturn+=reward
if replay_buffer.size()>minimal_size:
bs,ba,br,bns,bd=replay_buffer.sample(batch_size)
transition_dict={'states':bs,'actions':ba,'rewards':br,'nextstates':bns,'dones':bd}
agent.update(transition_dict)
returnlist.append(episodereturn)
if(episode+1)%printreturnnum==0:
pbar.set_postfix({'episode':'%d'%(num_episodes/pbarnum*i+episode+1),'return':'%.3f'%np.mean(returnlist[-printreturnnum:])})
pbar.update(1)
return returnlist,maxqvaluelist
7.训练DQN并打印打印其学习过程中最大Q值的情况
random.seed(10)
np.random.seed(10)
torch.manual_seed(10)
replaybuffer = ReplayBuffer(buffersize)
agent = DQN(state_dim=state_dim, hidden_dim=hidden_dim, action_dim=action_dim, learning_rate=lr, gamma=gamma,
epsilon=epsilon, target_update=target_update, device=device, dqntype='VanillaDQN')
returnlist, maxqvaluelist= train_DQN(agent=agent, env=env, num_episodes=episodesnum, pbarnum=pbarnum,
printreturnnum=printreturnnum, replay_buffer=replaybuffer,
minimal_size=minimalsize, batch_size=batch_size)
episodelist = list(range(len(returnlist)))
mvreturn = moving_average(returnlist, 5)
plt.plot(episodelist, mvreturn)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('DQN on {}'.format(env.spec.name))
plt.show()
frameslist = list(range(len(maxqvaluelist)))
plt.plot(frameslist, maxqvaluelist)
plt.axhline(y=0, c='orange', ls='--')
plt.axhline(y=10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title('DQN on {}'.format(env.spec.name))
plt.show()
env.close()
下面则是一个平滑移动计算回报的函数实现
def moving_average(a, window_size):
cumulative_sum = np.cumsum(np.insert(a, 0, 0))
middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
r = np.arange(1, window_size-1, 2)
begin = np.cumsum(a[:window_size-1])[::2] / r
end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
return np.concatenate((begin, middle, end))
8.DQN训练结果可视化以及结论
Iteration 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00, 2.61it/s, episode=20, return=-1195.945]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.91it/s, episode=40, return=-870.678]
Iteration 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.90it/s, episode=60, return=-442.819]
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.17it/s, episode=80, return=-371.271]
Iteration 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.99it/s, episode=100, return=-419.501]
Iteration 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.06it/s, episode=120, return=-388.722]
Iteration 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00, 2.47it/s, episode=140, return=-947.871]
Iteration 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00, 2.40it/s, episode=160, return=-287.627]
Iteration 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.03it/s, episode=180, return=-747.758]
Iteration 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.86it/s, episode=200, return=-577.417]
我们发现DQN算法在倒立摆环境中能取得不错的回报,最后的期望回报在-200左右,但是不少Q值超过了0,有些还超过10,这一现象表明了DQN算法会对Q值过高估计。
9.训练Double DQN 并打印学习过程中最大Q值情况
random.seed(10)
np.random.seed(10)
torch.manual_seed(10)
replaybuffer = ReplayBuffer(buffersize)
agent = DQN(state_dim=state_dim, hidden_dim=hidden_dim, action_dim=action_dim, learning_rate=lr, gamma=gamma,
epsilon=epsilon, target_update=target_update, device=device, dqntype='DoubleDQN')
returnlist, maxqvaluelist= train_DQN(agent=agent, env=env, num_episodes=episodesnum, pbarnum=pbarnum,
printreturnnum=printreturnnum, replay_buffer=replaybuffer,
minimal_size=minimalsize, batch_size=batch_size)
episodelist = list(range(len(returnlist)))
mvreturn = moving_average(returnlist, 5)
plt.plot(episodelist, mvreturn)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(' Double DQN on {}'.format(env.spec.name))
plt.show()
frameslist = list(range(len(maxqvaluelist)))
plt.plot(frameslist, maxqvaluelist)
plt.axhline(y=0, c='orange', ls='--')
plt.axhline(y=10, c='red', ls='--')
plt.xlabel('Frames')
plt.ylabel('Q value')
plt.title(' Double DQN on {}'.format(env.spec.name))
plt.show()
env.close()
10.Double DQN训练结果可视化以及结论
Iteration 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00, 2.64it/s, episode=20, return=-1113.004]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.90it/s, episode=40, return=-731.259]
Iteration 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.89it/s, episode=60, return=-487.520]
Iteration 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.22it/s, episode=80, return=-346.293]
Iteration 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.11it/s, episode=100, return=-353.339]
Iteration 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.20it/s, episode=120, return=-313.014]
Iteration 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.20it/s, episode=140, return=-433.954]
Iteration 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.18it/s, episode=160, return=-356.992]
Iteration 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.08it/s, episode=180, return=-320.773]
Iteration 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00, 2.18it/s, episode=200, return=-386.794]
发现与普通DQN相比,Double DQN比较少出现Q值大于0的情况,说明Q值过高估计的问题得到了很大缓解。