简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

news2024/11/26 22:44:13

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):
    def __init__(self, num_patch = 4):
    
        self.num_patch = num_patch
      
    def __call__(self, x):
    
    
        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GBlur(p=0.1),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.ToTensor(),  
            normalize
        ])
        augmented_x = [aug_transform(x) for i in range(self.num_patch)]
     
        return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):
    net.zero_grad()
    opt.zero_grad()
        
    data = torch.cat(data, dim=0) 
    data = data.cuda()
    z_proj = net(data)
            
    z_list = z_proj.chunk(num_patches, dim=0)
    z_avg = chunk_avg(z_proj, num_patches)
            
    # Contractive Loss
    loss_contract, _ = contractive_loss(z_list, z_avg)
    loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):
    def __init__(self, ):
        super().__init__()
        pass

    def forward(self, z_list, z_avg):
        z_sim = 0
        num_patch = len(z_list)
        z_list = torch.stack(list(z_list), dim=0)
        z_avg = z_list.mean(dim=0)
        
        z_sim = 0
        for i in range(num_patch):
            z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
            
        z_sim = z_sim/num_patch
        z_sim_out = z_sim.clone().detach()
                
        return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):
    z_list = z.chunk(num_patches,dim=0)
    loss = 0
    for i in range(num_patches):
        loss += criterion(z_list[i])
    loss = loss/num_patches
    return loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/877673.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

图解结构体大小和位域例子

struct A {short a; char b; int c : 1; char d : 4; short e : 7; }; 备注:蓝色:表示占一个符号位空间红色:表示补齐其他颜色:实际最大值所占空间 (1)图解例1 st…

ASPICE学习笔记

文章目录 1. ASPICE是什么?2. ASPICE能干什么?2.1 过程参考模型2.2 过程评估模型参考1. ASPICE是什么? ASPICE的全称是Automotive SPICE。很明显的看出ASPICE是由SPICE发展而来。而SPICE是由国际标准化组织ISO、国际电工委员会IEC、信息技术委员会JTC1发起制定的ISO15504标…

VSCODE[配置ssh免密远程登录]

配置ssh免密远程登录 本文摘录于:https://blog.csdn.net/qq_44571245/article/details/123031276只是做学习备份之用,绝无抄袭之意,有疑惑请联系本人! 这里要注意如下几个地方: 1.要进入.ssh目录创建文件: 2.是拷贝带"ssh-…

Android面试官:“来给我讲讲View绘制?”

前言 迎面走来的一位中年男子,他一手拿着保温杯,一手抱着笔记本电脑,顶着惺忪的睡眼,不紧不慢地走着,不多的几根头发在他头顶自由飞翔。过了一会,他面对着我坐下,放下电脑和保温杯,…

【腾讯云Cloud Studio实战训练营】用Vue+Vite快速构建完成律师H5页面

👀前置了解:(官网 https://cloudstudio.net/) 什么是Cloud Studio? Cloud Studio 是基于浏览器的集成式开发环境(IDE),为开发者提供了一个永不间断的云端工作站。用户在使用 Cloud Studio 时无需安装&#…

计算机设计大赛国赛一等奖项目分享——基于多端融合的化工安全生产监管可视化系统

文章目录 一、计算机设计大赛国赛一等奖二、项目背景三、项目简介四、系统架构五、系统功能结构六、项目特色(1)多端融合(2)数据可视化(3)计算机视觉(目标检测) 七、系统界面设计&am…

源码断点分析Spring的占位符(Placeholder)是怎么工作的

项目中经常需要使用到占位符来满足多环境不同配置信息的需求&#xff0c;比如&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <beans xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xmlns"http://www.springframe…

老师怎么发分班录取情况?这个快捷方式可以搞定

关于分班录取的发布方式&#xff0c;老师们可以参考下面的步骤来发布&#xff1a; 1. 班级通知&#xff1a;老师可以在学校的通知栏、公告栏或班级群里发布分班录取情况。在通知中&#xff0c;老师可以列出每个班级的学生名单&#xff0c;包括学生的姓名和分班结果。 2. 班级…

纽扣电池寿命和功率增强器

近日&#xff0c;基础半导体器件领域的高产能生产专家Nexperia&#xff08;安世半导体&#xff09;宣布推出NBM7100和NBM5100。这两款IC采用了具有突破意义的创新技术&#xff0c;是专为延长不可充电的典型纽扣锂电池寿命而设计的新型电池寿命增强器&#xff0c;相比于同类解决…

抖店商品详情API接口(关键词搜索商品列表API接口)

联盟商品和非联盟商品是抖店平台上的两种不同类型的商品。 联盟商品是指与抖店平台达成合作关系的商家提供的商品。这些商家通常是经过严格筛选和审核的合作伙伴&#xff0c;与抖店平台有合作协议&#xff0c;并享受一定的运营支持和优惠政策。联盟商品通常具有较高的品质和可…

【大数据】一些基本概念

一、数据库、数据仓库、数据湖 1.什么是数据库 (Database, DB) 数据库是指长期储存在计算机中的有组织的, 可共享的数据集合 就是存储数据的仓库 数据库有三个特点: 永久存储, 有组织, 可共享 数据库是一种结构化数据存储技术&#xff0c;用于存储和管理有组织的数据。数据库…

[HDLBits] Exams/m2014 q4a

Implement the following circuit: Note that this is a latch, so a Quartus warning about having inferred a latch is expected. module top_module (input d, input ena,output q);always(*) beginif(ena)qd;end endmodule

面试热题(合并K个升序链表)

给定一个链表数组&#xff0c;每个链表都已经按升序排列。 请将所有链表合并到一个升序链表中&#xff0c;返回合并后的链表。 输入&#xff1a;lists [[1,4,5],[1,3,4],[2,6]] 输出&#xff1a;[1,1,2,3,4,4,5,6] 解释&#xff1a;链表数组如下&#xff1a; [1->4->5,1…

Java多款线程池,总有一款适合你。

线程池的选择 一&#xff1a;故事背景二&#xff1a;线程池原理2.1 ThreadPoolExecutor的构造方法的七个参数2.1.1 必须参数2.1.2 可选参数 2.2 ThreadPoolExecutor的策略2.3 线程池主要任务处理流程2.4 ThreadPoolExecutor 如何做到线程复用 三&#xff1a;四种常见线程池3.1 …

企业数字化转型与股利分配(2007-2021年)

参照李滟&#xff08;2023&#xff09;的做法&#xff0c;本团队对来自西南大学学报&#xff08;社会科学版&#xff09;《企业数字化转型与股利分配》一文中的基准回归部分进行复刻。 企业数字化转型已成为我国经济增长的新引擎和新动力。为探究数字化转型对企业财务决策的影…

Spring 依赖注入和自动装配

DI&#xff08;依赖注入&#xff09; DI&#xff1a;Dependency Injection 共有三种方式 构造器注入 在前面IOC容器创建对象的方式中已经提到&#xff0c;无参构造器和有参构造器都可以。 Set方式注入&#xff08;重点&#xff09; 依赖注入&#xff1a;本质是Set注入 依赖…

【Linux】高级IO

目录 IO的基本概念 钓鱼五人组 五种IO模型 高级IO重要概念 同步通信 VS 异步通信 阻塞 VS 非阻塞 其他高级IO 阻塞IO 非阻塞IO IO的基本概念 什么是IO&#xff1f; I/O&#xff08;input/output&#xff09;也就是输入和输出&#xff0c;在著名的冯诺依曼体系结构当中…

LinuxC编程——进程间通信(二)(信号、共享内存)

目录 一、信号1.1 概念1.2 信号的响应方式⭐⭐⭐1.3 几种常见的信号1.4 函数练习 二、共享内存2.1 共享内存的特点2.2 共享内存创建步骤⭐⭐2.3 共享内存创建所需函数 信号主要用来通知进程异步事件的发生。最初信号设计的目的是为了处理错误&#xff0c;它们也用来作为最基本的…

【EI/SCOPUS检索】第二届能源与动力工程国际学术会议(EPE 2023)

第二届能源与动力工程国际学术会议&#xff08;EPE 2023&#xff09; 2023 2nd International Conference on Energy and Power Engineering 能源是人类社会发展的重要推动力量。如何安全、清洁、高效地存储、转化和利用能源&#xff0c;实现人类可持续发展&#xff0c;一直…

比例方向阀控制多功能放大器

适用于控制无电位置反馈的三位四通比例方向阀&#xff0c;两路独立工作的比例放大器&#xff0c;可组合成并联工作方式&#xff0c;0到10V输入接口&#xff0c;可切换为0(4)到20mA输入&#xff0c;工作电压24VDC&#xff0c;允许工作温度范围0~45℃&#xff0c;放大器只有在断电…