Improved Deep Metric Learning with Multi-class N-pair Loss Objective

news2024/11/23 3:25:24

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

来源:

  • NIPS’2016
  • NEC Laboratories America

文章目录

  • Improved Deep Metric Learning with Multi-class N-pair Loss Objective
    • Distance Metric Learning
    • Deep Metric Learning with Multiple Negative Examples
      • N-pair loss for efficient deep metric learning
    • 总结
    • 参考

找到这篇论文是因为看了淘宝搜索出品的论文Rethinking the Role of Pre-ranking in Large-scale E-Commerce 1,文中就提到了传统的list-wise损失 不适用于列表中存在多个正样本的场景。从样本构造的角度来看,这种方式应该也适用于多标签分类。

度量学习一直是我想了解的一个领域,就拿这篇论文做一个开始吧。

Distance Metric Learning

度量学习(metric learning)2,简言之:学习数据的嵌入表示,嵌入具有这样的性质,相似的数据点距离近不相似的数据点距离远。度量学习中常见的两种损失:对比损失和三元组损失,二者形式化的表示:
L c o n t ( x i , x j ; f ) = 1 { y i = y j } ∣ ∣ f i − f j ∣ ∣ 2 2 + 1 { y i ≠ y j } m a x ( 0 , m − ∣ ∣ f i − f j ∣ ∣ 2 ) 2 \mathcal{L}_{cont}(x_i, x_j; f) = \mathbb{1}\{y_i = y_j\}||f_i - f_j||_2^2 + \mathbb{1}\{y_i \neq y_j\}max(0, m - ||f_i - f_j||_2)^2 Lcont(xi,xj;f)=1{yi=yj}∣∣fifj22+1{yi=yj}max(0,m∣∣fifj2)2

L t r i ( x , x + , x − ; f ) = m a x ( 0 , ∣ ∣ f − f + ∣ ∣ 2 2 − ∣ ∣ f − f − ∣ ∣ 2 2 + m ) \mathcal{L}_{tri}(x, x^+, x^-; f) = max(0, ||f - f^+||_2^2 - ||f - f^-||_2^2 + m) Ltri(x,x+,x;f)=max(0,∣∣ff+22∣∣ff22+m)

其中 L c o n t \mathcal{L}_{cont} Lcont为对比损失(现在火起来的对比学习), L t r i \mathcal{L}_{tri} Ltri为三元组损失, f f f表示样本的嵌入。在对比损失中,要求来自同类别的样本距离近,不同类别的样本距离远;三元组损失中要求正( x + x^+ x+)、负( x − x^- x)样本到锚点( x x x,如搜图场景中的查询图)的距离要大于一定的阈值。

度量学习有一些现在很常见的应用,例如人脸识别、搜图等。度量学习的样本中通常只有一个负样本,容易导致收敛速度慢和局部最优的问题。难负样本挖掘(提一嘴:随着更多的实践,愈发觉得数据质量的重要性,如何构造好的数据集是一个值得研究的问题)能够减轻这些问题,但是如何找到难负样本本身就是一个难题。

与常见的三元组损失(triplet loss)中一个锚样本、一个正样本和一个负样本不一样,论文提出了一个 ( N + 1 ) (N+1) (N+1)元组的损失,来使一个正样本与 N − 1 N-1 N1个负样本区分开来。

Deep Metric Learning with Multiple Negative Examples

在三元组损失中,如果要使得损失尽可能低,显然有这么几种情况:

  • 缩短正样本与锚样本的距离;
  • 增大负样本与锚样本的距离;
  • 以上二者的结合。

从三元组损失的计算方式上也可以看出,再一次更新中只会比较锚样本与一个负样本,忽略了其他类别的负样本。这就导致:每次只能使锚样本远离一种负类,或许又被推到其他负类那里去了。最终学习到的嵌入可能会出现这样的情况:锚样本离训练数据中出现较多的负类远,而离某些负类又很近

当然,我们可以为锚样本配很多个三元组,囊括不同类别的负样本,这样在多轮、充足的训练后嵌入能够具有理想的性质。这样做就面临了不稳定以及收敛速度慢的问题。因此,文中就提出了 N + 1 N+1 N+1元组的损失,二者的区别如下图所示:
Triplet loss and (N+1)-tuplet loss

Deep metric learning with (left) triplet loss and (right) (N+1)-tuplet loss.

上图中红色的圆表示负样本,蓝色的表示锚样本和正样本。从左侧可以看出, N + 1 N+1 N+1元组损失的一个很简单的出发点:既然一个负类的样本不够,那就每个负类都拿一个样本出来,组成一个 N + 1 N+1 N+1的元组。但是在类别很多的场景(比如人脸识别),计算的复杂度过高。文章的重点就在于如何设计这样一个计算上可行的损失函数。

下图是三元组损失(a)、 ( N + 1 ) (N+1) (N+1)元组损失(b)及其改进后的损失©的一个对比。 N N N-pair-mc loss(multi-class N-pair loss)损失就是文章最后提出的损失。

N-pair-mc loss
Triplet loss, (N+1)-tuplet loss, and multi-class N-pair loss with training batch construction.

( N + 1 ) (N+1) (N+1)元组损失可以定义如下:
L ( { x , x + , { x i } i = 1 N − 1 } ; f ) = l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) \mathcal{L}(\{x, x^+, \{x_i\}_{i=1}^{N-1}\}; f) = log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) L({x,x+,{xi}i=1N1};f)=log(1+i=1N1exp(fTfifTf+))
N N N等于2的时候该损失是与三元组损失等价的。提一嘴,这个形式和softplus的形式是一样的:
s o f t p l u s ( x ) = l o g ( 1 + e x p ( x ) ) softplus(x) = log(1 + exp(x)) softplus(x)=log(1+exp(x))
( N + 1 ) (N+1) (N+1)元组的损失可以写为如下形式:
l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) = − l o g e x p ( f T f + ) e x p ( f T f + ) + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) = - log \frac{exp(f^T f^+)} {exp(f^T f^+) + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+))} log(1+i=1N1exp(fTfifTf+))=logexp(fTf+)+i=1N1exp(fTfifTf+))exp(fTf+)
这样一看是不是就更顺眼了,这不就是多分类里的softmax loss嘛。

N-pair loss for efficient deep metric learning

论文提出了一种高效的批构造方法,以降低额外的计算开销。方法的名字叫multi-class N N N-pair loss( N N N-pair-mc),其构造方式如上图 ( c )所示。来个说文解字,道一道作者的解决方法。方法名中有个N-pair,就从这入手。假若我们有 N N N个pair:
{ ( x 1 , x 1 + ) , ⋯   , ( x N , x N + } ,   y i ≠ y j , ∀ i ≠ j \{(x_1, x_1^+), \cdots, (x_N, x_N^+\},\ y_i \neq y_j, \forall i \neq j {(x1,x1+),,(xN,xN+}, yi=yj,i=j
每个pair的样本来自不同的类别,在这 N N N个pair的基础上构建 N N N个元组 { S i } i = 1 N \{S_i\}_{i=1}^N {Si}i=1N,其中:
S i = { x i , x 1 + , x 2 + , ⋯   , x N + } S_i = \{x_i, x_1^+, x_2^+, \cdots, x_N^+\} Si={xi,x1+,x2+,,xN+}
其中 x i x_i xi就是锚样本。显然, S i S_i Si就是一个包含了一个 i i i类别正样本, N − 1 N-1 N1个其他类别负样本的 N + 1 N+1 N+1元组了。因此,对于一个由 N N N个查询组成的batch,只需要准备 2 N 2 N 2N个样本,即 N N N个锚样本和 N N N个对应类别的正样本,每个batch只需要** 2 N 2 N 2N次前向计算**样本的嵌入就可以了。而在三元组损失和 N + 1 N+1 N+1元组损失中分别是 3 N 3 N 3N ( N + 1 ) N (N+1) N (N+1)N。因此,对于 N N N个查询组成的batch,其损失可以如下计算:
L N − p a i r − m c ( { ( x i , x i + } i = 1 N ; f ) = 1 N ∑ i = 1 N l o g ( 1 + ∑ j ≠ i e x p ( f i T f j + − f i T f i + ) ) \mathcal{L}_{N-pair-mc}(\{(x_i, x_i^+\}_{i=1}^N ; f) = \frac{1} {N} \sum_{i=1}^N log (1 + \sum_{j \neq i} exp(f_i^T f_j^+ - f_i^T f_i^+)) LNpairmc({(xi,xi+}i=1N;f)=N1i=1Nlog(1+j=iexp(fiTfj+fiTfi+))
以上就是论文的主要内容了,当然论文中还提到了负类别挖掘,这个就暂且不提了。

总结

简言之,这篇论文将度量学习中常见的三元组损失中只有一个负样本扩展到每个样本中包含 N − 1 N-1 N1个负样本,并且为了计算的效率提出了 N N N-pair的batch构造方法以降低计算量。其实,如果在三元组损失的batch中精心设计各种类别样本的配比,比如每个batch只训练一个类别,是否也能达到类似的效果呢?

参考


  1. Rethinking the Role of Pre-ranking in Large-scale E-Commerce, KDD 2023. ↩︎

  2. 漫谈-Distance Metric Learning那些事儿:https://zhuanlan.zhihu.com/p/458114525. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/857634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实战:使用Docker部署Hadoop集群

文章目录 Hadoop简介Hadoop优势Hadoop应用场景docker与docker-compose安装Hadoop集群搭建环境变量docker-compose环境文件树结构编排并运行容器运行wordcount例子 写在最后 Hadoop简介 Hadoop是一个由Apache基金会所开发的分布式系统基础架构。用户可以在不了解分布式底层细节…

ChatGLM2-6B在windows下的部署

2023-08-10 ChatGLM2-6B在windows下的部署 一、部署环境 1、Windows 10 专业版, 64位,版本号:22H2,内存:32GB 2、已安装CUDA11.3 3、已安装Anaconda3 64bit版本 4、有显卡NVIDIA GeForce RTX 3060 Laptop GPU …

AI Deep Reinforcement Learning Autonomous Driving(深度强化学习自动驾驶)

AI Deep Reinforcement Learning Autonomous Driving(深度强化学习自动驾驶) 背景介绍研究背景研究目的及意义项目设计内容算法介绍马尔可夫链及马尔可夫决策过程强化学习神经网络 仿真平台OpenAI gymTorcs配置GTA5 参数选择行动空间奖励函数 环境及软件…

8.10CPI决战日来临,黄金会意外走高吗?

近期有哪些消息面影响黄金走势?黄金多空该如何研判? ​黄金消息面解析:周四(8月10日)亚市早盘,美元指数在102.50维持多头走势,黄金避险情绪消散,金价跌至1916美元,下破1900美元前景深化。周三黄…

如何使用Audition生成固定频率的正弦波

一,简介 本文主要介绍如何使用Audition软件生成固定频率的正弦波进行相关测试验证工作。 二,准备工作 需要安装Audition软件,本次使用的是Adobe Audition CC 2018绿色版。其他版本也都可以,只是步骤上可能有细微的差别。 三&…

山西电力市场日前价格预测【2023-08-11】

日前价格预测 预测明日(2023-08-11)山西电力市场全天平均日前电价为367.15元/MWh。其中,最高日前电价为408.91元/MWh,预计出现在20: 00。最低日前电价为343.90元/MWh,预计出现在02: 30。 价差方向预测 1: 实…

2.UE数字人语音交互(UE数字人系统教程)

上一篇:1.Fay-UE5数字人工程导入 2.UE数字人语音交互(UE数字人系统教程) 1、启动ue数字人 2、下载Fay数字人控制器 Fay数字人控制器下载地址 3、依照说明配置运行Fay 4、启动Fay控制器 5、切换到UE界面开始说话 6、完成了&#xf…

(学习笔记-进程管理)进程调度

进程都希望自己能够占用CPU进行工作,那么这涉及到前面说过的进程上下文切换。 一旦操作系统把进程切换到运行状态,也就意味着该进程占用着CPU在执行,但是操作系统把进程切换到其他状态的时候,就不能在CPU中执行了,于是…

力扣真题:200. 岛屿数量(两种实现方法)

java代码实现: 第一种: 用了类似感染的方法,就是一个节点出发,如果此时这个节点没被感染,且是陆地,就可以进入遍历,将其邻接的陆地全部遍历一遍,标志数组sign相应位置至为1.然后一…

pdf怎么压缩到1m?这样做压缩率高!

PDF是目前使用率比较高的一种文档格式,因为它具有很高的安全性,还易于传输等,但有时候当文件体积过大时,会给我们带来不便,这时候简单的解决方法就是将其压缩变小。 想要将PDF文件压缩到1M,也要根据具体的情…

【LeetCode】丑数题目合辑

文章目录 263. 丑数思路代码 264. 丑数 II方法一:最小堆思路代码 方法二:动态规划(三指针法)思路代码 1201. 丑数 III方法:二分查找 容斥原理思路代码 313. 超级丑数方法:“多路归并”思路代码 总结参考资…

如何压缩照片?一看就会的压缩方法

压缩照片是再正常不过的需求了,比如上传个证件照,要求在20k以内,那么超过这个大小的照片我们就必须进行压缩处理,其实现在压缩照片的方法也特别多,不论是压缩软件、图片编辑软件,甚至在线网站都能搞定。 下…

SpringBoot 将项目打包成 jar 包

SpringBoot 将项目打包成 jar 包 一、项目打包成 jar 包 首先在 pom.xml 文件中导入 Springboot 的 maven 依赖 <!-- 将应用打包成一个可以执行的 jar 包 --> <build><plugins><plugin><groupId>org.springframework.boot</groupId><…

go-admin 使用开发

在项目中使用redis 作为数据缓存&#xff1a;首先引入该包 “github.com/go-redis/redis/v8” client : redis.NewClient(&redis.Options{Addr: config.QueueConfig.Redis.Addr, // Redis 服务器地址Password: config.QueueConfig.Redis.Password, // Redis 密码&…

【LeetCode 75】第二十五题(735)行星碰撞

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码运行结果&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给一个数组&#xff0c;数组里的元素表示行星&#xff0c;元素的符号决定行星运动的方向&#xff0c;元素的绝对值决定行星的大小…

谷歌发布RT-2大模型,让机器人像人类那样思考

原创 | 文 BFT机器人 大语言模型是指基于深度学习技术的大规模预训练模型&#xff0c;它能够通过学习大量的文本数据来生成人类类似的语言表达&#xff0c;机器人可以通过对大量的语言数据进行学习&#xff0c;从中掌握人类的语言表达方式&#xff0c;进而能够更好地与人进行交…

冠达管理:股票注册制通俗理解?

目前我国A股商场正在进行股票注册制变革&#xff0c;相较之前的发行准则&#xff0c;股票注册制在理念上更为商场化&#xff0c;这意味着公司发行股票的门槛将下降&#xff0c;公司数量将添加&#xff0c;而股票流通的方式也将有所改变。那么股票注册制指的是什么&#xff0c;它…

日常报错记录

日常报错记录 Shorten the command line via JAR manifest or via a classpath file and rerun 解决方法如下&#xff1a;

代码分析:waitpid的使用,非阻塞轮回检测技术

wait 函数 wait函数的作用是父进程调用&#xff0c;等待子进程退出&#xff0c;回收子进程的资源&#xff1b; #include<sys/types.h> #include<sys/wait.h> pid_t wait(int*status);返回值&#xff1a; 成功返回被等待进程pid&#xff0c;失败返回-1。 参数&…

B2B2C小程序商城系统--跨境电商后台数据采集功能开发

搭建一个B2B2C小程序商城系统涉及到多个步骤和功能开发&#xff0c;其中包括跨境电商后台数据采集功能的开发。具体搭建步骤如下&#xff1a; 一、系统搭建 1. 确定需求和功能&#xff1a;根据B2B2C商城的需求&#xff0c;确定系统的功能和模块&#xff0c;包括商品管理、订单…