深度学习应用篇-元学习[15]:基于度量的元学习:SNAIL、RN、PN、MN

news2024/12/24 21:46:57

在这里插入图片描述
【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等

在这里插入图片描述
专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等

本专栏主要方便入门同学快速掌握相关知识。后续会持续把深度学习涉及知识原理分析给大家,让大家在项目实操的同时也能知识储备,知其然、知其所以然、知何由以知其所以然。

声明:部分项目为网络经典项目方便大家快速学习,后续会不断增添实战环节(比赛、论文、现实应用等)

专栏订阅:

  • 深度学习入门到进阶专栏
  • 深度学习应用项目实战篇

深度学习应用篇-元学习[15]:基于度量的元学习:SNAIL、RN、PN、MN

1.Simple Neural Attentive Learner(SNAIL)

元学习可以被定义为一种序列到序列的问题,
在现存的方法中,元学习器的瓶颈是如何去吸收同化利用过去的经验。
注意力机制可以允许在历史中精准摘取某段具体的信息。

Simple Neural Attentive Learner (SNAIL)
组合时序卷积和 soft-attention,
前者从过去的经验整合信息,后者精确查找到某些特殊的信息。

1.1 Preliminaries

1.1.1 时序卷积和 soft-attention

时序卷积 (TCN) 是有因果前后关系的,即在下一时间步生成的值仅仅受之前的时间步影响。
TCN 可以提供更直接,高带宽的传递信息的方法,这允许它们基于一个固定大小的时序内容进行更复杂的计算。
但是,随着序列长度的增加,卷积膨胀的尺度会随之指数增加,需要的层数也会随之对数增加。
因此这种方法对于之前输入的访问更粗略,且他们的有限的能力和位置依赖并不适合元学习器,
因为元学习器应该能够利用增长数量的经验,而不是随着经验的增加,性能会被受限。

soft-attention 可以实现从超长的序列内容中获取准确的特殊信息。
它将上下文作为一种无序的关键值存储,这样就可以基于每个元素的内容进行查询。
但是,位置依赖的缺乏(因为是无序的)也是一个缺点。

TCN 和 soft-attention 可以实现功能互补:
前者提供高带宽的方法,代价是受限于上下文的大小,后者可以基于不确定的可能无限大的上下文提供精准的提取。
因此,SNAIL 的构建使用二者的组合:使用时序卷积去处理用注意力机制提取过的内容。
通过整合 TCN 和 attention,SNAIL 可以基于它过去的经验产出高带宽的处理方法且不再有经验数量的限制。
通过在多个阶段使用注意力机制,端到端训练的 SNAIL 可以学习从收集到的信息中如何摘取自己需要的信息并学习一个恰当的表示。

1.1.2 Meta-Learning

在元学习中每个任务 T i \mathcal{T}_{i} Ti 都是独立的,
其输入为 x t x_{t} xt ,输出为 a t a_{t} at ,损失函数是 L i ( x t , a t ) \mathcal{L}_{i}\left(x_{t}, a_{t}\right) Li(xt,at)
一个转移分布 P i ( x t ∣ x t − 1 , a t − 1 ) P_{i}\left(x_{t} \mid x_{t-1}, a_{t-1}\right) Pi(xtxt1,at1) ,和一个输出长度 H i H_i Hi
一个元学习器(由 θ \theta θ 参数化)建模分布:

π ( a t ∣ x 1 , … , x t ; θ ) \pi\left(a_{t} \mid x_{1}, \ldots, x_{t} ; \theta\right) π(atx1,,xt;θ)

给定一个任务的分布 T = P ( T i ) \mathcal{T}=P\left(\mathcal{T}_{i}\right) T=P(Ti)
元学习器的目标是最小化它的期待损失:

min ⁡ θ E T i ∼ T [ ∑ t = 0 H i L i ( x t , a t ) ]  where  x t ∼ P i ( x t ∣ x t − 1 , a t − 1 ) , a t ∼ π ( a t ∣ x 1 , … , x t ; θ ) \begin{aligned} &\min _{\theta} \mathbb{E}_{\mathcal{T}_{i} \sim \mathcal{T}}\left[\sum_{t=0}^{H_{i}} \mathcal{L}_{i}\left(x_{t}, a_{t}\right)\right] \\ &\text { where } x_{t} \sim P_{i}\left(x_{t} \mid x_{t-1}, a_{t-1}\right), a_{t} \sim \pi\left(a_{t} \mid x_{1}, \ldots, x_{t} ; \theta\right) \end{aligned} θminETiT[t=0HiLi(xt,at)] where xtPi(xtxt1,at1),atπ(atx1,,xt;θ)

元学习器被训练去针对从 T \mathcal{T} T 中抽样出来的任务 (或一个 mini-batches 的任务) 优化这个期望损失。
在测试阶段,元学习器在新任务分布 T ~ = P ( T ~ i ) \widetilde{\mathcal{T}}=P\left(\widetilde{\mathcal{T}}_{i}\right) T =P(T i) 上被评估。

1.2 SNAIL

1.2.1 SNAIL 基础结构

两个时序卷积层(橙色)和一个因果关系层(绿色)的组合是 SNAIL 的基础结构,
如图1所示。
在监督学习设置中,
SNAIL 接收标注样本 ( x 1 , y 1 ) , … , ( x t − 1 , y t − 1 ) \left(x_{1}, y_{1}\right), \ldots,\left(x_{t-1}, y_{t-1}\right) (x1,y1),,(xt1,yt1) 和末标注的 ( x t , − ) \left(x_{t},-\right) (xt,)
然后基于标注样本对 y t y_{t} yt 进行预测。

图1 SNAIL 基础结构示意图。

1.2.2 Modular Building Blocks

对于构建 SNAIL 使用了两个主要模块:
Dense Block 和 Attention Block。

图1 SNAIL 中的 Dense Block 和 Attention Block。(a) Dense Block 应用因果一维卷积,然后将输出连接到输入。TC Block 应用一系列膨胀率呈指数增长的 Dense Block。(b) Attention Block 执行(因果)键值查找,并将输出连接到输入。

Densen Block
用了一个简单的因果一维卷积(空洞卷积),
其中膨胀率 (dilation)为 R R R 和卷积核数量 D D D ([1] 对于所有的实验中设置卷积核的大小为2),
最后合并结果和输入。
在计算结果的时候使用了一个门激活函数。
具体算法如下:

  1. function DENSENBLOCK (inuts, dilation rate R R R, number of filers D D D):
    1. xf, xg = CausalConv (inputs, R R R, D D D), CausalConv (inputs, R R R, D D D)
    2. activations = tanh (xf) * sigmoid (xg)
    3. return concat (inputs, activations)

TC Block
由一系列 dense block 组成,这些 dense block 的膨胀率 R R R 呈指数级增长,直到它们的接受域超过所需的序列长度。具体代码实现时,对序列是需要填充的为了保持序列长度不变。具体算法如下:

  1. function TCBLOCK (inuts, sequence length T T T, number of filers D D D):
    1. for i in 1 , … , [ l o g 2 T ] 1, \ldots, \left[log_2T\right] 1,,[log2T] do
      1. inputs = DenseBlock (inputs, 2 i 2^i 2i, D D D)
    2. return inputs

Attention Block
[1] 中设计成 soft-attention 机制,
公式为:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V \mathrm{ Attention }(Q, K, V)=\mathrm{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V Attention(Q,K,V)=softmax(dk QKT)V

  1. function ATTENTIONBLOCK (inuts, key size K K K, value size V V V):
    1. keys, query = affine (inputs, K K K), affine (inputs, K K K)
    2. logits = matmul (query, transpose (keys))
    3. probs = CausallyMaskedSoftmax ( l o g i t s / K \mathrm{logits} / \sqrt{K} logits/K )
    4. values = affine (inputs, V V V)
    5. read = matmul (probs, values)
    6. return concat (inputs, read)

1.3 SNAIL 分类结果

表1 SNAIL 在 Omniglot 上的分类结果。
Method5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
Santoro et al. (2016)82.8 % \% %94.9 % \% %
Koch (2015)97.3 % \% %98.4 % \% %88.2 % \% %97.0 % \% %
Vinyals et al. (2016)98.1 % \% %98.9 % \% %93.8 % \% %98.5 % \% %
Finn et al. (2017)98.7 ± \pm ± 0.4 % \% %99.9 ± \pm ± 0.3 % \% %95.8 ± \pm ± 0.3 % \% %98.9 ± \pm ± 0.2 % \% %
Snell et al. (2017)97.4 % \% %99.3 % \% %96.0 % \% %98.9 % \% %
Munkhdalai & \& & Yu (2017)98.9 % \% %97.0 % \% %
SNAIL99.07 ± \pm ± 0.16 % \% %99.78 ± \pm ± 0.09 % \% %97.64 ± \pm ± 0.30 % \% %99.36 ± \pm ± 0.18 % \% %
表1 SNAIL 在 miniImageNet 上的分类结果。
Method5-way 1-shot5-way 5-shot
Vinyals et al. (2016)43.6 % \% %55.3 % \% %
Finn et al. (2017)48.7 ± \pm ± 1.84 % \% %63.1 ± \pm ± 0.92 % \% %
Ravi & \& & Larochelle (2017)43.4 ± \pm ± 0.77 % \% %60.2 ± \pm ± 0.71 % \% %
Snell et al. (2017)46.61 ± \pm ± 0.78 % \% %65.77 ± \pm ± 0.70 % \% %
Munkhdalai & \& & Yu (2017)49.21 ± \pm ± 0.96 % \% %
SNAIL55.71 ± \pm ± 0.99 % \% %68.88 ± \pm ± 0.92 % \% %
  • 参考文献

[1] A Simple Neural Attentive Meta-Learner

2.Relation Network(RN)

Relation Network (RN) 使用有监督度量学习估计样本点之间的距离,
根据新样本点和过去样本点之间的距离远近,对新样本点进行分类。

2.1 RN

RN 包括两个组成部分:嵌入模块和关系模块,且两者都是通过有监督学习得到的。
嵌入模块从输入数据中提取特征,关系模块根据特征计算任务之间的距离,
判断任务之间的相似性,找到过去可借鉴的经验进行加权平均。
RN 结构如图1所示。

图1 RN 结构。

嵌入模块记为 f φ f_{\varphi} fφ,关系模块记为 g ϕ g_{\phi} gϕ
支持集中的样本记为 x i \boldsymbol{x}_{i} xi
查询集中的样本记为 x j \boldsymbol{x}_{j} xj

  • x i \boldsymbol{x}_{i} xi x j \boldsymbol{x}_{j} xj 输入 f φ f_{\varphi} fφ
    产生特征映射 f φ ( x i ) f_{\varphi}\left(\boldsymbol{x}_{i}\right) fφ(xi)
    f φ ( x j ) f_{\varphi}\left(\boldsymbol{x}_{j}\right) fφ(xj)

  • 通过运算器 C ( . , . ) C(.,.) C(.,.) f φ ( x i ) f_{\varphi}\left(\boldsymbol{x}_{i}\right) fφ(xi)
    f φ ( x j ) f_{\varphi}\left(\boldsymbol{x}_{j}\right) fφ(xj) 结合,
    得到 C ( f φ ( x i ) , f φ ( x j ) ) C(f_{\varphi}\left(\boldsymbol{x}_{i}\right),f_{\varphi}\left(\boldsymbol{x}_{j}\right)) C(fφ(xi),fφ(xj))

  • C ( f φ ( x i ) , f φ ( x j ) ) C(f_{\varphi}\left(\boldsymbol{x}_{i}\right),f_{\varphi}\left(\boldsymbol{x}_{j}\right)) C(fφ(xi),fφ(xj)) 输入 g ϕ g_{\phi} gϕ
    得到 [ 0 , 1 ] [0, 1] [0,1] 范围内的标量,
    表示 x i \boldsymbol{x}_{i} xi x j \boldsymbol{x}_{j} xj 之间的相似性,记为关系得分 r i , j r_{i, j} ri,j
    x i \boldsymbol{x}_{i} xi x j \boldsymbol{x}_{j} xj 相似度越高, r i , j r_{i, j} ri,j 越大。

r i , j = g ϕ ( C ( f φ ( x i ) , f φ ( x j ) ) ) ,   i = 1 , 2 , . . . , C r_{i, j}=g_{\phi}\left(C\left(f_{\varphi}\left(\boldsymbol{x}_{i}\right), f_{\varphi}\left(\boldsymbol{x}_{j}\right)\right)\right), \ i = 1, 2, ..., C ri,j=gϕ(C(fφ(xi),fφ(xj))), i=1,2,...,C

2.2 RN 目标函数

ϕ , φ ← arg ⁡ min ⁡ ϕ , φ ∑ i = 1 m ∑ j = 1 n ( r i , j − 1 ( y i = = y j ) ) 2 \phi, \varphi \leftarrow \underset{\phi, \varphi}{\arg \min } \sum_{i=1}^{m} \sum_{j=1}^{n}\left(r_{i, j}-1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)\right)^{2} ϕ,φϕ,φargmini=1mj=1n(ri,j1(yi==yj))2

其中, 1 ( y i = y j ) 1\left(\boldsymbol{y}_{i}=\boldsymbol{y}_{j}\right) 1(yi=yj) 用来判断 x i \boldsymbol{x}_{i} xi x j \boldsymbol{x}_{j} xj 是否属于同一类别。
y i = y j \boldsymbol{y}_{i}=\boldsymbol{y}_{j} yi=yj 时, 1 ( y i = = y j ) = 1 1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)=1 1(yi==yj)=1
y i ≠ y j \boldsymbol{y}_{i} \neq \boldsymbol{y}_{j} yi=yj 时, 1 ( y i = = y j ) = 0 1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)=0 1(yi==yj)=0

2.3 RN 网络结构

嵌入模块和关系模块的选取有很多种,包括卷积网络、残差网络等。

图2给出了 [1] 中使用的 RN 模型结构。

图2 RN 模型结构。

2.3.1 嵌入模块结构

  • 每个卷积块分别包含 64 个 3 × \times × 3 滤波器进行卷积,一个归一化层、一个 ReLU 非线性层。

  • 总共有四个卷积块,前两个卷积块包含 2 × \times × 2 的最大池化层,后边两个卷积块没有池化层。

3.2 关系模块结构

  • 有两个卷积块,每个卷积模块中都包含 2 × \times × 2 的最大池化层。

  • 两个全连接层,第一个全连接层是 ReLU 非线性变换,最后的全连接层使用 Sigmoid 非线性变换输出 r i , j r_{i,j} ri,j

2.4 RN 分类结果

表1 RN 在 Omniglot 上的分类结果。
ModelFine Tune5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
MANNN82.8 % \% %94.9 % \% %
CONVOLUTIONAL SIAMESE NETSN96.7 % \% %98.4 % \% %88.0 % \% %96.5 % \% %
CONVOLUTIONAL SIAMESE NETSY97.3 % \% %98.4 % \% %88.1 % \% %97.0 % \% %
MATCHING NETSN98.1 % \% %98.9 % \% %93.8 % \% %98.5 % \% %
MATCHING NETSY97.9 % \% %98.7 % \% %93.5 % \% %98.7 % \% %
SIAMESE NETS WITH MEMORYN98.4 % \% %99.6 % \% %95.0 % \% %98.6 % \% %
NEURAL STATISTICIANN98.1 % \% %99.5 % \% %93.2 % \% %98.1 % \% %
META NETSN99.0 % \% %97.0 % \% %
PROTOTYPICAL NETSN98.8 % \% %99.7 % \% %96.0 % \% %98.9 % \% %
MAMLY98.7 ± \pm ± 0.4 % \% %99.9 ± \pm ± 0.1 % \% %95.8 ± \pm ± 0.3 % \% %98.9 ± \pm ± 0.2 % \% %
RELATION NETN99.6 ± \pm ± 0.2 % \% %99.8 ± \pm ± 0.1 % \% %97.6 ± \pm ± 0.2 % \% %99.1 ± \pm ± 0.1 % \% %
表1 RN 在 miniImageNet 上的分类结果。
ModelFT5-way 1-shot5-way 5-shot
MATCHING NETSN43.56 ± \pm ± 0.84 % \% %55.31 ± \pm ± 0.73 % \% %
META NETSN49.21 ± \pm ± 0.96 % \% %
META-LEARN LSTMN43.44 ± \pm ± 0.77 % \% %60.60 ± \pm ± 0.71 % \% %
MAMLY48.70 ± \pm ± 1.84 % \% %63.11 ± \pm ± 0.92 % \% %
PROTOTYPICAL NETSN49.42 ± \pm ± 0.78 % \% %68.20 ± \pm ± 0.66 % \% %
RELATION NETN50.44 ± \pm ± 0.82 % \% %65.32 ± \pm ± 0.70 % \% %
  • 参考文献

[1] Learning to Compare: Relation Network for Few-Shot Learning

3.Prototypical Network(PN)

Prototypical Network (PN) 利用支持集中每个类别提供的少量样本,
计算它们的嵌入中心,作为每一类样本的原型 (Prototype),
接着基于这些原型学习一个度量空间,
使得新的样本通过计算自身嵌入与这些原型的距离实现最终的分类。

3.1 PN

在 few-shot 分类任务中,
假设有 N N N 个标记的样本 S = ( x 1 , y 1 ) , … , ( x N , y N ) S=\left(x_{1}, y_{1}\right), \ldots,\left(x_{N}, y_{N}\right) S=(x1,y1),,(xN,yN)
其中, x i ∈ x_{i} \in xi R D \mathbb{R}^{D} RD D D D 维的样本特征向量,
y ∈ 1 , … , K y \in 1, \ldots, K y1,,K 是相应的标签。
S K S_{K} SK 表示第 k k k 类样本的集合。

PN 计算每个类的 M M M 维原型向量 c k ∈ R M c_{k} \in \mathbb{R}^{M} ckRM
计算的函数为 f ϕ : R D → R M f_{\phi}: \mathbb{R}^{D} \rightarrow \mathbb{R}^{M} fϕ:RDRM
其中 ϕ \phi ϕ 为可学习参数。
原型向量 c k c_{k} ck 即为嵌入空间中该类的所有 支持集样本点的均值向量

c k = 1 ∣ S K ∣ ∑ ( x i , y i ) ∈ S K f ϕ ( x i ) c_{k}=\frac{1}{\left|S_{K}\right|} \sum_{\left(x_{i}, y_{i}\right) \in S_{K}} f_{\phi}\left(x_{i}\right) ck=SK1(xi,yi)SKfϕ(xi)

给定一个距离函数 d : R M × R M → [ 0 , + ∞ ) d: \mathbb{R}^{M} \times \mathbb{R}^{M} \rightarrow[0,+\infty) d:RM×RM[0,+)
不包含任何可训练的参数,
PN 通过在嵌入空间中对距离进行 softmax 计算,
得到一个针对 x x x 的样本点的概率分布

p ϕ ( y = k ∣ x ) = exp ⁡ ( − d ( f ϕ ( x ) , c k ) ) ∑ k ′ exp ⁡ ( − d ( f ϕ ( x ) , c k ′ ) ) p_{\phi}(y=k \mid x)=\frac{\exp \left(-d\left(f_{\phi}(x), c_{k}\right)\right)}{\sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(x), c_{k^{\prime}}\right)\right)} pϕ(y=kx)=kexp(d(fϕ(x),ck))exp(d(fϕ(x),ck))

新样本点的特征离类别中心点越近,
新样本点属于这个类别的概率越高;
新样本点的特征离类别中心点越远,
新样本点属于这个类别的概率越低。

通过在 SGD 中最小化第 k k k 类的负对数似然函数 J ( ϕ ) J(\phi) J(ϕ) 来推进学习

J ( ϕ ) = argmin ⁡ ϕ ( ∑ k = 1 K − log ⁡ ( p ϕ ( y = k ∣ x k ) ) ) J(\phi)= \underset{\phi}{\operatorname{argmin}}\left(\sum_{k=1}^{K}-\log \left(p_{\phi}\left(\boldsymbol{y}=k \mid \boldsymbol{x}_{k}\right)\right)\right) J(ϕ)=ϕargmin(k=1Klog(pϕ(y=kxk)))

PN 示意图如图1所示。

图1 PN 示意图。

3.2 PN 算法流程

Input: Training set D = { ( x 1 , y 1 ) , … , ( x N , y N ) } \mathcal{D}=\left\{\left(\mathbf{x}_{1}, y_{1}\right), \ldots,\left(\mathbf{x}_{N}, y_{N}\right)\right\} D={(x1,y1),,(xN,yN)}, where each y i ∈ { 1 , … , K } y_{i} \in\{1, \ldots, K\} yi{1,,K}. D k \mathcal{D}_{k} Dk denotes the subset of D \mathcal{D} D containing all elements ( x i , y i ) \left(\mathbf{x}_{i}, y_{i}\right) (xi,yi) such that y i = k y_{i}=k yi=k.

Output: The loss J J J for a randomly generated training episode.

  1. select class indices for episode: V ←  RANDOMSAMPLE  ( { 1 , … , K } , N C ) V \leftarrow \text { RANDOMSAMPLE }\left(\{1, \ldots, K\}, N_{C}\right) V RANDOMSAMPLE ({1,,K},NC)
  2. for k k k in { 1 , … , N C } \left\{1, \ldots, N_{C}\right\} {1,,NC} do
    1. select support examples: S k ←  RANDOMSAMPLE  ( D V k , N S ) S_{k} \leftarrow \text { RANDOMSAMPLE }\left(\mathcal{D}_{V_{k}}, N_{S}\right) Sk RANDOMSAMPLE (DVk,NS)
    2. select query examples: Q k ←  RANDOMSAMPLE  ( D V k \ S k , N Q ) Q_{k} \leftarrow \text { RANDOMSAMPLE }\left(\mathcal{D}_{V_{k}} \backslash S_{k}, N_{Q}\right) Qk RANDOMSAMPLE (DVk\Sk,NQ)
    3. compute prototype from support examples: c k ← 1 N C ∑ ( x i , y i ) ∈ S k f ϕ ( x i ) c_k \leftarrow \frac{1}{N_{C}} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S_{k}} f_{\phi}\left(\mathbf{x}_{i}\right) ckNC1(xi,yi)Skfϕ(xi)
  3. end for
  4. J ← 0 J \leftarrow 0 J0
  5. for k k k in { 1 , … , N C } \left\{1, \ldots, N_{C}\right\} {1,,NC} do
    1. for x , y x, y x,y in Q k Q_{k} Qk do
    2. update loss J ← J + 1 N C N Q [ d ( f ϕ ( x ) , c k ) ) + log ⁡ ∑ k ′ exp ⁡ ( − d ( f ϕ ( x ) , c k ′ ) ) ] \left.J \leftarrow J+\frac{1}{N_{C} N_{Q}}\left[d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k}\right)\right)+\log \sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k^{\prime}}\right)\right)\right] JJ+NCNQ1[d(fϕ(x),ck))+logkexp(d(fϕ(x),ck))]
  6. end for
  7. end for

其中,

  • N N N 是训练集中的样本个数;
  • K K K 是训练集中的类个数;
  • N C ≤ K N_{C} \leq K NCK 是每个 episode 选出的类个数;
  • N S N_{S} NS 是每类中 support set 的样本个数;
  • N Q N_{Q} NQ 是每类中 query set 的样本个数;
  • R A N D O M S A M P L E ( S , N ) \mathrm{RANDOMSAMPLE}(S, N) RANDOMSAMPLE(S,N) 表示从集合 S \mathrm{S} S 中随机选出 N \mathrm{N} N 个元素。

3.3 PN 分类结果

表1 PN 在 Omniglot 上的分类结果。
ModelDist.Fine Tune5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
MATCHING NETWORKSCosineN98.1 % \% %98.9 % \% %93.8 % \% %98.5 % \% %
MATCHING NETWORKSCosineY97.9 % \% %98.7 % \% %93.5 % \% %98.7 % \% %
NEURAL STATISTICIAN-N98.1 % \% %99.5 % \% %93.2 % \% %98.1 % \% %
MAML-N98.7 % \% %99.9 % \% %95.8 % \% %98.9 % \% %
PROTOTYPICAL NETWORKSEuclid.N98.8 % \% %99.7 % \% %96.0 % \% %98.9 % \% %
表1 PN 在 miniImageNet 上的分类结果。
ModelDist.Fine Tune5-way 1-shot5-way 5-shot
BASELINE NEAREST NEIGHBORSCosineN28.86 ± \pm ± 0.54 % \% %49.79 ± \pm ± 0.79 % \% %
MATCHING NETWORKSCosineN43.40 ± \pm ± 0.78 % \% %51.09 ± \pm ± 0.71 % \% %
MATCHING NETWORKS (FCE)CosineN43.56 ± \pm ± 0.84 % \% %55.31 ± \pm ± 0.73 % \% %
META-LEARNER LSTM-N43.44 ± \pm ± 0.77 % \% %60.60 ± \pm ± 0.71 % \% %
MAML-N48.70 ± \pm ± 1.84 % \% %63.15 ± \pm ± 0.91 % \% %
PROTOTYPICAL NETWORKSEuclid.N49.42 ± \pm ± 0.78 % \% %68.20 ± \pm ± 0.66 % \% %
  • 参考文献

[1] Prototypical Networks for Few-shot Learning

4.Matching Network(MN)

Matching Network (MN)
结合了度量学习 (Metric Learning) 与记忆增强神经网络 (Memory Augment Neural Networks),
并利用注意力机制与记忆机制加速学习,同时提出了 set-to-set 框架,
使得 MN 能够为新类产生合理的测试标签,且不用网络做任何改变。

4.1 MN

将支持集 S = { ( x i , y i ) } i = 1 k S=\left\{\left(x_{i}, y_{i}\right)\right\}_{i=1}^{k} S={(xi,yi)}i=1k
映射到一个分类器 c S ( x ^ ) c_{S}(\hat{x}) cS(x^)
给定一个测试样本 x ^ \hat{x} x^ c S ( x ^ ) c_{S}(\hat{x}) cS(x^) 定义一个关于输出 y ^ \hat{y} y^ 的概率分布,即

S → c S ( x ^ ) : = P ( y ^ ∣ x ^ , S ) S \rightarrow c_{S}\left(\hat{x}\right):= P\left(\hat{y} \mid \hat{x}, S\right) ScS(x^):=P(y^x^,S)

其中, P P P 被网络参数化。
因此,当给定一个新的支持集 S ′ S^{\prime} S 进行小样本学习时,
只需使用 P P P 定义的网络来预测每个测试示例 x ^ \hat{x} x^ 的适当标签分布
P ( y ^ ∣ x ^ , S ′ ) P\left(\hat{y} \mid \hat{x}, S^{\prime}\right) P(y^x^,S) 即可。

4.1.1 注意力机制

模型以最简单的形式计算 y ^ \hat{y} y^ 上的概率:

P ( y ^ ∣ x ^ , S ) = ∑ i = 1 k a ( x ^ , x i ) y i P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i} P(y^x^,S)=i=1ka(x^,xi)yi

上式本质是将一个输入的新类描述为支持集中所有类的一个线性组合,
结合了核密度估计KDE( a a a 可以看做是一种核密度估计)和 KNN 。
其中, k k k 表示支持集中样本类别数,
a ( x ^ , x i ) a\left(\hat{x}, x_{i}\right) a(x^,xi) 是注意力机制,
类似 attention 模型中的核函数,
用来度量 x ^ \hat{x} x^ 和训练样本 x i x_{i} xi 的匹配度。

a a a 的计算基于新样本数据与支持集中的样本数据的嵌入表示的余弦相似度以及softmax函数:

a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a\left(\hat{x}, x_{i}\right)=\frac{e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)}}{\sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)}} a(x^,xi)=j=1kec(f(x^),g(xj))ec(f(x^),g(xi))

其中, c ( ⋅ ) c(\cdot) c() 表示余弦相似度,
f f f g g g 表示施加在测试样本与训练样本上的嵌入函数 (Embedding Function)。

如果注意力机制是 X × X X \times X X×X 上的核,
则上式类似于核密度估计器。
如果选取合适的距离度量以及适当的常数,
从而使得从 x i x_{i} xi x ^ \hat{x} x^ 的注意力机制为 0 ,
则上式等价于 KNN 。

图1是 MN 的网络结构示意图。

图1 MN 示意图。

4.1.2 Full Context Embeddings

为了增强样本嵌入的匹配度,
[1] 提出了 Full Context Embeeding (FCE) 方法:
支持集中每个样本的嵌入应该是相互独立的,
而新样本的嵌入应该受支持集样本数据分布的调控,
其嵌入过程需要放在整个支持集环境下进行,
因此 [1] 采用带有注意力的 LSTM 网络对新样本进行嵌入。

在对余弦注意力定义时,
每个已知标签的输入 x i x_i xi 通过 CNN 后的 embedding ,
因此 g ( x i ) g(x_i) g(xi) 是独立的,前后没有关系,
然后与 f ( x ^ ) f\left(\hat{x}\right) f(x^) 进行逐个对比,
并没有考虑到输入任务 S S S 改变 embedding x ^ \hat{x} x^ 的方式,
f ( ⋅ ) f(\cdot) f() 应该是受 g ( S ) g(S) g(S) 影响的。
为了实现这个功能,[1] 采用了双向 LSTM 。

在通过嵌入函数 f f f g g g 处理后,
输出再次经过循环神经网络进一步加强 context 和个体之间的关系。

f ( x ^ , S ) = a t t L S T M ( f ′ ( x ^ ) , g ( S ) , K ) f\left(\hat{x},S\right)=\mathrm{attLSTM}\left(f'\left(\hat{x}\right),g(S),K\right) f(x^,S)=attLSTM(f(x^),g(S),K)

其中, S S S 是相关的上下文, K K K 为网络的 timesteps 。

因此,经过 k k k 步后的状态为:

h ^ k , c k = LSTM ⁡ ( f ′ ( x ^ ) , [ h k − 1 , r k − 1 ] , c k − 1 ) h k = h ^ k + f ′ ( x ^ ) r k − 1 = ∑ i = 1 ∣ S ∣ a ( h k − 1 , g ( x i ) ) g ( x i ) a ( h k − 1 , g ( x i ) ) = e h k − 1 T g ( x i ) / ∑ j = 1 ∣ S ∣ e h k − 1 T g ( x j ) \begin{aligned} & \hat{h}_{k}, c_{k} =\operatorname{LSTM}\left(f^{\prime}(\hat{x}),\left[h_{k-1}, r_{k-1}\right], c_{k-1}\right) \\ & h_{k} =\hat{h}_{k}+f^{\prime}(\hat{x}) \\ & r_{k-1} =\sum_{i=1}^{|S|} a\left(h_{k-1}, g\left(x_{i}\right)\right) g\left(x_{i}\right) \\ & a\left(h_{k-1}, g\left(x_{i}\right)\right) =e^{h_{k-1}^{T} g\left(x_{i}\right)} / \sum_{j=1}^{|S|} e^{h_{k-1}^{T} g\left(x_{j}\right)} \end{aligned} h^k,ck=LSTM(f(x^),[hk1,rk1],ck1)hk=h^k+f(x^)rk1=i=1Sa(hk1,g(xi))g(xi)a(hk1,g(xi))=ehk1Tg(xi)/j=1Sehk1Tg(xj)

4.2 网络结构

特征提取器可采用常见的 VGG 或 Inception 网络,
[1] 设计了一种简单的四级网络结构用于图像分类任务的特征提取,
每级网络由一个 64 通道的 3 × \times × 3 卷积层,一个批规范化层,
一个 ReLU 激活层和一个 2 × \times × 2 的最大池化层构成。
然后将最后一层输出的特征输入到 LSTM 网络中得到最终的特征映射
f ( x ^ , S ) f\left(\hat{x},S\right) f(x^,S) g ( x i , S ) g\left({x_i},S\right) g(xi,S)

4.3 损失函数

θ = arg ⁡ max ⁡ θ E L ∼ T [ E S ∼ L , B ∼ L [ ∑ ( x , y ) ∈ B log ⁡ P θ ( y ∣ x , S ) ] ] \theta=\arg \max _{\theta} E_{L \sim T}\left[E_{S \sim L, B \sim L}\left[\sum_{(x, y) \in B} \log P_{\theta}(y \mid x, S)\right]\right] θ=argθmaxELT ESL,BL (x,y)BlogPθ(yx,S)

4.4 MN 算法流程

  • 将任务 S S S 中所有图片 x i x_i xi (假设有 K K K 个)和目标图片 x ^ \hat{x} x^(假设有 1 个)
    全部通过 CNN 网络,获得它们的浅层变量表示。

  • 将( K + 1 K+1 K+1 个)浅层变量全部输入到 BiLSTM 中,获得 K + 1 K+1 K+1 个输出,
    然后使用余弦距离判断前 K K K 个输出中每个输出与最后一个输出之间的相似度。

  • 根据计算出来的相似度,按照任务 S S S 中的标签信息 y 1 , y 2 , … , y K y_1, y_2, \ldots, y_K y1,y2,,yK
    求解目标图片 x ^ \hat{x} x^ 的类别标签 y ^ \hat{y} y^

4.5 MN 分类结果

表1 MN 在 Omniglot 上的分类结果。
ModelMatching FnFine Tune5-way 1-shot5-way 5-shot20-way 1-shot20-way 5-shot
PIXELSCosineN41.7 % \% %63.2 % \% %26.7 % \% %42.6 % \% %
BASELINE CLASSIFIERCosineN80.0 % \% %95.0 % \% %69.5 % \% %89.1 % \% %
BASELINE CLASSIFIERCosineY82.3 % \% %98.4 % \% %70.6 % \% %92.0 % \% %
BASELINE CLASSIFIERSoftmaxY86.0 % \% %97.6 % \% %72.9 % \% %92.3 % \% %
MANN (NO CNOV)CosineN82.8 % \% %94.9 % \% %
CONVOLUTIONAL SIAMESE NETCosineY96.7 % \% %98.4 % \% %88.0 % \% %96.5 % \% %
CONVOLUTIONAL SIAMESE NETCosineY97.3 % \% %98.4 % \% %88.1 % \% %97.0 % \% %
MATCHING NETSCosineN98.1 % \% %98.9 % \% %93.8 % \% %98.5 % \% %
MATCHING NETSCosineY97.9 % \% %98.7 % \% %93.5 % \% %98.7 % \% %
表1 MN 在 miniImageNet 上的分类结果。
ModelMatching FnFine Tune5-way 1-shot5-way 5-shot
PIXELSCosineN23.0 % \% %26.6 % \% %
BASELINE CLASSIFIERCosineN36.6 % \% %46.0 % \% %
BASELINE CLASSIFIERCosineY36.2 % \% %52.2 % \% %
BASELINE CLASSIFIERCosineY38.4 % \% %51.2 % \% %
MATCHING NETSCosineN41.2 % \% %56.2 % \% %
MATCHING NETSCosineY42.4 % \% %58.0 % \% %
MATCHING NETSCosine (FCE)N44.2 % \% %57.0 % \% %
MATCHING NETSCosine (FCE)Y46.6 % \% %60.0 % \% %

4.6 创新点

  • 采用匹配的形式实现小样本分类任务,
    引入最近邻算法的思想解决了深度学习算法在小样本的条件下无法充分优化参数而导致的过拟合问题,
    且利用带有注意力机制和记忆模块的网络解决了普通最近邻算法过度依赖度量函数的问题,
    将样本的特征信息映射到更高维度更抽象的特征空间中。

  • one-shot learning 的训练策略,一个训练任务中包含支持集和 Batch 样本。

4.7 算法评价

  • MN 受到非参量化算法的限制,
    随着支持集 S S S 的增长,每次迭代的计算量也会随之快速增长,导致计算速度降低。

  • 在测试时必须提供包含目标样本类别在内的支持集,
    否则它只能从支持集所包含的类别中选择最为接近的一个输出其类别,而不能输出正确的类别。

  • 参考文献

[1] Matching Networks for One Shot Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/648771.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

免费的Outlook邮箱备份的方法!

关于Outlook邮箱备份 Outlook是Microsoft Office微软办公软件套装的组件之一,利用一套Microsoft应用程序和服务,与Office工具共享与协作,和各种设备时刻保持联系。用户可以通过登录邮箱首页申请Outlook为域名后缀的邮箱。 长时间使用Outl…

使用海外社交媒体的九种方法(和五个专业提示!)

社交媒体是推广您的产品、吸引新客户以及围绕您的业务建立社区的有效方式。您可以使用社交媒体做的事情几乎没有限制,但无限的可能性会带来让自己过于分散的风险。好消息是,通过遵循一些最佳实践,您的在线商店可以取得成功。 目录 使用…

3ds MAX 灯光

简单展示一下3dsMAX中灯光的不同效果: 先简单画一个房间出来,展示效果就不添加材质了 在选项卡中添加灯光: 首先是直接渲染的结果,没有添加光照,可以看出也没有阴影等 采用【目标聚光灯】,这有点类似在摄…

【MySQL新手入门系列二】:手把手教你入门MySQL - 数据库及数据表操作

如果您是一位刚刚开始学习MySQL的新手,本文将为您提供一些实用的入门知识和技巧,帮助您快速上手。 【MySQL新手入门系列一】:手把手教你入门MySQL 前面我们已经大致讲了一下mysql的安装等介绍,本篇文章将以windows为例&#xff0c…

聚类分析(文末送书)

目录 聚类分析是什么 一、 定义和数据类型 聚类应用 聚类分析方法的性能指标 聚类分析中常用数据结构有数据矩阵和相异度矩阵 聚类分析方法分类 二、K-means聚类算法 划分聚类方法对数据集进行聚类时包含三个要点 K-Means算法流程: K-means聚类算法的特点 三、k-med…

【JVM系列】垃圾收集器介绍

文章目录 垃圾收集器Serial收集器ParNew收集器Parallel收集器CMS收集器G1收集器 常用的收集器组合 垃圾收集器 Serial收集器 串行收集器是最古老,最稳定以及效率高的收集器,可能会产生较长的停顿,只使用一个线程去回收。新生代、老年代使用…

不逆向解决5s盾之cloudscraper

一、背景 经常写爬虫的同学,肯定知道 Cloud Flare 的五秒盾。当你没有使用正常的浏览器访问网站的时候,它会返回如下这段文字: Checking your browser before accessing xxx. This process is automatic. Your browser will redirect to your…

020+limou+C语言内存管理

0.在Linux下验证C语言地址空间排布 这里是limou3434的博文系列。接下来,我会带您了解在C语言程序视角下的内存分布,会涉及到一点操作系统的知识,但是不多,您无需担忧。 注意:只能在Linux下验证,因为Windo…

外观模式(十三)

每天都是全新的一天,感谢今日努力的自己。 上一章简单介绍了组合模式(十二), 如果没有看过, 请观看上一章 一. 外观模式 引用 菜鸟教程里面的外观模式介绍: https://www.runoob.com/design-pattern/facade-pattern.html 外观模式(Facade Pattern&…

商品编号篡改测试-业务安全测试实操(7)

商品编号篡改测试,邮箱和用户篡改测试 手机号码篡改测试-业务安全测试实操(6)_luozhonghua2000的博客-CSDN博客 邮箱和用户篡改测试 测试原理和方法 在发送邮件或站内消息时,篡改其中的发件人参数,导致攻击者可以伪造发信人进行钓鱼攻击等操作,这也是一种平行权限绕过漏洞…

2023年CPSM-3中级项目管理专业人员认证招生简章

CPSM-3中级项目管理专业人员认证,是中国标准化协会(全国项目管理标准化技术委员会秘书处)联合中国国际人才交流基金会,面向社会开展项目管理专业人员能力的等级证书。旨在构建多层次从业人员培养培训体系,建立健全人才…

软体机器人对工业应用的影响

原创 | 文 BFT机器人 软机器人模仿生物体的运动和动作,使它们具有高度的多功能性和迷人性。 软机器人领域正在迅速发展。它旨在为各种行业创造灵活的设备,包括医疗保健、太空探索、食品生产、地理、物流、康复、国防和家庭应用。 软机器人的独特之处在于…

windows下使用cmake编译c++

好久没有更新博客了 最近在做c相关的,编译起来确实很痛苦。 所以心血来潮,继续更新一下 主要还是一些跨平台的库,比如zlib、libpng、opencv、ffmpeg 编译工具使用mingw作为主要编译环境支持,使用msys进行编译。 一、下载mingw…

Python--输入和输出

Python--输入和输出 <font colorblue>一、输入&#xff1a;input()函数<font colorblue>二、输出&#xff1a;print()函数<font colorblue>1.print函数说明<font colorblue>2.格式化输出<font colorblue>方法一&#xff1a;使用占位符&#xff0…

MySQL哈希索引

介绍 建表时存储引擎选择 MEMORY&#xff0c;则创建索引就是哈希索引&#xff1a;如create index nameidx on student(name);哈希索引底层数据结构就是链式哈希表&#xff0c;链式就是指冲突时用链表法解决哈希表中的元素没有任何顺序可言&#xff0c;则只能进行等值比较。如果…

PinYin4j库的使用

一、PinYin4j库简介 1、PinYin4j简介 Pinyin4j 是一个流行的 Java 库&#xff0c;支持汉字和大多数流行的拼音系统之间的转换&#xff08;汉语拼音&#xff0c;罗马拼音等&#xff09;。可自定义拼音输出格式&#xff0c;功能强大。 官网地址&#xff1a;http://pinyin4j.sou…

道岔表示故障电路如何进行检查

一、分线盘区分提速道岔表示电路故障的方法 定位时可以通过测量X1、X2 (或者反位时X1、X3)端子间的交直流电压和BD1-7的3号端子上的电流&#xff0c;来判断表示电路的故障和范围。表示电路正常工作时&#xff0c;在分线盘端子X1、X2之间可以测到电压交流60V左右&#xff0c;直…

小白也能学会的电脑C盘空间释放技巧大集合

引言 电脑C盘快装满了怎么办&#xff1f;这是很多人使用电脑时面临的困扰。电脑的运行速度会变得很慢&#xff0c;甚至出现蓝屏等问题。那么&#xff0c;如何解决电脑C盘快装满的问题呢&#xff1f;接下来&#xff0c;本文将详细介绍解决电脑C盘快装满的几种方法。 先记录一下…

浪涌保护器的工作原理

浪涌保护器&#xff08;SPD&#xff09;旨在通过限制瞬态电压和转移浪涌电流来保护电气系统和设备免受浪涌事件的影响。 浪涌可能来自外部&#xff0c;最强烈的是雷击&#xff0c;也可能来自内部的电气负载切换。这些内部浪涌的来源占所有瞬变的65%&#xff0c;可能包括负载打…

CSS | CSS中height:100vh和height:100%的区别

目录 1、对于设置height:100%;有下面几种情况 2、对于设置height:100vh时有如下的情况 首先&#xff0c;我们得知道1vh它表示的是当前屏幕可见高度的1/100&#xff0c;而1%它表示的是父元素长或者宽的1% 1、对于设置height:100%;有下面几种情况 &#xff08;1&#xff09;当…