ALBEF算法解读

news2025/1/10 16:43:39

ALBEF论文全名Align before Fuse: Vision and Language Representation Learning with Momentum Distillation,来自于Align before Fuse,作者团队为Salesforce Research。
论文地址:https://arxiv.org/pdf/2107.07651.pdf
论文代码:GitCode - 开发者的代码家园

1、灵感来源

a)图 (a) 是VSE架构,它们的文本端就是直接抽一个文本特征,但是它们的视觉端非常大,需要的计算量非常多,因为它通常是一个目标检测器。当得到了文本特征和视觉特征之后,它最后只能做一个很简单的模态之间的交互,从而去做多模态的任务。

b)图(b)是CLIP的结构,视觉端和文本端都用同等复杂度的encoder进行特征提取,再做一个简单的模态交互,结构优点是对检索任务极其有效,因为它可以提前把特征都抽好,接下来直接算Similarity矩阵乘法就可以,极其适合大规模的图像文本的检索,非常具有商业价值。缺点是只计算Cosine Similarity无法做多模态之间深度融合,难一些的任务性能差。

c)图(c)是Oscar或者ViLBERT、Uniter采用的架构,因为对于多模态的任务,最后的模态之间的交互非常重要,只有有了模态之间更深层的交互,VQA、VR、VE这些任务效果才会非常好,所以他们就把最初的简单的点乘的模态之间的交互,变成了一个Transformer的Encoder,或者变成别的更复杂的模型结构去做模态之间的交互,所以这些方法的性能都非常的好,但是随之而来的缺点也很明显:所有的这一系列的工作都用了预训练的目标检测器,再加上这么一个更大的模态融合的部分,模型不论是训练还是部署都非常的困难。

d)图 (d) 是ViLT的架构。当Vision Transformer出来之后,ViLT这篇论文就应运而生了,因为在Vision Transformer里,基于Patch的视觉特征与基于Bounding Box的视觉特征没有太大的区别,它也能做图片分类或者目标检测的任务,因此就可以将这么大的预训练好的目标检测器换成一层Patch Embedding就能去抽取视觉的特征,所以大大的降低了运算复杂度,尤其是在做推理的时候。但是如果文本特征只是简单Tokenize,视觉特征也只是简单的Patch Embedding是远远不够的,所以对于多模态任务,后面的模态融合非常关键,所以ViLT就直接借鉴 c类里的模态融合的方法,用一个很大的Transformer Encoder去做模态融合,从而达到了还不错的效果。因为移除了预训练的目标检测器,换成了可学习的Patch Embedding Layer。

ViLT的优点:模型极其简单。它虽然是一个多模态学习的框架,但跟NLP的框架没什么区别,就是先Tokenized,然后送到一个Transformer去学习,所以非常的简单易学。
ViLT的缺点:1)它的性能不够高,ViLT在很多任务上是比不过 © 类里的这些方法的,原因是对于现有的多模态任务,需要更多的视觉能力(可能是由于数据集的bias),因此视觉模型需要比文本模型要大,最后的效果才能好,但是在ViLT里,文本端用的Tokenizer很好,但是Visual Embedding是Random Initialized,所以它的效果就很差2)ViLT虽然推理时间很快,但它的训练时间非常慢,在非常标准的一个4 million的数据集set上,ViLT需要64张32G的GPU训练三天,它训练的复杂度和训练的时间丝毫不亚于 © 类的方法,所以它只是结构上简化了多模态学习,但训练难度并没有降低。

从上面可以发现,a和b的模态融合是比较弱的,比如b的代表就是clip,他是使用余弦相似度计算文本和图像的相似性,这对于VQA,VE之类的任务效果就比较差。c的代表是ViLBERT,除了多模态融合使用较大模型外,在图像端也使用了较大的模型。d的代表是ViLT,只在多模态融合时使用较大模型。

那么我们发现,b类模型的模态融合方面较弱,适合做检索但不适合做推理问答;cd类模型在模态融合方面较强,适合做推理问答但不适合做检索。

因此,在模型的选择上:

        1、因为有图像的输入和文本的输入,模型应有两个分支分别抽取图像文本特征。
        2、在多模态学习里,视觉特征重要性远远要大于这个文本特征,所以应该使用更大更强的视觉模型。
        3、多模态学习模态之间的融合也非常关键,因此需要模态融合的模型尽可能大,所以好的多模态学习网络结构应该像 ©,也就是文本编码器比图像编码器小,多模态融合的部分尽可能大。

2、ALBEF预训练

ALBEF有三个损失函数:图像文本对比损失(ITC)、图像文本匹配损失(ITM)和掩码语言建模损失(MLM)。

图像和文本对在经过图像编码器和文本编码器后会算一个ITC,促使学习特征对齐的能力。对齐后的文本和图像特征再经过多模态编码器会算一个ITM,促使学习特征交互的能力。仿照BERT,文本mask后再单独过一遍文本编码器和多模态编码器来算MLM,促使学习语言和特征交互的能力。

另外,还有两个训练技巧:

        1、使用ITC相似度矩阵,挖掘困难的负样本,来提升ITM的难度,进而督促模型学习。

        2、由于数据中有噪声(有的负文本对更能表述图像),仿照MoCo,使用动量模型作为teacher模型,在学习的过程中student模型以teacher的软标签作为label进行学习,能够抵抗数据质量不高和一图多义的问题。

图像端:给定任何一张图片,按照Vision Transformer的做法,把它打成patch,然后通过patch embedding layer送给一个Vision Transformer。这里是一个非常标准的12层的Vision Transformer的base模型,如果图片是224x224,那这里的sequence length就是196,然后加上额外的一个CLS token就是197,它的特征维度是768,所以这里绿黄色的特征就是197乘以768。但论文在预训练阶段用的图片是256x256,所以这里绿色的sequence length就会相应的再长一些。它的预训练参数用的DEiT,也就是Data Efficient Vision Transformer在ImageNet 1K数据集上训练出来的初始化参数。
文本端:文本端为保持计算量与clip类似,并且增强模态融合的部分,用前六层做文本编码,剩下的六层transformer encoder作为multi-model fusion的过程。文本模型用BERT模型做初始化,它中间的特征维度也是768,它也有一个CLS token代表了整个句子的文本信息。
momentum model:ALBEF为了做momentum distillation,而且为了给ITC loss提供更多negative,增加了momentum model,这个模型参数由左边训练的模型参数通过moving average得到的(和MoCo一模一样),通过把moving average的参数设的非常高(论文里是0.995)来保证momentum model不会那么快更新,产生的特征更加稳定,不仅可以做更稳定的negative sample,而且还可以去做momentum distillation。

ITC

ITC目的是在融合前学习更好的单模态表示,并且对齐图像和文本特征。它学习一个相似度函数,使得相关联的图像-文本对具有更高的相似度得分。受MoCo的启发,维护了两个队列来存储来自动量单模态编码器的最新M个图像-文本表示。

ITC loss:对比学习就是首先定义一个正样本对,然后定义很多负样本对,对比使正样本对之间的距离越来越近,正负样本对之间的距离越来越远。在ALBEF里首先将图像I通过vision transformer之后得到图像的全局特征,图中黄色CLS token作为全局特征,也就是一个768×1的向量,文本这边先做tokenize,将文本text变成tokenization的序列,再输入BERT的前六层,得到了一系列的特征,文本端的CLS token作为文本的全局特征,也是一个768×1的向量。接下来与MoCo相同,图像的特征向量先做一下downsample和normalization,将768×1变成了256×1的向量,文本特征向量也是768变成256×1。正样本这两个特征距离尽可能的近,它的负样本全都存在一个q里,含有65536个负样本(因为它由momentum model产生的,没有gradient,所以它并不占很多内存),正负样本之间的对比学习,使得这两个特征距离尽可能的远。这个过程就是align before fuse的align,也就是说在图像特征和这个文本特征输入Multi-model Fusion的encoder之前,就已经通过对比学习的ITC loss让这个图像特征和文本特征尽可能的拉近,在同一个embedding space里,具体使用了cross entropy loss。

ITM

ITM预测一个图文对是正对(匹配)还是负对(不匹配)。我们使用多模态编码器的[CLS] token的输出嵌入作为图像-文本对的联合表示,并附加一个全连接(FC)层,然后是softmax来预测两类概率。

本文提出一种策略,为ITM任务采样较难的负样本,而且计算开销为零。如果图像-文本具有相似的语义,但在细节上存在差异,这就是一个比较难的图像-文本对。我们使用公式1中的对比相似性来找到批量内的困难的负对。对于小批量中的每个图像,我们按照对比相似性分布从同一批次中采样一个负文本,其中与图像更相似的文本有更高的机会被采样。同样,我们还为每个文本采样一个难的负图像。

ITM loss:Image Text Matching,属于一个二分类任务,就是给定一个图片,给定一个文本,图像文本通过ALBEF的模型之后输出一个特征,在这个特征之后加一个分类头,也就是一个FC层,然后去判断I和T是不是一个对,这个loss虽然很合理,但是实际操作的时候发现这个loss太简单,所以这个分类任务,很快它的准确度就提升得很高无法进一步优化。因此ALBEF通过某种方式选择最难的负样本(最接近于正样本的那个负样本),具体来说ALBEF的batch size是512,所以ITM loss正样本对就是512个,对于mini batch里的每一张图像,把这张图片和同一个batch里所有的文本都算一遍cos similarity,然后它在这里选择除了它自己之外相似度最高的文本当做negative,这样ITM loss就非常有难度。

MLM

MLM利用图像和对应的文本来预测被遮盖的token。我们以15%的概率随机遮盖输入token,并用特殊标记[mask]替换它们。

MLM Loss:Mask Language Modeling,它把原来完整的句子text T变成一个T’,也就是有些单词被mask掉了,然后它把缺失的句子和图片一起通过ALBEF的模型,最后把之前的完整的句子给预测出来,它这里也借助了图像的信息去更好的恢复被mask掉的单词。

momentum distillation动量蒸馏

用于预训练的图像-文本对大部分是从web上收集的,它们往往有噪声。正相关对通常是弱相关的:文本可能包含与图像无关的单词,或者图像可能包含文本中没有描述的实体。对于ITC学习,图像的负面文本也可能与图像内容相匹配。对于MLM,可能存在与标注不同的其他单词,这些单词同样好地描述了图像(或更好)。然而,ITC和MLM的one-hot标签会惩罚所有负面预测,无论其正确性如何

为解决这个问题,本文建议从动量模型生成的伪目标中学习。动量模型是一个不断进化的老师,由单模态和多模态编码器的指数移动平均版本组成。在训练过程中,我们训练基础模型,使其预测与动量模型的预测相匹配。

具体的,作者认为one hot label(就是图片和文本就是一对,其他跟它都不是一对)对于ITC和MLM这两个loss来说不好,因为有的负样本也包含了很多的信息。所以作者采取了自训练方式,先构建一个momentum model然后用这个动量模型去生成pseudo targets伪目标(其实就是一个softmax score),这样它就不再是一个one hot label。
动量模型在已有模型之上做exponential moving average EMA,目的是在原始模型训练的时候,不仅希望模型预测与ground truth的one hot label去尽可能的接近,还希望模型预测与动量模型出来的pseudo targets尽可能的匹配,这样就能达到一个比较好的折中点。因为当one hot label正确时,可以学习到很多信息,但当one hot label是错误的,或者是noisy的时候,作者希望稳定的momentum model能够提供一些改进。
以ITC loss为例,它是基于这个one hot label的,所以这里再算一个pseudo target loss去弥补它的一些缺陷和不足。这个loss跟前面equation1里的ITC loss的不同,这里将ground truth换成这个q,就是pseudo targets,q不再是one hot label,而是softmax score,所以这里计算KL divergence而不是cross entropy。

最终ALBEF的训练loss有五个:两个ITC、两个MLM、一个ITM,其中:
1)ITC有两个loss:一个是原始的ITC,一个是基于pseudo target的ITC,所以分别加权(1-α)和α的loss weight,最终得到momentum版本的ITC loss。
2)MLM loss有两个loss:一个是原始的MLM,一个是基于pseudo target的MLM,用新生成的pseudo target去代替了原来的ground truth,分别加权(1-α)和α的loss weight,最终得到momentum版本的MLM loss。
3)ITM有一个loss:ITM没有动量版本,因为本身它就是基于ground truth,它就是一个二分类任务,而且在ITM里又做了hard negative,这跟momentum model有冲突。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1460047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

跑批SQL性能异常分析

跑批是串行进行的,同时伴随着在线业务,整个跑批过程中只卡在SQL(6hqva0h4awrxh)执行环节上,相关等待事件为gc current grant 2-way,其他环节及在线业务都不受影响,主机整体资源较为空闲。最后通…

dockerfile文件书写

1.dockerfile构建nginx镜像 1.1书写dockerfile文件 mkdir nginx #创建nginx目录 cd nginx vim dockerfile # 修改文件FROM centos # 基础镜像,默认最新的centos8操作系统 MAINTAINER xianchao # 指定镜像的作者信息 RUN rm -rf /etc/yum.repos.d/* # centos8默认…

HTTP的详细介绍

目录 一、HTTP 相关概念 二、HTTP请求访问的完整过程 1、 建立连接 2、 接收请求 3、 处理请求 3.1 常见的HTTP方法 3.2 GET和POST比较 4、访问资源 5、构建响应报文 6、发送响应报文 7、记录日志 三、HTTP安装组成 1、常见http 服务器程序 2、apache介绍和特点 …

基于机器学习的青藏高原高寒沼泽湿地蒸散发插补研究_王秀英_2022

基于机器学习的青藏高原高寒沼泽湿地蒸散发插补研究_王秀英_2022 摘要关键词 1 材料和方法1.1 研究区概况与数据来源1.2 研究方法 2 结果和分析2.1 蒸散发通量观测数据缺省状况2.2 蒸散发与气象因子的相关性分析2.3 不同气象因子输入组合下各模型算法精度对比2.4 随机森林回归模…

对象的接口

“类”,那个类具有自己的通用特征与行为。 因此,在面向对象的程序设计中,尽管我们真正要做的是新建各种各样的数据“类型”(Type),但几乎所有面向对象的程序设计语言都采用了“class”关键字。当您看到“ty…

多线程相关(1)

线程调度 线程状态:状态切换阻塞与唤醒阻塞唤醒 wait 与 sleep创建线程方式 线程是cpu任务调度的最小执行单位,每个线程拥有自己独立的程序计数器、虚拟机栈、本地方法栈。 线程状态: 线程状态包括:创建、就绪、运行、阻塞、死亡…

第六十四天 服务攻防-框架安全CVE复现Apache shiroApache Solr

第六十四天 服务攻防-框架安全&CVE复现Apache shiro&Apache Solr 知识点: 中间件及框架列表: IIS,Apache,Nginx,Tomcat,Docker,K8s,Weblogic.JBoos,WebSphere, Jenkins,GlassFish,Jetty,Jira,Struts2,Laravel,Solr,Shiro,Thinkphp,Spring, Flask,jQuery等 1、开发框…

【rust】7、命令行程序实战:std::env、clap 库命令行解析、anyhow 错误库、indicatif 进度条库

文章目录 一、解析命令行参数1.1 简单参数1.2 数据类型解析-手动解析1.3 用 clap 库解析1.4 收尾 二、实现 grep 命令行2.1 读取文件,过滤关键字2.2 错误处理2.2.1 Result 类型2.2.2 UNwraping2.2.3 不需要 panic2.2.4 ? 问号符号2.2.5 提供错误上下文-自定义 Cust…

Vue 使用 v-bind 动态绑定 CSS 样式

在 Vue3 中&#xff0c;可以通过 v-bind 动态绑定 CSS 样式。 语法格式&#xff1a; color: v-bind(数据); 基础使用&#xff1a; <template><h3 class"title">我是父组件</h3><button click"state !state">按钮</button&…

相机图像质量研究(38)常见问题总结:编解码对成像的影响--呼吸效应

系列文章目录 相机图像质量研究(1)Camera成像流程介绍 相机图像质量研究(2)ISP专用平台调优介绍 相机图像质量研究(3)图像质量测试介绍 相机图像质量研究(4)常见问题总结&#xff1a;光学结构对成像的影响--焦距 相机图像质量研究(5)常见问题总结&#xff1a;光学结构对成…

如何在同一个module里面集成多个数据库的多张表数据

确保本公司数据安全&#xff0c;通常对数据的管理采取很多措施进行隔离访问。 但是&#xff0c;Mendix应怎样访问散布于异地的多个数据库呢&#xff1f; 前几期我们介绍过出海跨境的大企业对于Mendix的技术、人才的诉求后&#xff0c;陆陆续续有其他客户希望更聚焦具体的实际场…

雷池社区版WAF:开源护网,共筑网络安全长城

雷池社区版WAF&#xff08;Web Application Firewall&#xff09;是一款开源的网络应用防火墙&#xff0c;旨在为网站和网络应用提供安全防护&#xff0c;以抵御各种网络攻击&#xff0c;如SQL注入、跨站脚本攻击&#xff08;XSS&#xff09;、文件包含、以及其他常见的安全威胁…

java日志框架总结(六、logback日志框架 扩展)

springboot推荐使用logback-spring.xml而不是logback.xml而logback-spring.xml文件与logback.xml文件还是有一定的区别&#xff0c;所以简单讲解一下。 一、logback-spring.xml 配置文件实例&#xff1a; <?xml version"1.0" encoding"UTF-8"?> …

刷LeetCode541引起的java数组和字符串的转换问题

起因是今天在刷下面这个力扣题时的一个报错 541. 反转字符串 II - 力扣&#xff08;LeetCode&#xff09; 这个题目本身是比较简单的&#xff0c;所以就不讲具体思路了。问题出在最后方法的返回值处&#xff0c;要将字符数组转化为字符串&#xff0c;第一次写的时候也没思考直…

美团外卖商超商品月销量数据

字段内容&#xff1a; shop_id varchar(50) NOT NULL, shop_id_str varchar(50) NOT NULL, shop_name varchar(400) DEFAULT NULL, shop_min_price varchar(10) DEFAULT NULL, shop_score varchar(10) DEFAULT NULL, shop_wm_score varchar(10) DEFAULT NU…

「WinCC报警系统专题」简述“消息系统”

WinCC通过报警给操作员提供了有关过程故障和错误的信息。它们有助于尽早检测重要情况和避免停机时间。 一、消息系统 消息&#xff08;报警&#xff09;系统由组态和运行系统组件组成。 1、组态系统 报警记录编辑器&#xff08;如图1所示&#xff09;是报警系统的组态组件。报…

2024年了,如何从 0 搭建一个 Electron 应用

简介 Electron 是一个开源的跨平台桌面应用程序开发框架&#xff0c;它允许开发者使用 Web 技术&#xff08;如 JavaScript、HTML 和 CSS&#xff09;来构建桌面应用程序。Electron 嵌入了 Chromium&#xff08;一个开源的 Web 浏览器引擎&#xff09;和 Node.js&#xff08;一…

【riscv】使用qemu运行riscv裸机freestanding程序

文章目录 1. 运行显示2. 工具准备3. 裸机代码和编译3.1 源码3.2 编译 4. 使用qemu仿真运行riscv裸机程序 1. 运行显示 详见左下角&#xff0c; 运行时串口输出的字符 A ; 2. 工具准备 # for riscv64-linux-gnu-gcc sudo apt-get install gcc-riscv64-linux-gnu# for qemu-s…

PROBIS铂思金融破产后续:ASIC牌照已注销

2024年1月31日&#xff0c;PROBIS铂思金融的澳大利亚ASIC牌照 (AFSL 338241) 被注销《差价合约经纪商PROBIS宣布破产&#xff0c;澳大利亚金融服务牌照遭暂停》&#xff0c;这也就意味着&#xff0c;PROBIS铂思金融目前已经没有任何金融牌照。 值得注意的是&#xff0c;时至今日…

摄像设备+nginx+rtmp服务器

前言 由于html中的video现在不支持rtmp协议(需要重写播放器框架&#xff0c;flash被一刀切&#xff0c;360浏览器还在支持flash),遂用rtmp作为桥梁,实际是hls协议在html中起作用. 在此推荐一款前端播放器,.ckplayer 简直了,写点页面,一直循环&#xff0c;洗脑神曲 dream it po…