文章目录
- 一、CLIP 模型
- 二、准备
- 三、加载模型
- 四、查看图片处理器
- 五、文本分词
- 六、输入图片和文本,并可视化
- 七、将图片和文字 encode 生成特征
- 八、计算 cosine 相似度
- 九、零样本进行图片分类
- 十、编写函数进行图片分类
- 十一、测试自己的函数
- 十二、编写函数对多图片进行分类
项目地址:
https://github.com/biluko/Paper_Codes_for_fun/tree/master/CLIP
一、CLIP 模型
CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一个深度学习模型,用于处理图像和文本之间的联合表示。它的目标是将图像和文本嵌入到一个共享的向量空间中,使得相似的图像和文本在这个空间中距离较近,而不相似的图像和文本距离较远。
CLIP模型的特点在于它可以通过对图像和文本之间进行对比学习,来学习到一个通用的特征表示。在训练过程中,CLIP通过最大化相似图像和文本的相似性,并最小化不相似图像和文本的相似性来调整模型参数。这种对比学习的方法使得CLIP能够在多个任务上进行迁移学习,如图像分类、文本分类、图像生成等。
CLIP模型的应用非常广泛。通过将图像和文本映射到共享的向量空间,CLIP可以实现图像和文本之间的多模态检索和匹配。例如,通过将一张图片和一个描述该图片内容的文本查询进行编码,可以计算它们在向量空间中的距离,并找到与之相似的图片或文本。这为图像搜索、商品推荐、智能问答等应用提供了新的可能性。
CLIP模型的优势在于它不需要大量标注的训练数据,而是通过对比学习来学习通用的特征表示。这使得CLIP在跨领域和跨语言的应用上具有良好的泛化能力。此外,CLIP还能够理解和生成自然语言描述的图像,以及生成图像描述的文本,具备了一定的语义理解和生成能力。
总之,CLIP是一个强大的深度学习模型,能够将图像和文本嵌入到共享的向量空间中,并实现多模态的检索和匹配。它在图像和文本处理、多模态应用以及迁移学习等方面有着广泛的应用前景。
二、准备
包括下载 CLIP 依赖和将设置改为 GPU:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
import numpy as np
import torch
from pkg_resources import packaging
print("Torch version:", torch.__version__)
三、加载模型
展示可选择的不同图片特征提取器:
import clip
clip.available_models()
加载模型和图片处理器:
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
print("模型参数:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("输入图片尺寸:", input_resolution)
print("文本长度:", context_length)
print("词表大小:", vocab_size)
四、查看图片处理器
这里调整图片大小 224 × 224 224 \times 224 224×224,中心裁剪,然后使用均值和标准差进行归一化,最后输出tensor向量:
preprocess
五、文本分词
clip.tokenize("Hello World!")
六、输入图片和文本,并可视化
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format='retina'
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
original_images=[]
images=[]
texts=[]
plt.figure(figsize=(16,5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
plt.tight_layout()
七、将图片和文字 encode 生成特征
image_input = torch.tensor(np.stack(images)).cuda()
print(image_input.shape)
text_tokens = clip.tokenize(['This is '+ desc for desc in texts]).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
print(image_features.shape)
print(text_features.shape)
八、计算 cosine 相似度
image_features /= image_features.norm(dim = -1,keepdim = True)
text_features /= text_features.norm(dim = -1,keepdim = True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize = (20, 14))
plt.imshow(similarity, vmin = 0.1, vmax = 0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize = 18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent = (i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha = "center", va = "center", size = 12)
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size = 20)
九、零样本进行图片分类
数据集CIFAR100,就是使用相似度计算得分,然后softmax一下:
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform = preprocess, download = True)
加上prompt 提示模板进行分类:
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
计算相似度得分:
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim = -1, keepdim = True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
可视化结果:
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace = 0.5)
plt.show()
十、编写函数进行图片分类
输入图片和供选择标签进行分类:
def show_result(image, probs, labels, label_name):
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(1, 2, 2)
y = np.arange(probs.shape[-1])
plt.grid()
plt.barh(y, probs[0])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [label_name[index] for index in labels[0].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace = 0.5)
plt.show()
def clip_classifier(image_path, choice_label, top_k = 5):
# top_k小于choice_label数
if top_k > len(choice_label):
raise Exception('top_k大于候选标签数')
# 读取图片
image = Image.open(image_path).convert("RGB")
# 输入特征
text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
text_tokens = clip.tokenize(text_descriptions).cuda()
image_input = preprocess(image)
image_input = image_input.clone().detach().cuda()
with torch.no_grad():
image_features = model.encode_image(image_input.unsqueeze(0)).float()
text_features = model.encode_text(text_tokens).float()
image_features /= image_features.norm(dim = -1, keepdim = True)
text_features /= text_features.norm(dim = -1, keepdim = True)
#相似度得分
text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
show_result(image, top_probs, top_labels, choice_label)
十一、测试自己的函数
clip_classifier('R.jpg',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])
clip_classifier('Holmes.jpg',['Holmes','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])
十二、编写函数对多图片进行分类
def clip_classifier_m(image_dir, choice_label, top_k = 5):
# image_dir不为文件夹
if not os.path.isdir(image_dir):
raise Exception(image_dir + ' 应该为一个图片文件夹')
# top_k小于choice_label数
if top_k > len(choice_label):
raise Exception('top_k大于候选标签数')
#读取图片
original_images = []
images = []
for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
image = Image.open(os.path.join(image_dir, filename)).convert("RGB")
original_images.append(image)
images.append(preprocess(image))
# 输入特征
text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
text_tokens = clip.tokenize(text_descriptions).cuda()
image_input = torch.tensor(np.stack(images)).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
image_features /= image_features.norm(dim = -1, keepdim = True)
text_features /= text_features.norm(dim = -1, keepdim = True)
# 相似度得分
text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
show_result_m(original_images, top_probs, top_labels, choice_label)
def show_result_m(images, probs, labels, label_name):
length = len(images)
num_row = length // 2
plt.figure(figsize = (16, 16))
for i, image in enumerate(images):
plt.subplot(num_row, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(num_row, 4, 2 * i + 2)
y = np.arange(probs.shape[-1])
plt.grid()
plt.barh(y, probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [label_name[index] for index in labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace = 1)
plt.show()
clip_classifier_m('img',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','Holmes'])