ReID的各种Loss的pytorch实现

news2025/1/13 2:49:16

为了提高ReID的性能通常会采用softmax loss 联合 Triplet Loss和Center Loss来提高算法的性能

本文对Triplet Loss和Cnetr Loss做一个总结,以简洁的方式帮助理解。

Triplet Loss和Center Loss都是从人脸识别领域里面提出来的,后面在各种图像检索任务中被广泛应用。

想要了解Triplet Loss和Center Loss算法原文的可以看《FaceNe: Triplet Loss》、 《Center Loss》,对论文做了详细翻译。

1. Triplet Loss

1.1 原理

如上图所示,Triplet Loss 是有一个三元组<a, p, n>构成,其中

a: anchor 表示训练样本。

p: positive 表示预测为正样本。

n: negative 表示预测为负样本。

    triplet loss的作用:用于减少positive(正样本)与anchor之间的距离,扩大negative(负样本)与anchor之间的距离。基于上述三元组,可以构建一个positive pair <a, p>和一个negative pair <a, n>。triplet loss的目的是在一定距离(margin)上把positive pair和negative pair分开。

  所以我们希望:D(a, p) < D(a, n)。进一步希望在一定距离上(margin) 满足这个情况:D(a, p)  + margin  <  D(a, n)

对于一个样本经过网络有: 

 训练时有这么几种情况:

(a)easy triplets:loss = 0,D(a, p) + margin < D(a, n),positive pair 的距离远远小于于negative pair的距离。即,类内距离很小,类间很大距离,这种情况不需要优化。

(b)hard tripletsD(a, n)   <  D(a, p) ,positive pair 的距离大于于negative pair的距离,即类内距离大于类间距离。这种情况比较难优化。

(c)semi-hard tripletsD(a, p) < D(a, n) < D(a, p) + margin。positive pair的距离和negative pair的距离比较高近。即,<a, p>和<a, n>很近,但都在一个margin内。比较容易优化。

当为 semi-hard triplets 时, D(a, p) + margin -  D(a, n) > 0产生loss。得到要优化的损失函数。
 

对于Triplet Loss的梯度: 

训练的时候:早期为了网络loss平稳,一般选择easy triplets进行优化,后期为了优化训练关键是要选择hard triplets,他们是活跃的,因此可以帮助改进模型。

1.2 代码实现

class TripletLoss(nn.Module):
    """
    Triplet loss with hard positive/negative mining.
    
    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    
    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.
    
    Args:
        margin (float, optional): margin for triplet. Default is 0.3.
    """

    def __init__(self,margin = 0.3,gloal_feat,labels):
        super(TripletLoss,self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin = margin)

    def forward(self,inputs,targets):
        """
        Args:
            inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
            targets (torch.LongTensor): ground truth labels with shape (num_classes).
        """
        n = inputs.size(0)
        
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs,2).sum(dim = 1,keepdim = True).expand(n,n)
        dist = dist + dist.t()
        dist.addmn_(1,-2,inputs,inputs.t())
        dist = dist.clamp(min = 1e - 12).sqrt()  # for numerical stability       
        
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n,n).eq(targets.expand(n,n).t())
        dist_ap,dist_an = [],[]
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        return self.ranking_loss(dist_an,dist_ap,y)

训练的时候对每一个样本选择hardest triplet进行训练。

2. Triplet Loss

2.1 原理

center loss是在triplet之后提出来的。triplet学习的是样本间的相对距离,没有学习绝对距离,尽管考虑了类间的离散性,但没有考虑类内的紧凑性。对于triplet loss举一个例子。设margin = 0.3,D(a, p) = 0.3 , D(a, n) = 0.5 得triplet loss = 0.1。而当D(a, p) = 1.3 D(a, n) = 1.5时,triplet loss仍然等于0.1,这相当于,内类之间不够紧凑(距离还不够小)。

所以Center Loss希望可以通过学习每个类的类中心,使得类内的距离变得更加紧凑。

表示深度特征的第类中心。理想情况下, 应该随着深度特征的变化而更新。

训练时:

第一是基于mini-batch执行更新。在每次迭代中,计算中心的方法是平均相应类的特征(一些中心可能不会更新)。

第二,避免大扰动引起的误标记样本,用一个标量 α 控制中心的学习速率,一般这个α 很小(如,0.005)。

计算 相对于的梯度和的更新方程为

2.2 代码实现

class CenterLoss(nn.Module):
    """Center loss.
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
   def __init__(self,num_classes = 751,feat_dim = 2048,use_gpu = True):
        super(CenterLoss,self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes,self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes,self.feat_dim))

    def forward(self,x,labels):
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim).
            labels: ground truth labels with shape (num_classes).
        """
        assert x.label(0) == labels.size(0) "features.size(0) is not equal to labels.size(0)"
        batch_size = x.size(0)
        dismat = torch.pow(x,2).sum(dim = 1,keepdim = True).expand(batch_size,self.num_classes) + \
                 torch.pow(self.centers,2).sum(dim  = 1, keepdim = True).expand(self.num_classes,batch_size).t()
        dismat.addmm_(1,-2,x,self.centers.t()) 

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size,self.num_classes)
        mask = labels.eq(classes.expand(batch_size,self.num_classes))
        print(mask)

        dist = []
        for i in range(batch_size):
            print(mask[i])
            value = dismat[i][mask[i]]
            value = value.clamp(min = 1e - 12,max = 1e +12) #for numerical stability
            dist.append(value)
        dist = torch.cat(dist)
        loss = dist.mean()
        return loss      

3. OIM Loss

3.1 原理

4. Circle Loss

Circle Loss是Triplet Loss的改进版

4.1 原理

5. ArcFace Loss

5.1 原理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1096945.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

笔记本Win10系统一键重装操作方法

笔记本电脑已经成为大家日常生活和工作中必不可少的工具之一&#xff0c;如果笔记本电脑系统出现问题了&#xff0c;那么就会影响到大家的正常操作。这时候就可以考虑给笔记本电脑重装系统了。接下来小编给大家介绍关于一键重装Win10笔记本电脑系统的详细步骤方法。 推荐下载 系…

遗传算法------微生物进化算法(MGA)

前言 该文章写在GA算法之后&#xff1a;GA算法 遗传算法 (GA)的问题在于没有有效保留好的父母 (Elitism), 让好的父母不会消失掉. Microbial GA (后面统称 MGA) 就是一个很好的保留 Elitism 的算法. 一句话来概括: 在袋子里抽两个球, 对比两个球, 把球大的放回袋子里, 把球小…

Qt中各个功能模块遵循的协议

Qt 中各个模块的协议&#xff0c;是在变化的&#xff0c;并不是一成不变 不同版本&#xff0c;协议有可能会变。同一版本&#xff0c;在不同时间期间&#xff0c;协议也可能会变 具体以官网为准

搜索引擎站群霸屏排名源码系统+关键词排名 前后端完整的搭建教程

开发搜索引擎站群霸屏排名系统是一项重要的策略&#xff0c;通过在搜索引擎中获得多个高排名站点&#xff0c;可以大大提高企业的品牌知名度&#xff0c;从而吸引更多的潜在客户和消费者。而且当潜在客户在搜索结果中看到多个与您的品牌相关的站点时&#xff0c;他们可能会认为…

EtherCAT报文-BRD(广播读)抓包分析

0.工具准备 1.EtherCAT主站 2.EtherCAT从站&#xff08;本文使用步进电机驱动器&#xff09; 3.Wireshark1.EtherCAT报文帧结构 EtherCAT使用标准的IEEE802.3 Ethernet帧结构&#xff0c;帧类型为0x88A4。EtherCAT数据包括2个字节的数据头和44-1498字节的数据。数据区由一个或…

ST-SSL:基于自监督学习的交通流预测模型

文章信息 文章题为“Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction”&#xff0c;是一篇发表于The Thirty-Seventh AAAI Conference on Artificial Intelligence (AAAI-23)的一篇论文。该论文主要针对交通流预测任务&#xff0c;结合自监督学习&#…

EtherCAT报文-BWR(广播写)抓包分析

0.工具准备 1.EtherCAT主站 2.EtherCAT从站&#xff08;本文使用步进电机驱动器&#xff09; 3.Wireshark1.EtherCAT报文帧结构 EtherCAT使用标准的IEEE802.3 Ethernet帧结构&#xff0c;帧类型为0x88A4。EtherCAT数据包括2个字节的数据头和44-1498字节的数据。数据区由一个或…

【2023研电赛】全国技术竞赛一等奖:基于FPGA的超低时延激光多媒体终端

该作品参与极术社区组织的研电赛作品征集活动&#xff0c;欢迎同学们投稿&#xff0c;获取作品传播推广&#xff0c;并有丰富礼品哦~ 基于FPGA的超低时延激光多媒体终端 参赛单位&#xff1a;华东师范大学 指导老师&#xff1a;刁盛锡 参赛队员&#xff1a;王泽宇 谢祖炜 秦子淇…

解读 | 自动驾驶系统中的多视点三维目标检测网络

原创 | 文 BFT机器人 01 背景 多视角三维物体检测网络&#xff0c;用于实现自动驾驶场景高精度三维目标检测&#xff0c;该网络使用激光雷达点云和RGB图像进行感知融合&#xff0c;以预测定向的三维边界框&#xff0c;相比于现有技术&#xff0c;取得了显著的精度提升。同时现…

【重要!合规政策更新】英国,儿童玩具相关产品卖家,请及时关注!EN71

合规政策更新&#xff01; 尊敬的卖家&#xff1a; 您好&#xff01; 我们此次联系您是因为您正在销售需要审批流程的商品。为此&#xff0c;亚马逊正在实施审批流程&#xff0c;以确认我们网站上提供的商品类型须符合指定的认证标准。要在亚马逊商城销售这些商品&#xff0c;您…

第六篇Android--ImageView、Bitmap

ImageView&#xff0c;和前面介绍的TextView、EditText&#xff0c;都继承自View都是View的子类。 ImageView 是用于呈现图片的视图。View可以理解为一个视图或控件。 1.简单使用 在drawable-xxhdpi文件夹下放一张图片&#xff1a; xml中把这张图片设置给ImageView&#xff0…

MySQL单表查询基础综合练习

一、单表查询 素材&#xff1a; 表名&#xff1a;worker-- 表中字段均为中文&#xff0c;比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float(8,2) NOT NULL, 政治面貌 v…

三勾知识付费(PHP+vue3)微信小程序平台+SAAS+前后端源码

项目介绍 三勾小程序商城基于thinkphp8element-plusuniapp打造的面向开发的小程序商城&#xff0c;方便二次开发或直接使用&#xff0c;可发布到多端&#xff0c;包括微信小程序、微信公众号、QQ小程序、支付宝小程序、字节跳动小程序、百度小程序、android端、ios端。 软件架…

BUUCTF学习(二):一起来撸猫

1、介绍 2、解题 &#xff08;1&#xff09;查看网页源代码 &#xff08;2&#xff09;解读代码内容 &#xff08;3&#xff09;得出结论 网址&#xff1a;一起来撸猫http://df4c147d-c7f4-4aac-a9d6-fdce2606ee18.node4.buuoj.cn:81/?catdog 第二题结束

PyTorch入门教学——在虚拟环境中安装Jupyter

1、简介 Jupyter Notebook是一个开源的web应用程序&#xff0c;可以使用它来创建和共享包含实时代码、方程、可视化和文本的文档。Jupyter Notebook是一个交互式笔记本&#xff0c;可以当作python编译器来使用。 2、安装 在安装Anaconda时是自带了Jupyter Notebook的&#x…

Unity第一人称移动和观察

创建一个可以自由移动的第一人称视角 人物通过WSAD进行前后左右移动&#xff0c;通过鼠标右键进行旋转 Step1:创建一个Player玩家&#xff0c;在节点下加两个子物体&#xff0c;一个摄像头和一个Capsule充当身体 Step2:创建一个脚本挂载在Player节点下&#xff0c;再在这个Pl…

ThreeJS-3D教学十-有宽度的line

webgl中线是没有宽度的&#xff0c;现实的应用中一般做法都是将线拓宽成面来绘制。默认threejs的线宽是无法调节的&#xff0c;需要用有厚度的线 THREE.Line2。 先看效果图&#xff1a; 看下代码&#xff1a; <!DOCTYPE html> <html lang"en"> <he…

2022年03月 Python(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

Python编程&#xff08;1~6级&#xff09;全部真题・点这里 C/C编程&#xff08;1~8级&#xff09;全部真题・点这里 一、单选题&#xff08;共25题&#xff0c;每题2分&#xff0c;共50分&#xff09; 第1题 关于Python中的列表&#xff0c;下列描述错误的是?&#xff08; …

BUUCTF在线评测简介

1、网站 BUUCTF在线评测简介 BUUCTF在线评测 BUUCTF在线评测BUUCTF 是一个 CTF 竞赛和训练平台&#xff0c;为各位 CTF 选手提供真实赛题在线复现等服务。https://buuoj.cn/challenges 2、介绍 3、学习步骤 学习从这里开始&#xff01;

深入剖析 深度学习中 __init()__函数和forward()函数

目录 前言1. __init()__函数2. forward()函数3. 两者关系 前言 再看代码时&#xff0c;发现init函数和forward函数都有参数&#xff0c;具体是怎么传参的呢&#xff1f; 为了更方便的讲解&#xff0c;会举简单的代码例子结合讲解。 forward() 和 __init__() 是神经网络模型类…