度量学习:使用多类N对损失改进深度度量学习

news2024/11/22 17:16:25

@度量学习系列

Author: 码科智能

使用多类N对损失改进深度度量学习

度量学习是ReID任务中常用的方式之一,今天来看下一篇关于如何改进度量学习的论文。来自2016年NeurIPS上的一篇论文,被引用超过900次。

论文:Improved Deep Metric Learning with Multi-class N-pair Loss Objective.
链接:论文.

1. 对比损失和三重损失 度量学习

  • 令 x ∈ X 为输入数据,f ∈ {1, …, L} 为其输出标签。
  • f+ 和 f- 分别表示 f 的正例和负例,意思是 f 和 f+ 属于同一类,f- 属于 f 的不同类。

1.1. 对比损失

  • 对比损失将成对的样本作为网络模型的输入,通过训练网络来预测两个输入是否来自同一类。

在这里插入图片描述

  • 其中 m 是一个边距参数,它强制来自不同类的样本之间的距离大于 m。

1.2. 三重损失

  • Triplet loss 与 contrastive loss 具有相似的原理,但其由三元组组成,每个三元组由一个查询、一个正例(同查询一个类别)和一个负例组成:

在这里插入图片描述

  • 与contrastive loss相比,triplet loss只需要正例与查询样本的相似度和负例与查询点的相似度之差大于margin即可(即上述的边距参数m)。

  • Triplet loss 的作用是拉近正样本 f+ ,同时推开负样本 f- 。

  • 对比损失或三元组损失已用于许多应用,例如人脸识别和图像检索,例如DrLIM、DeepFace、DeepID2、FaceNet。但此类框架通常存在收敛速度慢和局部最优值差的问题,部分原因是损失函数在每次更新时仅使用一个负样本,而不与其他负样本交互。

  • Hard negative data mining 可以缓解这个问题,但是 hard negative example search 在网络训练中带来额外的时间开销。

2. (N+1)-Tuplet Loss for Multiple Negative Examples

在这里插入图片描述

  • 如上所示,(N+1)-tuplet loss 根据它们与输入样本的相似性,一次性推送 N-1 个负样本。
  • f+ 是 f 的正例(蓝色圆圈),{f2, …, fN-1} 是负例(粉色圆圈)。 (N+1)-tuplet 损失为:

在这里插入图片描述

  • 当 N=2 时,对应的 (2+1)-tuplet loss 与 triplet loss 非常相似,因为每对输入和正例只有一个负例:

在这里插入图片描述

  • 当 N>2 时,进一步论证了 (N+1)-tuplet loss 相对于 triplet loss 的优势。 根据理想 (L+1)-tuplet 损失的分配函数估计,将 (N+1)-tuplet 损失与三重损失进行比较,其中 (L+1)-tuplet 损失与每个负类的单个样本相结合,可以写成如下:

在这里插入图片描述

  • 回想一下,L 是类别的总数,上面的等式类似于多类逻辑损失(即 softmax 损失)。在监督学习里指的是这个数据集一共有多少类别,比如CV的ImageNet数据集有1000类,L就是1000。在度量学习中每个样本都应该有一个类别,那么在扩大数据规模时,比如当向量的维度是几百万的时候,计算复杂度是相当高的。

  • 为了克服这个问题,提出了一种高效的批量构建方法,它只需要 2N 个示例而不是 (N+1)N 来构建长度为 N+1 的 N 个元组。

3. N-pair Loss as Efficient Batch Construction Method

在这里插入图片描述

  1. Triplet Loss:对于一个f,有一个f+和一个f-。 Batch size N,一个batch需要N个f,有N个f+和N个f-。
  2. (N+1)-Tuplet Loss:对于一个f,有一个f+和N-1个f-。 总共有 N+1 个例子。 当 SGD 的 batch size 为 N 时,一次更新有 N(N+1) 个样本要通过 f。由于每个批次要评估的示例数量以二次方方式增长,因此为非常深的卷积网络扩展训练再次变得不切实际。
  3. N-pair-mc 损失:多类 N-pair 损失 (N-pair-mc),可以表示为:

在这里插入图片描述

  • 提出的 N-pair-mc 损失是一个新颖的损失,由两个不可或缺的组成部分组成:(N+1)-tuplet 损失,作为构建块损失函数,以及 N-pair 构造,作为实现高度可扩展训练的关键。这意味着每个 f 的每个正 f+ 将变成另一个 f 的 f-,如上图 © 所示。

4. 难负类挖掘和正则化

  • 难负数据挖掘被认为是许多基于三元组的距离度量学习算法的重要组成部分。在这里,提出了负“类”挖掘,而不是负“实例”挖掘,后者以相对有效的方式贪婪地选择负类。
  • N-pair loss的负类挖掘可以按如下方式执行:
    1. Evaluate Embedding Vectors:随机选择大量的输出类C;对于每个类,随机传递一些(一个或两个)示例来提取它们的嵌入向量。
    2. 选择负类:从步骤 1 的 C 个类中随机选择一个类。接下来,贪婪地添加一个违反三重态约束的新类。选定的数量直到我们达到 N 个类别数。当出现平局时,我们随机选择一个平局类。
    3. 完成 N 对:从步骤 2 中选择的每个类中抽取两个示例。
    4. 此外,L2 范数正则化用于将嵌入向量的 L2 范数正则化为较小的。

5. 人脸验证和识别的实验结果

  • 人脸验证和识别是判断两张人脸图像是否为相同身份的问题(验证)和从具有许多负样本的图库中识别相同身份的人脸图像的问题(识别)。

  • 网络在 WebFace 数据库上进行训练,该数据库由来自 10,575 个身份的 494,414 张图像组成,并且使用不同度量学习目标训练的嵌入网络的质量在 Labeled Faces in the Wild (LFW) 数据库上进行评估。
    在这里插入图片描述

  • 上述几个指标分别为LFW 数据集上的平均验证准确度 (MRF)、Rank-1 准确度和DIR@FAR=1% 开集识别率

  • Triplet loss 模型显示了 95.88% 的验证准确率,但在识别任务上表现不佳。N-pair-mc 损失模型显着提高了性能。 此外,通过将 N 增加到 320,可以观察到额外的改进,获得 98.33% 的验证、90.17% 的封闭集和 71.76% 的开放集识别精度。

6. N-pair-mc Loss 代码

// N-pair loss
import torch
import torch.nn.functional as F

class NPairMCLoss(torch.nn.Module):
    def __init__(self, margin=0.1):
        super(NPairMCLoss, self).__init__()
        self.margin = margin

    def forward(self, anchors, positives, negatives):
        # 计算anchor和positive之间的距离
        pos_distance = F.pairwise_distance(anchors, positives)
        
        # 计算anchor和negative之间的距离
        neg_distance = F.pairwise_distance(anchors, negatives)

        # 计算损失函数
        loss = torch.mean(torch.relu(pos_distance - neg_distance + self.margin))
        return loss

// 调用示例

# 创建NPairMCLoss对象
loss_fn = NPairMCLoss(margin=0.1)

# 假设有一批输入数据 anchors, positives, negatives
anchors = torch.randn(16, 128)
positives = torch.randn(16, 128)
negatives = torch.randn(16, 128)

# 计算损失
loss = loss_fn(anchors, positives, negatives)

# 打印损失值
print("Loss:", loss.item())

请关注博主,一起玩转人工智能及深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/593730.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

针对蓝桥杯竞赛(Python)的基础知识 No.1

首先我们要知道Python是有着大量的库(模块、类、函数)的,所谓善借其器,善用其利 Q1、日期问题 掌握 datetime库 eg:小蓝每周六、周日都晨跑,每月的 1、11、21、31日也晨跑。其它时间不晨跑。已知 2022年1月1日是周六&…

Allegro输出光绘文件规范

光绘输出操作规范 1.1添加钻孔表 添加钻孔表的具体步骤为: 1.通过屏幕右边的Visibility选项的Views列表,将Drill层打开 2.将Visibility选项中的PIN和Via选项都选中,见下图所示: 1.2添加钻孔文件 参数设好之后关闭NC Drill/Parameters窗口,输出数控机床钻孔文件的命令…

1130 Infix Expression(34行代码+超详细注释)

分数 25 全屏浏览题目 切换布局 作者 CHEN, Yue 单位 浙江大学 Given a syntax tree (binary), you are supposed to output the corresponding infix expression, with parentheses reflecting the precedences of the operators. Input Specification: Each input fil…

练习Vue烘培坊项目

烘培坊项目 文章目录 烘培坊项目项目概述项目页面展示后台管理页面登录页面文章详情页面稿件发布页面 项目关键代码实现后台管理页面稿件管理页面内容列表页面文章详情页面烘培坊主页面注册页面登录页面个人信息页面稿件发布页面 项目概述 烘培坊(Bakery&#xff0…

WTI纽约原油CFD期货怎么交易?交易方法有哪些?

我们通常把未加工过的石油称为原油,原油也有“黑色黄金”之称。原油的用途无处不在,无论是工业制品或者日常生活用品等都离不开原油。原油一般以“桶(barrel)”作为单位,1桶约等于159升。在国际上影响力较大的基准原油…

教会你----如何烧录Arduinod代码进入ESP8266 MCU中,让你清楚的了解这个烧录方式的正确操作。

本次开发板为ESP8266 MUC 以下视频是烧录的操作,专给小白的视频 . .分隔符....................................................................................................... . .主要在 RST按久一点, 在放手的一瞬间接着按下 Flash …

电商网站的构建思维和技术

电商网站的架构及技术 3.1框架和技术 本系统主要以.net框架和C#语言位主要的开发工具,前端使用QUI前端框架。技术插件有Redis集群缓存、RabbitMQ 消息、MySql数据库。 实际上,在电商系统中,大部分数据都是可以缓存的,不能使用缓…

影响布伦特原油CFD期货行情的因素有哪些?

原油有很多种,其中比较有知名度的是布伦特原油,该原油是欧洲的原油产品,后来相继的有北海、地中海、也门以及非洲等诸多国家和地区以此为标准推出该产品。在国际金融市场中,布伦特原油特指洲际交易所(ICE)的…

flink1.17.0 集成kafka,并且计算

前言 flink是实时计算的重要集成组件,这里演示如何集成,并且使用一个小例子。例子是kafka输入消息,用逗号隔开,统计每个相同单词出现的次数,这么一个功能。 一、kafka环境准备 1.1 启动kafka 这里我使用的kafka版本…

驾校驾考APP开发功能有哪些?

随着汽车成为越来越多人出行的代步工具之后,需要考驾照的人也是越来越多。小编记得我当初考驾照的时候还是抱着一个小本本每天刷题练习,小本本都快翻烂了。移动互联网的普及让驾考也开始走向线上,刷题、模拟、甚至是考试都可以通过驾考APP小程…

【csdn AI写作助手能帮助我们做什么呢?】

CSDN AI写作助手上线了!InsCode AI 创作助手不仅能够帮助用户高效创作文章,而且能够作为对话式AI回答你想知道的问题。成倍提高生产力! 一、你平时会使用这类AI工具吗?你对这类型的工具有什么看法? 提示:根…

Ubuntu离线安装Vsftp

这是资源包:(14条消息) unbuntu-vsftp.server-Linux文档类资源-CSDN文库 一、安装vsftp 将包解压,然后在解压报的目录下一键安装 dpkg -i *.deb // 安装所有 systemctl status vsftpd #查看运行状态 systemctl restart vsftpd #重新启动vsftp 二、…

【论文阅读公式推导1】连续体机器人的哈密尔顿动力学推导

推导了一下论文哈密尔顿原理的表达,原论文的计算公式是对的,记录一下。 Gravagne I A, Rahn C D, Walker I D. Good vibrations: a vibration damping setpoint controller for continuum robots[C]//Proceedings 2001 ICRA. IEEE International Confer…

[网站分享]

Element-ui Element - The worlds most popular Vue UI frameworkElement,一套为开发者、设计师和产品经理准备的基于 Vue 2.0 的桌面端组件库https://element.eleme.cn/#/zh-CN Vant Weapp Vant Weapp - 轻量、可靠的小程序 UI 组件库轻量、可靠的小程序 UI 组件…

没有数学基础可以学编程吗?

一、为什么学编程 这里我并不是问大家,是因为兴趣啊还是就业学编程。 而是,我想要学Python为了量化交易,或者我要处理表格。我想要学Java我就想自己建站。是否有这种非常明确的目标,有目标才能明确学习路线。 如果在这里&#…

大数据:HDFS操作的客户端big data tools和NFS

大数据:HDFS操作的客户端 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招人,可能很多算法学生都得去找开发,测开 测开的话,你就得学数据库,sql,oracle,尤…

Vue--》Vue3打造可扩展的项目管理系统后台的完整指南(三)

今天开始使用 vue3 ts 搭建一个项目管理的后台,因为文章会将项目的每一个地方代码的书写都会讲解到,所以本项目会分成好几篇文章进行讲解,我会在最后一篇文章中会将项目代码开源到我的GithHub上,大家可以自行去进行下载运行&…

Pytorch入门(二)神经网络的搭建

torch.nn中的nn全称为neural network,意思是神经网络,是torch中构建神经网络的模块。 文章目录 一、神经网络基本骨架二、认识卷积操作三、认识最大池化操作四、非线性激活五、线性层及其它层介绍六、简单的神经网络搭建七、简单的认识神经网络中的数值计算八、损失…

mmdetection训练coco数据集(继跑通后的一些工具使用)

(仅做个人过程记录的笔记) 1、生成中间件 可以选择评估方式 --eval ,对于 COCO 数据集,可选 bbox 、segm、proposal 。可以得到result.bbox.json文件 生成pkl文件:faster_rcnn.pkl python tools/test.py config.py …

利用栈和队列共同解决迷宫问题

文章目录 什么是迷宫问题?如何解决迷宫问题?DFS(深度优先搜索)BFS(广度优先搜索) 总结 什么是迷宫问题? 迷宫问题是一道经典的算法问题,旨在寻找一条从起点到终点的最短路径。通常迷…