代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque
# 定义 Actor 网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
action = self.max_action * torch.tanh(self.fc3(x)) # 输出在 [-max_action, max_action] 范围内
return action
# 定义 Critic 网络
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.fc1 = nn.Linear(state_dim + action_dim, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
def forward(self, state, action):
x = torch.cat([state, action], 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
q_value = self.fc3(x)
return q_value
# DDPG 算法
class DDPG:
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
self.max_action = max_action
self.replay_buffer = deque(maxlen=1000000)
self.batch_size = 256
self.gamma = 0.99
self.tau = 0.005
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
return self.actor(state).cpu().data.numpy().flatten()
def add_to_buffer(self, state, action, reward, next_state, done):
self.replay_buffer.append((state, action, reward, next_state, done))
def train(self):
if len(self.replay_buffer) < self.batch_size:
return
# 从 Replay Buffer 中采样
batch = random.sample(self.replay_buffer, self.batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
state = torch.FloatTensor(state).to(device)
action = torch.FloatTensor(action).to(device)
reward = torch.FloatTensor(reward).unsqueeze(1).to(device)
next_state = torch.FloatTensor(next_state).to(device)
done = torch.FloatTensor(done).unsqueeze(1).to(device)
# 更新 Critic 网络
with torch.no_grad():
next_action = self.actor_target(next_state)
target_q = self.critic_target(next_state, next_action)
target_q = reward + (1 - done) * self.gamma * target_q
current_q = self.critic(state, action)
critic_loss = nn.MSELoss()(current_q, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 更新 Actor 网络
actor_loss = -self.critic(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 软更新目标网络
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
# 训练 DDPG 算法
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('MountainCarContinuous-v0') # 使用 MountainCarContinuous-v0 环境
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
ddpg = DDPG(state_dim, action_dim, max_action)
max_episodes = 1000
max_steps = 200
for episode in range(max_episodes):
state = env.reset()
episode_reward = 0
for step in range(max_steps):
env.render() # 渲染环境
action = ddpg.select_action(state)
next_state, reward, done, _ = env.step(action)
ddpg.add_to_buffer(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
ddpg.train()
if done:
break
print(f"Episode {episode + 1}, Reward: {episode_reward}")
env.close()
简介
在强化学习领域中,很多实际问题涉及到连续的动作空间,比如机器人的关节角度控制、自动驾驶中车辆的速度和转向控制等。传统的基于策略梯度的算法(如 A2C、A3C 等)以及基于值函数的算法(如 DQN 及其变体)在处理连续动作空间时往往面临诸多困难,DDPG 旨在有效地解决在连续动作空间下进行策略学习的问题。