混合专家模型 Mixture-of-Experts (MoE)

news2024/10/6 22:20:31


  • Mixture-of-Experts (MoE)
  • Mixture of Sequential Experts(MoSE)
  • Multi-gate Mixture-of-Experts (MMoE)


1. MoE架构

MoE(Mixture of Experts)层包含一个门网络(Gating Network)和n个专家网络(Expert Network)。对于每一个输入,动态地由门网络选择k个专家网络进行激活。在具体设计中,每个输入x激活的专家网络数量k往往是一个非常小的数字。比如在MoE论文的一些实验中,作者采用了n=512,k=2的设定,也就是每次只会从512个专家网络中挑选两个来激活。在模型运算量(FLOPs)基本不变的情况下,可以显著增加模型的参数量。

MoE架构的主要特点是在模型中引入了专家网络层,通过路由机制(Routing function)选择激活哪些专家,以允许不同的专家模型对输入进行独立处理,并通过加权组合它们的输出来生成最终的预测结果。

通过稀疏模型MoE扩大大语言模型的方法:以GLaM模型为例,它包含1.2T个参数,但实际上被激活的参数(activated parameters)只有97B,远少于GPT-3,也就是说,它是稀疏激活的MoE。它与GPT-3同样是只有解码器的模型,但与GPT-3相比,GlaM获得了更好的性能。


  1. 可以扩大模型的参数数量,因为只需要激活部分参数,其他参数可以被"关机"。这降低了计算量和内存消耗。
  2. 提高效率:只激活相关的专家模块,可以提高模型效率。
  3. 组合优势:通过组合不同专家的优势,有可能获得更好的效果。

混合专家系统有两种架构:competitive MoE 和cooperative MoE。competitive MoE中数据的局部区域被强制集中在数据的各离散空间,而cooperative MoE没有进行强制限制。

2. GateNet:决策输入样本由哪个专家处理

GateNet可以理解为一个分配器,根据输入样本的特征,动态决策将其分配给哪个专家进行处理。这个过程可以通过一个softmax分类器来实现,其中每个神经元对应一个专家模型。GateNet的输出值表示了每个专家的权重。 GateNet的设计需要考虑两个关键点:输入样本特征的提取和分配策略的确定。在特征的提取方面,常用的方法是使用卷积神经网络(CNN)或者Transformer等结构来提取输入样本的特征表示。而在分配策略的确定方面,可以采用不同的注意力机制或者引入一些先验知识来指导。

这种训练过程传统上是使用期望最大化 (Expectation Maximization, EM) 来实现的。门控网络可能有一个 softmax 输出,它为每个专家提供类似概率的置信度分数。


上述公式表示了包含 n 个专家的 MoE 层的计算过程。具体来讲,首先对样本 x 进行门控计算, W 表示权重矩阵;然后由 Softmax 处理后获得样本 x 被分配到各个 expert 的权重; 然后只取前 k (通常取 1 或者 2)个最大权重,最终整个 MoE Layer 的计算结果就是选中的 k 个专家网络输出的加权和。

3. Experts:专家模型的构建与训练

专家模型是MoE架构中的核心组件,它们负责处理输入样本的具体任务。每个专家模型都是相对独立的,可以根据任务的需求选择不同的模型架构。 在训练阶段,专家模型可以采用传统的有监督学习方法进行训练。然而,为了提高模型的效果,还可以引入一些主从式训练策略。即通过联合训练GateNet和Experts,共同优化整个MoE架构。


4. 代码实现

# Sparsely-Gated Mixture-of-Experts Layers.
# See "Outrageously Large Neural Networks"
# https://arxiv.org/abs/1701.06538
# Author: David Rau
# The code is based on the TensorFlow implementation:
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py

import torch
import torch.nn as nn
from torch.distributions.normal import Normal
import numpy as np

class SparseDispatcher(object):
    """Helper for implementing a mixture of experts.
    The purpose of this class is to create input minibatches for the
    experts and to combine the results of the experts to form a unified
    output tensor.
    There are two functions:
    dispatch - take an input Tensor and create input Tensors for each expert.
    combine - take output Tensors from each expert and form a combined output
      Tensor.  Outputs from different experts for the same batch element are
      summed together, weighted by the provided "gates".
    The class is initialized with a "gates" Tensor, which specifies which
    batch elements go to which experts, and the weights to use when combining
    the outputs.  Batch element b is sent to expert e iff gates[b, e] != 0.
    The inputs and outputs are all two-dimensional [batch, depth].
    Caller is responsible for collapsing additional dimensions prior to
    calling this class and reshaping the output to the original shape.
    See common_layers.reshape_like().
    Example use:
    gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
    inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
    experts: a list of length `num_experts` containing sub-networks.
    dispatcher = SparseDispatcher(num_experts, gates)
    expert_inputs = dispatcher.dispatch(inputs)
    expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
    outputs = dispatcher.combine(expert_outputs)
    The preceding code sets the output for a particular example b to:
    output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
    This class takes advantage of sparsity in the gate matrix by including in the
    `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.

    def __init__(self, num_experts, gates):
        """Create a SparseDispatcher."""

        self._gates = gates
        self._num_experts = num_experts
        # sort experts
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        # drop indices
        _, self._expert_index = sorted_experts.split(1, dim=1)
        # get according batch index for each expert
        self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
        # calculate num samples that each expert gets
        self._part_sizes = (gates > 0).sum(0).tolist()
        # expand gates to match with self._batch_index
        gates_exp = gates[self._batch_index.flatten()]
        self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

    def dispatch(self, inp):
        """Create one input Tensor for each expert.
        The `Tensor` for a expert `i` contains the slices of `inp` corresponding
        to the batch elements `b` where `gates[b, i] > 0`.
          inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
          a list of `num_experts` `Tensor`s with shapes
            `[expert_batch_size_i, <extra_input_dims>]`.

        # assigns samples to experts whose gate is nonzero

        # expand according to batch index so we can just split by _part_sizes
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp, self._part_sizes, dim=0)

    def combine(self, expert_out, multiply_by_gates=True):
        """Sum together the expert output, weighted by the gates.
        The slice corresponding to a particular batch element `b` is computed
        as the sum over all experts `i` of the expert output, weighted by the
        corresponding gate values.  If `multiply_by_gates` is set to False, the
        gate values are ignored.
          expert_out: a list of `num_experts` `Tensor`s, each with shape
            `[expert_batch_size_i, <extra_output_dims>]`.
          multiply_by_gates: a boolean
          a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
        # apply exp to expert outputs, so we are not longer in log space
        stitched = torch.cat(expert_out, 0).exp()

        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
        zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device)
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        # add eps to all zero values in order to avoid nans when going back to log space
        combined[combined == 0] = np.finfo(float).eps
        # back to log space
        return combined.log()

    def expert_to_gates(self):
        """Gate values corresponding to the examples in the per-expert `Tensor`s.
          a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
              and shapes `[expert_batch_size_i]`
        # split nonzero gates for each expert
        return torch.split(self._nonzero_gates, self._part_sizes, dim=0)

class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.soft = nn.Softmax(1)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.soft(out)
        return out

class MoE(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element

    def __init__(self, input_size, output_size, num_experts, hidden_size, noisy_gating=True, k=4):
        super(MoE, self).__init__()
        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.output_size = output_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.k = k
        # instantiate experts
        self.experts = nn.ModuleList([MLP(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)])
        self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
        self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)

        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)
        self.register_buffer("mean", torch.tensor([0.0]))
        self.register_buffer("std", torch.tensor([1.0]))
        assert(self.k <= self.num_experts)

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        x: a `Tensor`.
        a `Scalar`.
        eps = 1e-10
        # if only num_experts = 1

        if x.shape[0] == 1:
            return torch.tensor([0], device=x.device, dtype=x.dtype)
        return x.float().var() / (x.float().mean()**2 + eps)

    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        The load is the number of examples for which the corresponding gate is >0.
        gates: a `Tensor` of shape [batch_size, n]
        a float32 `Tensor` of shape [n]
        return (gates > 0).sum(0)

    def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
        """Helper function to NoisyTopKGating.
        Computes the probability that value is in top k, given different random noise.
        This gives us a way of backpropagating from a loss that balances the number
        of times each expert is in the top k experts per example.
        In the case of no noise, pass in None for noise_stddev, and the result will
        not be differentiable.
        clean_values: a `Tensor` of shape [batch, n].
        noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
          normally distributed noise with standard deviation noise_stddev.
        noise_stddev: a `Tensor` of shape [batch, n], or None
        noisy_top_values: a `Tensor` of shape [batch, m].
           "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
        a `Tensor` of shape [batch, n].
        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()

        threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
        threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
        # is each value currently in the top k.
        normal = Normal(self.mean, self.std)
        prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        clean_logits = x @ self.w_gate
        if self.noisy_gating and train:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
            noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
            logits = noisy_logits
            logits = clean_logits

        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
        top_k_logits = top_logits[:, :self.k]
        top_k_indices = top_indices[:, :self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)

        if self.noisy_gating and self.k < self.num_experts and train:
            load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
            load = self._gates_to_load(gates)
        return gates, load

    def forward(self, x, loss_coef=1e-2):
        x: tensor shape [batch_size, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses

        y: a tensor with shape [batch_size, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        gates, load = self.noisy_top_k_gating(x, self.training)
        # calculate importance loss
        importance = gates.sum(0)
        loss = self.cv_squared(importance) + self.cv_squared(load)
        loss *= loss_coef

        dispatcher = SparseDispatcher(self.num_experts, gates)
        expert_inputs = dispatcher.dispatch(x)
        gates = dispatcher.expert_to_gates()
        expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)]
        y = dispatcher.combine(expert_outputs)
        return y, loss




GitHub - XueFuzhao/awesome-mixture-of-experts: A collection of AWESOME things about mixture-of-experts





中文编程开发语言工具编程实际案例&#xff1a;台球棋牌混合计时计费软件使用的编程构件说明 上图说明&#xff1a;该软件可以用于桌球和棋牌同时计时计费&#xff0c;在没有开台的时候&#xff0c;图片是处于等待状态&#xff0c;这使用编程工具中的固定图像构件&#xff0c;在…

【经典 PageRank 】01/2 PageRank的基本原理

一、说明 PageRank是Google搜索算法中使用的一种算法&#xff0c;用于确定页面的重要性和排名。 它是通过对网页间的链接关系进行评估来计算的&#xff0c;具有较高的链接权重的网页将获得较高的PageRank值。 PageRank是一个0到10的指标&#xff0c;其中10是最高级别&#xff0…




(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 索引签名类型 映射类型 索引查询&#xff08;访问&#xff09;类型 基本使用 同时查询多个索引的类型…

2023天猫双十一活动时间表 天猫淘宝双11预售几号开始付定金

双十一购物节是生活不可或缺的一部分&#xff0c;不论是满足基本需求还是享受生活乐趣&#xff0c;都需要购物。因此&#xff0c;双十一绝对是一个不容错过的绝佳机会&#xff0c;希望大家能善用这个机会&#xff0c;因为错过了就得再等一整年。 每日领红包&#xff1a;红包有…

基于袋獾优化的BP神经网络(分类应用) - 附代码

基于袋獾优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于袋獾优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.袋獾优化BP神经网络3.1 BP神经网络参数设置3.2 袋獾算法应用 4.测试结果&#xff1a;5.M…


c知识点合集已经完成欢迎前往主页查看&#xff0c;点点赞点点关注不迷路哦 点我进入c第一章知识点合集 MYSQL第一章节DDL数据定义语言的操作 目录 DDL-数据库操作 查询所有数据库 查询当前数据库 创建数据库 删除数据库 DDL-操作表-查询 查询当前数据库中的所有表 查询表结构…

基于斑马优化的BP神经网络(分类应用) - 附代码

基于斑马优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于斑马优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.斑马优化BP神经网络3.1 BP神经网络参数设置3.2 斑马算法应用 4.测试结果&#xff1a;5.M…


MVC&#xff08;Model-View-Controller&#xff0c;模型-视图-控制器&#xff09;模式是相当古老的设计模式之一&#xff0c;ta最早出现在SmallTalk语言中。现在&#xff0c;很多计算机语言和架构都采用了MVC模式。 MVC模式概述 MVC模式是一种设计模式&#xff0c;由3部分组成…


文章目录 1. 代码仓库2. 单源路径2.1 思路2.2 主要代码 3. 所有点对路径3.1 思路3.2 主要代码 4. 联通分量5. 环检测5.1 思路5.2 主要代码 6. 二分图检测6.1 思路6.2 主要代码6.2.1 遍历每个联通分量6.2.2 判断相邻两点的颜色是否一致 7. 最短路径问题7.1 思路7.2 代码 1. 代码…


末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…


一、介绍 机器学习已成为现代技术的基石&#xff0c;为从推荐系统到自动驾驶汽车的一切提供动力。在众多机器学习算法中&#xff0c;AdaBoost&#xff08;自适应增强的缩写&#xff09;作为一种强大的集成方法脱颖而出&#xff0c;为该领域的成功做出了重大贡献。AdaBoost 是一…


&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…


2023阿里云双十一云服务器大概会降到什么区间&#xff1f;阿里云服务器网认为会在当前的优惠价格基础上&#xff0c;降价10%左右&#xff0c;可以在阿里云CLUB中心领券&#xff1a;aliyun.club 云服务器专用满减优惠券。阿里云服务器网从各个渠道了解到大家对今年阿里云双十一服…


在写.NET实验时用visualstudio连接数据库显示”此版本的 SQL Server 不支持用户实例登录标志。该连接将关闭“&#xff0c;我是开始在数据库已经导入了这个mbf文件的。然后就去网上找一堆办法。 失败经历&#xff1a; 按照教程操作后代码语句运行显示数据库已存在。按照网上的…

王道计算机考研 操作系统学习笔记 + 完整思维导图篇章四: 文件管理

目录 文件管理 文件的逻辑结构 无结构文件 有结构文件 顺序文件 索引文件 索引顺序文件 文件目录 文件控制块&#xff08;FCB&#xff09; 目录结构分类 单级目录结构 两级目录结构 多级目录结构 &#xff08;树形目录结构&#xff09; 无环图目录结构 索引节点 文件的物理结构…

【经典PageRank 】02/2 算法和线性代数

系列前文&#xff1a;【经典 PageRank 】01/2 PageRank的基本原理-CSDN博客 一、说明 并非所有连接都同样重要&#xff01; 该算法由 Sergey 和 Lawrence 开发&#xff0c;用于在 Google 搜索中对网页进行排名。基本原则是重要或值得信赖的网页更有可能链接到其他重要网页。例…

2023.10.21 关于 阻塞队列

目录 阻塞队列 优先级队列&#xff08;Priority Queue&#xff09; 阻塞队列&#xff08;Blocking Queue&#xff09; 消息队列&#xff08;Message Queue&#xff09; 生产者消费者模型 生产者消费者模型的两个好处 标准库阻塞队列使用 实现一个简单 生产者消费者模型…


需要云服务器等云产品来学习Linux的同学可以移步/-->腾讯云<--/-->阿里云<--/-->华为云<--/官网&#xff0c;轻量型云服务器低至112元/年&#xff0c;新用户首次下单享超低折扣。 目录 一、Reactor介绍 二、基于epoll的ET模式下的Reactor计算器代码 1、Tcp…


W...Y的主页 代码片段分享 前言&#xff1a; 在上篇内容里&#xff0c;我们初识了C中的类与对象&#xff0c;了解了类的定义、类的实例化、 类的作用域等等&#xff0c;今天我们将继续深入了解类与对象的相关内容&#xff0c;学习构造函数、析构函数与拷贝构造函数&#xff…