【论文笔记】Dual-Balancing for Multi-Task Learning

news2025/1/9 1:49:06

Abstract

多任务学习(Multi-task learning, MTL)中,任务平衡问题仍然是重要的挑战,损失、梯度尺度的不同,会导致性能的折中。
本文提出Dual-Balancing for Multi-Task Learning (DB-MTL),从损失和梯度两个角度缓解任务均衡问题。

DB-MTL通过对每个任务的损失进行对数变换,保证损失-尺度平衡,通过将所有任务梯度归一化到与最大梯度范数相同的幅度来保证梯度-幅度平衡。

1 Introduction

很多方法被提出,用来动态调整任务权重,可以粗略划分为损失平衡方法、梯度平衡方法。

本文重点关注同时平衡损失级别的损失尺度和梯度级别的梯度幅度,以减轻任务平衡问题。
不同任务的损失尺度和梯度幅度不同,大的一方可以左右模型更新的方向,导致部分任务的表现下降。

本文提出的DB-MTL简单且有效第平衡损失尺度和梯度规模。
① 对每个任务的损失施加对数变换 (logarithm transformation),保证所有任务的损失都有相同的尺度,是非参数的变换。对数变换有利于现有的梯度平衡方法,如图1所示。
② 将所有任务的梯度标准化成与最大梯度范数相同的幅度,这是免训练的,与GradNorm相比,所有梯度的大小都相同。归一化梯度大小对性能起着重要的作用,将其设置为任务中最大梯度范数效果最好,如图4所示。

总结贡献:

  • 提出DB-MTL方法,缓解任务平衡问题的双重平衡方法,包含损失尺度和梯度幅度平衡方法。
  • 大量实验证明DB-MTL在多个基准数据集上实现最先进的性能。
  • 实验结果表明,损失尺度平衡方法有利于现有的梯度平衡方法。

2 Related Works

给定 T T T个任务和每个任务 t t t的训练集 D t \mathcal{D}_t Dt,MTL的目标是在 { D t } t = 1 T \{\mathcal{D}_t\}_{t=1}^T {Dt}t=1T训练一个模型。MTL模型的参数包括两个部分:任务共享参数 θ \theta θ和任务独享参数 { ψ t } t = 1 T \{\psi_t\}_{t=1}^T {ψt}t=1T θ \theta θ占据MTL模型的大部分参数,这对性能至关重要。

ℓ t ( D t ; θ , ψ t ) \ell_t(\mathcal{D}_t;\theta,\psi_t) t(Dt;θ,ψt) ( θ , ψ t ) (\theta,\psi_t) (θ,ψt)下任务 t t t在数据集 D t \mathcal{D}_t Dt的平均损失。目标函数可表示为:
∑ t = 1 T γ t ℓ t ( D t ; θ , ψ t ) (1) \sum_{t=1}^T \gamma_t\ell_t(\mathcal{D}_t;\theta,\psi_t)\tag{1} t=1Tγtt(Dt;θ,ψt)(1)
其中 γ t \gamma_t γt是任务 t t t的任务权重。

Equal weighting (EW)是一种简单的MTL方法,设置所有任务 γ t = 1 \gamma_t=1 γt=1。但是EW会导致任务平衡问题,即某些任务执行不理想。因此后续提出了很多MTL方法,通过在训练过程中动态调整任务权重 { γ t } t = 1 T \{\gamma_t\}_{t=1}^T {γt}t=1T,来提高EW的性能。这可以归类为损失平衡(loss balancing)、梯度平衡(gradient balancing)、混合平衡(hybrid balancing)。

2.1 Loss Balancing Methods

这一类方法,通过不同的衡量方法动态地计算任务权重 { γ t } t = 1 T \{\gamma_t\}_{t=1}^T {γt}t=1T,如同方差不确定性 (homoscedastic uncertainty),学习速率 (learning speed),验证集性能 (validation performance)。不同于上述方法,IMTL-L期望所有任务的权重损失 { γ t ℓ t ( D t ; θ , ψ t ) } t = 1 T \{\gamma_t\ell_t(\mathcal{D}_t;\theta,\psi_t)\}_{t=1}^T {γtt(Dt;θ,ψt)}t=1T是常值,对每一个损失实施变换 e s t ℓ t ( D t ; θ , ψ t ) − s t e^{s_t}\ell_t(\mathcal{D}_t;\theta,\psi_t)-s_t estt(Dt;θ,ψt)st,其中 s t s_t st是第 t t t个任务上的可学习的参数,在每次迭代中通过梯度下降近似求解。

2.2 Gradient Balancing Methods

从梯度的角度,任何共享参数 θ \theta θ的更新取决于所有任务的梯度 { ∇ θ ℓ t ( D t ; θ , ψ t ) } t = 1 T \{\nabla_\theta\ell_t(\mathcal{D}_t;\theta,\psi_t)\}_{t=1}^T {θt(Dt;θ,ψt)}t=1T。梯度平衡方法旨在以不同的方式聚合所有任务梯度。例如MGDA将MTL表述为多目标优化问题,避免某些任务的梯度主导更新方向,目标是找到一个更新方向 d d d,使得所有任务的梯度在该方向上的投影尽可能小;CAGrad通过将聚合梯度约束到平均梯度周围来优化MGDA;MoCo通过引入类动量梯度估计和正则化项来缓解MGDA中的偏差;GradNorm通过学习任务权重来衡量任务梯度,使其具有接近的数量级;如果两个任务的梯度冲突,PCGrad会将一个任务的法平面投影到另一个任务的法平面;无论两个任务是否发生梯度冲突,GradVac都会将梯度对齐;GradDrop随机掩盖一些符号不一致的梯度值,IMTL-G学习任务权重,以确保聚合梯度在每个任务梯度上具有相等的投影;Nash-MTL将聚合梯度制定为纳什均衡。

2.3 Hybrid Balancing Method

Towards impartial multi-task learning, Liu et al.中,发现损失平衡和梯度平衡具有互补性,提出了IMTL-L和IMTL-G相结合的混合平衡法IMTL。

3 Proposed Method

3.1 Scale-Balancing Loss Transformation

不同类型的损失函数会带来不同的损失尺度,导致任务平衡问题。

假设损失规模的先验已知,可以选择 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st}t=1T,使得 { s t ⋆ ℓ t ( D ; θ , ψ t ) } t = 1 T \{s_t^\star\ell_t(\mathcal{D};\theta,\psi_t)\}_{t=1}^T {stt(D;θ,ψt)}t=1T具有相同的尺度,并最小化 ∑ t = 1 T s t ⋆ ℓ t ( D ; θ , ψ t ) \sum_{t=1}^T s_t^\star\ell_t(\mathcal{D};\theta,\psi_t) t=1Tstt(D;θ,ψt)。之前的工作在学习任务权重 { γ t } t = 1 T \{\gamma_t\}_{t=1}^T {γt}t=1T时直接学习 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st}t=1T,但由于训练过程无法获得最优的 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st}t=1T,这种方法会导致结果不是最优。

对数变换(Logarithmic transformation)可以实现所有损失达到相同的尺度,而不需要 { s t ⋆ } t = 1 T \{s_t^\star\}_{t=1}^T {st}t=1T
由于 ∇ θ , ψ t log ⁡ ℓ t ( D ; θ , ψ t ) = ∇ θ , ψ t ℓ t ( D ; θ , ψ t ) ℓ t ( D ; θ , ψ t ) \nabla_{\theta,\psi_t}\log\ell_t(\mathcal{D};\theta,\psi_t)=\frac{\nabla_{\theta,\psi_t}\ell_t(\mathcal{D};\theta,\psi_t)}{\ell_t(\mathcal{D};\theta,\psi_t)} θ,ψtlogt(D;θ,ψt)=t(D;θ,ψt)θ,ψtt(D;θ,ψt)(普通的对log求导),这等价于对调整了尺度的任务损失 ℓ t ( D ; θ , ψ t ) stop-gradient ( ℓ t ( D ; θ , ψ t ) ) \frac{\ell_t(\mathcal{D};\theta,\psi_t)}{\text{stop-gradient}(\ell_t(\mathcal{D};\theta,\psi_t))} stop-gradient(t(D;θ,ψt))t(D;θ,ψt)取梯度,这一项对于任意的任务都有相同的尺度。

Discussion

尽管对数变换可以轻易实现尺度平衡,在MTL中使用得很少。本文对其进行深入研究,并将其集成到现有的梯度平衡方法(PCGrad、GradVac、IMTL-G、CAGrad、Nash-MTL、Aligned-MTL),大幅提升了他们的性能,如图1所示,证明了在MTL中平衡损失规模的有效性。
![[Pasted image 20240807220230.png]]

图1:现有梯度平衡方法+损失尺度平衡方法在NYUv2上的表现。

IMTL-L用一个转换后的损失来处理损失尺度问题: e s t ℓ t ( D t ; θ , ψ t ) − s t e^{s_t}\ell_t(\mathcal{D}_t;\theta,\psi_t)-s_t estt(Dt;θ,ψt)st,其中 s t s_t st是第 t t t个任务上的可学习的参数,在每次迭代中通过梯度下降近似求解。这不能保证每次迭代中所有的损失尺度是相同的,而对数变换却可以。

3.2 Magnitude-Balancing Gradient Normalization

除了任务损失中的尺度问题,任务梯度也存在尺度问题。通过均匀平均所有可能主导最终梯度的任务的更新方向,导致次优的性能。

一个简单的方法是将任务梯度归一化到相同的幅度。
对于任务的梯度,计算一个batch的梯度 ∇ θ log ⁡ ℓ t ( D t ; θ , ψ t ) \nabla_\theta \log\ell_t(\mathcal{D}_t;\theta,\psi_t) θlogt(Dt;θ,ψt)的计算开销很大,通常使用小批量随机梯度下降方法。在第 k k k次迭代中,从 D t \mathcal{D}_t Dt中采样一个小批量 B t , k \mathcal{B}_{t,k} Bt,k(Algorithm 1中的第5步),计算这个小批量的梯度 g t , k = ∇ θ k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) g_{t,k}=\nabla_{\theta_k} \log\ell_t(\mathcal{B}_{t,k};\theta_k,\psi_{t,k}) gt,k=θklogt(Bt,k;θk,ψt,k)(Algorithm 1中的第6步)。在动态估计 E B t , k ∼ D t ∇ θ k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) \mathbb{E}_{\mathcal{B}_{t,k}\sim\mathcal{D}_t}\nabla_{\theta_k} \log\ell_t(\mathcal{B}_{t,k};\theta_k,\psi_{t,k}) EBt,kDtθklogt(Bt,k;θk,ψt,k)中使用了指数移动平均(Exponential moving average, EMA):
g ^ t , k = β g ^ t , k − 1 + ( 1 − β ) g t , k \hat{g}_{t,k}=\beta\hat{g}_{t,k-1}+(1-\beta)g_{t,k} g^t,k=βg^t,k1+(1β)gt,k
其中 β ∈ ( 0 , 1 ) \beta\in(0,1) β(0,1)控制遗忘率。
获得任务梯度 { g ^ t , k } t = 1 T \{\hat{g}_{t,k}\}_{t=1}^T {g^t,k}t=1T后,进行标准化,使得具有相同的 ℓ 2 \ell_2 2范数,计算聚合梯度为:
g ~ k = α k ∑ t = 1 T g ^ t , k ∣ ∣ g ^ t , k ∣ ∣ 2 (2) \tilde{g}_k=\alpha_k\sum_{t=1}^T\frac{\hat{g}_{t,k}}{||\hat{g}_{t,k}||_2}\tag{2} g~k=αkt=1T∣∣g^t,k2g^t,k(2)
其中 α k \alpha_k αk是控制更新尺度的尺度因子。标准化后,所有任务对更新的方向提供相同的贡献。

α k \alpha_k αk的选择对于缓解任务均衡的问题至关重要。
当某些任务梯度范数较大,其他任务梯度范数较小时,意味着模型 θ k \theta_k θk接近于前者尚未收敛而后者已经收敛的点。这一点在MTL中是不令人满意的,会导致任务平衡问题,因为期望的是所有任务都能实现收敛。因此, α k \alpha_k αk需要足够大来躲避这种不令人满意的点。
当所有任务的梯度范数都很小,表明模型 θ k \theta_k θk对于所有任务都接近令人满意的点了, α k \alpha_k αk应当足够小,使得模型 θ k \theta_k θk可以捕捉到这个最好的点。
因此可选择 α k = max ⁡ 1 ≤ t ≤ T ∣ ∣ g ^ t , k ∣ ∣ 2 \alpha_k=\max_{1\leq t\leq T} ||\hat{g}_{t,k}||_2 αk=max1tT∣∣g^t,k2

图4展示了在NYUv2数据集上,使用不同的策略调整 α k \alpha_k αk的表现区别,实验设置在4.1节。由此可见,最大范数策略表现得更好。
![[Pasted image 20240807231351.png]]

图4:选择 α k \alpha_k αk的不同策略在NYUv2数据集的表现 Δ p \Delta_p Δp

对损失和梯度进行缩放后,任务共享参数由 θ k + 1 = θ k − η g ~ k \theta_{k+1}=\theta_k-\eta\tilde{g}_k θk+1=θkηg~k(Algorithm 1中第10步)来更新。
对于任务独享参数 { ψ t } t = 1 T \{\psi_t\}_{t=1}^T {ψt}t=1T,不同任务之间是独立更新的,因此不必进行梯度缩放。因此任务独享参数由 ψ t , k + 1 = ψ t , k − η ∇ ψ t , k log ⁡ ℓ t ( B t , k ; θ k , ψ t , k ) \psi_{t,k+1}=\psi_{t,k}-\eta\nabla_{\psi_{t,k}}\log\ell_t(\mathcal{B}_{t,k};\theta_k,\psi_{t,k}) ψt,k+1=ψt,kηψt,klogt(Bt,k;θk,ψt,k)(Algorithm 1中第11~13步)更新。
![[Pasted image 20240807231916.png]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2033224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

成功解决:IDEA导入java项目 或 建包的时候com.不分开 【详细原理解释说明】

我 | 在这里 ⭐ 全栈开发攻城狮、全网10W粉丝、2022博客之星后端领域Top1、专家博主。 🎓擅长 指导毕设 | 论文指导 | 系统开发 | 毕业答辩 | 系统讲解等。已指导60位同学顺利毕业 ✈️个人公众号:热爱技术的小郑。回复 Java全套视频教程 或 前端全套视频…

子串 前缀和 | Java | (hot100) 力扣560. 和为K的子数组

560. 和为K的子数组 暴力法&#xff08;连暴力法都没想出来……&#xff09; class Solution {public int subarraySum(int[] nums, int k) {int count0;int len nums.length;for(int i0; i<len; i) {int sum0;for(int ji; j<len; j) {sumnums[j];if(sum k) {count;}…

C/C++复习 day2(模板,继承,多态)

C/C复习 day2 文章目录 C/C复习 day2前言一、模板1.模板的原理2.非类型模板参数3.模板的特化a. 函数模板的特化b. 类模板的特化1.全特化2.偏特化 4.模板的分离编译 二、继承1.继承的概念2.继承与派生类对象赋值转化3.隐藏1.成员变量的隐藏2. 成员函数的隐藏 4.继承中的友元5.继…

数据结构:栈(含源码)

目录 一、栈的概念和结构 二、栈的实现 2.1 头文件 2.2 各个功能的实现 初始化栈 入栈 出栈 获取栈顶元素和栈中有效个数 判断栈是否为空 栈的销毁 2.3 测试 完整源码 一、栈的概念和结构 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和…

[C++][opencv]基于opencv实现photoshop算法图像剪切

【测试环境】 vs2019 opencv4.8.0 【效果演示】 【核心实现代码】 //图像剪切 //参数&#xff1a;src为源图像&#xff0c; dst为结果图像, rect为剪切区域 //返回值&#xff1a;返回0表示成功&#xff0c;否则返回错误代码 int imageCrop(InputArray src, OutputArray dst,…

遥感影像-语义分割数据集:sar水体数据集详细介绍及训练样本处理流程

原始数据集详情 简介&#xff1a;该数据集由WHU-OPT-SAR数据集整理而来&#xff0c;覆盖面积51448.56公里&#xff0c;分辨率为5米。据我们所知&#xff0c;WHU-OPT-SAR是第一个也是最大的土地利用分类数据集&#xff0c;它融合了高分辨率光学和SAR图像&#xff0c;并进行了充…

Chromium编译指南2024 -Android篇:安装其他常用软件(三)

1.引言 在前面的章节中&#xff0c;我们详细讲解了编译 Chromium for Android 所需的系统和硬件要求&#xff0c;并介绍了如何配置开发环境&#xff0c;包括更改软件源和安装基本依赖。在完成这些基础配置之后&#xff0c;为了进一步提升开发和编译效率&#xff0c;您可能还需…

【Hot100】LeetCode—438. 找到字符串中所有字母异位词

目录 1- 思路哈希表 滑动窗口 2- 实现⭐438. 找到字符串中所有字母异位词——题解思路 3- ACM 实现 原题链接&#xff1a;438. 找到字符串中所有字母异位词 1- 思路 哈希表 滑动窗口 思路 哈希表&#xff1a;通过数组维护一个哈希表滑动窗口&#xff1a;通过控制数组的下标…

为何说本届巴黎奥运会中国金牌榜应排列第一?

为何说本届巴黎奥运会中国金牌榜应排列第一&#xff1f; 在奥运会上&#xff0c;金牌榜的排名一直是各国关注的焦点。然而&#xff0c;在历届奥运会中&#xff0c;关于金牌榜的统计方法和排名标准却存在一定的争议。尤其在中美两国之间&#xff0c;金牌榜的排名往往成为双方媒体…

制作好的excel报表设置打开密码或忘记密码怎么办?

excel工作表经常用来做数据统计、工资、报表等的文件格式&#xff0c;这些类型的文件都是很重要的数据资料&#xff0c;为此做这些数据的朋友们都会给他设置一个打开密码&#xff0c;不让其他人随便打开。但随着时间的流逝&#xff0c;我们做的数据报表越来越多了&#xff0c;做…

transformer(李宏毅老师系列)

自学参考&#xff1a; Transformer:Attention Is All You Need Transformer论文逐段精读 视频课 课件资料 笔记 一、引入 seq2seq&#xff1a;输入一个序列的向量作为input&#xff0c;output的长度由机器自己决定seq2seq model应用: 语音辨识 输入是声音讯号的一串vector 输出…

提高清晰度的全彩LED显示屏的关键要素

全彩LED显示屏作为现代广告宣传和信息传播的主要媒介&#xff0c;其清晰度在很大程度上决定了观众的视觉体验和信息传达的效果。随着人们对高清显示需求的不断提升&#xff0c;全彩LED显示屏也在向更高清、更细腻的显示效果迈进。那么&#xff0c;如何进一步提升全彩LED显示屏的…

6数字基石:掌握计算机语言、多媒体与系统工程

计算机语言 计算机语言是指用于人与计算机之间交流的一种语言&#xff0c;是人与计算机之间传递信息的媒介。计算机语言主要由一套指令组成&#xff0c;而这一种指令一般包括表达式、流程控制和集合三大部分内容。 表达式又包含变量、常量、字面量和运算符。 流程控制有分支…

善用 AI ,优化项目,保姆级简历写作指南第七弹

大家好&#xff0c;我是程序员鱼皮。做知识分享这些年来&#xff0c;我看过太多简历、也帮忙修改过很多的简历&#xff0c;发现很多同学是完全不会写简历的、会犯很多常见的问题&#xff0c;不能把自己的优势充分展示出来&#xff0c;导致错失了很多面试机会&#xff0c;实在是…

如何将TRIZ的“最终理想解”应用到机器人电机控制设计中?

TRIZ理论&#xff0c;作为一套系统的创新方法论&#xff0c;旨在帮助设计师和工程师突破思维惯性&#xff0c;解决复杂的技术难题。其核心思想之一便是“最终理想解”&#xff0c;它如同一盏明灯&#xff0c;指引着我们在技术创新的道路上不断前行。最终理想解追求的是产品或技…

“听到“温度 - 科学家发现人类感知的新层次

雷克曼大学&#xff08;IDC Herzliya&#xff09;伊夫切尔大脑、认知与技术研究所&#xff08;BCT Institute&#xff09;的研究人员发现了一种在很大程度上被忽视的感知能力&#xff0c;他们利用机器学习揭示了跨模态感知–不同感官模态之间的相互作用–的动态。在最近的一项研…

【HarmonyOS NEXT星河版开发学习】小型测试案例06-小红书卡片

个人主页→VON 收录专栏→鸿蒙开发小型案例总结​​​​​ 基础语法部分会发布于github 和 gitee上面&#xff08;暂未发布&#xff09; 前言 在鸿蒙&#xff08;HarmonyOS&#xff09;开发中&#xff0c;自适应伸缩是指应用程序能够根据不同设备的屏幕尺寸、分辨率和形态&…

2-63 基于matlab的GMPHD滤波器算法

基于matlab的GMPHD滤波器算法&#xff08;1&#xff09;本次仿真采用线性CV模型&#xff1b;&#xff08;2&#xff09;观测模型为线性条件下&#xff0c;观测值为X&#xff0c;Y轴坐标&#xff1b;&#xff08;3&#xff09;验证GMPHD算法对多目标跟踪的有效性&#xff1b;输出…

对于产品设计方面来说,3D 技术的应用有哪些优势?

3D技术在产品设计方面提供了许多优势&#xff0c;主要体现在以下几个方面&#xff1a; 1、可视化&#xff1a;设计师利用3D技术创建产品三维模型&#xff0c;使得产品在设计阶段就能被可视化&#xff0c;帮助团队更好地理解产品的外观和功能。 2、精确性&#xff1a;3D模型可…

人人都能搞定的大模型原理 - 神经网络

人工智能的发展起步于1950年&#xff0c;期间经历了各种里程碑和变革&#xff0c;与此相关的神经网络技术也从最初的单层感知到复杂的层级和卷积神经网络一路创新和变革&#xff0c;不断推动人工智能领域的发展&#xff0c;直到 2022 年 ChatGPT 的问世&#xff0c;彻底引爆了…