深度学习论文: Learning Transferable Visual Models From Natural Language Supervision
Learning Transferable Visual Models From Natural Language Supervision
PDF: https://arxiv.org/pdf/2103.00020.pdf
官方代码: https://github.com/OpenAI/CLIP
PyTorch代码: https://github.com/shanglianlm0525/CvPytorch
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks
1 概述
CLIP(对比性语言-图像预训练)是一个在各种(图像,文本)对上进行训练的神经网络。它可以通过自然语言指令,在给定图像的情况下预测最相关的文本片段,而不是直接为任务进行优化,类似于GPT-2和GPT-3的零样本能力。发现CLIP在ImageNet的“零样本”上与原始的ResNet50的性能相匹配,而且没有使用任何原始的128万个标记示例,克服了计算机视觉中的几个重要挑战。
2 CLIP (Contrastive Language-Image Pre-Training)
首先是构建CLIP,CLIP实际上是一个预训练模型,包括文本编辑和图像编辑器两部分,分别计算文本向量和图像向量的相似度,以预测它们是否为一对,如图1所示。CLIP将图像和文本先分别输入一个图像编码器image_encoder和一个文本编码器text_encoder,得到图像和文本的向量表示
I
f
I_{f}
If 和
T
f
T_{f}
Tf 。然后将图像和文本的向量表示映射到一个联合多通道空间,得到新的可直接进行比较的图像和文本的向量表示
I
e
I_{e}
Ie 和
T
e
T_{e}
Te 。然后计算图像向量和文本向量之间的cosine相似度。最后,对比学习的目标函数就是让正样本对的相似度较高,负样本对的相似度较低。矩阵中非对角线上的元素都是负样本,n个正样本,
n
2
−
n
n^{2}-n
n2−n个负样本,有了正负样本,模型就可以通过对比学习的方式去训练了,不需要任何手工标注。但是这种无监督的训练方式,是需要大量的训练数据的。
CLIP核心实现的伪代码:
推理代码:
import torch
import clip
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
在训练过程中,使用了5个不同的ResNet模型:ResNet-50,ResNet-101,4xResNet-50,16xResNet-50和64xResNet-50,以及3个ViT模型:ViT-B/32,ViT-B/16和ViT-L/14。训练时的迭代轮数为32,优化器使用Adam,并采用解耦权重衰减正则化技术,调度器采用余弦调度。初始超参数通过网格搜索确定。此外,训练过程中使用了半精度技术,并在训练完成后,再对稍大一些的336像素分辨率进行一个额外的训练轮数。最终,选择在336像素分辨率上表现最好的ViT-L/14作为后文所指的CLIP模型。
3 Experiments
3-1 Zero-Shot Transfer
CLIP实现零样本推理的方法是通过预训练获得文本和图像的特征,而没有分类头。为了解决这个问题,作者提出了一种利用自然语言的方法,即prompt template(提示模板)。对于ImageNet的类别,首先将其转化为句子的形式,例如"A photo of a {object}"。由于ImageNet有1000个类别,因此会生成1000个句子。然后,通过之前预训练好的文本编码器,这1000个句子可以得到1000个文本特征。
虽然可以直接使用类别单词来提取文本特征,但是在模型预训练时,与图像配对的都是句子,因此在推理时使用单词效果会下降。因此,使用句子作为提示更为有效。将待分类的图像送入图像编码器,得到其特征。然后,将图像特征与1000个文本特征计算余弦相似度,并选择最相似的文本特征对应的句子,从而完成分类任务。这种方法不仅限于这1000个类别,任何类别都可以使用。
通过这种方式,CLIP完全摆脱了分类标签的限制,无论是在训练还是推理过程中,都不需要预先定义好的标签列表。
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
3-2 Representation Learning
Fine-tuning 与 linear probe
- Linear probe:把一个训练好的模型冻结住,只训练最后一层的分类头去做分类任务。
- Fine tune:对整个模型参数进行端到端的训练
CLIP 的作者选择了linear probe的方法进行预训练模型在其他数据集上表现的对比。
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")