本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:
强化学习(5)---《【MADRL】基于MADRL的单调价值函数分解(QMIX)算法》
【MADRL】基于MADRL的单调价值函数分解(QMIX)算法
目录
0. 前言
1. 背景与挑战
2. QMIX算法架构
3. 算法训练过程
4. QMIX的优势
5. QMIX的应用
6. 局限性与改进
[Python] QMIX实现(可移植)
0. 前言
基于MADRL的单调价值函数分解(Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning)QMIX 是一种用于多智能体强化学习的算法,特别适用于需要协作的多智能体环境,如分布式控制、团队作战等场景。QMIX 算法由 Rashid 等人在 2018 年提出,其核心思想是通过一种混合网络(Mixing Network)来对各个智能体的局部 Q 值进行非线性组合,从而得到全局 Q 值。
算法原文:Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning
算法程序例程
个人成功移植的算法程序和注释在下文
1. 背景与挑战
在多智能体强化学习中,每个智能体都需要基于自身的观测和经验来学习策略。在一个协作环境中,多个智能体的决策往往相互影响,因此仅考虑单个智能体的 Q 值并不足够。直接对整个系统的 Q 值进行建模在计算上是不可行的,因为状态和动作空间会随着智能体数量呈指数增长。
2. QMIX算法架构
QMIX算法由以下几个核心组件组成:
2.1 局部Q网络(Individual Q Networks)
- 每个智能体都有一个独立的局部Q网络,该网络输入智能体的局部观测和动作,输出该智能体的局部Q值。
- 局部Q网络可以使用任何深度神经网络结构来表示,如卷积神经网络(CNN)或前馈神经网络(FNN),根据任务的具体需求进行选择。
2.2 混合网络(Mixing Network)
- 混合网络的作用是将各个智能体的局部Q值进行组合,生成全局Q值。该网络的结构是一个完全连接的神经网络,由一组参数化的权重和偏置决定。
- 混合网络的输入是所有智能体的局部Q值 以及全局状态 (s)(在训练过程中使用)。输出是全局Q值 。
- 单调性约束:混合网络的设计要求全局Q值 对于各个局部Q值 是单调非减函数。这意味着,任意一个局部Q值的增加不会导致全局Q值的减小。该约束通过使用非负的权重来实现。
2.3 全局Q值的计算
混合网络根据以下公式计算全局Q值:
其中,(f) 表示混合网络的映射函数,(n) 是智能体的数量,(s) 是全局状态信息。
3. 算法训练过程
QMIX的训练基于Q-learning的框架,具体步骤如下:
3.1 经验采集(Experience Collection)
在每个时间步,所有智能体根据当前策略选择动作,并与环境交互,收集经验样本,其中表示所有智能体的观测集合,表示所有智能体的动作集合,(r) 是全局奖励,(s') 是下一个状态。
3.2 目标Q值计算(Target Q Calculation)
计算下一个状态 (s') 下的目标Q值:
其中, 是折扣因子,是目标网络的参数(使用延迟更新策略)。
3.3 损失函数与优化(Loss Function and Optimization)
通过最小化TD误差来更新混合网络和局部Q网络的参数:
使用反向传播和随机梯度下降(SGD)来更新网络参数。
3.4 目标网络的更新
为了稳定训练过程,QMIX使用了目标网络。目标网络的参数以较低的频率从当前网络的参数 复制而来。
4. QMIX的优势
- 协作性:通过全局Q值的优化,QMIX能够有效捕捉智能体之间的协作关系。
- 可扩展性:由于混合网络的设计,QMIX可以扩展到更多智能体的环境中,而不会因为联合动作空间的指数级增长而受到影响。
- 灵活性:通过非线性混合网络,QMIX能够处理复杂的协作任务,而不仅限于线性组合策略。
5. QMIX的应用
- 分布式机器人控制:在多个机器人需要协作完成任务的场景下,QMIX可以学习到有效的协作策略。
- 团队游戏AI:在需要团队协作的游戏中,QMIX被广泛应用于训练复杂的多智能体AI。
- 资源分配与管理:在智能电网或多无人机系统中,QMIX能够有效处理多智能体之间的资源协调问题。
6. 局限性与改进
- 策略的表达能力受限:由于单调性约束,QMIX可能无法表达某些复杂的非线性策略。
- 样本效率:在高维环境中,QMIX对样本的需求较大,训练时间较长。
- 改进方法:后续的算法如QTRAN、QPLEX等在不同程度上尝试解决这些局限性,进一步提升了多智能体强化学习的性能。
[Python] QMIX实现(可移植)
若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。
主函数文件:
"""
@content: QMIX
@author: 不去幼儿园
@Timeline: 2024.08.21
"""
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from env_base import Env # @移植事项:导入环境
import argparse
from replay_buffer import ReplayBuffer # @移植事项:导入其他类
from qmix_smac import QMIX_SMAC
from normalization import Normalization
class Runner_QMIX_SMAC:
def __init__(self, args, env_name, number, seed):
self.args = args
self.env_name = env_name
self.number = number
self.seed = seed
# Set random seed
np.random.seed(self.seed)
torch.manual_seed(self.seed)
# Create env
# self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed)
"""@移植事项
1.环境声明
2.环境参数设置:注意相关参数的格式
"""
self.env = Env() # @移植事项:环境声明
# self.env_info = self.env.get_env_info()
self.args.N = 3 # The number of agents
self.args.obs_dim = 15 # The dimensions of an agent's observation space
self.args.state_dim = 100+9*3+2*3 # The dimensions of global state space
self.args.action_dim = 9 # The dimensions of an agent's action space
self.args.episode_limit = 50 # Maximum number of steps per episod
print("number of agents={}".format(self.args.N))
print("obs_dim={}".format(self.args.obs_dim))
print("state_dim={}".format(self.args.state_dim))
print("action_dim={}".format(self.args.action_dim))
print("episode_limit={}".format(self.args.episode_limit))
# Create N agents
self.agent_n = QMIX_SMAC(self.args)
self.replay_buffer = ReplayBuffer(self.args)
# Create a tensorboard
self.writer = SummaryWriter(log_dir='./runs/{}/{}_env_{}_number_{}_seed_{}'.format(self.args.algorithm, self.args.algorithm, self.env_name, self.number, self.seed))
self.epsilon = self.args.epsilon # Initialize the epsilon
self.win_rates = [] # Record the win rates
self.total_steps = 0
if self.args.use_reward_norm:
print("------use reward norm------")
self.reward_norm = Normalization(shape=1)
def run(self, ):
evaluate_num = -1 # Record the number of evaluations
while self.total_steps < self.args.max_train_steps:
if self.total_steps // self.args.evaluate_freq > evaluate_num:
self.evaluate_policy() # Evaluate the policy every 'evaluate_freq' steps
evaluate_num += 1
_, _, episode_steps = self.run_episode_smac(evaluate=False) # Run an episode
self.total_steps += episode_steps
if self.replay_buffer.current_size >= self.args.batch_size:
self.agent_n.train(self.replay_buffer, self.total_steps) # Training
self.evaluate_policy()
# self.env.close()
def evaluate_policy(self, ):
win_times = 0
evaluate_reward = 0
goal_num_buffer__ = []
for _ in range(self.args.evaluate_times):
win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True)
"""获取其他状态数据"""
goal_num_buffer_ = self.env.get_state_data() # @移植事项:其他状态获取
goal_num_buffer_ = np.array(goal_num_buffer_)
goal_num_buffer__.append(goal_num_buffer_)
if win_tag:
win_times += 1
evaluate_reward += episode_reward
goal_num_buffer = np.sum(goal_num_buffer__[:], axis=0) / self.args.evaluate_times
log_flag = ["state/target_num", "state/target_num", "state/crash_num",
"state/ratio"]
for i in range(4):
goal_num = goal_num_buffer[i]
goal_num = {log_flag[i]: goal_num}
log_state(name=log_flag[i], state=goal_num, step=self.total_steps)
win_rate = win_times / self.args.evaluate_times
evaluate_reward = evaluate_reward / self.args.evaluate_times
reward_total = {"state/reward_total": evaluate_reward}
log_state(name="state/reward_total", state=reward_total, step=self.total_steps)
self.win_rates.append(win_rate)
print("total_steps:{}\tepisode:{}\tevaluate_reward:{:.3f}\t"
"target_num:{:.3f}\ttarget_num:{:.3f}\tcrash_num:{:.3f}\tratio:{:.3f}"
.format(self.total_steps, int(self.total_steps / 1250 + 1), evaluate_reward,
goal_num_buffer[0], goal_num_buffer[1], goal_num_buffer[2], goal_num_buffer[3]))
# self.writer.add_scalar('win_rate_{}'.format(self.env_name), win_rate, global_step=self.total_steps)
# Save the win rates
np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.args.algorithm, self.env_name, self.number, self.seed), np.array(self.win_rates))
def run_episode_smac(self, evaluate=False):
win_tag = False
episode_reward = 0
"""移植事项:环境运行
1.环境重置函数设置
2.环境状态返回函数设置:注意格式
3.环境下一步更新:注意返回值
"""
self.env.reset() # @移植事项:环境重置函数
if self.args.use_rnn: # If use RNN, before the beginning of each episode,reset the rnn_hidden of the Q network.
self.agent_n.eval_Q_net.rnn_hidden = None
last_onehot_a_n = np.zeros((self.args.N, self.args.action_dim)) # Last actions of N agents(one-hot)
for episode_step in range(self.args.episode_limit):
obs_n = self.env.get_obs() # obs_n.shape=(N,obs_dim) # @移植事项:观测状态获取
s = self.env.get_state() # s.shape=(state_dim,) # @移植事项:状态获取
# avail_a_n = self.env.get_avail_actions() # Get available actions of N agents, avail_a_n.shape=(N,action_dim)
avail_a_n = [[1] * 9 for _ in range(3)]
epsilon = 0 if evaluate else self.epsilon
a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon)
last_onehot_a_n = np.eye(self.args.action_dim)[a_n] # Convert actions to one-hot vectors
_, r_, done_, info = self.env.step(a_n) # @移植事项:环境下一步更新
done = done_[0]
r = sum(list(np.array(r_).flatten()))
win_tag = True if done and 'battle_won' in info and info['battle_won'] else False
episode_reward += r
if not evaluate:
if self.args.use_reward_norm:
r = self.reward_norm(r)
""""
When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them;
dw means dead or win,there is no next state s';
but when reaching the max_episode_steps,there is a next state s' actually.
"""
if done and episode_step + 1 != self.args.episode_limit:
dw = True
else:
dw = False
# Store the transition
self.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw)
# Decay the epsilon
self.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_min
if done:
break
if not evaluate:
# An episode is over, store obs_n, s and avail_a_n in the last step
obs_n = self.env.get_obs() # @移植事项
s = self.env.get_state() # @移植事项
# avail_a_n = self.env.get_avail_actions()
avail_a_n = [[1] * 9 for _ in range(3)]
self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n)
return win_tag, episode_reward, episode_step + 1
# 运行结果展示
from tensorboardX import SummaryWriter
writer = SummaryWriter()
def log_state(name, state, step):
writer.add_scalars(name, state, step)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")
parser.add_argument("--max_train_steps", type=int, default=int(1e6), help=" Maximum number of training steps")
parser.add_argument("--evaluate_freq", type=float, default=1250, help="Evaluate the policy every 'evaluate_freq' steps")
parser.add_argument("--evaluate_times", type=float, default=5, help="Evaluate times")
parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")
parser.add_argument("--algorithm", type=str, default="QMIX", help="QMIX or VDN")
parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon")
parser.add_argument("--epsilon_decay_steps", type=float, default=50000, help="How many steps before the epsilon decays to the minimum")
parser.add_argument("--epsilon_min", type=float, default=0.05, help="Minimum epsilon")
parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network")
parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network")
parser.add_argument("--hyper_layers_num", type=int, default=1, help="The number of layers of hyper-network")
parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")
parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")
parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")
parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Orthogonal initialization")
parser.add_argument("--use_grad_clip", type=bool, default=True, help="Gradient clip")
parser.add_argument("--use_lr_decay", type=bool, default=False, help="use lr decay")
parser.add_argument("--use_RMS", type=bool, default=False, help="Whether to use RMS,if False, we will use Adam")
parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation")
parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent id into the observation")
parser.add_argument("--use_double_q", type=bool, default=True, help="Whether to use double q-learning")
parser.add_argument("--use_reward_norm", type=bool, default=False, help="Whether to use reward normalization")
parser.add_argument("--use_hard_update", type=bool, default=True, help="Whether to use hard update")
parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network")
parser.add_argument("--tau", type=int, default=0.005, help="If use soft update")
args = parser.parse_args()
args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_steps
env_names = ['3m', '8m', '2s3z']
env_index = 0
runner = Runner_QMIX_SMAC(args, env_name=env_names[env_index], number=1, seed=0)
runner.run()
from replay_buffer import ReplayBuffer
replay_buffer.py文件
import numpy as np
import torch
import copy
class ReplayBuffer:
def __init__(self, args):
self.N = args.N
self.obs_dim = args.obs_dim
self.state_dim = args.state_dim
self.action_dim = args.action_dim
self.episode_limit = args.episode_limit
self.buffer_size = args.buffer_size
self.batch_size = args.batch_size
self.episode_num = 0
self.current_size = 0
self.buffer = {'obs_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.obs_dim]),
's': np.zeros([self.buffer_size, self.episode_limit + 1, self.state_dim]),
'avail_a_n': np.ones([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]), # Note: We use 'np.ones' to initialize 'avail_a_n'
'last_onehot_a_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]),
'a_n': np.zeros([self.buffer_size, self.episode_limit, self.N]),
'r': np.zeros([self.buffer_size, self.episode_limit, 1]),
'dw': np.ones([self.buffer_size, self.episode_limit, 1]), # Note: We use 'np.ones' to initialize 'dw'
'active': np.zeros([self.buffer_size, self.episode_limit, 1])
}
self.episode_len = np.zeros(self.buffer_size)
def store_transition(self, episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw):
self.buffer['obs_n'][self.episode_num][episode_step] = obs_n
self.buffer['s'][self.episode_num][episode_step] = s
self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n
self.buffer['last_onehot_a_n'][self.episode_num][episode_step + 1] = last_onehot_a_n
self.buffer['a_n'][self.episode_num][episode_step] = a_n
self.buffer['r'][self.episode_num][episode_step] = r
self.buffer['dw'][self.episode_num][episode_step] = dw
self.buffer['active'][self.episode_num][episode_step] = 1.0
def store_last_step(self, episode_step, obs_n, s, avail_a_n):
self.buffer['obs_n'][self.episode_num][episode_step] = obs_n
self.buffer['s'][self.episode_num][episode_step] = s
self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n
self.episode_len[self.episode_num] = episode_step # Record the length of this episode
self.episode_num = (self.episode_num + 1) % self.buffer_size
self.current_size = min(self.current_size + 1, self.buffer_size)
def sample(self):
# Randomly sampling
index = np.random.choice(self.current_size, size=self.batch_size, replace=False)
max_episode_len = int(np.max(self.episode_len[index]))
batch = {}
for key in self.buffer.keys():
if key == 'obs_n' or key == 's' or key == 'avail_a_n' or key == 'last_onehot_a_n':
batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len + 1], dtype=torch.float32)
elif key == 'a_n':
batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.long)
else:
batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.float32)
return batch, max_episode_len
from qmix_smac import QMIX_SMAC
qmix_smac .py文件
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mix_net import QMIX_Net, VDN_Net
# orthogonal initialization
def orthogonal_init(layer, gain=1.0):
for name, param in layer.named_parameters():
if 'bias' in name:
nn.init.constant_(param, 0)
elif 'weight' in name:
nn.init.orthogonal_(param, gain=gain)
class Q_network_RNN(nn.Module):
def __init__(self, args, input_dim):
super(Q_network_RNN, self).__init__()
self.rnn_hidden = None
self.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim)
self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_dim)
if args.use_orthogonal_init:
print("------use_orthogonal_init------")
orthogonal_init(self.fc1)
orthogonal_init(self.rnn)
orthogonal_init(self.fc2)
def forward(self, inputs):
# When 'choose_action', inputs.shape(N,input_dim)
# When 'train', inputs.shape(bach_size*N,input_dim)
x = F.relu(self.fc1(inputs))
self.rnn_hidden = self.rnn(x, self.rnn_hidden)
Q = self.fc2(self.rnn_hidden)
return Q
class Q_network_MLP(nn.Module):
def __init__(self, args, input_dim):
super(Q_network_MLP, self).__init__()
self.rnn_hidden = None
self.fc1 = nn.Linear(input_dim, args.mlp_hidden_dim)
self.fc2 = nn.Linear(args.mlp_hidden_dim, args.mlp_hidden_dim)
self.fc3 = nn.Linear(args.mlp_hidden_dim, args.action_dim)
if args.use_orthogonal_init:
print("------use_orthogonal_init------")
orthogonal_init(self.fc1)
orthogonal_init(self.fc2)
orthogonal_init(self.fc3)
def forward(self, inputs):
# When 'choose_action', inputs.shape(N,input_dim)
# When 'train', inputs.shape(bach_size,max_episode_len,N,input_dim)
x = F.relu(self.fc1(inputs))
x = F.relu(self.fc2(x))
Q = self.fc3(x)
return Q
class QMIX_SMAC(object):
def __init__(self, args):
self.N = args.N
self.action_dim = args.action_dim
self.obs_dim = args.obs_dim
self.state_dim = args.state_dim
self.add_last_action = args.add_last_action
self.add_agent_id = args.add_agent_id
self.max_train_steps=args.max_train_steps
self.lr = args.lr
self.gamma = args.gamma
self.use_grad_clip = args.use_grad_clip
self.batch_size = args.batch_size # 这里的batch_size代表有多少个episode
self.target_update_freq = args.target_update_freq
self.tau = args.tau
self.use_hard_update = args.use_hard_update
self.use_rnn = args.use_rnn
self.algorithm = args.algorithm
self.use_double_q = args.use_double_q
self.use_RMS = args.use_RMS
self.use_lr_decay = args.use_lr_decay
# Compute the input dimension
self.input_dim = self.obs_dim
if self.add_last_action:
print("------add last action------")
self.input_dim += self.action_dim
if self.add_agent_id:
print("------add agent id------")
self.input_dim += self.N
if self.use_rnn:
print("------use RNN------")
self.eval_Q_net = Q_network_RNN(args, self.input_dim)
self.target_Q_net = Q_network_RNN(args, self.input_dim)
else:
print("------use MLP------")
self.eval_Q_net = Q_network_MLP(args, self.input_dim)
self.target_Q_net = Q_network_MLP(args, self.input_dim)
self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())
if self.algorithm == "QMIX":
print("------algorithm: QMIX------")
self.eval_mix_net = QMIX_Net(args)
self.target_mix_net = QMIX_Net(args)
elif self.algorithm == "VDN":
print("------algorithm: VDN------")
self.eval_mix_net = VDN_Net()
self.target_mix_net = VDN_Net()
else:
print("wrong!!!")
self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())
self.eval_parameters = list(self.eval_mix_net.parameters()) + list(self.eval_Q_net.parameters())
if self.use_RMS:
print("------optimizer: RMSprop------")
self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=self.lr)
else:
print("------optimizer: Adam------")
self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)
self.train_step = 0
def choose_action(self, obs_n, last_onehot_a_n, avail_a_n, epsilon):
with torch.no_grad():
if np.random.uniform() < epsilon: # epsilon-greedy
# Only available actions can be chosen
a_n = [np.random.choice(np.nonzero(avail_a)[0]) for avail_a in avail_a_n]
else:
inputs = []
obs_n = torch.tensor(obs_n, dtype=torch.float32) # obs_n.shape=(N,obs_dim)
inputs.append(obs_n)
if self.add_last_action:
last_a_n = torch.tensor(last_onehot_a_n, dtype=torch.float32)
inputs.append(last_a_n)
if self.add_agent_id:
inputs.append(torch.eye(self.N))
inputs = torch.cat([x for x in inputs], dim=-1) # inputs.shape=(N,inputs_dim)
q_value = self.eval_Q_net(inputs)
avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim)
q_value[avail_a_n == 0] = -float('inf') # Mask the unavailable actions
a_n = q_value.argmax(dim=-1).numpy()
return a_n
def train(self, replay_buffer, total_steps):
batch, max_episode_len = replay_buffer.sample() # Get training data
self.train_step += 1
inputs = self.get_inputs(batch, max_episode_len) # inputs.shape=(bach_size,max_episode_len+1,N,input_dim)
if self.use_rnn:
self.eval_Q_net.rnn_hidden = None
self.target_Q_net.rnn_hidden = None
q_evals, q_targets = [], []
for t in range(max_episode_len): # t=0,1,2,...(episode_len-1)
q_eval = self.eval_Q_net(inputs[:, t].reshape(-1, self.input_dim)) # q_eval.shape=(batch_size*N,action_dim)
q_target = self.target_Q_net(inputs[:, t + 1].reshape(-1, self.input_dim))
q_evals.append(q_eval.reshape(self.batch_size, self.N, -1)) # q_eval.shape=(batch_size,N,action_dim)
q_targets.append(q_target.reshape(self.batch_size, self.N, -1))
# Stack them according to the time (dim=1)
q_evals = torch.stack(q_evals, dim=1) # q_evals.shape=(batch_size,max_episode_len,N,action_dim)
q_targets = torch.stack(q_targets, dim=1)
else:
q_evals = self.eval_Q_net(inputs[:, :-1]) # q_evals.shape=(batch_size,max_episode_len,N,action_dim)
q_targets = self.target_Q_net(inputs[:, 1:])
with torch.no_grad():
if self.use_double_q: # If use double q-learning, we use eval_net to choose actions,and use target_net to compute q_target
q_eval_last = self.eval_Q_net(inputs[:, -1].reshape(-1, self.input_dim)).reshape(self.batch_size, 1, self.N, -1)
q_evals_next = torch.cat([q_evals[:, 1:], q_eval_last], dim=1) # q_evals_next.shape=(batch_size,max_episode_len,N,action_dim)
q_evals_next[batch['avail_a_n'][:, 1:] == 0] = -999999
a_argmax = torch.argmax(q_evals_next, dim=-1, keepdim=True) # a_max.shape=(batch_size,max_episode_len, N, 1)
q_targets = torch.gather(q_targets, dim=-1, index=a_argmax).squeeze(-1) # q_targets.shape=(batch_size, max_episode_len, N)
else:
q_targets[batch['avail_a_n'][:, 1:] == 0] = -999999
q_targets = q_targets.max(dim=-1)[0] # q_targets.shape=(batch_size, max_episode_len, N)
# batch['a_n'].shape(batch_size,max_episode_len, N)
q_evals = torch.gather(q_evals, dim=-1, index=batch['a_n'].unsqueeze(-1)).squeeze(-1) # q_evals.shape(batch_size, max_episode_len, N)
# Compute q_total using QMIX or VDN, q_total.shape=(batch_size, max_episode_len, 1)
if self.algorithm == "QMIX":
q_total_eval = self.eval_mix_net(q_evals, batch['s'][:, :-1])
q_total_target = self.target_mix_net(q_targets, batch['s'][:, 1:])
else:
q_total_eval = self.eval_mix_net(q_evals)
q_total_target = self.target_mix_net(q_targets)
# targets.shape=(batch_size,max_episode_len,1)
targets = batch['r'] + self.gamma * (1 - batch['dw']) * q_total_target
td_error = (q_total_eval - targets.detach())
mask_td_error = td_error * batch['active']
loss = (mask_td_error ** 2).sum() / batch['active'].sum()
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.eval_parameters, 10)
self.optimizer.step()
if self.use_hard_update:
# hard update
if self.train_step % self.target_update_freq == 0:
self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())
self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())
else:
# Softly update the target networks
for param, target_param in zip(self.eval_Q_net.parameters(), self.target_Q_net.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.eval_mix_net.parameters(), self.target_mix_net.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
if self.use_lr_decay:
self.lr_decay(total_steps)
def lr_decay(self, total_steps): # Learning rate Decay
lr_now = self.lr * (1 - total_steps / self.max_train_steps)
for p in self.optimizer.param_groups:
p['lr'] = lr_now
def get_inputs(self, batch, max_episode_len):
inputs = []
inputs.append(batch['obs_n'])
if self.add_last_action:
inputs.append(batch['last_onehot_a_n'])
if self.add_agent_id:
agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(self.batch_size, max_episode_len + 1, 1, 1)
inputs.append(agent_id_one_hot)
# inputs.shape=(bach_size,max_episode_len+1,N,input_dim)
inputs = torch.cat([x for x in inputs], dim=-1)
return inputs
def save_model(self, env_name, algorithm, number, seed, total_steps):
torch.save(self.eval_Q_net.state_dict(), "./model/{}/{}_eval_rnn_number_{}_seed_{}_step_{}k.pth".format(env_name, algorithm, number, seed, int(total_steps / 1000)))
from normalization import Normalization
normalization .py文件
import numpy as np
class RunningMeanStd:
# Dynamically calculate mean and std
def __init__(self, shape): # shape:the dimension of input data
self.n = 0
self.mean = np.zeros(shape)
self.S = np.zeros(shape)
self.std = np.sqrt(self.S)
def update(self, x):
x = np.array(x)
self.n += 1
if self.n == 1:
self.mean = x
self.std = x
else:
old_mean = self.mean.copy()
self.mean = old_mean + (x - old_mean) / self.n
self.S = self.S + (x - old_mean) * (x - self.mean)
self.std = np.sqrt(self.S / self.n)
class Normalization:
def __init__(self, shape):
self.running_ms = RunningMeanStd(shape=shape)
def __call__(self, x, update=True):
# Whether to update the mean and std,during the evaluating,update=False
if update:
self.running_ms.update(x)
x = (x - self.running_ms.mean) / (self.running_ms.std + 1e-8)
return x
class RewardScaling:
def __init__(self, shape, gamma):
self.shape = shape # reward shape=1
self.gamma = gamma # discount factor
self.running_ms = RunningMeanStd(shape=self.shape)
self.R = np.zeros(self.shape)
def __call__(self, x):
self.R = self.gamma * self.R + x
self.running_ms.update(self.R)
x = x / (self.running_ms.std + 1e-8) # Only divided std
return x
def reset(self): # When an episode is done,we should reset 'self.R'
self.R = np.zeros(self.shape)
from mix_net import QMIX_Net, VDN_Net
mix_net .py文件
import torch
import torch.nn.functional as F
class QMIX_Net(nn.Module):
def __init__(self, args):
super(QMIX_Net, self).__init__()
self.N = args.N
self.state_dim = args.state_dim
self.batch_size = args.batch_size
self.qmix_hidden_dim = args.qmix_hidden_dim
self.hyper_hidden_dim = args.hyper_hidden_dim
self.hyper_layers_num = args.hyper_layers_num
"""
w1:(N, qmix_hidden_dim)
b1:(1, qmix_hidden_dim)
w2:(qmix_hidden_dim, 1)
b2:(1, 1)
因为生成的hyper_w1需要是一个矩阵,而pytorch神经网络只能输出一个向量,
所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵
"""
if self.hyper_layers_num == 2:
print("hyper_layers_num=2")
self.hyper_w1 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),
nn.ReLU(),
nn.Linear(self.hyper_hidden_dim, self.N * self.qmix_hidden_dim))
self.hyper_w2 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),
nn.ReLU(),
nn.Linear(self.hyper_hidden_dim, self.qmix_hidden_dim * 1))
elif self.hyper_layers_num == 1:
print("hyper_layers_num=1")
self.hyper_w1 = nn.Linear(self.state_dim, self.N * self.qmix_hidden_dim)
self.hyper_w2 = nn.Linear(self.state_dim, self.qmix_hidden_dim * 1)
else:
print("wrong!!!")
self.hyper_b1 = nn.Linear(self.state_dim, self.qmix_hidden_dim)
self.hyper_b2 = nn.Sequential(nn.Linear(self.state_dim, self.qmix_hidden_dim),
nn.ReLU(),
nn.Linear(self.qmix_hidden_dim, 1))
def forward(self, q, s):
# q.shape(batch_size, max_episode_len, N)
# s.shape(batch_size, max_episode_len,state_dim)
q = q.view(-1, 1, self.N) # (batch_size * max_episode_len, 1, N)
s = s.reshape(-1, self.state_dim) # (batch_size * max_episode_len, state_dim)
w1 = torch.abs(self.hyper_w1(s)) # (batch_size * max_episode_len, N * qmix_hidden_dim)
b1 = self.hyper_b1(s) # (batch_size * max_episode_len, qmix_hidden_dim)
w1 = w1.view(-1, self.N, self.qmix_hidden_dim) # (batch_size * max_episode_len, N, qmix_hidden_dim)
b1 = b1.view(-1, 1, self.qmix_hidden_dim) # (batch_size * max_episode_len, 1, qmix_hidden_dim)
# torch.bmm: 3 dimensional tensor multiplication
q_hidden = F.elu(torch.bmm(q, w1) + b1) # (batch_size * max_episode_len, 1, qmix_hidden_dim)
w2 = torch.abs(self.hyper_w2(s)) # (batch_size * max_episode_len, qmix_hidden_dim * 1)
b2 = self.hyper_b2(s) # (batch_size * max_episode_len,1)
w2 = w2.view(-1, self.qmix_hidden_dim, 1) # (batch_size * max_episode_len, qmix_hidden_dim, 1)
b2 = b2.view(-1, 1, 1) # (batch_size * max_episode_len, 1, 1)
q_total = torch.bmm(q_hidden, w2) + b2 # (batch_size * max_episode_len, 1, 1)
q_total = q_total.view(self.batch_size, -1, 1) # (batch_size, max_episode_len, 1)
return q_total
class VDN_Net(nn.Module):
def __init__(self, ):
super(VDN_Net, self).__init__()
def forward(self, q):
return torch.sum(q, dim=-1, keepdim=True) # (batch_size, max_episode_len, 1)
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。