伪代码
如算法描述,dqn即深度q网络和记忆池
初始化记忆池和可以容纳的数量N
动作价值函数Q使用随机权重进行初始化。
目标动作价值函数Q′也使用相同的权重进行初始化,即Q′=Q。
循环训练M局
初始化和预处理观察到的状态
每局循环训练T步
采用e的概率随机选取动作
其他情况选择q值最大的动作
执行动作并且得到下一步的状态和奖励
更新下一步的状态,并进行预处理
将其s,r,a,s_存入记忆池
从记忆池随机选取minbatch的数据
对于每个转换,计算目标值yj。如果第j+1步是终止状态,则yj=rj;否则,使用贝尔曼方程计算yj=rj+γmaxaQ(sj+1,a;θ′),其中γ是折扣因子。
计算损失函数loss=loss=(yj−Q(sj,aj;θ))2,并通过反向传播更新网络参数θ。
每进行c步更新策略网络
游戏环境介绍
环境描述:CartPole-v1是一个物理模拟环境,其中有一个水平杆(pole)固定在一个移动的推车(cart)上。目标是保持杆子竖直,防止它倒向任何一侧或推车移动到轨道的尽头。
-
终止条件:环境会在以下情况下终止:
- 杆子的角度超过15度(±15°)。
- 推车的位置超出轨道的中心线一定距离(通常是2.4个单位)
-
动作空间(Action Space):在这个环境中,智能体可以选择两个离散动作之一:
- 向左推车(0)
- 向右推车(1)
代码复现
# describe : dqn算法训练流程
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from agent import QNetwork
from agent import ReplayBuffer
# 创建CartPole环境
env = gym.make('CartPole-v1')
# 超参数
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
learning_rate = 0.001 # 神经网络优化算法中使用的学习能力,用于调整网络权重的步长
gamma = 0.99 # reward discount
epsilon = 1.0 # 探索率
epsilon_decay = 0.995 # epsilon衰减率
epsilon_min = 0.01 # 最小探索率
episodes = 300
steps = 300
batch_size = 64
memory_size = 2000
# 初始化记忆池
replay_buffer = ReplayBuffer(memory_size)
# 初始化Q网络和优化器
q_network = QNetwork(state_size, action_size)
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# 初始化奖励数组
total_rewards = []
for episode in range(episodes):
state, info = env.reset()
state = torch.FloatTensor(state).unsqueeze(0) # 预处理,将 state 转换为 PyTorch 的 FloatTensor,这是神经网络处理所需的数据类型
total_reward = 0
for step in range(steps):
# epsilon-greedy策略选择动作
if np.random.rand() <= epsilon:
action = np.random.randint(action_size)
else:
with torch.no_grad():
q_values = q_network(state)
action = q_values.max(1)[1].item()
# 执行动作
next_state, reward, done, _, _ = env.step(action)
next_state = torch.FloatTensor(next_state).unsqueeze(0)
# 存储经验
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
if done:
total_rewards.append(total_reward)
print(f"Episode: {episode}, Total Reward: {total_reward}")
break
# 经验回放
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
# 将采样的状态序列 states 合并成一个张量,以便于批量处理。
states = torch.cat(states)
actions = torch.LongTensor(actions).unsqueeze(1)
rewards = torch.FloatTensor(rewards).unsqueeze(1)
next_states = torch.cat(next_states)
dones = torch.FloatTensor(dones).unsqueeze(1)
# 批量N,计算当前状态下不同动作的Q值,取选择的动作对应的Q值
current_q_values = q_network(states).gather(1, actions)
# 批量N,计算下一个状态取得的最大的Q值
next_q_values = q_network(next_states).max(1)[0].unsqueeze(1)
target_q_values = rewards + (gamma * next_q_values * (1 - dones))
# 使用损失函数计算当前 Q 值和目标 Q 值之间的差异。
loss = criterion(current_q_values, target_q_values)
# 清除之前的梯度,为新的梯度更新做准备。
optimizer.zero_grad()
# 计算损失函数关于网络参数的梯度。
loss.backward()
# 根据计算出的梯度更新网络的权重。
optimizer.step()
# epsilon衰减
if epsilon > epsilon_min:
epsilon *= epsilon_decay
# 绘制奖励图
env.close()
plt.plot(total_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Rewards per Episode in CartPole-v1')
plt.show()
# describe : 定义经验回放缓冲区和Q网络
import random
import torch
import torch.nn as nn
# 经验回放缓冲区
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = []
self.capacity = capacity
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
# 用于从缓冲区中随机采样一批经验。
def sample(self, batch_size):
return zip(*random.sample(self.buffer, batch_size))
def __len__(self):
return len(self.buffer)
# Q网络
class QNetwork(nn.Module):
def __init__(self, state_size, action_size):
super(QNetwork, self).__init__()
# 定义了第一个全连接层,将输入状态的特征从 state_size 映射到 24 个特征。
self.fc1 = nn.Linear(state_size, 24)
# 定义了第二个全连接层,将 24 个特征再次映射到 24 个特征。
self.fc2 = nn.Linear(24, 24)
# 定义了第三个全连接层,将 24 个特征映射到输出层,其大小等于可能的动作数量。
self.fc3 = nn.Linear(24, action_size)
def forward(self, x):
# 应用第一个全连接层并使用 ReLU 激活函数。
x = torch.relu(self.fc1(x))
# 应用第二个全连接层并使用 ReLU 激活函数。
x = torch.relu(self.fc2(x))
# 应用第三个全连接层,不使用激活函数,因为 Q 网络的输出通常不需要激活函数。
x = self.fc3(x)
# 返回q值
return x