当我们面临具有挑战性的图像分类任务时,我们希望通过分解part来解释推理。每一类别的更多原型证据有助于做出最终分类决策。作者提出一种深度网络架构:Prototypical Part网络即ProtoPNet。网络通过寻找原型part来解释图像,并基于原型part进行分类。网络仅使用图像级标签进行训练,并在推理时表现出与专家一样的水平(比如CUB上可以比肩鸟类学家)。并且当ProtoPNet组合到更大的网络中时,可以实现与性能最好的深度模型相当的精度。
来自:This Looks Like That: Deep Learning for Interpretable Image Recognition
目录
- 背景概述
- 案例一:鸟类识别
- 网络架构
- 训练
背景概述
如何描述为什么图1中的图像看起来像一只clay colored sparrow?也许这只鸟的头和翅膀看起来像典型的clay colored sparrow。当我们描述如何对图像进行分类时,我们可能会关注图像的part,并将其与给定类别的图像的prototypical part进行比较。这种推理方法通常用于困难的识别任务:例如医学图像分类,细粒度自然图像分类。因此,该工作的目标是定义一种图像处理中的可解释性,希望网络与人类在分类任务中描述自己思维的方式一致。
作者引入了一种网络架构ProtoPNet,它适应了可解释性的定义。给定如图1所示的鸟类图像,模型能够识别图像的几个part,并认为图像的这一部分看起来像某个类别的原型部分,并基于图像部分和学习到的原型之间的相似性得分的加权组合进行预测。通过这种方式,模型是可解释的,因为它在进行预测时有一个透明的推理过程。
- 图1:clay colored sparrow的图像,以及它的part可以成为分类clay colored sparrow的原型。
案例一:鸟类识别
作者在鸟类物种识别的背景下介绍了ProtoPNet的架构和训练程序,并详细介绍了网络如何对新的鸟类图像进行分类并解释其预测。实验在CUB200-2011数据集上对200种鸟类进行了训练和评估。
网络架构
图2概述了ProtoPNet的体系结构。网络包括标准卷积神经网络 f f f,参数为 w c o n v w_{conv} wconv,和prototype层 g p g_{\textbf{p}} gp,然后是全连接层 h h h,参数为 w h w_{h} wh。对于 f f f,可以使用VGG-16、VGG-19、ResNet-34、ResNet-152、DenseNet-121或DenseNet-161等(最好是在ImageNet上训练)的卷积层,然后接两个额外的 1 × 1 1\times 1 1×1卷积层。使用ReLU作为所有卷积层的激活函数,除了最后一个卷积层使用sigmoid激活函数。
给定输入图像 x x x(如图2中的clay colored sparrow),卷积层提取有用的特征 f ( x ) f(x) f(x)。假设 f ( x ) f(x) f(x)的形状为 H × W × D H\times W\times D H×W×D。对于输入图像为 ( 224 , 224 , 3 ) (224,224,3) (224,224,3),有 H = W = 7 H=W=7 H=W=7, D D D可以是128,256,512。网络需要学习 m m m个原型 P = { p j } j = 1 m \textbf{P}=\left\{\textbf{p}_{j}\right\}_{j=1}^{m} P={pj}j=1m,原型形状为 H 1 × W 1 × D H_{1}\times W_{1}\times D H1×W1×D。在实验中,使用 H 1 = W 1 = 1 H_1=W_1=1 H1=W1=1。由于每个原型的深度与卷积输出的深度相同,但每个原型的高度和宽度小于整个卷积输出的高度和宽度,因此每个原型用于表示卷积输出的patch中的一些原型模式,这些patch代表原始像素空间中的图像区域。因此,每个原型 p j \textbf{p}_j pj可以被理解为某些鸟类图像的原型part的表示。作为示意图,图2中的第一个原型 p 1 \textbf{p}_1 p1对应于clay colored sparrow的头部,第二个原型 p 2 \textbf{p}_2 p2对应于Brewer’s sparrow的头部。
- 图2:ProtoPNet架构。
令 z = f ( x ) z=f(x) z=f(x),原型层 g p g_{\textbf{p}} gp中的第 j j j个原型单元 g p j g_{\textbf{p}_j} gpj计算了第 j j j个原型 p j \textbf{p}_j pj和具有与 p j \textbf{p}_j pj相同形状的 z z z的所有patch之间的平方 L 2 L^{2} L2距离,并将距离转换为相似性得分。结果是相似性得分的激活图,其值指示原型part在图像中存在的强度。该激活图保留了卷积输出的空间关系,并且可以被上采样到输入图像的大小,以产生热图,该热图识别输入图像的哪个部分与该原型最相似。
然后,使用全局最大池化将每个原型单元 g p j g_{\textbf{p}_j} gpj产生的相似性得分的激活图缩减为单个相似性得分,其可以理解为原型part在输入图像中存在的强度:
- 在图2中,第一个原型是clay colored sparrow的头部原型,和clay colored sparrow输入图像中最活跃的(右上)patch之间的相似性得分为3:954,第二个原型是Brewer’s sparrow头部原型,和输入图像中最大活跃的patch之间的相似性得分为1.447。这表明,在输入图像中,clay colored sparrow的头部比Brewer’s sparrow的头部存在性更强。
在数学上, g p j g_{\textbf{p}_{j}} gpj计算: g p j ( z ) = m a x z ~ ∈ p a t c h e s ( z ) l o g ( ∣ ∣ z ~ − p j ∣ ∣ 2 2 + 1 ∣ ∣ z ~ − p j ∣ ∣ 2 2 + ϵ ) g_{\textbf{p}_{j}}(z)=max_{\widetilde{z}\in patches(z)}log(\frac{||\widetilde{z}-\textbf{p}_{j}||_{2}^{2}+1}{||\widetilde{z}-\textbf{p}_{j}||_{2}^{2}+\epsilon}) gpj(z)=maxz ∈patches(z)log(∣∣z −pj∣∣22+ϵ∣∣z −pj∣∣22+1)因此,如果第 j j j个原型单元 g p j g_{\textbf{p}_j} gpj的输出很大,那么在卷积输出中存在一个patch,该patch非常接近潜在空间中的第 j j j个原型,这反过来意味着在输入图像中有一个patch具有与第 j j j个原型所表示的相似的语义。
在ProtoPNet中,为每个类 k ∈ { 1 , . . . , K } k\in\left\{1,...,K\right\} k∈{1,...,K}分配预先确定的原型数量 m k m_{k} mk,实验中每类都为10。类别 k k k对应的原型集为 P k \textbf{P}_{k} Pk。
最后,将 m m m个相似性得分乘以全连接层 h h h中的权重矩阵,以产生输出logits,使用softmax对其进行归一化,以产生属于各种类别的预测概率。
训练
ProtoPNet的训练分为:
- 最后一层之前的层的随机梯度下降;
- 原型投影;
- 最后一层的凸优化;
对于第一部分:第一个训练阶段,目标是学习一个有意义的潜在空间,其中用于分类图像的最重要的patch被聚类在图像的真实类的语义相似的原型周围,并且以来自不同类的原型为中心的簇被很好地分离。为了实现该目标,首先使用SGD优化 w c o n v w_{conv} wconv和原型集 P \textbf{P} P,并保持 w h w_{h} wh固定。
设 D = [ X , Y ] = { ( x i , y i ) } i = 1 n D=[X,Y]=\left\{(x_{i},y_{i})\right\}_{i=1}^{n} D=[X,Y]={(xi,yi)}i=1n为训练集,优化目标是: m i n w c o n v , P 1 n ∑ i = 1 n C E ( h ∘ g p ∘ f ( x i ) , y i ) + λ 1 C l s t + λ 2 S e p C l s t = 1 n ∑ i = 1 n m i n j : p j ∈ P y i m i n z ∈ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 S e p = − 1 n ∑ i = 1 n m i n j : p j ∉ P y i m i n z ∈ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 min_{w_{conv},\textbf{P}}\frac{1}{n}\sum_{i=1}^{n}CE(h\circ g_{\textbf{p}}\circ f(x_{i}),y_{i})+\lambda_{1}Clst+\lambda_{2}Sep\\Clst=\frac{1}{n}\sum_{i=1}^{n}min_{j:\textbf{p}_{j}\in\textbf{P}_{y_{i}}}min_{z\in patches(f(x_{i}))}||z-\textbf{p}_{j}||_{2}^{2}\\Sep=-\frac{1}{n}\sum_{i=1}^{n}min_{j:\textbf{p}_{j}\notin\textbf{P}_{y_{i}}}min_{z\in patches(f(x_{i}))}||z-\textbf{p}_{j}||_{2}^{2} minwconv,Pn1i=1∑nCE(h∘gp∘f(xi),yi)+λ1Clst+λ2SepClst=n1i=1∑nminj:pj∈Pyiminz∈patches(f(xi))∣∣z−pj∣∣22Sep=−n1i=1∑nminj:pj∈/Pyiminz∈patches(f(xi))∣∣z−pj∣∣22交叉熵损失 C E CE CE惩罚对训练数据的错误分类。 C l s t Clst Clst的最小化鼓励每个训练图像具有接近其自己类的至少一个原型的一些潜在patch,而 S e p Sep Sep的最小化则鼓励训练图像的每个潜在patch远离不属于其自身类的原型。这些项将潜在空间塑造成语义上有意义的聚类结构,这有助于网络基于L2距离分类。
对于 w h w_{h} wh,令 w h ( k , j ) w_{h}^{(k,j)} wh(k,j)表示连接第 j j j个原型和第 k k k个类别的权重,给定类别 k k k,设置符合 p j ∈ P k \textbf{p}_{j}\in \textbf{P}_{k} pj∈Pk的 j j j有 w h ( k , j ) = 1 w_{h}^{(k,j)}=1 wh(k,j)=1,而符合 p j ∉ P k \textbf{p}_{j}\notin \textbf{P}_{k} pj∈/Pk的 j j j有 w h ( k , j ) = − 0.5 w_{h}^{(k,j)}=-0.5 wh(k,j)=−0.5。可以发现,其实 w h w_{h} wh跟随类别的改变而改变,但这种改变是人工干预的。
对于第二部分:为了能够将原型可视化为训练图像patch,作者将每个原型 p j \textbf{p}_j pj投影(push)到与 p j \textbf{p}_j pj相同类的最近的潜在训练patch上,对于属于类别 k k k的原型 p j \textbf{p}_{j} pj, p j ∈ P k \textbf{p}_{j}\in\textbf{P}_{k} pj∈Pk,有: p j ← a r g m i n z ∈ Z j ∣ ∣ z − p j ∣ ∣ 2 Z j = { z ~ : z ~ ∈ p a t c h e s ( f ( x i ) ) ∀ i s . t . y i = k } \textbf{p}_{j}\leftarrow argmin_{z\in Z_{j}}||z-\textbf{p}_{j}||_{2}\\ Z_{j}=\left\{\widetilde{z}:\widetilde{z}\in patches(f(x_{i}))\forall i\thinspace s.t.\thinspace y_{i}=k\right\} pj←argminz∈Zj∣∣z−pj∣∣2Zj={z :z ∈patches(f(xi))∀is.t.yi=k}
对于第三部分:对最后一层 h h h的权重矩阵 w h w_h wh进行凸优化。该优化是凸的,因为固定了来自卷积层和原型层的所有参数。该阶段在不改变学习到的潜在空间或原型的情况下可以进一步提高准确性。优化目标为: m i n w h 1 n ∑ i = 1 n C E ( h ∘ g p ∘ f ( x i ) , y i ) + λ ∑ k = 1 K ∑ j : p j ∉ P k ∣ w h ( k , j ) ∣ min_{w_{h}}\frac{1}{n}\sum_{i=1}^{n}CE(h\circ g_{\textbf{p}}\circ f(x_{i}),y_{i})+\lambda\sum_{k=1}^{K}\sum_{j:\textbf{p}_{j}\notin\textbf{P}_{k}}|w_{h}^{(k,j)}| minwhn1i=1∑nCE(h∘gp∘f(xi),yi)+λk=1∑Kj:pj∈/Pk∑∣wh(k,j)∣该阶段的目标是调整最后一层 w h ( k , j ) w_{h}^{(k,j)} wh(k,j),正则化项可以提高负类别推理的稀疏性,这种稀疏性可以降低以下负面推理形式:该鸟属于 k ′ k' k′类,因为它不是 k k k类(它包含了一个不是 k k k类原型的patch)。