深度学习 - 45.MMOE Gate 简单实现 By Keras

news2024/12/23 0:25:46

目录

一.引言

二.MMoE 模型分析

三.MMoE 逻辑实现

• Input

• Expert Output

• Gate Output

• Weighted Sum

• Sigmoid Output

• 完整代码

四.总结


一.引言

上一篇文章介绍了 MMoE 借鉴 MoE 的思路,为每一类输出构建一个 Gate 并最终加权多个 Expert 的输出,提高了相关性不高的多任务问题,下面根据思路简易实现下 MMoE 逻辑。

 

二.MMoE 模型分析

上图为 BaseLine、MoE 与 MMoE,下面我们专门了解 MMoE 的实现细节,其中 bs 为 BatchSize,Hidden_units 代表隐层输出维度,E0、E1、E2 代表多个 Expert,Tower A、Tower B 代表 N 个输出。

这里 Gate 与 Expert 都可以理解为浅层模型。 下面针对 bs 里的一个样本,分析下 MMoE 执行过程。

• Input 

输入一个 embedding_size 长度的向量,这里也可以是 Filed 个变量,进入对应 Embedding 层 Lookup 再做 pooling,为了简单,这里直接取 1 x F,F 为 Field Size。

• Expert Output

输入向量分别经过 E1、E2、... 的 K 个 Expert 计算,得到 Expert 个输出,每个输出维度为 1 x Hidden。

• Gate output

输入向量分别经过 G1、G2、... 的 K 个 Gate 计算,得到 Expert 个输出,维度为 1 x expert,每一维代表对应 Expert Output 的权重,注意这里 Gate 最终输出需经过一层 softmax。

• Weighted Sum

将当前样本输出的 expert 分别与对应 Expert Output 加权求和作为每个任务的 Tower 的输入向量,合并后维度为 1 x hidden_units。

• Sigmoid Output

每个 Tower 最终为 sigmoid 输出的二分类深度模型,输入为 Expert x Hidden 的加权求和,维度仍然为 1 x Hidden。

Tips:

实际 BatchSize 执行时,将上述 1 x ... 换为 bs x ... 即对应 Batch 的逻辑。

三.MMoE 逻辑实现

参数设置为:

    num_field = 4
    hidden_units = 8
    num_expert = 3
    num_output = 2

4 个 Field 域、8 维输出、3 个 Expert、2 个任务。

• Input

这里 N 就是上面的 bs 即 Batch Size,F 为 Field Size,这里实现逻辑比较简单,实践场景下可以先将 Field Lookup 得到 Embedding 再 pooling 输入到后面的 Expert 和 Gate:

    # 1.构造 Input => N x F
    num_samples = 10
    inputs = np.array([np.ones(shape=num_field) for i in range(num_samples)])
    print("Input Shape:", inputs.shape)
Input Shape: (10, 4)

• Expert Output

    # 2.构建专家 kernel [(filed * hidden) x expert]
    expert_kernels = np.random.random(size=(num_field, hidden_units, num_expert))
    print("Expert Shape:", expert_kernels.shape)

    # 3.获取 Expert 输出 => [N x F] * [F x Hidden x Expert] = N x Hidden x Expert
    outputs_by_expert = tf.tensordot(inputs, expert_kernels, axes=1)
    print("Output By Expert:", outputs_by_expert.shape)
Expert Shape: (4, 8, 3)
Output By Expert: (10, 8, 3)

• Gate Output

Gate:在 Dense 输出基础上,增加 softmax 逻辑。

class Gate(Layer):

    def __init__(self, expert_num, **kwargs):
        self.expert_num = expert_num
        self.gate = None

        super(Gate, self).__init__(**kwargs)

    def build(self, input_shape):
        self.gate = Dense(self.expert_num, activation='relu', kernel_initializer=glorot_normal_initializer)

        super(Gate, self).build(input_shape)

    def call(self, _inputs, **kwargs):
        weight = self.gate(_inputs)
        _output = tf.nn.softmax(weight)
        return _output

 根据 num_output 任务输出数,决定构造 Gate 的数量,这是 MMoE 与 MoE 的区别之一。

    # 4.构建 Gate
    gate = [Gate(num_expert) for i in range(num_output)]

    # 5.获取每个任务的 Gate 输出权重 output x N x expert
    outputs_by_gate = np.array([gate[i](inputs) for i in range(num_output)])
    print("Output By Gate:", outputs_by_gate.shape)
Output By Gate: (2, 10, 3)

 

• Weighted Sum

    # 6.获取最终输出 N x output
    part_input = []
    for output_by_gate in outputs_by_gate:
        # N x Expert => N x 1 x Expert
        expand_output_by_gate = tf.expand_dims(output_by_gate, axis=1)

        # N x 1 x Expert => N x Hidden x Expert
        repeat_gate_weight = K.repeat_elements(expand_output_by_gate, hidden_units, axis=1)

        # N x Hidden x Expert
        weighted_expert_output = tf.cast(outputs_by_expert, dtype='float32') * repeat_gate_weight

        # N x Hidden
        weighted_expert_sum = tf.reduce_sum(weighted_expert_output, axis=2)
        part_input.append(weighted_expert_sum)

    print("Part Input:", np.array(part_input).shape)

原始输入样本个数 N=10,输出 hidden_size=8,任务有2个,所以每个任务获得 BS x hidden_size 即 10 x 8 的 batch 样本。 

Part Input: (2, 10, 8)

 

• Sigmoid Output

这里两个任务对应两个 Tower,之前介绍了多输出模型:TF x Keras 之多输出模型

任务架构基于 Shared-bottom Multi-task Model,实现了同时预测年龄、收入、性别的多分类问题,有兴趣的同学可以把 Shared-bottom 的架构切换为多个 Expert 再加入 Gate 即可实现基础的 MMoE,这里就不再展开了。

 

• 完整代码

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import *
from tensorflow.keras.layers import Layer
from tensorflow.python.ops.init_ops import glorot_normal_initializer
from tensorflow.keras import backend as K


class Gate(Layer):

    def __init__(self, expert_num, **kwargs):
        self.expert_num = expert_num
        self.gate = None

        super(Gate, self).__init__(**kwargs)

    def build(self, input_shape):
        self.gate = Dense(self.expert_num, activation='relu', kernel_initializer=glorot_normal_initializer)

        super(Gate, self).build(input_shape)

    def call(self, _inputs, **kwargs):
        weight = self.gate(_inputs)
        _output = tf.nn.softmax(weight)
        return _output


def MMOE(num_field, hidden_units, num_expert, num_output):
    # 1.构造 Input => N x F
    num_samples = 10
    inputs = np.array([np.ones(shape=num_field) for i in range(num_samples)])
    print("Input Shape:", inputs.shape)

    # 2.构建专家 kernel [(filed * hidden) x expert]
    expert_kernels = np.random.random(size=(num_field, hidden_units, num_expert))
    print("Expert Shape:", expert_kernels.shape)

    # 3.获取 Expert 输出 => [N x F] * [F x Hidden x Expert] = N x Hidden x Expert
    outputs_by_expert = tf.tensordot(inputs, expert_kernels, axes=1)
    print("Output By Expert:", outputs_by_expert.shape)

    # 4.构建 Gate
    gate = [Gate(num_expert) for i in range(num_output)]

    # 5.获取每个任务的 Gate 输出权重 output x N x expert
    outputs_by_gate = np.array([gate[i](inputs) for i in range(num_output)])
    print("Output By Gate:", outputs_by_gate.shape)

    # 6.获取最终输出 N x output
    part_input = []
    for output_by_gate in outputs_by_gate:
        # N x Expert => N x 1 x Expert
        expand_output_by_gate = tf.expand_dims(output_by_gate, axis=1)

        # N x 1 x Expert => N x Hidden x Expert
        repeat_gate_weight = K.repeat_elements(expand_output_by_gate, hidden_units, axis=1)

        # N x Hidden x Expert
        weighted_expert_output = tf.cast(outputs_by_expert, dtype='float32') * repeat_gate_weight

        # N x Hidden
        weighted_expert_sum = tf.reduce_sum(weighted_expert_output, axis=2)
        part_input.append(weighted_expert_sum)

    print("Part Input:", np.array(part_input).shape)


if __name__ == '__main__':
    num_field, hidden_units, num_expert, num_output = 4, 8, 3, 2
    MMOE(num_field, hidden_units, num_expert, num_output)

四.总结

MMoE 几个 Expert 最终输出维度相同,Gate 输出维度与 Expert 数量相同,通过分析 Gate 的输出概率可以看出不同 Expert 对不同 Output 的测出,也可以控制 loss_weights 显式的指定某个 Output 占据主导地位。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/467089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

http---HTTP缓存

目录 1、缓存介绍 2、http缓存 3、强缓存 4、协商缓存 1、缓存介绍 缓存:存储将被用的数据,让数据访问更快。 缓存相关术语 命中:在缓存中找到了请求的数据不命中/穿透:缓存中没有需要的数据命中率:命中次数/总…

Yarn(Yet Another Reource Negotiator)另一个资源协调者

官网引用 总结性 产生的需求 YARN工作逻辑 通用的资源管理系统,为上一层应用提供统一的资源管理和调度。解决集群资源利用率,数据共享,资源管理统一问题,yarn取代Job Tracker角色 组件说明 Client 向RM提交任务,终…

1、软件测试概述

1、软件测试概述 一、软件生命周期二、软件开发模型1、瀑布模型2、增量模型3、原型模型4、敏捷开发 三、软件质量1、软件质量概念2、影响软件质量的因素 一、软件生命周期 软件生命周期分为多个阶段,每个阶段有明确的任务,通常,可将软件生命…

ARM寄存器组织

ARM有37个32位长的寄存器: 1个用做PC(Program Counter); 1个用做CPSR(Current Program Status Register); 5个用做SPSR(Saved Program Status Registers); 30个通用寄存器。 AR…

Unity之OpenXR+XR Interaction Toolkit实现 射线和物体交互事件回调

前言 前面我们介绍了如何抓取物体,今天我们来说一下如何和3D的物体进行交互,得到接触的事件回调。 交互的两种方式: 1.直接抓取或者射线抓取物体,得到接触回调 2.射线或者手部触摸物体后,得到接触回调 准备工作 有了…

Android 10.0 设置默认launcher后安装另外launcher后默认Launcher失效的功能修复

1.前言 在10.0的系统rom定制化开发中,在系统中有多个launcher的时候,会在开机进入launcher的时候弹窗launcher列表,让用户选择进入哪个launcher,这样显得特别的不方便 所以产品开发中,要求用RoleManager的相关api来设置默认Launcher,但是在设置完默认Launcher以后,在安…

嵌入式软考备考_3 嵌入式操作系统概述

嵌入式操作系统概述 工作在嵌入式环境中的操作系统 Embedded Operating System。 嵌入式和一般操作系统区别: 非通用操作系统,用于完成特定功能;性能实时性能源可靠性要求高;占用资源少;可剪裁,可配置。…

渗透测试 | Web信息收集

0x00 免责声明 本文仅限于学习讨论与技术知识的分享,不得违反当地国家的法律法规。对于传播、利用文章中提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,本文作者不为此承担任何责任,一旦造成后果请自行承担…

《程序员面试金典(第6版)》面试题 16.04. 井字游戏(棋盘类问题,C++)

题目描述 设计一个算法,判断玩家是否赢了井字游戏。输入是一个 N x N 的数组棋盘,由字符" ",“X"和"O"组成,其中字符” "代表一个空位。 以下是井字游戏的规则: 玩家轮流将字符放入空位…

专门为麻醉科和手术室开发的:手术麻醉系统源码,系统稳定,功能完整,支持二次开发

手术麻醉系统源码:C# .net 桌面软件 C/S版 系统极其稳定,扩展性强,已在多家医院运营。 文末获取联系 手术麻醉信息管理系统是专门为麻醉科和手术室开发的围手术期临床信息管理系统,具备以下功能: 1.手术程管理系统整合了手术室、…

人工智能实践: 基于T-S 模型的模糊推理

模糊推理是一种基于行为的仿生推理方法, 主要用来解决带有模糊现象的复杂推理问题。由于模糊现象的普遍存在, 模糊推理系统被广泛的应用。模糊推理系统主要由模糊化、模糊规则库、模糊推理方法以及去模糊化组成, 其基本流程如图1所示。

C++(继承下)

目录: 1.继承与有元 2.继承与静态成员 3.单继承、多继承 4.如何定义一个不能被继承的类?? 5.分享有意思的一道题 6.菱形继承及菱形虚拟继承 --------------------------------------------------------------------------------------------…

【c语言】全局变量 | 局部变量的生命周期与作用域

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 给大家跳段街舞感谢支持&#xff01;ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ…

亚马逊云科技赋能客户,为海思科打造安全高效的营销业务中台系统

羽翼渐丰&#xff0c;翱翔云端 携手亚马逊云科技&#xff0c;打造互联网级企业解决方案 秉承“为客户创造价值”的理念&#xff0c;在公司发展过程中&#xff0c;博智信息先后服务了众多知名企业&#xff0c;客户行业覆盖制造、零售、餐饮、科技、电子等。经过近20年的发展&a…

AutoGPT 是 prompt 工程的下一个前沿

前言 最近了解到Auto GPT的上线&#xff0c;下面我来整理一下整个体验过程&#xff0c;希望对大家有所帮助和启发。 首先Auto GPT是 OpenAI 的 Andrej Karpathy 都在大力宣传的一个开源项目&#xff0c;他认为 AutoGPT 是 prompt 工程的下一个前沿。 近日&#xff0c;AI 界貌…

自助式数据分析平台:jvs数据智仓-统计报表的使用条件及界面介绍

统计报表界面介绍 统计报表是指利用表格和报表等形式&#xff0c;将数据以清晰的结构和布局的方式呈现出来&#xff0c;以便用户进行数据分析和决策制定的一种BI统计方法。表格式的BI统计通常采用交叉表格、分组表、报表等形式&#xff0c;对数据进行整合、分析和展示&#xff…

【数据库数据恢复】ndf文件损坏的SQL SERVER数据库数据恢复案例

数据库数据恢复环境&#xff1a; 某公司存储上部署SQL SERVER数据库&#xff0c;数据库中有1000多个文件&#xff0c;该SQL SERVER数据库每10天生成一个NDF文件&#xff0c;数据库包含两个LDF文件。 数据库故障&分析&#xff1a; 存储设备出现故障导致SQL SERVER数据库异常…

IDEA实用设置

1、设置全局编码统一为UTF-8 file>setting中搜索框输入file encoding修改格式为UTF-8 2、设置文字大小 file>setting中搜索框输入font修改字体大小 3、配置maven file>setting中搜索框输入maven修改maven的路径、conf文件、文件仓库 4、idea中实现Serializable提示…

RabbitMQ通讯方式

RabbitMQ通讯方式 RabbitMQ提供了很多中通讯方式&#xff0c;依然可以去官方查看&#xff1a;https://rabbitmq.com/getstarted.html 七种通讯方式 1 RabbitMQ提供的通讯方式 Hello World!&#xff1a;为了入门操作&#xff01;Work queues&#xff1a;一个队列被多个消费者…

三十五、垃圾回收器

一、GC分类于性能指标 垃圾回收器的分类 1.串行回收指的是在同一时间段内只允许有一个CPU用于执行垃圾回收操作&#xff0c;此时工作线程被暂停&#xff0c;直至垃圾收集工作结束。 1)在诸如单CPU处理器或者较小的应用内存等硬件平台不是特别优越的场合&#xff0c;串行回收器…