【阅读笔记】多任务学习之MMoE(含代码实现)

news2025/2/25 1:25:10

本文作为自己阅读论文后的总结和思考,不涉及论文翻译和模型解读,适合大家阅读完论文后交流想法。

MMoE

    • 一. 全文总结
    • 二. 研究方法
    • 三. 结论
    • 四. 创新点
    • 五. 思考
    • 六. 参考文献
    • 七. Pytorch实现⭐

一. 全文总结

提出了一种基于**多门混合专家(MMoE)**结构的多任务学习方法,验证了模型的有效性和可训练性。

在这里插入图片描述

二. 研究方法

构造了可以人为控制相关性的合成数据集,比较了Share-Bottom、OMoE、MMoE不同相关系数任务下的训练精度。最后,对真实的基准数据和具有数亿用户和项目的大规模生产推荐系统进行了实验,验证了MMoE在现实环境中的效率和有效性。

下图为三种模型在不同相关性任务中的表现:
三种模型在不同相关性任务中的表现
下图为不同模型在不同相关性任务中,重复实验200次最低loss的分布情况:
在这里插入图片描述

三. 结论

  1. MMoE明确地学习从数据中建模任务关系,可以更好地处理任务不太相关的场景。
  2. 与基线方法相比,MMoE 更容易训练
  3. MMoE 在很大程度上保留了计算优势(有更好的计算效率),因为门控网络通常是轻量级的,并且专家网络在所有任务中共享。

四. 创新点

  1. 提出了一种新颖的多门专家混合模型MMoE,该模型明确地对任务关系进行建模。通过调制和门控网络,MMoE自动调整建模共享信息和建模任务特定信息之间的参数化
  2. 对合成数据进行控制实验,报告了任务相关性如何影响多任务学习中的训练动态以及 MMoE 如何提高模型表达能力和可训练性
  3. 对真实的基准数据和具有数亿用户和项目的大规模生产推荐系统进行了实验,实验验证了MMoE在现实环境中的效率和有效性

五. 思考

  1. MMoE在任务相关性低时较其他模型有更好的效果,但是可能会”跷跷板“的情况:一个task的效果提升,会伴随着另一个task的效果降低。
  2. 门控网络一般由线性变换+softmax组成,这部分计算量非常小,几乎可以忽略,但有人实验表明门控网络多叠加几层会有更好的效果。
  3. 多门结构在解决由任务差异引起的冲突引起的不良局部最小值方面有效。

六. 参考文献

  1. 大厂技术实现 | 多目标优化及应用(含代码实现)
  2. 我要打十个:多任务学习模型MMoE解读
  3. 多目标学习(Multi-task Learning)-网络设计和损失函数优化
  4. 收藏|浅谈多任务学习(Multi-task Learning)

七. Pytorch实现⭐

class Expert_Gate(nn.Module):
    def __init__(self,feature_dim,expert_dim,n_expert,n_task,use_gate=True): #feature_dim:输入数据的维数  expert_dim:每个神经元输出的维数  n_expert:专家数量  n_task:任务数(gate数)  use_gate:是否使用门控,如果不使用则各个专家取平均
        super(Expert_Gate, self).__init__()
        self.n_task = n_task
        self.use_gate = use_gate
        
        '''专家网络'''
        p=0
        expert_hidden_layers = [64,32,expert_dim]
        self.expert_layer = nn.Sequential(
                            nn.Linear(feature_dim, expert_hidden_layers[0]),
                            nn.ReLU(),
                            nn.Dropout(p),
                            nn.Linear(expert_hidden_layers[0], expert_hidden_layers[1]),
                            nn.ReLU(),
                            nn.Dropout(p),
                            nn.Linear(expert_hidden_layers[1], expert_hidden_layers[2]),
                            nn.ReLU(),
                            nn.Dropout(p)
                            )  
        self.expert_layers = [self.expert_layer for i in range(n_expert)] #为每个expert创建一个DNN
        
        '''门控网络'''
        self.gate_layer = nn.Sequential(nn.Linear(feature_dim, n_expert),
                                        nn.Softmax(dim=1))
        self.gate_layers = [self.gate_layer for i in range(n_task)] #为每个gate创建一个lr+softmax
        
    def forward(self, x):
        if self.use_gate:
            # 构建多个专家网络
            E_net = [expert(x) for expert in self.expert_layers]
            E_net = torch.cat(([e[:,np.newaxis,:] for e in E_net]),dim = 1) # 维度 (bs,n_expert,expert_dim)

            # 构建多个门网络
            gate_net = [gate(x) for gate in self.gate_layers]     # 维度 n_task个(bs,n_expert)

            # towers计算:对应的门网络乘上所有的专家网络
            towers = []
            for i in range(self.n_task):
                g = gate_net[i].unsqueeze(2)  # 维度(bs,n_expert,1)
                tower = torch.matmul(E_net.transpose(1,2),g)# 维度 (bs,expert_dim,1)
                towers.append(tower.transpose(1,2).squeeze(1))           # 维度(bs,expert_dim)
        else:
            E_net = [expert(x) for expert in self.expert_layers]
            towers = sum(E_net)/len(E_net)
        return towers

上面Expert_Gate为下图中红框内的模型实现:
在这里插入图片描述

class MMoE(nn.Module):
	#feature_dim:输入数据的维数  expert_dim:每个神经元输出的维数  n_expert:专家数量  n_task:任务数(gate数)
    def __init__(self,feature_dim,expert_dim,n_expert,n_task,use_gate=True): 
        super(MMoE, self).__init__()
        
        self.use_gate = use_gate
        self.Expert_Gate = Expert_Gate(feature_dim=feature_dim,expert_dim=expert_dim,n_expert=n_expert,n_task=n_task,use_gate=use_gate)
        
        '''Tower1'''
        p1 = 0 
        hidden_layer1 = [64,32] #[64,32] 
        self.tower1 = nn.Sequential(
            nn.Linear(expert_dim, hidden_layer1[0]),
            nn.ReLU(),
            nn.Dropout(p1),
            nn.Linear(hidden_layer1[0], hidden_layer1[1]),
            nn.ReLU(),
            nn.Dropout(p1),
            nn.Linear(hidden_layer1[1], 1))
        '''Tower2'''
        p2 = 0
        hidden_layer2 = [64,32]
        self.tower2 = nn.Sequential(
            nn.Linear(expert_dim, hidden_layer2[0]),
            nn.ReLU(),
            nn.Dropout(p2),
            nn.Linear(hidden_layer2[0], hidden_layer2[1]),
            nn.ReLU(),
            nn.Dropout(p2),
            nn.Linear(hidden_layer2[1], 1))
        
    def forward(self, x):
        
        towers = self.Expert_Gate(x)
        if self.use_gate:            
            out1 = self.tower1(towers[0])
            out2 = self.tower2(towers[1]) 
        else:
            out1 = self.tower1(towers)
            out2 = self.tower2(towers)
        
        return out1,out2
    
Model = MMoE(feature_dim=112,expert_dim=32,n_expert=4,n_task=2,use_gate=True)

nParams = sum([p.nelement() for p in Model.parameters()])
print('* number of parameters: %d' % nParams)

输入数据格式为(batchsize,feature_dim),输出为(batchsize,2)

在原文中作者构造了可以控制任务相关性的人工数据集,我搜遍全网都没找到人工数据集的创建方式,于是自己写了一个分享给大家:MMoE论文中Synthetic Data生成代码(控制多任务学习中任务之间的相关性)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2589.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL数据库基础操作

目录 前言: 库的操作 创建数据库 显示所有数据库 选中数据库 删除数据库 MySQL数据类型 数值类型 字符串类型 日期类型 表的操作 创建表 显示数据库中所有表 查看表结构 删除表 小结: 前言: 🎉MySQL是关系型数据…

【jquery Ajax】接口的学习与Postcode插件的使用

✍️ 作者简介: 前端新手学习中。 💂 作者主页: 作者主页查看更多前端教学 🎓 专栏分享:css重难点教学 Node.js教学 从头开始学习 目录 接口 接口的概念 分析接口的请求过程 通过GET方式请求接口的过程 通过post方式请求接口的过程 接口…

基于CarSystemUI实现左侧导航栏NavigationBar及下拉面板定制开发——Android10智能座舱

系列文章目录 提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加 例如:第一章 Python 机器学习入门之pandas的使用 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目…

使用多阶段和多尺度联合通道协调注意融合网络进行单图去雨[2022论文]

这是篇2022年来自一区的International Journal of Intelligent Systems的贵州大学的去雨论文 论文链接:百度网盘 请输入提取码 提取码:zdje ✍介绍 作者提出的问题: 1、目前去雨方法不能对不同密度和方向的雨条纹信息进行有效的编码 2、…

ThreadLocal类详解

ThreadLocal类注释翻译 打开JDK中ThreadLocal类源码,翻译类上注释如下(提取重点部分): 每个访问ThreadLocal实例对象的线程都有其自己的关于ThreadLocal对象的变量副本(通过get和set方法),只要线程存活而且ThreadLocal对象也存活,则线程都保…

嵌入式和单片机开发模式的区别

一、 开发模式 单片机开发多为裸机,程序规模小,多为单个程序员独立开发。有些复杂产品也会使用高端单片机如STM32之类,并使用RTOS(uCOS、freeRTOS等)。嵌入式开发几乎全部基于嵌入式操作系统,目前使用最多的是 linux 和Android。…

公众号搜题接口系统使用方法

公众号搜题接口系统使用方法 本平台优点:免费查题接口搭建 多题库查题、独立后台、响应速度快、全网平台可查、功能最全! 1.想要给自己的公众号获得查题接口,只需要两步! 2.题库:题库后台http://daili.jueguangzhe.c…

html实现飞机小游戏(源码)

文章目录1.思路讲解1.1 游戏设计1.2 主界面1.3 倒计时进入游戏1.4 游戏效果1.3 游戏结束2.实现源码2.1 游戏动态效果2.2 游戏主代码2.3 源码目录源码下载作者:xcLeigh 文章说明 html实现飞机大战源码,酷炫的界面效果,有四款飞机大战背景&…

Elasticsearch:通过热、温、冷和冻结层管理数据自动化 — 无需编码!

如果你想完全按照本文标题的建议去做,那就别无所求。 这篇文章旨在指导如何使用 Kibana Dashboard 的 “堆栈管理(Stack Management)” 功能集通过热、温、冷和冻结层自动移动数据,而无需进行任何编码或执行命令行动作。 在下面的…

Cookie 和 Session

本文主要讲解一下 Cookie 和 Session 的关系和区别,大家都知道 Session 比 Cookie 安全,Session 是存储在服务器端的,Cookie 是存储在客户端的,然而更详细的说,恐怕就不太清楚了 文章目录1. 什么是 HTTP2. Cookie2.1 图…

​目标检测算法——YOLOv5/YOLOv7改进之结合Criss-Cross Attention

关注”PandaCVer“公众号 深度学习Tricks,第一时间送达 (一)前沿介绍 论文题目:CCNet: Criss-Cross Attention for Semantic Segmentation 论文地址:https://arxiv.org/pdf/1811.11721.pdf 代码地址:ht…

B树和B+树(平衡多路查找树)

文章目录为什么需要B树B 树的特点B树的查找B树的引入B树的删除链接:https://www.cs.usfca.edu/~galles/visualization/Algorithms.html 可以点击 Indexing 下的 B Trees 和 B Trees 去学习。 为什么需要B树 对 B 树的需求随着访问物理存储介质(如硬盘&…

【Java】反射, 枚举,Lambda表达式

✨博客主页: 心荣~ ✨系列专栏:【Java SE】 ✨一句短话: 难在坚持,贵在坚持,成在坚持! 文章目录一. 反射1. 反射的概述2. 反射的使用2.1 反射常用的类2.2 通过反射获取Class对象2.3 获得Class类相关的方法2.4 使用反射创建实例对象2.5 使用反射获取实例对象中的构造方法2.6 通过…

Spring学习第1篇:学习spring必备的概念知识

大家家好,我是一名网络怪咖,北漂五年。相信大家和我一样,都有一个大厂梦,作为一名资深Java选手,深知Spring重要性,现在普遍都使用SpringBoot来开发,面试的时候SpringBoot原理也是经常会问到&…

纸牌博弈问题

纸牌博弈问题 作者:Grey 原文地址: 博客园:纸牌博弈问题 CSDN:纸牌博弈问题 题目描述 有一个整型数组 A,代表数值不同的纸牌排成一条线。玩家 a 和玩家 b 依次拿走每张纸牌, 规定玩家 a 先拿&#xff…

win11开机音效设置的方法

微软为win11重做了开机音效,与我们一直以来使用的开机音效不太一样,听起来很不舒服,因此我们可以通过设置开机音效的方法来修改它,只要在个性化设置中就可以找到了,下面一起来试试看吧。 win11开机音效怎么设置&#…

wordpress图片压缩插件-免费批量wordpress图片压缩

wordpress图片压缩插件,相信每个人都知道图片的太大会影响到网站的加载速度。过多的图像会对服务器产生相应的压力。导致网站打开会越来越慢。而图片也是会被搜索引擎收录的,可以在百度图片里面能搜索的到,也算是增加了网站的宣传力度。今天给…

(附源码)计算机毕业设计SSM基于微信平台的匿名电子投票系统

(附源码)计算机毕业设计SSM基于微信平台的匿名电子投票系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。…

【微信小程序支付功能】uniapp实现微信小程序支付功能

支付实现流程 首先前端写一个页面,简单说就是有一个输入支付金额的 然后有一个按钮,点击可以支付。 点击按钮后触发支付方法,就是我下面写的这些代码,复制就可以了。 然后先请求后端的一个方法,把你的价格还有openid之…

在Vue中使用Swiper轮播图、同时解决点击轮播图左右切换按钮不生效的问题、同时将轮播图抽离出为一个公共组件

轮播图左右的切换按钮、如果点击没有反应,控制台也没有报错。很大可能是版本问题。如果不指定版本信息、默认安装的是最新的版本。版本过高或者过低都有可能导致无效。目前兼容性和稳定性比较好的是:5.4.5。 官网地址:https://www.swiper.com…