浅谈 EMP-SSL + 代码解读:自监督对比学习的一种极简主义风

news2025/1/20 3:52:27

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):
    def __init__(self, num_patch = 4):
    
        self.num_patch = num_patch
      
    def __call__(self, x):
    
    
        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GBlur(p=0.1),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.ToTensor(),  
            normalize
        ])
        augmented_x = [aug_transform(x) for i in range(self.num_patch)]
     
        return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):
    net.zero_grad()
    opt.zero_grad()
        
    data = torch.cat(data, dim=0) 
    data = data.cuda()
    z_proj = net(data)
            
    z_list = z_proj.chunk(num_patches, dim=0)
    z_avg = chunk_avg(z_proj, num_patches)
            
    # Contractive Loss
    loss_contract, _ = contractive_loss(z_list, z_avg)
    loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):
    def __init__(self, ):
        super().__init__()
        pass

    def forward(self, z_list, z_avg):
        z_sim = 0
        num_patch = len(z_list)
        z_list = torch.stack(list(z_list), dim=0)
        z_avg = z_list.mean(dim=0)
        
        z_sim = 0
        for i in range(num_patch):
            z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
            
        z_sim = z_sim/num_patch
        z_sim_out = z_sim.clone().detach()
                
        return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):
    z_list = z.chunk(num_patches,dim=0)
    loss = 0
    for i in range(num_patches):
        loss += criterion(z_list[i])
    loss = loss/num_patches
    return loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/884145.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从0到1:通用后台管理系统 Vue3使用wangEditor

那么这一节我们在编辑公司信息的弹窗中使用富文本插件wangEditor官网 Vue3使用wangEditor 安装wangEditor在弹窗中引入wangEditor结构api接口部分editor组件script部分怎么去修改富文本的编辑器? 案例内效果: 安装wangEditor npm install wangeditor/…

【D3.js 01】

D3.js 01 说在前面1 概述2 配置Web环境3 HTML4 SVG5 DOM6 JS7 常用接口8 D3语法基础9 使用D3查询SVG10 使用D3设置SVG中属性11 修改整组属性12 使用D3添加与删除SVG元素13 数据读取 —— CSV数据14 D3.js的数值计算15 比例尺Scale - LinearScale - Band 16 引入坐标轴17 DATA-J…

通过网络和SD卡连接开发板

SD卡 有时候相关代码改动以后想验证能否正常工作,如果编译代码又需要好久,所以可以通过SD卡拷贝到板子里验证: 将SD卡插入读卡器,将读卡器插入ubuntu主机上,将相关带动的代码文件拷贝到SD卡中。假设你的板子已经具备…

LLMs大模型plugin开发实战

一、概述 ChatGPT是通用语言大模型,如果用户想要在与大模型进行交互时能够使用到企业私有的数据,那么可以通过开发plugin(插件)的方式来实现,另外GPT3.5模型的训练数据是截止到2021年9月,如果想让模型能够…

leetcode228. 汇总区间

题目 给定一个 无重复元素 的 有序 整数数组 nums 。 返回 恰好覆盖数组中所有数字 的 最小有序 区间范围列表 。也就是说,nums 的每个元素都恰好被某个区间范围所覆盖,并且不存在属于某个范围但不属于 nums 的数字 x 。 列表中的每个区间范围 [a,b]…

python矩阵形状和乘法

python矩阵的形状 A np.array([[[1],[2],[3]],[[4],[5],[6]]])AA np.array([[1,2,3],[4,5,6]])print(A) print(A.shape) print(AA) print(AA.shape)python矩阵的乘法 A np.array([[1, 2, 3, 4],[1, 2, 3, 4],[1, 1, 1, 1],[1, 1, 1, 1]]) B np.array([[1],[2],[1],[2]])C …

由主机的IP地址计算主机所在子网的广播地址(子网划分)

子网掩码是一个与IP地址相对应的、长32bit的二进制串,它由一串1和跟随的一串0组成。 其中,1对应于IP地址中的网络号及子网号,而0对应于主机号。计算机只需将IP地址和其对应的子网掩码逐位“与”(逻辑AND运算)&#xff…

新物联网卡智能管理系统源代码下载

新物联网卡智能管理系统源代码现已开放供大家使用。该系统具有强大的物联网卡管理功能,可以帮助您实现自动化管理,提高效率。 相比其他同类系统,本系统具有更高的灵活性和可定制性,能够满足您的各种需求。 授权机制已删除&#…

【C++11保姆级教程】delete和default关键字

文章目录 前言一、delete关键字1.1 什么是delete关键字?1.2delete关键字的语法和用法1.3delete关键字的作用和优势 二、default关键字2.1 什么是default关键字?2.2default关键字的语法和用法2.3 default关键字的作用和优势 总结 前言 欢迎来到本教程&am…

谁是 “凶手” !

找“凶手” 解题方法!🕵️‍ 近日,日本米花町发生了一起凶杀案,警察通过排查确定杀人凶手必为4个嫌疑犯中的一个。 以下为4个嫌疑犯的供词: A说:不是我。 B说:是C。 C说:是D。 D说&…

Mybatis多表查询与动态SQL的使用

目录 1. Mybatis多表查询 1.1 添加文章表实体类 1.2 文章Interface 1.3 文章.xml 1.4 lombok的toString()有关对象打印的说明 1.5 场景: 一个用户查询多篇文章 2. 复杂情况: 动态SQL的使用 2.1 为什么要使用动态SQL? 2.2 <if>标签 2.3 <trim>标签 2.4 <where&g…

如何在 iOS 上安装并使用 ONLYOFFICE 文档

借助 iOS 版文档应用&#xff0c;您可在移动端设备上访问存储于 ONLYOFFICE 账户中的文件&#xff0c;查看和编辑现有文本文档、电子表格和演示文稿&#xff0c;创建新文档并对其进行整理&#xff0c;以及连接第三方云存储服务。您可与其他门户网站用户协作编辑文档&#xff0c…

【第三阶段】kotlin语言的可空性

1.kotlin语言默认是不可空类型&#xff0c;不能随意给null fun main() {var name:String"kotlin"namenull }执行结果 报错&#xff1a; Null can not be a value of a non-null type String2.声明可空类型 &#xff1f; fun main() {var name:String ?namenull…

语聚AI公测发布,大语言模型时代下新的生产力工具

语聚AI 公测发布 距离语聚AI内测上线已经过去近1个月。 这期间&#xff0c;我们共邀请了近百位资深用户与行业专家加入语聚AI产品体验。通过大家的热情参与积极反馈&#xff0c;我们不断优化并完善了语聚AI的功能与使用体验。 经过研发团队不懈的努力&#xff0c;今天语聚AI终…

[谦实思纪 01]整理自2023雷军年度演讲——《成长》(上篇)武大回忆(梦想与成长)

文章目录 [谦实思纪]整理自2023雷军年度演讲 ——《成长》&#xff08;上篇&#xff09;武大回忆&#xff08;梦想与成长&#xff09;0. 写在前面1. 梦开始的地方1.1 要有梦想&#xff0c;要用目标量化梦想 2. 在两年内修完所有的学分。2.1 别老自己琢磨&#xff0c;找个懂的人…

【C++】STL案例1-评委打分

0.前言 1.系统自动生成的评委评分代码&#xff1a; #include <iostream> using namespace std; #include <deque> #include <vector> #include <algorithm> #include <string>//选手类 class Player { public:Player(string name, float score)…

Python:LVGL与触摸屏的调试记录

在移远模块EC-600M上驱动电容触摸屏&#xff0c;触摸屏控制IC为FT6206。 一、接口 TP屏的管脚如下&#xff0c;有6PIN。使用I2C接口通讯 所以我们用模块的I2C1通道&#xff0c;模块的IO口电压也是1.8v 二、I2C从地址 FT6x06芯片相对于主机来说是一个I2C设备 因此需要一个I2C…

3.1 命名空集using声明

博主介绍&#xff1a;爱打游戏的计算机专业学生 博主主页&#xff1a;夏驰和徐策 所属专栏&#xff1a;夏驰和徐策带你从零开始学C 前言&#xff1a; 第2章介绍的内置类型是由C语言直接定义的。这些类型&#xff0c;比如数字和字符&#xff0c;体现了大 多数计算机硬件本身具…

Python弹球小游戏

给在校的小妹妹做个游戏玩&#xff1a;. 弹珠游戏主要是靠坐标xy&#xff0c;接板长度&#xff0c;球的半径等决定&#xff1a; # -*- coding: utf-8 -*- # Author : Codeooo # Time : 2022/04/29import sys import time import random import pygame as pgprint("&q…

TCP三次握手,四次挥手,SYN泛洪攻击

目录 一.三次握手 二.SYN泛洪攻击概念 三.四次挥手 一.三次握手 当客户端调用connect连接服务器时,底层会发生“三次握手”&#xff0c;握手成功&#xff0c;建立连接,connect解阻塞&#xff0c;继续执行。 TCP报头&#xff1a; 三次握手过程&#xff1a; 客户端发出SYN请求…