MMOE+ESSM

news2024/12/24 11:41:30

MMOE

动机
  1. 多个任务之间的相关性并不是很强,这个时候如果再用过去那种共享底座embedding的结构,往往会导致『跷跷板』现象。

  2. 当前学术界已经有很多工作意识到1中描述的问题并且尝试去解决,但大多数工作的套路都是『大力出奇迹』的路子,即加很多可学的参数去学习多个任务之间的difference,这在学术界跑跑数据,写写论文倒是没什么,但是在工业界场景下,增加这些参数会导致线上做infer时耗时增加,导致模型服务可用性大大下降,这是无法接受的。

结构

(a)展示了传统的MTL模型结构,即多个task共享底座(一般都是embedding向量),(b)则是论文中提到的一个gate的Mixture-of-Experts模型结构,(c)则是论文中的MMoE模型结构。

每一个expert和gate都是一个全连接网络(MLP),层数由在实际的场景下自己决定。

  1. Gate网络的数量取决于task数量,即与task数量相同。Gate网络最后一层全连接层的隐藏单元(即输出)size必须等于expert个数。另外,Gate网络最后的输出会经过softmax进行归一化。

  2. Gate网络最后一层全连接层经过softmax归一化后的输出,对应作用到每一个expert上(图2中GateA输出的红、紫、绿三条线分别作用与expert0,expert1,expert2),注意是通过广播机制作用到expert中的每一个隐藏单元,比如红线作用于expert0的2个隐藏单元。这里gate网络的作用非常类似于attention机制,提供了权重。

    (广播机制的一个常见例子是在全连接层(Dense Layer)中,当输入张量的维度与权重张量的维度不匹配时,TensorFlow 会自动应用广播机制来匹配维度。)

  3. towerA的输入size等于expert输出隐藏单元个数(在这个例子中,expert最后一层全连接层隐藏单元个数为2,因此towerA的输入维度也为2)

  4. expert每个网络的输入特征都是一样的,其网络结构也是一致的。

  5. 两个gate网络的输入也是一样的,gate网络结构也是一样的。

gate用来学习每个expert的权重,Gate是一个概率分布,控制每个Expert对task的贡献程度,比如task A的gate为(0.1, 0.2, 0.7),则代表Expert 0、Expert 1、Expert 2对task A的贡献程度分别为0.1、0.2和0.7;

一个task独立拥有一个Gate,tower是多任务目标的数量,ctr/cvr,两个塔,两个gate

代码
class MMOELayer(object):
    def __init__(self, name='mmoe'):
        self._name = name

    def __call__(self, inputs, expert_units, num_experts, num_tasks,
                 expert_act=tf.nn.leaky_relu, gate_act=tf.nn.softmax, temp=None):
        expert_outputs, final_outputs = [], []
        with tf.name_scope('experts_network'):
            for i in range(num_experts):
                weight_name_template = self._name + '_expert{}_'.format(i) + '_h{}_param'
                expert_layer = simple_dense_network(inputs, expert_units, '{}_experts'.format(self._name),
                                                    weight_name_template, act=expert_act)
                expert_outputs.append(tf.expand_dims(expert_layer, axis=2))
            expert_outputs = tf.concat(expert_outputs, 2)  # (batch_size, expert_units[-1], num_experts)

        with tf.name_scope('gates_network'):
            for i in range(num_tasks):
                weight_name_template = self._name + '_task_gate{}_'.format(i) + 'param'
                gate_mlps = simple_dense_network(inputs, [16*num_experts, 4*num_experts], '{}_gate_mlp'.format(self._name),
                                                 weight_name_template+'_h{}_mlps', act=expert_act)

                gate_layer  = mio_dense_layer(gate_mlps, num_experts, gate_act, '{}_gates'.format(self._name),
                                              weight_name_template+'_gate', temp)

                expanded_gate_output = tf.expand_dims(gate_layer, axis=1)  # (batch_size, ?, num_experts)
                weighted_expert_output = expert_outputs * repeat_elements(expanded_gate_output, expert_units[-1],
                                                                          axis=1)
                final_outputs.append(sum(weighted_expert_output, axis=2))

        return final_outputs  # (num_tasks, batch_size, expert_units[-1])

相关问题

【问】: expert网络结构一样,输入特征一样,是否会导致每个expert学出来的参数趋向于一致,从而失去了ensemble的意义?

【答】: 在网络参数随机初始化的情况下,不会发生问题中提到的问题。核心原因在于数据存在multi-view,只要每一个expert网络参数初始化是不一样的,就会导致每一个expert学到数据中不同的view(paddle官方实现就犯了这个致命错误)。微软的一篇论文中提到因为数据存在multi-view,训练多个DNN时,即使一样的特征,一样的超参数,只要简单的把参数初始化设置不一样, 这多个DNN也会有差le/details/123309660

multi-view:

在图像识别任务中,multi-view数据可能包括原始图像、边缘检测图像、颜色直方图图像等,每个视图都提供了不同的信息,有助于提高模型的性能和鲁棒性。在文本分析任务中,multi-view数据可能包括原始文本、TF-IDF向量、词嵌入向量等,每个视图都捕捉了文本的不同语义信息。

【问】: 是否应该强上MTL?
【答】: 如果task之间的相关性很弱,基本上都会发生negative transfer,所以MTL是绝对打不过single model的,不要盲目的为了显得高大上牛逼哄哄的一股脑MTL(Multi Task Learning多任务学习)。还是那句话,模型不重要,重要的是对数据及场景的理解。

Multi Task Learning

以往的做法可能会对不同的task分别建立对应的模型,但这样会导致几个问题

  1. 模型的数量会随着task的数量增加而增加,模型维护成本高;

  2. 生产环境中,需要同时多个模型进行计算,才能完成多个task的预估,存在性能问题;

  3. 忽略了不同task之间的关联。

多任务学习还有一个优点:

经常存在某个task的样本数量比较少的情况,导致模型的学习难度较高。多任务学习过程中多个task的样本数据是共享的,一定程度上减缓这个问题。

用于cvr训练的都是有点击的样本,这部分样本实在是太少了

原始的MTL--Shared-bottom

效果不好,原因:底层网络参数的共享

1、底层共享的参数容易偏向于某个task,或者说偏向于某个task的全局(局部)最优方向,如果task之间差异很大,那么不同task的模型参数的全局(局部)最优方向也会不同,那么其他task的效果肯定会大打折扣;

2、不同task的梯度冲突,存在参数撕扯,比如两个task的梯度是正负相反,那最终的梯度可能被抵消为0,导致共享参数更新缓慢。

MMOE和MTL对比
  1. 对比Shared-Bottom模型,MMoE将底层的共享网络层拆分为多个共享的Expert,并且通过引入Gate来学习每个Expert对不同task的贡献程度;

  2. 对应不同相关性的task,MMoE模型的效果比较稳定。这主要是因为相关性弱的task,可以通过Gate来利用不同的Expert。

  3. 减少直接参数共享

补充

MMoE在弱相关性task中表现地相对比较稳定,但由于底层的Expert仍然是共享的(虽然引入Gate来让task选择Expert),所以还是会存在**“跷跷板”**的情况:一个task的效果提升,会伴随着另一个task的效果降低。

腾讯在2020的论文中,就对MMoE进行改进,提出了CGC(Customized Gate Control)、PLE(Progressive Layered Extraction)

CGC(Customized Gate Control)

区别:

除了共享的Expert之外,还加入了每个task自己的Specific Expert

PLE(Progressive Layered Extraction)

相当于deep化mmoe

ESSM

ESSM全称Entire Space Multi-Task Model,也就是全样本空间的多任务模型,该模型有效地解决了CVR建模(转化率预估)中存在的两个非常重要的问题:样本选择偏差(SSB,sample selection bias)和数据稀疏。

动机

一条流量为公司创造商业收入的路径为:

请求 → 广告曝光 → 广告点击 -> 广告转化。 在这条路径中涉及到非常复杂繁多的算法与博弈,

(1)在【请求 → 广告曝光】阶段涉及到流量(算力)在线分配、沉重复杂的检索、广告填充率预估PVR、竞价博弈以及计费策略;

(2)在【广告曝光 → 广告点击】阶段就是必备的点击率预估PCTR;

(3)在【广告点击 → 广告转化】阶段则是经典的转化率预估CVR。

以上每一个环节的工程及策略优化都会为公司的商业收入带来巨大的提升,当然最终目的还是为了能够在用户,平台,广告主三者博弈之间达到全局最优三者皆赢的均衡状态。

样本选择偏差(SSB,sample selection bias)

目前业界在训练转化率(CVR)预估模型时,所采用数据集的正负样本分别为:点击未转化为负样本,点击转化为正样本。也就是整个样本集都是在有点击的样本上构建的。但在做在线infer时,是对整个样本空间进行预估,这就导致了样本选择偏差问题(即在离线样本空间有gap)。

数据稀疏

用于cvr训练的都是有点击的样本,这部分样本实在是太少了,对于广告而言,大盘的点击率也就在2%左右,其中能够转化的更加少之又少(正样本)。

通过共享底座embedding的方式也减缓这种情况。

ESMM模型结构细节

CVR模型的目的很明确,就是预估广告被点击之后的转化率(Post-Click Conversion Rate),

因此,cvr模型训练的时候只能用有点击的样本,这也直接导致了在离线的样本选择偏差(SSB)。如果我们想在训练的时候把样本空间扩大到整个有曝光的样本,那么需要怎么办呢?现在很明确的是CTR任务是用全部有曝光的样本,ESMM这里巧妙的做了转换,即训练CTCVR和CTR这两个任务,那么CTCVR和CTR、CVR之间的关系如何呢?下面就来看一下:

ESMM的出发点就是:既然CTCVR和CTR这两个任务训练是可以使用全部有曝光样本的,那我们通过这学习两个任务,隐式地来学习CVR任务。

  1. CTR与CVR这两个塔,共享底座embedding。 因此CVR样本数量太少了,也就是存在开头提到的两个问题中的数据稀疏问题,所以很难充分训练学到好的embedding表达,但是CTR样本很多,这样共享底座embedding,有点transfer learning的味道,帮助CVR的embedding向量训练的更充分,更准确。

  2. CVR这个塔其实个中间变量,他没有自己的损失函数也就意味着在训练期间没有明确的监督信号,在ESMM训练期间,主要训练的是CTR和CTCVR这两个任务,这一点从ESMM的loss函数设计也能看出来。

损失函数

loss = ctr loss + ctcvr loss 交叉熵

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2072867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

zigbee笔记、十五、组播通信原理

一、zigbee四种通讯 1、单播(略) 2、广播(略) 3、组播:在zigbee网络中,模块可以用分组来标记,发送的模块如果发送的组号和网络里面标记接收模块的组号相对应,那么这些模块就可以拿到…

深度剖析渗透测试:流程、规范与实战全指南

一、引言 在当今数字化时代,网络安全问题日益凸显。渗透测试作为一种主动的安全评估方法,能够帮助企业和组织发现潜在的安全漏洞,提高系统的安全性。本文将详细介绍渗透测试的实施流程、规范、不同类型的测试方法以及相关的 checklist 和报告…

Matlab处理H5文件

1.读取h5文件 filenamexxx.h5; h5disp(filename) 2.h5文件保存为mat文件 读取 HDF5 文件中的数据 % 指定 HDF5 文件的路径 filename xxx.h5;% 读取 HDF5 文件中的各个数据集 A241_P h5read(filename, /A241_P); A241_W h5read(filename, /A241_W); A242_P h5read(filen…

ensp 中 wlan 的配置过程和示例

一、拓朴: 要求:vlan20 用于笔记本上网,使用Huawei信号,vlan30 用于手机上网,使用 Huawei-5G 信号 二、配置过程: 1、SW1 基本配置: 起 vlan batch 10 20 30,10 为管理 vlan&#…

Acrobat Pro DC 2024 for mac/Win:跨平台PDF编辑与管理的巅峰之作

Adobe Acrobat Pro DC 2024是一款专为Mac和Windows用户设计的全面PDF解决方案软件,它集成了创建、编辑、转换、共享和签署PDF文件的强大功能,为用户带来前所未有的高效与便捷体验。 强大的PDF编辑功能 Acrobat Pro DC 2024在PDF编辑方面表现出色。用户…

JavaScript初级——DOM增删改

1、 document.createElement() —— 可以用于创建一个元素节点对象,他需要一个标签名作为参数,将会根据该标签名创建元素节点对象,并将创建好的对象作为返回值返回。 2、 document.createTextNode(&#…

职场达人必备!MyComputerManager助你轻松管理快捷方式

前言 你是否还在为硬盘管理界面上那一堆乱糟糟的快捷方式头疼不已?是不是每次打开‘此电脑’都像是在玩寻宝游戏,寻找那个被深埋的文件夹?想象一下,如果能在此电脑页面一键启动程序,是不是觉得整个人都轻松了许多&…

使用Tabs组件提升页面内容的聚焦与分类效率

当页面信息量较大时,为了提高用户的浏览效率,我们需要对页面内容进行有效的分类和展示。HarmonyOS提供的Tabs组件是一个理想的解决方案,可以在一个页面内快速切换视图内容,提升用户查找信息的效率,同时减少用户在单次操…

CSS定位与布局

一、display属性(元素如何显示) 网页上的每个元素都是一个​盒模型​。​display​属性决定了盒模型的​行为方式​,设置元素如何被显示。 display常用的属性共有​4个​值: ​display: none;​ -- 让标签消失(隐藏元素并脱离文档…

Mac M1Pro 安装Java性能监控工具VisualVM 2.1.9

本地已经安装了java8,在终端输入jvisualvm提示没有安装 zhiniansara ~ % jvisualvm The operation couldn’t be completed. Unable to locate a Java Runtime that supports jvisualvm. Please visit http://www.java.com for information on installing Java.官网…

Kafka事件(消息、数据)的存储

1、查看有关kafka日志配置文件的信息 2、查看kafka全部主题的日志文件 3、查看每个主题的日志文件 4、__consumer_offsets-xx文件夹的作用 package com.power;public class Test {public static void main(String[] args) {int partition Math.abs("myTopic".hashCo…

企业微信API对接文档【可向微信用户发消息】

目录 企业微信API对接文档 1.背景 2.获取微信第三方token 3.安装docker环境 4.打包与启动 4.1打包镜像 4.2启动容器(启动应用) 5.企业微信二维码验证 5.1 获取初始二维码 5.2 第1次二维码验证 5.3 第2次二维码验证 6. 企业微信三…

上博士为了毕业写学术论文头都大了,但更难受的是英语不咋地,投稿后经常会因为语言问题而惨遭拒稿,每每想起就令人心情郁郁,天台可期。

上博士为了毕业写学术论文头都大了,但更难受的是英语不咋地,投稿后经常会因为语言问题而惨遭拒稿,每每想起就令人心情郁郁,天台可期。有些审稿人也会直接告知需要专业的修改,那咋整呢,让润色呗,…

虚拟机virtualbox linux ubuntu使用usb串口

1.卸载brltty sudo apt remove brltty brltty是一个没啥用但是会抢占ch431的软件,所以卸载它 2.连接上串口,点击连接对应的usb串口 3.查看是否连接上 sudo dmesg -T | grep tty 查看tty组的最近日志,如果连接成功会显示连接的时间和串口…

基于数据挖掘的心力衰竭疾病风险评估系统

B站视频及代码下载:基于数据挖掘的心力衰竭疾病风险评估系统_哔哩哔哩_bilibili 1. 项目简介 心力衰竭是一种常见的心脏疾病,它严重影响患者的生活质量和预期寿命。早期识别和干预对于改善患者的预后至关重要。近年来,随着大数据技术和机器学…

eleme

设置主从从mysql57服务器 --配置主数据库 # systemctl stop firewalld # setenforce 0 # systemctl disable firewalld # ls anaconda-ks.cfg mysql-5.7.44-linux-glibc2.12-x86_64.tar.gz# tar -xf mysql-5.7.44-linux-glibc2.12-x86_64.tar.gz # cp -r mysql-5.7.44-linux-…

【题解】【循环】——[NOIP2010 普及组] 数字统计

【题解】【循环】——[NOIP2010 普及组] 数字统计 [NOIP2010 普及组] 数字统计题目描述输入格式输出格式输入输出样例输入 #1输出 #1输入 #2输出 #2 提示 1.题意解析2.AC代码 [NOIP2010 普及组] 数字统计 戳我查看题目(洛谷) 题目描述 请统计某个给定…

Spring cloud 网关信息

网关简绍 就是网络的关口&#xff0c;负责请求的路由、转发、身份校验。 引入网关依赖 <dependencies><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-gateway</artifactId></dependenc…

html+css 实现爱心跳动

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享htmlcss 实现爱心跳动&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f49…

速通教程:如何使用Coze+剪映,捏一个爆款悟空视频

程哥最近做了一个和黑神话悟空有关的视频&#xff0c;没想到就火了&#xff0c;视频主打一个玉石风格&#xff0c;就是下面这个视频。 视频请移步飞书观看&#xff1a;黑神话悟空玉石版 制作过程不算很复杂&#xff0c;全程只需要用到Coze智能体和剪映这两个工具。 智能体用…