原型part学习NeurIPS2019

news2025/1/17 4:13:18

当我们面临具有挑战性的图像分类任务时,我们希望通过分解part来解释推理。每一类别的更多原型证据有助于做出最终分类决策。作者提出一种深度网络架构:Prototypical Part网络即ProtoPNet。网络通过寻找原型part来解释图像,并基于原型part进行分类。网络仅使用图像级标签进行训练,并在推理时表现出与专家一样的水平(比如CUB上可以比肩鸟类学家)。并且当ProtoPNet组合到更大的网络中时,可以实现与性能最好的深度模型相当的精度。

来自:This Looks Like That: Deep Learning for Interpretable Image Recognition

目录

  • 背景概述
  • 案例一:鸟类识别
    • 网络架构
    • 训练

背景概述

如何描述为什么图1中的图像看起来像一只clay colored sparrow?也许这只鸟的头和翅膀看起来像典型的clay colored sparrow。当我们描述如何对图像进行分类时,我们可能会关注图像的part,并将其与给定类别的图像的prototypical part进行比较。这种推理方法通常用于困难的识别任务:例如医学图像分类,细粒度自然图像分类。因此,该工作的目标是定义一种图像处理中的可解释性,希望网络与人类在分类任务中描述自己思维的方式一致。

作者引入了一种网络架构ProtoPNet,它适应了可解释性的定义。给定如图1所示的鸟类图像,模型能够识别图像的几个part,并认为图像的这一部分看起来像某个类别的原型部分,并基于图像部分和学习到的原型之间的相似性得分的加权组合进行预测。通过这种方式,模型是可解释的,因为它在进行预测时有一个透明的推理过程。

fig1

  • 图1:clay colored sparrow的图像,以及它的part可以成为分类clay colored sparrow的原型。

案例一:鸟类识别

作者在鸟类物种识别的背景下介绍了ProtoPNet的架构和训练程序,并详细介绍了网络如何对新的鸟类图像进行分类并解释其预测。实验在CUB200-2011数据集上对200种鸟类进行了训练和评估。

网络架构

图2概述了ProtoPNet的体系结构。网络包括标准卷积神经网络 f f f,参数为 w c o n v w_{conv} wconv,和prototype层 g p g_{\textbf{p}} gp,然后是全连接层 h h h,参数为 w h w_{h} wh。对于 f f f,可以使用VGG-16、VGG-19、ResNet-34、ResNet-152、DenseNet-121或DenseNet-161等(最好是在ImageNet上训练)的卷积层,然后接两个额外的 1 × 1 1\times 1 1×1卷积层。使用ReLU作为所有卷积层的激活函数,除了最后一个卷积层使用sigmoid激活函数。

给定输入图像 x x x(如图2中的clay colored sparrow),卷积层提取有用的特征 f ( x ) f(x) f(x)。假设 f ( x ) f(x) f(x)的形状为 H × W × D H\times W\times D H×W×D。对于输入图像为 ( 224 , 224 , 3 ) (224,224,3) (224,224,3),有 H = W = 7 H=W=7 H=W=7 D D D可以是128,256,512。网络需要学习 m m m个原型 P = { p j } j = 1 m \textbf{P}=\left\{\textbf{p}_{j}\right\}_{j=1}^{m} P={pj}j=1m,原型形状为 H 1 × W 1 × D H_{1}\times W_{1}\times D H1×W1×D。在实验中,使用 H 1 = W 1 = 1 H_1=W_1=1 H1=W1=1。由于每个原型的深度与卷积输出的深度相同,但每个原型的高度和宽度小于整个卷积输出的高度和宽度,因此每个原型用于表示卷积输出的patch中的一些原型模式,这些patch代表原始像素空间中的图像区域。因此,每个原型 p j \textbf{p}_j pj可以被理解为某些鸟类图像的原型part的表示。作为示意图,图2中的第一个原型 p 1 \textbf{p}_1 p1对应于clay colored sparrow的头部,第二个原型 p 2 \textbf{p}_2 p2对应于Brewer’s sparrow的头部。

fig2

  • 图2:ProtoPNet架构。

z = f ( x ) z=f(x) z=f(x),原型层 g p g_{\textbf{p}} gp中的第 j j j个原型单元 g p j g_{\textbf{p}_j} gpj计算了第 j j j个原型 p j \textbf{p}_j pj和具有与 p j \textbf{p}_j pj相同形状的 z z z的所有patch之间的平方 L 2 L^{2} L2距离,并将距离转换为相似性得分。结果是相似性得分的激活图,其值指示原型part在图像中存在的强度。该激活图保留了卷积输出的空间关系,并且可以被上采样到输入图像的大小,以产生热图,该热图识别输入图像的哪个部分与该原型最相似。

然后,使用全局最大池化将每个原型单元 g p j g_{\textbf{p}_j} gpj产生的相似性得分的激活图缩减为单个相似性得分,其可以理解为原型part在输入图像中存在的强度:

  • 在图2中,第一个原型是clay colored sparrow的头部原型,和clay colored sparrow输入图像中最活跃的(右上)patch之间的相似性得分为3:954,第二个原型是Brewer’s sparrow头部原型,和输入图像中最大活跃的patch之间的相似性得分为1.447。这表明,在输入图像中,clay colored sparrow的头部比Brewer’s sparrow的头部存在性更强。

在数学上, g p j g_{\textbf{p}_{j}} gpj计算: g p j ( z ) = m a x z ~ ∈ p a t c h e s ( z ) l o g ( ∣ ∣ z ~ − p j ∣ ∣ 2 2 + 1 ∣ ∣ z ~ − p j ∣ ∣ 2 2 + ϵ ) g_{\textbf{p}_{j}}(z)=max_{\widetilde{z}\in patches(z)}log(\frac{||\widetilde{z}-\textbf{p}_{j}||_{2}^{2}+1}{||\widetilde{z}-\textbf{p}_{j}||_{2}^{2}+\epsilon}) gpj(z)=maxz patches(z)log(∣∣z pj22+ϵ∣∣z pj22+1)因此,如果第 j j j个原型单元 g p j g_{\textbf{p}_j} gpj的输出很大,那么在卷积输出中存在一个patch,该patch非常接近潜在空间中的第 j j j个原型,这反过来意味着在输入图像中有一个patch具有与第 j j j个原型所表示的相似的语义。

在ProtoPNet中,为每个类 k ∈ { 1 , . . . , K } k\in\left\{1,...,K\right\} k{1,...,K}分配预先确定的原型数量 m k m_{k} mk,实验中每类都为10。类别 k k k对应的原型集为 P k \textbf{P}_{k} Pk

最后,将 m m m个相似性得分乘以全连接层 h h h中的权重矩阵,以产生输出logits,使用softmax对其进行归一化,以产生属于各种类别的预测概率。

训练

ProtoPNet的训练分为:

  • 最后一层之前的层的随机梯度下降;
  • 原型投影;
  • 最后一层的凸优化;

对于第一部分:第一个训练阶段,目标是学习一个有意义的潜在空间,其中用于分类图像的最重要的patch被聚类在图像的真实类的语义相似的原型周围,并且以来自不同类的原型为中心的簇被很好地分离。为了实现该目标,首先使用SGD优化 w c o n v w_{conv} wconv和原型集 P \textbf{P} P,并保持 w h w_{h} wh固定。

D = [ X , Y ] = { ( x i , y i ) } i = 1 n D=[X,Y]=\left\{(x_{i},y_{i})\right\}_{i=1}^{n} D=[X,Y]={(xi,yi)}i=1n为训练集,优化目标是: m i n w c o n v , P 1 n ∑ i = 1 n C E ( h ∘ g p ∘ f ( x i ) , y i ) + λ 1 C l s t + λ 2 S e p C l s t = 1 n ∑ i = 1 n m i n j : p j ∈ P y i m i n z ∈ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 S e p = − 1 n ∑ i = 1 n m i n j : p j ∉ P y i m i n z ∈ p a t c h e s ( f ( x i ) ) ∣ ∣ z − p j ∣ ∣ 2 2 min_{w_{conv},\textbf{P}}\frac{1}{n}\sum_{i=1}^{n}CE(h\circ g_{\textbf{p}}\circ f(x_{i}),y_{i})+\lambda_{1}Clst+\lambda_{2}Sep\\Clst=\frac{1}{n}\sum_{i=1}^{n}min_{j:\textbf{p}_{j}\in\textbf{P}_{y_{i}}}min_{z\in patches(f(x_{i}))}||z-\textbf{p}_{j}||_{2}^{2}\\Sep=-\frac{1}{n}\sum_{i=1}^{n}min_{j:\textbf{p}_{j}\notin\textbf{P}_{y_{i}}}min_{z\in patches(f(x_{i}))}||z-\textbf{p}_{j}||_{2}^{2} minwconv,Pn1i=1nCE(hgpf(xi),yi)+λ1Clst+λ2SepClst=n1i=1nminj:pjPyiminzpatches(f(xi))∣∣zpj22Sep=n1i=1nminj:pj/Pyiminzpatches(f(xi))∣∣zpj22交叉熵损失 C E CE CE惩罚对训练数据的错误分类。 C l s t Clst Clst的最小化鼓励每个训练图像具有接近其自己类的至少一个原型的一些潜在patch,而 S e p Sep Sep的最小化则鼓励训练图像的每个潜在patch远离不属于其自身类的原型。这些项将潜在空间塑造成语义上有意义的聚类结构,这有助于网络基于L2距离分类。

对于 w h w_{h} wh,令 w h ( k , j ) w_{h}^{(k,j)} wh(k,j)表示连接第 j j j个原型和第 k k k个类别的权重,给定类别 k k k,设置符合 p j ∈ P k \textbf{p}_{j}\in \textbf{P}_{k} pjPk j j j w h ( k , j ) = 1 w_{h}^{(k,j)}=1 wh(k,j)=1,而符合 p j ∉ P k \textbf{p}_{j}\notin \textbf{P}_{k} pj/Pk j j j w h ( k , j ) = − 0.5 w_{h}^{(k,j)}=-0.5 wh(k,j)=0.5。可以发现,其实 w h w_{h} wh跟随类别的改变而改变,但这种改变是人工干预的。

对于第二部分:为了能够将原型可视化为训练图像patch,作者将每个原型 p j \textbf{p}_j pj投影(push)到与 p j \textbf{p}_j pj相同类的最近的潜在训练patch上,对于属于类别 k k k的原型 p j \textbf{p}_{j} pj p j ∈ P k \textbf{p}_{j}\in\textbf{P}_{k} pjPk,有: p j ← a r g m i n z ∈ Z j ∣ ∣ z − p j ∣ ∣ 2 Z j = { z ~ : z ~ ∈ p a t c h e s ( f ( x i ) ) ∀ i   s . t .   y i = k } \textbf{p}_{j}\leftarrow argmin_{z\in Z_{j}}||z-\textbf{p}_{j}||_{2}\\ Z_{j}=\left\{\widetilde{z}:\widetilde{z}\in patches(f(x_{i}))\forall i\thinspace s.t.\thinspace y_{i}=k\right\} pjargminzZj∣∣zpj2Zj={z :z patches(f(xi))is.t.yi=k}

对于第三部分:对最后一层 h h h的权重矩阵 w h w_h wh进行凸优化。该优化是凸的,因为固定了来自卷积层和原型层的所有参数。该阶段在不改变学习到的潜在空间或原型的情况下可以进一步提高准确性。优化目标为: m i n w h 1 n ∑ i = 1 n C E ( h ∘ g p ∘ f ( x i ) , y i ) + λ ∑ k = 1 K ∑ j : p j ∉ P k ∣ w h ( k , j ) ∣ min_{w_{h}}\frac{1}{n}\sum_{i=1}^{n}CE(h\circ g_{\textbf{p}}\circ f(x_{i}),y_{i})+\lambda\sum_{k=1}^{K}\sum_{j:\textbf{p}_{j}\notin\textbf{P}_{k}}|w_{h}^{(k,j)}| minwhn1i=1nCE(hgpf(xi),yi)+λk=1Kj:pj/Pkwh(k,j)该阶段的目标是调整最后一层 w h ( k , j ) w_{h}^{(k,j)} wh(k,j),正则化项可以提高负类别推理的稀疏性,这种稀疏性可以降低以下负面推理形式:该鸟属于 k ′ k' k类,因为它不是 k k k类(它包含了一个不是 k k k类原型的patch)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/528443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

同步 Swagger URL问题, 用这个插件就可解决

这个开源的 API 管理工具叫 Postcat, 支持从 Swagger URL 增量同步 API 数据到 Postcat。 使用 进入 API 模块,鼠标移动到主按钮加号,下拉看到从 Swagger 同步 URL 的选项。 填写完配置点击立即同步即可同步 API 数据。 同步规则 新的数据覆盖旧的数据…

PHP语言调用api接口,电商平台商品详情接口(封装可高并发)

PHP是为Web而生的语言,它提供了一些强大的内置函数来处理HTTP请求和响应。PHP为开发人员提供了一些Web开发工具,包括HTML、CSS、JavaScript以及各种数据库的连接和互动。与其他Web开发工具相比,PHP可以更加高效地运转与发挥作用。 PHP表现出…

Matlab 非线性迭代法(3)阻尼牛顿法 L-M

高斯牛顿法详解_我只是一只自动小青蛙的博客-CSDN博客 一、思想 先看一下牛顿高斯迭代法的缺点: 1、在计算的过程中可能会出现奇异矩阵(不满秩),比如:J(k)​)TJ(k) 为病态矩阵的时候就不能得到正确的解,或…

如何提升性能测试效能

上周六应邀在天津devops峰会的质量内建专场做了一次分享,主题是《稳定性保障利器:全链路压测》。 其中关于全链路压测对质量内建的意义,我做了一个总结,如下图所示。本文基于下图做了展开描述,仅供参考。 如何理解性能…

从零开始Vue3+Element Plus后台管理系统(八)——模仿禅道做一个Vue3版本的高级查询组件

暗黑模式 使用 Vue3element Plus 简单模仿了禅道系统的高级搜索组件,说简单也有点复杂,还没有完全开发完,但是大体架子有了,剩下一些功能点继续coding。边开发边记录吧,因为这个相比之前的内容确实复杂一些&#xff0c…

Java的基操,基操(一)

🔥二进制🔥二进制和十进制的转化🔥注释🔥标识符🔥关键字/保留字🔥变量(variable) 🔥二进制 二进制,是计算技术中广泛采用的一种数制,由德国数理哲学大师莱布尼茨于 1679 …

Helm chart 常用命令以及原理和生产实践

问: 到哪里去搜索helm package? 答: artifacthub.io Helm 的实质就是搞一些模版,最终依据这些模版生成k8s的系列yaml文件(deployemnt,service,secret,map等等),从而在k8s上能够简单部署出完整应用。可以用helm template查看最终生成的k8s部署文件。 helm version…

MQTT客户端应用编程及接口分析

MQTT客户端应用编程及接口分析 MQTT协议简介 MQTT是一个基于客户端-服务器的消息发布/订阅传输协议。MQTT协议是轻量、简单、开放和易于实现的,这些特点使它适用范围非常广泛。 客户端服务端安装 1.安装 sudo apt-add-repository ppa:mosquitto-dev/mosquitto-…

SpringCloud_服务注册中心_Consul(八)

SpringCloud_服务注册中心_Consul(八) 分为五部分 Consul简介 安装并运行Consul 服务提供者 服务消费者 三个注册中心异同点 Consul简介 官网:https://developer.hashicorp.com/consul/docs/intro 是Go语言写的 Consul是一套开源的分布式服务发现和配置管理系统&am…

GB50312-2016标准中需要检测的参数(AEMFLUKE)含双绞线和光

很多同学经常搞不清GB50312-2016标准的规定测试参数,或者说和测试设备对不上号。特意从标准中摘抄出来,供大家参考。 ACR-F(Attenuation to Crosstalk Ratio at the Far-end) 衰减远端串音比 ACR-N(Attenuation to Crosstalk Ratio at the Near-end)衰…

用于申威Alpha指令集处理器CModel裸机(不带操作系统)的CoreMark性能测试程序源码编译流程

CoreMark是一个综合基准,用于测量嵌入式系统中使用的中央处理器(CPU)的性能。它是在2009由eembc的shay gal-on开发的,并且试图将其发展成为工业标准,取代过时的dehrystone基准。代码用C编写,包含以下算法:列表处理(增删…

如何在Colab中使用gpu资源(附使用MMdet推理示例)

如何在Colab中“白嫖”gpu资源(附使用MMdet推理示例) Google Colab简介 当今,深度学习已经成为许多人感兴趣的话题,Google Colab(全称为Google Colaboratory)是Google推出的一个强大的云端 notebook&…

《微服务实战》 第七章 Spring Cloud 之 GateWay

前言 API 网关是一个搭建在客户端和微服务之间的服务,我们可以在 API 网关中处理一些非业务功能的逻辑,例如权限验证、监控、缓存、请求路由等。 1、通过API网关访问服务 客户端通过 API 网关与微服务交互时,客户端只需要知道 API 网关地…

UWB智慧工厂人员定位系统源码,人员在岗监控、车辆实时轨迹监控源码

近年来人员定位系统在工业领域的发展势头迅猛,工业识别与定位成为促进制造业数字化的关键技术。通过实时定位可以判断所有的人、物、车的位置。实时定位系统要适用于复杂工业环境,单一技术是很难实现的,需要融合多种不同的定位技术&#xff0…

【hive】hive grouping sets和GROUPING__ID的用法

前言​ GROUPING SETS,GROUPING__ID,CUBE,ROLLUP 这几个分析函数通常用于OLAP中,不能累加,而且需要根据不同维度上钻和下钻的指标统计,比如,分小时、天、月的UV数。 grouping sets根据不同的维度组合进行聚合,等价于…

从事网络安全工作,这五大证书是加分项!

对我们而言,无论从事什么工作,考取相关证书都有非常重要的作用,它是我们找工作时的加分项,同时也是对我们技术水平的验证,那么从事网络安全工作可以考哪些证书?本篇文章为大家介绍一下。 1、CISP 国家注册信息安全专业…

vue3【父子组件间的传值--setup语法糖】

这篇文章主要讲解vue3语法糖中组件传值的用法、 一、父组件给子组件传值 父组件 <template><div classmain>我是父组件<Child :msg"parentMsg"></Child></div></template><script setup> import Child from ./child im…

idea热部署插件JRebel激活

JRebel可以实现在idea中热部署项目&#xff0c;修改后不用重启项目&#xff0c;让开发更丝滑。 JRebel需要激活才可以正常使用。 不想安装服务的可以用我个人部署的服务器注册&#xff0c;不保证稳定哦&#xff0c;有问题可以留言。 安装完插件直接看激活。 http://121.5.183.2…

亲水性Sulfo-Cyanine3 NHS ester水溶性CY3标记活性脂

Sulfo-Cy3是一种荧光染料&#xff0c;可用于生物成像和细胞标记等应用。Sulfo-Cy3是一种含有硫酸基的Cy3染料&#xff0c;具有高度的水溶性和稳定性。Sulfo-Cy3可以与NHS&#xff08;N-羟基琥珀酰亚胺&#xff09;结合&#xff0c;形成Sulfo-Cy3 NHS&#xff0c;这种结合物可以…

微生物常见统计检验方法比较及选择

谷禾健康 微生物组经由二代测序分析得到庞大数据结果&#xff0c;其中包括OTU/ASV表&#xff0c;物种丰度表&#xff0c;alpha多样性、beta多样性指数&#xff0c;代谢功能预测丰度表等&#xff0c;这些数据构成了微生物组的变量&#xff0c;大量数据构成了高纬度数据信息。 针…