推荐系统里面的多任务学习概述

news2025/3/1 23:10:02

1. 概述

多任务学习(multi-task learning),本质上是希望使用一个模型完成多个任务的建模,在推荐系统中,多任务学习一般即指多目标学习(multi-label learning),不同目标输入相同的feature进行联合训练,是迁移学习的一种。他们之间的关系如图:

2.分类

通过对任务关系的建模,可以将基于多任务学习的推荐系统分为以下几种:

2.1并行任务建模

在这种MTL就是把这些任务分开来,单独建模,不用考虑它们之间是不是有先后顺序的影响。这种模型一般会把目标函数设成损失的加权和,而且这些权重都是固定的。还有些研究用上了注意力机制,来抓取一些可以在不同任务之间共享的特征。

比较有代表模型有:Rank and Rate (RnR)、Multiple Relational Attention Network (MRAN)。

2.2级联任务建模

这些MTL的建模,它们会考虑任务之间是有先后顺序的,就像是多米诺骨牌一样,一个任务的结果会影响下一个任务。这种模型在电商、广告和金融这些领域挺常见的,它们通常会根据用户的行为模式来设定一个序列,比如“先展示商品,然后用户点击,最后购买”。在这个类别里,有几个做得不错的。

代表模型有:Entire Space Multi-task Model (ESMM) Adaptive Information Transfer Multi-task (AITM)

 ESMM模型结构如下:

 AITM模型结构如下:

                                                                        

2.3辅助任务学习

这种MTL建模技术里,挑一个任务当主角,其他的都算是配角,它们存在的意义就是帮主角提升表现。在好几个任务一起优化的时候,很难做到每个任务都获益。所以,有的MTL技术就是以提高主要任务的性能为目标,哪怕牺牲一些辅助任务的性能。用上整个空间的辅助任务,能在预测主要任务的时候提供更丰富的背景信息。

代表模型有:Multi-gate Mixture-of-Experts (MMoE)、Progressive Layered Extraction (PLE)

3.优点

减少过拟合:在多任务学习框架下,模型通过共享的表示层学习编码更加通用的特征,而不是仅针对单一任务的特征表示。一般来说,神经网络能够从输入数据中提取出有用的特征,这些特征随后将被用于执行特定的任务(如分类或回归等)。如果我们有多个相关的任务,就没有必要重复提取特征,而只需要一次性提取出这些特征,然后将其输入到各个任务专用的模型中进行处理即可。这正是多任务架构的核心思想所在。

提高效率:通过单一模型同时执行多个任务,多任务架构能极大地加快推理过程,对于“效能要求”苛刻的边缘应用场景尤为重要。提升速度和效率的一个常被忽视的好处是,可以通过减少训练和推理阶段的整体计算量来缓解服务器成本的压力

基于正迁移提高效率:在多任务学习中,存在这样一种情况:当将某些任务一同学习时,会导致各个任务的性能都得到提升,这种现象被称为“正迁移”(Positive Transfer)。正迁移的发生源于不同任务之间存在一定的共性和相关性。当模型通过共享表示层同时学习这些相关任务时,任务间的共性知识会在底层得到很好的提炼和内化,从而形成通用的特征表示。

4.使用技巧

1. 整合损失函数

最简单的办法,我们可以整合不同tasks的loss function,然后简单求和。这种方法存在一些不足,比如当模型收敛时,有一些task的表现比较好,而另外一些task的表现却惨不忍睹。其背后的原因是不同的损失函数具有不同的尺度,某些损失函数的尺度较大,从而影响了尺度较小的损失函数发挥作用。这个问题的解决方案是把多任务损失函数“简单求和”替换为“加权求和”。加权可以使得每个损失函数的尺度一致,但也带来了新的问题:加权的超参难以确定。幸运的是,有一篇论文《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》通过“不确定性(uncertainty)”来调整损失函数中的加权超参,使得每个任务中的损失函数具有相似的尺度。该算法的keras版本实现,详见github:https://github.com/yaringal/multi-task-learning-example/blob/master/multi-task-learning-example.ipynb

2, 学习率选择

在神经网络的参数中,learning rate是一个非常重要的参数。在实践过程中,我们发现某一个learnig rate=0.001能够把任务A学习好,而另外一个learning rate=0.1能够把任务B学好。选择较大的learning rate会导致某个任务上出现dying relu;而较小的learning rate会使得某些任务上模型收敛速度过慢。怎么解决这个问题呢?对于不同的task,我们可以采用不同的learning rate。

all_variables = shared_vars + a_vars + b_vars
all_gradients = tf.gradients(loss, all_variables)

shared_subnet_gradients = all_gradients[:len(shared_vars)]
a_gradients = all_gradients[len(shared_vars):len(shared_vars + a_vars)]
b_gradients = all_gradients[len(shared_vars + a_vars):]

shared_subnet_optimizer = tf.train.AdamOptimizer(shared_learning_rate)
a_optimizer = tf.train.AdamOptimizer(a_learning_rate)
b_optimizer = tf.train.AdamOptimizer(b_learning_rate)

train_shared_op = shared_subnet_optimizer.apply_gradients(zip(shared_subnet_gradients, shared_vars))
train_a_op = a_optimizer.apply_gradients(zip(a_gradients, a_vars))
train_b_op = b_optimizer.apply_gradients(zip(b_gradients, b_vars))

train_op = tf.group(train_shared_op, train_a_op, train_b_op)

3. 任务A的评估作为其他任务的特征

当我们构建了一个MTL的神经网络时,该模型对于任务A的估计可以作为任务B的一个特征。在前向传播时,这个过程非常简单,因为模型对于A的估计就是一个tensor,可以简单的将这个tensor作为另一个任务的输入。但是后向传播时,存在着一些不同。因为我们不希望任务B的梯度传给任务A。幸运的是,Tensorflow提供了一个API tf.stop_gradient。当计算梯度时,可以将某些tensor看成是constant常数,而非变量,从而使得其值不受梯度影响。

all_gradients = tf.gradients(loss, all_variables, stop_gradients=stop_tensors)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2257522.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Springboot技术的实验室管理系统【附源码】

基于Springboot技术的实验室管理系统 效果如下: 系统登录页面 实验室信息页面 维修记录页面 轮播图管理页面 公告信息管理页面 知识库页面 实验课程页面 实验室预约页面 研究背景 在科研、教育等领域,实验室是进行实验教学和科学研究的重要场所。随着…

Abaqus断层扫描三维重建插件CT2Model 3D V1.1版本更新

更新说明 Abaqus AbyssFish CT2Model3D V1.1版本更新新增对TIF、TIFF图像文件格式的支持。本插件用户可免费获取升级服务。 插件介绍 插件说明: Abaqus基于CT断层扫描的三维重建插件CT2Model 3D 应用案例: ABAQUS基于CT断层扫描的细观混凝土三维重建…

【开源】A066—基于JavaWeb的农产品直卖平台的设计与实现

🙊作者简介:在校研究生,拥有计算机专业的研究生开发团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看项目链接获取⬇️,记得注明来意哦~🌹 赠送计算机毕业设计600个选题ex…

ChatGPT Pro是什么

ChatGPT Pro 和 ChatGPT Plus 的区别主要体现在功能范围、适用场景和目标用户上。 ChatGPT Plus 功能 • 价格:20美元/月。 • 目标用户:针对个人用户设计。 • 主要特点: • 在高峰期响应速度更快。 • 使用高级模型(如 GPT-4…

增加数据长度——提高频率分辨率

由于运算方式和存储容量的限制,计算机只能处理离散且有限长的数据,故“不得不”将无限长的采样序列在时域截断,再进行后续处理。由数据在时域截断引起失真。 分析余弦序列 x ( n ) cos ⁡ ( ω 0 n ) x(n) \cos(\omega_0 n) x(n)cos(ω0​…

天喻InteKEY加密软件卸载

1 概述 有些小伙伴向我求助,说他们的电脑上被迫安装了天喻InteKEY加密软件,现在所有的office文档、代码等文件都会自动加密,传给别人,都是乱码,无法打开。 如下图所示: 请求我能不能帮他们把这些加密的文…

【报错】新建springboot项目时缺少resource

1.问题描述 在新建springboot项目时缺少resources,刚刚新建时的目录刚好就是去掉涂鸦的resources后的目录 2.解决方法 步骤如下:【文件】--【项目结构】--【模块】--【源】--在main文件夹右击选择新建文件夹并命名为resources--在test文件夹右击选择新建文件夹并命名…

【PlantUML系列】流程图(四)

目录 目录 一、基础用法 1.1 开始和结束 1.2 操作步骤 1.3 条件判断 1.4 并行处理 1.5 循环 1.6 分区 1.7 泳道 一、基础用法 1.1 开始和结束 开始一般使用start关键字;结束一般使用stop/end关键字。基础用法包括: start ... stopstart ...…

计算机网络:传输层、应用层、网络安全、视频/音频/无线网络、下一代因特网

目录 (五)传输层 1.传输层寻址与端口 2.无连接服务与面向连接服务 3. 传输连接的建立与释放 4. UDP 的优点 5. UDP 和 TCP 报文段报头格式 6. TCP 的流量控制 7.TCP 的拥塞控制 8. TCP 传送连接的管理 &#…

【cpp/c++ summary 语法总结】细节(作为参数时) 数组退化

在C语言中,参数传递通常是通过值传递(pass by value)的方式进行的,这意味着当调用函数时,实际参数的值会被复制到对应的形参中。因此,函数内部操作的是这些值的副本,而不是原始变量本身。这种方…

Python生成对抗神经网络GAN预测股票及LSTMs、ARIMA对比分析ETF金融时间序列可视化

全文链接:https://tecdat.cn/?p38528 本文聚焦于利用生成对抗网络(GANs)进行金融时间序列的概率预测。介绍了一种新颖的基于经济学驱动的生成器损失函数,使 GANs 更适用于分类任务并置于监督学习环境中,能给出价格回…

常用环境部署(二十四)——Docker部署开源物联网平台Thingsboard

1、Docker和Docker-compose安装 参考网址如下: CENTOS8.0安装DOCKER&DOCKER-COMPOSE以及常见报错解决_centos8安装docker-compose-CSDN博客 2、 Thingsboard安装 (1)在/home目录下创建docker-compose.yml文件 vim /home/docker-com…

Mind 爱好者周刊 第6期 | 关于假设检验的贝叶斯因子(含R包)、高阶冥想期间的神经现象学、大脑中广泛的 β 网络、视觉和听觉审美具有不同的神经机制……

所有的研究由我的独断和偏见选出,单位仅标注第一单位/通讯单位;本篇为 12.3~12.10 期间我感兴趣的研究摘要;取名创意来自「科技爱好者周刊」 注:相比前几期以认知神经研究为主,本期收录了很多有趣的行为实验&#xff0…

太速科技-488-基于3U VPX的ZYNQ XC7Z100 计算主控板

基于3U VPX的ZYNQ XC7Z100 计算主控板 一、板卡概述 本板卡基于3U VPX结构 使用FPGA XC7Z100 FFG 9000 芯片。产品类似计算机主控板,包含以太网、USB、HDMI、EMMC\M.2存储接口。同时又有自定义的IO扩展,包括高速PCIe、RapidIO,普通LV…

【Devops】Python运维自动化之集合Set

集合Set 集合,简称集。由任意个元素构成的集体。高级语言都实现了这个非常重要的数据结构类型。 Python中,它是可变的、无序的、不重复的元素的集合。 hash表 Python中的集合(set)是基于哈希表(Hash Table&#xff…

x64dbg 安装使用教程

x64dbg的安装与配置 x64dbg官网地址:https://x64dbg.com/#start x64dbg界面介绍 1.反汇编窗口 这个位置显示的是需要分析的程序的反汇编代码。在第一个区域的最左侧例如“7712EAA3”这一列就是内存地址区域,接着“E8 07”就是汇编指令的opcode&#xff…

CH343等第3代USB串口芯片常见问题解答

一、概述 CH343、CH9101、CH9102等系列芯片,是沁恒推出的第三代USB转单串口产品,基于经典版CH340系列芯片进行技术革新,实现USB转高速异步串口,波特率支持最高6Mbps。芯片内部高度集成,外围精简,均提供VIO…

npm安装-详细教程

npm安装教程 第一章 Vue学习入门之 Node.js 的使用 文章目录 npm安装教程 [TOC] 前言一、npm是什么?二、安装、配置环境变量 1.下载并安装NodeJS2.npm配置 前言 随着时代的不断发展,前端学习这门技术也越来越重要,很多人都开启了学习前端…

【Web】2024“国城杯”网络安全挑战大赛题解

目录 Ez_Gallery 法一:shell盲注 法二:反弹shell 法三:响应钩子回显 Easy Jelly 法一:无回显XXE 法二:Jexl表达式RCE signal 法一:SSRF 法二:filterchain RCE Ez_Gallery 用这个bp验证…

【模型对比】ChatGPT vs Kimi vs 文心一言那个更好用?数据详细解析,找出最适合你的AI辅助工具!

在这个人工智能迅猛发展的时代,AI聊天助手已经深入我们的工作与生活。你是否曾在选择使用ChatGPT、Kimi或是百度的文心一言时感到一头雾水?每款AI都有其独特的魅力与优势,那么,究竟哪一款AI聊天助手最适合你呢?本文将带…