1、环境准备,gym的版本为0.26.2
2、编写网络代码
# 导入必要的库
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
# 定义DQN网络
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
# 定义三层全连接网络
self.fc1 = nn.Linear(state_size, 24)
self.fc2 = nn.Linear(24, 24)
self.fc3 = nn.Linear(24, action_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# 定义DQN智能体
class DQNAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.memory = deque(maxlen=2000) # 经验回放池
self.gamma = 0.95 # 折扣因子
self.epsilon = 1.0 # 探索率
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.001
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = DQN(state_size, action_size).to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
def remember(self, state, action, reward, next_state, done):
# 将经验存储到经验回放池中
self.memory.append((state, action, reward, next_state, done))
def act(self, state):
# ε-贪婪策略选择动作
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
act_values = self.model(state)
return np.argmax(act_values.cpu().data.numpy())
def replay(self, batch_size):
# 从经验回放池中随机采样进行学习
minibatch = random.sample(self.memory, batch_size)
for state, action, reward, next_state, done in minibatch:
target = reward
if not done:
next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
target = (reward + self.gamma * np.amax(self.model(next_state).cpu().data.numpy()))
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
target_f = self.model(state)
target_f[0][action] = target
self.optimizer.zero_grad()
loss = nn.MSELoss()(self.model(state), target_f)
loss.backward()
self.optimizer.step()
# 更新探索率
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
def load(self, name):
# 加载模型
self.model.load_state_dict(torch.load(name))
def save(self, name):
# 保存模型
torch.save(self.model.state_dict(), name)
# 训练函数
def train_dqn():
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
episodes = 1000
batch_size = 32
for e in range(episodes):
state, _ = env.reset() #重置环境,返回初始观察值和初始奖励
for time in range(500):
action = agent.act(state)
next_state, reward, done, _, _ = env.step(action) # 执行动作,返回5个数值
reward = reward if not done else -10 # 如果游戏结束,给予负奖励
agent.remember(state, action, reward, next_state, done)
state = next_state
if done:
print(f"episode: {e}/{episodes}, score: {time}, epsilon: {agent.epsilon:.2}")
break
if len(agent.memory) > batch_size:
agent.replay(batch_size)
if e % 100 == 0:
agent.save(f"cartpole-dqn-{e}.pth") # 每100回合保存一次模型
# 使用训练好的模型玩游戏
def play_cartpole():
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
agent.load("cartpole-dqn-900.pth") # 加载训练好的模型
for e in range(10): # 玩10局
state, _ = env.reset()
for time in range(500):
env.render()
action = agent.act(state)
next_state, reward, done, _, _= env.step(action)
state = next_state
if done:
print(f"episode: {e}, score: {time}")
break
env.close()
if __name__ == '__main__':
# 如果要训练模型,取消下面这行的注释
# train_dqn()
# 如果要使用训练好的模型玩游戏,取消下面这行的注释
play_cartpole()
更多解析请参考:https://zhuanlan.zhihu.com/p/29283993
https://zhuanlan.zhihu.com/p/29213893