【MADRL】基于MADRL的单调价值函数分解(QMIX)算法

news2025/1/17 0:46:36

        本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:

       强化学习(5)---《【MADRL】基于MADRL的单调价值函数分解(QMIX)算法》

【MADRL】基于MADRL的单调价值函数分解(QMIX)算法

目录

0. 前言

1. 背景与挑战

2. QMIX算法架构

3. 算法训练过程

4. QMIX的优势

5. QMIX的应用

6. 局限性与改进

 [Python] QMIX实现(可移植)


0. 前言

        基于MADRL的单调价值函数分解(Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning)QMIX 是一种用于多智能体强化学习的算法,特别适用于需要协作的多智能体环境,如分布式控制、团队作战等场景。QMIX 算法由 Rashid 等人在 2018 年提出,其核心思想是通过一种混合网络(Mixing Network)来对各个智能体的局部 Q 值进行非线性组合,从而得到全局 Q 值。

算法原文:Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning

算法程序例程

个人成功移植的算法程序和注释在下文


1. 背景与挑战

        在多智能体强化学习中,每个智能体都需要基于自身的观测和经验来学习策略。在一个协作环境中,多个智能体的决策往往相互影响,因此仅考虑单个智能体的 Q 值并不足够。直接对整个系统的 Q 值进行建模在计算上是不可行的,因为状态和动作空间会随着智能体数量呈指数增长。


2. QMIX算法架构

QMIX算法由以下几个核心组件组成:

2.1 局部Q网络(Individual Q Networks)

  • 每个智能体都有一个独立的局部Q网络,该网络输入智能体的局部观测(o_i)和动作(a_i),输出该智能体的局部Q值(Q_i(o_i, a_i))
  • 局部Q网络可以使用任何深度神经网络结构来表示,如卷积神经网络(CNN)或前馈神经网络(FNN),根据任务的具体需求进行选择。

2.2 混合网络(Mixing Network)

  • 混合网络的作用是将各个智能体的局部Q值进行组合,生成全局Q值(Q_{tot})。该网络的结构是一个完全连接的神经网络,由一组参数化的权重和偏置决定。
  • 混合网络的输入是所有智能体的局部Q值 (Q_i)以及全局状态 (s)(在训练过程中使用)。输出是全局Q值 (Q_{tot})
  • 单调性约束:混合网络的设计要求全局Q值 (Q_{tot})对于各个局部Q值 (Q_i)是单调非减函数。这意味着,任意一个局部Q值的增加不会导致全局Q值的减小。该约束通过使用非负的权重来实现。

2.3 全局Q值的计算

        混合网络根据以下公式计算全局Q值:

 [ Q_{tot} = f\left(Q_1(o_1, a_1), Q_2(o_2, a_2), \dots, Q_n(o_n, a_n); s\right) ]

        其中,(f) 表示混合网络的映射函数,(n) 是智能体的数量,(s) 是全局状态信息。


3. 算法训练过程

QMIX的训练基于Q-learning的框架,具体步骤如下:

3.1 经验采集(Experience Collection)

        在每个时间步,所有智能体根据当前策略选择动作,并与环境交互,收集经验样本( (s, \mathbf{o}, \mathbf{a}, r, s') ),其中(\mathbf{o})表示所有智能体的观测集合,(\mathbf{a})表示所有智能体的动作集合,(r) 是全局奖励,(s') 是下一个状态。

3.2 目标Q值计算(Target Q Calculation)

        计算下一个状态 (s') 下的目标Q值:

[ y = r + \gamma \max_{\mathbf{a}'} Q_{tot}(s', \mathbf{a}'; \theta^-) ]

        其中,(\gamma) 是折扣因子,(\theta^-)是目标网络的参数(使用延迟更新策略)。

3.3 损失函数与优化(Loss Function and Optimization)

         通过最小化TD误差来更新混合网络和局部Q网络的参数:

[ L(\theta) = \mathbb{E}\left[\left(Q_{tot}(s, \mathbf{a}; \theta) - y\right)^2\right] ]

        使用反向传播和随机梯度下降(SGD)来更新网络参数。

3.4 目标网络的更新

        为了稳定训练过程,QMIX使用了目标网络。目标网络的参数(\theta^-)以较低的频率从当前网络的参数 (\theta)复制而来。


4. QMIX的优势

  • 协作性:通过全局Q值的优化,QMIX能够有效捕捉智能体之间的协作关系。
  • 可扩展性:由于混合网络的设计,QMIX可以扩展到更多智能体的环境中,而不会因为联合动作空间的指数级增长而受到影响。
  • 灵活性:通过非线性混合网络,QMIX能够处理复杂的协作任务,而不仅限于线性组合策略。

5. QMIX的应用

  • 分布式机器人控制:在多个机器人需要协作完成任务的场景下,QMIX可以学习到有效的协作策略。
  • 团队游戏AI:在需要团队协作的游戏中,QMIX被广泛应用于训练复杂的多智能体AI。
  • 资源分配与管理:在智能电网或多无人机系统中,QMIX能够有效处理多智能体之间的资源协调问题。

6. 局限性与改进

  • 策略的表达能力受限:由于单调性约束,QMIX可能无法表达某些复杂的非线性策略。
  • 样本效率:在高维环境中,QMIX对样本的需求较大,训练时间较长。
  • 改进方法:后续的算法如QTRAN、QPLEX等在不同程度上尝试解决这些局限性,进一步提升了多智能体强化学习的性能。

 [Python] QMIX实现(可移植)

        若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。

主函数文件:

"""
@content: QMIX
@author: 不去幼儿园
@Timeline: 2024.08.21
"""
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from env_base import Env  # @移植事项:导入环境
import argparse
from replay_buffer import ReplayBuffer  # @移植事项:导入其他类
from qmix_smac import QMIX_SMAC
from normalization import Normalization


class Runner_QMIX_SMAC:
    def __init__(self, args, env_name, number, seed):
        self.args = args
        self.env_name = env_name
        self.number = number
        self.seed = seed
        # Set random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Create env
        # self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed)
        """@移植事项
        1.环境声明
        2.环境参数设置:注意相关参数的格式
        """
        self.env = Env()    # @移植事项:环境声明
        # self.env_info = self.env.get_env_info()
        self.args.N = 3  # The number of agents
        self.args.obs_dim = 15  # The dimensions of an agent's observation space
        self.args.state_dim = 100+9*3+2*3  # The dimensions of global state space
        self.args.action_dim = 9  # The dimensions of an agent's action space
        self.args.episode_limit = 50  # Maximum number of steps per episod
        print("number of agents={}".format(self.args.N))
        print("obs_dim={}".format(self.args.obs_dim))
        print("state_dim={}".format(self.args.state_dim))
        print("action_dim={}".format(self.args.action_dim))
        print("episode_limit={}".format(self.args.episode_limit))

        # Create N agents
        self.agent_n = QMIX_SMAC(self.args)
        self.replay_buffer = ReplayBuffer(self.args)

        # Create a tensorboard
        self.writer = SummaryWriter(log_dir='./runs/{}/{}_env_{}_number_{}_seed_{}'.format(self.args.algorithm, self.args.algorithm, self.env_name, self.number, self.seed))

        self.epsilon = self.args.epsilon  # Initialize the epsilon
        self.win_rates = []  # Record the win rates
        self.total_steps = 0
        if self.args.use_reward_norm:
            print("------use reward norm------")
            self.reward_norm = Normalization(shape=1)

    def run(self, ):
        evaluate_num = -1  # Record the number of evaluations
        while self.total_steps < self.args.max_train_steps:
            if self.total_steps // self.args.evaluate_freq > evaluate_num:
                self.evaluate_policy()  # Evaluate the policy every 'evaluate_freq' steps
                evaluate_num += 1

            _, _, episode_steps = self.run_episode_smac(evaluate=False)  # Run an episode
            self.total_steps += episode_steps

            if self.replay_buffer.current_size >= self.args.batch_size:
                self.agent_n.train(self.replay_buffer, self.total_steps)  # Training

        self.evaluate_policy()
        # self.env.close()

    def evaluate_policy(self, ):
        win_times = 0
        evaluate_reward = 0
        goal_num_buffer__ = []
        for _ in range(self.args.evaluate_times):
            win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True)

            """获取其他状态数据"""
            goal_num_buffer_ = self.env.get_state_data()  # @移植事项:其他状态获取
            goal_num_buffer_ = np.array(goal_num_buffer_)
            goal_num_buffer__.append(goal_num_buffer_)

            if win_tag:
                win_times += 1
            evaluate_reward += episode_reward

        goal_num_buffer = np.sum(goal_num_buffer__[:], axis=0) / self.args.evaluate_times
        log_flag = ["state/target_num", "state/target_num", "state/crash_num",
                    "state/ratio"]
        for i in range(4):
            goal_num = goal_num_buffer[i]
            goal_num = {log_flag[i]: goal_num}
            log_state(name=log_flag[i], state=goal_num, step=self.total_steps)

        win_rate = win_times / self.args.evaluate_times
        evaluate_reward = evaluate_reward / self.args.evaluate_times

        reward_total = {"state/reward_total": evaluate_reward}
        log_state(name="state/reward_total", state=reward_total, step=self.total_steps)

        self.win_rates.append(win_rate)
        print("total_steps:{}\tepisode:{}\tevaluate_reward:{:.3f}\t"
              "target_num:{:.3f}\ttarget_num:{:.3f}\tcrash_num:{:.3f}\tratio:{:.3f}"
              .format(self.total_steps, int(self.total_steps / 1250 + 1), evaluate_reward,
                      goal_num_buffer[0], goal_num_buffer[1], goal_num_buffer[2], goal_num_buffer[3]))
        # self.writer.add_scalar('win_rate_{}'.format(self.env_name), win_rate, global_step=self.total_steps)
        # Save the win rates
        np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.args.algorithm, self.env_name, self.number, self.seed), np.array(self.win_rates))

    def run_episode_smac(self, evaluate=False):
        win_tag = False
        episode_reward = 0
        """移植事项:环境运行
        1.环境重置函数设置
        2.环境状态返回函数设置:注意格式
        3.环境下一步更新:注意返回值
        """
        self.env.reset()  # @移植事项:环境重置函数
        if self.args.use_rnn:  # If use RNN, before the beginning of each episode,reset the rnn_hidden of the Q network.
            self.agent_n.eval_Q_net.rnn_hidden = None
        last_onehot_a_n = np.zeros((self.args.N, self.args.action_dim))  # Last actions of N agents(one-hot)
        for episode_step in range(self.args.episode_limit):
            obs_n = self.env.get_obs()  # obs_n.shape=(N,obs_dim)  # @移植事项:观测状态获取
            s = self.env.get_state()  # s.shape=(state_dim,)  # @移植事项:状态获取
            # avail_a_n = self.env.get_avail_actions()  # Get available actions of N agents, avail_a_n.shape=(N,action_dim)
            avail_a_n = [[1] * 9 for _ in range(3)]
            epsilon = 0 if evaluate else self.epsilon

            a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon)

            last_onehot_a_n = np.eye(self.args.action_dim)[a_n]  # Convert actions to one-hot vectors
            _, r_, done_, info = self.env.step(a_n)  # @移植事项:环境下一步更新
            done = done_[0]
            r = sum(list(np.array(r_).flatten()))
            win_tag = True if done and 'battle_won' in info and info['battle_won'] else False
            episode_reward += r

            if not evaluate:
                if self.args.use_reward_norm:
                    r = self.reward_norm(r)
                """"
                    When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them;
                    dw means dead or win,there is no next state s';
                    but when reaching the max_episode_steps,there is a next state s' actually.
                """
                if done and episode_step + 1 != self.args.episode_limit:
                    dw = True
                else:
                    dw = False

                # Store the transition
                self.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw)
                # Decay the epsilon
                self.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_min

            if done:
                break

        if not evaluate:
            # An episode is over, store obs_n, s and avail_a_n in the last step
            obs_n = self.env.get_obs()  # @移植事项
            s = self.env.get_state()  # @移植事项

            # avail_a_n = self.env.get_avail_actions()
            avail_a_n = [[1] * 9 for _ in range(3)]

            self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n)

        return win_tag, episode_reward, episode_step + 1

# 运行结果展示
from tensorboardX import SummaryWriter
writer = SummaryWriter()
def log_state(name, state, step):
    writer.add_scalars(name, state, step)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")
    parser.add_argument("--max_train_steps", type=int, default=int(1e6), help=" Maximum number of training steps")
    parser.add_argument("--evaluate_freq", type=float, default=1250, help="Evaluate the policy every 'evaluate_freq' steps")
    parser.add_argument("--evaluate_times", type=float, default=5, help="Evaluate times")
    parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")

    parser.add_argument("--algorithm", type=str, default="QMIX", help="QMIX or VDN")
    parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon")
    parser.add_argument("--epsilon_decay_steps", type=float, default=50000, help="How many steps before the epsilon decays to the minimum")
    parser.add_argument("--epsilon_min", type=float, default=0.05, help="Minimum epsilon")
    parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network")
    parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network")
    parser.add_argument("--hyper_layers_num", type=int, default=1, help="The number of layers of hyper-network")
    parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")
    parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")
    parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")
    parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Orthogonal initialization")
    parser.add_argument("--use_grad_clip", type=bool, default=True, help="Gradient clip")
    parser.add_argument("--use_lr_decay", type=bool, default=False, help="use lr decay")
    parser.add_argument("--use_RMS", type=bool, default=False, help="Whether to use RMS,if False, we will use Adam")
    parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation")
    parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent id into the observation")
    parser.add_argument("--use_double_q", type=bool, default=True, help="Whether to use double q-learning")
    parser.add_argument("--use_reward_norm", type=bool, default=False, help="Whether to use reward normalization")
    parser.add_argument("--use_hard_update", type=bool, default=True, help="Whether to use hard update")
    parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network")
    parser.add_argument("--tau", type=int, default=0.005, help="If use soft update")

    args = parser.parse_args()
    args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_steps

    env_names = ['3m', '8m', '2s3z']
    env_index = 0
    runner = Runner_QMIX_SMAC(args, env_name=env_names[env_index], number=1, seed=0)
    runner.run()

from replay_buffer import ReplayBuffer 

replay_buffer.py文件

import numpy as np
import torch
import copy


class ReplayBuffer:
    def __init__(self, args):
        self.N = args.N
        self.obs_dim = args.obs_dim
        self.state_dim = args.state_dim
        self.action_dim = args.action_dim
        self.episode_limit = args.episode_limit
        self.buffer_size = args.buffer_size
        self.batch_size = args.batch_size
        self.episode_num = 0
        self.current_size = 0
        self.buffer = {'obs_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.obs_dim]),
                       's': np.zeros([self.buffer_size, self.episode_limit + 1, self.state_dim]),
                       'avail_a_n': np.ones([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]),  # Note: We use 'np.ones' to initialize 'avail_a_n'
                       'last_onehot_a_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]),
                       'a_n': np.zeros([self.buffer_size, self.episode_limit, self.N]),
                       'r': np.zeros([self.buffer_size, self.episode_limit, 1]),
                       'dw': np.ones([self.buffer_size, self.episode_limit, 1]),  # Note: We use 'np.ones' to initialize 'dw'
                       'active': np.zeros([self.buffer_size, self.episode_limit, 1])
                       }
        self.episode_len = np.zeros(self.buffer_size)

    def store_transition(self, episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw):
        self.buffer['obs_n'][self.episode_num][episode_step] = obs_n
        self.buffer['s'][self.episode_num][episode_step] = s
        self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n
        self.buffer['last_onehot_a_n'][self.episode_num][episode_step + 1] = last_onehot_a_n
        self.buffer['a_n'][self.episode_num][episode_step] = a_n
        self.buffer['r'][self.episode_num][episode_step] = r
        self.buffer['dw'][self.episode_num][episode_step] = dw

        self.buffer['active'][self.episode_num][episode_step] = 1.0

    def store_last_step(self, episode_step, obs_n, s, avail_a_n):
        self.buffer['obs_n'][self.episode_num][episode_step] = obs_n
        self.buffer['s'][self.episode_num][episode_step] = s
        self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n
        self.episode_len[self.episode_num] = episode_step  # Record the length of this episode
        self.episode_num = (self.episode_num + 1) % self.buffer_size
        self.current_size = min(self.current_size + 1, self.buffer_size)

    def sample(self):
        # Randomly sampling
        index = np.random.choice(self.current_size, size=self.batch_size, replace=False)
        max_episode_len = int(np.max(self.episode_len[index]))
        batch = {}
        for key in self.buffer.keys():
            if key == 'obs_n' or key == 's' or key == 'avail_a_n' or key == 'last_onehot_a_n':
                batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len + 1], dtype=torch.float32)
            elif key == 'a_n':
                batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.long)
            else:
                batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.float32)

        return batch, max_episode_len

from qmix_smac import QMIX_SMAC

qmix_smac .py文件

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mix_net import QMIX_Net, VDN_Net


# orthogonal initialization
def orthogonal_init(layer, gain=1.0):
    for name, param in layer.named_parameters():
        if 'bias' in name:
            nn.init.constant_(param, 0)
        elif 'weight' in name:
            nn.init.orthogonal_(param, gain=gain)


class Q_network_RNN(nn.Module):
    def __init__(self, args, input_dim):
        super(Q_network_RNN, self).__init__()
        self.rnn_hidden = None

        self.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_dim)
        if args.use_orthogonal_init:
            print("------use_orthogonal_init------")
            orthogonal_init(self.fc1)
            orthogonal_init(self.rnn)
            orthogonal_init(self.fc2)

    def forward(self, inputs):
        # When 'choose_action', inputs.shape(N,input_dim)
        # When 'train', inputs.shape(bach_size*N,input_dim)
        x = F.relu(self.fc1(inputs))
        self.rnn_hidden = self.rnn(x, self.rnn_hidden)
        Q = self.fc2(self.rnn_hidden)
        return Q


class Q_network_MLP(nn.Module):
    def __init__(self, args, input_dim):
        super(Q_network_MLP, self).__init__()
        self.rnn_hidden = None

        self.fc1 = nn.Linear(input_dim, args.mlp_hidden_dim)
        self.fc2 = nn.Linear(args.mlp_hidden_dim, args.mlp_hidden_dim)
        self.fc3 = nn.Linear(args.mlp_hidden_dim, args.action_dim)
        if args.use_orthogonal_init:
            print("------use_orthogonal_init------")
            orthogonal_init(self.fc1)
            orthogonal_init(self.fc2)
            orthogonal_init(self.fc3)

    def forward(self, inputs):
        # When 'choose_action', inputs.shape(N,input_dim)
        # When 'train', inputs.shape(bach_size,max_episode_len,N,input_dim)
        x = F.relu(self.fc1(inputs))
        x = F.relu(self.fc2(x))
        Q = self.fc3(x)
        return Q


class QMIX_SMAC(object):
    def __init__(self, args):
        self.N = args.N
        self.action_dim = args.action_dim
        self.obs_dim = args.obs_dim
        self.state_dim = args.state_dim
        self.add_last_action = args.add_last_action
        self.add_agent_id = args.add_agent_id
        self.max_train_steps=args.max_train_steps
        self.lr = args.lr
        self.gamma = args.gamma
        self.use_grad_clip = args.use_grad_clip
        self.batch_size = args.batch_size  # 这里的batch_size代表有多少个episode
        self.target_update_freq = args.target_update_freq
        self.tau = args.tau
        self.use_hard_update = args.use_hard_update
        self.use_rnn = args.use_rnn
        self.algorithm = args.algorithm
        self.use_double_q = args.use_double_q
        self.use_RMS = args.use_RMS
        self.use_lr_decay = args.use_lr_decay

        # Compute the input dimension
        self.input_dim = self.obs_dim
        if self.add_last_action:
            print("------add last action------")
            self.input_dim += self.action_dim
        if self.add_agent_id:
            print("------add agent id------")
            self.input_dim += self.N

        if self.use_rnn:
            print("------use RNN------")
            self.eval_Q_net = Q_network_RNN(args, self.input_dim)
            self.target_Q_net = Q_network_RNN(args, self.input_dim)
        else:
            print("------use MLP------")
            self.eval_Q_net = Q_network_MLP(args, self.input_dim)
            self.target_Q_net = Q_network_MLP(args, self.input_dim)
        self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())

        if self.algorithm == "QMIX":
            print("------algorithm: QMIX------")
            self.eval_mix_net = QMIX_Net(args)
            self.target_mix_net = QMIX_Net(args)
        elif self.algorithm == "VDN":
            print("------algorithm: VDN------")
            self.eval_mix_net = VDN_Net()
            self.target_mix_net = VDN_Net()
        else:
            print("wrong!!!")
        self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())

        self.eval_parameters = list(self.eval_mix_net.parameters()) + list(self.eval_Q_net.parameters())
        if self.use_RMS:
            print("------optimizer: RMSprop------")
            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=self.lr)
        else:
            print("------optimizer: Adam------")
            self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

        self.train_step = 0

    def choose_action(self, obs_n, last_onehot_a_n, avail_a_n, epsilon):
        with torch.no_grad():
            if np.random.uniform() < epsilon:  # epsilon-greedy
                # Only available actions can be chosen
                a_n = [np.random.choice(np.nonzero(avail_a)[0]) for avail_a in avail_a_n]
            else:
                inputs = []
                obs_n = torch.tensor(obs_n, dtype=torch.float32)  # obs_n.shape=(N,obs_dim)
                inputs.append(obs_n)
                if self.add_last_action:
                    last_a_n = torch.tensor(last_onehot_a_n, dtype=torch.float32)
                    inputs.append(last_a_n)
                if self.add_agent_id:
                    inputs.append(torch.eye(self.N))

                inputs = torch.cat([x for x in inputs], dim=-1)  # inputs.shape=(N,inputs_dim)
                q_value = self.eval_Q_net(inputs)

                avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32)  # avail_a_n.shape=(N, action_dim)
                q_value[avail_a_n == 0] = -float('inf')  # Mask the unavailable actions
                a_n = q_value.argmax(dim=-1).numpy()
            return a_n

    def train(self, replay_buffer, total_steps):
        batch, max_episode_len = replay_buffer.sample()  # Get training data
        self.train_step += 1

        inputs = self.get_inputs(batch, max_episode_len)  # inputs.shape=(bach_size,max_episode_len+1,N,input_dim)
        if self.use_rnn:
            self.eval_Q_net.rnn_hidden = None
            self.target_Q_net.rnn_hidden = None
            q_evals, q_targets = [], []
            for t in range(max_episode_len):  # t=0,1,2,...(episode_len-1)
                q_eval = self.eval_Q_net(inputs[:, t].reshape(-1, self.input_dim))  # q_eval.shape=(batch_size*N,action_dim)
                q_target = self.target_Q_net(inputs[:, t + 1].reshape(-1, self.input_dim))
                q_evals.append(q_eval.reshape(self.batch_size, self.N, -1))  # q_eval.shape=(batch_size,N,action_dim)
                q_targets.append(q_target.reshape(self.batch_size, self.N, -1))

            # Stack them according to the time (dim=1)
            q_evals = torch.stack(q_evals, dim=1)  # q_evals.shape=(batch_size,max_episode_len,N,action_dim)
            q_targets = torch.stack(q_targets, dim=1)
        else:
            q_evals = self.eval_Q_net(inputs[:, :-1])  # q_evals.shape=(batch_size,max_episode_len,N,action_dim)
            q_targets = self.target_Q_net(inputs[:, 1:])

        with torch.no_grad():
            if self.use_double_q:  # If use double q-learning, we use eval_net to choose actions,and use target_net to compute q_target
                q_eval_last = self.eval_Q_net(inputs[:, -1].reshape(-1, self.input_dim)).reshape(self.batch_size, 1, self.N, -1)
                q_evals_next = torch.cat([q_evals[:, 1:], q_eval_last], dim=1) # q_evals_next.shape=(batch_size,max_episode_len,N,action_dim)
                q_evals_next[batch['avail_a_n'][:, 1:] == 0] = -999999
                a_argmax = torch.argmax(q_evals_next, dim=-1, keepdim=True)  # a_max.shape=(batch_size,max_episode_len, N, 1)
                q_targets = torch.gather(q_targets, dim=-1, index=a_argmax).squeeze(-1)  # q_targets.shape=(batch_size, max_episode_len, N)
            else:
                q_targets[batch['avail_a_n'][:, 1:] == 0] = -999999
                q_targets = q_targets.max(dim=-1)[0]  # q_targets.shape=(batch_size, max_episode_len, N)

        # batch['a_n'].shape(batch_size,max_episode_len, N)
        q_evals = torch.gather(q_evals, dim=-1, index=batch['a_n'].unsqueeze(-1)).squeeze(-1)  # q_evals.shape(batch_size, max_episode_len, N)

        # Compute q_total using QMIX or VDN, q_total.shape=(batch_size, max_episode_len, 1)
        if self.algorithm == "QMIX":
            q_total_eval = self.eval_mix_net(q_evals, batch['s'][:, :-1])
            q_total_target = self.target_mix_net(q_targets, batch['s'][:, 1:])
        else:
            q_total_eval = self.eval_mix_net(q_evals)
            q_total_target = self.target_mix_net(q_targets)
        # targets.shape=(batch_size,max_episode_len,1)
        targets = batch['r'] + self.gamma * (1 - batch['dw']) * q_total_target

        td_error = (q_total_eval - targets.detach())
        mask_td_error = td_error * batch['active']
        loss = (mask_td_error ** 2).sum() / batch['active'].sum()
        self.optimizer.zero_grad()
        loss.backward()
        if self.use_grad_clip:
            torch.nn.utils.clip_grad_norm_(self.eval_parameters, 10)
        self.optimizer.step()

        if self.use_hard_update:
            # hard update
            if self.train_step % self.target_update_freq == 0:
                self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())
                self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())
        else:
            # Softly update the target networks
            for param, target_param in zip(self.eval_Q_net.parameters(), self.target_Q_net.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.eval_mix_net.parameters(), self.target_mix_net.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        if self.use_lr_decay:
            self.lr_decay(total_steps)

    def lr_decay(self, total_steps):  # Learning rate Decay
        lr_now = self.lr * (1 - total_steps / self.max_train_steps)
        for p in self.optimizer.param_groups:
            p['lr'] = lr_now

    def get_inputs(self, batch, max_episode_len):
        inputs = []
        inputs.append(batch['obs_n'])
        if self.add_last_action:
            inputs.append(batch['last_onehot_a_n'])
        if self.add_agent_id:
            agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(self.batch_size, max_episode_len + 1, 1, 1)
            inputs.append(agent_id_one_hot)

        # inputs.shape=(bach_size,max_episode_len+1,N,input_dim)
        inputs = torch.cat([x for x in inputs], dim=-1)

        return inputs

    def save_model(self, env_name, algorithm, number, seed, total_steps):
        torch.save(self.eval_Q_net.state_dict(), "./model/{}/{}_eval_rnn_number_{}_seed_{}_step_{}k.pth".format(env_name, algorithm, number, seed, int(total_steps / 1000)))

from normalization import Normalization

normalization .py文件

import numpy as np


class RunningMeanStd:
    # Dynamically calculate mean and std
    def __init__(self, shape):  # shape:the dimension of input data
        self.n = 0
        self.mean = np.zeros(shape)
        self.S = np.zeros(shape)
        self.std = np.sqrt(self.S)

    def update(self, x):
        x = np.array(x)
        self.n += 1
        if self.n == 1:
            self.mean = x
            self.std = x
        else:
            old_mean = self.mean.copy()
            self.mean = old_mean + (x - old_mean) / self.n
            self.S = self.S + (x - old_mean) * (x - self.mean)
            self.std = np.sqrt(self.S / self.n)


class Normalization:
    def __init__(self, shape):
        self.running_ms = RunningMeanStd(shape=shape)

    def __call__(self, x, update=True):
        # Whether to update the mean and std,during the evaluating,update=False
        if update:
            self.running_ms.update(x)
        x = (x - self.running_ms.mean) / (self.running_ms.std + 1e-8)

        return x


class RewardScaling:
    def __init__(self, shape, gamma):
        self.shape = shape  # reward shape=1
        self.gamma = gamma  # discount factor
        self.running_ms = RunningMeanStd(shape=self.shape)
        self.R = np.zeros(self.shape)

    def __call__(self, x):
        self.R = self.gamma * self.R + x
        self.running_ms.update(self.R)
        x = x / (self.running_ms.std + 1e-8)  # Only divided std
        return x

    def reset(self):  # When an episode is done,we should reset 'self.R'
        self.R = np.zeros(self.shape)

from mix_net import QMIX_Net, VDN_Net

mix_net .py文件

import torch
import torch.nn.functional as F


class QMIX_Net(nn.Module):
    def __init__(self, args):
        super(QMIX_Net, self).__init__()
        self.N = args.N
        self.state_dim = args.state_dim
        self.batch_size = args.batch_size
        self.qmix_hidden_dim = args.qmix_hidden_dim
        self.hyper_hidden_dim = args.hyper_hidden_dim
        self.hyper_layers_num = args.hyper_layers_num
        """
        w1:(N, qmix_hidden_dim)
        b1:(1, qmix_hidden_dim)
        w2:(qmix_hidden_dim, 1)
        b2:(1, 1)
        因为生成的hyper_w1需要是一个矩阵,而pytorch神经网络只能输出一个向量,
        所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵
        """
        if self.hyper_layers_num == 2:
            print("hyper_layers_num=2")
            self.hyper_w1 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(self.hyper_hidden_dim, self.N * self.qmix_hidden_dim))
            self.hyper_w2 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(self.hyper_hidden_dim, self.qmix_hidden_dim * 1))
        elif self.hyper_layers_num == 1:
            print("hyper_layers_num=1")
            self.hyper_w1 = nn.Linear(self.state_dim, self.N * self.qmix_hidden_dim)
            self.hyper_w2 = nn.Linear(self.state_dim, self.qmix_hidden_dim * 1)
        else:
            print("wrong!!!")

        self.hyper_b1 = nn.Linear(self.state_dim, self.qmix_hidden_dim)
        self.hyper_b2 = nn.Sequential(nn.Linear(self.state_dim, self.qmix_hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(self.qmix_hidden_dim, 1))

    def forward(self, q, s):
        # q.shape(batch_size, max_episode_len, N)
        # s.shape(batch_size, max_episode_len,state_dim)
        q = q.view(-1, 1, self.N)  # (batch_size * max_episode_len, 1, N)
        s = s.reshape(-1, self.state_dim)  # (batch_size * max_episode_len, state_dim)

        w1 = torch.abs(self.hyper_w1(s))  # (batch_size * max_episode_len, N * qmix_hidden_dim)
        b1 = self.hyper_b1(s)  # (batch_size * max_episode_len, qmix_hidden_dim)
        w1 = w1.view(-1, self.N, self.qmix_hidden_dim)  # (batch_size * max_episode_len, N,  qmix_hidden_dim)
        b1 = b1.view(-1, 1, self.qmix_hidden_dim)  # (batch_size * max_episode_len, 1, qmix_hidden_dim)

        # torch.bmm: 3 dimensional tensor multiplication
        q_hidden = F.elu(torch.bmm(q, w1) + b1)  # (batch_size * max_episode_len, 1, qmix_hidden_dim)

        w2 = torch.abs(self.hyper_w2(s))  # (batch_size * max_episode_len, qmix_hidden_dim * 1)
        b2 = self.hyper_b2(s)  # (batch_size * max_episode_len,1)
        w2 = w2.view(-1, self.qmix_hidden_dim, 1)  # (batch_size * max_episode_len, qmix_hidden_dim, 1)
        b2 = b2.view(-1, 1, 1)  # (batch_size * max_episode_len, 1, 1)

        q_total = torch.bmm(q_hidden, w2) + b2  # (batch_size * max_episode_len, 1, 1)
        q_total = q_total.view(self.batch_size, -1, 1)  # (batch_size, max_episode_len, 1)
        return q_total


class VDN_Net(nn.Module):
    def __init__(self, ):
        super(VDN_Net, self).__init__()

    def forward(self, q):
        return torch.sum(q, dim=-1, keepdim=True)  # (batch_size, max_episode_len, 1)

     文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2060043.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

30道python自动化测试面试题与答案汇总!

Python是不可或缺的语言,它的优美与简洁令人无法自拔,下面这篇文章主要给大家介绍了关于30道python自动化测试面试题与答案汇总的相关资料,需要的朋友可以参考下 1、什么项目适合做自动化测试&#xff1f; 关键字&#xff1a;不变的、重复的、规范的 1&#xff09;任务测试明…

硬核剧透!安全领域专家及领军人物共话生态标准 | 2024 龙蜥大会=

2024 龙蜥操作系统大会由中国计算机学会开源发展委员会、中关村科学城委员会、海淀区委网信办、中国开源软件推进联盟指导&#xff0c;龙蜥社区主办&#xff0c;阿里云、中兴通讯、Intel、浪潮信息、Arm、中科方德等 24 家理事单位共同承办&#xff0c;主题为“进化重构赴未来”…

别再混淆了!一文带你理清前置机、跳板机与堡垒机的区别

本文详细介绍前置机、跳板机和堡垒机在网络安全和IT基础设施中各自扮演着重要角色&#xff0c;它们虽然有一定的相似性&#xff0c;但在功能和用途上存在显著差异。以下是对三者的详细解析&#xff1a; 前置机 概念 前置机是一种中间设备&#xff0c;通常位于客户端和服务器…

神经网络训练多个epoch,写论文的时候可以取最好的效果那一个epoch作为结果吗?

在论文中&#xff0c;通常建议报告在验证集上表现最佳的模型作为结果。你可以在训练过程中记录每个 epoch 的性能&#xff0c;并选择在验证集上性能最好的那个 epoch 的结果。这种方法能够展示你所训练的模型在其最佳状态下的表现。 这样做有几个优点&#xff1a; 客观展示模…

Linux系统编程(13)IPC(共享内存)和网络通信基础

一、共享内存 共享内存是通过映射的方式在内核中申请一段可以使用的物理内存空间来映射到用户空间中&#xff0c;用户对用户空间的操作就是直接操作物理内存区域。通过这种方式&#xff0c;进程可以直接读写这部分内存&#xff0c;从而实现高效的数据交换。相比于其他 IPC 机制…

vue2子组件生命周期被调用两次

目前解决办法是改成了这种写法&#xff0c;改为这种不会出现加载两次子组件生命周期的问题 <el-tabs v-model"activeName" style"margin: 0px"><el-tab-pane name"systemLogCollection"><span slot"label">{{ tabLi…

[数据集][目标检测]起子检测数据集VOC+YOLO格式1215张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;1215 标注数量(xml文件个数)&#xff1a;1215 标注数量(txt文件个数)&#xff1a;1215 标注…

20240822 每日AI必读资讯

特斯拉雇佣员工通过穿戴动捕服来帮助训练其人形机器人Optimus&#xff01; - 该职位被称为“数据采集操作员”&#xff0c;时薪最高可达48美元&#xff0c;要求员工每天行走超过七小时&#xff0c;携带高达30磅的重量&#xff0c;并长时间佩戴VR头显。 - 员工还必须身高在5英…

ConfigurationProperties使用

ConfigurationProperties 是 Spring Framework 的一个注解 用于将配置文件&#xff0c;环境变量中的值映射到一个 Java 对象的属性上。 简单使用 user:admin:name: adminpassword: 123456age: 18Data Component ConfigurationProperties(prefix "user.admin") publ…

Vue 满屏纵向轮播图

目录 前言轮播图效果展示具体实现实现思路具体代码前言 今天汇总一个需求,还是之前写的,要求写一个满屏的轮播图,准确的说,是鼠标滑动到轮播图的时候,轮播图固定在屏幕上,随着其中的轮播子项遍历结束后,解除固定的效果。原本我最开始想直接修改Element-UI的组件的,但是…

湖州网站建设快速建站

在当今信息化时代&#xff0c;网站的建设已成为企业和个人展示形象、传播信息的重要途径。湖州作为一个历史悠久、文化底蕴深厚的城市&#xff0c;发展迅速&#xff0c;涌现出许多需要快速建立网站的企业和个人。本文将探讨湖州网站建设的快速建站方案。 首先&#xff0c;快速建…

WRF-LES与PALM微尺度气象大涡模拟

针对微尺度气象的复杂性&#xff0c;大涡模拟&#xff08;LES&#xff09;提供了一种无可比拟的解决方案。微尺度气象学涉及对小范围内的大气过程进行精确模拟&#xff0c;这些过程往往与天气模式、地形影响和人为因素如城市布局紧密相关。在这种规模上&#xff0c;传统的气象模…

Camunda BPMN 基础组件

Camunda基础 一、参与者 参与者&#xff08;Participants&#xff09;是参与流程的对象&#xff0c;表示流程中活动的执行者&#xff0c;可以是一个组织、角色、系统或者个人。 图示为基础事件&#xff0c;除此之外还有&#xff1a; 消息开始事件&#xff08;Message Start…

第二百零四节 Java正则表达式教程 - Java正则表达式量词

Java正则表达式教程 - Java正则表达式量词 我们可以指定正则表达式中的字符的次数可以匹配字符序列。 为了使用正则表达式表达一个数字或更多的模式&#xff0c;我们可以使用量词。 下表列出了量词及其含义。 量词含义*零次或更多次一次或多次?一次或根本不{m}正好m次{m,}至…

数字工厂管理系统的使用操作难不难

在当今智能制造的浪潮中&#xff0c;引入数字工厂管理系统无疑为企业带来了前所未有的效率提升与决策优化能力。然而&#xff0c;谈及数字工厂管理系统的操作难易程度&#xff0c;这并非一个可以一概而论的问题&#xff0c;它深深植根于系统的复杂性、用户技能水平、培训深度以…

python之matplotlib (8 极坐标)-圆与心

极坐标 极坐标图像的绘制类似于三维图像的绘制&#xff0c;只需要将projection参数由3d改为polar即可。 import numpy as np import matplotlib.pyplot as plt figplt.figure() axfig.add_subplot(projectionpolar)theta np.linspace(0, 2 * np.pi, 100) r np.sin(the…

企业高性能web服务器之Nginx

文章目录 Apache经典的web服务端Apache prefork 模型Apache work 模型&#xff08;适应市场&#xff09;Apache event 模型 网络I/O网络I/O模型I/O模型网络I/O模型 Nginx架构和安装Nginx源码编译环境准备安装nginx Nginx的平滑升级及版本回滚 Nginx架构和进程Nginx进程结构Ngin…

MapTR的BEV结果可视化到PV图中

MapTRv2这篇工作很有意思的一点是预测可视化的时候&#xff0c;在Argoverse数据集上把BEV的预测结果投影到PV图中&#xff0c;来更直观地评估预测结果的好坏&#xff0c;如下图所示。 这部分的代码在maptrv2分支中的tools/maptrv2/av2_vis_pred.py中 def points_ego2img(pts_e…

爬取央视榜单节目

爬取结果&#xff1a; 热播榜&#xff1a; 动画片&#xff1a; 电视剧&#xff1a; 纪录片&#xff1a; 特别节目&#xff1a; 代码部分&#xff1a; import re import pymongoimport requestsres requests.get(https://tv.cctv.com/top/index.shtml?spmC28340.PdNvWY0LYxC…

5款文案生成器,高效率生成短剧解说文案

短剧解说在自媒体写作中很受大家的欢迎&#xff0c;相比其它的视频内容写作&#xff0c;短剧解说不仅写作更简单&#xff0c;而且也更容易带来好的流量&#xff0c;但是&#xff0c;对于短剧解说写作者来说&#xff0c;让人觉得麻烦的一点是短剧解说的文案问题&#xff0c;然而…