浅谈多任务学习

news2025/1/15 22:50:00

目录

一、前言及定义

二、多任务学习(MTL)的两种方法 

2.1 参数的硬共享机制(hard parameter sharing)

2.2 参数的软共享机制(soft parameter sharing)

三、多任务学习模型

3.1 MT-DNN

3.2 ERNIE 2.0

四、多任务学习与其他学习算法的关系


一、前言及定义

多任务学习(Multi-task learning)是和单任务学习(single-task learning)相对的一种机器学习方法。在机器学习领域,标准的算法理论是一次学习一个任务,也就是系统的输出为实数的情况。复杂的学习问题先被分解成理论上独立的子问题,然后分别对每个子问题进行学习,最后通过对子问题学习结果的组合建立复杂问题的数学模型。多任务学习是一种联合学习,多个任务并行学习,结果相互影响。

  用大家经常使用的school data做个简单的对比,school data是用来预测学生成绩的回归问题的数据集,总共有139个中学的15362个学生,其中每一个中学都可以看作是一个预测任务。单任务学习就是忽略任务之间可能存在的关系分别学习139个回归函数进行分数的预测,或者直接将139个学校的所有数据放到一起学习一个回归函数进行预测。而多任务学习则看重 任务之间的联系,通过联合学习,同时对139个任务学习不同的回归函数,既考虑到了任务之间的差别,又考虑到任务之间的联系,这也是多任务学习最重要的思想之一。

  

如果用含一个隐含层的神经网络来表示学习一个任务,单任务学习和多任务学习可以表示成下图所示:

二、多任务学习(MTL)的两种方法 

2.1 参数的硬共享机制(hard parameter sharing)

 

参数的硬共享机制是神经网络的多任务学习中最常见的一种方式,一般来讲,它可以应用到所有任务的所有隐层上,而保留任务相关的输出层。如上图所示,前几层dnn为各个任务共享,后面分离出不同任务的layers。这种方法有效降低了过拟合的风险: 模型同时学习的任务数越多,模型在共享层就要学到一个通用的嵌入式表达使得每个任务都表现较好,从而降低过拟合的风险。 直观来将,这一点是非常有意义的。越多任务同时学习,我们的模型就能捕捉到越多任务的同一个表示,从而导致在我们原始任务上的过拟合风险越小。

2.2 参数的软共享机制(soft parameter sharing)

在这种方法下,每个任务都有自己的模型,有自己的参数,但是对不同模型之间的参数是有限制的,不同模型的参数之间必须相似,由此会有个distance描述参数之间的相似度,会作为额外的任务加入到模型的学习中,类似正则化项。

三、多任务学习模型

3.1 MT-DNN

微软提出的 MT-DNN模型 (Multi-task Deep Neural Network)是一个简单有效的尝试。MT-DNN的模型结构如下图所示,模型主要包含两个部分,分别是多任务的共享编码层(与BERT一致)以及任务相关的输出层。MT-DNN模型考虑了四种不同类型的语言理解类任务,分别是单句文本分类(如CoLA、SST-2)、句对文本分类(如RTE、MNLI、MRPC)、文本相似度(回归问题)以及相关性排序(排序问题)。不同类型的任务对应不同的输出层结构与参数。模型输入的构造方式与BERT基本一致,即“[CLS]文本1 [SEP]文本2 [SEP]”的形式。

模型的训练过程分为两个阶段,首先对多任务共享的编码层进行预训练,方法与BERT模型一致;然后利用各个任务的标注数据以及相应的损失函数进行有监督的多任务学习。与T5模型类似,经过多任务学习的MT-DNN模型可以在特定任务上进一步精调,通常能够取得更好的效果。

3.2 ERNIE 2.0

除了利用相关NLP任务的有限标注数据,还可以从无标注以及弱标注数据中抽象出一系列任务联合学习,以进一步提升预训练模型的能力。百度的研究人员在ERNIE模型的基础之上做了改进,分别从词法、句法及语义层面构造了更加丰富的预训练任务,并通过连续多任务学习(Continual Multi-task Learning)的方式进行增量式预训练,从而得到了 ERNIE 2.0模型 。ERNIE 2.0模型框架如下图所示。
在模型的输入层,除了常用的词向量、块向量和位置向量,ERNIE2.0使用了一个额外的任务向量(Task embedding)。每一个预训练任务对应一个独立的任务编码(1,2,···)并被转化为连续向量表示,在训练过程中更新。使用任务向量是多任务学习中的常用手段,尤其是在任务较多的情况下。这与T5、GPT-3等生成模型中使用的任务提示(Prompt)思想是类似的。模型的输出层分别对应以下预训练任务。
1.词法相关预训练任务
  • ERNIE模型原有的单词、实体、短语掩码模型;
  • 单词的大写(Capitalization)预测;
  • 单词--文档关系预测(预测输入文本块中的词是否出现在同一文档的其他文本块)。
2.语法相关预训练任务
  • 句子重排序:对于随机打乱的文本块,恢复其原始顺序;
  • 句子距离预测:判断输入的两个句子是来自同一文档的两个相邻句子,或是同一文档的两个不相邻句子,或来自不同文档。因此是一个多分类问题。
3.语义相关预训练任务
  • 篇章关系(Discourse relation)预测:对句对间的修辞关系分类。这里用到了由无监督方法构建的篇章关系数据集
  • 信息检索相关性(IR Relevance)。这里需要用到搜索引擎的查询日志:取搜索引擎的查询与文档的标题作为模型的输入句对,如果该文档没有出现在搜索结果中,则认为两者不相关。否则,根据用户是否点击进一步分为强相关与弱相关。
关于模型的训练过程,ERNIE 2.0采用了连续多任务学习的方式,在训练过程中逐渐增加任务数量并进行多任务学习。在维持整体迭代次 数不变的条件下,自动为每个任务分配其在各个阶段多任务学习中的迭代次数。实验结果表明,这种训练方式既可以避免连续学习 (Continual learning)的知识遗忘问题,也能够使各个任务得到更有效的训练。ERNIE 2.0在中、英文各项任务上都取得了出色的表现,同时为预训练任务的设计及多任务学习的机制带来了很多启发。

四、多任务学习与其他学习算法的关系

  1. transfer learning:定义一个源域一个目标域,从源域学习,然后把学习的知识信息迁移到目标域中,从而提升目标域的泛化效果。迁移学习一个非常经典的案例就是图像处理中的风格迁移

  2. multi-task:训练模型的时候目标是多个相关目标共享一个表征,比如人的特征学习,一个人,既可以从年轻人和老人这方面分类,也可以从男人女人这方面分类,这两个目标联合起来学习人的特征模型,可以学习出来一个共同特征,适用于这两种分类结果,这就是多任务学习

  3. multi-label:打多个标签,或者说进行多种分类,还是拿人举例啊,一个人,他可以被打上标签{青年,男性,画家}这些标签。如果还有一个人他也是青年男性,但不是画家,那就只能打上标签{青年,男性}。它和多任务学习不一样,它的目标不是学习出一个共同的表示,而是多标签

  4. multi-class:多分类问题,可选类别有多个但是结果只能分到一类中,比如一个人他是孩子、少年、中年人还是老人

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/178598.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数学建模学习笔记(9)多元线性回归分析(非常详细)

多元线性回归分析1.回归分析的地位、任务和分类2.数据的分类3.对线性的理解、系数的解释和内生性4.取对数预处理、虚拟变量和交互效应5.使用Stata进行多元线性回归分析6.异方差7.多重共线性8.逐步回归法1.回归分析的地位、任务和分类 回归分析的地位:数据分析中最基…

cclow 面试心得

开源ccflow学习的一些心得目录概述需求:设计思路实现思路分析1.心得参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for change,challenge …

JavaEE-文件和IO(一)

目录一、文件1.1 认识文件1.2 树型结构组织和目录1.3 文件路径二、Java中操作文件2.1 文件系统相关的操作一、文件 1.1 认识文件 平时说的文件一般都是指存储再硬盘上的普通文件,形如txt,jpg,MP4,rar等这些文件都可以认为是普通…

Java集合常见面试题(四)

Map 接口 HashMap 的底层实现 JDK1.8 之前 JDK1.8 之前 HashMap 底层是 数组和链表 结合在一起使用也就是 链表散列,数组是 HashMap 的主体,链表则是主要为了解决哈希冲突而存在的。 HashMap 通过 key 的 hashcode 经过扰动函数(hash函数&…

JAVA基础知识08集合基础

目录 1. 集合 1.1 什么是集合? 1.2 ArrayList 1.2.1 ArrayList 长度可变原理 1.2.2 集合和数组的使用选择 1.2.3 ArrayList 集合常用成员方法 1. 集合 1.1 什么是集合? 集合是一种容器,用来装数据的,类似于数组。 其长度可…

线段树的懒标记与应用

目录 一、前言 二、Lazy-tag技术 1、update() 中的lazy-tag 三、例题 1、区间修改、区间查询(lanqiaoOJ 1133) 一、前言 本文主要讲了线段树的Lazy-tag技术和一道例题,建议自己要多练习线段树的题目。 二、Lazy-tag技术 背景&#xf…

水面漂浮物垃圾识别检测系统 YOlOv7

水面漂浮物垃圾识别检测系统通过PythonYOLOv7网络模型,实现对水面漂浮物以及生活各种垃圾等全天候24小时不间断智能化检测。Python是一种由Guido van Rossum开发的通用编程语言,它很快就变得非常流行,主要是因为它的简单性和代码可读性。它使…

Linux- 系统随你玩之--文本处理三剑客-带头一哥-awk

文章目录1、awk概述2、awk原理2.1、 awk 工作原理2.2、 与sed工作原理比较2.3、 awk与sed的区别3、使用方法及原理3.1、格式如下:3.2、 匹配规则3.3、 参数说明3.4、处理规则与流程控制3.5、 常用 awk 内置变量3.6、 awk 正则表达式解释4、操作实例4.1、 准备工作4.…

(十七)抽象队列同步器AQS

AQSAbstractQueuedSynchronizer抽象同步队列简称AQS,它是实现同步器的基础组件,并发包中锁的底层就是使用AQS实现。类图如下,AbstractQueuedLongSynchronizer与AbstractQueuedSynchronizer结构一模一样,只是AbstractQueuedSynchro…

Springboot+java师生交流答疑作业系统

,本系统拥有学生,教师,管理员三个角色,学生可以注册登陆系统,查看新闻,查看教学,在线提问答疑,提交作业,发布交流,留言反馈等功能,教师可以发布教…

恶意代码分析实战 14 反虚拟机技术

14.1 Lab17-01 题目 这个恶意代码使用了什么反虚拟机技术? 恶意代码用存在漏洞的x86指令来确定自己是否运行在虚拟机中。 如果你有一个商业版本IDAPro,运行第17章中代码清单17-4所示的IDAPython脚本(提供如jindAniM.py)&#…

spring boot前后端交互之数据格式转换

在前后端分离开发的项目种,前端获取数据的方式基本都是通过Ajax。请求方法也有所不同,常见的有POST,GET,PUT,DELETE等。甚至连请求的数据类型都不一样,x-www-form-urlencodeed,form-data,json等。 那么在前后端交互过程中,具体的数据该如何接…

ESP32设备驱动-8x8LED点阵驱动(基于Max7219+SPI)

8x8LED点阵驱动(基于Max7219+SPI) 1、Max7219介绍 MAX7219/MAX7221是紧凑型串行输入/输出共阴极显示驱动器,可将微处理器(Ps)连接到多达8位的7段数字LED显示器、条形图显示器或64个独立LED。片上包括一个 BCD 代码 B 解码器、多路扫描电路、段和数字驱动器,以及存储每个数字…

通信电子、嵌入式类面试题刷题计划04

文章目录036——看门狗电路的作用是什么?【社招】037——你了解CAN总线协议吗?说一说你的理解【社招】038——锁存器、触发器、寄存器三者的区别?【校招】039——D触发器和D锁存器的区别是什么?【校招】040——三极管和MOS管的区别…

Cadence PCB仿真使用Allegro PCB SI生成单网络EMI报告Single Net EMI Report及报告导读图文教程

🏡《Cadence 开发合集目录》   🏡《Cadence PCB 仿真宝典目录》 目录 1,概述2,生成报告3,报告导读4,总结1,概述 单网络EMI报告是值将差分模式下的网络视为单个网络,分析来自时钟上升沿的辐射影响。本文简单介绍使用Allegro PCB SI生成单网络EMI报告的方法,及Singl…

搜索引擎位置跟踪应用SerpBear

什么是 SerpBear ? SerpBear 是一款开源搜索引擎位置跟踪应用程序。它允许你跟踪你的网站在谷歌中的关键词位置,并得到他们的位置通知。 软件特点: 无限关键词:添加无限域名和无限关键词以跟踪其 SERP电子邮件通知:每天/每周/每…

车载以太网简介

车载以太网简介 基本概念 传统车载网络 LIN:用于通信速率低的场景,比如车窗、座椅等。CAN:目前车载网络首先,低成本高可靠。FlexRay :具备故障容错的车载总线系统。MOST:内置流媒体数据信道,…

2023年企业信息安全缺陷和解决方案,防止职员外泄信息

随着网络的发展和普及,信息安全与每个人息息相关,包含方方面。每个人既是独立个体又必须和社会交换资源。这就需要把控一个尺度。 要了解信息安全,首先需要对信息有个大体了解。从拥有者和使用者分类分为,个人,企业&a…

恶意代码分析实战 11 恶意代码的网络特征

11.1 Lab14-01 问题 恶意代码使用了哪些网络库?它们的优势是什么? 使用WireShark进行动态分析。 使用另外的机器进行分析对比可知,User-Agent不是硬编码。 请求的URL值得注意。 回答:使用了URLDownloadToCacheFileA函数&#…

JavaEE多线程-定时器

目录一、定时器1.1 什么是定时器?1.2 定时器的构成二、简单实现定时器一、定时器 1.1 什么是定时器? 定时器是多线程编码中的一个重要组件,它就好比一个闹钟,例如我们想去坐车,但是不想现在去坐车,想8:30去坐车,于是我们订了一个8点钟的闹钟,也就是说定…