(表征学习论文阅读)A Simple Framework for Contrastive Learning of Visual Representations

news2024/11/25 14:30:08

Chen T, Kornblith S, Norouzi M, et al. A simple framework for contrastive learning of visual representations[C]//International conference on machine learning. PMLR, 2020: 1597-1607.

1. 前言

本文作者为了了解对比学习是如何学习到有效的表征,对本文所提出的三大组件进行了全面的研究:

  1. 各种数据增强手段的组合在表征学习中起到了重要作用;
  2. 在表征和对比损失之间引入非线性变换能够有效提高表征质量;
  3. 对比学习相较于监督学习需要更大的batch size和更多的训练步数。

在没有人类标注或者监督的情况下学习数据的有效表征是一个长期存在的难题,目前的主要工作可以分为两类:

  1. 基于生成模型的方法
    例如VQ-VAE,MAE,BERT
  2. 基于判别模型的方法
    例如MoCo,CLIP

2. 方法

本文提出了一个框架SimCLR,通过最大化同一数据的不同数据增强处理后的两个视角之间的相似度来学习有效表征。
在这里插入图片描述

  1. 如图所示,本文首先将数据 x x x进行两个不同的增强,这里作者使用了三种简单的数据增强方法:随机裁剪后再调整到原始大小、随机颜色失真、高斯模糊。
  2. f ( ∙ ) f(\bullet) f()代表编码器,这里作者使用的是同一个编码器来对两个视角数据进行编码
  3. 最后编码器输出的结果通过非线性变换 g ( ∙ ) g(\bullet) g()得到 z i z_i zi z j z_j zj,两个向量构成了一组正例,进行相似度计算,也就是简单的单位向量内积计算出余弦相似度。目标就是最大化两者的余弦相似度。同时,一个batch中其他的数据构成了负例,最小化与负例的相似度。注意最终训练完成的编码器我们是需要舍弃掉非线性变换的。
    本文使用的损失函数就是最基本的InfoNCE损失,具体可以参考我的另一篇讲解InfoNCE的博文。
    在这里插入图片描述
    在这里插入图片描述

3. 代码

这里仅提供文章提到的两个点的代码:

  1. 数据增强
    高斯模糊
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

np.random.seed(0)


class GaussianBlur(object):
    """blur a single image on CPU"""
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img

组合各类增强手段

class ContrastiveLearningDataset:
    def __init__(self, root_folder=r"D:\pyproject\representation_learning\data"):
        self.root_folder = root_folder

    @staticmethod
    def get_simclr_pipeline_transform(size, s=1):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * size)),
                                              transforms.ToTensor()])
        return data_transforms

    def get_dataset(self, name, n_views):
        valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
                                                              transform=ContrastiveLearningViewGenerator(
                                                                  self.get_simclr_pipeline_transform(32),
                                                                  n_views),
                                                              download=True),

                          'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                          transform=ContrastiveLearningViewGenerator(
                                                              self.get_simclr_pipeline_transform(96),
                                                              n_views),
                                                          download=True)}

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()
  1. 非线性变换
class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        # 修改resnet最后一层的全连接层即可
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1578878.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vscode 中调试Django程序

调试介绍: ​​​​​​​Explore the debugger Debug/调试 可以让我们在特定的代码行上暂停程序的运行。当程序暂停时,我们可以查看变量的数值,在“Debug控制台”中运行代码,或利用“Debug”工具提供的其他功能。启动Debugger/调试器会自动…

MAC苹果电脑如何使用Homebrew安装iperf3

一、打开mac终端 找到这个终端打开 二、终端输入安装Homebrew命令 Homebrew官网地址:https://brew.sh/ 复制这个命令粘贴到mac的终端窗口,然后按回车键 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/in…

rsync实时同步(上行同步)

目录 一、实现实时同步 1. 定期同步的不足 2. 实时同步的优点 3. Linux内核的inotify机制 4. 发起端配置rsyncinotify 4.1 修改rsync源服务器配置文件 4.2 调整inotify内核参数 4.3 安装inotify-tools 4.4 在另一个终端编写触发式同步脚本 4.5 验证 二、使用rsync实现…

电脑与多台罗克韦尔AB PLC无线通讯的搭建方法分为几步?

在实际系统中,同一个车间里分布多台PLC,通过上位机集中控制。通常所有设备距离在几十米到上百米不等。在有通讯需求的时候,如果布线的话,工程量较大耽误工期,这种情况下比较适合采用无线通信方式。本方案以组态王和2台…

WinForm用微软打包工具打包

WinForm用微软打包工具打包 1. 安装扩展 下载扩展Microsoft Visual Studio Installer Projects 点击扩展 —> 管理扩展 安装完之后重启VS就好了。 2. 新建Set up项目 点击解决方案 —> 添加 —> 新建项目 选择这个Setup Project 创建打包项目 安装项目&…

nandgame中的asm编程PUSH_VALUE、ADD、SUB、NEG、AND、OR

参考:https://zhuanlan.zhihu.com/p/613188641 PUSH_VALUE题目说明及答案 将值推送到堆栈上。 提示:该值将作为宏的替换值提供,但在测试时,您可以在“测试工具”框中设置该值。现在,我们引入了一个使用占位符的宏。宏…

nginx This request has been blocked; the content must be served over HTTPS问题处理

This request has been blocked; the content must be served over HTTPS问题处理 1.问题现象2.解决问题3.解决后的现象4.proxy_set_header x-forwarded-proto 作用 1.问题现象 Mixed Content: The page at https://www.ssjxx.cn/ssjy/viy-edu/index.html?systemCodeTW0010#/…

iOS-获取Xcode工程中文件的路径

1、使用Create folder references的Add folders的方式把文件或者文件夹拖到Xcode工程中 拖入时的设置参考下图 注意拖入到工程之后文件夹是蓝色的(Xcode10.1环境) 2、代码具体实现: 使用NSBundle的API,然后拼接具体路径即可 NS…

区块链与数字身份:探索Facebook的新尝试

在数字化时代,随着区块链技术的崛起,数字身份成为了一个备受关注的话题。作为全球最大的社交媒体平台之一,Facebook一直在探索如何利用区块链技术来改善数字身份管理和用户数据安全。本文将深入探讨Facebook在这一领域的新尝试,探…

【ArcGIS微课1000例】0109:ArcGIS计算归一化水体指数(NDWI)

文章目录 一、加载数据二、归一化水体指数介绍三、归一化水体指数计算四、注意事项一、加载数据 加载配套数据0108.rar(本实验的数据与0108的一致)中的Landsat8的8个单波段数据,如下所示: Landsat8波段信息对照表如下表所示: 接下来学习在ArcGIS平台上,基于Landsat8数据…

GPT-4对多模态大模型在多模态预训练、 理解生成上的启发

传统人工智能 模型往往依赖大量有标签数据的监督训练,而且一个模型一般只能解决一个任务,仅适用于单一场景, 这使得人工智能的研发和应用成本高,场景适应能力弱,难以规模化应用。 常见的多模态任务大致可以分为两类: 多模态理解任务,如视频 分类、视觉问答、跨模态检索、指代…

5G如何助力物流智能化转型

导语 大家好,我是智能仓储物流技术研习社的社长,你的老朋友,老K。行业群 新书《智能物流系统构成与技术实践》人俱乐部 整版PPT和更多学习资料,请球友到知识星球 【智能仓储物流技术研习社】自行下载 智能制造-话题精读 1、西门子…

Maven的scope详解

依赖范围介绍 maven 项目不同的阶段引入到classpath中的依赖是不同的,例如,编译时,maven 会将与编译相关的依赖引入classpath中,测试时,maven会将测试相关的的依赖引入到classpath中,运行时,mav…

菜狗学前端之JS高级笔记

老样子。复制上来的图片都没了,想看原版可以移步对应资源下载(资源刚上传,还在审核中) (免费) JS高级笔记https://download.csdn.net/download/m0_58355897/89102910 一些前提概念 一 什么是js高级 js高级是对js基础语法的一个补充说明,本质…

JetBrains IDE(IDEA/WebStorm)配置GitHub Copilot

关于 GitHub Copilot 和 JetBrains IDE GitHub Copilot 在编写代码时提供 AI 对程序员的自动完成样式的建议。 有关详细信息,请参阅“关于 GitHub Copilot Individual”。 如果使用 JetBrains IDE,可以直接在编辑器中查看并合并来自 GitHub Copilot 的…

asp.net网上水果销售平台 水果购物商城系统+sqlserver

网上水果销售平台 说明文档 运行前附加数据库.mdf(或sql生成数据库) 主要技术: 基于asp.net架构和sql server数据库 功能模块: asp.net网上水果销售平台 水果购物商城系统 用户功能有 网站首页 全部水果 我的订单 购物车用户…

python-flask后端知识点

anki 简单介绍: 在当今信息爆炸的时代,学习已经不再仅仅是获取知识,更是一项关于有效性和持续性的挑战。幸运的是,我们有幸生活在一个科技日新月异的时代,而ANKI(Anki)正是一款旗舰级的学习工具…

探索算力(云计算、人工智能、边缘计算等):数字时代的引擎

引言 在数字时代,算力是一种至关重要的资源,它是推动科技创新、驱动经济发展的关键引擎之一。简而言之,算力即计算能力,是计算机系统在单位时间内完成的计算任务数量或计算复杂度的度量。随着科技的不断发展和应用范围的不断扩大…

在ubuntu系统上安装ffmpeg支持rrweb使用rrvideo对视频文件转mp4格式遇到的一些问题及解决办法

在ubuntu系统上安装ffmpeg支持rrweb使用rrvideo对视频文件转mp4格式遇到的一些问题及解决办法 1,ubuntu系统上安装ffmpeg4.4.1稳定版本1,ubuntu系统上安装ffmpeg4.4.1稳定版本 按照ChatGPT3.5来 sudo apt updatesudo apt install build-essential git sudo apt-get instal…

Qt 中的项目文件解析和命名规范

🐌博主主页:🐌​倔强的大蜗牛🐌​ 📚专栏分类:QT❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、Qt项目文件解析 1、.pro 文件解析 2、widget.h 文件解析 3、main.cpp 文件解析 4、widget.cpp…