MTN模型LOSS均衡相关论文解读

news2024/10/5 21:22:29

一、综述

MTN模型主要用于两个方面,1.将多个模型合为一个显著降低车载芯片负载。2.将多个任务模型合为一个,有助于不同模型在共享层的特征可以进行互补,提高模型泛化性能的同时,也有可能提高指标。传统的方法是直接不同任务loss相加或者人为设置权重,这样很费时,也很难找到最优解。接下来的论文将会为大家介绍一些更优秀的MTN方法。

二、依据任务不确定性加权多任务损失

论文地址: https://arxiv.org/pdf/1705.07115.pdf

 文章通过对损失函数求最大似然估计,同时引入不同任务的不确定性(可以理解为噪声),最大似然估计的推理结果如下:

1.两个任务都为回归任务

 2.一个任务为回归任务,一个任务为分类任务

 可以看到损失是由不确定性估计的倒数来加权的,后面的log(不确定性)是为了防止不确定性变得太大(类似于正则项)。当模型不确定性变小后,任务权重会增大,造成无效学习,所以论文里使用的(annealing the lr with a power law) 不会翻译。。。。。同时论文里也表明这个不确定性的初始化很鲁棒,都可以收敛的很好。

代码:

log_vars = nn.Parameter(torch.zeros((2)))

def criterion(y_pred, y_true, log_vars):
  loss = 0
  for i in range(len(y_pred)):
    weight = torch.exp(-log_vars[i])
    diff = (y_pred[i]-y_true[i])**2.
    loss += torch.sum(weight * diff + log_vars[i], -1)
return torch.mean(loss)

其中diff表示一个任务的loss,log_vars是可学习的参数。权重为weight = torch.exp(-log_vars[i]),后面加个log_vars[i]是一个惩罚项,防止任务不确定性变得太大,导致权重很小不更新。大家跟论文里对比一下会发现,模型学习参数log_vars[i] = log(不确定性**2),使用log是为了让其更加平滑稳定,便于学习。

在实际应用中,训mtn有的时候loss会变负,就是因为log_vars有可能为负的,把loss带成负的了。

所以后续有论文对其进行了改进,论文地址:https://arxiv.org/pdf/1805.06334.pdf

 将正则项变为log(1+log_vars**2),这样正则项就不会为负,同时还能起到正则化效果。具体代码如下:

class AutomaticWeightedLoss(nn.Module):
    def __init__(self, num=2):
        super(AutomaticWeightedLoss, self).__init__()
        params = torch.ones(num, requires_grad=True)
        self.params = torch.nn.Parameter(params)

    def forward(self, *x):
        loss_sum = 0
        for i, loss in enumerate(x):
            loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
        return loss_sum

Reference 

多任务权重自动学习论文介绍和代码实现 - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/514565.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谷歌打响全面反击战!官宣AI重构搜索、新模型比肩GPT-4,朝着ChatGPT微软开炮

明敏 丰色 发自 凹非寺 量子位 | 公众号 QbitAI 万众瞩目,谷歌的反击来了。 现在,谷歌搜索终于要加入AI对话功能了,排队通道已经开放。 当然这还只是第一步。 大的还在后面: 全新大语言模型PaLM 2正式亮相,谷歌声称…

vue解决跨域的几种办法

当我们遇到请求后台接口遇到 Access-Control-Allow-Origin 时,那说明跨域了。 跨域是因为浏览器的同源策略所导致,同源策略(Same origin policy)是一种约定,它是浏览器最核心也最基本的安全功能,同源是指&…

【C#】本地下载附件(Excel模板)

系列文章 【C#】单号生成器(编号规则、固定字符、流水号、产生业务单号) 本文链接:https://blog.csdn.net/youcheng_ge/article/details/129129787 【C#】日期范围生成器(开始日期、结束日期) 本文链接:h…

腾讯云发布金融国产化战略、《腾讯云银行核心系统分布式转型白皮书》

5月11日,在腾讯金融云国产化战略峰会上,腾讯云发布金融国产化战略:腾讯云将持续加大技术投入打造新一代国产化精品产品,并依托产品构建全栈领先的国产数字化基座。同时,腾讯云还将携手伙伴,共同构建国产数字…

PD虚拟机增加CentOS虚拟机磁盘空间

mac环境下安装了PD虚拟机,近期由于需求需要,扩容了其中一台CentOS7的磁盘空间。 做以下记录: 1.PD虚拟机操作: 2. CentOS7内部操作: 2.1 lsblk -f 查看 2.2 fdisk -f 查看,物理磁盘已增加容量 2.3 fdis…

leetcode 1035. Uncrossed Lines(不交叉的线)

把数组nums1和nums2的元素排成2行, 从左到右把相同的元素连线,但是任意2条线不能交叉。 比如Example1中,可以连2个4,也可以连2个2,但是不能同时连,因为会交叉。 找出最多的连线数。 思路: 一开…

Leetcode2383. 赢得比赛需要的最少训练时长

Every day a Leetcode 题目来源:2383. 赢得比赛需要的最少训练时长 解法1:模拟 可以分开考虑在比赛开始前,需要最少增加的精力和经验数量。 每次遇到一个对手,当前精力值都需要严格大于当前对手,否则需要增加精力值…

设计测试用例(万能思路 + 六种设计用例方法)(详细 + 图解 + 实例)

目录 一、设计测试用例的万能思路 二、设计用例的方法 1. 等价类 2. 边界值 3. 判定表法 4. 正交法 5. 场景设计法 6. 错误猜测法 一、设计测试用例的万能思路 针对某个物品/功能进行测试。 万能思路:功能测设 界面测试 性能测试 兼容性测试 易用性测试…

NASM 编译器 - 产生机器码“66”,导致无法正确打印

【问题描述】 代码hello-DOS.asm,实现功能:打印“hello world” ; hello-DOS.asm - single-segment, 16-bit "hello world" program ; ; assemble with "nasm -f bin -o hi.com hello-DOS.asm" [BITS 32]org 0x100 ; .com…

el-table多级嵌套列表,菜单使用el-switch代替

需求:根据el-table实现多级菜单复选,并且只要是菜单就不再有复选框,也没有全选按钮,一级菜单使用el-switch代替原有的列复选框,子级如果全部选中,那么父级的el-switch也会被选中,如下图&#xf…

2023年微单相机市场电商数据分析(京东数据查询分析)

5月10日,尼康发布了Z8微单相机,首发价格27999元。规格、性能等都可以看到官方的详细讲解。不过从目前业内人士以及数码爱好者的评价来看,Z8的配置匹配27999元的价格是比较有优势的。 并且有很多人表示,Z8一经推出很有可能会对自身…

6. N 字形变换

将一个给定字符串 s 根据给定的行数 numRows ,以从上往下、从左到右进行 Z 字形排列。 比如输入字符串为 "PAYPALISHIRING" 行数为 3 时,排列如下: P A H NA P L S I I GY I R 之后,你的输出需要从左往右逐…

mysql查询列添加序号

添加序号查询结果 # 每次值1 # 值从0开始 SELECT (i:i1) AS 序号,user.* FROM user, (SELECT i:0) AS itable;

【Java多线程编程】解决线程的不安全问题

前言: 当我们进行多线程编程时候,多个线程抢占系统资源就会造成程序运行后达不到想要的需求。我们可以通过 synchronized 关键字对某个代码块或操作进行加锁。这样就能达到多个线程安全的执行,因此我把如何使用 synchronized 进行加锁的操作…

ChatGPT插件推荐,效率提升100倍!

在浏览器上使用ChatGPT时,借助一些插件可以帮助我们更便捷的获取消息,比如: 在搜索引擎搜索东西的同时和ChatGPT对话; 同一个问题同时向ChatGPT、newBing、Claude 等多个模型提问获取结果; 让ChatGPT可以联网获取最新…

实时聊天如何做,让客户眼前一亮(一)

网站上的实时聊天功能应该非常有用,因为它允许客户支持立即帮助用户。在线实时聊天可以快速轻松地访问客户服务部门,而它也代表着企业的门面。 让我们讨论一下如何利用SaleSmartly(ss客服)在网站中的实时聊天视图如何提供出色的实…

Yolov8改进:小目标到大目标一网打尽,轻骨干重Neck的轻量级目标检测器GiraffeDet

1.GiraffeDet介绍 论文:https://arxiv.org/abs/2202.04256 🏆🏆🏆🏆🏆🏆Yolov8魔术师🏆🏆🏆🏆🏆🏆 ✨✨✨魔改网络、复现前沿论文,组合优化创新 🚀🚀🚀小目标、遮挡物、难样本性能提升 🍉🍉🍉定期更新不同数据集涨点情况 本文是阿里巴…

【Autoware】Open Planner代码分析

目录 包结构op_global_plannerop_global_planner_core.cpp中代码的主要逻辑 op_local_plannerop_trajectory_generatorop_behavior_selectorop_common_paramsop_motion_predictorop_trajectory_evaluator 本篇主要对Open Planner的代码进行分析,主要包括op_global_p…

FSS对象存储挂载到windows云服务器操作方法

FSS对象存储可以挂载到云主机中用于存储视频、备份等不需要 经常读写的大文件。不适合存放数据库等对IO需求较高、经常读写的场景。 1、远程登陆服务器,打开控制面板,然后点击“打开或关闭windows功能”。 windows2008系统: 选择“功能”-- …

图可视化工具Gephi使用教程

图可视化工具Gephi使用教程 操作界面介绍在Gephi界面完成图的绘制键盘输入导入CSV文件直接在概览界面鼠标点击创建自己创建一个红楼梦关系网络图用一个Web of Science上的数据创建一个有向关系图 静态随机数据使用动态数据的使用Gephi的可视化处理节点移动节点放大&缩小单个…