【center-loss 中心损失函数】 原理及程序解释(更新中)

news2024/9/21 18:35:51

文章目录

  • 前言
  • 问题引出
    • open-set问题
    • 抛出
  • 解决方法
    • softmax函数、softmax-loss函数
    • 解决
    • 代码(center_loss.py)
      • 原理
      • 程序解释
    • 代码运用
  • 如何梯度更新
    • 首先了解一下基本的梯度下降算法
    • 然后
  • 补充:外围知识
    • 模型


前言

学习一下: 中心损失函数,用于用于深度人脸识别的特征判别方法
论文:https://ydwen.github.io/papers/WenECCV16.pdf
github代码:https://github.com/KaiyangZhou/pytorch-center-loss

参考:史上最全MNIST系列(三)——Centerloss在MNIST上的Pytorch实现(可视化)


问题引出

open-set问题

open-set 问题是一种模式识别中的问题,它指的是当训练集和测试集的类别不完全相同的情况。例如,如果训练集只包含 0 到 9 的数字,而测试集包含了 A 到 Z 的字母,那么就是一个 open-set 问题。这种情况下,分类器不仅要正确识别已知的类别,还要能够拒绝未知的类别,即将它们标记为 unknown 或 outlier。(后来我想这就是一个聚类问题)

open-set 问题与 closed-set 问题相对应,closed-set 问题是指训练集和测试集的类别完全相同的情况。例如,如果训练集和测试集都只包含 0 到 9 的数字,那么就是一个 closed-set 问题。这种情况下,分类器只需要正确识别已知的类别即可。

抛出

当我们要预测的人脸不在训练集中出现过时,我们需要让它不识别出,而不要因为与训练过的人脸相像而误判。

对于常见的图像分类问题,我们常常用softmax loss来求损失。以MNIST数据集为例,如果你的损失采用softmax loss,那么最后各个类别学出来的特征分布大概如下图:(在倒数第二层全连接层输出了一个2维的特征向量)
在这里插入图片描述
左图为训练集,右图为测试集,发现结果分类还算不错,但是每一类的界限太过模糊,若从中再加一列,则有可能出现误判。

解决方法

softmax函数、softmax-loss函数

已知softmax的函数为:
在这里插入图片描述
在深度学习的分类问题上,意为对应类的概率(符合每个类的概率相加和为1且0<每个类的概率<1)。(输出层后的结果)
其中 zi​ 是第 i 个输出节点的值,K是输出节点的个数,即分类的类别数。softmax函数的作用是将输出节点的值归一化为范围在 [0, 1] 且和为 1 的概率值,表示属于每个类别的可能性。

softmax的损失函数:(具体详见:一文详解Softmax函数)
在softmax的函数的基础上,我们要求正确对应类的概率最大,
在这里插入图片描述
即,损失函数为:(越小越好)
在这里插入图片描述
其中z = wTx+b,m是小批量的样本的个数,n是类别个数,w是全连接层的权重,b为偏置项,(一般来说w,b是要学习出来的),xi是第i个深层特征,属于第yi类。
在这里插入图片描述

解决

在分类的基础上,我们还要求,每个类,往自己的中心的特征靠拢,使类内间距减少,这样才能更加显出差别。
在这里插入图片描述
于是,论文作者提出如下新的损失函数:
在这里插入图片描述
其中 xi​ 是第 i 个样本的特征向量,cyi​​ 是第 yi​ 个类别的中心向量,m 是样本的个数。
小的注意点:这里用的是样本数,也就是说,是在小批量分类任务完成后再进行的聚类任务
在这里插入图片描述
我们不难发现,作者这里用的是欧式距离的平方,即在多维空间上的两点之间真实距离的平方。
实际上,我们很容易发现,这实际上是一个聚类问题,常见的聚类问题上用的是误差平方和(SSE)损失函数:
在这里插入图片描述

论文作者仅从此推出的欧式距离的平方,仅从形式上十分相像,当然可以作为此聚类问题的解决。
在这里插入图片描述
(在二维上,可以简单看成是直角三角形斜边的平方=两直角边平方和)
为何要加以平方:欧式距离的平方相比欧式距离有一些优点,例如:
1、计算更快,省去了开方的运算。
2、更加敏感,能够放大距离的差异,使得离群点更容易被发现。
3、更加方便,能够与其他平方项相结合,如方差、协方差等。

最后,作者沿用softmax损失函数与中心损失函数相加的方法(在损失函数上很常见),
在这里插入图片描述
这里的1/2很容易想到是作为梯度下降时与后面的平方用来抵消的项,这里λ 可以看作调节两者损失函数的比例

来作为总体的损失函数,作为此问题的解决。

代码(center_loss.py)

原理

首先确定三个事实:
1、在之前的学习中通过实践得知(最小二乘法那块),在拟合的最后结果上,在利用SSE损失函数与MSE损失函数作为损失值参与到程序中时,拟合出的结果并无差别,只有在结果出来后,作为评判模型的好坏时,才有数值上的差别。(两者区别在于是否求平均)。
2、1/2的乘或不乘,只对运算的过程的简便程度有影响,与结果无影响。
3、图像是二维的。

于是我们将作者的的中心损失函数稍加变形。就得到了如下形式(事实上,github上给出的代码就是这么写的)
在这里插入图片描述

程序解释

参考:Center loss-pytorch代码详解
这里只做补充:

import torch
import torch.nn as nn

class CenterLoss(nn.Module):
    """Center loss.
    # 参考
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    # 参数
    Args:
        num_classes (int): number of classes. # 类别数
        feat_dim (int): feature dimension.  # 特征维度
    """
    # 初始化        默认参数:类别数为10 特征维度为2 使用GPU
    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__() # 继承父类的所有属性
        self.num_classes = num_classes 
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu: # 如果使用GPU
            #nn.Parameter()将一个不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到一个module里面,参与到模型的训练和优化中。
            #nn.Parameter的对象的requires_grad属性的默认值是True,即是可被训练的,这与torth.Tensor对象的默认值相反。
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 
            # 初始化中心矩阵 .cuda()表示将数据放到GPU上
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels): # 前向传播
        """
        Args:
            x: feature matrix with shape (batch_size, feat_dim). # 特征矩阵
            labels: ground truth labels with shape (batch_size). # 真实标签
        """
        batch_size = x.size(0) #batch_size x的形式为tensor张量
        # .pow对x中的每一个元素求平方 dim=1表示按行求和 keepdim=True表示保持原来的维度 expand是扩展维度
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() #.t()表示转置 均转成batch_size x num_classes的形式
        # .addmm_()表示进行1*distmat + (-2)*x@self.centers.t()的运算    @表示矩阵乘法
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()# 生成一个从0到num_classes-1的整数序列 long表示数据类型
        if self.use_gpu: classes = classes.cuda() 
        #这里 .unsqueeze(0) 例[0,4,2] -> [[0,4,2]]  .unsqueeze(1) 例[0,4,2] -> [[0],[4],[2]] 
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)# .unsqueeze(1)表示在第1维增加一个维度  .expand()表示扩展维度
        mask = labels.eq(classes.expand(batch_size, self.num_classes)) #eq是比较两个tensor是否相等,相等返回1,不相等返回0
        # *表示对应元素相乘
        dist = distmat * mask.float() # mask.float()将mask转换为float类型
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size # clamp表示将dist中的元素限制在1e-12和1e+12之间

        return loss

代码运用

import torch
import torch.nn as nn
from center_loss import CenterLoss

criterion_cent =CenterLoss(10, 2,False) #10个类别,2维
torch.manual_seed(0) #设置随机种子
x = torch.randn(10, 2) # (32, 2) #32个样本,每个样本2维
labels = torch.randint(0, 10, (10,)) # (32,) #32个样本,每个样本一个标签
print(x)
print(labels)
print(criterion_cent(x, labels))

结果:

tensor([[-1.1258, -1.1524],
        [-0.2506, -0.4339],
        [ 0.5988, -1.5551],
        [-0.3414,  1.8530],
        [ 0.4681, -0.1577],
        [ 1.4437,  0.2660],
        [ 1.3894,  1.5863],
        [ 0.9463, -0.8437],
        [ 0.9318,  1.2590],
        [ 2.0050,  0.0537]])
tensor([2, 9, 1, 8, 8, 3, 6, 9, 1, 7])
tensor(2.9281, grad_fn=<DivBackward0>)

如何梯度更新

首先了解一下基本的梯度下降算法

【点云、图像】学习中 常见的数学知识及其中的关系与python实战
里的小标题:基于迭代的梯度下降算法,了解到超参数(人为设定的参数):
学习率,w,b等。此算法为拟合出一条线。

伪代码:
1、未达到设定迭代次数
2、迭代次数epoch+1
3、计算损失值
4、计算梯度
5、更新w、b
6、达到迭代次数,结束

然后

看下这里的更新方法
在这里插入图片描述
其中,卷积层中初始化的参数 θc,超参数 :λ、α 和学习率 μ ,迭代次数 t 。(λ、α、 μ均可调)
伪代码:
1、当未收敛时:
2、迭代次数t+1
3、计算损失函数 L=Ls+Lc
4、计算反向传播误差(即梯度)
5、更新w
6、更新cj
7、更新θc
8、结束

其中
在这里插入图片描述
其中,如果满足条件,则 δ(条件) = 1,如果不满足,则 δ(条件) = 0。

首先:第一个公式不用多说,为中心损失函数对特征xi求偏导,即求梯度。
其次:第二个公式:加入了判断,当条件满足时(即yi=j,即就是在同一类时)Δcj=cj-xi 的小批量的所有和/m+1。相当于累加了误差求平均,以此来作为梯度。
当条件不满足时,即Δcj =0,此时这个分母的+1的作用就体现出来了,分母不能为0嘛。

补充:
第5步后面一个等号成立是因为Lc中没有W项,所以Lc对W的求导为0
第7步也是一样,求梯度,只是写成了求导的链式法则。

补充:外围知识

模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1481291.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你真的了解C语言中的【柔性数组】吗~

柔性数组 1. 什么是柔性数组2. 柔性数组的特点3. 柔性数组的使用4. 柔性数组的优势 1. 什么是柔性数组 也许你从来没有听说过柔性数组这个概念&#xff0c;但是它确实是存在的。 C99中&#xff0c;结构体中的最后⼀个元素允许是未知大小的数组&#xff0c;这就叫做柔性数组成员…

2024年腾讯云优惠券领取入口、查看使用和常见问题解答FAQ

腾讯云代金券领取渠道有哪些&#xff1f;腾讯云官网可以领取、官方媒体账号可以领取代金券、完成任务可以领取代金券&#xff0c;大家也可以在腾讯云百科蹲守代金券&#xff0c;因为腾讯云代金券领取渠道比较分散&#xff0c;腾讯云百科txybk.com专注汇总优惠代金券领取页面&am…

力扣hot100题解(python版36-40题)

36、二叉树的中序遍历 给定一个二叉树的根节点 root &#xff0c;返回 它的 *中序 遍历* 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3,2]示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[]示例 3&#xff1a; 输入&am…

Mysql主从备份

主从复制概述 将主服务器的binlog日志复制到从服务器上执行一遍&#xff0c;达到主从数据的一致状态&#xff0c;称之为主从复制。一句话表示就是&#xff0c;主数据库做什么&#xff0c;从数据库就跟着做什么。 为什么要使用主从复制 为实现服务器负载均衡/读写分离做铺垫&…

队列的结构概念和实现

文章目录 一、队列的结构和概念二、队列的实现三、队列的实现函数四、队列的思维导图 一、队列的结构和概念 什么是队列&#xff1f; 队列就是只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c;队列具有先进先出 如上图所示&#x…

Ansible的playbook的编写和解析

目录 什么是playbook Ansible 的脚本 --- playbook 剧本 实例部署&#xff08;使用playbook安装启动httpd服务&#xff09; 1.编写一个.yaml文件 在主机下载安装http&#xff0c;将配置文件复制到opt目录下 运行playbook 在192.168.17.77主机上查看httpd服务是否成功开启…

BY组态功能清单

演示地址 &#xff1a;http://www.byzt.net:60/sm/ 官网地址&#xff1a;http://www.hcy-soft.com BY组态是一款非常优秀的纯前端的【web组态插件工具】&#xff0c;可无缝嵌入到vue项目&#xff0c;react项目等&#xff0c;由于是原生js开发&#xff0c;对于前端的集成没有框架…

vue2 引入阿里图标库iconfont

有时候我们使用的 ui 里面图标是不够丰富的&#xff0c;不一定可以满足我们的需求。 这时我们可以引入阿里图标库来丰富自己的项目图标。 上步骤&#xff1a; 1、点击文字地址连接&#xff1a; iconfont-阿里巴巴矢量图标库 2、登录&#xff0c;没有账号需要先使用手机号注册…

【Java题】调整奇数位于偶数之前(超简单版)

题目&#xff1a; 调整数组顺序使得奇数位于偶数之前。调整之后&#xff0c;不关心大小顺序。 如数组&#xff1a;[1,2,3,4,5,6,7,8,9] 调整后可能是&#xff1a;[1, 9,3,7,5, 6, 4, 8, 2] 代码&#xff1a; import java.util.Arrays;public class Main {public static voi…

MySQL存储引擎及索引机制

MySQL技术——存储引擎和索引机制 一、存储引擎概述二、常见存储引擎的区别三、索引机制四、索引的底层实现原理五、InnoDB主键和二级索引六、聚集索引和非聚集索引七、哈希索引八、InnoDB的自适应哈希索引九、索引常见问题十、慢查询日志总结 一、存储引擎概述 插件式存储引擎…

b站小土堆pytorch学习记录——P7-P8 Tensorboard的使用

文章目录 一、前置知识1.Tensorboard是什么2.SummaryWriter3.add_scalar()4.add_image() 二、代码1.一次函数2.蚂蚁和蜜蜂图片 一、前置知识 1.Tensorboard是什么 TensorBoard 是 TensorFlow 的可视化工具&#xff0c;它允许开发者可视化模型的图&#xff08;graph&#xff0…

一线互联网大厂中高级Android面试真题收录,记一次字节跳动Android社招面试

在开始回答前&#xff0c;先简单概括性地说说Linux现有的所有进程间IPC方式&#xff1a; 1. **管道&#xff1a;**在创建时分配一个page大小的内存&#xff0c;缓存区大小比较有限&#xff1b; 2. 消息队列&#xff1a;信息复制两次&#xff0c;额外的CPU消耗&#xff1b;不合…

驱动高级--mknod

一、起源 仅devfs&#xff0c;导致开发不方便以及一些功能难以支持&#xff1a; 热插拔 不支持一些针对所有设备的统一操作&#xff08;如电源管理&#xff09; 不能自动mknod 用户查看不了设备信息 设备信息硬编码&#xff0c;导致驱动代码通用性差&#xff0c;即没有分离…

【Python】成功解决ValueError: not enough values to unpack (expected 2, got 1)

【Python】成功解决ValueError: not enough values to unpack (expected 2, got 1) &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&am…

本科毕业设计:计及并网依赖性的分布式能源系统优化研究。(C语言实现)(内包含NSGA II优化算法)(二)

目录 前言 1、sofc函数 2、光伏板函数 3、集热场函数 4、sofc电跟随策略函数 5、二分法找sofc运行点函数 6、目标函数&#xff1a;成本 7、目标函数&#xff1a;二氧化碳排放量 8、目标函数&#xff1a;并网依赖性 前言 本篇文章介绍的是我的毕业设计&#xff0c;我将C…

【前端素材】推荐优质后台管理系统Annex平台模板(附源码)

一、需求分析 1、系统定义 后台管理系统是一种用于管理网站、应用程序或系统的管理界面&#xff0c;通常由管理员和工作人员使用。它提供了访问和控制网站或应用程序后台功能的工具和界面&#xff0c;使其能够管理用户、内容、数据和其他各种功能。 2、功能需求 后台管理系…

基于java+springboot动物检疫信息管理系统设计和实现

基于java SSM springboot动物检疫信息管理系统设计和实现 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文…

3d模型版本转换器注意事项---模大狮模型网

在使用3D模型版本转换器时&#xff0c;有一些注意事项可以帮助您顺利完成模型转换并避免不必要的问题&#xff1a; 数据完整性&#xff1a;在进行模型转换之前&#xff0c;确保您的原始3D模型文件没有损坏或缺失数据。损坏的文件可能导致转换器无法正常处理或输出错误的结果。 …

公钥密码体制

公钥密码体制 一个系统中,n个用户之间要进行保密通信,为了确保安全性,两两用户之间的密钥不能一样。这种方式下,需要系统提供C2 n=n(n-1)/2把共享密钥。这样密钥的数量就大幅增加了,随之而来的产生、存储、分配、管理密钥的成本也大幅增加。而使用公钥密码体制可以大大减…

[unity] c# 扩展知识点其一 【个人复习笔记/有不足之处欢迎斧正/侵删】

.NET 微软的.Net既不是编程语言也不是框架,是类似于互联网时代、次时代、21世纪、信息时代之类的宣传口号,是一整套技术体系的统称&#xff0c;或者说是微软提供的技术平台的代号. 1.跨语言 只要是面向.NET平台的编程语言(C#、VB、 C、 F#等等)&#xff0c;用其中一种语言编写…