CLIP模型原理

news2024/12/25 21:30:15

CLIP模型

CLIP(Contrastive Language-Image Pre-Training) 模型是 OpenAI 在 2021 年初发布的用于匹配图像文本的预训练神经网络模型,是近年来在多模态研究领域的经典之作。OpenAI 收集了 4 亿对图像文本对(一张图像和它对应的文本描述),分别将文本和图像进行编码,使用 metric learning进行训练。希望通过对比学习,模型能够学习到文本-图像对的匹配关系。

CLIP的论文地址

CLIP模型共有3个阶段:1阶段用作训练,2、3阶段用作推理。

  1. Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;
  2. Create dataset classifier from label text:提取预测类别文本特征;
  3. Use for zero-shot predictiion:进行 Zero-Shot 推理预测;

在这里插入图片描述

1、训练阶段

通过计算文本和目标图像的余弦相似度从而获取预测值。CLIP模型主要包含以下两个模型;

  • Text Encoder:用来提取文本的特征,可以采用NLP中常用的text transformer模型;
  • Image Encoder:用来提取图像的特征,可以采用常用CNN模型或者vision transformer模型;

在这里插入图片描述
这里举例一个包含N个文本-图像对的训练batch,对提取的文本特征和图像特征进行训练的过程:

  1. 输入图片 —> 图像编码器 —> 图片特征向量;输入文字 —> 文字编码器 —> 文字特征向量;并进行线性投射,得到相同维度;
  2. N个文本特征和N个图像特征两两组合,形成一个具有N2个元素的矩阵;
  3. CLIP模型会预测计算出这N2个文本-图像对的相似度(文本特征和图像特征的余弦相似性即为相似度);
  4. 对角线上的N个元素因为图像-标签对应正确被作为训练的正样本,剩下的N2-N个元素作为负样本;
  5. CLIP的训练目标就是最大化N个正样本的相似度,同时最小化N2-N个负样本的相似度;

2、推理过程

CLIP的预测推理过程主要有以下两步:

  1. 提取预测类别的文本特征:由于CLIP 预训练文本端的输出输入都是句子,因此需要将任务的分类标签按照提示模板 (prompt template)构造成描述文本(由单词构造成句子):A photo of {object}.,然后再送入Text Encoder得到对应的文本特征。如果预测类别的数目为N,那么将得到N个文本特征。
  2. 进行 zero-shot 推理预测:将要预测的图像送入Image Encoder得到图像特征,然后与上述的N个文本特征计算余弦相似度(和训练过程一致),然后选择相似度最大的文本对应的类别作为图像分类预测结果。进一步地,可以将这些相似度看成输入,送入softmax后可以得到每个类别的预测概率。

在这里插入图片描述

3、补充:zero-shot 零样本学习

zero-shot :零样本学习,域外泛化问题。利用训练集数据训练模型,使得模型能够对测试集的对象进行分类,但是训练集类别和测试集类别之间没有交集,期间需要借助类别的描述,来建立训练集和测试集之间的联系,从而使得模型有效。

可以发现CLIP其实就是两个模型:视觉模型 + 文本模型。

在计算机视觉中,即便想迁移VGGMobileNet这种预训练模型,也需要经过预训练、微调等手段,才能学习数据集的数据特征,而CLIP可以直接实现zero-shot的图像分类,即不需要任何训练数据,就能在某个具体下游任务上实现分类,这也是CLIP亮点和强大之处。

我的猜测:CLIP的zero-shot能力是依赖于它预训练的4亿对图像-文本对,样本空间涵盖的太大,并不是真正的零样本学习,和解决域外泛化问题。和人脸比对的原理相似,依靠大量样本来学习分类对象的特征空间。人脸比对是image-to-image,CLIP是 image-to-text。

4、代码: CLIP实现zero-shot分类

OpenAI有关CLIP的代码链接地址

环境:

pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git

Torch version: 1.9.0+cu102

4.1、模型加载

import clip

clip.available_models()

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

4.2、图像、文本数据处理

向模型提供8个示例图像及其文本描述,并比较相应特征之间的相似性

# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

在这里插入图片描述

4.3、建立图片特征

对图像进行归一化,对每个文本输入进行标记,并运行模型的前向传递以获得图像和文本特征

image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()

with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

4.4、计算余弦相似度

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)

在这里插入图片描述

4.5、Zero-Shot图像分类

from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)

text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()

with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

plt.figure(figsize=(16, 16))

for i, image in enumerate(original_images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image)
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1106177.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UPS设备的最新管理方法,简单高效!

随着信息技术的快速发展,UPS监控系统变得至关重要。系统用于监视、管理和维护机房中的UPS设备,以确保稳定的电力供应,保护敏感的电子设备和数据中心运营。 UPS监控系统提供了对电力系统的关键可见性,使运维团队能够预测和解决潜在…

家政系统预约小程序具备哪些功能?

预约家政小程序有这么大的市场需求加上这么多的好处,相信未来发展前景不错。也必将吸引很多商家投资者着手开发属于自己的上门家政APP小程序软件,在实际的开发过程中需要具备哪些功能呢? 一、用户端功能: 1. 用户注册登录&#x…

绕过防火墙

1.pikuchu靶场 位置 该文件 然后找到install 文件 初始化 创建 就ok了 安全狗搭建 bug解决 进入apache目录 命令 安全狗安装 重启web服务 绕过 前端 绕过 waf 阻拦 脏数据 成功 二.yakit 中国版的bp

源码编译安装部署lnmp

源码编译安装部署lnmp 文章目录 源码编译安装部署lnmp1.简介:2.环境说明:3.部署前的准备工作4.安装nginx4.1.进入官网拉取nginx源码包4.2.通过IP地址访问nginx的web页面 5.安装mysql5.1.安装依赖包5.2.创建用户和组5.3.下载源码包并解压到/usr/local/5.4…

在数组中合并相同id数据,并且数据中某一字段不一致也统一合并进去

封装的合并的函数 function formateArray(data:any){// ts-ignorelet res data.reduce((ac,a) > {// ts-ignorelet index ac.findIndex(x > x.id a.id);index -1 ? ac.push({...a}) : ac[index] {...ac[index],...a};return ac;},[])return res;}使用 allData 原始…

忆联分布式数据库存储解决方案,助力MySQL实现高性能、低时延

据艾瑞咨询研究院《2022 年中国数据库研究报告》显示,截止2021年,中国分布式数据库占比达到 20%左右,主要以 MySQL 和 PostgreSQL 为代表的开源数据库为主。MySQL 作为备受欢迎的开源数据库,当前已广泛应用于互联网、金融、交通、…

(Python)在Matplotlib中对图像坐标轴进行log转换

对于跨度很大其分布离散的数据,常用log转换来缩写其差距,呈现在图上的效果也更好,比如在绘制转录组的表达量数据时,常用log转换之后的值进行绘制。在matplotlib中,支持在绘图时对数据进行log转换,根据log转…

解密代理IP:加速互联网业务的利器

众所周知,代理IP是一类常见的互联网服务。借助代理IP,一个终端可以通过远程的服务器(即代理IP对应的服务器)访问另一个终端,从而非直接地接触。代理IP在日常互联网应用中的应用场景十分广泛,包括但不限于&a…

C++:模板初阶

本篇文章主要对模板有个简单的认识,方便我们后面对模板进行更加深入的学习。 目录 1.泛型编程 2.函数模板 2.1 函数模板的概念 2.2 函数模板格式 2.3 函数模板的原理 2.4 函数模板的实例化 2.5 模板参数的匹配原则 3.类模板 3.1 类模板的格式定义 3.2 类模…

2023年【北京市安全员-C3证】最新解析及北京市安全员-C3证作业考试题库

题库来源:安全生产模拟考试一点通公众号小程序 北京市安全员-C3证最新解析是安全生产模拟考试一点通总题库中生成的一套北京市安全员-C3证作业考试题库,安全生产模拟考试一点通上北京市安全员-C3证作业手机同步练习。2023年【北京市安全员-C3证】最新解…

缺失找不到msvcr71.dll无法执行代码,应用程序无法启动的解决方法

最近我在使用电脑时遇到了一个问题,提示我缺少 msvcr71.dll 这个文件。这个文件是系统中的一个动态链接库文件,常用于支持一些运行在 Windows 系统上的程序。 当我发现这个问题时,我感到有点困惑和焦虑。因为我需要使用的软件要求系统中必须…

SimpleCG图像操作基础

上一篇我们介绍了程序的交互功能,就可以编写一些简单的游戏了,例如贪吃蛇、扫雷、俄罗斯方块、五子棋等,都可以使用图形函数直接绘制,在后续文章中将逐一展示。不过编写画面丰富游戏离不开图像,所以本篇我们介绍一下基…

零信任身份管理平台,构建下一代网络安全体系

随着数字化时代的到来,网络安全已成为企业和组织面临的一项重要挑战。传统的网络安全方法已经无法满足不断演变的威胁和技术环境。近期,中国信息通信研究院(简称“中国信通院”)发布了《零信任发展研究报告( 2023 年&a…

全球领先的即时通讯厂家,为企业提供卓越沟通解决方案

不同部门的协同合作是企业内部高效运作的关键,然而,传统的沟通方式往往会受到时间、空间以及信息传递效率的限制,给企业带来不必要的困扰。随着科技的不断进步,解决这一问题的新利器应运而生——WorkPlus,一款基于即时…

链表增删操作问题及解决方法

目录 链表增加元素首部中间尾部 链表删除元素首部中间尾部 链表是一种常用的数据结构,用于存储和组织数据。在链表中,增加和删除元素是常见的操作。然而,在进行链表的增删操作时,对于首部、中间和尾部位置的元素,都存在…

UWB安全数据通讯STS-加密、身份认证

DW3000系列才能支持UWB安全数据通讯,DW1000不支持 IEEE 802.15.4a没有数据通讯安全保护机制,IEEE 802.15.4z中指定的扩展得到增强(在PHY/RF级别):增添了一个重要特性“扰频时间戳序列(STS)”&a…

mysql修改root用户的密码

mysql修改root用户的密码 方法1: 用SET PASSWORD命令方法2:用mysqladmin方法3:用UPDATE直接编辑user表方法4:在忘记root密码的时候,可以这样以windows为例: 连接mysql问题 mysql备份工具之mysqldump 方法1&…

反转链表(java)

大家好我是苏麟今天说一说链表常见的简单题目 . BM1 反转链表 牛客BM1 反转链表 : 描述 : 给定一个单链表的头结点(该头节点是有值的,比如在下图,它的val是1),长度为n,反转该链表后,返回新链表的表头。 分析 : …

AP5101C 高压线性恒流 LED电源驱动IC 3D打印机显示灯驱动器

1,产品描述 AP5101C 是一款高压线性 LED 恒流芯片 , 简单 、 内置功率管 , 适用于6- 100V 输入的高精度降压 LED 恒流驱动芯片。电流2.0A。AP5101C 可实现内置MOS 做 2.0A,外置 MOS 可做 3.0A 的。AP5101C 内置温度保护功能 ,温度…

3. 实战入门

3. 实战入门 文章目录 3. 实战入门3.1 Namespace3.1.1测试两个不同的名称空间之间的 Pod 是否连通性 3.2 Pod3.3 Label3.4 Deployment3.5 Service 本章节将介绍如何在kubernetes集群中部署一个nginx服务,并且能够对其进行访问。 3.1 Namespace Namespace是kubernet…