深度学习 - 44.MMOE 与 Gate 之多目标学习

news2024/11/15 15:35:20

目录

一.引言

二.摘要 Abstract

三.介绍 Introduction

四.相关工作 RELATED WORK

1.DNN 中的多任务学习

2.SubNet 集成与 Expert 混合

3.多任务学习应用

五.建模方法 MODELING APPROACHES

1.Shared-bottom Multi-task Model

2.Mixture-of-Experts

3.Multi-gate Mixture-of-Experts

六.数据实验 REAL DATA EXPERIMENTS

1.基线模型 Baseline

2.参数调优 Hyper-Parameter

3.人口收入数据 Census-income Data

4.大规模内容推荐 Large-scale Content Recommendation

七.总结 CONCLUSION


 

一.引言

MMOE 是 Google 在 2018 年提出的一篇基于多任务学习的论文,全名为:Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts,其介绍了通过引入 Gate 实现不相关任务的多任务学习问题,下面对论文做简要回顾。

 

二.摘要 Abstract

基于神经网络的多任务学习已经成功地应用于许多现实世界的大规模应用,例如如推荐系统中我们可以建立深度模型,同时学习用户对 Item 的点击率、点赞率、收藏率等等。

dcf9473118b0497b9bfa03f9e53dc37d.png

现实场景下,多任务模型的预测质量往往对任务之间的关系很敏感。论文提出一种新的多任务学习方法:Multi-gate Mixture-of-Experts (MMoE) 模型,明确地学习从数据中建模任务关系。我们将 Mixture-of-Experts (MoE) 混合专家结构用于多任务学习,通过在所有任务中共享专家子模型,同时还拥有一个经过训练的 Gate 门控网络来优化每个 MoE 的输出占比。

 

三.介绍 Introduction

• 多任务学习问题

许多大型推荐系统已经使用 DNN模型进行多任务学习,多任务学习模型可以通过使用正则化和迁移学习来提高对所有任务的模型预测。然而许多基于 DNN 的多任务学习模型对数据分布差异和任务间关系等因素都很敏感。且任务差异带来的内在冲突实际上会损害至少部分任务的预测,特别是在所有任务之间广泛共享模型参数时。

 

• 原有解决方式

最早的研究通过假设每个任务的特定数据生成过程来研究多任务学习中的任务差异,根据假设测量任务差异,然后根据任务差异的大小提出建议。最近的一些研究提出了新的建模技术来处理多任务学习中的任务差异,而不依赖于明确的任务差异测量。然而,这些技术通常涉及到为每个任务添加更多的模型参数,以适应任务差异,由此新增的额外参数量级很大,额外的计算成本在生产环境中很难接受。

 

• 多门-混合专家模型

Multi-gate Mixture-of-Experts (MMoE) 结构的多任务学习方法,受到 Mixture-of-Experts (MoE)模型和最近的MoE层的启发。MMoE显式地对任务关系建模,并学习特定于任务的功能,以利用共享表示。它允许自动分配参数以捕获共享任务信息或特定于任务的信息,从而避免了为每个任务添加许多新参数的需要。

 

四.相关工作 RELATED WORK

1.DNN 中的多任务学习

多任务模型可以学习不同任务之间的共性和差异,这样做可以提高每个任务的效率和模型质量。

- 共享底层模型结构

模型具有共享底层模型结构,底层隐藏层在任务之间共享。这种结构极大地降低了过度调试的风险,但由于所有任务共享参数可能由于任务间的差异对优化条件造成影响。

 

- 合成数据生成

为了了解任务相关性对模型质量的影响,已有研究采用合成数据生成的方法,对不同类型的任务相关性进行操作,以评估多任务模型的有效性。

 

- 增加对应任务参数

最近的一些方法在特定于任务的参数上添加不同类型的约束。例如在两组参数之间添加 L2 约束、为每个任务学习特定的隐藏层嵌入组合、使用张量分解模型为每个任务生成隐层参数等等。相比于共享参数,该模式下的不同任务拥有更多特定参数,可以获得更好的性能。然而,大量的任务特定参数需要更多的训练数据以及工程师对业务任务的深刻理解,对大规模推荐模型不太友好。

 

2.SubNet 集成与 Expert 混合

在 DNN 中,将混合专家模型转化为基本的构建模块 (MoE层),并将它们堆叠在DNN中已被证明能够提高模型性能。MoE 层根据该层在训练时间和服务时间的输入选择 Expert 即 SubNet。通过引入门控网络的稀疏性,该模型不仅建模能力更强,而且降低了计算成本。通过使用 SubNet (专家)集成来实现迁移学习,同时节省计算量。

 

3.多任务学习应用

由于分布式机器学习系统的发展,许多大规模的现实应用已经采用了基于 DNN 的多任务学习算法,并观察到质量的显著提高。在多语言机器翻译任务中,由于模型参数共享,训练数据有限的翻译任务可以通过与训练数据量大的任务联合学习来改进。在构建推荐系统时,多任务学习被发现有助于提供上下文感知的推荐。与这些先前的工作类似,我们在现实世界的大规模推荐系统上评估了MMoE,该模型方法确实是可伸缩的,并且与其他最先进的建模方法相比具有良好的性能。

 

五.建模方法 MODELING APPROACHES

1fe3771f089749abae83a737f9ecdb15.png

1.Shared-bottom Multi-task Model

如图 a 所示,该模型架构在许多多任务学习中广泛采用,论文将模型视为多任务建模中具有代表性的基线方法。给定 K 个任务,该模型由一个共享底部网络 (表示为函数 f) 和 K 个 Tower 塔网络组成,其中 K = 1,2,... 模型共享底层网络,塔式网络建立在共享底层的输出上,然后每个任务 Tower 的淡出输出与 Output K 遵循对应的任务。对于任务 K,模型可表示为:

b7e758b1c36c4243ba17e57bf3dcf8ca.png

其中共享体现在公用一个 f(x),多任务体现在多个 gif.latex?h%5EK 上,以开头的点击率、点赞率的多任务为例:

600eb12931124bcd92964e83f018a7c6.png

 

2.Mixture-of-Experts

混合专家模型可以采用如下公式表示:

0f2f931eee334661950276eb1926b91f.png

其中 g(x)i 表示 g(x) 输出的第 i 个 logit,表示对应专家 fi 的概率。这里 f 是 n 个专家网络,可以理解为集成学习的多个基学习器,g 表示集合所有专家结果的门控网络。更具体地说,门控网络 g 根据输入产生了专家的概率分布,而最终输出是所有专家输出的加权和。MoE Layer 具有与 MoE 模型相同的结构,但接受前一层的输出作为输入和输出到后续层。然后以端到端的方式训练整个模型。

 

3.Multi-gate Mixture-of-Experts

多门专家混合 (MMoE) 模型,其关键思想史将共享底层网络替换为 MoE 层,同时为每个任务添加一个单独的门控网络 Gate K 用于捕捉不同任务时不同专家的贡献度:

1875e0c8bbba49b6940dfe292cb6b2e7.png

图 c 显示了 MMoE 模型结构,实现由具有 ReLU 激活的相同多层感知器组成。门控网络是输入的简单线性变换,带有softmax层:

31fff55efa414024b2e2950089918b44.png

 其中 gif.latex?y_k 为最终输出,gif.latex?h%5Ek 为 K 个任务,gif.latex?f%5Ek%28x%29 为第 K 个任务的多个 Expert 的混合输出,每个 gif.latex?f_i%28x%29 对应一个 Expert,gif.latex?g%5Ek%28x%29_i 代表门控网络基于第 K 个任务生成的 Expert 专家概率分布。

 

Tips:

• 相比于 Shared-bottom Multi-task Model

与 Shared-bottom Model 相比,MMoE 的 N 个 Expert  网络也是 K 个任务共享的,而门控网络通常是轻量级的所以 MMoE 与一些多任务 Baseline 在计算量和参数量上并无太多差异,适用于工业场景。

 

• 相比于 Mixture-of-Experts

相比于 MoE K 个任务共享一个 Gate 门控网络外,MMoE 为每个任务准备一个单独的 Gate 门控网络,比较符合 Expert 对不同任务的权重存在差异的假设,通过不同的 Gate 可以学习到不同的 Expert 组合方式,从而捕捉到任务之间的相关性与差异。

 

六.数据实验 REAL DATA EXPERIMENTS

1.基线模型 Baseline

除了Shared-Bottom 多任务模型,我们还将我们的方法与几个最先进的多任务深度神经网络模型进行了比较,这些模型试图从数据中学习任务关系。

- L2-Constrained 正则

这种方法是为具有两个任务的跨语言问题而设计的。在这种方法中,用于不同任务的参数 θ 由 L2 约束软共享。设 yk 为任务 k 的真值标记,k∈1、2,则任务 k 的预测表示为:

380aca146bec42d9a457d01f9e9d8ffb.png

θ 为模型参数,最终目标函数为:

691a32beafa24345b579524f3c19ce20.png

其中 y1、y2 是任务 1 和任务 2 的基本真值标签,α 是超参数。该方法利用 α 的大小对任务关联度进行建模。

 

- Cross-Stitch "十字绣" 网络

这种方法通过引入 "十字绣" 单元来共享两个任务之间的知识。十字绣单元从任务 1 和任务 2 中获取分离的隐藏层 x1 和 x2 的输入,分别通过以下公式输出:

edd555ee56f441c4841bb1e5f2531953.png

其中,gif.latex?%5Calpha_%7Bj%2Ck%7D  j, k = 1、2 是一个可训练参数,表示任务 k 到任务 j 的交叉传递。任务 1 和任务 2 分别向上级发送任务 x1 和任务 x2。

 

- Tensor-Factorization 张量因子分解

该方法将多个任务的权重建模为张量,并利用张量分解方法实现任务间参数共享。这里实现了 Tucker 分解来学习多任务模型,例如,给定输入隐藏层大小为 m,输出隐藏层大小为 n,任务数为 k,则 m × n × k 张量的权重 W 由下式导出:

c8278723704f421f805b086bade19507.png

其中大小为 r1 × r2 × r3 的张量S,大小为 m × r1的矩阵 U1,大小为 n × r2 的 U2,大小为 k ×  r3 的 U3 是可训练参数。所有这些都是通过标准反向传播一起训练的。r1、r2 和 r3 是超参数。

 

2.参数调优 Hyper-Parameter

为了使比较公平,我们通过为每层隐藏单元的数量设置相同的上界 (2048) 来约束所有方法的最大模型大小。对于MMoE,它是 "专家数量" × "每个专家的 Hidden"。我们的方法和所有基线方法都是使用 TensorFlow 实现的。我们调整了所有方法的学习率和训练步骤数。我们还调优了一些特定于方法的超参数:

• MMOE: 专家数量,每个专家隐藏单位的数量。

• L2-Constrained: 隐层的大小。L2约束的权值 α。

• Cross-Stitch: 隐层的大小。L2约束的权值 α。

• Tensor-Factorization: r1, r2, r3 表示 Tuck 分解,隐藏层大小。

 

3.人口收入数据 Census-income Data

UCI 人口普查收入数据集提取自 1994 年人口普查数据库。它包含 299285 个美国成年人的人口统计信息实例,总共有40个特性。通过设置一些特征作为预测目标,我们从这个数据集中构建了两个多任务学习问题,并计算了 10,000 个随机样本上任务标签的 Pearson 相关性绝对值:

- Multi-Task-A

Output1:预测收入是否超过$50K;

Output2:预测此人的婚姻状况是否从未结过婚。

绝对 Pearson 相关性:0.1768。

- Multi-Task-B

Output1:预测学历是否为大学以上;

Output2:预测此人的婚姻状况是否从未结过婚。

绝对 Pearson 相关性:0.2373。

bf23bd2a1efd41ed8796ddf004963b19.png

a06159abc90e4c2686e0077bcf146457.png

考虑到任务相关性 (通过 Pearson 相关性粗略测量) 在两组中都不是很强,Shared-Bottom 模型几乎总是多任务模型中最差的 (除了张量分解)。L2-Constrained 和 Cross-Stitch 对每个任务都有单独的模型参数,并对如何学习这些参数添加了约束,因此比 Shared-Bottom 表现得更好。在第二组中,MMoE 在所有方面都优于其他多任务模型。

 

4.大规模内容推荐 Large-scale Content Recommendation

模型在 Google 的大型内容推荐系统上进行实验,其中推荐是由数十亿用户的数亿个独特项目生成的。具体来说,给定用户当前消费某种商品的行为,该推荐系统的目标是向用户显示下一步消费的相关商品列表。我们设置的深度排名模型是针对两种类型的排名目标进行优化的:

(1) 优化与用户粘性相关的目标,如点击率和用户粘性时间。

(2) 对满意度相关目标进行优化,如相似率等。

我们的训练数据包括数千亿的用户隐式反馈,比如点击和点赞。如果单独训练,每个任务的模型需要学习数十亿个参数。因此,与单独学习多个目标相比,Shared-Bottom 架构具有较小的模型大小的好处。事实上,这种 Shared-Bottom 模型已经在生产中使用了。

 

• 模型试验

aa4194b9d7984b5d860c034ba052aca2.png

表中显示了参与度子任务的 AUC 分数和 R-Squared 分数。@2M 代表训练 200万步,其中包含 100 亿个样本、BatchSize = 1024,@4M、@6M 同理。MMoE 在这两个指标上都优于其他模型。

 

• Gate 理解

为了更好地理解门是如何工作的,下图展示了每个任务的 softmax 门的分布。可以看到,MMoE 学习了这两个任务之间的差异,并自动平衡了共享和非共享参数。由于满意度子任务的标签比参与度子任务的标签更稀疏,因此满意度子任务的大门更关注于单个专家:

1287724cc4d74469be8276f3ca088f2d.png

 

七.总结 CONCLUSION

多门专家混合模型 Multi-gate MoE, MMoE 明确地从数据中建模学习任务关系,该方法可以更好地处理任务相关性较低的情况。且 MMoE 的训练更容易,效果也优于常见的 BaseLine 多任务学习模型。

• Expert 模型

常规情况下每一个 Expert 是一个小规模的全连接神经网络,不同 Expert 都有着不同的预测方向与相同的输出维度,当然放在广义的集成学习上看,Expert 可以是任何模型,只要输出维度相同即可。

• Gate Layer

Gate Layer 生成 Expert 概率分布并进行加权求和,这里引入了集成的思想,类似于多个 Expert 贡献力量。极端情况下可以修改为只激活一个 Expert 的意见,此时退化为常规模型。我们可以 Gate 输出分析不同 Expert 对不同任务的偏向。

• 任务相关性

文中多次提高多任务之间的相关性,对于相关性类似的任务,共享参数可以优化效率并且可复用性高,此时 MoE 与 MMoE 效果近似,而当任务相关性较低时,MMoE 的多 Gate 机制优与单 Gate 的 MoE,说明 Multi-Gate 的模式对于任务相关性不同造成的参数冲突有一定效果。

• Softmax 权重

Gate 门输出的权重是通过端到端训练得到的,其学习了数据中蕴含的任务逻辑,如果我们有先验信息认为某个 Expert 的意见比较可靠,则可以手动修正 Softmax 处得到的多个 Expert 的权重。

 

论文参考:https://dl.acm.org/doi/10.1145/3219819.3220007

多任务学习讲解:多目标模型讲解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/464371.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

美团B端“加速度”

配图来自Canva可画 一提起本地生活服务,相信绝大多数人并不会感到陌生,人们经常使用的餐饮外卖,便是本地生活服务的重要组成部分之一。而在消费者线上消费习惯逐渐养成、本地生活服务需求日渐增长等多方因素的共同影响下,本地生活…

BUUCTF pwn——picoctf_2018_rop chain

checksec && 运行 ida main函数调用vuln函数 名为vuln的函数存在溢出 名为flag的函数,通过校验可直接getflag 具体校验过程看图,只有win1和win2均为真,并且a1的值等于0xDEADBAAD才能getflag 变量win2的真假性通过win_function2函…

三个练手的软件测试实战项目(附全套视频跟源码)偷偷卷死他们

项目一:12306抢票项目 项目测试目的 学会Selenium定位web元素的方法 熟练浏览器调试工具使用 项目主体步骤 1) 人工走一遍流程,对自动化的流程心中有数 2) 按步骤拆分,然后对每一个小步骤编写自动化脚本 3&#xff…

FreeRTOS(三)——应用开发(一)

文章目录 0x01 FreeRTOS文件夹FreeRTOSConfig.h文件内容上面定义的宏决定FreeRTOS.h文件中的定义0x02 创建任务创建静态任务过程configSUPPORT_STATIC_ALLOCATION创建动态任务过程configSUPPORT_DYNAMIC_ALLOCATION 0x03 FreeRTOS启动流程启动流程概述 0x04 任务管理任务调度器…

python基于轻量级YOLOv5的生猪检测+状态识别分析系统

在我之前的一篇文章中有过生猪检测盒状态识别相关的项目实践,如下: 《Python基于yolov4实现生猪检测及状态识》 感兴趣的话可以自行移步阅读,这里主要是基于同样的技术思想,将原始体积较大的yolov4模型做无缝替换,使…

关于python异常的总结

Python异常是在程序执行时发生的错误,可能会导致程序终止运行。 在Python中,异常处理是一种机制,它允许开发人员在程序发生异常时捕获、处理和报告这些异常,以便程序可以继续运行或在出现异常时进行优雅的退出。 在Python中&…

大数据之入门开发流程介绍

目录: 1、大数据的开发大致流程2、技术导图 1、大数据的开发大致流程 1.1 数据收集 大数据处理的第一步是数据的收集。现在的中大型项目通常采用微服务架构进行分布式部署,所以数据的采集需要在多台服务器上进行,且采集过程不能影响正常业务的…

Domino的线程ID和操作系统的进程ID对应关系

大家好,才是真的好。 很多时候,在Domino中运行的任务出现一些错误提示,如果能够准确定位到和提示信息相关任务时,对我们排错有着巨大的帮助,也能节省很多时间。 例如,我们可能在Domino实时控制台上看到以…

RedHat8配置本地YUM源

目录: RedHat8配置本地YUM源1、创建规则文件2、创建挂载点3、挂载ISO镜像(1).将iso镜像连接到虚拟机再进行挂载a.将ISO镜像连接虚拟机b.挂载镜像到挂载点c.使用df -h查看当前系统设备挂载情况 (2)将iso镜像上传至服务器再进行挂载a.将ISO镜像通过ftp工具上传b.挂载镜…

Spring Boot——优雅的参数校验

🎈 概述 当我们想提供可靠的 API 接口,对参数的校验,以保证最终数据入库的正确性,是 必不可少 的活。比如下图就是 我们一个项目里 新增一个菜单校验 参数的函数,写了一大堆的 if else 进行校验,或者基础校…

C#简单向:textbox添加提示内容

项目场景: 向C#窗体项目的textbox内添加提示内容,如下图所示效果: 具体实现: 首先: 1.到所要操作的文件(/xx.cs/xx.Designer.cs),这里我是到Form3.cs/Form3.Designer.cs文件 2.找到你所要操作的textBox&#xff0c…

数据结构与算法(一):基础数据结构(算法概念、数组、链表、栈、队列)

算法概念、数组、链表、栈、队列 判断一个数是否是2的N次方? N & (N-1) 0 (N > 0)算题: 力扣 https://leetcode.cn/POJ http://poj.org/ 算法 算法概念 算法代表: 高效率和低存储 内存占用小、CPU占用小、运算速度快 算法的高…

C# HttpClient使用JWT请求token调用接口,解决返回HTML网页的异常信息

一.项目目的: 1.使用JWT获取token,调用外部提供的接口,解决返回HTML错误信息。 错误缘由,接口服务器未能识别token,token信息不准确。 二.项目工具: Visual Studio(开发工具)&…

【Java|golang】1031. 两个非重叠子数组的最大和---前缀和+滑动窗口

给你一个整数数组 nums 和两个整数 firstLen 和 secondLen,请你找出并返回两个非重叠 子数组 中元素的最大和,长度分别为 firstLen 和 secondLen 。 长度为 firstLen 的子数组可以出现在长为 secondLen 的子数组之前或之后,但二者必须是不重…

专为Windows电脑和服务器设计的磁盘管理软件

关于Windows磁盘管理 磁盘管理是Windows自带工具,允许你对磁盘进行一些基本操作,Windows个人用户和Windows Server用户可以使用它来: 1. 创建一个新驱动器,如“新建简单卷”功能。 2. 将一个卷扩展到当前未被同一磁盘…

STM32CubeMX配置I2C通讯

1.如上图所示点击New Project 2.如上图所示选择自己所开发的新品最后双击芯片型号 3.配置RCC,我的芯片使用的是外部高速晶振。这里如图所选。 4.配置一下串口 5.配置I2C 6.根据自己的硬件选择时钟源和主频 6.①填写项目名②选择项目路径③选择开发环境④获取代码 …

Android build.gradle配置详解

Android Studio是采用gradle来构建项目的,gradle是基于groovy语言的,如果只是用它构建普通Android项目的话,是可以不去学groovy的。当我们创建一个Android项目时会包含两个Android build.gradle配置详解文件,如下图: …

2023 HDCTF --- Crypto wp

文章目录 Normal_RsaNormal_Rsa(revenge)爬过小山去看云Math_Rsa Normal_Rsa 题目: from Crypto.Util.number import * #from shin import flagmbytes_to_long(bHDCTF{****************}) e65537 pgetPrime(256) #qgetPrime(512) q67040062584277953042204504112809489262131…

Revit砌体排砖的几种方法对比

方法简介 传统砌体深化排砖是绘图者使用CAD 软件通过二维想象进行排布,在墙面转角、两面或多面墙相互咬砌的位置,门窗洞口过梁的位置,构造柱等位置由于二维图形的局限性很难观察出排布是否合理。然而复杂区域砌体排布若出错…

这个假期有这些游戏就不怕无聊了

1、塞尔达传说旷野之息 Switch端的优秀游戏体验不容错过! 人气王《塞尔达传说》! 被玩家誉为“唯一让人长大后有种回到童年的感觉的作品”。 豆瓣网友写道:“在雨夜,我在寺庙里看到了一条白龙划过天空,在岩壁上看到了…