多任务学习MTL模型:多目标Loss优化策略

news2024/11/15 9:40:08

前言

之前的文章中多任务学习MTL模型:MMoE、PLE,介绍了针对多任务学习的几种模型,着重网络结构方面的优化,减缓task之间相关性低导致梯度冲突,模型效果差,以及task之间的“跷跷板”问题。

但其实多任务学习还存在另外一些棘手的问题:

1、不同task的loss量级不同,可能会出现loss较大的task主导的现象(loss较大的task,梯度也会较大,导致模型的优化方向很大程度上由该task决定);

2、不同task的学习速度不同,有的慢有的快;

3、不同的loss应该分配怎样的权重?直接平均?如何选出最优的loss权重组合?

Using Uncertainty to Weigh Losses

相关论文:《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene
Geometry and Semantics》

这篇论文指出了多任务学习模型的效果很大程度上由共享的权重决定,但训练这些权重是很困难。由此引出uncertainty的概念,来衡量不同的task的loss,使得可以同时学习不同类型的task。

图片

  1. 其中, α {\alpha} α为可学习参数,论文认为它是对应task建模的uncertainty(不确定性)。

  2. 容易看出,总的loss会惩罚loss大且 α {\alpha} α小的task,因为对于这种task, 1 2 α 2 L {\frac{1}{2\alpha^2}L} 2α21L这一项就会很大,SGD就会将它往小优化;

  3. 它代表着对于loss较大的task,意味着它的uncertainty(不确定性)也较高,为了避免模型往错误的方向“大步迈”,应该以较小的梯度去更新w;相反的,对于loss较小的task,它的uncertainty也就较低,以较大的梯度去更新w;

  4. 同时,这也能避免让较大loss的task主导的问题。

总结:大loss的task给予小权重,小loss的task给予大权重。

注意事项:这个方法由于后面的log项,可能会出现总loss为负的情况。

GradNorm

相关论文:《GradNorm: Gradient Normalization for Adaptive Loss Balancing in
Deep Multitask Networks》

这篇论文提出一个新的方法:梯度正则化gradient normalization (GradNorm),它能自动平衡多task不同的梯度量级,提升多任务学习的效果,减少过拟合。

首先,总loss的定义仍是不同task的loss加权平均:

image-20220104212733900

GradNorm设计了额外的loss来学习不同task loss的权重 w i {w_i} wi,但它不参与网络层的参数的反向梯度更新,目的在于不同task的梯度通过正则化能够变成同样的量级,使不同task可以以接近的速度进行训练:

image-20220104211807540

其中,t代表训练的步数;

W一般是取最后一层共享网络层shared layer的权重;

第i个task的正则化梯度,即loss对W的梯度,然后再做L2-norm:

image-20220104213519661

image-20220104213725845

第i个task的loss(第t步)与初设loss比率,用来代表学习速度:

image-20220104214027162

第i个task的相对学习速度:

image-20220104214059350

注意事项:

1、容易看出不同task的初设loss: L i ( 0 ) {L_i(0)} Li(0),对学习速度的计算影响很大。

如果所有网络层有着稳定的参数初设化,则可以直接使用(第一次的loss);

但如果 L i ( 0 ) {L_i(0)} Li(0)对参数初设化方式很敏感,在多分类中,则可以令 L i ( 0 ) = l o g ( C ) {L_i(0)}=log(C) Li(0)=log(C),C为分类数。

2、论文的流程是在每轮训练中,先通过反向传播进行不同task loss的权重 w i {w_i} wi,再进行网络参数的更新。

Dynamic Weight Average

相关论文:《End-to-End Multi-Task Learning with Attention》

这篇论文仍然致力于寻找平衡多个task训练的方法,提出了一种**Dynamic Weight Average (DWA)**的方法,它比较简单直接,与GradNorm不同,不需要计算梯度,而是只需要task的loss。

image-20220108094337220

  1. λ k {\lambda_k} λk为task的权重,即总的loss仍为所有task的loss加权平均: L t o t a l = ∑ k λ k L k {L_{total}=\sum_k\lambda_kL_k} Ltotal=kλkLk
  2. w_k则为上一轮以及上上轮的loss比率,代表不同task的学习速率
  3. T起到平滑task权重的作用,T越大,不同task的权重分布越均匀。甚至T足够大的话,则 λ k ≈ 1 {\lambda_k} \approx 1 λk1,每个task的权重相等
  4. K则是让所有task的权重加权求和后为K: ∑ k λ k = K {\sum_k\lambda_k=K} kλk=K。因为一般情况下不特殊处理的话,每个task的权重都相等为1,那么所有task加权之后便为K。

Pareto-Eficient

相关论文:《A Pareto-Efficient Algorithm for Multiple Objective Optimization
in E-Commerce Recommendation》

这篇论文对总loss的定义仍然是所有task的loss加权平均,但这个权重是经过正则化(scalarization)的:

image-20220108102209035

image-20220108102311650

不过,论文指出了不同task会有不同的优先级,比如一个task为点击预测,一个task转化预测,那肯定转化预测的task的优先级更高,因此,可以为不同task的权重增加了一个边界条件:

w i ≥ c i ,   0 < c i < 1 ,   ∑ K c i ≤ 1 {w_i \ge c_i},\ 0<c_i<1,\ \sum_Kc_i \le 1 wici, 0<ci<1, Kci1

我们的目标当然是让总loss即 L ( θ ) {L(\theta)} L(θ)最小,常规做法,对 L ( θ ) {L(\theta)} L(θ)求导,然后令其导数等于0,即为下式:

image-20220111205436509

满足这种条件的解法称为Pareto stationary(帕累托平稳)

以上式子可以转化为:(K个 w i ▽ θ L i ( θ ) {w_i\triangledown_\theta L_i(\theta)} wiθLi(θ) 二范数之和的最小值即为0,最小化就是在往0逼近)

image-20220111211139699

论文也给出了最优的task的权重组合解法:

w ^ i = w i − c i {\hat{w}_i = w_i - c_i} w^i=wici,则不等式变成:

image-20220111212541560

然后根据以上理论求出 w ^ {\hat{w}} w^的所有解 w ^ ∗ {\hat{w}}^* w^,但是可能会出现负数的解。

image-20220111212637026

由于上述求出的 w ^ ∗ {\hat{w}}^* w^可能为负,并且为了能够用上的解,最终转化为非负的最小二乘问题:

image-20220111213150907

总结

以上论文都是为了解决不同task的loss量级或者学习速度不同,求出最优的task权重组合;

大致流程都相同:先对模型参数进行反向传播进行更新,再使用各自的算法更新task权重

-------------------------------------------------------- END --------------------------------------------------------

以上的算法实现:github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2086394.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件包含之session.upload_progress的使用

目录 原理 环境搭建 渗透 结果 一次项目经历复现 原理 session.auto_start顾名思义&#xff0c;如果开启这个选项&#xff0c;则PHP在接收请求的时候会自动初始化Session&#xff0c;不再需要执行session_start()。但默认情况下&#xff0c;也是通常情况下&#xff0c;这…

k8s声明式管理方式(yaml文件实现)

首先在/opt目录下创建 mkdir k8s-yaml cd k8s-yaml/ yaml文件 1.deployment的部署方式 首先 kubectl explain deployment 获取它的类型kind和标签version vim nginx-deploy.yaml apiVersion: apps/v1 #定义api版本的标签 kind: Deployment #定义资源的类型&#xff08;kin…

【数模修炼之旅】10 遗传算法 深度解析(教程+代码)

【数模修炼之旅】10 遗传算法 深度解析&#xff08;教程代码&#xff09; 接下来 C君将会用至少30个小节来为大家深度解析数模领域常用的算法&#xff0c;大家可以关注这个专栏&#xff0c;持续学习哦&#xff0c;对于大家的能力提高会有极大的帮助。 1 遗传算法介绍及应用 …

网络安全面试经验80篇

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330#rd 《网安面试指南》http://mp.weixin.qq.com/s…

《JavaEE进阶》----5.<SpringMVC②剩余基本操作(CookieSession)>

Cookie和Session简介。 Spring MVC的请求中 Cookie的设置和两种获取方式 Session的设置和三种获取方式。 三、&#xff08;接上文&#xff09;SpringMVC剩余基本操作 3.2postman请求 3.2.10 获取Cookie和Session 1.理解Cookie 我们知道HTTP协议自身是“无状态”协议。 &qu…

2024.8.28 C++

使用C手动封装一个顺序表&#xff0c;包含成员数组一个&#xff0c;成员变量N个 代码 #include <iostream> //使用C手动封装一个顺序表&#xff0c;包含成员数组一个&#xff0c;成员变量N个 using namespace std;using datatype int; struct Seqlist { private:datat…

flink 实战理解watermark,maxOutOfOrderness,allowedLateness

watermark watermark的作用 就是延迟触发窗口&#xff0c;让乱序到达的元素依然能够落在正确的窗口内。为啥能实现这个效果&#xff0c;一直通过公式更新watermark,如果乱序到的元素就不能更新watermark,相当于就是延迟触发计算操作。触发时间 watermark 大于窗口的最大值allo…

我的易经代码

本人从2000年起&#xff0c;就开始写一款算命软件&#xff0c;第一版用的是powerbuilder。后来改成企业版&#xff0c;名为“始皇预测”&#xff0c;用Java Swing编写&#xff0c;支持五大神数&#xff0c;三式&#xff0c;主要应用还是六爻、四柱、风水&#xff0c;其它如称骨…

2024118读书笔记|《岳阳楼记》——天高地迥,觉宇宙之无穷;兴尽悲来,识盈虚之有数

2024118读书笔记|《岳阳楼记》——天高地迥&#xff0c;觉宇宙之无穷&#xff1b;兴尽悲来&#xff0c;识盈虚之有数 爱莲说陋室铭小石潭记醉翁亭记赤壁赋桃花源记归去来兮辞木兰辞阿房宫赋滕王阁序岳阳楼记 《岳阳楼记》范仲淹&#xff0c;都是背过的古文&#xff0c;挺不错的…

【Qt窗口】—— 工具栏

前情摘要&#xff1a; 工具栏相当于菜单栏中的众多快捷方式&#xff0c;毕竟很多操作都是通过菜单栏来直接访问的&#xff0c;但是可能会查找很长时间&#xff0c;首先就是查找在哪个菜单里面&#xff0c;打开菜单才能进一步操作。而工具栏则是把一些常用的操作都给列举出来&am…

生产者与消费者模型

生产者与消费者模型 生产者&#xff1a;生产数据的线程&#xff0c;这类的线程负责从用户端、客户端接收数据&#xff0c;然后把数据Push到存储中介。 消费者&#xff1a;负责消耗数据的线程&#xff0c;对生产者线程生产的数据进行&#xff08;判断、筛选、使用、响应、存储&…

C++必修:布隆过滤器的提出与实现

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C学习 贝蒂的主页&#xff1a;Betty’s blog 1. 布隆过滤器的引入 在我们注册游戏或者社交账号时&#xff0c;我们可以自己设置…

科学重温柯南TV版:基于B站视频数据分析

麻鸭&#xff0c;四年过去了&#xff0c;失踪人口回归。 第一篇就决定是你了。 看了柯南M27剧场版后&#xff0c;萌生了重温TV版的念头&#xff0c;但是1191集(截止24/8/26)的体量太恐怖了&#xff0c;遂取点巧&#xff0c;综合大V建议(知乎&#xff1b;公众号)和视频网站数据…

基于asp.net的驾校管理系统附源码

这是一个基于asp.net的webform框架开发的BS架构的系统&#xff0c;详情如下&#xff1a; 项目下载链接 链接&#xff1a;https://pan.quark.cn/s/0679e783ef71

【设计模式】创建型模式——抽象工厂模式

抽象工厂模式 1. 模式定义2. 模式结构3. 实现3.1 实现抽象产品接口3.2 定义具体产品3.3 定义抽象工厂接口3.4 定义具体工厂3.5 客户端代码 4. 模式分析4.1 抽象工厂模式退化为工厂方法模式4.2 工厂方法模式退化为简单工厂模式 5. 模式特点5.1 优点5.2 缺点 6. 适用场景6.1 需要…

深入理解OJ编程中的输入输出:11个经典题目详解与技巧分享及stringstream,sort详解

文章目录 1.多组输入计算ab2.给定组数计算ab3.给定组数计算ab&#xff08;如果为0则结束&#xff09;4.计算一些列数的和(第一个数为0时结束)5.计算一些列数的和&#xff08;告诉了有几组&#xff09;6.计算一系列数的和&#xff08;不告知几组和何时结束&#xff0c;每一组第一…

如何评估云服务器提供商可靠性与信誉度

在云计算时代&#xff0c;选择一个可靠和信誉良好的云服务器提供商对于个人用户和企业来说至关重要。以下是评估云服务器提供商可靠性与信誉度的关键指标和方法&#xff1a; 1. 服务水平协议&#xff08;SLA&#xff09;&#xff1a; 可用性承诺&#xff1a; 查看云服务器提供…

服务器内存飙升分析小记

1. 写在最前面 这个繁忙的八月真的是转瞬即逝&#xff0c;我明明感觉似乎好像才八月刚开始&#xff0c;但是其实已经到了八月的尾巴。这个月本来想抽空整理一下学习 AI 模型相关的东西&#xff0c;奈何每天不是在查问题就是在查问题的路上&#xff0c;不是在修 Bug 就是在写 B…

AI Lossless Zoomer v3.1.0.0 — 超实用的AI无损图片放大工具

AI Lossless Zoomer 是一款基于腾讯开源 Real-ESRGAN 算法的 AI 图片无损放大工具&#xff0c;支持多线程和批量处理&#xff0c;具备自定义输出格式和路径等高级设置选项&#xff0c;并允许用户选择不同的 AI 引擎进行图片放大处理。此版本修复了一些小 bug&#xff0c;并增加…