【笔记整理】元学习笔记
文章目录
- 【笔记整理】元学习笔记
- 一、元学习基础概念
- 1、概述(“多任务,推理,快速学习”)
- 1)Meta-learning(“学习如何学习” + “老千层饼”)
- 2)Transfer learning(”当下先验”)
- 3)Multi-task learning(”In 迁移/meta“)
- 2、分类 & 应用(“优化 + 度量 + 循环”)
- 1)优化器元学习(“相似任务,公用模型,特定任务”)
- 2)度量元学习
- 3)循环模型元学习
- 4)其他方面的应用
- 3、比较(“对象 + 采样 + 目的”)
- 1)机器学习 vs 元学习("完成一个任务" vs "掌握学习能力")
- 2)元学习 vs 迁移学习 ("任务空间 vs 任务迁移" + "关注训练集 vs 关注测试集")
- 3)元学习 vs 交叉验证
- 4)元学习 vs 增量学习(“迅速学习能力 vs 类别增量学习”)
- 5)元学习 vs 传统机器学习(★★★★★)
- a)元学习术语
- b)本质差异
- c)小总结
- 二、元学习具体算法
- 1、MAML算法的推导流程
- 1)算法介绍
- 2)MAML推导流程(“meta梯度,一阶近似 FOMAML“)
- 2、MAML与Reptile的比较
- 3、与迁移学习的比较
- 4、实验分析(”迁移 vs 元学习“)
- 1)回归问题
- 2)分类问题
- 3)MAML代码实现
参考
- 元学习——MAML论文详细解读
- 迁移学习概述(Transfer Learning)
- 一文入门元学习(Meta-Learning)
- 基于度量的元学习和基于优化的元学习 - 知乎 (zhihu.com)
- 元学习(Meta Learning)与迁移学习(Transfer Learning)的区别联系是什么? - 许铁-巡洋舰科技的回答 - 知乎
一、元学习基础概念
1、概述(“多任务,推理,快速学习”)
元学习目的:探索针对多个任务有效的学习策略。
1)Meta-learning(“学习如何学习” + “老千层饼”)
学习如何学习 = 学习泛化能力 = 关注测试集 (不局限于单个任务中的训练集)
meta learning 俗称元学习,目标是 learn to learn,即学会如何学习。听起来有点绕,大白话解释就是通过之前任务的学习使得模型具备一些先验知识或学习技巧,从而在面对新任务的学习时,不至于一无所知。这更接近于人的学习过程,我们人在过去的经历中,会不断地积累学习经验,使自己的知识积累变得越来越丰富,所以在面对新问题的时候,并不是一无所知的,可以自动借鉴之前相似问题的经验来解决新问题。所以元学习也被称为是机器实现通用人工智能的关键技术。
meta-learning 学习的对象是 Tasks,而不是 Samples 样本点,因为 meta-learning 最终要解决的问题是在新的 task 上可以更好的学习,所以要迁移之前 task 上的学习经验。那么在训练阶段,输入的就是不同的 tasks,如下图所示,所有的 task 都是五分类任务,每个 task 仍然有训练集和测试集,训练集是 5 类不同的图片,该 task 的测试集是这 5 类中没有出现过的样本。不同的 task 对应的 5 个类别是不一样的,那么在若干个这样的 task 上训练之后,需要在一个新的任务上进行 meta 的测试/推理,测试任务是从未见过的 5 个类别的样本,让模型在这些样本上进行微调的训练,只不过这时候的模型在训练时就已经具备了之前学习到的 “经验”,从而可以快速适应测试任务。
举个现实的例子:老师在教课期间,如何衡量学生当前学的好不好? 如何衡量学生学习能力强不强?学的好不好就是通过当前科目上的考试成绩来判断,学习能力强不强,则是通过学习时间来判断,比如学习七天考到90分和学习一天就考到90分,是两种不同的学习能力。
meta-learning希望教会模型学习的能力,基于过去训练的经验提高模型的学习能力。给定一系列训练任务,期望模型在新任务上快速调整和学习。元学习同样不涉及保留过去任务的知识,防止遗忘。
分享一个关于元学习的搞笑的图。。。
2)Transfer learning(”当下先验”)
transfer learning 迁移学习,同样也是迁移之前学习到的 “经验”,在新的数据上进行微调,比如用在 ImageNet 大数据集上预训练的 VGG 等模型,在自己的图片数据集上微调 VGG 进行特征提取,不过这里和 meta-learning 有本质的区别,稍后会详细说明。
3)Multi-task learning(”In 迁移/meta“)
multi-task learning 是多任务学习,多个任务一起进行训练,以达到相互辅助训练的作用,这里的多任务可以是同一数据多个目标任务,也可以是多个数据同一个目标任务,如在人脸识别数据集上,既进行人脸识别任务,又要预测出该人脸的性别和年龄等。
迁移学习和元学习都属于多任务学习。
Few shot learning(meta元学习在CV中的应用)
少样本学习,是指一份数据中可用来训练的样本很少,比如只有 10 条或者 5 条样本,那么这时候用常规的训练方式,是学不出什么的,因为可用信息太少了,那么自然就会想到用 meta-learning 的方式来训练,借助之前任务的先验经验来学习少样本的任务。
few shot learning 可以说是 meta learning 在监督学习中的一个典型应用,而 meta-learning 个人觉得是一个思想框架,可以用在少样本数据上也可以用在多样本数据上,只不过在 few-shot 的场景下,更能发挥出它的威力。比如一个10条样本的分类数据,用普通的训练方式,可能只取得 10% 的准确率,但用 meta-learning 的方式训练可以取得 70% 的准确率。样本数量比较多的时候,用普通的训练方式就可以取得不错的效果,比如准确率 95%,用 meta-learning 的方式可能取得 97% 的准确率,但预训练过程就比较麻烦了。
few shot learning 中还有两个比较特殊的场景,就是 one shot 和 zero shot,即只有1个样本,甚至是零训练样本的场景,不过不在这次的讨论范围之内,如果大家比较感兴趣可以自行查找这方面的论文,few shot learning 目前也是学术上的研究热点。
2、分类 & 应用(“优化 + 度量 + 循环”)
参考
- 走进元学习:概述不同类型的元学习方法_模型 (sohu.com)
- 最前沿:百家争鸣的Meta Learning/Learning to learn
人类在学习时,会根据具体情况采用不同的方法。同样,并非所有的元学习模型都采用相同的技术。一些元学习模型关注的是优化神经网络结构(NAS),而另一些模型(如Reptile)则更注重于寻找合适的数据集来训练特定的模型。
加州大学伯克利分校人工智能实验室最近发表了一篇研究论文,文中全面列举了不同类型的元学习。以下是笔者最喜欢的一些类型:
1)优化器元学习(“相似任务,公用模型,特定任务”)
论文题目[Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks]
文章开头提到 meta-learning 的研究共有三个方向,第一个方向就是 optimization based meta-learning,而 MAML 是这个方向的开山之作,所以要想知道 MAML 是怎么做的,首先要知道这个方向是如何实现 meta learning 的。
思考一下,我们平时普通 learn 的模式是怎样训练模型的?以 DNN 网络模型为例,首先是搭建一个网络模型,接着对模型中每层的参数进行初始化,然后不断的进行“前向计算 loss -> 反向传播更新参数”的过程,直到 loss 收敛。这个过程中,模型初始时对当前数据是一无所知的,所以要通过随机初始化的方式对参数进行赋值,尽管用多种初始化方式,但总归都是随机的。那么有没有方法可以让模型从一个给定的位置开始训练呢,并且这个初始位置给的好的话,比如就在全局最优解附近,可能只需要迭代几次模型就收敛了。答案是肯定的,这个方向的 meta-leaning 就是来做这个事情的。简单总结下就是**「optimization based meta-learning 是通过之前大量的相似任务的学习,给网络模型学习到一组不错的/有潜力的/比较万金油的参数,使用这组参数作为初始值,在特定任务上进行训练,只需要微调几次就可以在当前的新任务上收敛了」**,这句话有几个值得注意的地方或者使用要求:
- 相似任务:并不是随便找一个数据就可以拿来进行训练。
- 共用一个网络模型:我们的最终目的是使用 DNN 模型在任务 A 上进行训练,为了避免随机初始化的方式,故而采用 meta- learning 方式,对这个 DNN 模型先进行预训练,这个预训练的过程就是 meta 训练,其结果是得到一组不错的 DNN 参数,然后在任务A上进行微调。所以从始至终,就只有一个相同的 DNN 模型。
- 特定任务:这个特定任务是我们实际关心的任务,也是 meta 的推理任务,所以要和 meta 训练阶段的大量任务具有一定的相似性,如果差异性太大,那么 meta 学到的这组参数可能不起作用甚至还不如随机初始化的参数。「乍一看是不是觉得和迁移学习有点像,最终形式都是从一组已知参数开始微调,但是这两个方式是有本质的区别的,这个后面还会再讲到。」
优化器元学习模型的重点是学习如何优化神经网络从而更好地完成任务。这些模型通常包括一个神经网络,该神经网络将不同的优化应用于另一个神经网络的超参数,从而改进目标任务。那些专注于改进梯度下降技术的模型就是优化器元学习很好的体现,就像该研究中发布的那些模型。
- 模型无关:模型无关不是说任意模型都行,这个范围太广了,是指任意的可以通过梯度下降进行优化训练的模型,这个模型一般都是网络模型,也可以是支持随机梯度下降的机器学习模型。
- 快速适应:意思是模型经过较少的迭代次数就可以在当前任务上收敛
- 多场景:论文的方法可以适用到分类、回归、强化学习等场景
2)度量元学习
度量元学习的目标是确定一个高效率学习的度量空间。该方法可以看作是小样本元学习的一个子集,通过使用学习度量空间来评价学习质量并举例说明。该研究论文向读者展示如何将度量元学习应用于分类问题。
3)循环模型元学习
该类型的元学习模型适用于循环神经网络(RNNs),比如长短期记忆网络(LSTM)。在这种架构中,元学习器算法将训练RNN模型依次处理数据集,然后再处理任务中新输入的数据。在图像分类设置中,这可能涉及到依次传递数据集(图像、标签)对的集合,然后是必须分类的新示例。元强化学习就是这种方法的一个例子。
4)其他方面的应用
参考
- learn2learn库
- Meta-Learning in Neural Networks: A Survey
粗略认识:
1)元学习能够从单族多任务场景中学习到与任务无关的知识,用于学习该族的新的任务
2)元学习已成功应用在少样本的图像检测,无监督学习,强化学习,超参数优化,和网络体系结构探索
3)虽然元学习给出的定义是:学习如何学习(提高从指定任务中学习的能力,并应用在部分看到的任务上),
但是这个概念包括了迁移学习,多任务学习,特征选择和模型的集成学习,并不适合元学习精准定义:
1)内部是传统机器学习(base),外部是元学习,内外可以使用不同的损失函数。根据数学公式可知,通过最小化元学习Loss,则meta-learning
通过测试任务集合来学习它内部的模型,寻找改模型的较优参数w,其中w为元知识分类:
元学习就是在传统的学习方法上,套上一层模型(我愿称其为“反思”模型),因此作为模型,肯定离不开如下三个方面
1)元学习的表征:关于元知识w的元学习,该知识可以是初始化模型的参数(元知识就是对整个模型的抽象,而元数据则是对集群各节点(工作节点,文件系统)等信息的抽象)
2)元学习的优化器:关于在外层如何学习到元知识,这时就需要优化器来对知识进行优化
3)元学习的目标函数:根据元学习的目的,可以选择不同的损失函数(提高少样本学习的准确率,快速多样本的优化,领域迁移到鲁棒性问题研究,标签噪声和对抗攻击等)
元学习的实现方式可以多种多样:可以用以往任务的学习来预测loss,或者用来预测梯度,可以对模型梯度进行二次计算(MAML模型公式)
3、比较(“对象 + 采样 + 目的”)
1)机器学习 vs 元学习(“完成一个任务” vs “掌握学习能力”)
参考一文入门元学习(Meta-Learning)(附代码) - 知乎 (zhihu.com)
目的 | 输入 | 函数 | 输出 | 流程 | |
---|---|---|---|---|---|
Machine learning | 通过训练数据,学习到输入X与输出Y之间的映射,找到函数f | X | f | Y | 1.初始化f参数 |
2.喂数据<X,Y> | |||||
3.计算loss,优化f参数 | |||||
4.得到:y = f(x) | |||||
meta learning | 通过(很多) 训练任务T及对应的训练数据D,找到函数F。F可以输出一个函数f,f可用于新的任务 | (很多)训练任务及对应的训练数据 | F(学习能力) | f(要学习的任务) | 1. 初始化F参数 2.喂训练任务T及对应的训练数据D,优化F参数 3.得到:f=F* 4. 新任务中:y=f(x) |
在机器学习中,训练单位是一条数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位分层级了,第一层训练单位是任务,也就是说,元学习中要准备许多任务来进行学习,第二层训练单位才是每个任务对应的数据。
二者的目的都是找一个Function,只是两个Function的功能不同,要做的事情不一样。机器学习中的Function直接作用于特征和标签,去寻找特征与标签之间的关联;而元学习中的Function是用于寻找新的f,新的f才会应用于具体的任务。有种不同阶导数的感觉。又有种**老千层饼的感觉,**你看到我在第二层,你把我想象成第一层,而其实我在第五层。。。
机器学习学习某个数据分布X到另一个分布Y的映射。 而元学习学习的是某个任务集合D到每个任务对应的最优函数 f ( x ) f(x) f(x)的映射(任务到学习函数的映射)
我理解元学习输出的f是任务与任务间的关联关系。举个例子,好比要学好物理,需要先掌握数学基础知识。其中一个任务相当于一个包含train,test的batch,每个任务=某个知识,任务间的关系可以理解成知识间的拓扑结构。
2)元学习 vs 迁移学习 (“任务空间 vs 任务迁移” + “关注训练集 vs 关注测试集”)
参考
- 元学习(Meta Learning)与迁移学习(Transfer Learning)的区别联系是什么? - 许铁-巡洋舰科技的回答 - 知乎
- 一文入门元学习(Meta-Learning)
迁移关注当下,元学习关注潜力
从目标上看元学习和迁移学习并无本质区分都是增加学习器在多任务的范化能力, 但元学习更偏重于任何和数据的双重采样, 任务和数据一样是需要采样的,而学习到的 F ( x ) F(x) F(x)可以帮助在未见过的任务 f ( x ) f(x) f(x)里迅速建立mapping。 而迁移学习更多是指从一个任务到其它任务的能力迁移,不太强调任务空间的概念。
机器学习围绕一个具体的任务展开, 然而生物体及其一生, 学习的永远不只是一个任务。 与之相对应的叫做元学习, 元学习旨在掌握一种学习的能力, 使得智能体可以掌握很多任务。 如果用数学公式表达, 这就好比先学习一个函数 F ( x ) F(x) F(x),代表一种抽象的学习能力, 再此基础上学习 f ( x ) f(x) f(x)对应具体的任务, 如下图所示 。
3)元学习 vs 交叉验证
元学习大致可以理解了:就是通过采样,获得多个任务(每个任务包含train,test)的任务空间,上层学习任务间的关系,下游学习该任务下的模型参数。模型在学完这个任务的基础上,通过“元学习”分配的另一个相似任务(知识拓扑),提高模型的学习能力。
“交叉验证”参考
- 交叉验证–关于最终选取模型的疑问 - 简书 (jianshu.com)
- sklearn之交叉验证
Q:交叉验证是什么?
A:如果数据集很小,把数据集划分成三部分的话,训练模型的数据就大大减少了,并且结果会取决于(训练集,验证集)的随机选择。因此就出现了交叉验证(cross-validation,简称CV)。一般是用k-fold CV,也就是k折交叉验证。训练集被划分成k个子集,每次训练的时候,用其中k-1份作为训练数据,剩下的1份作为验证,按这样重复k次。交叉验证计算复杂度比较高,但是充分利用了数据,对于数据集比较小的情况会有明显优势。
Q:对于交叉验证部分,不说K-fold的K是多少,那么最后会产生多(K)个模型,一般来说,每次训练出来的模型的参数都是不一样的,所以最后既然已经训练完成了,那么我应该选取那个模型呢?
A:交叉验证并不是为了去获取一个模型,他的主要目的是为了去评估这个模型的一些特性。交叉验证主要在于验证模型的泛化能力。
4)元学习 vs 增量学习(“迅速学习能力 vs 类别增量学习”)
元学习强调的是掌握迅速学习的能力,类别可能是不变的;而增量学习强调学习更多的类别。这里要区分好两者,弄清楚增量学习的概念即可。
参考
- 增量学习-学习总结(上) - 知乎 (zhihu.com)
- 增量学习(Incremental Learning)小综述 - 知乎 (zhihu.com)
5)元学习 vs 传统机器学习(★★★★★)
参考中国计算机学会通讯-第8期:群智智能制造 - 从机器学习到元学习的方法论 演变
a)元学习术语
先补充一下元学习的术语:
- 训练任务集 :由一组或多组训练集 - 测试集(
验证集)对构成。由于有时测试集(验证集)体现的是预期目标域(域内,跨域)的泛化信息,又被称为元数据集。其中每个训练集与测试集均对应一组带标记数据集。为了避免元学习中的**“训练 - 测试”任务集与这些训练 - 测试数据集产生混淆,通常也称某个任务中的训练集为支撑集**(support set),测试集为查询集(query set)。相应地,测试任务通常由仅包含有监督信息的支撑集构成,用以基于元学习机产生对此数据集匹配的超参赋值,即制订其合适的学习方法论。(理解:
元学习的训练集(每个任务包括支撑集,查询集) 和测试集,当前的元学习机的预测参数对下一次迭代的机器学习模型进行超参赋值,这里我理解的超参是指可以控制模型 w w w参数自动化调整的参数,比如lr,momentum,weight-decay
等。 这里理解有误,之前理解错了,元学习包括支撑集和查询集,其中查询集起到测试集的作用,之前把查询集误解成验证集了)
- 元学习机 :具有参数化结构的备选元决策函数, 其输入为一个学习任务相应的表达特征(可简单理解为训练数据集所传达的特征),输出为机器学习某个环节的超参赋值,如图(a)所示。
(理解:输入训练数据集(支撑集,查询集)表达的特征,输出机器学习的超参赋值(机器学习模型的初始化参数))
- 元表现度量 :指导对元学习机参数进行优化学习的量化目标函数。在元学习的层次上,元表现度量与传统机器学习的表现度量具有本质区别。传统机器学习对某个学习机参数表现度量构建的基本依据是该参数对应的决策函数对测试集中数据标记获得的预测精度。而在元学习的框架下,其元学习机参数的表现度量需要相对较为复杂的运算获得:对于每个训练任务中的支撑集,可利用该元学习参数获得相应机器学习相关环节的超参赋值,也就是得到一个确定的机器学习执行方法,然后将该方法应用于查询集中,可获得相应的验证精度。该验证精度即可表达该元学习机参数的表现度量。
(理解:传统机器学习参数表现度量通过测试集获得预测精度,而元表现度量是通过查询集获得验证精度)
- 优化算法 :以元学习机参数为优化参数,以元表现度量为优化目标,运用优化工具获得元学习机参数的合理估计,从而得到用以超参赋值的元学习者基本形式。
当获得该元学习机之后,面对新的学习任务,只需利用该元学习机对机器学习方法各环节进行超参赋值,利用该确定性机器学习执行模式,即可习得相应的标记预测函数,如图 (b)所示。这体现了元学习的“学会学习”功能,即其能够学习到学习方法论 的核心内涵
b)本质差异
接着简要概述元学习与传统机器学习的本质差异:
拟合对象:
- 元学习用以拟合的对象不是带标记的数据集, 而是成对的学习任务集,这意味着不同于传统机器学习凝练数据预测共有规律的学习方式,元学习的目标转变为从多个任务中总结其有效执行方法论的共有规律(理解:元学习是总结多个任务中共有的规律,即学习方法论)。
输出参数:
元学习训练输出的参数是用以预测机器学习超参的元学习机参数,而不是传统机器学习中用以预测数据标记的学习机参数,因而体现了**取代“学会标记预测”的“学会学习方法论”**的元学习内涵。
(理解:元学习输出的参数用以预测机器学习模型的超参(下一步超参赋值),而机器学习模型输出的参数用以预测数据标记)
元学习表现度量:
元学习表现度量不再是可直接算出的对训练数据标记预测的准确度,而是元学习机在支撑集所获的超参赋值下,学习方法在查询集上获得的学习效果和表现。这是由于元学习的学习对象,即学习方法论的特殊性决定的,方法论的优劣只能通过其在具体任务上的执行结果来度量。
(理解:元表现度量是在支撑集上对机器学习进行超参赋值,并通过查询集上的验证精度进行度量。在具体任务上的执行结果来度量方法论的优劣)
泛化对象:
- 最为重要的,元学习用以泛化的对象是机器学习超参赋值的共有规则, 即对于具体任务的执行方法论。在理想学习效果的前提下,这一方法论可以自然诱导具有不同数据模态、 数据尺寸、网络模型、优化算法等的不同学习任务之间优质方法论知识的折中利用与迁移借鉴,从而有望带来比传统机器学习更为强大的学习功能。
c)小总结
总结来说,这种超越传统机器学习层次的新型学习模式有望实质性增强机器学习自身的自动化与智能化程度,在诸多应用领域的多个层面取得重要进展。其变革性特征体现为:
- 将传统机器学习面向“数据集”进行“数据有效预测规律”学习的基本模式,转变为面向“任务集”进行“任务有效执行规律”学习的全新模式。其目标为通过探索针对多个任务有效的学习策略, 归纳多任务预测策略的共性规律,从而将其快速迁移用于对新任务的学习过程。该元学习模式突破了传统机器学习针对单个任务进行数据预测规律学习的传统思维,革新性地针对多个任务进行其共有“学习方法论”的学习,对于进一步深入探索机器学习理论内涵, 挖掘机器学习潜在能力,扩展机器学习有效应用边界, 均具有重要价值与作用。
- 特别地,其有望全面改善原有机器学习存在的本质问题。如:
- 针对“大数据”问题, 可通过学习大量小数据任务的学习规律,从而习得保证依赖于小数据进行有效预测的共有方法论,进而将其应用于实际的小样本学习任务中;
- 针对“大算力”问题,由于理想的学习**“学习方法论”模式避免了过度的人工调参**,超参赋值转变为自动化的元学习预测问题, 因而有望大量节省盲目调参带来的算力浪费;
- 针对“大模型”问题,理想的通过元学习调参而进行自动搭建网络结构(被称为网络结构搜索,NAS)的想法可以减少网络结构设计的困难,而更为先进的方法通过在已有性能优良的网络结构上引入超参,利用元学习进行结构微调,可获得更加快速有效的网络结构调整策略。
另外需要强调的一点是,当把训练任务集里的支撑集 D s u p D^{sup} Dsup 设置为“源”任务,而将查询集 D q u e D^{que} Dque 设置为“目标”任务,执行以上双边优化过程,可以实现传统机器学习中类似领域迁移的层次更高的任务迁移。其合理性在于 :这种学习方式将在源任务上执行的方法论规律,在目标任务的指导下学习,从而将目标任务的内在信息实质性嵌入源任务学习的方法论之中,从而达到减少学习结果对于目标任务产生过度偏差的目的。从这种意义上讲,此时的查询集起到了更高层的指导作用,其信息反映了所求元学习机旨在正确泛化的方向与趋势,因此亦常被称为元数据 。以带偏差数据的鲁棒性机器学习问题为例,此时 D s u p D^{sup} Dsup 代表包含有噪声标记或具有显著类不均衡性的现实数据,而 D q u e D^{que} Dque 代表一小部分理想条件下收集的高质量标记数据。则通过以上元学习的方式,可以实现利用元数据指导合理超参调整,在偏差数据上实现鲁棒学习的目标。该方法采用的元学习机格式,能够使得所学习的学习方法论制定规律 直接泛化用于新的偏差数据任务的鲁棒方法论设计之中。这种学习模式,对于拓宽传统迁移学习的应用范畴,实现更为高层的机器学习泛化目标, 具有一定程度的启发意义。
二、元学习具体算法
参考
- Meta-Learning: Learning to Learn Fast
- 元学习——MAML论文详细解读
1、MAML算法的推导流程
Model-Agnostic Meta-Learning(MAML,即与模型无关的元学习)
1)算法介绍
以监督学习中少样本分类场景为例,MAML算法流程如下(其中要注意支撑集和查询集的类别空间相同,而且样本可能存在重复,但是测试集和前面两个集合无论是在样本还是类别,都不重复):
-
第一个 Require 很关键,也往往容易忽略,就是要求所有的任务都服从一个分布 P ( T ) P(T) P(T),每个任务是从这个分布中采样得到的。这就说明了不能随意拿一个 Task 放到 MAML 框架里,必须满足某种相似性。(参考正态分布采样及Python实现 )
这里即反复随机抽取task T T T 组成的task池,作为MAML的训练集。有的小伙伴可能要纳闷了,训练样本就这么多,要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些task的query set会成为其他task的support set?没错!就是这样!我们要记住,MAML的目的,在于fast adaptation,即通过对大量task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tune就可以快速拟合
这里的query set是就是测试集,之前把query Set误解成验证集了。在MAML中,query集和support集可以类别相同。)
-
第二个 Require,这两个是模型超参数,我们的设置是,有一个用于多分类的网络模型 DNN,现在要通过元学习的方式得到网络模型的一组初始参数,为了得到这组参数,MAML 设计了两层训练,一个是 Task 内部的训练更新,更新的学习率是 α α α,一个是外部的网络模型的训练更新,更新的学习率是 β β β。可以把内部任务的训练,想象成前面例子中学生在每个科目上的学习过程,而外部的更新,则是老师调整学生学习方向的过程。
- 训练前首要先随机初始化网络模型所有参数 θ \theta θ
- 设置一个外层训练结束条件,例如迭代10000步等
- 外层训练就是 Meta 训练过程,上面讲过 Meta 训练的对象是 task,所以 meta 每一次的迭代都要从任务分布 P ( T ) P(T) P(T)中随机抽取一个 batch 的 tasks
- 针对 3 中的每个 task,执行下面的过程:
- 在一个具体的任务 T i T_i Ti中随机抽取包含 K K K个数据点的训练样本集 D D D
- 使用交叉熵损失函数或者均方差损失函数,在 D D D 上计算出损失函数 L t i L_{t_i} Lti和损失函数对 θ \theta θ的梯度。
- 使用 6 计算出的梯度,在该任务上使用一次梯度下降更新模型参数 θ \theta θ: θ i ′ = θ − a v ‾ \theta^{'}_i = \theta - a\overline{v} θi′=θ−av
- 从该任务 T i T_i Ti中抽样出一个测试集 D ′ D' D′ ,用于 meta 的参数更新
- 每个 task 都执行 6,7,8 三步,直到这一个 batch 的 task 都执行完
- 这一步就是 meta 的更新过程, 虽然 7 也有更新模型参数,但是那是任务内的局部更新,并没有改变外面网络模型的参数,这一步就是要改变网络模型的参数了,仍然是使用梯度下降法进行更新,只不过更新用到的梯度是每个任务 T i T_i Ti在各自的测试集 D ′ D' D′上计算出的batch 个梯度的平均梯度,更新的步长是 β \beta β。可以把这一步的更新,想象成上面例子中老师根据学生的平均考试成绩来调整其学习方向的过程,考试题目自然是各科上没见过的题目,这也就是第 8 步抽取测试集 $D’ 的作用,然后 ∗ ∗ 每个任务用自己更新过的模型参数 的作用,然后**每个任务用自己更新过的模型参数 的作用,然后∗∗每个任务用自己更新过的模型参数\theta^{'}_i$在 D ′ D' D′上进行前向计算得到一个 loss**,用这个 loss 再对 θ i ′ \theta^{'}_i θi′进行求导得到测试集 D ′ D' D′上的梯度,这个梯度就相当于是该任务上的考试成绩(用测试集loss求 θ i ′ \theta^{'}_i θi′梯度是为了更新meta阶段的参数)。
这里要事先对6,7步说明一点:
上面的过程中有一个问题需要**「事先说明」:可以看到每个任务内部只更新了一次参数,也就是 6,7 两步只做了一次梯度更新**,但其实也可以进行多次的梯度更新,就是把 6,7 两步重复执行几次。那作者这里为什么只写一次呢?这就是作者高明的地方了,那就是做了一个**「最大化假设」。我们的最终目的是希望 MAML 训练出的参数,在新的任务上进行少量几次的微调就可以收敛,那最好的结果就是只更新一次就收敛了,所以在 MAML 训练过程中,作者就特意设计每个任务内部只更新一次参数,以此来训练这个模型 “「更新一次就可以最大化性能」**”的能力。类比到上面讲的例子,那就是老师希望该学生具备强大的学习能力,在新的没有见过的科目上只学习一天就可以考出好成绩,为了训练该学生的这个能力,就让他在训练的每个科目上都学习一天然后考试一次,老师根据平均考试成绩调整学生的学习方向,不断地重复这个过程,直到平均考试成绩可以到 90 分以上就结束训练,此时老师就认为这个学生具备了“在新科目上学习一天就能考出好成绩”的能力。
算法简化版本为(在meta-update时,测试集中的样本一定要是支撑集中没出现过的(不一定是novel cls,可以是相同类别),而查询集没有这个约束(其样本可以存在于支撑集中),查询集就是测试集,下面红字的意思是在进行meta-update更新时query集必须是不同于support集的set,即再次采样,但是却没说support集中的样本是否可以重复存在于query集中):
可以发现在更新 θ ∗ \theta^* θ∗时,需要计算多次梯度(对应不同更新的模型)
2)MAML推导流程(“meta梯度,一阶近似 FOMAML“)
参数的更新策略为:根据全局参数 ϕ \phi ϕ,在进行每一个 t a s k i task_i taski时进行内部参数 θ i \theta_i θi的局部更新; 根据更新后的内部参数 θ i ′ \theta^{'}_i θi′对测试集计算loss, 并逐一对全局参数 ϕ \phi ϕ计算梯度,得到新的全局参数 ϕ ′ \phi^{'} ϕ′。
上图是 MAML 训练时模型参数的更新过程,其中 ϕ \phi ϕ是网络模型的初始参数,也就是伪算法中的 1, 那一步 θ i ′ \theta^{'}_i θi′是任务内部在 ϕ \phi ϕ上更新一次后的参数,也就是伪算法的第 7 步, L ( ϕ ) L(\phi) L(ϕ)是所有task 在各自测试集 D ′ D' D′上的 loss 和,用 L ( ϕ ) L(\phi) L(ϕ)对模型参数进行求导得出梯度,来进行meta的参数更新,也就是真正更新网络模型的参数。图中右边的过程就是把 meta 梯度下降更新的数学过程展开,其中最关键的一步是蓝色弯箭头标出的那个变换,就是第二个等号到第三个等号的那一步,其它步骤还都比较好理解,下面来详细看下关键这步的变换,其中主要是 l i ( θ ’ i ) l^{i}(\theta’_i) li(θ’i)对 ϕ \phi ϕ求导不好求,如果这个可以算出来,剩下的步骤就好说了。
上图就是计算 l i ( θ ’ i ) l^{i}(\theta’_i) li(θ’i)对 ϕ \phi ϕ求导的过程,因为 θ ’ i \theta’_i θ’i是由 ϕ \phi ϕ经过一次梯度下降更新得到的, 其实 ϕ \phi ϕ是一组参数向量,代表网络模型的各个参数,所以可以将求导展开成向量形式,向量每个元素 l i ( θ ’ i ) l^{i}(\theta’_i) li(θ’i)是对 ϕ \phi ϕ的求导,也就是上图中的红框1,那如何计算 l i ( θ ’ i ) l^{i}(\theta’_i) li(θ’i)对 ϕ \phi ϕ求导呢?我们知道是由经过梯度下降公式得到的,那么和的关系就是下面这样( θ ’ i \theta’_i θ’i是一个向量):
也就是 ϕ i \phi_i ϕi和 θ i j ′ \theta'_{ij} θij′每个都是有关系的, θ i ′ \theta'_{i} θi′又是由多个 θ i j ′ \theta'_{ij} θij′组成的,所以 l ( θ i ′ ) l(\theta'_{i}) l(θi′)对 ϕ i \phi_i ϕi的求导就是对上面的链路求导的和,每个路径的求导则是 l ( θ i ′ ) l(\theta'_{i}) l(θi′)对 θ i j ′ \theta'_{ij} θij′求导结果和 θ i j ′ \theta'_{ij} θij′对 ϕ i \phi_i ϕi的求导结果相乘,也就是上图中红框2所在的公式,其中的关键是红框2的位置,也就是 θ i j ′ \theta'_{ij} θij′对 ϕ i \phi_i ϕi的求导, θ i j ′ \theta'_{ij} θij′是 ϕ i j \phi_{ij} ϕij经过梯度下降公式变过来的,也就是图中的红色5标记的地方, 所以 θ i j ′ \theta'_{ij} θij′对 ϕ i j \phi_{ij} ϕij的求导就有两种情况, i = j i=j i=j和 i ≠ j i \neq j i=j, 时,计算结果就是红框4所处的公式, i = j i=j i=j时就是红框3的公式,可以看到这两个公式中都出现了二阶的偏导,二阶偏导求起来比较麻烦会影响到计算速度,所以作者使用了一阶近似的方法 first-order approximation,也就是把公式中的二阶偏导近似为0,这样近似后就简单很多,即 θ i j ′ \theta'_{ij} θij′对 ϕ i j \phi_{ij} ϕij的求导在 i = j i=j i=j时约等于1,在 i ≠ j i \neq j i=j时约等于0。然后顺着图中的蓝色箭头一步步带入,最后就会得到 l ( θ i ′ ) l(\theta'_{i}) l(θi′)对 ϕ i \phi_i ϕi的求导近似等于 l ( θ i ′ ) l(\theta'_{i}) l(θi′)对 θ i ′ \theta'_{i} θi′的求导,再回到更新 meta 参数的公式来看就简单了:
上图红框标出的公式就是 meta 更新参数时实际做的事情,这个式子可以这样看
这是什么意思呢? g i g_i gi是第 i i i个任务在其测试集上计算出的梯度方向,从几何上看,这个式子的更新过程是这样的:
蓝色点表示网络模型真正的参数,绿色第一个箭头表示在其训练集 D 上计算的梯度,绿色第二个箭头表示在其测试集D’上计算出的梯度,蓝色箭头表示 meta 模型网络模型参数的方向,可以看到它就是在每个任务的测试集的梯度方向上不断的去做更新。从这个过程中可以看出来,MAML 真正更新网络模型参数时,「关心的是测试集上的梯度,而不是每个任务上训练集的梯度」,也就是说,它更新的每一步的目标,都是使得更新后的参数能在以后的测试集上表现的更好,正是因为这样,才能说明 meta 停止更新时的参数具有很好的潜力/学习能力,这个能力使得这组参数在之后新的任务上微调几次就可以在该任务上取得很好的性能,当然理想情况还是微调一次就能取得不错的成绩,如果一次微调更新效果不好,那还可以再继续多次的微调更新。这也与我们最初希望的目标,即能在新任务上快速适应相吻合,即使该任务只有少量的训练样本,比如10条或者5条,甚至是1条样本,也能快速的学习到一些有效特征。
该图是不是验证了该小章节的开头的一句小总结:
参数的更新策略为:根据全局参数 ϕ \phi ϕ,在进行每一个 t a s k i task_i taski时进行内部参数 θ i \theta_i θi的局部更新; 根据更新后的内部参数 θ i ′ \theta^{'}_i θi′对测试集计算loss, 并逐一对全局参数 ϕ \phi ϕ计算梯度,得到新的全局参数 ϕ ′ \phi^{'} ϕ′。(这难道不和蚁群算法中局部信息素,全局信息素的更新很相似吗?目的就是要尽快找到全局最优解?)
基于上面的分析,可以用一个流程图来表示通用 MAML 的训练更新过程:
2、MAML与Reptile的比较
下面是Reptile模型(简化版的MAML)
具体步骤:
- sampling a task,
- training on it by multiple gradient descent steps,
- and then moving the model weights towards the new parameters.
Retitle vs MAML
共同点:
Reptile和MAML元学习优化框架都要保存多个模型
不同点:
- MAML对于每个任务(支撑集)都需要采样查询集,为inner每个模型计算loss,在meta-update阶段通过累加的loss求梯度,接着对meta模型的参数进行更新。
- Reptile无需通过查询集来计算误差,直接通过多次在支撑集上训练,得到内部每个模型的参数并保存。在meta-update阶段,通过累加计算inner每个模型的参数和当前meta-iter元模型初始化参数的差再求平均,该阶段无需计算梯度,直接通过meta-lr来更新元模型初始化参数即可。
基于上面的分析,可以用一个流程图来表示通用 MAML 的训练更新过程:
这个图中是以 task 内部更新 k 次参数为例的,当 k=1 的时候 Reptile 就和 MAML 一模一样了。
3、与迁移学习的比较
前面讲过,元学习和迁移学习有相似的地方,形式上都是在之前的任务上进行预训练,然后获得一组参数,然后用这组参数在新的任务继续微调,但它们是有本质的区别的。想想迁移学习的预训练是怎么训练的,比如在 ImageNet 大数据集上预训练的 ResNet、VGG 这些网络模型,它们在训练的时候是用在 ImageNet 训练集上的 loss 算出来的梯度来更新模型参数的,以训练集上的 loss 为准,关心的是当前模型参数在训练集上的性能如何。而元学习 MAML 在训练期间是用测试集上的 loss 算出的梯度来更新模型参数的,以测试集上的 loss 为准,不关心在当前训练集上的性能,而是关心这组参数在之后的测试集上的性能如何,也就是这组参数的潜力。换句话说,在 MAML 这篇论文中,是看这组参数在更新一次后的模型参数在测试集上能够表现多好,而不是训练期间能够多好,这种潜力也与元学习的大目标相符,即 Learn to learn 学会如何学习从而具备某种学习能力或学习技巧,可以在新的任务上快速学习。类比到上面老师和学生的例子,也很好理解,老师每次都是以学生的平均考试成绩为方向进行调整,这个考试成绩自然是每门功课上没有见过的题目,只有这样才能训练出该学生的学习能力。从几何上来看,迁移学习预训练模型的参数更新过程是这样的:
这就能看出和 MAML 不一样的地方了,迁移学习的预训练每次更新参数时,都是在当前任务上训练集的梯度方向上进行更新。
4、实验分析(”迁移 vs 元学习“)
1)回归问题
已有预训练模型(学习旧曲线),在给定新曲线的少数样本下,拟合新曲线
论文中关于回归问题的例子是,拟合正弦函数曲线,所有任务的分布 p ( T ) p(T) p(T)就是正弦函数分布 y = a ∗ s i n ( x + b ) y=a*sin(x+b) y=a∗sin(x+b),不同的任务只需要抽样不同的 a 和 b 即可,按照上面讲的 MAML 训练过程,在若干个不同a和b的正弦函数上进行预训练,然后用预训练出的网络模型在新的正弦函数样本上进行测试,这个新的正弦函数是训练期间没有见过的一组a和b,只给出少量的训练样本,如5个或10个。论文中对比了 MAML 模型和迁移学习预训练模型,在这个新的正弦函数上的预测性能,注意不管是哪种模型在这个新的任务上都还是要进行训练的,只不过这个训练是在之前参数的基础上微调,这个新任务对于 meta 来说就是推理任务,而在任务内部还是需要微调更新的。下图就是 MAML 模型和预训练模型在新的正弦函数上训练之后,在其测试集上的表现。
左边两个图是 MAML 模型的结果,左边第一个图是用 MAML 的思路训练出的模型,在新正弦函数的 5 个样本上微调之后,进行预测的结果。可以看到新的正弦函数,在训练时只给了分布在右半部分的 5 个点,其中红色线是真实分布,浅绿色线是不进行微调直接用预训练参数进行预测的结果,可以看出来预训练参数跑出的结果已经有了初步的形状。深绿色线是微调一次参数后进行预测的结果,此时预测出的曲线已经基本拟合真实的正弦函数了,在包含训练样本的右半侧可以完全拟合,在左半边的曲线,模型虽然没有见过这部分的样本但也可以学习出它的周期性质,在形状上基本拟合。左边第二个图不同的是,给出了新的正弦函数的10个训练样本,可以看到 MAML 在进行一次微调后,基本就可以拟合全部曲线了,在进行十次微调后,拟合程度更进一步。
右边两个图是同样的设置下,迁移学习预训练模型的表现,浅蓝色曲线是直接进行预测的结果,可以看到和真实分布相差甚远,尤其是波峰的位置,完全没对上,在微调1次和10次之后,相比于不微调,有一点进步,但和真实分布相比,依然相差较大。并且模型发生了过拟合现象,如果样本点只在右半部分,那模型在右半部分的拟合上表现还行,在另一半的位置上表现更差。如果迁移学习预训练的任务足够多的话,它训练出的模型对应的曲线应该是一条接近水平的直线,因为每个任务都以训练集上的 loss 为主,这么多任务的 loss 加起来更新参数时,梯度应该接近于0。从几何上理解就是,很多个正弦函数叠加在一起,其趋势就是一个水平线,同一个点,可能是波峰也可能是波谷,中间水平线的位置才能让所有任务上的 loss 最小,这就是迁移学习预训练和元学习的质的差别。
上面的图可能不是很明显,有第三方的作者复现了这个回归实验,并且重新绘制了这部分的图,如下所示,这个图看起来更明显些。
论文中也对微调次数进行了实验,结果如下图:
红色线是 oracle 设置组的结果,oracle 就是在训练时加入了该任务真实的a和b作为特征,相当于提前知道了真实分布,所以在这个设置下训练的模型,在新任务上的 mse loss 基本为 0 ,绿色线是 MAML 的模型,横轴是微调次数,可以看到微调一次的模型,就可以得到很低的 mse 误差,而随着微调次数增加,性能也逐渐提升,不过由微调1次变为2次,提升还比较明显,后面的提升就不明显了,尤其是在5次微调之后,基本就没有提升了。蓝色线是迁移学习的预训练模型,可以看到不管是微调几次,其 mse 值都很大,与 MAML 的模型相比,更是相差甚远。
2)分类问题
分类问题场景是两个少样本学习中常见的基准数据集:MiniImagenet 和 Omniglot,下图是在 Omniglot 数据集上的结果:
其中 5-way 是表示5分类,1-shot 表示训练时每个类别下只有一个样本,5-shot就是每个类别下只有5个样本,可以看到不管在哪个设置下,MAML 模型的表现都是最好的。下图是在 MiniImagenet 数据集上的结果:
在这个数据上,作者还对比了使用一阶微分近似和不使用的结果,还记得一阶微分近似是啥吗?不记得的话,请往上翻看数学公式部分,可以看到使用了一阶微分近似,在效果上相差不大,但是作者证明在速度上可以提高 33% 左右,这可是一个性价比很高的改进。
3)MAML代码实现
参考
- Model-Agnostic Meta-Learning在少样本上的应用
- Model-Agnostic Meta-Learning在强化学习上的应用