深度学习论文: Learning Transferable Visual Models From Natural Language Supervision

news2024/9/23 19:25:10

深度学习论文: Learning Transferable Visual Models From Natural Language Supervision
Learning Transferable Visual Models From Natural Language Supervision
PDF: https://arxiv.org/pdf/2103.00020.pdf
官方代码: https://github.com/OpenAI/CLIP
PyTorch代码: https://github.com/shanglianlm0525/CvPytorch
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

CLIP(对比性语言-图像预训练)是一个在各种(图像,文本)对上进行训练的神经网络。它可以通过自然语言指令,在给定图像的情况下预测最相关的文本片段,而不是直接为任务进行优化,类似于GPT-2和GPT-3的零样本能力。发现CLIP在ImageNet的“零样本”上与原始的ResNet50的性能相匹配,而且没有使用任何原始的128万个标记示例,克服了计算机视觉中的几个重要挑战。

2 CLIP (Contrastive Language-Image Pre-Training)

首先是构建CLIP,CLIP实际上是一个预训练模型,包括文本编辑和图像编辑器两部分,分别计算文本向量和图像向量的相似度,以预测它们是否为一对,如图1所示。CLIP将图像和文本先分别输入一个图像编码器image_encoder和一个文本编码器text_encoder,得到图像和文本的向量表示 I f I_{f} If T f T_{f} Tf 。然后将图像和文本的向量表示映射到一个联合多通道空间,得到新的可直接进行比较的图像和文本的向量表示 I e I_{e} Ie T e T_{e} Te 。然后计算图像向量和文本向量之间的cosine相似度。最后,对比学习的目标函数就是让正样本对的相似度较高,负样本对的相似度较低。矩阵中非对角线上的元素都是负样本,n个正样本, n 2 − n n^{2}-n n2n个负样本,有了正负样本,模型就可以通过对比学习的方式去训练了,不需要任何手工标注。但是这种无监督的训练方式,是需要大量的训练数据的。
在这里插入图片描述

CLIP核心实现的伪代码:
在这里插入图片描述
推理代码:

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

在训练过程中,使用了5个不同的ResNet模型:ResNet-50,ResNet-101,4xResNet-50,16xResNet-50和64xResNet-50,以及3个ViT模型:ViT-B/32,ViT-B/16和ViT-L/14。训练时的迭代轮数为32,优化器使用Adam,并采用解耦权重衰减正则化技术,调度器采用余弦调度。初始超参数通过网格搜索确定。此外,训练过程中使用了半精度技术,并在训练完成后,再对稍大一些的336像素分辨率进行一个额外的训练轮数。最终,选择在336像素分辨率上表现最好的ViT-L/14作为后文所指的CLIP模型。

3 Experiments

3-1 Zero-Shot Transfer

CLIP实现零样本推理的方法是通过预训练获得文本和图像的特征,而没有分类头。为了解决这个问题,作者提出了一种利用自然语言的方法,即prompt template(提示模板)。对于ImageNet的类别,首先将其转化为句子的形式,例如"A photo of a {object}"。由于ImageNet有1000个类别,因此会生成1000个句子。然后,通过之前预训练好的文本编码器,这1000个句子可以得到1000个文本特征。

虽然可以直接使用类别单词来提取文本特征,但是在模型预训练时,与图像配对的都是句子,因此在推理时使用单词效果会下降。因此,使用句子作为提示更为有效。将待分类的图像送入图像编码器,得到其特征。然后,将图像特征与1000个文本特征计算余弦相似度,并选择最相似的文本特征对应的句子,从而完成分类任务。这种方法不仅限于这1000个类别,任何类别都可以使用。

通过这种方式,CLIP完全摆脱了分类标签的限制,无论是在训练还是推理过程中,都不需要预先定义好的标签列表。
在这里插入图片描述

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

3-2 Representation Learning

Fine-tuning 与 linear probe

  • Linear probe:把一个训练好的模型冻结住,只训练最后一层的分类头去做分类任务。
  • Fine tune:对整个模型参数进行端到端的训练

CLIP 的作者选择了linear probe的方法进行预训练模型在其他数据集上表现的对比。
在这里插入图片描述

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/908488.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vector(介绍)

目录 1.vector的介绍及使用 1.1 vector的介绍 1.2 vector的使用 1.2.1 vector的定义 1.2.2 vector iterator 的使用 1.2.3 vector 空间增长问题 1.2.4 vector 增删查改 1.2.5 vector 迭代器失效问题。(重点) 2.vector深度剖析及模拟实现 2.1 使用…

PHP“牵手”淘宝商品评论数据采集方法,淘宝API接口申请指南

淘宝天猫商品评论数据接口 API 是开放平台提供的一种 API 接口,它可以帮助开发者获取商品的详细信息,包括商品的标题、描述、图片等信息。在电商平台的开发中,详情接口API是非常常用的 API,因此本文将详细介绍详情接口 API 的使用…

深入理解Semaphore

Semaphore(信号量)是操作系统中PV操作的原语在java中的实现,它也是基于AQS实现的。其中PV操作是操作系统中一种实现进程互斥与同步的有效方法。PV操作与信号量(S)的处理有关,P表示通过,V表示释放…

2023.8 - java - 泛型

泛型问题的引出: jdk 1.5 引出泛型 // package 泛型; public class index {public static void main (String[] args){test t new test();t.setContent("aaa");int a (int) t.getContent();System.out.println(a);} }class test{Object content;publi…

分享图片 | 快速浏览网页资源,批量保存、一键分享图片

前言 小伙伴学习吉他,有时需要在互联网搜索曲谱资源,而多数曲谱均为图片,并且为多页,在电脑上显示练习很不方便,需要停下来点击鼠标进行翻页,影响练习的连贯性。 为了解决上述问题,通常把图片…

【数据分析入门】Jupyter Notebook

目录 一、保存/加载二、适用多种编程语言三、编写代码与文本3.1 编辑单元格3.2 插入单元格3.3 运行单元格3.4 查看单元格 四、Widgets五、帮助 Jupyter Notebook是基于网页的用于交互计算的应用程序。其可被应用于全过程计算:开发、文档编写、运行代码和展示结果。 …

产品流程图是什么?怎么做?

产品流程图是什么? 产品流程图是一种图形化的表达方式,用于描述产品开发、制造、销售、使用等各个阶段中涉及的流程、步骤和关系。它通过图形符号、箭头、文本等元素,展示了产品的各个环节之间的关联和顺序,通常被用于可视化产…

IT项目即将上线:项目经理的前夜清单

在IT项目的生命周期中,投产前的准备是至关重要的。作为项目经理,你需要确保所有的细节都已经准备好,以确保项目的顺利上线。以下是一份详细的清单,帮助项目经理在项目投产前进行全面的准备。 1. 项目的回顾 在项目即将上线之前&…

stm32的命令规则

stm32型号的说明:以STM32F103RBT6这个型号的芯片为例,该型号的组成为7个部分,其命名规则如下:

前端(十三)——JavaScript 闭包的奥秘与高级用法探索

😶博主:小猫娃来啦 😶文章核心:深入理解 JavaScript 中的闭包 文章目录 不理解闭包?这玩意很难?闭包的定义与原理闭包是什么创建一个闭包 闭包的应用场景闭包与作用域闭包与作用域之间的关系全局作用域、函…

python之Pandas

1.Pandas简介 Pandas 是 Python 语言的一个扩展程序库,用于数据分析。 Pandas 名字衍生自术语 “panel data”(面板数据)和 “Python data analysis”(Python 数据分析)。 Pandas 一个强大的分析结构化数据的工具集…

苹果手机桌面APP带云图标有个箭头,过一段时间经常要下载才能使用APP

环境: IPhone 11 IOS13.0 问题描述: 苹果手机桌面APP带云图标有个箭头,过一段时间经常要下载才能使用APP 解决方案: 1.打开设置,往下找到iTunes Store与App Store 2.找到下面卸载未使用的APP 关闭按钮

记录几个Hudi Flink使用问题及解决方法

前言 如题,记录几个Hudi Flink使用问题,学习和使用Hudi Flink有一段时间,虽然目前用的还不够深入,但是目前也遇到了几个问题,现在将遇到的这几个问题以及解决方式记录一下 版本 Flink 1.15.4Hudi 0.13.0 流写 流写…

一百六十三、Kettle——Linux上安装Kettle9.2(亲测有效,附截图)

一、目的 由于之前发现kettle8.2和kettle9.3这两个版本,或多或少的存在问题 比如kettle8.2的本地服务没问题,但在Linux上创建共享资源库时就有问题; 比如kettle9.3由于不自带shims驱动包,目前在新的下载官网上无法找到下载路径…

Gitbook超详细使用教程,搭建属于你自己的博客!

文章目录 简介与github同步1.创建space2.安装github插件3.同步github4.生成space的url 博客搭建指南1.自定义域名2.发表博客内容3.设置域名默认页面4.界面设置注意事项 End 简介 Gitbook 是一个平台,允许用户创建和分享内容丰富的在线书籍。它有一个用户友好的界面…

JDK JRE JVM 三者之间的详解

JDK : Java Development Kit JRE: Java Runtime Environment JVM : JAVA Virtual Machine JDK : Java Development Kit JDK : Java Development Kit【 Java开发者工具】,可以从上图可以看出,JDK包含JRE;java自己的一些开发工具中&#…

SpringBootWeb案例 Part 2

3. 员工管理 完成了部门管理的功能开发之后,我们进入到下一环节员工管理功能的开发。 基于以上原型,我们可以把员工管理功能分为: 分页查询 带条件的分页查询 删除员工 新增员工 修改员工 那下面我们就先从分页查询功能开始学习。 3.…

内存分布(以及new,delete)

今天给大家说下内存分布,我们都知道的是,像局部变量都在栈区,但是像我们自己有时候申请的空间都在堆区,当然,内存分布不只只是栈区和堆区,还有常量区,代码区等等。如下图: 这就是内存…

项目实战笔记4:敏捷

术语介绍 敏捷项目管理是一种以快速响应变化为核心的项目管理方法。与传统的瀑布模型不同,敏捷方法强调迭代开发和紧密的团队合作。其目的是尽可能快地交付可用的产品,然后在客户和团队之间进行反馈和迭代,以不断优化产品和开发过程。 在敏捷…

电商转化率是什么意思,怎么计算和提高电商转化率?

电商转化率是指访问电商网站的用户中,实际完成购买行为的比例。它可以衡量电商网站的销售能力和用户转化效果,是衡量电商运营效果的重要指标之一。 一、电商转化率的计算公式 电商转化率的计算公式为:转化率完成购买的用户数/访问网站的用户…