CLIP在Github上的使用教程

news2025/1/16 7:45:20

CLIP的github链接:https://github.com/openai/CLIP

CLIP

Blog,Paper,Model Card,Colab
CLIP(对比语言-图像预训练)是一个在各种(图像、文本)对上进行训练的神经网络。可以用自然语言指示它在给定图像的情况下预测最相关的文本片段,而无需直接对任务进行优化,这与 GPT-2 和 3 的零镜头功能类似。我们发现,CLIP 无需使用任何 128 万个原始标注示例,就能在 ImageNet "零拍摄 "上达到原始 ResNet50 的性能,克服了计算机视觉领域的几大挑战。

Usage用法

首先,安装 PyTorch 1.7.1(或更高版本)和 torchvision,以及少量其他依赖项,然后将此 repo 作为 Python 软件包安装。在 CUDA GPU 机器上,完成以下步骤即可:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

将上面的 cudatoolkit=11.0 替换为机器上相应的 CUDA 版本,如果在没有 GPU 的机器上安装,则替换为 cpuonly

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

CLIP 模块提供以下方法:

clip.available_models()

返回可用 CLIP 模型的名称。例如下面就是我执行的结果。
在这里插入图片描述

clip.load(name, device=..., jit=False)

返回模型和模型所需的 TorchVision 变换(由 clip.available_models() 返回的模型名称指定)。它将根据需要下载模型。name参数也可以是本地检查点的路径。
可以选择指定运行模型的设备,默认情况下,如果有第一个 CUDA 设备,则使用该设备,否则使用 CPU。当 jitFalse 时,将加载模型的非 JIT 版本。

clip.tokenize(text: Union[str, List[str]], context_length=77)

返回包含给定文本输入的标记化序列的 LongTensor。这可用作模型的输入。

clip.load() 返回的模型支持以下方法:

model.encode_image(image: Tensor)

给定一批图像,返回 CLIP 模型视觉部分编码的图像特征。

model.encode_text(text: Tensor)

给定一批文本标记,返回 CLIP 模型语言部分编码的文本特征。

model(image: Tensor, text: Tensor)

给定一批图像和一批文本标记,返回两个张量,其中包含与每张图像和每个文本输入相对应的 logit 分数。这些值是相应图像和文本特征之间的余弦相似度乘以 100。

More Examples更多实例

Zero-Shot预测

下面的代码使用 CLIP 执行零点预测,如论文附录 B 所示。该示例从 CIFAR-100 数据集中获取一张图片,并预测数据集中 100 个文本标签中最有可能出现的标签。

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

输出结果如下(具体数字可能因计算设备而略有不同):

Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

请注意,本示例使用的 encode_image()encode_text() 方法可返回给定输入的编码特征。

Linear-probe evaluation线性探针评估

下面的示例使用 scikit-learn 对图像特征进行逻辑回归。

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

请注意,C 值应通过使用验证分割进行超参数扫描来确定。

See Also

OpenCLIP:包括更大的、独立训练的 CLIP 模型,最高可达 ViT-G/14
Hugging Face implementation of CLIP:更易于与高频生态系统集成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1287583.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JS箭头函数

箭头函数 1. 基本语法 // // 一般函数const fn function() {console.log(123);}// 箭头函数const fn () > {console.log(123);}fn()const fn (x) > {console.log(x);}fn(1)// 只有一个形参的时候可以省略小括号const fn x > {console.log(x);}fn(1)// 只有一行代…

基于c++版本链队列改-Python版本链队列基础理解

##基于链表的队列实现 可以将链表的“头节点”和“尾节点”分别视为“队首”和“队尾”,规定队尾仅可添加节点,队首仅可删除节点。 ##图解 ##基于链表的队列实现代码 class ListNode:"""定义链表"""def __init__(self)…

nodejs微信小程序+python+PHP本科生优秀作业交流网站的设计与实现-计算机毕业设计推荐

通过软件的需求分析已经获得了系统的基本功能需求,根据需求,将本科生优秀作业交流网站功能模块主要分为管理员模块。管理员添加系统首页、个人中心、用户管理、作业分类管理、作业分享管理、论坛交流、投诉举报、系统管理等操作。 随着信息化社会的形成…

Mybatis XML 配置文件

我们刚开始就有说Mybatis 的开发有两种方式: 1.注释 2.XML 注解和 XML 的方式是可以共存的 我们前面说的都是注释的方式,接下来是XML方式 XML的方式分为三步 : 1.配置数据库(配在 application.yml 里面) 这个跟注释的配置是一样的,username应该都是一样的,password记得写…

git常用命令指南

目录 一、基本命令 1、创建分支 2、切换分支 3、合并分支 4、初始化空git仓库 二、文件操作 1、创建文件 2、添加多个文件 3、查看项目的当前状态 4、修改文件 5、删除文件 6、提交项目 三、实际操作 1、创建目录 2、进入新目录 3、初始化空git仓库 4、创建文…

【OpenGauss源码学习 —— (RowToVec)算子】

VecToRow 算子 概述ExecInitRowToVec 函数ExecRowToVec 函数VectorizeOneTuple 函数 ExecEndRowToVec 函数总结 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊重他人的知识产权和学术成果,力求遵循合理使用原则,并在…

使用AWS Glue与AWS Kinesis构建的流式ETL作业(二)——数据处理

大纲 2 数据处理2.1 架构2.2 AWS Glue连接和创建2.2.1 创建AWS RedShift连接2.2.2 创建AWS RDS连接(以PG为例) 2.3 创建AWS Glue Job2.4 编写脚本2.4.1 以AWS RedShift为例2.4.2 以PG为例 2.5 运行脚本 2 数据处理 2.1 架构 2.2 AWS Glue连接和创建 下…

DSSS技术和OFDM技术

本内容为学习笔记,内容不一定正确,请多处参考进行理解 https://zhuanlan.zhihu.com/p/636853588 https://baike.baidu.com/item/OFDM/5790826?frge_ala https://zhuanlan.zhihu.com/p/515701960?utm_id0 一、 DSSS技术 信号替代:DSSS技术为…

谈一谈内存池

文章目录 一,什么是内存池二,进程地址空间中是如何解决内存碎片问题的三,malloc的实现原理四,STL中空间配置器的实现原理五,高并发内存池该内存池的优势在哪里内存池的设计框架内存申请流程ThreadCache层CentreCache层…

论文阅读:一种通过降低噪声和增强判别信息实现细粒度分类的视觉转换器

论文标题: A vision transformer for fine-grained classification by reducing noise and enhancing discriminative information 翻译: 一种通过降低噪声和增强判别信息实现细粒度分类的视觉转换器 摘要 最近,已经提出了几种基于Vision T…

Linux 中用户与权限

1.添加用户 useradd 1)创建用户 useradd 用户名 2)设置用户密码 passwd 用户名 设置密码是便于连接用户时使用到,如我使用物理机链接该用户 ssh 用户名 ip 用户需要更改密码的话,使用 passwd 指令即可 3)查看用户信息 id 用…

【数据结构(六)】希尔排序、快速排序、归并排序、基数排序的代码实现(3)

文章目录 1. 希尔排序1.1. 简单插入排序存在的问题1.2. 相关概念1.3. 应用实例1.3.1. 交换法1.3.1.1. 逐步推导实现方式1.3.1.2. 通用实现方式1.3.1.3. 计算时间复杂度 1.3.2. 移动法 2. 快速排序2.1. 相关概念2.2. 实例应用2.2.1. 思路分析2.2.2. 代码实现 2.3. 计算快速排序的…

Spatial Data Analysis(三):点模式分析

Spatial Data Analysis(三):点模式分析 ---- 1853年伦敦霍乱爆发 在此示例中,我将演示如何使用 John Snow 博士的经典霍乱地图在 Python 中执行 KDE 分析和距离函数。 感谢 Robin Wilson 将所有数据数字化并将其转换为友好的 G…

(04730)电路分析基础之电阻、电容及电感元件

04730电子技术基础 语雀(完全笔记) 电阻元件、电感元件和电容元件的概念、伏安关系,以及功率分析是我们以后分析电 路的基础知识。 电阻元件 电阻及其与温度的关系 电阻 电阻元件是对电流呈现阻碍作用的耗能元件,例如灯泡、…

基于STM32驱动的压力传感器实时监测系统

本文介绍了如何使用STM32驱动压力传感器进行实时监测。首先,我们会介绍压力传感器的工作原理和常见类型。然后,我们将介绍如何选择合适的STM32单片机和压力传感器组合。接下来,我们会详细讲解如何使用STM32驱动压力传感器进行数据采集和实时监…

根文件系统软件运行测试

一. 简介 前面几篇文章学习了制作一个可以在开发板上运行的,简单的根文件系统。 本文在上一篇文章学习的基础上进行的,文章地址如下: 完善根文件系统-CSDN博客 本文对根文件系统软件运行进行测试。 我们使用 Linux 的目的就是运行我们自…

vue3 setup语法糖 多条件搜索(带时间范围)

目录 前言: setup介绍: setup用法: 介绍: 前言: 不管哪个后台管理中都会用到对条件搜索带有时间范围的也不少见接下来就跟着我步入vue的多条件搜索(带时间范围) 在 Vue 3 中,你…

[JavaScript前端开发及实例教程]计算器井字棋游戏的实现

计算器&#xff08;网页内实现效果&#xff09; HTML部分 <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>My Calculator&l…

Unity3D对CSV文件操作(创建、读取、写入、修改)

系列文章目录 Unity工具 文章目录 系列文章目录前言一、Csv是什么&#xff1f;二、创建csv文件2-1、构建表数据2-2、创建表方法2-3、完整的脚本&#xff08;第一种方式&#xff09;2-4、运行结果2-5、完整的脚本&#xff08;第二种方式&#xff09;2-6、运行结果2-7、想用哪种…

基于STM32 HAL库的光电传感器驱动程序实例

本文将使用STM32 HAL库编写一个光电传感器的驱动程序示例。首先&#xff0c;我们会介绍光电传感器的工作原理和应用场景。然后&#xff0c;我们将讲解如何选择合适的STM32芯片和光电传感器组合。接下来&#xff0c;我们会详细介绍使用STM32 HAL库编写光电传感器驱动程序的基本步…