深度学习应用篇-元学习[14]:基于优化的元学习-MAML模型、LEO模型、Reptile模型

news2024/12/23 18:19:14

在这里插入图片描述
【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等

在这里插入图片描述
专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等

本专栏主要方便入门同学快速掌握相关知识。后续会持续把深度学习涉及知识原理分析给大家,让大家在项目实操的同时也能知识储备,知其然、知其所以然、知何由以知其所以然。

声明:部分项目为网络经典项目方便大家快速学习,后续会不断增添实战环节(比赛、论文、现实应用等)

专栏订阅:

  • 深度学习入门到进阶专栏
  • 深度学习应用项目实战篇

深度学习应用篇-元学习[14]:基于优化的元学习-MAML模型、LEO模型、Reptile模型

1.Model-Agnostic Meta-Learning

Model-Agnostic Meta-Learning (MAML):
与模型无关的元学习,可兼容于任何一种采用梯度下降算法的模型。
MAML 通过少量的数据寻找一个合适的初始值范围,从而改变梯度下降的方向,
找到对任务更加敏感的初始参数,
使得模型能够在有限的数据集上快速拟合,并获得一个不错的效果。
该方法可以用于回归、分类以及强化学习。

该模型的Paddle实现请参考链接:PaddleRec版本

1.1 MAML

MAML 是典型的双层优化结构,其内层和外层的优化方式如下:

1.1.1 MAML 内层优化方式

内层优化涉及到基学习器,从任务分布 p ( T ) p(T) p(T) 中随机采样第 i i i 个任务 T i T_{i} Ti。任务 T i T_{i} Ti 上,基学习器的目标函数是:

min ⁡ ϕ L T i ( f ϕ ) \min _{\phi} L_{T_{i}}\left(f_{\phi}\right) ϕminLTi(fϕ)

其中, f ϕ f_{\phi} fϕ 是基学习器, ϕ \phi ϕ 是基学习器参数, L T i ( f ϕ ) L_{T_{i}}\left(f_{\phi}\right) LTi(fϕ) 是基学习器在 T i T_{i} Ti 上的损失。更新基学习器参数:

θ i N = θ i N − 1 − α [ ∇ ϕ L T i ( f ϕ ) ] ϕ = θ i N − 1 \theta_{i}^{N}=\theta_{i}^{N-1}-\alpha\left[\nabla_{\phi} L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N-1}} θiN=θiN1α[ϕLTi(fϕ)]ϕ=θiN1

其中, θ \theta θ 是元学习器提供给基学习器的参数初始值 ϕ = θ \phi=\theta ϕ=θ,在任务 T i T_{i} Ti 上更新 N N N ϕ = θ i N − 1 \phi=\theta_{i}^{N-1} ϕ=θiN1.

1.1.2 MAML 外层优化方式

外层优化涉及到元学习器,将 θ i N \theta_{i}^{N} θiN 反馈给元学匀器,此时元目标函数是:

min ⁡ θ ∑ T i ∼ p ( T ) L T i ( f θ i N ) \min _{\theta} \sum_{T_{i}\sim p(T)} L_{T_{i}}\left(f_{\theta_{i}^{N}}\right) θminTip(T)LTi(fθiN)

元目标函数是所有任务上验证集损失和。更新元学习器参数:

θ ← θ − β ∑ T i ∼ p ( T ) ∇ θ [ L T i ( f ϕ ) ] ϕ = θ i N \theta \leftarrow \theta-\beta \sum_{T_{i} \sim p(T)} \nabla_{\theta}\left[L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N}} θθβTip(T)θ[LTi(fϕ)]ϕ=θiN

1.2 MAML 算法流程

  1. randomly initialize θ \theta θ
  2. while not done do:
  3. sample batch of tasks T i ∼ p ( T ) T_i \sim p(T) Tip(T)
  4. for all T i T_i Ti do:
    1. evaluate ∇ ϕ L T i ( f ϕ ) \nabla_{\phi}L_{T_{i}}\left(f_{\phi}\right) ϕLTi(fϕ) with respect to K examples
    2. compute adapted parameters with gradient descent: $\theta_{i}{N}=\theta_{i}{N-1} -\alpha\left[\nabla_{\phi}L_{T_{i}}\left(f_{\phi}\right)\right]{\phi=\theta{i}^{N-1}} $
  5. end for
  6. update $\theta \leftarrow \theta-\beta \sum_{T_{i} \sim p(T)} \nabla_{\theta}\left[L_{T_{i}}\left(f_{\phi}\right)\right]{\phi=\theta{i}^{N}} $
  7. end while

MAML 中执行了两次梯度下降 (gradient by gradient),分别作用在基学习器和元学习器上。图1给出了 MAML 中特定任务参数 θ i ∗ \theta_{i}^{*} θi 和元级参数 θ \theta θ 的更新过程。

图1 MAML 示意图。灰色线表示特定任务所产生的梯度值(方向);黑色线表示元级参数选择更新的方向(黑色线方向是几个特定任务产生方向的平均值);虚线代表快速适应,不同的方向代表不同任务更新的方向。

1.3 MAML 模型结构

MAML 是一种与模型无关的元学习方法,可以适用于任何基于梯度优化的模型结构。

基准模型:4 modules with a 3 × \times × 3 convolutions and 64 filters,
followed by batch normalization,
a ReLU nonlinearity,
and 2 × \times × 2 max-pooling。

1.4 MAML 分类结果

表1 MAML 在 Omniglot 上的分类结果。
Method5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
MANN, no conv (Santoro et al., 2016)82.8 % \% %94.9 % \% %
MAML, no conv89.7 ± \pm ± 1.1 % \% %97.5 ± \pm ± 0.6 % \% %
Siamese nets (Koch, 2015)97.3 % \% %98.4 % \% %88.2 % \% %97.0 % \% %
matching nets (Vinyals et al., 2016)98.1 % \% %98.9 % \% %93.8 % \% %98.5 % \% %
neural statistician (Edwards & Storkey, 2017)98.1 % \% %99.5 % \% %93.2 % \% %98.1 % \% %
memory mod. (Kaiser et al., 2017)98.4 % \% %99.6 % \% %95.0 % \% %98.6 % \% %
MAML98.7 ± \pm ± 0.4 % \% %99.9 ± \pm ± 0.1 % \% %95.8 ± \pm ± 0.3 % \% %98.9 ± \pm ± 0.2 % \% %
表1 MAML 在 miniImageNet 上的分类结果。
Method5-way 1-shot5-way 5-shot
fine-tuning baseline28.86 ± \pm ± 0.54 % \% %49.79 ± \pm ± 0.79 % \% %
nearest neighbor baseline41.08 ± \pm ± 0.70 % \% %51.04 ± \pm ± 0.65 % \% %
matching nets (Vinyals et al., 2016)43.56 ± \pm ± 0.84 % \% %55.31 ± \pm ± 0.73 % \% %
meta-learner LSTM (Ravi & Larochelle, 2017)43.44 ± \pm ± 0.77 % \% %60.60 ± \pm ± 0.71 % \% %
MAML, first order approx.48.07 ± \pm ± 1.75 % \% %63.15 ± \pm ± 0.91 % \% %
MAML48.70 ± \pm ± 1.84 % \% %63.11 ± \pm ± 0.92 % \% %

1.5 MAML 的优缺点

优点

  • 适用于任何基于梯度优化的模型结构。

  • 双层优化结构,提升模型精度和泛化能力,避免过拟合。

缺点

  • 存在二阶导数计算

1.6 对 MAML 的探讨

  • 每个任务上的基学习器必须是一样的,对于差别很大的任务,最切合任务的基学习器可能会变化,那么就不能用 MAML 来解决这类问题。

  • MAML 适用于所有基于随机梯度算法求解的基学习器,这意味着参数都是连续的,无法考虑离散的参数。对于差别较大的任务,往往需要更新网络结构。使用 MAML 无法完成这样的结构更新。

  • MAML 使用的损失函数都是可求导的,这样才能使用随机梯度算法来快速优化求解,损失函数中不能有不可求导的奇异点,否则会导致优化求解不稳定。

  • MAML 中考虑的新任务都是相似的任务,所以没有对任务进行分类,也没有计算任务之间的距离度量。对每一类任务单独更新其参数初始值,每一类任务的参数初始值不同,这些在 MAML 中都没有考虑。

  • 参考文献

[1] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.

2.Latent Embedding Optimization

Latent Embedding Optimization (LEO) 学习模型参数的低维潜在嵌入,并在这个低维潜在空间中执行基于优化的元学习,将基于梯度的自适应过程与模型参数的基础高维空间分离。

2.1 LEO

在元学习器中,使用 SGD 最小化任务验证集损失函数,
使得模型的泛化能力最大化,计算元参数,元学习器将元参数输入基础学习器,
继而,基础学习器最小化任务训练集损失函数,快速给出任务上的预测结果。
LEO 结构如图1所示。

图1 LEO 结构图。 D t r D^{\mathrm{tr}} Dtr 是任务 ε \varepsilon ε 的 support set,
D v a l D^{\mathrm{val}} Dval 是任务 ε \varepsilon ε 的 query set,
z z z 是通过编码器计算的 N N N 个类别的类别特征, f θ f_{\theta} fθ 是基学习器,
θ \theta θ 是基学习器参数,
L t r = f θ ( D t r ) L^{\mathrm{tr}}=f_{\theta}\left( D^{\mathrm{tr}}\right) Ltr=fθ(Dtr), L v a l = f θ ( D v a l ) L^{\mathrm{val}}=f_{\theta}\left( D^{\mathrm{val}}\right) Lval=fθ(Dval)

LEO 包括基础学习器和元学习器,还包括编码器和解码器。
在基础学习器中,编码器将高维输入数据映射成特征向量,
解码器将输入数据的特征向量映射成输入数据属于各个类别的概率值,
基础学习器使用元学习器提供的元参数进行参数更新,给出数据标注的预测结果。
元学习器为基础学习器的编码器和解码器提供元参数,
元参数包括特征提取模型的参数、编码器的参数、解码器的参数等,
通过最小化所有任务上的泛化误差,更新元参数。

2.2 基础学习器

编码器和解码器都在基础学习器中,用于计算输入数据属于每个类别的概率值,
进而对输入数据进行分类。
元学习器提供编码器和解码器中的参数,基础学习器快速的使用编码器和解码器计算输入数据的分类。
任务训练完成后,基础学习器将每个类别数据的特征向量和任务 ε \varepsilon ε 的基础学习器参数 θ ε \boldsymbol{\theta}_{\varepsilon} θε 输入元学习器,
元学习器使用这些信息更新元参数。

2.2.1 编码器

编码器模型包括两个主要部分:编码器和关系网络。

编码器 g ϕ e g_{\phi_{e}} gϕe ,其中 ϕ e \phi_{e} ϕe 是编码器的可训练参数,
其功能是将第 n n n 个类别的输入数据映射成第 n n n 个类别的特征向量。

关系网络 g ϕ r g_{\phi_{r}} gϕr ,其中 ϕ r \phi_{r} ϕr 是关系网络的可训练参数,
其功能是计算特征之间的距离。

n n n 个类别的输入数据的特征记为 z n z_{n} zn
对于输入数据,首先,使用编码器 g ϕ e g_{\phi_{e}} gϕe 对属于第 n n n 个类别的输入数据进行特征提取;
然后,使用关系网络 g ϕ r g_{\phi_r} gϕr 计算特征之间的距离,
综合考虑训练集中所有样本点之间的距离,计算这些距离的平均值和离散程度;
n n n 个类别输入数据的特征 z n z_{n} zn 服从高斯分布,
且高斯分布的期望是这些距离的平均值,高斯分布的方差是这些距离的离散程度,
具体的计算公式如下:

μ n e , σ n e = 1 N K 2 ∑ k n = 1 K ∑ m = 1 N ∑ k m = 1 K g ϕ r [ g ϕ e ( x n k n ) , g ϕ e ( x m k m ) ] z n ∼ q ( z n ∣ D n t r ) = N { μ n e , diag ⁡ ( σ n e ) 2 } \begin{aligned} &\mu_{n}^{e}, \sigma_{n}^{e}=\frac{1}{N K^{2}} \sum_{k_{n}=1}^{K} \sum_{m=1}^{N} \sum_{k_{m}=1}^{K} g_{\phi_{r}}\left[g_{\phi_{e}}\left(x_{n}^{k_{n}}\right), g_{\phi_{e}}\left(x_{m}^{k_{m}}\right)\right] \\ &z_{n} \sim q\left(z_{n} \mid D_{n}^{\mathrm{tr}}\right)=N\left\{\mu_{n}^{e}, \operatorname{diag}\left(\sigma_{n}^{e}\right)^{2}\right\} \end{aligned} μne,σne=NK21kn=1Km=1Nkm=1Kgϕr[gϕe(xnkn),gϕe(xmkm)]znq(znDntr)=N{μne,diag(σne)2}

其中, N N N 是类别总数, K K K 是每个类别的图片总数,
D n t r {D}_{n}^{\mathrm{tr}} Dntr 是第 n n n 个类别的训练数据集。
对于每个类别的输入数据,每个类别下有 K K K 张图片,
计算这 K K K 张图片和所有已知图片之间的距离。
总共有 N N N 个类别,通过编码器的计算,形成所有类别的特征,
记为 z = ( z 1 , ⋯   , z N ) z=\left(z_{1}, \cdots, z_{N}\right) z=(z1,,zN)

2.2.2 解码器

解码器 g ϕ d g_{\phi_{d}} gϕd ,其中 ϕ d \phi_{d} ϕd 是解码器的可训练参数,
其功能是将每个类别输入数据的特征向量 z n z_{n} zn
映射成属于每个类别的概率值 w n \boldsymbol{w}_{n} wn

μ n d , σ n d = g ϕ d ( z n ) w n ∼ q ( w ∣ z n ) = N { μ n d , diag ⁡ ( σ n d ) 2 } \begin{aligned} &\mu_{n}^{d}, \sigma_{n}^{d}=g_{\phi_{d}}\left(z_{n}\right) \\ &w_{n} \sim q\left(w \mid z_{n}\right)=N\left\{\mu_{n}^{d}, \operatorname{diag}\left(\sigma_{n}^{d}\right)^{2}\right\} \end{aligned} μnd,σnd=gϕd(zn)wnq(wzn)=N{μnd,diag(σnd)2}

其中,任务 ε \varepsilon ε 的基础学习器参数记为 θ ε \theta_{\varepsilon} θε
基础学习器参数由属于每个类别的概率值组成,
记为 θ ε = ( w 1 , w 2 , ⋯   , w N ) \theta_{\varepsilon}=\left(w_{1}, w_{2}, \cdots, w_{N}\right) θε=(w1,w2,,wN)
基础学习器参数 w n \boldsymbol{w}_{n} wn 指的是输入数据属于第 n n n 个类别的概率值,
g ϕ d g_{\phi_{d}} gϕd 是从特征向量到基础学习器参数的映射。

图2 LEO 基础学习器工作原理图。

2.2.3 基础学习器更新过程

在基础学习器中,任务 ε \varepsilon ε 的交叉熵损失函数是:

L ε t r ( f θ ε ) = ∑ ( x , y ) ∈ D ε t r [ − w y x + log ⁡ ∑ j = 1 N e w j x ] L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right)=\sum_{(x, y) \in D_{\varepsilon}^{\mathrm{tr}}}\left[-w_{y} \boldsymbol{x}+\log \sum_{j=1}^{N} \mathrm{e}^{w_{j} x}\right] Lεtr(fθε)=(x,y)Dεtr[wyx+logj=1Newjx]

其中, ( x , y ) (x, y) (x,y) 是任务 ε \varepsilon ε 训练集 D ε t r D_{\varepsilon}^{\mathrm{tr}} Dεtr 中的样本点, f θ ε f_{\theta_{\varepsilon}} fθε 是任务 ε \varepsilon ε 的基础学习器,
最小化任务 ε \varepsilon ε 的损失函数更新任务专属参数 θ ε \theta_{\varepsilon} θε
在解码器模型中,任务专属参数为 w n ∼ q ( w ∣ z n ) w_{n} \sim q\left(w \mid z_{n}\right) wnq(wzn)
更新任务专属参数 θ ε \theta_{\varepsilon} θε 意味着更新特征向量 z n z_{n} zn

z n ′ = z n − α ∇ z n L ε t r ( f θ ε ) , z_{n}^{\prime}=z_{n}-\alpha \nabla_{z_{n}} L_{\varepsilon}^{t r}\left(f_{\theta_{\varepsilon}}\right), zn=znαznLεtr(fθε),

其中, z n ′ \boldsymbol{z}_{n}^{\prime} zn 是更新后的特征向量,
对应的是更新后的任务专属参数 θ ε ′ \boldsymbol{\theta}_{\varepsilon}^{\prime} θε
基础学习器使用 θ ε ′ \theta_{\varepsilon}^{\prime} θε 来预测任务验证集数据的标注,
将任务 ε \varepsilon ε 的验证集 D ε v a l \mathrm{D}_{\varepsilon}^{\mathrm{val}} Dεval
损失函数 L ε v a l ( f θ ε ′ ) L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right) Lεval(fθε)
更新后的特征向量 z n ′ z_{n}^{\prime} zn
更新后的任务专属参数 θ ε ′ \theta_{\varepsilon}^{\prime} θε 输入元学习器,
在元学习器中更新元参数。

2.3 元学习器更新过程

在元学习器中,最小化所有任务 ε \varepsilon ε 的验证集的损失函数的求和,
最小化任务上的模型泛化误差:

min ⁡ ϕ e , ϕ r , ϕ d ∑ ε [ L ε v a l ( f θ ε ′ ) + β D K L { q ( z n ∣ D n t r ) ∥ p ( z n ) } + γ ∥ s ( z n ′ ) − z n ∥ 2 2 ] + R \min _{\phi_{e}, \phi_{r}, \phi_{d}} \sum_{\varepsilon}\left[L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)+\beta D_{\mathrm{KL}}\left\{q\left(z_{n} \mid {D}_{n}^{\mathrm{tr}}\right) \| p\left(z_{n}\right)\right\}+\gamma\left\|s\left(\boldsymbol{z}_{n}^{\prime}\right)-\boldsymbol{z}_{n}\right\|_{2}^{2}\right]+R ϕe,ϕr,ϕdminε[Lεval(fθε)+βDKL{q(znDntr)p(zn)}+γs(zn)zn22]+R

其中, L ε v a l ( f θ ε ′ ) L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right) Lεval(fθε) 是任务 ε \varepsilon ε 验证集的损失函数,
衡量了基础学习器模型的泛化误差,损失函数越小,模型的泛化能力越好。
p ( z n ) = N ( 0 , I ) p\left(z_{n}\right)=N(0, I) p(zn)=N(0,I) 是高斯分布, D K L { q ( z n ∣ D n t r ) ∥ p ( z n ) } D_{\mathrm{KL}}\left\{q\left(z_{n} \mid {D}_{n}^{\mathrm{tr}}\right) \| p\left(z_{n}\right)\right\} DKL{q(znDntr)p(zn)} 是近似后验分布 q ( z n ∣ D n tr  ) q\left(z_{n} \mid D_{n}^{\text {tr }}\right) q(znDntr ) 与先验分布 p ( z n ) p\left(z_{n}\right) p(zn) 之间的 KL 距离 (KL-Divergence),
最小化 K L \mathrm{KL} KL 距离可使后验分布 q ( z n ∣ D n tr ) q\left(z_{n} \mid {D}_{n}^{\text {tr}}\right) q(znDntr) 的估计尽可能准确。
最小化距离 ∥ s ( z n ′ ) − z n ∥ \left\|s\left(z_{n}^{\prime}\right)-z_{n}\right\| s(zn)zn 使得参数初始值 z n z_{n} zn 和训练完成后的参数更新值 z n ′ z_{n}^{\prime} zn 距离最小,
使得参数初始值和参数最终值更接近。
R R R 是正则项, 用于调控元参数的复杂程度,避免出现过拟合,正则项 R R R 的计算公式如下:

R = λ 1 ( ∥ ϕ e ∥ 2 2 + ∥ ϕ r ∥ 2 2 + ∥ ϕ d ∥ 2 2 ) + λ 2 ∥ C d − I ∥ 2 R=\lambda_{1}\left(\left\|\phi_{e}\right\|_{2}^{2}+\left\|\phi_{r}\right\|_{2}^{2}+\left\|\phi_{d}\right\|_{2}^{2}\right)+\lambda_{2}\left\|C_{d}-\mathbb{I}\right\|_{2} R=λ1(ϕe22+ϕr22+ϕd22)+λ2CdI2

其中, ∥ ϕ r ∥ 2 2 \left\|\phi_{r}\right\|_{2}^{2} ϕr22 指的是调控元参数的个数和大小,
C d {C}_{d} Cd 是参数 ϕ d \phi_{d} ϕd 的行和行之间的相关性矩阵,
超参数 λ 1 , λ 2 > 0 \lambda_{1},\lambda_{2}>0 λ1,λ2>0
∥ C d − I ∥ 2 \left\|C_{d}-\mathbb{I}\right\|_{2} CdI2 使得 C d C_{d} Cd 接近单位矩阵,
使得参数 ϕ d \phi_{d} ϕd 的行和行之间的相关性不能太大,
每个类别的特征向量之间的相关性不能太大,
属于每个类别的概率值之间的相关性也不能太大,分类要尽量准确。

2.4 LEO 算法流程

LEO 算法流程

  1. randomly initialize ϕ e , ϕ r , ϕ d \phi_{e}, \phi_{r}, \phi_{d} ϕe,ϕr,ϕd
  2. let ϕ = { ϕ e , ϕ r , ϕ d , α } \phi=\left\{\phi_{e}, \phi_{r}, \phi_{d}, \alpha\right\} ϕ={ϕe,ϕr,ϕd,α}
  3. while not converged do:
    1. for number of tasks in batch do:
      1. sample task instance T i ∼ S t r \mathcal{T}_{i} \sim \mathcal{S}^{t r} TiStr
      2. let ( D t r , D v a l ) = T i \left(\mathcal{D}^{t r}, \mathcal{D}^{v a l}\right)=\mathcal{T}_{i} (Dtr,Dval)=Ti
      3. encode D t r \mathcal{D}^{t r} Dtr to z using g ϕ e g_{\phi_{e}} gϕe and g ϕ r g_{\phi_{r}} gϕr
      4. decode z \mathbf{z} z to initial params θ i \theta_{i} θi using g ϕ d g_{\phi_{d}} gϕd
      5. initialize z ′ = z , θ i ′ = θ i \mathbf{z}^{\prime}=\mathbf{z}, \theta_{i}^{\prime}=\theta_{i} z=z,θi=θi
      6. for number of adaptation steps do:
        1. compute training loss L T i t r ( f θ i ′ ) \mathcal{L}_{\mathcal{T}_{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right) LTitr(fθi)
        2. perform gradient step w.r.t. z ′ \mathbf{z}^{\prime} z:
        3. z ′ ← z ′ − α ∇ z ′ L T i t r ( f θ i ′ ) \mathbf{z}^{\prime} \leftarrow \mathbf{z}^{\prime}-\alpha \nabla_{\mathbf{z}^{\prime}} \mathcal{L}_{\mathcal{T}_{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right) zzαzLTitr(fθi)
        4. decode z ′ \mathbf{z}^{\prime} z to obtain θ i ′ \theta_{i}^{\prime} θi using g ϕ d g_{\phi_{d}} gϕd
      7. end for
      8. compute validation loss L T i v a l ( f θ i ′ ) \mathcal{L}_{\mathcal{T}_{i}}^{v a l}\left(f_{\theta_{i}^{\prime}}\right) LTival(fθi)
    2. end for
    3. perform gradient step w.r.t ϕ \phi ϕ ϕ ← ϕ − η ∇ ϕ ∑ T i L T i v a l ( f θ i ′ ) \phi \leftarrow \phi-\eta \nabla_{\phi} \sum_{\mathcal{T}_{i}} \mathcal{L}_{\mathcal{T}_{i}}^{v a l}\left(f_{\theta_{i}^{\prime}}\right) ϕϕηϕTiLTival(fθi)
  4. end while

(1) 初始化元参数:编码器参数 ϕ e \phi_{e} ϕe、关系网络参数 ϕ r \phi_{r} ϕr、解码器参数 ϕ d \phi_{d} ϕd,
在元学习器中更新的元参数包括 ϕ = { ϕ e , ϕ r , ϕ d } \phi=\left\{\phi_e, \phi_r,\phi_d \right\} ϕ={ϕe,ϕr,ϕd}

(2) 使用片段式训练模式,
随机抽取任务 ε \varepsilon ε, D ε t r {D}_{\varepsilon}^{\mathrm{tr}} Dεtr 是任务 ε \varepsilon ε 的训练集,
D ε v a l {D}_{\varepsilon}^{\mathrm{val}} Dεval 是任务 ε \varepsilon ε 的验证集。

(3) 使用编码器 g ϕ e g_{\phi_{e}} gϕe 和关系网络 g ϕ r g_{\phi_{r}} gϕr 将任务 ε \varepsilon ε 的训练集 D ε t r D_{\varepsilon}^{\mathrm{tr}} Dεtr 编码成特征向量 z z z
使用 解码器 g ϕ d g_{\phi_{d}} gϕd 从特征向量映射到任务 ε \varepsilon ε 的基础学习器参数 θ ε {\theta}_{\varepsilon} θε
基础学习器参数指的是输入数据属于每个类别的概率值向量;
计算任务 ε \varepsilon ε 的训练集的损失函数 L ε t r ( f θ ε ) L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right) Lεtr(fθε)
最小化任务 ε \varepsilon ε 的损失函数,更新每个类别的特征向量:

z n ′ = z n − α ∇ z n L ε t r ( f θ ε ) z_{n}^{\prime}=z_{n}-\alpha \nabla_{z_{n}} L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right) zn=znαznLεtr(fθε)

使用解码器 g ϕ d g_{\phi_{d}} gϕd 从更新后的特征向量映射到更新后的任务 ε \varepsilon ε 的基础学习器参数 θ ε ′ {\theta}_{\varepsilon}^{\prime} θε
计算任务 ε \varepsilon ε 的验证集的损失函数 L ε val ( f θ s ′ ) L_{\varepsilon}^{\text {val}}\left(f_{\theta_{s}^{\prime}}\right) Lεval(fθs)
基础学习器将更新后的参数和验证集损失函数值输入元学习器。

(4) 更新元参数, ϕ ← ϕ − η ∇ ϕ ∑ ε L ε val ( f θ ε ′ ) \phi \leftarrow \phi-\eta \nabla_{\phi} \sum_{\varepsilon} L_{\varepsilon}^{\text {val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right) ϕϕηϕεLεval(fθε)
最小化所有任务 ε \varepsilon ε 的验证集的损失和,
将更新后的元参数输人基础学习器,继续处理新的分类任务。

2.5 LEO 模型结构

LEO 是一种与模型无关的元学习,[1] 中给出的各部分模型结构及参数如表1所示。

表1 LEO 各部分模型结构及参数。
Part of the modelArchitectureHiddenlayerShape of the output
Inference model ( f θ f_{\theta} fθ)3-layer MLP with ReLU40(12, 5, 1)
Encoder3-layer MLP with ReLU16(12, 5, 16)
Relation Network3-layer MLP with ReLU32(12, 2 × 16 2\times 16 2×16)
Decoder3-layer MLP with ReLU32(12, 2 × 1761 2\times 1761 2×1761)

2.6 LEO 分类结果

表1 LEO 在 miniImageNet 上的分类结果。
Model5-way 1-shot5-way 5-shot
Matching networks (Vinyals et al., 2016)43.56 ± \pm ± 0.84 % \% %55.31 ± \pm ± 0.73 % \% %
Meta-learner LSTM (Ravi & Larochelle, 2017)43.44 ± \pm ± 0.77 % \% %60.60 ± \pm ± 0.71 % \% %
MAML (Finn et al., 2017)48.70 ± \pm ± 1.84 % \% %63.11 ± \pm ± 0.92 % \% %
LLAMA (Grant et al., 2018)49.40 ± \pm ± 1.83 % \% %
REPTILE (Nichol & Schulman, 2018)49.97 ± \pm ± 0.32 % \% %65.99 ± \pm ± 0.58 % \% %
PLATIPUS (Finn et al., 2018)50.13 ± \pm ± 1.86 % \% %
Meta-SGD (our features)54.24 ± \pm ± 0.03 % \% %70.86 ± \pm ± 0.04 % \% %
SNAIL (Mishra et al., 2018)55.71 ± \pm ± 0.99 % \% %68.88 ± \pm ± 0.92 % \% %
(Gidaris & Komodakis, 2018)56.20 ± \pm ± 0.86 % \% %73.00 ± \pm ± 0.64 % \% %
(Bauer et al., 2017)56.30 ± \pm ± 0.40 % \% %73.90 ± \pm ± 0.30 % \% %
(Munkhdalai et al., 2017)57.10 ± \pm ± 0.70 % \% %70.04 ± \pm ± 0.63 % \% %
DEML+Meta-SGD (Zhou et al., 2018)58.49 ± \pm ± 0.91 % \% %71.28 ± \pm ± 0.69 % \% %
TADAM (Oreshkin et al., 2018)58.50 ± \pm ± 0.30 % \% %76.70 ± \pm ± 0.30 % \% %
(Qiao et al., 2017)59.60 ± \pm ± 0.41 % \% %73.74 ± \pm ± 0.19 % \% %
LEO61.76 ± \pm ± 0.08 % \% %77.59 ± \pm ± 0.12 % \% %
表1 LEO 在 tieredImageNet 上的分类结果。
Model5-way 1-shot5-way 5-shot
MAML (deeper net, evaluated in Liu et al. (2018))51.67 ± \pm ± 1.81 % \% %70.30 ± \pm ± 0.08 % \% %
Prototypical Nets (Ren et al., 2018)53.31 ± \pm ± 0.89 % \% %72.69 ± \pm ± 0.74 % \% %
Relation Net (evaluated in Liu et al. (2018))54.48 ± \pm ± 0.93 % \% %71.32 ± \pm ± 0.78 % \% %
Transductive Prop. Nets (Liu et al., 2018)57.41 ± \pm ± 0.94 % \% %71.55 ± \pm ± 0.74 % \% %
Meta-SGD (our features)62.95 ± \pm ± 0.03 % \% %79.34 ± \pm ± 0.06 % \% %
LEO66.33 ± \pm ± 0.05 % \% %81.44 ± \pm ± 0.09 % \% %

2.7 LEO 的优点

  • 新任务的初始参数以训练数据为条件,这使得任务特定的适应起点成为可能。
    通过将关系网络结合到编码器中,该初始化可以更好地考虑所有输入数据之间的联合关系。

  • 通过在低维潜在空间中进行优化,该方法可以更有效地适应模型的行为。
    此外,通过允许该过程是随机的,可以表达在少数数据状态中存在的不确定性和模糊性。

  • 参考文献
    [1] Meta-Learning with Latent Embedding Optimization

3.Reptile

Reptil 是 MAML 的特例、近似和简化,主要解决 MAML 元学习器中出现的高阶导数问题。
因此,Reptil 同样学习网络参数的初始值,并且适用于任何基于梯度的模型结构。

在 MAML 的元学习器中,使用了求导数的算式来更新参数初始值,
导致在计算中出现了任务损失函数的二阶导数。
在 Reptile 的元学习器中,参数初始值更新时,
直接使用了任务上的参数估计值和参数初始值之间的差,
来近似损失函数对参数初始值的导数,进行参数初始值的更新,从而不会出现任务损失函数的二阶导数。

Peptile 有两个版本:Serial Version 和 Batched Version,两者的差异如下:

3.1 Serial Version Reptile

单次更新的 Reptile,每次训练完一个任务的基学习器,就更新一次元学习器中的参数初始值。

(1) 任务上的基学习器记为 f ϕ f_{\phi} fϕ ,其中 ϕ \phi ϕ 是基学习器中可训练的参数,
θ \theta θ 是元学习器提供给基学习器的参数初始值。
在任务 T i T_{i} Ti 上,基学习器的损失函数是 L T i ( f ϕ ) L_{T_{i}}\left(f_{\phi}\right) LTi(fϕ)
基学习器中的参数经过 N N N 次迭代更新得到参数估计值:

θ i N = SGD ⁡ ( L T i , θ , N ) \theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right) θiN=SGD(LTi,θ,N)

(2) 更新元学习器中的参数初始值:

θ ← θ + ε ( θ i N − θ ) \theta \leftarrow \theta+\varepsilon\left(\theta_{i}^{N}-\theta\right) θθ+ε(θiNθ)

Serial Version Reptile 算法流程

  1. initialize θ \theta θ, the vector of initial parameters
  2. for iteration=1, 2, … do:
    1. sample task T i T_i Ti, corresponding to loss L T i L_{T_i} LTi on weight vectors θ \theta θ
    2. compute θ i N = SGD ⁡ ( L T i , θ , N ) \theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right) θiN=SGD(LTi,θ,N)
    3. update θ ← θ + ε ( θ i N − θ ) \theta \leftarrow \theta+\varepsilon\left(\theta_{i}^{N}-\theta\right) θθ+ε(θiNθ)
  3. end for

3.2 Batched Version Reptile

批次更新的 Reptile,每次训练完多个任务的基学习器之后,才更新一次元学习器中的参数初始值。

(1) 在多个任务上训练基学习器,每个任务从参数初始值开始,迭代更新 N N N 次,得到参数估计值。

(2) 更新元学习器中的参数初始值:

θ ← θ + ε 1 n ∑ i = 1 n ( θ i N − θ ) \theta \leftarrow \theta+\varepsilon \frac{1}{n} \sum_{i=1}^{n}\left(\theta_{i}^{N}-\theta\right) θθ+εn1i=1n(θiNθ)

其中, n n n 是指每次训练完 n n n 个任务上的基础学习器后,才更新一次元学习器中的参数初始值。

Batched Version Reptile 算法流程

  1. initialize θ \theta θ
  2. for iteration=1, 2, … do:
    1. sample tasks T 1 T_1 T1, T 2 T_2 T2, … , T n T_n Tn,
    2. for i=1, 2, … , n do:
      1. compute θ i N = SGD ⁡ ( L T i , θ , N ) \theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right) θiN=SGD(LTi,θ,N)
    3. end for
    4. update θ ← θ + ε 1 n ∑ i = 1 n ( θ i N − θ ) \theta \leftarrow \theta+\varepsilon \frac{1}{n} \sum_{i=1}^{n}\left(\theta_{i}^{N}-\theta\right) θθ+εn1i=1n(θiNθ)
  3. end for

3.3 Reptile 分类结果

表1 Reptile 在 Omniglot 上的分类结果。
Algorithm5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
MAML + Transduction98.7 ± \pm ± 0.4 % \% %99.9 ± \pm ± 0.1 % \% %95.8 ± \pm ± 0.3 % \% %98.9 ± \pm ± 0.2 % \% %
1 s t 1^{st} 1st-order MAML + Transduction98.3 ± \pm ± 0.5 % \% %99.2 ± \pm ± 0.2 % \% %89.4 ± \pm ± 0.5 % \% %97.9 ± \pm ± 0.1 % \% %
Reptile95.32 ± \pm ± 0.05 % \% %98.87 ± \pm ± 0.02 % \% %88.27 ± \pm ± 0.30 % \% %97.07 ± \pm ± 0.12 % \% %
Reptile + Transduction97.97 ± \pm ± 0.08 % \% %99.47 ± \pm ± 0.04 % \% %89.36 ± \pm ± 0.20 % \% %97.47 ± \pm ± 0.10 % \% %
表1 Reptile 在 miniImageNet 上的分类结果。
Algorithm5-way 1-shot5-way 5-shot
MAML + Transduction48.70 ± \pm ± 1.84 % \% %63.11 ± \pm ± 0.92 % \% %
1 s t 1^{st} 1st-order MAML + Transduction48.07 ± \pm ± 1.75 % \% %63.15 ± \pm ± 0.91 % \% %
Reptile45.79 ± \pm ± 0.44 % \% %61.98 ± \pm ± 0.69 % \% %
Reptile + Transduction48.21 ± \pm ± 0.69 % \% %66.00 ± \pm ± 0.62 % \% %
  • 参考文献
    [1] Reptile: a Scalable Metalearning Algorithm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/644907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为认证 | HCIP-Datacom-Core 考试大纲

01 考试概况 02 考试内容 HCIP-Datacom-Core Technology V1.0考试覆盖数据通信领域各场景通用核心知识,包括路由基础、OSPF、 IS-IS、BGP、路由和流量控制、以太网交换技术、组播、IPv6、网络安全、网络可靠性、网络服务与管理、 WLAN、网络解决方案。 ★路由基础 …

【MySQL 函数】:一文彻底搞懂 MySQL 函数(一)

前言 ✨欢迎来到小K的MySQL专栏,本节将为大家带来MySQL字符串函数和数学函数的讲解✨ 目录 前言一、字符串函数二、数学函数三、总结 一、字符串函数 函数作用UPPER(列|字符串)将字符串每个字符转为大写LOWER(列|字符串)将字符串每个字符转为小写CONCAT(str1,str2,…

阿里云企业邮箱购买流程

阿里云企业邮箱购买流程,企业邮箱分为免费版、标准版、集团版和尊享版,阿里云百科分享企业邮箱版本区别,企业邮箱收费标准价格表,以及阿里企业邮箱详细购买流程: 目录 阿里云企业邮箱购买流程 一、阿里云账号注册及…

OpenAI 刚刚宣布了海量更新

OpenAI 刚刚宣布了海量更新,增加函数调用,支持更长上下文,价格更低! ​新模型上架 1、gpt-4-0613 2、gpt-4-32k-0613 3、gpt-3.5-turbo-0613 4、gpt-3.5-turbo-16k 部分模型降价 1、text-embedding-ada-002:$0.00…

DevExpress WinForms功能区组件,让业务应用创建更轻松!(上)

DevExpress WinForms的Ribbon(功能区)组件灵感来自于Microsoft Office,并针对WinForms开发人员进行了优化,它可以帮助开发者轻松地模拟当今最流行的商业生产应用程序。 PS:DevExpress WinForm拥有180组件和UI库&#…

Linux安装SQLServer数据库

Linux安装SQLServer数据库 文章目录 Linux安装SQLServer数据库SQLServer是什么SQLServer的安装安装要求安装步骤安装配置安装命令行工具 SQLServer是什么 美国 Microsoft 公司推出的一种关系型数据库系统。SQL Server 是一个可扩展的、高性能的、为分布式客户机/服务器计算所设…

[PostgreSQL-16新特性之EXPLAIN的GENERIC_PLAN选项]

随着PostgreSQL-16beta1版本的发布,我们可以发现,对于我们时常使用的explain增加了一个GENERIC_PLAN选项。这个选项是允许了包含参数占位符的语句,如select * from tab01 where id$1;等等这种语句,让其生成不依赖于这些参数值的通…

两个HC-05蓝牙之间的配对

两个HC-05蓝牙之间的配对 文章目录 两个HC-05蓝牙之间的配对1.进入AT指令模式后,先确定是否为AT模式:2.获取模块A,B的地址3.将蓝牙A配置为主模式,将蓝牙B配置为从模式:4.设置模块通信波特率,蓝牙模块A和B的配置需要相同6.验证 买了…

目标检测数据集---玻璃瓶盖工业缺陷数据集

✨✨✨✨✨✨目标检测数据集✨✨✨✨✨✨ 本专栏提供各种场景的数据集,主要聚焦:工业缺陷检测数据集、小目标数据集、遥感数据集、红外小目标数据集,该专栏的数据集会在多个专栏进行验证,在多个数据集进行验证mAP涨点明显,尤其是小目标、遮挡物精度提升明显的数据集会在该…

儿童遗留监测成为「加分项」,多种技术路线「争夺战」一触即发

儿童遗留密闭车内,温度可以在短短15分钟内达到临界水平,从而可能导致中暑和死亡,尤其是当汽车在太阳底下暴晒。 按照Euro NCAP给出的指引,与车祸相比,儿童因车辆中暑而死亡的情况较少,但却是完全可以避免的…

计算机网络开荒4-网络层

文章目录 一、网络层概述1.1 路由转发1.2 建立连接1.3 网络服务类型 二、虚拟电路与数据报网络2.1 虚电路Virtual circuits VC网络2.1.1 VC 实现2.1.1 虚电路信令协议(signaling protocols) 2.2 数据报网络2.3 对比 三、Internet网络的网络层 IP协议3.1 IP分片3.1.1 最大传输单…

中创|数据中心集聚,算力企业环绕,郑州:力争打造中部最强数据中心集群

信息时代,算力就是生产力。从田间到车间、从陆地到天空,算力的应用已经在方方面面“大显身手”。不仅是在存储领域,在具体的应用服务领域,算力也无处不在。 手机支付、网上购物、精准导航、人脸识别……这些人们熟悉的生活场景&a…

如何在telnet连接的情况下下载上传文件

1.下载tftp文件 TFTP下载-TFTP正式版下载[电脑版]-华军软件园 2.选择自己PC机所在的IP 3.telnet登录到设备 4.上传下载 //上传: 从Clinet(设备)上传文件到Server(PC机)时, 使用下面的命令 tftp –p –…

cesium学习笔记

cesium入门笔记 一、下载源码,源码介绍二、html案例体验三、cesium中的类介绍1.它们分别是:2.四大类的完整演示代码: 四、cesium的坐标与转换五、相机系统介绍六、地图、地形的加载七、建筑体添加和使用八、空间数据加载1、加载数据2、对加载…

python中变量与字符串详解!!

❄️作者介绍:奇妙的大歪❄️ 🎀个人名言:但行前路,不负韶华!🎀 🐽个人简介:云计算网络运维专业人员🐽 前言 初学者经常会遇到的困惑是,看书上或者是听课都懂…

数据库:mysqldump用法详解

目录 一、命令格式详解 1.1 语法格式 1.2 常用参数 1.3 完整参数 二、mysqldump常见的几个错误 2.1、提示命令未找到 -bash: mysqldump: 未找到命令 2.2、 the command line interface can be insecure 2.3、Gotpacket bigger than ‘max_allowed_packet‘ bytes 一、命令格式详…

GPT-4官方使用经验都在里面;Stability AI联合Clipdrop推出一键改变图像比例

🦉 AI新闻 🚀 Stability AI联合Clipdrop推出扩图工具Uncrop,一键改变图像比例 摘要:Stability AI联合Clipdrop推出的Uncrop Clipdrop是一个终极图像比例编辑器。它可以补充任何现有照片或图像,来更改任何图像的比例。…

apple pencil二代平替笔哪个好用?ipad平替笔合集

现在很多人都在用IPAD记录,或者用IPAD画图。还有就是,大部分的IPAD用户,都是以实用为主,他们觉得,要想让IPAD更加实用,就一定要有一款好用的电容笔。其实,如果只是用来做笔记,或者只…

43从零开始学Java之一文详解初学者难以理解的多态

作者:孙玉昌,昵称【一一哥】,另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 我们知道,面向对象有三大特征:封装、继承和多态。现在我们已经了解了封装和继…

IBM Spectrum LSF 针对要求苛刻、任务关键型计算环境的全面工作负载管理

IBM Spectrum LSF 针对要求苛刻、任务关键型计算环境的全面工作负载管理 亮点 通过卓越的可重复性能加快求解时间; 使用可靠且可扩展的架构管理大量作业; 面向管理员和用户的直观界面提高工作效率; IBM Spectrum LSF 系列是一套完整的工作负载管理解决方案组合 &#xff0…