【经典论文】打通文本图像的里程碑--clip

news2025/1/15 16:50:04

Git|Paper|Colab|

CLIP 论文逐段精读【论文精读】_哔哩哔哩_bilibili

clip是openai团队在4亿对文本图像对上训练出来的。它的训练方法简单,但效果缺出奇的好。是打通图片文本的里程碑式的模型。

目录

一.模型结构​编辑

1.为什么选择用自然语言的监督信号去训练视觉模型?

2.为什么用对比学习?

二.伪代码实现

三.实验

1.zero-shot Transfer

2.Prompt Engineering and Ensemble

3.特征学习

4.和人进行对比

5. Data Overlap Analysis

6.局限

四.如何调用clip

1.需要安装的包

2.zero-shot  Prediction demo

3.Linear-probe evaluation


一.模型结构

 如图所示:

clip整体分为了三个部分,

1)通过对比学习进行预训练

clip整体的原理非常简单,就是将图片-文本pair对进行对比学习。图片经过图片编码器(resnet50/vit),文本经过文本编码器,同一对文本-图片的结果为1,不同pair的文本-图片为0。通过对比学习,联合训练图片和文本编码器。

clip通过将一个图像和一段文本输入到同一个神经网络,将它们映射到同一个嵌入空间,从而实现了图像和文本跨模态的语义对齐。

2)构建文本标签 3)进行zero-shot 预测

对于图像分类任务,对于已知的标签,openai通过prompt emsemble和promt enginering构建了一批模板,然后用label对模板进行填充。

举个例子,对于标签dog,将其改写为一个句子:this is a picture of a dog,以便和预训练时保持一致(都是句子)。

将所有可能的标签用模板得到对应的句子,然后使用text_encoder得到文本向量。

将待预测的图片使用image_encoder得到图片向量,和所有文本向量做cosine,计算出其中最相似的作为其label。

1.为什么选择用自然语言的监督信号去训练视觉模型?

1)不需要再标注了,且文本的自由度比多标签的形式的大了很多

2)将视觉特征和语言特征联系在一起后,提取出了多模态的特征,更容易去做zero-shot的学习。

2.为什么用对比学习?

用图片预测对应的文本,结果会非常多样,训练起来会非常慢。而使用对比学习,判断图片文本是否是一对,就简化了任务。

论文中也提到,将预测型的目标函数换为对比型的目标函数,训练速度快了4倍。

 

二.伪代码实现

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
# 这里还有一个投射层,也就是W_i,W_t,是用来学习如何从单模态学习到多模态
# 多模态里非常常见的做法,fusion学习一个联合表征空间
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

三.实验

1.zero-shot Transfer

动机:为什么要做zero-shot任务?

结果:之前预训练好的embedding,都需要用下游任务去微调,但是会遇到几个问题:

1)下游数据不好收集

2)distribution shift

用文本做引导,利用文本信号去很灵活地做zero-shot的迁移学习。

下面2张图分别是clip在zero-shot/few-shot的结果

 对于特别困难的任务,比如需要特定领域知识的任务,人类都无法很好识别,需要给clip几张图片学习一下。

右图就是clip进行few-shot learning的结果,可以看到,随着给的例子的增加,clip的效果越来越好。但是比较有意思的一点是,1-4 few-shot learning的结果是打不过左上角的zero-shot clip的效果,这也是作者在后面所讲的局限性之一。

2.Prompt Engineering and Ensemble

为什么要做提示工程?

1)文本本身的多样性,比如remote既可以做遥控器,也有遥远的意思,只给单独的一个单词,很有可能有歧义性。

2)推理时和训练是保持一致,避免distribution gap。

最简单的prompt就是a picture of [label]。

但是当你知道这个数据集是什么相关的,还可以给出额外的提示:a picture of [label],a type of pet,这样进一步缩小解的空间,帮助clip更好地选出正确的答案。

对于OCR任务,在你想要找的文本上打上双引号,模型会更明白你的意思(感觉和chatgpt一样,符号会帮助模型理解)

ensemble多个模板(clip里是80个)的结果,这个是可以根据自己的需求/下游数据集去编写一些模板的。

3.特征学习

在下游数据集上,使用全部的训练数据去训练。

训练的方式选择linear probe而不是fine-tune。

  • 一方面想要更直观地观察预训练的模型的好坏,使用fine-tune的话,有可能预训练模型不好,但是fine-tune过后得到一个比较好的结果。
  • linear-probe的训练比较简单,而fine-tune的话需要为每一个数据集调参,非常麻烦

结果如下:

clip的在多个数据集上要大幅优于在imageNet上预训练的最好的模型

 

4.和人进行对比

 5个参赛者在Pets数据集上vsclip的效果,clip大幅优于人类,人类在one-shot learning后,效果大幅提升,但是再多看一个例子并没有提升更多的准确率。

5. Data Overlap Analysis

很多人会质疑,是不是训练数据太大太好了,囊括了很多下游数据,导致clip的效果这么好。这一章也是做了一些去重实验,证明了clip本身泛化性比较好。

6.局限

1)目前clip只是远高于res50,和其他sota的结果还有一定差距(提升数据量可以提升clip的效果,但是需要扩大1000倍的数据量才可能弥补和sota之间十几个点的差距)

2)clip在某些细分领域、逻辑性较强的数据集上表现不好

3)对于out of distribution的数据,clip的效果也不太好(比如手写数字识别mnist,clip的效果连lr都打不过,后来利用去重方法查找,在4亿条训练数据中确实不存在和这些人造图片相似的图片)。

4)生成式模型不需要做prompt engineering,可以直接得到图片的标题,而clip还需要自己进行处理。(blip可以的)

5)clip对数据的利用并不是很高效,用了4亿条数据。

6)clip调参时其实用到了imageNet,在研发过程中,也是围绕着实验的27个数据集展开的,可能并不是真正意义上的zero-shot。

7)爬取的图片-文本对可能会带有偏见。

8)在1-4个样本上的few-shot结果反而不如zero-shot。

(有几个都是chatgpt也有的问题,哈哈哈)

四.如何调用clip

在git上,官方给出了使用的代码

1.需要安装的包

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git

cudatoolkit可以换成自己机器对应的版本,也可以不装,只使用cpu。

2.zero-shot  Prediction demo

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 处理将待预测的图片,这里可以使用自己的图片
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
# zero-shot,这里也可以根据自己的数据集去生成相应的label,让clip预测图片最可能是什么类别
# 最好在小数据集上尝试些prompt,看识别准确率是否有所上升,再在大数据集上预测。
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

使用CIFAR-100 dataset进行zero-shot Prediction

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

3.Linear-probe evaluation

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/551275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“Shell“免交互

文章目录 一.免交互(Here Document)1.1Here Document 概述2.2Here Document 常规用法 二.Expect2.1Expect基本命令2.2Expect执行方式 一.免交互(Here Document) 1.1Here Document 概述 使用I/O重定向的方式将命今列表提供给交互式…

chatgpt赋能Python-pythonandor

Pythonandor:探索Python的异步编程方式 如果您是一个Python工程师,你可能已经听过Pythonandor。Pythonandor是一个Python异步框架,能够让你更高效地处理请求,并且提高应用程序的响应速度。 什么是Pythonandor? Pythonandor实际…

5。STM32裸机开发(4)

嵌入式软件开发学习过程记录,本部分结合本人的学习经验撰写,系统描述各类基础例程的程序撰写逻辑。构建裸机开发的思维,为RTOS做铺垫(本部分基于库函数版实现),如有不足之处,敬请批评指正。 &…

一篇文章打好SQL基础,熟悉数据库的基础操作和方法,以及安装MySQL软件包和Python操作MySQL基础使用

1.SQL的概述 SQL的全称:Structured Query Language,结构化查询语言,用于访问和处理数据库的标准计算机语言。 SQL语言1974年有Boyce和Chamberlin提出的,并且首先在IBM公司研制的关系数据库系统SystemR上实现。 经过多年发展&am…

KVM(二)命令行新建虚拟机

目录 一、准备工作 二、新建虚拟机 2.1 文件准备 2.2 正式安装 2.3 时区设置 2.4 安装设置 2.5 设置root用户密码 2.6 vm2安装完成 三、进入虚拟机vm2 四、网络设置 五、参考链接 若还未部署KVM,请参考第一节: KVM(一)…

如何让你的 Jmeter+Ant 测试报告更具吸引力?

目录 引言 一、安装apache-Ant 二、Jmeter准备 3、生成测试报告 4、JMeter动态参数处理逻辑是什么? 5、JMeter是怎么做API自动化测试的? 结语 引言 想象一下,你辛苦搭建了一个复杂的网站,投入了大量的时间和精力进行开发和…

java 对接国标摄像头流程、代码整合 springboot SIP -GB2818

java 对接设备的代码资料较少,这里介绍GB2818的基本对接流程,有用自取👇 java负责SIP信令的注册交互,推流、拉流鉴权摄像头负责推流、流媒体负责拉流、转码 wvp-GB28181-pro项目 ,如果java对接各种摄像头,这个项目很&a…

Java流程控制(二)

⭐ 循环结构⭐ 嵌套循环⭐ break 语句和 continue 语句⭐ 方法⭐ 方法的重载(overload)⭐ 递归结构 ⭐ 循环结构 循环结构分两大类,一类是当型,一类是直到型。 🐟 当型: 当布尔表达式条件为 true 时,反复执行某语句&a…

【eNSP】win11解决virtualbox5.2.44无法安装、不兼容的问题

问题描述: 本人大三学生一枚,这学期上计算机网络,老师要求安装华为eNSP软件,安装环节一切顺利,直到安装到依赖组件中VirtualBox-5.2.44时,发生了问题,Windows提示此应用无法在此设备上运行&…

AI工具第三期:本周超16款国内精选AI工具分享!

1. 未来百科 未来百科,是一个知名的AI产品导航网站——为发现全球优质AI工具而生。目前已聚集全球2500优质AI工具产品,旨在帮助用户发现全球最好的AI工具,同时为研发AI垂直应用的创业公司提供展示窗口,迎接未来的AI时代。未来百科…

RocketMq源码分析(七)--消息发送流程

文章目录 一、消息发送入口二、消息发送流程1、消息验证1)消息主题验证2)消息内容验证 2、查找路由3、消息发送1)选择消息队列2)消息发送-内核实现sendKernelImpl方法参数获取brokerAddr添加消息全局唯一id设置实例id设置系统标记…

Linux Audio (5) DAPM-2 Widget/Path/Route

DAPM-2 Widget/Path/Route WM8960结构图WidgetRoutePath总结 课程:韦东山音频专题 内核:Kernel 3.5 实例:WM8960 WM8960结构图 录音时的音频通路 抽象图为: Widget wm8960.c sound\soc\codecs static const struct snd_soc_dap…

C++进阶——mapset的实现

C进阶——map&set的实现 红黑树的迭代器 迭代器的好处是可以方便遍历,是数据结构的底层实现与用户透明。如果想要给红黑树增加迭代器,需要考虑以前问题: 迭代器的定义 begin()与end() STL明确规定,begin()与end()代表的是一…

OS7安装rabbitmq

1.卸载存在的rabbitmq 停止rabbitmq服务: systemctl stop rabbitmq-server 查看rabbitmq安装的相关列表: yum list | grep rabbitmq 卸载rabbitmq已安装的相关内容: yum -y remove rabbitmq-server.noarch 查看erlang安装的相关列表: yum list | grep erlang 卸…

shell脚本----免交互操作

文章目录 一、Here Document免交互1.1免交互概述1.2语法格式1.3操作实验1.4tee命令 二、expect命令 一、Here Document免交互 1.1免交互概述 使用I/O重定向的方式将命令列表提供给交互式程序或命令,比如 ftp、cat 或 read 命令。 是标准输入的一种替代品可以帮助脚…

【Unity3D】立方体纹理(Cubemap)和天空盒子(Skybox)

1 立方体纹理(Cubemap) 本文完整资源见 → 立方体纹理(Cubemap)和天空盒子(Skybox) 。 1)立方体纹理简介 立方体纹理是指由上、下、左、右、前、后 6 张纹理组成的立方体结构纹理,其…

X3运行paddle-lite Demo

仓库地址GitHub - PaddlePaddle/Paddle-Lite-Demo at master git clone直接下载到X3上 环境准备 $ sudo apt-get update $ sudo apt-get install gcc g make wget unzip libopencv-dev pkg-config $ wget https://www.cmake.org/files/v3.10/cmake-3.10.3.tar.gz $ tar -zxvf …

Node.js--》深入理解 PM2:Node.js 应用部署和管理利器

目录 pm2:进程自动化管理工具 pm2的安装与使用 pm2:进程自动化管理工具 PM2:是一个流行的Node.js进程管理器,它可以帮助您在生产环境中管理和保持Node.js应用程序运行。PM2的功能包括监视您的应用程序、自动重启您的应用程序、…

手撕代码——异步FIFO

手撕代码——异步FIFO 一、异步FIFO原理与设计读写地址指针控制读写地址指针跨时钟处理与空满信号判断读写地址与读写操作 二、完整代码与仿真文件三、仿真结果 一、异步FIFO原理与设计 在FIFO的设计中,无论是同步FIFO,还是异步FIFO,最最最最…

ChatGPT:4. 使用OpenAI API创建自己的AI网站:3. flask web框架将OpenAI 创作的图片显示在网页界面上

ChatGPT:4. 使用OpenAI API创建自己的AI网站:3. flask web框架将OpenAI 创作的图片显示在网页界面上 如果你还是一个OpenAI的小白,有OpenAI的账号,但想调用OpenAI的API搞一些有意思的事,那么这一系列的教程将仔细的为…