Supervised Contrastive 损失函数详解

news2025/1/12 15:56:55

在这里插入图片描述
有什么不对的及时指出,共同学习进步。(●’◡’●)

有监督对比学习将自监督批量对比方法扩展到完全监督设置,能够有效地利用标签信息。属于同一类的点簇在嵌入空间中被拉到一起,同时将来自不同类的样本簇推开。这种损失显示出对自然损坏很稳健,并且对优化器和数据增强等超参数设置更稳定。

有监督对比学习论文的贡献

  1. 提出了对比损失函数一种新的扩展,允许每个锚点都有多个正样本,使对比学习适应完全监督设置。
  2. 该损失为很多数据集的top-1的准确率带来了提升,对自然损坏有稳健性。
  3. 损失函数的梯度鼓励从硬正样本和硬的负样本中学习。(硬的正样本与锚点图像不相似的正样本,硬的负样本就是与锚点图像相似的负样本,都是难以学习的那种)
  4. 对比损失函数不如交叉熵损失函数对超参数敏感。

自监督对比学习损失
在这里插入图片描述
有监督对比学习损失
在这里插入图片描述
文中对交叉熵损失训练,自监督对比损失训练和有监督对比损失训练进行比较
在这里插入图片描述
推理模型中的参数个数始终保持不变,应该是推理的时候就是编码器+分类头都一样。
上图是训练的时候,交叉熵损失不必说。
自监督损失一般采用的是个体判别代理任务,正样本是自身经过数据增强后的图像(一般一个正样本),其他的都是负样本,训练编码器的时候让正样本和锚点图像经过编码器得到的特征尽可能接近,与负样本之间的特征尽可能拉远。
有监督对比学习,有标签信息,正样本除了自身数据增强后的之外还有这个类别中的其他样本(一般这个batch_size中)。
stage1就是训练编码器。
stage2是训练分类头,作者指出不需要训练线性分类器,并且先前的工作已经使用k -最近邻分类或原型分类来评估分类任务上的表示。线性分类器也可以与编码器联合训练,只要不将梯度传播回编码器即可,就是分类头和编码器之间训练要分开。
有监督对比学习损失代码
对比学习对比的是特征,所以损失函数的输入是特征,有监督对比学习损失还要输入标签信息。
损失函数就是模型的输出和标签(这里是mask)之间的差距,输出和标签差距越大,那么loss就越大。
输出这里是编码器的输出就是特征,标签就是类别标签。标签是如何起作用的呢?就是让损失函数区分这个batchsize中的正负样本,属于同一类就是正样本,其他都是负样本。
其中标签mask怎么获得,一个是通过label,另一个直接输入。label是每个数据的类别信息,label.view(1,-1)变成列向量然后再与它的转置进行torch.eq(),得到一个矩阵mask,mask(i,j)如果第i个数据和第j个数据类别相同那么这个位置是True,否则为False,float就变成0,1。后面乘了一个对角线元素为0,其他位置元素为1的矩阵,就是不让每个feature与自身对比。
我们看它self.contrast_mode="one"的时候只是比较feature中第0个特征(也就是平常的第一个特征),那么锚点特征就是所有数据的第0个特征;"all"就是所有的特征都要对比;锚点特征就是所有数据的所有特征。 torch.cat(torch.unbind(features, dim=1), dim=0)把feature按照第1维拆开,然后在第0维上cat,然后比较的feature的形式就是每一个数据的第1个特征|每个数据的第2个特征|…|每个数据的第n个特征,排列,这些特征是排在一起的在一个维度上。锚点特征要么是输入特征组的每个数据的第0个特征要么就是这些比较的特征。(不太理解为什么one的时候比较特征还是所有的)
锚点特征与比较特征的转置相乘,得到的就是batch_size*channel个相似矩阵,每两个数据在这个特征下的相似度。然后这个相似度矩阵要和我们得到的mask进行比较,就是上面的第二个式子。
下面是详细解释。

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode#设置对比的模式有one和all两种,代表对比一个channel还是所有,个人理解
        self.base_temperature = base_temperature #设置的温度

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')#设置设备
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:# batch_size, channel,H,W,平铺变成batch_size, channel, (H,W)
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:#只能存在一个
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:#如果两个都没有就是无监督对比损失,mask就是一个单位阵
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:#有标签,就把他变成mask
            labels = labels.contiguous().view(-1, 1)#contiguous深拷贝,与原来的labels没有关系,展开成一列,这样的话能够计算mask,否则labels一维的话labels.T是他本身捕获发生转置
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask =  torch.eq(labels, labels.T).float().to(device)#label和label的转置比较,感觉应该是广播机制,让label和label.T都扩充了然后进行比较,相同的是1,不同是0.
            #这里就是由label形成mask,mask(i,j)代表第i个数据和第j个数据的关系,如果两个类别相同就是1, 不同就是0
        else:
            mask = mask.float().to(device)#有mask就直接用mask,mask也是代表两个数据之间的关系

        contrast_count = features.shape[1]#对比数是channel的个数
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)#把feature按照第1维拆开,然后在第0维上cat,(batch_size*channel,h*w..)#后面就是展开的feature的维度
        #这个操作就和后面mask.repeat对上了,这个操作是第一个数据的第一维特征+第二个数据的第一维特征+第三个数据的第一维特征这样排列的与mask对应
        if self.contrast_mode == 'one':#如果mode=one,比较feature中第1维中的0号元素(batch, h*w)
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':#all就(batch*channel, h*w)
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),#两个相乘获得相似度矩阵,乘积值越大代表越相关
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)#计算其中最大值
        logits = anchor_dot_contrast - logits_max.detach()#减去最大值,都是负的了,指数就小于等于1

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)#repeat它就是把mask复制很多份
        # mask-out self-contrast cases
        logits_mask = torch.scatter(#生成一个mask形状的矩阵除了对角线上的元素是0,其他位置都是1, 不会对自身进行比较
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask#定义其中的相似度
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))#softmax

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)#mask的和
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)#满足返回1,不满足返回mask_pos_pairs.保证数值稳定
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos#类似蒸馏temperature温度越高,分布曲线越平滑不易陷入局部最优解,温度低,分布陡峭
        loss = loss.view(anchor_count, batch_size).mean()#计算平均

        return loss

使用的化就是下面这段:

loss = criterion(features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1411670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

支付宝AES如何加密

继之前给大家介绍了 V3 加密解密的方法之后&#xff0c;今天给大家介绍下支付宝的 AES 加密。 注意&#xff1a;以下说明均在使用支付宝 SDK 集成的基础上&#xff0c;未使用支付宝 SDK 的小伙伴要使用的话老老实实从 AES 加密原理开始研究吧。 什么是AES密钥 AES 是一种高级加…

k8s实例

k8s实例举例 &#xff08;1&#xff09;Kubernetes 区域可采用 Kubeadm 方式进行安装。 &#xff08;2&#xff09;要求在 Kubernetes 环境中&#xff0c;通过yaml文件的方式&#xff0c;创建2个Nginx Pod分别放置在两个不同的节点上&#xff0c;Pod使用动态PV类型的存储卷挂载…

虚幻UE 插件-像素流送实现和优化

本笔记记录了像素流送插件的实现和优化过程。 UE version&#xff1a;5.3 文章目录 一、像素流送二、实现步骤1、开启像素流送插件2、设置参数3、打包程序4、打包后的程序进行像素流参数设置5、下载NodeJS6、下载信令服务器7、对信令服务器进行设置8、启动像素流送 三、优化1、…

路飞项目--03

总页面 二次封装Response模块 # drf提供的Response&#xff0c;前端想接收到的格式 {code:xx,msg:xx} 后端返回&#xff0c;前端收到&#xff1a; APIResponse(tokneasdfa.asdfas.asdf)---->{code:100,msg:成功,token:asdfa.asdfas.asdf} APIResponse(code101,msg用户不存…

数据结构排序算详解(动态图+代码描述)

目录 1、直接插入排序&#xff08;升序&#xff09; 2、希尔排序&#xff08;升序&#xff09; 3、选择排序&#xff08;升序&#xff09; 方式一&#xff08;一个指针&#xff09; 方式二&#xff08;两个指针&#xff09; 4、堆排序&#xff08;升序&#xff09; 5、冒…

精酿啤酒:啤酒花的选择与处理方法

啤酒花在啤酒的酿造过程中起着重要的作用&#xff0c;它不仅赋予啤酒与众不同的苦味和香味&#xff0c;还为啤酒的稳定性提供了帮助。对于Fendi Club啤酒来说&#xff0c;啤酒花的选择和处理方法更是重要。下面&#xff0c;我们将深入探讨Fendi Club啤酒在啤酒花的选择和处理方…

一文详解C++拷贝构造函数

文章目录 引入一、什么是拷贝构造函数&#xff1f;二、什么情况下使用拷贝构造函数&#xff1f;三、使用拷贝构造函数需要注意什么&#xff1f;四、深拷贝和浅拷贝浅拷贝深拷贝 引入 在现实生活中&#xff0c;可能存在一个与你一样的自己&#xff0c;我们称其为双胞胎。 相当…

【并发编程】 synchronized的普通方法,静态方法,锁对象,锁升级过程,可重入锁,非公平锁

目录 1.普通方法 2.静态方法 3.锁对象 4.锁升级过程 5.可重入的锁 6.不公平锁 非公平锁的 lock 方法&#xff1a; 1.普通方法 将synchronized修饰在普通同步方法&#xff0c;那么该锁的作用域是在当前实例对象范围内,也就是说对于 SyncDemosdnewSyncDemo();这一个实例对象…

el-table 动态渲染多级表头;一级表头根据数据动态生成,二级表头固定

一、表格需求&#xff1a; 实现一个动态表头&#xff0c;一级表头&#xff0c;根据数据动态生成&#xff0c;二级表头固定&#xff0c;每列的数据不一样&#xff0c;难点在于数据的处理。做这种表头需要两组数据&#xff0c;一组数据是实现表头的&#xff0c;另一组数据是内容…

【洛谷】P1135奇怪的电梯(DFS)

这题利用 dfs 解决&#xff0c;编程实现比较简单。 具体来说&#xff0c;每层楼有两种可能&#xff0c;上楼或下楼&#xff0c;因此可以形成一个以 a 楼为根的二叉树&#xff0c;因此只需一个 for 循环遍历某个父节点的两个子节点&#xff0c;之后递归就行。 易错点&#xff…

马尔可夫预测(Python)

马尔科夫链&#xff08;Markov Chains&#xff09; 从一个例子入手&#xff1a;假设某餐厅有A&#xff0c;B&#xff0c;C三种套餐供应&#xff0c;每天只会是这三种中的一种&#xff0c;而具体是哪一种&#xff0c;仅取决于昨天供应的哪一种&#xff0c;换言之&#…

灰度转换及修改尺寸

文章目录 主要内容一.OpenCVPycharm1.读取图片及灰度转换代码如下&#xff08;示例&#xff09;: 2.修改尺寸代码如下&#xff08;示例&#xff09;: 总结 主要内容 读取图片及灰度转换修改尺寸 一.OpenCVPycharm 1.读取图片及灰度转换 代码如下&#xff08;示例&#xff09…

C++ 程序使用 OpenCV 生成两个黑色的灰度图像,并添加随机特征点,然后将这两个图像合并为一张图像并显示

文章目录 源码文件功能解读编译文件 源码文件 #include <iostream> #include <vector> #include <opencv2/opencv.hpp>std::vector<cv::KeyPoint> generateRandomKeyPoints(const cv::Mat& image, int numPoints) {std::vector<cv::KeyPoint&g…

Flume1.9基础学习

文章目录 一、Flume 入门概述1、概述2、Flume 基础架构2.1 Agent2.2 Source2.3 Sink2.4 Channel2.5 Event 3、Flume 安装部署3.1 安装地址3.2 安装部署 二、Flume 入门案例1、监控端口数据官方案例1.1 概述1.2 实现步骤 2、实时监控单个追加文件2.1 概述2.2 实现步骤 3、实时监…

体感大屏互动游戏开发

体感大屏互动游戏是一种结合了体感技术和大屏幕显示的游戏形式&#xff0c;旨在通过玩家的身体动作和互动&#xff0c;提供更加身临其境的游戏体验。这种类型的游戏常常采用各种体感设备&#xff0c;如深度摄像头、体感控制器、传感器等&#xff0c;使玩家能够通过真实的动作来…

C++算法学习心得六.回溯算法(3)

1.子集II&#xff08;90题&#xff09; 题目描述&#xff1a; 给定一个可能包含重复元素的整数数组 nums&#xff0c;返回该数组所有可能的子集&#xff08;幂集&#xff09;。 说明&#xff1a;解集不能包含重复的子集。 示例: 输入: [1,2,2]输出: [ [2], [1], [1,2,2], …

centos 安装mysql5.7教程

一&#xff0c;配置yum mysql5.7安装源 配置yum mysql5.7安装源 yum localinstall https://dev.mysql.com/get/mysql57-community-release-el7-11.noarch.rpm 配置mysql5.7安装源成功 查看配置成功的安装源 yum repolist enabled | grep "mysql*" 执行后看到已配…

大模型|基础——长短时记忆网络

文章目录 LSTM遗忘门输入门整合信息特点实现神经单元的内部计算门控控制——可以动态选择信息在大数据量的情况下&#xff0c;可有效缓解梯度 LSTM 遗忘门 遗忘门&#xff0c;是否进行遗忘。 如果通过计算&#xff0c;计算出来的结果为0&#xff0c;就选择遗弃。 如果遗忘&…

14.4.2 Flash读取与修改数据库中的数据

14.4.2 Flash读取与修改数据库中的数据 计数器是网站必不可少的统计工具&#xff0c;使用计数器可以使网站管理者对网站的访问情况有一个清晰的了解。如果仅仅是统计首页访问量的话&#xff0c;用文本文件来存储数据就可以了&#xff0c;但如果统计的数据量比较大的话(如文章系…

MySQL和Redis的事务有什么异同?

MySQL和Redis是两种不同类型的数据库管理系统&#xff0c;它们在事务处理方面有一些重要的异同点。 MySQL事务&#xff1a; ACID属性&#xff1a; MySQL是一个关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;支持ACID属性&#xff0c;即原子性&#xff08;Ato…