SimSiam-Exploring Simple Siamese Pepresentation Learning

news2024/11/15 7:27:40

SimSiam

Abstract

模型坍塌,在siamese中主要是输入数据经过卷积激活后收敛到同一个常数上,导致无论输入什么图像,输出结果都能相同。

而He提出的simple Siamese networks在没有采用之前的避免模型坍塌那些方法:

  • 使用负样本
  • large batches
  • momentum encoders(论文直接用的encoder)

实验表明对于损失和结构确实存在坍塌解,但stop-gradient操作在防止坍塌方面起着至关重要的作用。

Method

如图为simsiam 的结构,输入是训练集中随机选取的一个图像,使用随机数据增强生成两个图像;左右两个encoder是完全一样的,包含卷积和全连接,将图像进行编码(特征提取);perdictor 是一般的MLP,左右都是有predictor模块的(看伪代码),只右侧是没画出来,用来转换视图的输出,并将其与另一个视图相匹配,(encoder是一样的,x1和x2即使经过数据增强大小也是一样的,那为啥要再加一个predictor模块使两个视图相匹配呢?);

similarity是对比predictor输出的特征向量,loss为经过encoder的p和predictor的输出z,p1和z2对比,p2和z1的负余弦相似度 如 D ( p 1 , z 2 ) = − p 1 ∣ ∣ p 1 ∣ ∣ 2 z 2 ∣ ∣ z 2 ∣ ∣ 2 D(p_1,z_2)=-\frac{p_1}{||p_1||_2} \frac{z_2}{||z_2||_2} D(p1,z2)=p12p1z22z2 (论文中说这个与l2正则化的mse相同?)

总的网络的loss 为 L = D ( p 1 , z 2 ) / 2 + D ( p 2 , z 1 ) / 2 L=D(p_1, z_2)/2 + D(p_2, z_1)/2 L=D(p1,z2)/2+D(p2,z1)/2

在这里插入图片描述

# f: backbone + projection mlp
# h: prediction mlp 
for x in loader: # load a minibatch x with n samples
    x1, x2 = aug(x), aug(x) # random augmentation对图像进行随机数据增强,这样就生成 
    z1, z2 = f(x1), f(x2) # projections, n-by-d encodeer的计算
    p1, p2 = h(z1), h(z2) # predictions, n-by-d predictor的计算
    L = D(p1, z2)/2 + D(p2, z1)/2 # loss  两个向量的负余弦相似度
    L.backward() # back-propagate
    update(f, h) # SGD update
def D(p, z): # negative cosine similarity
    z = z.detach() # stop gradient
    p = normalize(p, dim=1) # l2-normalize
    z = normalize(z, dim=1) # l2-normalize
    return -(p*z).sum(dim=1).mean()

在backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor。

如果不想要被继续追踪,可以调用.detach()将其从追踪记录中分离出来,这样就可以防止将来的计算被追踪,这样梯度就传不过去了。还可以用with torch.no_grad()将不想被追踪的操作代码块包裹起来,这种方法在评估模型的时候很常用,因为在评估模型时,我们并不需要计算可训练参数(requires_grad=True)的梯度。

上面将z给detach了, z 2 ∣ ∣ z 2 ∣ ∣ 2 \frac{z_2}{||z_2||_2} z22z2所以会被看成为常数只有 p 1 ∣ ∣ p 1 ∣ ∣ 2 \frac{p_1}{||p_1||_2} p12p1会产生梯度,

为了进一步确认那一部分的设计在本文的框架中是至关重要的,作者设计了以下的消融实验。


Empirical Study

stop grad
在这里插入图片描述

显然如果使两侧的梯度都进行传递网络的loss是非常小的,因为两个网络的参数是接近一模一样的所以两个网络很容易就达到一致了。而且这样的性能表现是非常差的,因为很容易达到两个网络参数一样,最后导致模型坍塌。实际上并不能学到什么有效的特征。


在这里插入图片描述

使用不同的predictor的结果

如果没有predictor模型不work(原因作者没说);

如果预测MLP头模块h固定为随机初始化,该模型同样不再有效,这是因为模型不收敛,loss太高;

当预测MLP头模块采用常数学习率时,该模型甚至可以取得比基准更好的结果,作者也提出了一个可能的解释:h应当适应最新的表征,所以不需要在表征充分训练之前使用降低学习率的方法迫使其收敛。

不同Batch Size

在这里插入图片描述

探究了不同的batch对精度的影响,虽然基础 l r lr lr是0.05,但是学习率会随着batch的变化做线性缩放 l r × B a t c h S i z e / 256 lr×BatchSize/256 lr×BatchSize/256 ,对于batch大于1024时,会采用10个epoch的warm-up学习率。

作者探究了SGD在较大batch上会导致性能退化,但同时也证明了优化器不是防止崩溃的必要条件。


Batch Normalization

在这里插入图片描述

移除BN之后可能因为难优化造成了性能下降,但是并没有造成collapsing,只加在隐层精度会提高到67.4%,如果在投影MLP中也加上BN则会提升到68.1%。但是如果把BN加到预测MLP上,就不work了,作者探究了这也不是崩溃问题,而是训练不稳定,loss震荡。

总结下来就是,BN在监督学习和非监督学习中都会使模型易于优化,但是并不能防止collapsing。


Similarity Function

除了余弦相似函数之外,该方法在交叉熵相似函数下也work,这里的softmax是channel维度的,softmax的输出可以认为是属于d个类别中每个类别的概率。

(img-DQyi1Tgo-1670137723538)(https://gitee.com/lizheng0219/picgo_img/raw/master/img/image-20221130170302429.png)]

在这里插入图片描述

可以看出使用交叉熵相似性依然可以很好地收敛,并没有崩溃,所以避免collapsing与余弦相似性无关。

结果比较

如下图7所示,SimSiam小的batch和没有负样本、momentum encoder的情况下仍然能取得较好的效果。

在这里插入图片描述

Hypothesis

为什么这样简单的网络能够work呢?作者提出了一种猜想:SimSiam实际上是一种Expectation-Maximization(EM)的算法。——最大期望算法。

我们最熟悉的最大期望算法就是k-means算法。

L ( θ , η ) = E x , T [ ∥ F θ ( T ( x ) ) − η x ∥ 2 2 ] L(\theta,\eta)=\mathbb{E}_{x,\mathcal{T} }[\|\mathcal{F} _\theta(\mathcal{T}(x)) - \eta_x\|_2 ^2 ] L(θ,η)=Ex,T[Fθ(T(x))ηx22]

这里x输入图像 T \mathcal{T} T是图像的一种增强, F θ \mathcal{F} _\theta Fθ是encoder, η x \eta _x ηx不一定局限于图像表征,在训练网络时我们希望找到一个 θ \theta θ,找到一个 η \eta η,使得loss的期望是最小的。

在每一步中首先会确定一个 θ \theta θ使得 loss 最小,这时使用的是一个固定的 η \eta η,从而得到 θ t \theta^t θt

θ t ← arg ⁡ min ⁡ θ L θ η t − 1 \theta^t \gets \mathop{\arg\min}_{\theta} \mathcal{L}\theta\eta^{t-1} θtargminθLθηt1(公式 2)

锁定 θ \theta θ,寻找一个使 loss 达到最小的 η \eta η

η t ← arg ⁡ min ⁡ η L ( θ t \eta^t \gets \mathop{\arg \min}_\eta \mathcal{L}(\theta^t%2C \eta ηtargminηL(θt))

反复进行以上两步最终使训练得到一个满意的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/60024.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

K_A08_003 基于 STM32等单片机驱动L9110模块按键控制直流电机正反转加减速启停

目录 一、资源说明 二、基本参数 1、参数 2、引脚说明 三、驱动说明 L9110模块驱动时序 对应程序: PWM信号 四、部分代码说明 接线说明 1、STC89C52RCL9110模块 2、STM32F103C8T6L9110模块 五、基础知识学习与相关资料下载 六、视频效果展示与程序资料获取 七、项目主要…

【Android工具】群晖安卓客户端基础套件:Drive、video、Photos和DS video安卓TV客户端...

微信关注公众号 “DLGG创客DIY”设为“星标”,重磅干货,第一时间送达。最近终于把all in one搞起来了,all in one就是把一堆功能一堆软件装一台主机里。。all in one(以后简称AIO)相关内容回头慢慢聊。今天先聊聊群晖&…

从一个demo说elf文件

本文的demo是在linux环境下编译解析的,cpu是x86-64 首先我们先写一个功能简单的demo-SimpleSection.c。这个demo中有一个func1函数用来打印数据,一个已经初始化的全局变量global_init_var和未初始化的全局变量global_uninit_var,一个已初始化…

使用TS 封装 自定义hooks,实现不一样的 CRUD

文章目录使用TS 封装 自定义hooks,实现不一样的 CRUD自定义 hooks文件结构type.tsuseDelData.ts使用useFetchList.ts使用useInsert.ts使用部分的接口方法使用TS 封装 自定义hooks,实现不一样的 CRUD 这一篇主要是记录 查缺补漏,提升自己的 强…

三、内存管理 (一)存储器管理

目录 1.1程序运行的基本过程 1.1.1 编辑、编译、链接、装入 1.1.2链接的三种方式 1.1.3装入的三种方式 1.2内存管理基本概念 1.2.1内存保护 1.2.2内存空间扩充 1.2.3地址转换功能 1.2.4内存空间的分配与回收 1.2.4.1连续分配管理方式 1.2.4.1.1单一连续分配 1.2.4.1…

Http协议和Https协议

Http是不安全的,你的数据容易被黑客拦截,篡改,攻击 https要求对数据加密(不能明文传输), 用抓包工具抓http请求,抓出来的都是明文的,你能看得懂的,抓https请求,抓出来的…

网站域名被QQ拦截提示:当前网页非官方页面拦截的解决办法

今天网友提醒,星空站长网的链接被QQ屏蔽拦截了。提示:当前页面非官方页面,请复制到浏览器打开。 如图: 原因:这是因为QQ方面的诈骗信息特别多,所以腾讯官方索性就直接屏蔽了所有的外部链接。让站长们通过申…

Python源码剖析笔记1-整数对象PyIntObject

1、PyIntObject 对象 [intobject.h] typedef struct {PyObject_HEADlong ob_ival; } PyIntObjectPyIntObject是一个不可变(immutable)对象。Python内部也大量的使用整数对象,我们在自己的代码中也会有大量的创建销毁整型对象的操作&#xff…

SVM 用于将数据分类为两分类或多分类(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

CMake中install的使用

CMake中的install命令用于指定安装时要运行的规则&#xff0c;其格式如下&#xff1a; install(TARGETS targets... [EXPORT <export-name>][RUNTIME_DEPENDENCIES args...|RUNTIME_DEPENDENCY_SET <set-name>][[ARCHIVE|LIBRARY|RUNTIME|OBJECTS|FRAMEWORK|BUNDL…

基于单片机的电压电流表设计

原理图&#xff1a; 部分程序&#xff1a; #include "stc15.h" #include "delay.h" #include "timer.h" #include "TM7707.h" #include "LCD1602.h" #include "eeprom.h" #include "stdio.h" #include…

【数学】双根号求值域问题

∣双根号求值域问题NightguardSeries.∣\begin{vmatrix}\Huge{\textsf{ 双根号求值域问题 }}\\\texttt{ Nightguard Series. }\end{vmatrix}∣∣∣∣∣​ 双根号求值域问题 Nightguard Series. ​∣∣∣∣∣​ 求 f(x)3x−63−xf(x)\sqrt{3x-6}\sqrt{3-x}f(x)3x−6​3−x​ 的…

开发工具——gdb

开发工具gdb gdb在Linux下负责程序的调试。 gdb相较于vs2019的调试&#xff0c;是不方便的。图形化界面调试确实是一种进步的现象。 先编写一个简单的程序&#xff0c;如果不支持在for循环中定义变量&#xff0c;要在编译指令后面加上-stdc99选项。 要编译的文件和要生成的文…

Spring 源码编译

Spring 源码编译&#xff0c;一定要选好版本&#xff01;&#xff01;&#xff01; Spring 源码编译&#xff0c;一定要选好版本&#xff01;&#xff01;&#xff01; Spring 源码编译&#xff0c;一定要选好版本&#xff01;&#xff01;&#xff01; 重要的事说三遍。 Spri…

MYSQL用函数请三思

背景&#xff1a;最近公司有个同事遇到个需求需要用到mysql sleep函数&#xff0c;但结果却大出意料. 测试如下&#xff1a; 表&#xff1a; CREATE TABLE test_sleep ( id int NOT NULL AUTO_INCREMENT, a int NOT NULL, b int NOT NULL, PRIMARY KEY (id), KEY a (a) ) ENGIN…

电子学会2021年3月青少年软件编程(图形化)等级考试试卷(一级)答案解析

青少年软件编程&#xff08;图形化&#xff09;等级考试试卷&#xff08;一级&#xff09; 分数&#xff1a;100.00 题数&#xff1a;37 一、单选题&#xff08;共25题&#xff0c;每题2分&#xff0c;共50分&#xff09; 1. 花花幼儿园有三个班。根据下面三句话&…

CentosLinux 7 字符安装教程

打开VMware虚拟机,点击文件 — 新建虚拟机选项。在弹出的对话框中选择自定义(高级)选项。单机下一步。 以下步骤根据自己的所需自行配置

[附源码]Python计算机毕业设计Django酒店在线预约咨询小程序

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

健身中心管理系统/健身房管理系统

摘 要 随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;各行各业相继进入信息管理时代&…

【OpenCV-Python】教程:3-16 利用Grabcut交互式前景提取

OpenCV Python Grabcut分割 【目标】 Grabcut 算法创建一个交互程序 【理论】 从用户角度是如何工作的呢&#xff1f;用户在需要的目标上初始绘制一个矩形&#xff0c;前景目标必须完全在矩形内部&#xff0c;算法迭代的去分割然后得到更好的效果&#xff0c;但是有些情况下…