CLIP 论文解读

news2024/9/23 5:26:02

文章目录

    • 模型
      • 训练
      • 推理
    • 实验
      • 与Visual N-Grams 相比较
      • 分布Shift的鲁棒性
    • 不足
    • 参考

现有的计算机视觉系统用来预测一组固定的预订对象类别,比如ImageNet数据集有1000类,CoCo数据集有80类。这种受限的监督形式限制了模型的通用性和可用性。使用这种方法训练好的模型对训练过程中出现的对象类别有很好的识别效果,但是对训练过程中未出现的类别,识别效果很差。直接从图像的原始文本中学习是一种很有潜力的选择,这样可以获得更多的监督信号。OpenAI从互联网上收集 4亿(图像,文本)对,基于此数据集,训练一个预测哪一个文本描述当前哪一个图像的预训练任务, 此任务取得了SOTA的图像表示。预训练之后,可以通过自然语言匹配视觉概率从而实现zero-shot的迁移。

模型

CLIP算法的核心是利用自然语言包含的监督信号来训练视觉模型。相比于其他的训练方法,从自然语言中学习具有以下两个优势。首先,相比于标准的有标签图像分类数据集,这种方法无需标注,就很容易扩展数据集;其次,图像和文字配对训练,学习到的特征不单单是一个视觉特征,而是多模态的特征,有助于zero-shot的迁移学习。

图像和文本的配对训练需要一个足够大的数据集,然而现有的数据集要么数据量不够,要不标注质量太差,为了解决这个问题,OpenAI他们从互联网上的各种公开来源公收集了4亿(图像,文本)对。

最近视觉领域的算法训练都需要巨大的计算量。 在ImageNet数据集上训练ResNeXt101-32x48d模型需要19个GPU年,训练Noisy Student EfficientNet-L2模型需要33个TPUv3 core-years。那么从自然语言中直接学习视觉概念的开放集合这一任务所需要的计算资源令人望而生畏。经过一些列的努力,他们发现训练效率对多模型预训练是至关重要的。

最初的方法,与VirTex相似,联合训练一个图像CNN和文本transformer去预测图像所对应的文本。然而,一个6300万参数的transformer语言模型,比ResNet-50图像编码器多了两倍的计算量,学习识别ImageNet类别的速度比一个更简单预测同一文本的BOW编码器慢了三倍。它们试图预测每张图片所附文本的确切单词,这种做法显然是比较困难的。对于一张图片来说,可以有不同的描述。比如下面一张图,你可以描述“有一只狗在草地上”,也可以描述“这是一张小狗的图片”,还可以描述“一条小黑狗伸着小舌头”等等。
在这里插入图片描述
最近在图像对比表征学习方面的工作发现,contrastive objectives可以比同等的predictive objective学习更好的表征,而且其他工作发现,图像的生成模型相比同性能的对比模型相比,需要更多的计算量才能够学习到高质量的图像表示。基于此,将文本看做一个整体,而不是一个个单词,与图像进行配对。将基准BOW编码器的预测性的目标函数替换成对比性的目标函数,训练效率提高了4倍。

在这里插入图片描述

训练

下图是CLIP的伪代码。CLIP包含两个编码器,一个图像编码器和一个文本编码器,图像编码器的输入形式是 [ n , h , w , c ] [n,h,w,c] [n,h,w,c] n n n是批次大小, h , w , c h,w,c h,w,c是图像的大小,比如 224 × 224 × 3 224 \times 224 \times 3 224×224×3;文字编码器的输入形式是 [ n , l ] [n,l] [n,l],由于是图像文本对,所以文本编码器中的批次是与图像编码器中的批次是一样的, l l l是序列长度。图像编码器可以是深度卷积网络,也可以是Transformer;文本编码器可以是CBOW,也可以是Text Transformer。

在这里插入图片描述

图像输入和文本输入分别进入图像编码器和文本编码器之后,得到相应的特征表示;然后后接一个归一化的操作。归一化操作时,牵涉到线性投射层 n p . d o t ( I f , W i ) np.dot(I_f, W_i) np.dot(If,Wi) n p . d o t ( T f , W t ) np.dot(T_f, W_t) np.dot(Tf,Wt)。线性投射层将每个编码器的表示映射到多模态嵌入空间。归一化操作后得到用于对比的特征 I e I_e Ie T e T_e Te。接下来,计算 n n n个图像的特征和 n n n个文本的特征的余弦相似性。伪代码中的余弦相似性值为 l o g i t s logits logits,logits的形状大小为 n × n n \times n n×n,然后logits与分别与图像和文本的gt标签做交叉熵目标函数,最后的损失函数为对称损失函数。gt的计算使用了arrange 函数,labels的取值为 1 , 2 , 3 , 4 , 5 , ⋯   , n 1,2,3,4,5, \cdots ,n 1,2,3,4,5,,n,CLIP的正样本标签为对角线上的元素。

OpenAI收集了4亿图像文本对,数据量过于巨大,所以图像编码器和文本编码器无预训权重。由于文本是被当做一个句子,所以文本编码器对文本无增强操作;图像编码器中的图像增强方式也仅仅使用了resize和随机裁减。温度参数 τ \tau τ,控制softmax中logits的范围,在训练过程中直接被优化为标量,而不是作为一个超参数。

在这里插入图片描述

推理

当使用训练好的CLIP模型进行推理时,首先需要使用clip.load()加载模型,然后分别对图像和文本进行前处理。文字前处理是对数据集中的所有类别进行prompt engineering处理,将每个类别转换成句子,clip.tokenize()将句子长度padding到77个token长度。下面的代码中,图像选取了CIFAR100数据集中的其中一张。接下来,将图像和文字分别喂入图像编码器和文字编码器,提取图像特征和文字特征,分别将图像特征和文字特征正则化之后,计算图像特征和文字特征之间的相似度,并对相似度进行softmax操作。

# 下面代码使用CLIP进行zero shot预测,从CIFAR-100数据集中挑选一张图片,预测这张图像最有可能与此数据集中100个标签中哪一个标签最相似。
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device) # 加载模型,返回的preprocess中包含一个torchvision transform依次执行Resize,CenterCrop和Normalization等操作。

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device) # 图像前处理
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) # 文字前处理

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input) # 图像编码器提取输入图像的特征
    text_features = model.encode_text(text_inputs) # 文字编码器提取输入文字的特征

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True) # 正则化
text_features /= text_features.norm(dim=-1, keepdim=True) # 正则化
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) #计算图像特征与文字特征之间的相似度
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
The output will look like the following (the exact numbers may be slightly different depending on the compute device):
Top predictions:
           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

在这里插入图片描述
大多数标准的图像分类只会使用标签的数字id来注释图像,并包含一个将这些数字id映射回其英文名称的文件。这样就导致一个最常见的问题就是多义性。当类的名称作为唯一的信息提供给CLIP的文本编码器时,由于缺乏上下文,它无法区分词义。在某些情况下,同一个单词的多个含义可能做为不同的类包含在同一数据集中。比如在ImageNet数据集中,建筑起重机和飞行的鹤。而且,在CLIP的预训练数据集中,与图像配对的文本中只有一个单词的情况是相对罕见的。通常,文本都是一个描述图像的完整句子。为了弥合这一分布差距,作者团队发现使用提示模板“A photo of a {label}”是一个很好的选择,有助于指定文本是关于图像的内容。与仅使用标签文本的基线模型相比,通常会提升性能。

在官方提供的Prompt_Engineering_for_ImageNet.ipynb中,ImageNet数据集总共包含1000个类,CLIP团队陆陆续续地提供了80个提示模板。

实验

与Visual N-Grams 相比较

Visual N-Grams是17年的工作,它是首个将零样本迁移到现有图像分类数据集的算法。它使用一般预处理模型研究zero shot到标准图像分类数据集的迁移。

下表展示了Visual N-Grams和CLIP在三个数据集上的性能对比。在ImageNet数据集上,CLIP模型在没有使用这128万标注数据的前提下,将准确率从Visual N-Grams的11.5%提升到76.2%。在aYahoo和SUN数据集上,CLIP模型相比于Visual N-Grams模型的性能有所提升。

CLIP与Visual N-Grams比较相似,但是17年的时候,transformer还没有出世,上述的比较,主要是为了了解CLIP的性能。

在这里插入图片描述

分布Shift的鲁棒性

2015年时,深度学习模型在ImageNet测试数据集的性能超过人类,但是这些模型仍然会犯一些简单的错误,我们戏称为人工智障。我们知道,深度学习模型的训练和测试数据集遵从独立同分布IID假设,这样就导致模型对分布外的数据分类性能较差。最常见的一个解释就是:深度学习模型非常善于找到其训练数据集中的一些相关性和模式,而这些相关性和模式实际上是伪造的,不适用于其他分布,导致模型在其他数据集上的性能大幅下降。为了使模型泛化性能更好,有一些学者在研究分布外OOD(out-of-distribution)分类问题。

CLIP模型从一个不同于以往的角度研究这个问题,它在一个非常大的数据集上通过自然语言监督进行训练模型,能够实现zero-shot的高性能。直观上说,zero-shot 模型由于没有针对某个分布进行训练,它不能够学习到仅在特定分布上保持的虚假的相关性和模式,因此,zero-shot模型应当具有更高的鲁棒性。

ImageNetV2、ImageNet Sketch、ObjectNet, ImageNet Adversarial 和 ImageNet Rendition这几个数据集,为ImageNet数据集的分布偏移。训练集为ImageNet数据集训练得到的模型,在这五个数据集上的表现都低于在ImageNet测试数据集上的性能;而zero-shot CLIP模型则取得较好的性能,如下图所示,尤其实在ImageNet-R数据集上,有51.2%的性能提升。

在这里插入图片描述

不足

  1. Zero-shot的CLIP比基于ResNet-50特征的线性分类器相比具有优势,但在很多任务上,仍逊色于SOTA模型。
  2. CLIP在细分类数据集上表示不好;CLIP不擅长处理抽象任务,比如数一数图像中物体的个数;CLIP对一些不包含预训练集中的新型任务,表现也不好, 比如对一张图像中到最近汽车的距离进行分类。
  3. 对于一些真正的分布外的数据,CLIP的泛化性能很差。
  4. CLIP本质上还是在有限的类别中进行推理,相比于image caption直接能生成新的输出,还是具有局限性的。一个值得尝试的简单想法是将对比和生成目标进行联合训练,整合CLIP的有效性和caption模型的灵活性。
  5. CLIP仍然没有解决深度学习中的poor data efficiency问题。CLIP与自监督和自训练结合训练会是一个提高数据效率方面的方向。
  6. CLIP 虽然一直强调zero-shot,但是在训练过程中,也反复以数据集的validation performance指导CLIP的表现,并不算真实的zero shot。如果能够创造一个验证zero-shot的迁移能力的新数据集,将会解决这种问题。
  7. 4亿图像文本对,不论图像和文本都是从网上爬下来的。而是这些图像文本对没有进行过滤和处理,难,难免会携带一些社会性偏见。
  8. 很多复杂的任务和视觉概念很难仅仅通过文本指定。未来的工作需要进一步开发一种将CLIP强大的zero shot性能与few shot学习相结合的方法。

参考

  1. Learning Transferable Visual Models From Natural Language Supervision
  2. CLIP blog
  3. openai/CLIP

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/431378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《花雕学AI》02:人工智能挺麻利,十分钟就为我写了一篇长长的故事

ChatGPT最近火爆全网,上线短短两个多月,活跃用户就过亿了,刷新了历史最火应用记录,网上几乎每天也都是ChatGPT各种消息。国内用户由于无法直接访问ChatGPT,所以大部分用户都无缘体验。不过呢,前段时间微软正…

Nginx实现会话保持,集群模式下session域共享

前言 生产环境下,多数系统为了应对线上多种复杂情况而进行了集群架构的部署,保证系统的高性能、价格有效性、可伸缩性、高可用性等。通常将生产环境下的域名指向Nginx服务,通过它做HTTP协议的Web负载均衡。 session是什么 在计算机中&…

13.广度优先搜索

一、算法内容 1.简介 广度优先搜索BFS(Breadth First Search)按照广度优先的方式进行搜索,可以理解为“尝试所有下一步可能”地穷举所有可行的方案,并不断尝试,直到找到一种情况满足问题问题的要求。 BFS从起点开始…

C语言——学生信息管理系统(数组)

文章目录一、前言二、目的三、框架1.菜单1.1主菜单1.2子菜单2.流程图2.1总流程图2.2开始流程图2.3增加学生信息流程图2.4.删除学生信息流程图2.5修改学生信息流程图2.6查询学生信息流程图2.7对学生信息排序流程图3.思路四、代码五、演示视频一、前言 因为最近是在赶进度总结&a…

无人驾驶--工控机安装autoware

时隔好久,又来写文章了,这次有高人指点,要系统的学习一下无人驾驶了。 使用的是易咖的底盘车,工控机是米文动力Apex Xavier II,基于autoware框架 首先是在工控机上安装autoware,工控是ubuntu18环境。 参…

Python入门教程+项目实战-9.2节: 字符串的操作符

目录 9.2.1 字符串常用操作符 9.2.2 操作符:拼接字符串 9.2.3 *操作符:字符串的乘法 9.2.4 []操作符:索引访问 9.2.5 [:]操作符:分片字符串 9.2.6 in操作符:查找子串 9.2.7 %操作符:格式化字符串 9…

为什么要做软件测试

随着信息技术的发展和普及,人们对软件的使用越来越普及。但是在软件的使用过程中,软件的效果却不尽如人意。为了确保软件的质量,整个软件业界已经逐渐意识到测试的重要性,软件测试已经成为IT 领域的黄金行业。本篇文章将会带领大家…

使用Tensorboard多超参数随机搜索训练

文章目录1超参数训练代码2远端电脑启动tensorboard完整代码位置https://gitee.com/chuge325/base_machinelearning.git 这里还参考了tensorflow的官方文档 但是由于是pytorch训练的差别还是比较大的,经过多次尝试完成了训练 硬件是两张v100 1超参数训练代码 这个…

Android Studio升级Gradle Plugin升级导致项目运行失败问题

背景&错误 升级Android Studio 旧项目无法运行,奇奇怪怪什么错误都有 例如: java.lang.IllegalAccessError: class org.gradle.api.internal.tasks.compile.processing.AggregatingProcessingStrategy (in unnamed module 0x390ea9fb) cannot acce…

传智健康-day2

一.需求分析(预约管理功能开发) 预约管理功能,包括检查项管理、检查组管理、体检套餐管理、预约设置等、预约管理属于系统的基础功能,主要就是管理一些体检的基础数据。 检查组是检查项的集合 二.基础环境搭建 1导入预约管理模块数据表 需要用到的…

Ubuntu安装MySQL及常用操作

一、安装MySQL 使用以下命令即可进行mysql安装,注意安装前先更新一下软件源以获得最新版本: sudo apt-get update #更新软件源 sudo apt-get install mysql-server #安装mysql 上述命令会安装以下包: apparmor mysql-client-5.7 mysql-c…

不定期更新:我对 ChatGPT 进行多方位了解后的报告,超级全面,建议想了解的朋友看看

优质介绍视频: GPT4前端【AI编程新纪元】 【渐构】万字科普GPT4为何会颠覆现有工作流;为何你要关注微软Copilot、文心一言等大模型 此文章不定期更新(一周应该会更新一次) 最近一次更新:2023.4.16 12:00 ChatGPT 是什…

零基础搭建私人影音媒体平台【远程访问Jellyfin播放器】

文章目录1. 前言2. Jellyfin服务网站搭建2.1. Jellyfin下载和安装2.2. Jellyfin网页测试3.本地网页发布3.1 cpolar的安装和注册3.2 Cpolar云端设置3.3 Cpolar本地设置4.公网访问测试5. 结语1. 前言 随着移动智能设备的普及,各种各样的使用需求也被开发出来&#xf…

关于加强供水企业营销管理的几点思考

供水营销部门是供水企业最重要的职能部门之一,其工作职能直接与供水企业的经济利益和社会效益息息相关,具体来说,主要涉及到五个方面的指标内容:水费回收率、 水量漏损率(产销差率)、水表完好率、水价调整及…

《年会抽奖》:无人获奖的概率

目录 一、题目 二、思路 1、错排问题 2、n 的阶乘 3、输出格式要求 三、代码 一、题目 题目:年会抽奖 题目链接:年会抽奖 今年公司年会的奖品特别给力,但获奖的规矩却很奇葩: 1. 首先,所有人员都将…

SpringBoot起步依赖和自动配置

文章目录 1、起步依赖2、自动配置 1、起步依赖 概念 起步依赖本质上是一个Maven项目对象模型(Project Object Model,POM),定义了对其他库的传递依赖,这些东西加在一起支持某一功能。 简单的说,起步依赖就…

这才是后端API该有的样子

一般系统大致架构如下: 有些小伙伴会说,这个架构太简单太low了吧,什么网关、缓存、消息中间件都没有。 需要说明的是,因为我们主题是API接口(tbAPI,pinduoduo API接口调用)所以聚焦这一点上就行…

Java FileChannel文件的读写实例

一、概述: 文件通道FileChannel是用于读取,写入,文件的通道。FileChannel只能被InputStream、OutputStream、RandomAccessFile创建。使用fileChannel.transferTo()可以极大的提高文件的复制效率,他们读和写直接建立了通道&#x…

【Leetcode刷题】链表的中间结点和合并两个有序链表

生命如同寓言,其价值不在与长短,而在与内容。 ——塞涅卡 目录 一.链表的中间结点 1.快慢指针 二.合并两个有序链表 1.尾插法 一.链表的中间结点 给你单链表的头结点 head ,请你找出并返回链表的中间结…

Java——对象克隆(复制)

假如想复制一个简单变量。很简单: int apples 5; int pears apples; 不仅int类型,其它七种原始数据类型(boolean,char,byte,short,float,double.long)同样适用于该类情况。 但是如果你复制的是一个对象,情况就复杂了。 假设说我是一个b…