【计算机视觉】使用 notebook 展示如何下载和运行 CLIP models,计算图片和文本相似度,实现 zero-shot 图片分类

news2024/11/25 4:49:21

文章目录

  • 一、CLIP 模型
  • 二、准备
  • 三、加载模型
  • 四、查看图片处理器
  • 五、文本分词
  • 六、输入图片和文本,并可视化
  • 七、将图片和文字 encode 生成特征
  • 八、计算 cosine 相似度
  • 九、零样本进行图片分类
  • 十、编写函数进行图片分类
  • 十一、测试自己的函数
  • 十二、编写函数对多图片进行分类

项目地址:

https://github.com/biluko/Paper_Codes_for_fun/tree/master/CLIP

在这里插入图片描述

一、CLIP 模型

CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一个深度学习模型,用于处理图像和文本之间的联合表示。它的目标是将图像和文本嵌入到一个共享的向量空间中,使得相似的图像和文本在这个空间中距离较近,而不相似的图像和文本距离较远。

CLIP模型的特点在于它可以通过对图像和文本之间进行对比学习,来学习到一个通用的特征表示。在训练过程中,CLIP通过最大化相似图像和文本的相似性,并最小化不相似图像和文本的相似性来调整模型参数。这种对比学习的方法使得CLIP能够在多个任务上进行迁移学习,如图像分类、文本分类、图像生成等。

CLIP模型的应用非常广泛。通过将图像和文本映射到共享的向量空间,CLIP可以实现图像和文本之间的多模态检索和匹配。例如,通过将一张图片和一个描述该图片内容的文本查询进行编码,可以计算它们在向量空间中的距离,并找到与之相似的图片或文本。这为图像搜索、商品推荐、智能问答等应用提供了新的可能性。

CLIP模型的优势在于它不需要大量标注的训练数据,而是通过对比学习来学习通用的特征表示。这使得CLIP在跨领域和跨语言的应用上具有良好的泛化能力。此外,CLIP还能够理解和生成自然语言描述的图像,以及生成图像描述的文本,具备了一定的语义理解和生成能力。

总之,CLIP是一个强大的深度学习模型,能够将图像和文本嵌入到共享的向量空间中,并实现多模态的检索和匹配。它在图像和文本处理、多模态应用以及迁移学习等方面有着广泛的应用前景。

二、准备

包括下载 CLIP 依赖和将设置改为 GPU:

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

在这里插入图片描述

import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)

在这里插入图片描述

三、加载模型

展示可选择的不同图片特征提取器:

import clip
clip.available_models()

在这里插入图片描述
加载模型和图片处理器:

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("模型参数:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("输入图片尺寸:", input_resolution)
print("文本长度:", context_length)
print("词表大小:", vocab_size)

在这里插入图片描述

四、查看图片处理器

这里调整图片大小 224 × 224 224 \times 224 224×224,中心裁剪,然后使用均值和标准差进行归一化,最后输出tensor向量:

preprocess

在这里插入图片描述

五、文本分词

clip.tokenize("Hello World!")

在这里插入图片描述

六、输入图片和文本,并可视化

import os 
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format='retina'

# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}
original_images=[]
images=[]
texts=[]
plt.figure(figsize=(16,5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()

在这里插入图片描述

七、将图片和文字 encode 生成特征

image_input = torch.tensor(np.stack(images)).cuda()

print(image_input.shape)
text_tokens = clip.tokenize(['This is '+ desc for desc in texts]).cuda()

with torch.no_grad():
  image_features = model.encode_image(image_input).float()
  text_features = model.encode_text(text_tokens).float()
print(image_features.shape)
print(text_features.shape)

在这里插入图片描述

八、计算 cosine 相似度

image_features /= image_features.norm(dim = -1,keepdim = True)
text_features /= text_features.norm(dim = -1,keepdim = True)

similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize = (20, 14))
plt.imshow(similarity, vmin = 0.1, vmax = 0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize = 18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent = (i - 0.5, i + 0.5, -1.6, -0.6), origin = "lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha = "center", va = "center", size = 12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size = 20)

在这里插入图片描述

九、零样本进行图片分类

数据集CIFAR100,就是使用相似度计算得分,然后softmax一下:

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform = preprocess, download = True)

在这里插入图片描述
加上prompt 提示模板进行分类:

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

计算相似度得分:

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim = -1, keepdim = True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)

可视化结果:

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace = 0.5)
plt.show()

在这里插入图片描述

十、编写函数进行图片分类

输入图片和供选择标签进行分类:

def show_result(image, probs, labels, label_name):
  plt.figure()
  plt.subplot(1, 2, 1)
  plt.imshow(image)
  plt.axis("off")

  plt.subplot(1, 2, 2)
  y = np.arange(probs.shape[-1])
  plt.grid()
  plt.barh(y, probs[0])
  plt.gca().invert_yaxis()
  plt.gca().set_axisbelow(True)
  plt.yticks(y, [label_name[index] for index in labels[0].numpy()])
  plt.xlabel("probability")

  plt.subplots_adjust(wspace = 0.5)
  plt.show()


def clip_classifier(image_path, choice_label, top_k = 5):
  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')

  # 读取图片
  image = Image.open(image_path).convert("RGB")
  # 输入特征
  text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = preprocess(image)
  image_input = image_input.clone().detach().cuda()

  with torch.no_grad():
    image_features = model.encode_image(image_input.unsqueeze(0)).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  #相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result(image, top_probs, top_labels, choice_label)

十一、测试自己的函数

clip_classifier('R.jpg',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

在这里插入图片描述

clip_classifier('Holmes.jpg',['Holmes','pig','boy','girl','one piece','bleach','black','man','cartoon','red','detector'])

在这里插入图片描述

十二、编写函数对多图片进行分类

def clip_classifier_m(image_dir, choice_label, top_k = 5):
  # image_dir不为文件夹
  if not os.path.isdir(image_dir):
    raise Exception(image_dir + ' 应该为一个图片文件夹')

  # top_k小于choice_label数
  if top_k > len(choice_label):
    raise Exception('top_k大于候选标签数')


  #读取图片
  original_images = []
  images = []

  for filename in [filename for filename in os.listdir(image_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    image = Image.open(os.path.join(image_dir, filename)).convert("RGB")

    original_images.append(image)
    images.append(preprocess(image))

  # 输入特征
  text_descriptions = [f"This is a photo of a {label}" for label in choice_label]
  text_tokens = clip.tokenize(text_descriptions).cuda()
  image_input = torch.tensor(np.stack(images)).cuda()
  with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

    image_features /= image_features.norm(dim = -1, keepdim = True)
    text_features /= text_features.norm(dim = -1, keepdim = True)

  # 相似度得分
  text_probs = (100.0 * image_features @ text_features.T).softmax(dim = -1)
  top_probs, top_labels = text_probs.cpu().topk(5, dim = -1)
  show_result_m(original_images, top_probs, top_labels, choice_label)


def show_result_m(images, probs, labels, label_name):
  length = len(images)
  num_row = length // 2

  plt.figure(figsize = (16, 16))

  for i, image in enumerate(images):
    plt.subplot(num_row, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(num_row, 4, 2 * i + 2)
    y = np.arange(probs.shape[-1])
    plt.grid()
    plt.barh(y, probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [label_name[index] for index in labels[i].numpy()])
    plt.xlabel("probability")

  plt.subplots_adjust(wspace = 1)
  plt.show()
clip_classifier_m('img',['Luffy','pig','boy','girl','one piece','bleach','black','man','cartoon','red','Holmes'])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/627931.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面对职业发展“迷茫期”除了抱怨焦虑我们还能做什么?

关注“软件测试藏经阁”微信公众号,回复暗号【软件测试】,即可获取氪肝整理的全套测试资源 Java和Python做自动化测试,哪个更有优势?这两个语言都是很流行的语言,所以从技术上很难说谁好谁不好的。因为要说好不好得看…

linux安装homeassistant(智能设备远程控制开源框架)

1、安装docker 先切换到root 用户,先安装一些基本环境: yum install -y yum-utils device-mapper-persistent-data lvm2添加阿里云软件源 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo然后安装 D…

QT+OpenGL高级光照 Blinn-Phong和Gamma校正

QTOpenGL高级光照1 本篇完整工程见gitee:QtOpenGL 对应点的tag,由turbolove提供技术支持,您可以关注博主或者私信博主 Blinn-Phong 冯氏光照:视线与反射方向之间的夹角不小于90度,镜面光分量会变成0.0(不是很合理&am…

死信队列小结

死信队列是RabbitMQ中非常重要的一个特性。简单理解,他是RabbitMQ对于未能正常消费的消息进行的 一种补救机制。死信队列也是一个普通的队列,同样可以在队列上声明消费者,继续对消息进行消费处理。 对于死信队列,在RabbitMQ中主要…

Spring 是什么?IoC 和 DI的区别

1. Spring 是什么?2. IoC是什么? 2.DI概念说明 1. Spring 是什么? 我们通常讲的Spring指的是Spring Framework(Spring框架),它是一个开源的框架,有着活跃而庞大的社区,这也是它之所谓经久不衰的原因。官方的解读是:Spring官网 翻译过来就是:Spring使Java编程对每…

学会这5个步骤,就能轻轻松松地获取代码覆盖率报告

目录 前言: 1、创建main函数的test文件 2、插桩方式编译源码 3、运行主服务 4、执行测试用例 5、优雅退出主服务,并生成覆盖率报告 前言: 代码覆盖率报告可以帮助我们了解测试用例的质量和覆盖程度。 小编前期所测项目多为go语言研发&…

《C++高级编程》读书笔记(一:C++和标准库速成)

1、参考引用 C高级编程(第4版,C17标准)马克葛瑞格尔 2、建议先看《21天学通C》 这本书入门,笔记链接如下 21天学通C读书笔记(文章链接汇总) 1. C 基础知识 1.1 小程序 “hello world” // helloworld.cpp…

开源项目合集......

likeshop开源商城系统,公众号商城、H5商城、微信小程序商城、抖音小程序商城、字节小程序商城、头条小程序商城、安卓App商城、苹果App商城代码全开源,免费商用。 适用场景:B2C商城、新零售商城、社交电商商城、分销系统商城、小程序商城、商…

循环链表的创建

循环链表的介绍及创建(C语言代码实现) 点击打开在线编译器,边学边练 循环链表概念 对于单链表以及双向链表,其就像一个小巷,无论怎么样最终都能从一端走到另一端,然而循环链表则像一个有传送门的小巷&…

力扣 912. 排序数组

文章目录 一、题目描述二、题解1.快速排序2.堆排序3.二路归并排序 一、题目描述 给你一个整数数组 nums,请你将该数组升序排列。 示例 1: 输入:nums [5,2,3,1] 输出:[1,2,3,5]示例 2: 输入:nums [5,1,1…

精细消费 年轻人和父母的奇妙交汇

日本社会学家三浦展结合对日本“311”大地震后的社会观察,提出了“第四消费时代”,即人们在经历了消费社会充分的发展过程之后,社会上逐渐兴起了低欲望、乐于共享、重视环保的消费理念。 在当时,主流观点普遍认为中国还处于大众化…

JWT单点登录

单点登录 文章目录 单点登录零、用户模块内容以及设计一、问题的提出二、单点登录SSO1.1 什么是单点登录1.2 单点登录的技术实现机制 二、远程调用方式RPC三、JWT的使用3.1 session的使用原理3.2 JWT介绍3.3 JWT原理3.4 JWT的使用 四、CAS实现单点登录的原理四、CAS的安装和代码…

十二、进程间通信

目录 目录 零、前置知识 一、什么是进程间通信 (一)含义 (二)发展 (三)类型 1.管道 2.System V IPC 3.POSIX IPC 二、为什么要有进程间通信 三、怎么进行进程间通信 (一)…

Snipaste工具推荐

Snipaste Snipaste 不只是截图,善用贴图功能将帮助你提升工作效率! 新用户? 截图默认为 F1,贴图为 F3,然后请对照着 快捷键列表 按一遍,体会它们的用法,就入门啦! 遇到了麻烦&…

Java通过Ip2region实现IP定位

我们在一些短视频平台上可以看到,视频作者或评论区可以显示IP地址,这其实就是根据IP获取到的我们可以通过一些在线网站就可以看到我们当前的公网IP和IP定位,最近有个需求也需要通过请求获取客户端的IP和IP的定位,于是通过一系列的百度,最终选择使用Ip2region这个工具库来进行定…

flutter的自定义系列雷达图

自定义是flutter进阶中不可缺少的ui层知识点,这里我们来总结下: 在Flutter中,自定义绘制通常是通过使用CustomPaint和CustomPainter来实现的。 创建CustomPaint组件 首先,需要创建一个CustomPaint组件。CustomPaint是一个Widge…

MobPush 厂商通道申请指南

华为厂商申请 创建应用 登录华为开发者联盟,注册您的应用,在应用信息中获取APP ID和Client Secret 配置SHA256证书指纹 在华为开发者联盟配置SHA256证书指纹。获取及配置请参见华为官方文档配置AppGallery Connect 设置消息回执 集成华为厂商通道SDK…

带你了解二进制

目录 视频参考: 讲解:​编辑 运算: 1001(二进制) 9(十位数)1111(二进制) 15(十位数)11001(二进制) 25(…

第二章 搭建TS环境

搭建 TypeScript 的开发环境。一个舒适、便捷且顺手的开发环境,不仅能大大提高学习效率,也会对我们日常的开发工作有很大帮助。 这一节我们就来介绍 VS Code 下的 TypeScript 环境搭建:插件以及配置项。对于 TS 文件的执行,我们会…

设计模式(十):结构型之外观模式

设计模式系列文章 设计模式(一):创建型之单例模式 设计模式(二、三):创建型之工厂方法和抽象工厂模式 设计模式(四):创建型之原型模式 设计模式(五):创建型之建造者模式 设计模式(六):结构型之代理模式 设计模式…