MMoE: 基于多门专家混合的多任务学习任务关系建模

news2024/11/20 7:27:02

文章链接:Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts

发表会议: KKD 2018 (Knowledge Discovery and Data Mining,数据挖掘领域顶会)

目录

  • 1.背景介绍
    • Recommendation System
    • Multi-task Learning
    • Synthetic Data Generation
    • MMoE
  • 2.内容摘要
    • 关键技术
      • MMoE Modeling
      • Synthetic Dataset
    • 实验结果
      • Census-income Data
      • 大规模内容推荐
  • 3.文章总结

1.背景介绍

目前应用的MoE模型结构普遍采用单门控设计,在单门控(Single-Gate) 网络的基础上可以分为单栅(Single-Gated) 网络和多栅(Multi-Gated) 网络。这两类门控网络的主要区别在于每次选择专家的数量,前者每次只选择概率分布最大的一个专家,后者则选择概率分布上Top-K个专家。
那么有没有一种MoE架构采用的不是单门控网络,而是采用多门控(Multi-Gate) 网络呢?
Google scholar查找发现,在2018年KDD伦敦会议上,有学者提出了基于多门专家混合的多任务学习任务关系建模,也就是本次学习的文章。

在这里插入图片描述


Recommendation System

推荐系统(Recommendation System,RS)是一类利用算法和数据分析技术来预测用户对物品(如商品、服务、信息等)喜好程度的系统。其主要目的是帮助用户在海量信息中快速找到符合其个性化需求的内容,提升用户体验和用户满意度。

推荐算法:推荐系统采用各种不同的算法来生成推荐结果。主要的算法包括:

- 协同过滤:基于用户行为历史和用户群体行为的相似性来进行推荐。

- 内容过滤:基于物品的属性和用户的兴趣进行匹配,通常需要提前对物品进行特征提取。

- 矩阵分解:通过将用户-物品交互矩阵分解为两个低秩矩阵,来学习用户和物品的表示。

- 深度学习方法:通过深度学习模型来学习用户和物品的表示。

采用深度神经网络模型的推荐系统通常需要同时优化多个目标。例如,在向用户推荐电影时,我们不但希望用户不仅购买并观看电影,还希望他们在看完电影后做出评价或反馈。也就是说,我们想要创建一个模型来同时预测用户的购买行为和他们的评分,即多任务学习。


Multi-task Learning

多任务学习(Multi-Task Learning, MTL)是一种机器学习方法,其主要目标是通过同时学习多个相关任务来提高模型的性能。相比于单一任务学习,MTL 在训练过程中利用了任务之间的共享知识和信息,从而可以更好地泛化到新的任务或数据。

多任务学习模型的一般流程和关键要素:

要素解释(例子)
多个相关任务在CV领域,算法可以同时学习物体检测、物体分类和语义分割等任务
共享的特征表示模型通常会共享一些底层的特征表示来处理不同任务
通过在神经网络中共享部分层次或通过共享权重来实现
任务损失函数为每个任务定义一个相应的损失函数,用于衡量模型在该任务上的性能
总损失函数将所有任务的损失函数进行加权组合
以训练过程中,模型将会在多个任务上同时优化
权衡任务重要性通过分配不同的损失权重来调整模型对不同任务的关注程度

事实上,许多基于DNN的多任务学习模型对数据分布差异、任务间关系等因素都很敏感。来自任务差异的内在冲突实际上可能会损害至少部分任务的预测,即优化冲突(optimization conflict),特别是当所有任务广泛共享模型参数时。

为了提高任务间的差异化程度,一个简单的方式是为每个任务添加更多的参数模型。特定于任务的参数与共享的参数,两者相互配合,可以在训练时取得更好的性能,但这些额外参数会带来额外的计算成本。


Synthetic Data Generation

合成数据生成(Synthetic Data Generation,SDG)是一种通过模拟或生成数据来扩充现有数据集的方法。这些生成的数据在统计上与真实数据集相似,但是并非来自实际观测或采集,而是通过模型或规则生成的。

通过使用合成数据,我们可以很容易地测量和控制任务相关性。

要素解释(例子)
生成模型合成数据生成通常依赖于生成模型,这些模型可以是基于概率分布的模型
如生成对抗网络(GANs)、变分自动编码器(VAEs)等
数据模拟通过对已有数据进行统计建模,可以从模型中生成新的合成数据
数据扩增通过对现有数据进行变换或扩增来生成更多的样本。
在图像处理中,可以对图像进行平移、旋转、缩放等操作
数据合成将不同源的数据组合或融合在一起,从而生成新的数据集

MMoE

多门专家混合(Multi-gate Mix-of-Experts, MMoE)模型,由多个专家和多个门控网络组成。MMoE显式地对任务关系建模,并学习特定于任务的功能,以利用共享表示。它允许自动分配参数来捕获共享的任务信息特定于任务的信息,从而避免了为每个任务添加许多新参数的需要。

在这里插入图片描述
MMoE的主干是建立在最常用的共享底部多任务DNN结构上的。共享底层模型结构如图(a)所示,其中输入层之后的几个底层在所有任务之间共享,然后每个任务在底层表示的顶部都有一个单独的网络“塔”。

MMoE模型,图(c)所示,有一组底层网络,但不是所有任务共享一个底层网络,底层网络被划分为多个专家,每个专家是一个前馈网络。

每个任务引入一个门控网络。门控网络采用输入特征和输出softmax门,将具有不同权重的专家组装在一起,允许不同的任务以不同的方式利用专家。然后将组装的专家的结果传递到特定于任务的塔网络中。通过这种方式,针对不同任务的门控网络可以学习专家集合的不同混合模式,从而捕获任务关系。


2.内容摘要

文章提出了一种新的多任务学习方法,多门专家混合(MMoE),它显式地学习从数据中建立任务关系的模型。文章将MMoE结构应用于多任务学习,在所有任务之间共享专家子模型,同时训练一个门控网络来优化每个任务。
为了在具有不同任务相关性级别的数据上验证MMoE的有效性,文章首先将其应用于一个合成数据集,在该数据集上我们控制任务相关性。当任务相关性较低时,所提出的方法比基线方法表现得更好。 根据训练数据和模型初始化的不同程度的随机性,MMoE结构还带来了额外的可训练性优势。

关键技术

MMoE Modeling

在这里插入图片描述
给定 K K K (上图 K = 2 K=2 K=2)个任务,模型由一个共享底层网络(表示为函数 f f f,设 i i i 为专家编号, k k k 为其中一个任务,则共享层输出可表示为 f k ( x ) = ∑ i = 1 n g k ( x ) i ⋅ f i ( x ) f^{k}(x)=\sum_{i=1}^{n} g^{k}(x)_{i}\cdot f_{i}(x) fk(x)=i=1ngk(x)ifi(x) 其中门控网络 g k g^{k} gk 是一个通过softmax激活的简单线性变换, g k ( x ) = s o f t m a x ( W g k x ) g^{k}(x)=softmax(W_{gk}x) gk(x)=softmax(Wgkx) W g k ∈ R n × d W_{gk} \in \mathbb{R}^{n\times d} WgkRn×d 是一个可训练的矩阵。 n n n 为专家数量, d d d 为特征维数。
在通过门控网络和共享层计算后,对应的任务 k k k 会进入到网络“塔” h k h^{k} hk 进行最后计算。对于任务 k,模型输出可以表示为:
y k = h k ( f k ( x ) ) y_{k}=h^{k}(f^{k}(x)) yk=hk(fk(x))

每个门控网络 g k g^{k} gk 可以学习“选择”一个专家子集。在MTL的情境下充当参数共享的概率分布。举个例子,如果每个门控网络只选择一个得分最高的专家,那么每个门网络实际上将输入空间线性划分为n个区域,每个区域对应一个专家。

MMoE能够以一种复杂的方式对任务关系进行建模,来确定不同的门所导致的分离如何相互重叠。如果任务的相关性较低,那么共享专家将受到惩罚,这些任务的控制网络将学会使用不同的专家。

与共享底层模型相比,MMoE只有几个附加的门控网络,门控网络中模型参数的数量可以忽略不计。因此,整个模型仍然尽可能地享受多任务学习中知识转移的好处。


Synthetic Dataset

多任务学习模型的性能在很大程度上依赖于数据中固有的任务关联性。在实际应用中,很难直接研究任务关联性如何影响多任务模型,因为我们无法轻易地改变任务间的关联性并观察其效果。

为了建立对这一关系的实证研究,本文使用合成数据,可以很容易地测量和控制任务相关性。本文生成了两个回归任务,并使用这两个任务标签的皮尔逊相关性作为任务关系的定量指标。

合成数据步骤如下:

  1. 给定输入特征维数 d d d,生成两个正交的单位向量 u 1 , u 2 ∈ R d u_{1},u_{2} \in \mathbb{R^{d}} u1,u2Rd
    u 1 T u 2 = 0 , ∣ ∣ u 1 ∣ ∣ 2 = 1 , ∣ ∣ u 2 ∣ ∣ 2 = 1 u^{T}_{1}u_{2}=0,||u_{1}||_{2}=1,||u_{2}||_{2}=1 u1Tu2=0,∣∣u12=1,∣∣u22=1
  2. 给定一个比例常数 c c c 和一个皮尔逊相关值 − 1 ≤ p ≤ 1 -1\le p \le 1 1p1 生成两个权重向量 w 1 w_{1} w1 w 2 w_{2} w2:
    w 1 = c ⋅ u 1 , w 2 = c ⋅ ( p u 1 + ( 1 − p 2 ) u 2 ) w_{1}=c\cdot u_{1},w_{2}=c \cdot (pu_{1}+\sqrt{(1-p^{2})u_{2}} ) w1=cu1,w2=c(pu1+(1p2)u2 )
  3. 对输入数据点 x ∈ R … … d x∈R……{d} xR……d 的每个元素从 N ( 0 , 1 ) N(0,1) N(0,1) 中随机抽样
  4. 为两个回归任务生成两个标签 y 1 , y 2 y1,y2 y1,y2:
    y 1 = w 1 T x + ∑ i = 1 m s i n ( α w 1 T x + β i ) + ϵ 1 y_{1}=w^{T}_{1}x+\sum_{i=1}^{m} sin(\alpha w^{T}_{1}x+\beta_{i})+\epsilon _{1} y1=w1Tx+i=1msin(αw1Tx+βi)+ϵ1 y 2 = w 2 T x + ∑ i = 1 m s i n ( α w 2 T x + β i ) + ϵ 2 y_{2}=w^{T}_{2}x+\sum_{i=1}^{m} sin(\alpha w^{T}_{2}x+\beta_{i})+\epsilon _{2} y2=w2Tx+i=1msin(αw2Tx+βi)+ϵ2其中 α i , β i , i = 1 , 2 , . . . , m \alpha_{i},\beta_{i},i=1,2,...,m αi,βi,i=1,2,...,m 是控制正弦函数形状的给定参数,并且
    ϵ 1 , ϵ 2 ∼ i . i . d N ( 0 , 0.01 ) \epsilon _{1},\epsilon _{2}\overset{i.i.d}{\sim} \mathcal{N} (0,0.01) ϵ1,ϵ2i.i.dN(0,0.01)
  5. 重复(3)和(4),直到生成足够的数据

实验结果

Census-income Data

本次测试数据集为”人口普查收入“数据集,UCI人口普查收入数据集是从1994年人口普查数据库中提取的。它包含了299285个美国成年人的人口统计信息实例。

本文构建了两个多任务学习问题,将其中的一些特征设置为预测目标,并计算了超过10,000个随机样本的任务标签的Pearson相关性绝对值:

  1. Group1:
    任务 1: 预测收入是否超过50K$;
    任务 2: 预测这个人的婚姻状况是否从未结婚。
    P = 0.1768
  2. Group2:
    任务 1: 预测教育水平是否至少为大学;
    任务 2: 预测这个人的婚姻状况是否从未结婚。
    P = 0.2373

在这里插入图片描述
表1和表2显示了两组的结果,报告了运行400次以上的平均AUC。

考虑到任务关联性(粗略地用皮尔逊相关性来衡量)在两组中都不是很强,共享底层模型几乎总是多任务模型中最差的。

L2 - restricted和Cross-Stitch都为每个任务提供了单独的模型参数,并为如何学习这些参数添加了约束,因此比Shared-Bottom执行得更好。

然而,模型参数学习的约束很大程度上依赖于任务关系假设,与MMoE使用的参数调制机制相比,任务关系假设的灵活性较差。


大规模内容推荐

在谷歌Inc.的一个大型内容推荐系统上进行实验,该系统中的推荐来自于为数十亿用户生成的数亿个独立条目。具体来说,给定用户当前消费商品的行为,该推荐系统的目标是向用户显示下一个要消费的相关商品的列表。

设置的DNN被训练为优化两种类型的排名目标:

  1. 优化与参与度相关的目标,如点击率和参与度时间;
  2. 优化满意度相关目标,如满意率等。

在这里插入图片描述

上表显示了训练200万步(100亿个示例,批量大小为1024)、400万步和600万步之后的结果。MMoE在这两个指标上都优于其他模型。L2-Constrained 和 Cross-Stitch比 Shared-Bottom 更糟糕。这可能是因为这两个模型是建立在两个单独的单任务模型之上的,并且有太多的模型参数需要很好的约束。

在这里插入图片描述上图显示了每个任务的softmax门的分布。MMoE学习了这两个任务之间的差异,并自动平衡了共享和非共享参数。

由于满意度子任务的标签比参与度子任务的标签更稀疏,因此满意度子任务的门更集中于单个专家。


3.文章总结

本文提出了一种新的多任务学习方法——多门混合专家模型(MMoE),它可以从数据中显式地学习模型任务关系。通过合成数据的控制实验表明,所提出的方法可以更好地处理任务相关性较低的场景。通过对基准数据集和一个实际的大规模推荐系统的实验,本文证明了该方法在几种最先进的基线多任务学习模型上的成功。
此外,MMoE模型在很大程度上保留了计算优势,因为门控网络通常是轻量级的,并且专家网络在所有任务中都是共享的。如果模型通过使门控网络成为一个稀疏的Top-K门,有可能实现更好的计算效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1112150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

wps excel js编程

定义全局变量 const a "dota" function test() {Debug.Print(a) }获取表格中单元格内容 function test() {Debug.Print("第一行第二列",Cells(1,2).Text)Debug.Print("A1:",Range("A1").Text) }写单元格 Range("C1").Val…

【UE】两步实现“从UI中拖出Actor放置到场景中”

效果 步骤 1. 创建两个actor蓝图 在两个蓝图中分别添加立方体和球体形状的静态网格体组件,注意移动性设置为“可移动” 设置碰撞预设为“NoCollsion” 2. 先创建一个控件蓝图 打开控件蓝图,在画布面板中添加两个按钮 为按钮添加“按压时”和“松开时”的…

手工测试的迷茫:除了重复劳动,到底还有什么?

我是在2008年毕业的,三本的学校,不上不下的专业水平,毕业的时候,恰好遇到了金融危机。校园招聘里阴差阳错的巧合,让我走上了软件测试工程师的道路。 入职第一天,来了个高大上的讲师,记得他是这…

easyphoto 妙鸭相机

AIGC专栏7——EasyPhoto 人像训练与生成原理详解-CSDN博客如何训练一个高品质的人像Lora与应用高品质Lora的链路对于写真生成而言非常重要。由《LoRA: Low-Rank Adaptation of Large Language Models》 提出的一种基于低秩矩阵的对大参数模型进行少量参数微调训练的方法&#x…

【牛客网】HJ91.走方格的方案数

题目 思路 考虑特殊情况,假设行数为m1,列数为n 则最短路径为mn 假设行数为m,列数n1,则最短路径为mn 考虑普遍情况 假设行数为m,列数为n 则总路经数为行数为m-1列数为n和行数为m列数为n-1的两个的和 根据上述条件,可以考虑使用递归的方式进行解决 代码 import java.util.Scan…

springmvc视图格式——模板引擎freemarker输出HTML文本

目录 1. freemarker 介绍创建测试工程2.2.2) 配置文件2.2.3) 创建模型类2.2.4) 创建模板2.2.5) 创建controller2.2.6) 创建启动类2.2.7) 测试 2.3) freemarker基础2.3.1) 基础语法种类2.3.2) 集合指令(List和Map)2.3.3) if指令2.3.4) 运算符2.3.5) 空值处…

Java设计模式 | 基于订单批量支付场景,对策略模式和简单工厂模式进行简单实现

基于订单批量支付场景,对策略模式和简单工厂模式进行简单实现 文章目录 策略模式介绍实现抽象策略具体策略1.AliPayStrategy2.WeChatPayStrategy 环境 使用简单工厂来获取具体策略对象支付方式枚举策略工厂接口策略工厂实现 测试使用订单实体类对订单进行批量支付结…

VS code中使用code Runner插件直接运行Typescript

使用VS code运行ts 运行问题 我们知道,在VS code中运行.ts文件,是不能直接运行的,需要在修改代码之后,都重复执行两个命令,才能运行ts代码 tsc 文件名.ts (tsc 文件名.ts -w 可以监视ts文件(监视模…

【数字人】5、RAD-NeRF | 通过解耦 audio-spatial 编码来实现基于 NeRF 的高效数字人合成

文章目录 一、背景二、方法2.1 问题定义2.2 Decomposed audio-spatial encoding module2.3 Pseudo-3D Deformable Module 用于控制 torso2.4 训练细节 三、效果3.1 实验设置3.2 对比 论文:Real-time Neural Radiance Talking Portrait Synthesis via Audio-spatial …

【LittleXi】【MIT6.S081-2022Fall】Lab: syscall

【LittleXi】【MIT6.S081-2022Fall】Lab: syscall 文章目录 lab2实验1:Process counting实验思路实验过程 实验2:Free Memory Cou实验思路实验过程 实验3:System call tracin实验思路实验过程 实验4:流程概述1.请概述用户从发出系…

嵌入式养成计划-44----QT--消息对话框(QMessageBox)--字体对话框--颜色对话框--文件对话框

一百一十三、消息对话框 (QMessageBox) 消息对话框给用户提供一个交互式的弹窗,该类提供两种实现版本, 基于属性版本基于静态成员函数版本 基于属性版本 需要用消息对话框这样的类 实例化对象 用该对象调用类里的相关成员函数进…

web:[MRCTF2020]Ez_bypass

题目 点进题目 调整一下 进行代码审计,先看第一段 if(isset($_GET[gg])&&isset($_GET[id])) {$id$_GET[id];$gg$_GET[gg];if (md5($id) md5($gg) && $id ! $gg) {echo You got the first step; get参数传参,后判断md5后的值是否相等&…

2023前端面试题总结

给大家推荐一个实用面试题库 1、前端面试题库 (面试必备) 推荐:★★★★★ 地址:web前端面试题库 Html5和CSS3 常见的水平垂直居中实现方案 最简单的方案当然是flex布局 .father {display: flex;justify-content…

手部关键点检测4:Android实现手部关键点检测(手部姿势估计)含源码 可实时检测

目录 1. 前言 2.手部关键点检测(手部姿势估计)方法 (1)Top-Down(自上而下)方法 (2)Bottom-Up(自下而上)方法: 3.手部关键点检测模型训练 4.手部关键点检测模型Android部署 (1) 将Pytorch模型转换ONNX模型 (2) …

嘉立创使用技巧

立创社区:电子工程师交流社区_电子发烧友论坛_嘉立创&立创商城旗下专业电子论坛【立创社区】 (szlcsc.com) 嘉立创官网使用教程:立创EDA使用教程 (lceda.cn) 嘉立创是国产软件对新手友好,中国人更懂中国人。下面介绍我在使用中用到的技巧…

【unity小技巧】适用于任何 2d 游戏的钥匙门系统和buff系统——UnityEvent的使用

文章目录 每篇一句前言开启配置门的开启动画代码调用,控制开启门动画 新增CollisionDetector 脚本,使用UnityEvent ,控制钥匙和门的绑定多把钥匙控制多个门一把钥匙控制多个门 BUFF系统扩展参考源码完结 每篇一句 人总是害怕去追求自己最重要…

堆-----数据结构

引言 什么是堆?堆是一种特殊的数据结构(用数组表示的树)。 为什么要使用到堆?比如一场比赛,如果使用擂台赛的方式来决出冠军(实力第一),就很难知道实力第二的队伍是什么了。 但是…

Simulink 最基础教程(三)常用模块

3.1源模块 1)clock 这个模块的输出是 y(t)t。很多信号都是和时间 t 相关的,例如正弦波信号,可以写成 sin(w*t) 的形式。虽然软件也提供了正弦波模块,但如果用 clock 模块三角运算模块,对初学者而言,也是很好…

QT_day3

完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到新的界面中 如果账号和密码不匹配&#…

科技资讯|2023全球智能手表预估出货1.3亿块,智能穿戴提升AI功能

根据集邦咨询公布的最新报告,受全球经济低迷影响,2023 年全球智能手表出货量预估为 1.3 亿块。苹果以超过 30% 的份额领先,其次是三星(接近 10%)、华为、Garmin、Fitbit 等。 报告认为苹果、三星和华为等主要智能手表…