论文笔记:A Simple Framework for Contrastive Learning of Visual Representations

news2024/11/25 14:33:29

0 简介

论文:A Simple Framework for Contrastive Learning of Visual Representations
代码:https://github.com/google-research/simclr
发表:2020年发表在ICML会议上

1 核心思想

如何构建对比学习的比较对象?本文按如下方式进行构建:

  • 数据增强:输入 x x x,增强为 x ~ i \tilde{x}_i x~i( t ∼ T t \sim \mathcal{T} tT)和 x ~ j \tilde{x}_j x~j( t ′ ∼ T t^\prime \sim \mathcal{T} tT),获得两个相关的视角,这两个相关的视角的距离越近越好;
  • 和其他图片增强的视角进行对比: x x x的视角和其他图片增强得到的视角距离越远越好。
    在这里插入图片描述

1.1 总体步骤

具体步骤如下:

  • 输入图像 x x x,对其进行两种不同增强得到两张新图片 x ~ i \tilde{x}_i x~i( t ∼ T t \sim \mathcal{T} tT)和 x ~ j \tilde{x}_j x~j( t ′ ∼ T t^\prime \sim \mathcal{T} tT);
  • 将两张新图片输入ResNet,即 f ( ⋅ ) f(\cdot) f()提取特征,得到 h i , h j h_i, h_j hi,hj
  • 两个特征向量经过MLP网络,即 g ( ⋅ ) g(\cdot) g()处理,得到 z i , z j z_i, z_j zi,zj

假设batch size大小为 N N N,经过数据增强,可以得到 2 N 2N 2N张图像。
SimCLR在对比学习时,需要正负例:

  • z i , z j z_i, z_j zi,zj构成正例;
  • z i z_i zi与batch size中其他图像(包括数据增强后的图像)的特征向量组成负例对,因此一张图片将存在1个正例对, 2 N − 2 2N − 2 2N2个负例对。

一张图片的损失函数为:
ℓ i , j = − log ⁡ exp ⁡ ( sim ⁡ ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( sim ⁡ ( z i , z k ) / τ ) \ell_{i, j}=-\log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_j\right) / \tau\right)}{\sum_{k=1}^{2 N} \mathbb{1}_{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_k\right) / \tau\right)} i,j=logk=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)
其中 sim ⁡ ( z i , z j ) \operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_j\right) sim(zi,zj)表示余弦相似度, 1 [ k ≠ i ] ∈ { 0 , 1 } \mathbb{1}_{[k \neq i]} \in \{0, 1\} 1[k=i]{0,1},当 k ≠ i k \neq i k=i等于1, k = = i k == i k==i等于0, τ \tau τ为温度系数。
2 N 2N 2N张图像的损失函数之和求平均,得到最终的损失函数:
L = 1 2 N ∑ k = 1 N [ ℓ ( 2 k − 1 , 2 k ) + ℓ ( 2 k , 2 k − 1 ) ] . \mathcal{L} = \frac{1}{2N} \sum_{k = 1}^{N}\left[\ell(2k-1, 2k) + \ell(2k, 2k-1)\right]. L=2N1k=1N[(2k1,2k)+(2k,2k1)].

1.2 增强图片的方式

  • 随机裁剪(random cropping);
  • 随机颜色失真(random color distortions);
  • 随机高斯模糊(random Gaussian blur)。
    在这里插入图片描述
    实矩形是原始图像,虚线矩形是随机裁剪。通过随机裁剪图像,我们采样对比预测任务,包括全局到局部视图( B → A B \rightarrow A BA)或相邻视图( D → C D \rightarrow C DC)预测。
    在这里插入图片描述

1.2 特征提取

h i = f ( x ~ i ) = ResNet ⁡ ( x ~ i ) \boldsymbol{h}_{i}=f\left(\tilde{\boldsymbol{x}}_{i}\right)=\operatorname{ResNet}\left(\tilde{\boldsymbol{x}}_{i}\right) hi=f(x~i)=ResNet(x~i)
其中 h i ∈ R d \boldsymbol{h}_{i} \in \mathbb{R}^d hiRd
z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) \boldsymbol{z}_{i}=g\left(\boldsymbol{h}_{i}\right)=W^{(2)} \sigma\left(W^{(1)} \boldsymbol{h}_{i}\right) zi=g(hi)=W(2)σ(W(1)hi)
其中 σ \sigma σ就是一个ReLU非线性操作。

2 具体算法

在这里插入图片描述
总体分为三个重要的过程:

  • 数据增强,通过两个增强函数操作,图片成对存储 [ ( x ~ 1 , x ~ 2 ) , ( x ~ 3 , x ~ 4 ) , … , ( x ~ 2 k − 1 , x ~ 2 k ) , … , ( x ~ 2 N − 1 , x ~ 2 N ) [(\tilde{x}_1, \tilde{x}_2), (\tilde{x}_3, \tilde{x}_4), \dots, (\tilde{x}_{2k-1}, \tilde{x}_{2k}), \dots, (\tilde{x}_{2N-1}, \tilde{x}_{2N}) [(x~1,x~2),(x~3,x~4),,(x~2k1,x~2k),,(x~2N1,x~2N)
  • 特征提取,经过ResNet(对应 f ( ⋅ ) f(\cdot) f())操作和MLP(对应 g ( ⋅ ) g(\cdot) g())操作后,得到特征向量组 [ ( z 1 , z 2 ) , ( z 3 , z 4 ) , … , ( z 2 k − 1 , z 2 k ) , … , ( z 2 N − 1 , z 2 N ) [(z_1, z_2),(z_3, z_4), \dots, (z_{2k-1}, z_{2k}), \dots, (z_{2N-1}, z_{2N}) [(z1,z2),(z3,z4),,(z2k1,z2k),,(z2N1,z2N)
    在这里插入图片描述
  • 对比学习,先是 x ~ 2 k − 1 \tilde{x}_{2k-1} x~2k1和其它图片进行对比学习,然后是 x ~ 2 k \tilde{x}_{2k} x~2k和其它图片进行对比学习
    在这里插入图片描述

3 实验

本文的实验分析非常有用,讨论了模型在什么情况下更有效,有利于读者选择合适的参数。

3.1 数据增强方式对性能的影响

在这里插入图片描述
【问】:怎么理解这个图?
【答】:以左上角的33.1为例,第一次数据增强采用Crop方法,第二次数据增强采用Crop方法;以左上角33.9为例,第一次数据增强采用Crop方法,第二次数据增强采用Cutout方法。
得到如下三个结论:

  • 单独使用一种数据增强,对比学习的效果会很差;
  • 效果最好的组合:第一次数据增强采用Crop方法,第二次数据增强采用Color方法,得到的精度为56.3;效果次好的组合:第一次数据增强采用Color方法,第二次数据增强采用Crop方法,得到的精度为55.8;
  • 数据增强方式对对比学习的影响非常明显,这不是一个好的性质,很多时候我们需要进行穷举试错。

3.2 模型宽度和深度对性能的影响

在这里插入图片描述
【问】:怎么理解这个图?
【答】:以R18(4x)为例说明,R18表示18层的ResNet网络,4x表示模型宽度加宽4倍。
从图上可以得到如下结论:

  • 增大模型容量时,优先增加模型的深度,比如ResNet152比ResNet18性能高不少,参数量并没有增加多少;
  • 次选增加模型的宽度,比如ResNet18(4x)比ResNet18(2x)性能高一些,但参数量增加较多,导致训练速度变慢。

3.3 特征向量 z i z_i zi的长度对性能的影响

在这里插入图片描述
从图上可以得到如下结论:

  • 向量长度对性能影响不大;
  • 非线性MLP性能优于线性MLP;
  • SimCLR中可以用于线性分类的特征有两个,一是特征提取器的输出 h \boldsymbol{h} h,二是MLP层的输出 g ( h ) g(\boldsymbol{h} ) g(h),在线性分类中,使用 h \boldsymbol{h} h的性能要优于 g ( h ) g(\boldsymbol{h} ) g(h)(大于10%),可能是因为MLP过滤掉了一些有用的信息。

3.4 batch size对性能的影响

在这里插入图片描述
从图上可以得到如下结论:

  • 对于有正负例的对比学习算法而言,batch size越大,效果越好,并且提升显著;
  • 如果只有正例的对比学习算法而言(如BYOL、simsiam),batch size大小对性能影响没有如此显著;
  • 对于有正负例的对比学习算法和只有正例的对比学习算法,训练epoch越长,效果越好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/471992.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国民技术N32G430开发笔记(8)- 内部Flash的读写操作

N32G430 内部Flash的读写操作 1、主存储区最大为 64KB,也称作主闪存存储器,包含 32 个 Page,用于用户程序的存放和运行,以及数 据存储。 每一页的大小为2K字节 2、IAP 升级我们将64K的flash分区如下: Boot 0x800000…

scanf老是出错?带你详细解决输入缓冲区问题

文章目录 1.前言2.getchar 和 putchar3.缓冲区问题3.1先观察一个代码3.2输入缓冲区3.3清除缓冲区 结尾 1.前言 我们一般在进行输入输出的时候,就会用到 scanf / printf 。并且根据格式指定可以输入输出各种类型的数据。可以输入整形,字符,浮…

【元分析研究方法】学习笔记4.评估研究的质量

评估研究的质量 该步骤的作用该步骤中需要注意的问题该步骤中知识点1:判断编码的分类方式该步骤中知识点2:统计识别异常值 参考来源:库珀 (Cooper, H. M. )., 李超平, & 张昱城. (2020). 元分析研究方法: A step-by step approach. 中国人…

SpringBoot的事务与锁

在一人一单问题里,为什么加了事务还是会出现一人下多单呢? 本质的原因是,我们使用Java的对象锁,可以保证临界区只有一个线程访问,但是这和SpringBoot里加Transactional注解不是等价的。数据库里的事务保证的是要么全部…

ChatGPT+Word的智能化文字生成和应用

在Word中引入OpenAI代码需要使用VBA编辑器。以下是在Word中引入OpenAI代码的步骤: 打开Word文档,按下Alt F11键打开VBA编辑器。 在VBA编辑器中,选择“插入”菜单,然后选择“模块”。 在新建的模块中,将OpenAI代码粘…

【教学类-35-01】(256*256*256)RGB色卡图片

作品展示: 背景需求: 甲流传染病,班级来了三位孩子,他们玩折纸的时候讨论, 09号问:“绿色和蓝色混合是什么颜色?” 08号问:“绿色加蓝色加浅蓝合在一起是什么颜色” 17号说&…

逆向学习X64DBG

目标游戏:焰影神兵 目的:更改玩家名称(中文名称) 使用X64dbg可以快速搜索游戏人名,所以本次逆向使用该工具进行工作。 原来的名字:平家物语 现在我们想改成:源氏物语。所以打开X32/64dbg 附…

免费域名申请

title: 免费域名申请 20230428153405|left 🌈Description: ​ 本文将介绍如何免费申请域名,在最近的折腾中发现,域名真的很重要,不然好多服务是无法访问的。 备注:由于freenom基于技术原因,暂时…

感知机学习

定义 感知机:假设输入控件(特征空间)是 X ⊆ R n \mathcal{X} \subseteq \mathbb{R}^n X⊆Rn,输出空间是 Y { 1 , − 1 } \mathcal{Y}\left\{1, -1\right\} Y{1,−1},输入 x ∈ X \mathbf{x}\in\mathcal{X} x∈X表示实例的特征向量&#x…

使用Pano2VR实现背景音乐、放大/缩小、旋转、缩略图和直线/立体/鱼眼模式等

内容简介 本文在文章《使用Pano2VR实现客厅VR效果》基础上,增加背景音乐、放大/缩小、旋转、缩略图和直线/立体/鱼眼模式等;效果如下图(为了可以上传缩小屏幕,属于PC端运行): 实现过程 1. 运行Pano2VR软件后…

【初学人工智能原理】【1】一元一次函数:感知器如何描述直觉

前言 本文教程均来自b站【小白也能听懂的人工智能原理】,感兴趣的可自行到b站观看。 本文【原文】章节来自课程的对白,由于缺少图片可能无法理解,故放到了最后,建议直接看代码(代码放到了前面)。 代码实…

企业管理中,如何组建数据团队

数字化已经成为了当前时代的标志,也变为人们对未来社会发展的共识,一时间数字化相关技术、理念、应用都开始向各行各业普及。此时人工智能、云计算、大数据、互联网、物联网等的发展也越来越快,给人们的生活和企业的经营管理模式带来了深刻改…

设计模式 -- 原型模式

前言 月是一轮明镜,晶莹剔透,代表着一张白纸(啥也不懂) 央是一片海洋,海乃百川,代表着一块海绵(吸纳万物) 泽是一柄利剑,千锤百炼,代表着千百锤炼(输入输出) 月央泽,学习的一种过程,从白纸->吸收各种知识->不断输入输出变成自己的内容 希望大家一起坚持这个过程,也同…

【问题解决】RabbitMQ启动出现epmd error for host xx.xx: nxdomain (non-existing domain)

问题描述 【k8s】或【普通容器】或【Linux】部署的RabbitMQ启动时出现了 epmd error for host xx.xx: nxdomain (non-existing domain) 错误,MQ无法启动成功。 其中 xx.xx 为无法解析的域名。 RabbitMQ官方还提到报错 Error during startup: {error,no_epmd_port}…

回归区间预测 | Matlab基于分位数随机森林算法(QRF)的回归预测

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 Matlab基于分位数随机森林算法(QRF)的回归预测,matlab代码。 基于分位数随机森林算法(QRF)回归预测,matiab代码,单变量输入模型。 评价指标包括:R2、MAE、MSE、RMSE和区间覆盖率和区间平均宽度百分比等,代码质…

【软件测试】自动化测试日志问题该怎么解决?测试老鸟总结方案...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 Python自动化测试&…

Leetcode力扣秋招刷题路-0801

从0开始的秋招刷题路,记录下所刷每道题的题解,帮助自己回顾总结 801. 使序列递增的最小交换次数 我们有两个长度相等且不为空的整型数组 nums1 和 nums2 。在一次操作中,我们可以交换 nums1[i] 和 nums2[i]的元素。 例如,如果 …

为什么越来越多的人开始学习大数据了?

现在,在数字化转型的推动下,越来越多的企业意识到大数据的魅力,并不断在这个领域投入资金,Python大数据开发相关人才也备受青睐! 大数据从业领域很宽广,不管是科技领域还是食品产业,零售业等都…

大数据行业就业前景怎么样呢

就目前的前景来看,大数据的发展的确的非常不错的~ 既然回答大数据的问题,那就让我们到用数据的方式来回答一下。大数据需求越来越多,只有技术在手不愁找不到工作。 先来看几个招聘网站的报告数据:Boss直聘发布的,今年…

Zynq-7000、FMQL45T900的GPIO控制(七)---linux驱动层配置GPIO中断输入

本文使用的驱动代码 (1条消息) FMQL45T900linux驱动外部中断输入ZYNQ-7000linux驱动外部中断输入资源-CSDN文库 在Zynq-7000、FMQL45T900驱动层也时常会用到对GPIO的控制,这里就针对实际使用的情况进行说明,首先根据之前的帖子确实使用GPIO编号 这里采…