从零实现CLIP模型

news2024/11/17 21:38:03

1. 引言

CLIP代表语言图像对比预训练模型,是OpenAI于2021年开发的一个深度学习模型。CLIP模型中图像和文本嵌入共享相同的潜在特征空间,从而能够在两种模式之间直接进行对比学习。这是通过训练模型使相关的图像和文本更紧密地结合在一起,同时将不相关的图像在特征空间距离分开来实现的。

闲话少说,我们直接开始吧!

2. 相关应用

关于CLIP模型的一些应用总结如下:

  • 图像分类和检索:CLIP可以通过将图像与自然语言文本描述关联起来进而可用于图像分类任务。它允许更通用和灵活的图像检索系统,用户可以使用文本查询来在数据库里搜索图像。

  • 内容调节:CLIP可用于通过分析图像和附带文本来识别和过滤不适当或有害的内容,从而调节在线平台上的展示内容。

3. 核心思想

CLIP模型旨在预测一个batchN×N个潜在(img,text)配对具体哪些是实际匹配的。为了实现这一点,CLIP通过图像编码器和文本编码器的联合训练建立了一个多模态嵌入空间。CLIP的损失函数旨在最大化批处理中N个真实配对的图像和文本嵌入之间的余弦相似性,同时最小化N²−N个错误配对的余弦相似度。以下是伪代码(取自原始论文),概述了CLIP的核心实现。
在这里插入图片描述
接着我们将伪代码中每一行的逐步描述,将其转化为使用PyTorch来实现。

4. 网络结构

在进行代码实现之前,我们先来简单回顾下clip模型具体的网络结构:
在这里插入图片描述

ClIP模型使用两种独立的网络结构来作为图像编码和文本编码的主干,其中:

  • image_encoder:负责编码图像的神经网络主干(eg,ResNetVision Transformer等)。
  • text_encoder:表示负责编码文本信息的神经网络架构(eg,CBOWBERT等)。

原始CLIP模型是从零开始训练的,而没有使用预训练的权重来初始化图像编码器和文本编码器,因为它们用于训练其CLIP模型的数据集体量很大(4亿个图像-文本对)。在这篇博客文章的例子中,我们将采取一些不同的做法。我们将从resnet(用于图像)和distilbert(用于文本)模型的预训练权重开始初始化这些部分。

5. 数据输入

该模型每个批次以n个图像和文本对作为输入,其中:

  • I[n,h,w,c]:表示对齐的图像的小批次输入,其中n是batch大小,h是图像高度,w是图像宽度,c是通道数。
  • T[n,l]:表示对齐文本的小批次输入,其中n是batch大小,l是文本序列的长度。

我们的实现中,我们默认batch的大小为128,如下所示:
在这里插入图片描述

6. 特征提取

关于文本和图像的特征提取,这里使用resnet34distilbert来分别提取图像和文本的特征,如下:

  • I_f = image_encoder(I) : 从图像编码器中获取的图像特征表示I_fI_f的大小为[n,d_I],其中d_I是图像特征的维度。
  • T_f=text_encoder(T):从文本编码器中获取的文本特征表示T_fT_f的大小为[n,d_T],其中d_T是文本特征的维度。

在本文实现中,相应的代码如下:

# for encoding images
I_f = models.resnet34(pretrained=True)      
# for encoding captions
T_f= AutoModel.from_pretrained("distilbert-base-multilingual-cased") 

7. 特征映射

接着,我们将相应的文本和图像特征,映射到同一嵌入特征空间,如下:

  • W_i[d_i,d_e]:表示用于将图像特征i_f映射到嵌入特征空间i_e的投影矩阵。W_i的形状大小是[d_i,d_e],其中d_e表示的是联合嵌入特征空间的维度。
  • W_t[d_t,d_e]:表示用于将文本特征t_f映射到相同嵌入空间t_e的投影矩阵。W_t的形状大小是[d_t,d_e]

投影操作可以使用具有两个线性层的神经网络进行编码,其权重是学习的投影矩阵。在大多数情况下,投影权重是唯一可以在新数据集上需要训练的权重。此外,投影层在对齐图像和文本嵌入的尺寸方面发挥着至关重要的作用,确保它们具有相同的维度。

相应的代码实现如下:
在这里插入图片描述

8. 组合

在上一节中,我们将文本和图像特征分别统一到相同的维度,接着我们将上述相关组件进行整合:

  • I_e = l2_normalize(np.dot(I_f, W_i), axis=1) :在联合嵌入空间I_e中嵌入并归一化图像特征
  • T_e = l2_normalize(np.dot(T_f, W_t), axis=1) :在联合嵌入空间T_e中嵌入并归一化文本特征

接着我们使用以下Pytorch代码来描述图像和文本数据的处理次序。首先,相应的数据通过基本编码器进行处理,然后通过投影层进行处理。最后,为两种模态特征进行嵌入归一化化并返回。如下:

在这里插入图片描述

9. 余弦相似度

接着在嵌入空间,我们来计算文本图像嵌入特征的相似度:

  • logits = np.dot(I_e, T_e.T) * np.exp(t):用以计算图像和文本对在联合嵌入空间的特征余弦相似度,通过可学习的参数t进行缩放。

在我们的例子中,我们考虑暂不使用参数t,代码如下:

logits = T_e @ T_e.T

10. 损失函数

CLIP使用对比损失用以将相关图像和文本在嵌入特征空间拉近,同时将不相关的图像和文本距离拉远。

  • labels = np.arange(n): 用以生成表示batch索引的真值标签。
  • loss_i = cross_entropy_loss(logits, labels, axis=0):用以计算图像特征和真值标签的损失
  • loss_t = cross_entropy_loss(logits, labels, axis=1):用以计算文本特征和真值标签的损失
  • loss = (loss_i + loss_t)/2:计算图像和文本损失的加权平均值。

代码实现如下:

在这里插入图片描述

11. 构建完整模型

将所有不同的部件组合在一起,最终的自定义CLIP模型如下所示:

在这里插入图片描述

12. 构建数据集

我们的自定义CLIP模型将使用flickr30k数据集进行训练。该数据集包括31000多张图像,每张图像至少有5个独立的人工生成文本描述。在这个例子中,我们将为每个图像使用两个标题,总共有62000个图像和文本对用于训练。 代码实现如下:
在这里插入图片描述
上述模型关键常数包括用于学习表示特征空间的维度embed_dim, 用于transformer特征维度的transformer_embed_dim和用于文本输入长度的max_len。所选的text_modeldistilbert base multilanguage-cased。用以训练的模型的epoch为3,同时batch_size的大小为128,这些常数将用于模型构建和训练。如下所示:
在这里插入图片描述

13. 数据集测试用例

DataLoader是为训练期间的高效迭代而设置的,提供图像文本对的迭代访问。调用代码如下:

# Create the DataLoader
clip_dataloader = DataLoader(flickr30k_custom_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

以下是数据集中一个批次中的图像文本对的示例:

import numpy as np
import matplotlib.pyplot as plt
# Create an iterator from the dataloader
data_iter = iter(clip_dataloader)

# Get one batch
batch = next(data_iter)

image = batch["image"][0]  # get one image from the batch
caption = batch["caption"][0]  # get one text from the batch

# Convert the image tensor to a NumPy array and permute dimensions
image_np = np.transpose(image.numpy(), (1, 2, 0))

# Display the image and caption
plt.imshow(image_np)
plt.title(f"Caption: {caption}")
plt.show()

运行结果如下:
在这里插入图片描述

14. 优化器选择

此外,我们还需要指定在整个训练过程中需要优化的参数。上文中我们已经固定了文本和图像编码器的特征提取层,那么只有与投影层相关的参数才会在新的数据集上进行训练。

# Create an instance of your model
model = CustomModel().to(device)

# Define optimizer
optimizer = torch.optim.Adam([
    {'params': model.vision_encoder.parameters()},
    {'params': model.caption_encoder.parameters()}
], lr=model.lr)

15. 模型训练

我们使用Tesla T4的GPU机器进行3个epoch的训练,相应的训练代码如下:
在这里插入图片描述

执行上述训练代码,可以得到训练过程如下:
在这里插入图片描述

16. 总结

总之,这篇博客文章探讨了CLIP模型,揭示了其广泛应用的潜力。随着我们对CLIP应用的了解,很明显,它的影响远远超出了最初的预期,为不同领域的创新解决方案铺平了道路。

您学废了嘛?

完整代码:戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1363294.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL高级DBA的理论与实践,MySQL数据库管理员从入门到精通

一、教程描述 数据库管理员(Database Administrator),简称DBA,想要成为高级的MySQL DBA,就要耐得住寂寞,持续不断地学习,除了数据库专业知识外,还需要了解主机、系统、网络、存储、…

SSD固态硬盘的黄金原则:抱最高的希望,做最坏的打算-1

随着SSD固态硬盘日益普及,在个人电脑中已成为基本的配置选项。在体验SSD固态硬盘带来的性能优势的同时,你有没有想过一个问题,SSD的数据如果误删除或发生故障丢失,还有没有可能找回来呢?这也许是固态硬盘飞入寻常百姓家…

C++_命令行操作

命令行操作 介绍第一步编译 源码第二部 找到exe 可执行文件第三步看图操作代码测试源码测试结果 介绍 本文介绍命令行操作 1.argc 表示当前输入 参数个数 2.argv 表示当前输入 字符串内容 第一步编译 源码 #include<iostream> #include<string>using namespace st…

构建网络信息安全的中国方案 - 国密SSL协议介绍以及国密Nginx服务器部署

国密SSL协议 国密SSL协议指的是采用国密算法&#xff0c;符合国密标准的安全传输协议。简而言之&#xff0c;国密SSL就是SSL/TLS协议的国密版本。TLS协议定义有三个版本号&#xff0c;为0x0301、0x0302、0x0303&#xff0c;分别对应TLS 1.0、1.1、1.2。国密SSL为了避免冲突&am…

Go (一) 基础部分5 -- 单元测试,协程(goroutine),管道(channel)

一、单元测试 Go自带一个轻量级的"测试框架testing"和自带的"go test"命令来实现单元测试和性能测试。 1.确保每个函数时可运行&#xff0c;并且运行结果是正确的。 2.确保写出来的代码性能是好的。 3.单元测试能及时的发现程序设计或实现的逻辑错误&#…

Nginx 常用变量 与 防盗链

目录 1.常用变量 2. $http_referer 配置防盗链 2.1 referer 2.2 配置防盗链 1.常用变量 变量说明 $args 请求中的参数&#xff0c;也叫查询参数 $content_length HTTP响应信息里的"Content-Length" $document_root nginx虚拟主机配置文件中的root站点根目录…

八大算法排序@归并排序(C语言版本)

目录 归并排序概念算法思想第一步第二步第三步 算法步骤代码实现代码1代码优化 时间复杂度空间复杂度特性总结 归并排序 概念 归并排序&#xff08;Merge Sort&#xff09;是一种基于分治策略的经典排序算法。它的基本思想是将待排序的数组划分成两个子数组&#xff0c;分别对…

vue-springboot基于java的社区志愿者活动信息管理系统 e2y4d

社区志愿者信息管理系统的主要开发目标如下&#xff1a; &#xff08;1&#xff09;对零碎化、分布散的数据信息进行收纳、整理&#xff0c;通过网络服务平台使这些信息内容更加调理&#xff0c;更加方便化和清晰化&#xff0c;让访问该系统的每个用户享受浏览的过程。 &#x…

简单 Web Server 程序的设计与实现 (2024)

1.题目描述 Web 服务是 Internet 最方便与受用户欢迎的服务类型&#xff0c;它的影响力也远远超出了专业技术范畴&#xff0c; 已广泛应用于电子商务、远程教育、远程医疗与信息服务等领域&#xff0c;并且有继续扩大的趋势。目前很多 的 Internet 应用都是基于 Web 技术的&…

MySQL之数据类型建表以及约束

SELECT(查询) 查询操作用于从数据库中检索数据 查询可以基于不同的条件&#xff0c;如字段值、范围、排序等 查询结果可以返回单个记录或多个记录 查询指定列 select 列名 from 表名 列名&#xff1a;代表从指定的列名中查找 , 如果是查找对应的多列&#xff0c;则用英文…

word 常用功能记录

word手册 多行文字对齐标题调整文字间距打钩方框插入三线表插入参考文献自动生成目录插入页码&#xff08;罗马格式和阿拉伯数字格式&#xff09; 多行文字对齐 标题调整文字间距 打钩方框 插入三线表 插入一个最基本的表格把整个表格设置为无框线设置上框线【实线1.5磅】设置…

基于Springboot的Timo商城

​ 目录 ​前言 开发环境和工具 项目功能 基础模块 商城功能 手机端 设计详情 后台登录页面 后台 手机端页面 小程序端页面 视频展示 源码获取 前言 本项目是一个基于IDEA和Java语言开基于Springboot的Timo商城。应用包含网页管理端&#xff0c;手机端&#xff0…

Matlab三维绘图

绘制三维图plot3 t0:pi/50:10*pi; xsin(t); ycos(t); zt; plot3(x,y,z); 产生栅格数据点meshgrid 这个接口在绘制三维图像里面相当重要&#xff0c;很多时候要将向量变成矩阵才能绘制三维图。 x0:0.5:5; y0:1:10; [X,Y]meshgrid(x,y); plot(X,Y,o); x和y是向量&#xff0c;…

JavaWeb——后端案例

五、案例 1. 开发规范—Restful REST&#xff08;Representational State Transfer&#xff09;&#xff0c;表述性状态转换&#xff0c;是一种软件架构风格 注&#xff1a; REST是风格&#xff0c;是约定方式&#xff0c;不是规定&#xff0c;可以打破描述模块的功能通常使…

uniappVue3版本中组件生命周期和页面生命周期的详细介绍

一、什么是生命周期&#xff1f; 生命周期有多重叫法&#xff0c;有叫生命周期函数的&#xff0c;也有叫生命周期钩子的&#xff0c;还有钩子函数的&#xff0c;其实都是代表&#xff0c;在 Vue 实例创建、更新和销毁的不同阶段触发的一组钩子函数&#xff0c;这些生命周期函数…

每日一博 - 多租户技术及其三种数据存储策略

文章目录 概述应用程序隔离数据隔离小结 概述 多租户技术&#xff08;Multi-Tenant Technology&#xff09;是软件即服务&#xff08;SaaS&#xff09;架构中的一项核心技术&#xff0c;允许单一软件应用或服务同时服务于多个客户&#xff08;即“租户”&#xff09;&#xff…

STM32F4xx之库函数

一、库函数介绍 库函数与寄存器的区别 库函数&#xff1a;不需要自己写很多代码&#xff0c;可以利用软件生成代码。使用的时候必须添加库文件。库文件是芯片厂商写好了。占用空间大。 寄存器&#xff1a;自己写的代码量大&#xff0c;没有软件生成代码。使用的时候不需要库文件…

目标检测数据集大全「包含VOC+COCO+YOLO三种格式+划分脚本+训练脚本」(持续原地更新)

一、作者介绍&#xff1a;五年算法开发经验、AI 算法经理、阿里云开发社区专家博主、稀土掘金人工智能内容评审委员会成员。擅长&#xff1a;检测、分割、理解、AIGC 等算法训练与部署。 二、数据集介绍&#xff1a; 质量高&#xff1a;高质量图片、高质量标注数据&#xff0c;…

【LMM 011】MiniGPT-5:通过 Generative Vokens 进行交错视觉语言生成的多模态大模型

论文标题&#xff1a;MiniGPT-5: Interleaved Vision-and-Language Generation via Generative Vokens 论文作者&#xff1a;Kaizhi Zheng* , Xuehai He* , Xin Eric Wang 作者单位&#xff1a;University of California, Santa Cruz 论文原文&#xff1a;https://arxiv.org/ab…

[技术杂谈]使用VLC将视频转成一个可循环rtsp流

通过vlc播放器&#xff0c;将一个视频转成rtsp流&#xff0c;搭建一个rtsp服务器。rtsp客户端可访问这个视频的rtsp流。 1. 打开vlc播放器&#xff0c;使用的版本如下 2. 菜单&#xff1a;媒体 ---> 流 3. 添加视频文件&#xff0c;点击添加一个mp4 文件 4. 选择串流&…