【计算机视觉】CLIP实战:Zero-Shot Prediction(含源代码)

news2024/10/7 12:25:36

一、代码实战

下面的代码使用 CLIP 执行零样本预测。 此示例从 CIFAR-100 数据集中获取图像,并预测数据集中 100 个文本标签中最可能的标签。

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("./data/"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

最后的输出结果为:

在这里插入图片描述
我们不妨可视化一下这张图片:

import os
import pickle
from PIL import Image
import matplotlib.pyplot as plt

# Define the path to the CIFAR-100 dataset
dataset_path = os.path.expanduser('./data/cifar-100-python')

# Load the image
with open(os.path.join(dataset_path, 'test'), 'rb') as f:
    cifar100 = pickle.load(f, encoding='latin1')

# Select an image index to visualize
image_index = 3637

# Extract the image and its label
image = cifar100['data'][image_index]
label = cifar100['fine_labels'][image_index]

# Reshape and transpose the image to the correct format
image = image.reshape((3, 32, 32)).transpose((1, 2, 0))

# Create a PIL image from the numpy array
pil_image = Image.fromarray(image)

# Display the image
plt.imshow(pil_image, interpolation='bilinear')
plt.title('Label: ' + str(label))
plt.axis('off')
plt.show()

在这里插入图片描述
可以看到,很模糊的图片,这可能是因为 CIFAR-100 数据集本身就具有较低的图像分辨率,这是无法改变的。

二、代码逐行解读

2.1 预测

import os
import clip
import torch
from torchvision.datasets import CIFAR100

首先导入所需的库和模块,包括os、clip、torch和CIFAR100。

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

确定设备类型(使用GPU还是CPU),并加载预训练的 CLIP 模型(Vision Transformer - B/32)。clip.load()函数会返回加载的模型和数据预处理函数。

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("./data/"), download=True, train=False)

下载 CIFAR-100 数据集,并将其保存到指定的根目录中(“./data/”)。CIFAR100类从 torchvision.datasets 模块中导入,用于加载 CIFAR-100 数据集。

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

准备输入数据。首先,从 CIFAR-100 数据集中获取指定索引(3637)的图像和类别 ID。然后,对图像进行预处理,包括规范化和转换为模型所需的张量格式,并将其移动到设备上(GPU或CPU)。接下来,生成文本输入,其中包括 CIFAR-100 数据集中所有类别的文本描述,也转换为模型所需的张量格式,并移动到设备上。

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

计算图像和文本的特征向量。通过调用模型的encode_image()和encode_text()方法,将输入图像和文本转换为特征向量。由于不需要进行梯度计算,使用torch.no_grad()上下文管理器来禁止梯度计算。

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

选择图像最相似的前 5 个标签。首先,对图像特征向量和文本特征向量进行归一化。然后,计算图像特征向量与所有文本特征向量之间的相似度。通过执行矩阵乘法和 softmax 操作,得到每个文本描述与图像的相似度。最后,从相似度中选择最高的前 5 个值和对应的索引。

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

打印结果。将最相似的前 5 个标签及其对应的相似度打印出来,格式为类别名和百分比表示的相似度。

这段代码使用 CLIP 模型将图像与文本进行编码,并找到与图像最相似的文本标签。这可以用于图像分类或图像检索等任务。

2.2 可视化

import os
import pickle
from PIL import Image
import matplotlib.pyplot as plt

首先导入所需的库和模块,包括os、pickle、Image和matplotlib.pyplot。

# Define the path to the CIFAR-100 dataset
dataset_path = os.path.expanduser('./data/cifar-100-python')

定义 CIFAR-100 数据集的路径。os.path.expanduser()函数用于扩展用户目录中的路径。

# Load the image
with open(os.path.join(dataset_path, 'test'), 'rb') as f:
    cifar100 = pickle.load(f, encoding='latin1')

加载图像数据。使用open()函数打开 CIFAR-100 数据集中的图像文件(‘test’),并使用pickle.load()函数将图像数据加载到cifar100变量中。'latin1’是编码参数,用于指定加载数据的编码格式。

# Select an image index to visualize
image_index = 3637

选择一个图像的索引,用于可视化该图像。在这里,选择索引为 3637 的图像进行可视化。

# Extract the image and its label
image = cifar100['data'][image_index]
label = cifar100['fine_labels'][image_index]

提取所选图像和其标签。从cifar100字典中的’data’键中提取指定索引的图像数据,并从’fine_labels’键中提取相应的标签。

# Reshape and transpose the image to the correct format
image = image.reshape((3, 32, 32)).transpose((1, 2, 0))

调整图像的形状和排列顺序,使其与正确的格式匹配。reshape()函数将图像的形状从扁平的一维数组调整为(3, 32, 32)的三维数组,表示通道数、高度和宽度。然后,transpose()函数将维度重新排列,以将通道维度移至最后,得到(32, 32, 3)的图像格式。

# Create a PIL image from the numpy array
pil_image = Image.fromarray(image)

将 NumPy 数组转换为 PIL 图像对象。使用Image.fromarray()函数将 NumPy 数组image转换为 PIL 图像对象pil_image。

# Display the image
plt.imshow(pil_image, interpolation='bilinear')
plt.title('Label: ' + str(label))
plt.axis('off')
plt.show()

显示图像。使用plt.imshow()函数显示图像,通过设置interpolation参数为’bilinear’进行双线性插值,以改善图像的显示效果。plt.title()函数用于设置图像标题,标题中包含图像的标签。plt.axis(‘off’)用于关闭坐标轴的显示。最后,使用plt.show()函数显示图像。

这段代码加载 CIFAR-100 数据集中的图像数据,并可视化指定索引的图像及其标签。注意,通过使用双线性插值等图像显示选项,可以提高图像的清晰度和质量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/539395.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Parceable 使用和原理

简介 在 Android 开发中,我们经常需要在不同的组件之间传递数据,比如在 Activity 之间传递数据、在 Service 和 Activity 之间传递数据等。为了实现数据的传递,Android 提供了两种常用的方式,一种是使用 Intent,另一种…

opencv_c++学习(十)

一、图像尺寸变化 图像插值原理 在图像变换的过程中往往需要对像素进行相关的操作。如上图(左)所示,我们会遇到两个相邻的像素块需要映射到同样的位置中,或者两个相邻的位置的像素中间需要映射出一个位置的像素块。这时候我们就需…

JavaEE(系列7) -- 多线程(wait 和 notify 的使用)

首先对上一章节的指令重排序,在进行解释一下; 假设现在有连个线程 t1 和 t2 t1频繁(速度特别快)读取主内存,效率比较低,就被优化成直接读自己的工作内存。 t2修改了主内存的结果,由于t1没有读主内存,导致修改不能被识别到。 1.什么…

全网最全最细的【设计模式】总目录,收藏起来慢慢啃,看完不懂砍我

文章目录 一、设计模式七大原则1、单一职责原则2、接口隔离原则3、依赖倒置原则4、里氏替换原则5、开闭原则6、迪米特法则7、合成复用原则 二、UML类图三、设计模式1、创建型模式(1)单例模式(常用)(2)原型模…

mybatis-config.xml文件中的mappers标签

前言 在MyBatis中&#xff0c;< mapper >标签非常重要&#xff0c;因为它对应着我们存放sql语句的xml文件&#xff0c;在之前的使用中我们都是使用resource来指定路径&#xff0c;但其实除了resource可以指定路径的还有url和class但路径形式有所不同&#xff0c;下面来讨…

用「明道云+ChatGPT+Weaviate」挑战零代码1小时实现ChatPDF

ChatGPT流行起来之后&#xff0c;快速的出现了一批基于ChatGPT的工具应用&#xff0c;ChatPDF就是其中比较受欢迎的一款。它是一个可以让你与PDF文件进行对话的工具&#xff0c;既可以帮助你快速提取PDF文件中的信息&#xff0c;例如手册、论文、合同、书籍等&#xff1b;也可以…

【计算机视觉】最后显示的CIFAR-100数据集照片很模糊怎么解决?

文章目录 一、前言二、如何解决2.1 使用图像增强技术2.2 使用插值方法2.3 使用更高分辨率的图像数据集2.4 手动调整图像尺寸 三、总结 一、前言 如果从CIFAR-100数据集加载的图像显示模糊&#xff0c;可能有几个可能的原因&#xff1a; 分辨率较低&#xff1a;CIFAR-100数据集…

全力押注预制菜,叮咚买菜或错失即时零售红利

实际上&#xff0c;叮咚买菜相比于美团、京东更适合抢分即时零售的市场红利。 目前美团进入即时零售的逻辑是&#xff0c;拥有几百万骑手的履约软硬件可以复用&#xff0c;同时从外卖场景延伸到其他消费场景比较丝滑&#xff0c;从平台几千万用户的温饱满足&#xff0c;延展到多…

计算机网络实验(ensp)-实验10:三层交换机实现VLAN间路由

目录 实验报告&#xff1a; 实验操作 1.建立网络拓扑图并开启设备 2.配置主机 1.打开PC机 配置IP地址和子网掩码 2.配置完成后点击“应用”退出 3.重复步骤1和2配置每台PC 3.配置交换机VLAN 1.点开交换机 2.输入命名&#xff1a;sys 从用户视图切换到系统视图…

网络工程师精选习题详解(一)

请点击↑关注、收藏&#xff0c;本博客免费为你获取精彩知识分享&#xff01;有惊喜哟&#xff01;&#xff01; 1.在IPv4地址192.168.2.0/24中&#xff0c;表示主机的二进制位数是&#xff08; &#xff09;位。 A.8 B.16 C.24 D.32 答案&#xff1a;A /24示网络…

面对职业焦虑,我们能做些什么?

目录 大环境分析&#xff1a;AI 发展汹涌而上温水煮青蛙&#xff1a;那些“被替代”的“我们”码农分类&#xff1a;程序员都在做些什么码农黑暗季&#xff1a;失业潮原因分析程序员短期真的可替代吗&#xff1f;AI 发展来势汹汹&#xff0c;如何顺势而为最后&#xff1a;纵观全…

SpringBoot整合Swagger2,让接口文档管理变得更简单

在软件开发的过程中&#xff0c;接口文档的编写往往是一个非常重要的环节&#xff0c;因为它是前端和后端沟通的桥梁&#xff0c;帮助团队更好地协作。然而&#xff0c;手动编写接口文档不仅耗费时间&#xff0c;还容易出错&#xff0c;因此我们需要一种简单的方法来管理接口文…

宝武中南钢铁借助飞桨让钢筋超限监控有了“火眼金睛”

现代钢铁工业生产过程是一个复杂而庞大的生产体系&#xff0c;涵盖数百道工序。 在70多年的发展历程中&#xff0c;炼钢、轧钢、连铸以及节能减排等各项技术不断进化&#xff0c;无一不印证了中国钢铁在技术创新之路上获得的持续性突破。如今&#xff0c;宝武中南钢铁&#xff…

Java websocket 使用

简介 WebSocket 是一种基于 TCP 协议的全双工通信协议&#xff0c;可以在浏览器和服务器之间建立实时、双向的数据通信。在 Java 中&#xff0c;我们可以使用 Java API for WebSocket&#xff08;JSR 356&#xff09;来实现 WebSocket。 WebSocket 的作用是在 Web 应用程序中…

基于html+css的图展示77

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…

Charles安装及抓取APP接口

一、Charles使用 Charles是一款代理服务器&#xff0c;通过过将自己设置成系统&#xff08;电脑或者浏览器&#xff09;的网络访问代理服务器&#xff0c;然后截取请求和请求结果达到分析抓包的目的。该软件是用Java写的&#xff0c;能够在Windows&#xff0c;Mac&#xff0c;…

STM32F4_DAC数模转换

目录 1. DAC简介 2. DAC框图 3. DAC功能介绍 3.1 DAC通道使能 3.2 DAC输出缓冲器使能 3.3 DAC数据格式 3.4 DAC转换 3.5 DAC输出电压 3.6 DAC触发选择 3.7 DMA请求 3.8 生成噪声 3.9 生成三角波 4. 相关寄存器 4.1 DAC控制寄存器&#xff1a;DAC_CR 4.2 DAC1通道…

1-《java基础》

1-《java基础》 一.java基本数据类型和引用类型1.基本数据类型&#xff1a;2.引用数据类型3.基本数据类型和引用数据类型区别3.1 存储位置3.2 传递方式 4.自动装箱&#xff0c;自动拆箱 二.equals和的区别三.static1.static关键字的用途2.static方法3.static变量4.static代码块…

Unity中级客户端开发工程师的进阶之路

上期UWA技能成长系统之《Unity高级客户端开发工程师的进阶之路》得到了很多Unity开发者的肯定。通过系统的学习&#xff0c;可以掌握游戏性能瓶颈定位的方法和常见的CPU、GPU、内存相关的性能优化方法。 UWA技能成长系统是UWA根据学员的职业发展目标&#xff0c;提供技能学习的…

加密解密软件VMProtect教程(六):主窗口之控制面板“项目”部分(3)

VMProtect 是新一代软件保护实用程序。VMProtect支持德尔菲、Borland C Builder、Visual C/C、Visual Basic&#xff08;本机&#xff09;、Virtual Pascal和XCode编译器。 同时&#xff0c;VMProtect有一个内置的反汇编程序&#xff0c;可以与Windows和Mac OS X可执行文件一起…