【人脸识别】CurricularFace:自适应课程学习人脸识别损失函数

news2024/9/26 3:28:49

论文题目:《CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition》
论文地址:https://arxiv.org/pdf/2004.00288v1.pdf
代码地址:https://github.com/HuangYG123/CurricularFace

建议先了解下这篇文章:MV-softmax

1.背景

       人脸识别中常用损失函数主要包括两类,基于间隔和难样本挖掘,这两种方法损失函数的训练策略都存在缺陷。前一种方法是对所有样本都采用一个固定的间隔值,没有充分利用每个样本自身的难易信息,这可能导致在使用大边际时出现收敛问题;后一种方法则在整个网络训练周期都强调难样本,可能出现网络无法收敛问题。在本论文中,提出了一种新的自适应课程学习损失函数,称为CurricularFace,它能够很好地解决上述两类损失函数存在的问题。
       下图是CurricularFace跟ArcFace和 MV-Arc-Softmax两种方法的对比,可以看到CurricularFace的优势还是很明显的,通过自适应的方式实现,在早期突出易样本的作用(红色虚线),而在晚期突出难样本的作用(红色实线)

在这里插入图片描述

注:Curriculum Learning即课程学习,它是由Montreal大学的Bengio教授团队在2009年的ICML上提出的,其主要思想是模仿人类学习的特点,按照从简单到困难的程度来学习课程,这样容易使模型找到更好的局部最优,同时加快训练速度。

– MV-Sotamax存在的问题:从training起始阶段就开始强调semi-hard/hard-sample,可能会导致模型的收敛问题!

easy sample first, hard sample later!

2.方法

       论文中提出的一种新的自适应课程学习损失CurricularFace,是将课程学习的思想嵌入到损失函数中,以实现一种新的深度人脸识别训练策略。该策略主要针对早期训练阶段的易样本和后期训练阶段的难样本,使其在不同的训练阶段,通过一个课程表自适应地调整简单和困难样本的相对重要性。也就是说,在每个阶段,不同的样本根据其相应的困难程度被赋予不同的重要性。
       由于人类学习的本质是先易后难,CurricularFace是以一种适应性的方式将课程学习的理念融入到人脸识别中,这与传统的认知有两处明显不同:
       1)首先,课程设计的自适应性。在传统的课程学习中,样本是按照相应的难易程度排序的,这些难易程度往往是由先验知识定义的,然后固定下来建立课程。而在CurricularFace中,做法是由每个Batch随机抽取样本,通过在线挖掘难样本自适应地建立课程
       2)其次,难样本的重要性是自适应的。一方面,易样本和难样本的相对重要性是动态的,可以在不同的训练阶段进行调整。另一方面,当前Batch中每一个难样本的重要性取决于其自身的难易程度。
       具体来看,文中选择Batch中的被误分类样本作为难样本,通过调整样本与假类别中心向量之间的余弦相似度的调制系数来加权。为了在整个训练过程中实现自适应课程学习的目标,论文设计了一种新的系数函数,该函数包括以下两个因子:
       1)自适应估计参数t,该参数利用样本和其真类别间的Positive余弦相似度的移动平均值来实现自适应,以消除人工调整的负担。
       2)余弦角度参数,该参数定义难样本实现自适应分配的的难易性。

       上面介绍完了CurricularFace的基本原理,我们来看下其损失函数是如何定义的,如下:
在这里插入图片描述
其中,T(cos(θ_y)) = cos(θ_y + m), I (t, cos(θ_j))表示样本的权重函数,N(t, cos(θ_j))定义如下:

在这里插入图片描述

Adaptive Estimation of t.
       在不同的训练阶段决定一个恰当的t的值是十分重要的。理想情况下,t的值能够指示模型的训练阶段。我们通过经验发现正cosine相似度的平均值是一个好的指示器。可是min-batch的基于统计的方法往往面临一个问题:当许多极端数据被采样到一个mini-batch时,统计可能是一个很大的噪声,估计值可能很不稳定。Exponential Moving Average (EMA)方法是一个常用的解决该问题的方法,假设r(k)是第k个batch的正cosine相似度的平均值,r^(0) = 0,即:

在这里插入图片描述
则有(t^(k)随着k的增加,会呈现出单调递增的趋势):
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
Note : (a, b), a表示在训练过程中[某个时刻] curricular_loss和arcface-loss的比值;b表示max {cos(θ_j), j ≠ yi}

3.训练

3.1.训练步骤

在这里插入图片描述

3.2.训练曲线

在这里插入图片描述
在这里插入图片描述
1.x-axis : iterations, y-axis : 难样本的调整系数
2. t:adaptive parameter; M : MV-Arc-Softmax; M(ours) : gradient modulation coefficients
3.在训练早期,t --> 0,模型可以利用easy-sample加速收敛;在训练中后期t不断增大使得I(t, cos(θ_j)) > 1,这样模型可以更多地关注hard-smaples.

4.实验

       从Figure 4中可以看到,在整个训练阶段,CurricularFace对于难样本的决策边界从训练早期到后期自适应性的变化。
在这里插入图片描述
       最终,与其它方法相比,CurricularFace下的人脸识别效果得到明显改善(如Table4与Table6)
在这里插入图片描述
在这里插入图片描述

5.结论

       论文提出的自适应课程学习损失CurricularFace,将自适应课程学习的思想嵌入到人脸识别中。该方法易于实现,收敛性强,能够明显的提升人脸识别的准确率,而且它解决的是经常在训练过程中出现的问题(如:大边际和难样本),因而具备很高的实用价值。

pytorch代码:

class CurricularFace(nn.Module):
    """Implementation for "CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition".
    """
    def __init__(self, in_features, out_features, device_id=None, m = 0.5, s = 64., fp16 = False):
        super(CurricularFace, self).__init__()
        self.device_id = device_id
        self.fp16 = fp16

        self.m = m
        self.s = s
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.threshold = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
     
        self.kernel = Parameter(torch.FloatTensor(out_features, in_features))
        self.register_buffer('t', torch.zeros(1))
        nn.init.xavier_uniform_(self.kernel)     
        #self.kernel = Parameter(torch.Tensor(in_features, out_features))
        #self.register_buffer('t', torch.zeros(1))
        #nn.init.normal_(self.kernel, std=0.01)

    def forward(self, feats, labels):
        #kernel_norm = F.normalize(self.kernel, dim=0)
        #feats = F.normalize(feats)
        #cos_theta = torch.mm(feats, kernel_norm)
        sub_weights = torch.chunk(self.kernel, len(self.device_id), dim=0)
        temp_x = feats.cuda(self.device_id[0])
        weight = sub_weights[0].cuda(self.device_id[0])
        cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))
        for i in range(1, len(self.device_id)):
            temp_x = x.cuda(self.device_id[i])
            weight = sub_weights[i].cuda(self.device_id[i])
            cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
        cos_theta = cos_theta.clamp(-1.0, 1.0)  # for numerical stability
        with torch.no_grad():
            origin_cos = cos_theta.clone()
        target_logit = cos_theta[torch.arange(0, temp_x.size(0)), labels].view(-1, 1)

        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)
        mask = cos_theta > cos_theta_m
        if self.fp16:
            cos_theta_m = cos_theta_m.half()
        final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)

        hard_example = cos_theta[mask]
        with torch.no_grad():
            self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
        if self.fp16:
            self.t = self.t.half()
        cos_theta[mask] = hard_example * (self.t + hard_example)
        if self.device_id != None:
            cos_theta = cos_theta.cuda(self.device_id[0])
        cos_theta.scatter_(1, labels.view(-1, 1).long(), final_target_logit)
        output = cos_theta * self.s
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/371911.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电子技术——频率补偿

电子技术——频率补偿 在本节我们介绍修改三极点或多极点放大器的开环增益函数 A(s)A(s)A(s) 的方法,使得闭环增益在我们希望的值上放大器是稳定的。这个过程称为频率补偿。 理论 最简单的频率补偿方法是引入新的极点,如图下面是一个放大器的伯德图&am…

windows安装Ubuntu子系统以及图形化界面记录

文章目录1. windows环境设置2. 开始安装3. ubuntu使用3.1 启动和退出 Linux 子系统3.2 安装位置3.3 更换源4. 安装图形化界面4.1 安装VcXsrv4.2 安装桌面环境(1)方法1:VcXsrv Gnome(2)方法2:VcXsrv Xfce4…

Python到底牛在哪?现在就业薪资高吗?

Python是什么呢?Python是一种全栈的开发语言,你如果能学好Python,前端,后端,测试,大数据分析,爬虫等这些工作你都能胜任。当下Python有多火我不再赘述,Python有哪些作用呢?据我多年P…

GoogleTest中gMock的使用

GoogleTest中的gMock是一个库,用于创建mock类并使用它们。 当你编写原型或测试(prototype or test)时,完全依赖真实对象通常是不可行或不明智的(not feasible or wise)。模拟对象(mock object)实现了与真实对象相同的接口,但是需要你在运行时指定它…

SpringCloud学习笔记 - Sentinel流控规则配置的持久化 - Sentinel

1. 为什么要将流控规则持久化 默认的的流控规则是配置在sentinel中的,又因为sentinel是懒加载的,只有当我们访问了一个请求的时候,sentinel才能监控到我们的簇点链路,我们才能对该链路进行流控配置,一旦我们重启应用s…

GNN专栏总览

文章目录图卷积神经网络1. 理论篇2. 模型篇3. 有关gnn的论文检索图卷积神经网络 1. 理论篇 原理:http://xtf615.com/2019/02/24/gcn/论文: 综述类: HOW POWERFUL ARE GRAPH NEURAL NETWORKS?Bridging the Gap between Spatial and Spectra…

PHP实现个人免签约微信支付接口原理+源码

什么是个人免签支付 个人免签支付就是给个人用的支付接口,一般的支付接口都需要营业执照才能申请,个人很难申请的到,或者是没有资质去申请,要和支付商进行签约的。免签,顾名思义就是不需要签约。那么个人免签支付就有…

企业数字化运营平台软件开发框架项目

【版权声明】本资料来源网络,知识分享,仅供个人学习,请勿商用。【侵删致歉】如有侵权请联系小编,将在收到信息后第一时间删除!完整资料领取见文末,部分资料内容: 目录 1 项目总体概述 1.1 项目…

Unity Avatar Camera Controller 第一、第三人称相机控制

文章目录简介Variables实现Target PositionTarget RotationOthers简介 本文介绍如何实现用于Avatar角色的相机控制脚本,支持第一人称、第三人称以及两种模式之间的切换,工具已上传至SKFramework框架的Package Manager中: Variables Avatar&…

51单片机入门 - 简短的位运算实现扫描矩阵键盘

介绍 例程使用 SDCC 编译、 stcgal 烧录,如果你想要配置一样的环境,可以参考本专栏的第一篇文章“51单片机开发环境搭建 - VS Code 从编写到烧录”,我的设备是 Windows 10,使用普中51单片机开发板(STC89C52RC&#xf…

Qt编写微信支付宝支付

文章目录一 微信支付配置参数二 支付宝支付配置参数三 功能四 Demo效果图五 体验地址一 微信支付配置参数 微信支付API,需要三个基本必填参数。 微信公众号或者小程序等的appid;微信支付商户号mchId;微信支付商户密钥mchKey; 具…

文件基础IO

目录 前言 用库进行文件操作 文件描述符 理解Linux一切皆文件 缓冲区 认识缓冲区 缓冲区缓冲策略 磁盘结构 磁盘分区 软链接和硬链接 硬链接本质 软连接本质 动态库和静态库进阶 写一个静态库 动态库的产生和使用 动静态库的加载 总结: 前言 在我们了…

SE | 哇哦!让人不断感叹真香的数据格式!~

1写在前面 最近在用的包经常涉及到SummarizedExperiment格式的文件,不知道大家有没有遇到过。🤒 一开始觉得这种格式真麻烦,后面搞懂了之后发现真是香啊,爱不释手!~😜 2什么是SummarizedExperiment 这种cla…

lighthouse的介绍和基本使用方法

Lighthouse简介 Lighthouse是一个开源的自动化性能测试工具,我们可以使用该功能检测我们的页面存在那些性能方面的问题,并会生成一个详细的性能报告来帮助我们来优化页面 使用方式 LH一共有四种使用方式 Chrome开发者工具Chrome扩展Node 命令行Node …

数据结构与算法(一)-软件设计(十七)

设计模式(十五)-面向对象概念https://blog.csdn.net/ke1ying/article/details/129171047 数组 存储地址的计算: 一维数组a[n],当a[2]的存储地址为:a2*len,如果每一个数组元素只占用一个字节,那…

Spring Batch 高级篇-分区步骤

目录 引言 概念 分区器 分区处理器 案例 转视频版 引言 接着上篇:Spring Batch 高级篇-并行步骤了解Spring Batch并行步骤后,接下来一起学习一下Spring Batch 高级功能-分区步骤 概念 分区:有划分,区分意思,在…

中国ETC行业市场规模及未来发展趋势

中国ETC行业市场规模及未来发展趋势编辑根据市场调研在线网发布的2023-2029年中国ETC行业发展策略分析及战略咨询研究报告分析:随着政府坚持实施绿色出行政策,ETC行业也受到了极大的支持。根据中国智能交通协会统计,2017年中国ETC行业市场规模…

浅析Linux内核进程间通信(信号量)

信号灯与其他进程间通信方式不大相同,它主要提供对进程间共享资源访问控制机制。相当于内存中的标志,进程可以根据它判定是否能够访问某些共享资源(临界区,类似于互斥锁),同时,进程也可以修改该…

FreeRTOS任务基础知识

单任务和多任务系统单任务系统单任务系统的编程方式,即裸机的编程方式,这种编程方式的框架一般都是在main()函数中使用一个大循环,在循环中顺序的执行相应的函数以处理相应的事务,这个大循环的部分可以视为…

Linux内核共享内存使用常见陷阱与分析

所谓共享内存就是使得多个进程可以访问同一块内存空间,是最快的可用IPC形式。是针对其他通信机制运行效率较低而设计的。往往与其它通信机制,如 信号量结合使用,来达到进程间的同步及互斥。其他进程能把同一段共享内存段“连接到”他们自己的…