CLIP模型原理与代码实现详解

news2024/11/28 14:54:29

文章目录

  • 前言
  • 一、CLIP模型原理
    • 1.背景介绍
    • 2.对比训练方式
    • 3.prompt推理方式
    • 4.图像与文本编码结构
    • 5.特征CLS token结构
      • vit划分patch原理
      • cls token原理
  • 二、CLIP环境安装
    • 1.官方环境安装
    • 2.CLIP环境安装
    • 3.CLIP运行结果
  • 三.CLIP的Transformer结构代码解读
  • 四、CLIP模型主函数代码解读
  • 五、CLIP的image encode代码解读
    • 1、主函数代码解读
    • 2、VisionTransformer结构代码解读
    • 3、图像patch方法代码解读
    • 3、图像cls token编码代码解读
    • 4、图像位置编码代码解读
    • 5、图像cls token特征表达代码解读
    • 6、图像特殊结构代码解读
  • 六、CLIP的text encode代码解读
    • 1、主函数代码解读
    • 2、文本token代码解读
    • 3、文本位置编码代码解读
    • 4、文本特殊结构代码解读
  • 七、CLIP多模态融合代码解读
  • 八、CLIP推理结构解读
  • 九、CLIP训练结构解读
  • 总结


前言

目前,大模型十分活跃,openai公司呈现GPT系列,特别是Chat-GPT给人深刻印象,意识到大模型厉害之处,随后推出GPT4模型,更是将大模型进一步推到一个高度,并将多模态融合技术留下深刻印象,同时,学者也对多模态融合技术研究呈现百花齐放之势。然而,多模态模型大多以CLIP所提方法或思路实现多模态融合。为此,本文将重新回顾CLIP论文相关理论,也重点梳理其源码,并附其代码供读者参考(本文会涉及VIT与BERT代码解读)。


提示:代码环境安装、重点部分代码解释(如:image encode(VIT),text encode(BERT)等)

论文地址:点击这里
官网源代码:点击这里
我的代码:点击这里 名称为:CLIP模型.zip 提取码:r63z

一、CLIP模型原理

1.背景介绍

CLIP算是在跨模态训练无监督中的开创性工作,作者提到早在2017年之后就陆续有工作提出和本文类似的想法,但数据量太少,而无好结果。本文收集4亿数据的大数据集,才得到很好的效果。这种现象最近好像在机器学习领域越来越突出。本文采用对比方式,图像使用vit结构编码、文本使用bert编码,实现视觉与语言多模态融合。

2.对比训练方式

本文并非像图像caption方式,而是通过对比学习实现模型训练,我想也是这种对比学习才被目前多模态融合方法所借鉴。其采用对比学习原因如下:

  1. OpenAI是不愁计算资源的公司,喜欢将一切都gpt化(就是做生成式模型);
  2. 以往工作在1000类ImageNet数据训练方法,非常耗费资源,而CLIP要做的是开发世界的视觉识别任务,所以训练的效率对于自监督的模型至关重要;
  3. 如果任务改为给定一张图片去预测一个文本(或者给定一个文本去预测一张图片),那么训练效率将会非常低下(因为一个图片可能对应很多种说法,一个文本也对应着很多种场景);
  4. 与其做默写古诗词,不如做选择题!(只要判断哪一个文本与图片配对即可);
  5. 通过从预测任务改为只预测某个单词到只选出配对的答案,模型的训练效率一下提升了4倍;

为此,本文训练阶段使用对比学习,让模型学习文本-图像对的匹配关系,也就是下面模型原理图中,蓝色对角线为匹配的图文对。训练集用的他们自己采集的包含4亿个图文对的 WIT数据集。

在这里插入图片描述

3.prompt推理方式

使用某种固定prompt结构,正如训练获得特征,通过图像与prompt特征相似度匹配,实现clip分类,如:图像猫、狗二分类,可分别输入 “ A photo of cat ” 和 “ A photo of dog ”,分别与图像特征算相似度,确定其图像类被。

4.图像与文本编码结构

CLIP为多模态模型是指图像维度与文本维度融合,那么需要对图像特征化与文本特征化,本文选择图像编码结构为VIT,文本编码结构为BERT。后面,代码讲解,我将有大量笔墨说明。

5.特征CLS token结构

对于图像数据而言,其数据格式为[H, W, C],分别代表的是图片的通道数Channel,图片的高Height和宽Width。但很明显的是三维数据并不是Transformer所需要的。所以需要通过使用一个Embedding层来对原始的图片数据进行变换。

vit划分patch原理

vit论文做法为将给定的一堆图片按照给定的大小分成一堆Patches。本文将输入的图片尺寸为(224×224)按照16×16大小的Patch进行划分。其中(224×224)/(16×16)=196,因此我们会得到196个patches。到这里我们可以知道每一个Patches数据的shape为[16, 16, 3]。为了满足Transformer的需求,在这里,对每个Patch进行投影变化,映射到一维向量中。即完成如下转化。[16, 16, 3]->[768],那么这样一来,就将原始的[224, 224, 3]转化为[196, 768]。

cls token原理

在输入Transformer Encoder之前,值得注意的是需要加上[class] token。在原论文中,作者的意思是参考BERT,在上述得到的一堆tokens中插入一个专门用于分类操作的[class] token,这个[class] token是一个可训练的参数,数据格式和其他token保持一致,均为一个向量。
以本文为例,其维度大小为[1, 768]。注意的是,这里采取的是Concat操作。即cat cls token [1, 768]与图像pathch [196, 768] -> [197, 768],此时正好变成了二维矩阵。最终将图像patch变成维度是[197, 768],而本文是将cls token放在第一位,后面分类也是通过cls token给出,如下图。

在这里插入图片描述
注:cls token是一个可学习参数。

二、CLIP环境安装

本小节介绍如何使用官网代码安装环境,而不同电脑或cuda版本不一样,所安装也有所不同,但基本不影响,我的电脑相关属性:
gpu:RTX 3060显卡
CUDA:11.1

1.官方环境安装

官网代码安装如下命令:

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git

2.CLIP环境安装

构建虚拟环境:

conda create -n clip python=3.8

安装torch相关包:

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html  -i https://pypi.mirrors.ustc.edu.cn/simple/

安装相关依赖包:

pip install ftfy regex tqdm  -i https://pypi.mirrors.ustc.edu.cn/simple/

运行源码setup.py,其一为install运行,该操作是一个包安装虚拟环境,其二为develop运行,该操作是开发安装,指向了源代码而不是安装它的位置,方便调试,其命令如下:

# 方法一安装命令
python setup.py install
# 方法二安装命令
python setup.py develop  # 我采用该命令

注:建议使用方法二指向源码

3.CLIP运行结果

以上安装即可运行检测命令,可测试安装成功,其结果如下:
在这里插入图片描述

三.CLIP的Transformer结构代码解读

无论是文本text或图像image的编码encode均大量使用Transformer结构(以VIT与BERT编码),其实质是Q K V结构,可参考文章点击这里,为此我将单独使用一小节介绍。

改代码在源码model.py文件中,其调用类如下代码:

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

以上代码可知,该类为一个包装结构,重点是重复调用ResidualAttentionBlock结构,其结构如下代码:

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)  # n_head 头,d_model 表示维度。
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]  # 三个x表示Q K V计算值,x最后维度=n_head*d_model

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

从上面forward代码结构可知。
首先使用 x = x + self.attention(self.ln_1(x)),类似残差方式x+transform后的结果,该结构类似进行了attention方法,等同于transform结构的attention,该结构也被torch所集成,可直接调用其源码,如下:

self.attn = nn.MultiheadAttention(d_model, n_head)  # n_head 头,d_model 表示维度。

其次又调用 x = x + self.mlp(self.ln_2(x)),类似FFN结构,进行nn.Linear常规线性操作,在来一个激活GELU结构,最后在来一次线性操作,符合mlp结构,具体如下:

self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))

其中GELU使用QuickGELU方法,其代码如下:

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

注:该部分结构类似transformer结构,并n次使用于image与text的编码。

四、CLIP模型主函数代码解读

CLIP模型主函数也在源码model.py文件中,如下图所示:
在这里插入图片描述

其中forward为模型流走向,其代码如下:

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features,# 每一行sqr(a1^2+a2^2+...)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)  # [batch_img,512]
        text_features = text_features / text_features.norm(dim=1, keepdim=True)  # [batch_text,512]

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()  # 可学习参数
        logits_per_image = logit_scale * image_features @ text_features.t()  # 特征相乘获得相似度
        logits_per_text = logits_per_image.t()  # 变成文本

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

以上可知,CLIP实现多模态融合,实际是对图像编码与文本编码,使其分别获得对应的特征表达,在将表达特征进行norm(我的理解减小偏差,是一个常规操作),随后将图像特征与对应文本特相差,便可获得相似值。
假设以2个图像与3个文本表示,其图像特征获得对应文本特征得到相似值,简易说明如下:
在这里插入图片描述
将其转职获得文本特征获得对应图像特征相似值,简易说明如下:
在这里插入图片描述
其中,每个图像与文本特征表达维度为512(CLIP使用此维度),获得对应相似值如上图V**,每一行的最大值分别是CLIP模型认为最相似的,也得到图像获得文本标签,或文本获得匹配的图像。

五、CLIP的image encode代码解读

图像编码使用VIT编码结构,将图片划分为多个patch,然后使用transformer结构编码提取特征,最终获得特征表达。接下来,我将详细阐述。

1、主函数代码解读

CLIP使用encode_image函数调用,如下:

image_features = self.encode_image(image)

而encode_image函数如下:

def encode_image(self, image):
    return self.visual(image.type(self.dtype))

CLIP使用图像编码有ResNet结构与VisionTransformer,前者是CNN方式,后者是transformer方式,我将以transformer方式解读,如下代码:

        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )

2、VisionTransformer结构代码解读

该类是图像encode的所有精华所在,代码已有我的注释,其代码如下:

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        # width相当于transform中的d_model
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        # x=[1,3,224,224]
        x = self.conv1(x)  # shape = [*, width, grid, grid] # 将图片分成[32,32]个patch [1,768,7,7]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2],合并高宽 [1,768,49]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width] ,更换位置 [1,49,768]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width],添加cls token[1,50,768]
        x = x + self.positional_embedding.to(x.dtype)  # 这里位置编码是可学习的参数,可能是切了path顺序让模型自己学习吧  [1,50,768]
        x = self.ln_pre(x)  # [1,50,768]

        x = x.permute(1, 0, 2)  # NLD -> LND  # [pixel,b,d_model]=[50,1,768]
        x = self.transformer(x)  # 多头transformer [50,1,768]
        x = x.permute(1, 0, 2)  # LND -> NLD  # [1,50,768]

        x = self.ln_post(x[:, 0, :])  # x[:, 0, :] 将所有信息汇聚到cls token中,只需前面来做下游任务 [1,768]

        if self.proj is not None:  # self.proj是可学习参数,维度为[768,512]
            x = x @ self.proj  # 通过学习参数将维度再次融合变成512特征,最终为[1,512]

        return x

以上可知,图片首先切成patch块,然后转成transformer能使用的结构,该结构可参考这里,同时,代码也有位置编码模块与特征结合,随后将所有信息汇聚到cls token,可实现下游任务,最后也通过可学习参数实现最终图像特征提取。我将在下面具体解读。

3、图像patch方法代码解读

将图像划分patch实际是VIT最重要思想,意在解决训练和推理速度问题,代码层面处理,实际为卷积核与步长来处理,代码如下:

self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

以上代码简单一句,即可将如[1,3,224,224]的一个图片分成3232尺寸(vit使用1616,这个根据模型而定,仅是一个参数而已)化成768个patch,高宽分别为7,格式为[1,768,7,7]:

# x=[1,3,224,224]
x = self.conv1(x)  # shape = [*, width, grid, grid] # 将图片分成[32,32]个patch [1,768,7,7]

结果如图:
在这里插入图片描述
768来源:VIT模型将输入224224尺寸化成1616像素的patch,那么每个patch为16163=768,其中3为图像通道,将每个patch投影为768维度表示,也就是本文中self.conv1通道为768的缘故。
196与49区别:196也是来源VIT将224变成16尺寸的patch,那么共有224224/(1616)=196,而本文的patch尺寸为32,变成224224/(3232)=49。

最终图像使用reshape将宽高7*7合并转为49的像素,成为[1,49,768],可理解1为batch在NLP中表示一句话,49为像素在NLP中表示文字,768为每个patch投影表达在NLP中表示d_model为每个文字使用d_model表达特征。其代码如下:

x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2],合并高宽 [1,768,49]
x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width] ,更换位置 [1,49,768]

3、图像cls token编码代码解读

cls token为VIT较为特殊设置,是一个可学习参数,我已在上面原理中介绍,不在细说,只解读实现方式,实现代码如下:

scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))

将cls token嵌入,原来[1,49,768]变为[1,50,768],其代码中如下:

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width],添加cls token[1,50,768]

若在VIT模型cls token嵌入,将[1,196,768]变成[1,197,768]。

4、图像位置编码代码解读

位置编码也是一个可学习参数,实现代码如下:

self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))

将位置编码嵌入,实际是x加上了位置信息,和我之前attention is all you need文章解释类似,该结构代码如下:

x = x + self.positional_embedding.to(x.dtype)  # 这里位置编码是可学习的参数,可能是切了path顺序让模型自己学习吧  [1,50,768]

5、图像cls token特征表达代码解读

最终每张图像特征表达直接使用cls token来代替,直接取前第一个,如下图显示:
在这里插入图片描述

6、图像特殊结构代码解读

proj特殊结构,该结构若使用将进一步将图像特征表达进行变换,该变换的self.proj是可学习参数,代码如下:

self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

将该结构嵌入,我理解可进一步特征混合整合或组合获得图像特征表达,该结构代码如下:

if self.proj is not None:  # self.proj是可学习参数,维度为[768,512]
   x = x @ self.proj  # 通过学习参数将维度再次融合变成512特征,最终为[1,512]

代码运行图像显示如下:
在这里插入图片描述
我个人觉得该结构可被借鉴。

六、CLIP的text encode代码解读

文本编码使用BERT编码结构,显然使用transformer结构编码提取文本特征,最终获得特征表达。接下来,我将详细阐述。

1、主函数代码解读

CLIP使用encode_text函数调用,如下:

text_features = self.encode_text(text)

而encode_text函数如下:

def encode_text(self, text):
    # x 每个句子前面有值,有2个特殊符号[CLS][Seq]
    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model][3,77,512]
    x = x + self.positional_embedding.type(self.dtype)  # 位置编码直接赋可学习位置,添加位置信息[3,77,512]
    x = x.permute(1, 0, 2)  # NLD -> LND,[77,3,512]
    x = self.transformer(x)  # 共11个 和图像encode结构一致 [77,3,512]
    x = x.permute(1, 0, 2)  # LND -> NLD,[3,77,512]
    x = self.ln_final(x).type(self.dtype)
    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    # text.argmax(dim=-1) 句子最后有一个seq字段,是最大的,因此能获得句子个数数量
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x

2、文本token代码解读

文本编码和我之前文章点击这里解释transform的encode基本相同,读者可查看。很多与我之前文章相同内容将不在解释,该小节说明如何使用文本token。首先文本为text_language = ["a diagram", "a dog", "a black cat"],也就是三句话,每句话大概几个词,其转码为下图计算机可识别符号方法,查阅我的博客点击这里。其代码如下:

x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model][3,77,512]

其结果如下图:
在这里插入图片描述
以上可知,文本变成[3,77]结构,如输入text第一行文本为"a diagram",理论映射只有2个,但有四个数字,其中第一个为[CLS]值,最后一个为[Seq]值,本文设置每个句子长度为77,不足使用0表示,最终变成[3,77]表示为3个句子有77个文字(不足用0表示)。最终使用512维度表达,成为[3,77,512]结构,该部分与我之前文章内容一致,详情可参考之前文章。

3、文本位置编码代码解读

位置编码也是一个可学习参数,实现代码如下:

self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))

将位置编码嵌入,实际是x加上了位置信息,和我之前attention is all you need文章解释类似,该结构代码如下:

x = x + self.positional_embedding.type(self.dtype)  # 位置编码直接赋可学习位置,添加位置信息[3,77,512]

4、文本特殊结构代码解读

self.text_projection特殊结构,该结构若使用将进一步将文本特征表达进行变换,该变换的self.text_projection是可学习参数,代码如下:

self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

将该结构嵌入,与图像变啊特殊结构类似,该结构代码如下:

# text.argmax(dim=-1) 句子最后有一个seq字段,是最大的,因此能获得句子个数数量
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

注:x[torch.arange(x.shape[0]), text.argmax(dim=-1)]改代码表达取x为[3,77,512]维度索引分别[0,3],[1,3],[2,4],得到三个句子512维度特征表达,而每个句子都是取第二个维度77文字最大那一个,我的理解是每句话都是从第一个文字[CLS]叠加到最后一个文字[Seq],因此使用最后一个就有时序表达该句话的特征。

代码运行图像显示如下:

在这里插入图片描述

至于文本encode过程可参考代码走向,因其过于简单,我不在说明。

七、CLIP多模态融合代码解读

在上面小节中我们已然知晓图像编码与文本编码方式,该小节说明获得图像、文本特征表达融合方式,其代码如下:

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features,# 每一行sqr(a1^2+a2^2+...)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)  # [batch_img,512]
        text_features = text_features / text_features.norm(dim=1, keepdim=True)  # [batch_text,512]

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()  # 可学习参数
        logits_per_image = logit_scale * image_features @ text_features.t()  # 特征相乘获得相似度
        logits_per_text = logits_per_image.t()  # 变成文本

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

从代码可知,图像特征与文本特征进行norm(其作用在上面已说明),然后求解其相似度获得图像与文本匹配结果。其过程也较为简单,可直接参考以上源码,其图示如下:
在这里插入图片描述
图像特征为[1,512]表示一个图像被512维度表达;
文本特征[3,512]表示3个句子分别被512维度表达;

八、CLIP推理结构解读

推理代码官网也有提供,直接官网下载权重便可实现,我使用VIT-B-32模型结构,实现推理分类任务。该模型使用对比学习,可定义很多文本,让每个图像与多个文本特征相似匹配,匹配值越高,自然就是那个类。如同,我在上面CLIP模型主函数代码解读说明一样。其代码如下:

import torch
import clip
from PIL import Image
import numpy as np

def class_demo():
    # 测试分类的demo
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # 模型选择['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16'],对应不同权重
    model, preprocess = clip.load("../ViT-B-32.pt", device=device)  # 载入模型
    image = preprocess(Image.open("../CLIP.png")).unsqueeze(0).to(device)
    text_language = ["a diagram", "a dog", "a black cat"]
    text = clip.tokenize(text_language).to(device)

    with torch.no_grad():
        logits_per_image, logits_per_text = model(image, text)  # 第一个值是图像,第二个是第一个的转置
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

        idx = np.argmax(probs, axis=1)
        for i in range(image.shape[0]):
            id = idx[i]
            print('image {}\tlabel\t{}:\t{}'.format(i, text_language[id],probs[i,id]))
            print('image {}:\t{}'.format(i, [v for v in zip(text_language,probs[i])]))


if __name__ == '__main__':
    class_demo()

其结果如下:
在这里插入图片描述

九、CLIP训练结构解读

分类的CLIP训练实际是交叉熵方法,我们获得匹配值,可看成每个图像分别与不同文本相似值为预测类别值,进行类似交叉熵运算即可,另外反过来也可看成每个文本与分别与不同图像相似值为预测值,亦可进行交叉熵运算。我大概查了github其它训练方法,可供参考,其代码如下:

images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)

data_time_m.update(time.time() - end)
optimizer.zero_grad()

if args.accum_freq == 1:
    with autocast():
        model_out = model(images, texts)
        logit_scale = model_out["logit_scale"]
        if args.distill:
            with torch.no_grad():
                dist_model_out = dist_model(images, texts)
            model_out.update({f'dist_{k}': v for k, v in dist_model_out.items()})
        losses = loss(**model_out, output_dict=True)

        total_loss = sum(losses.values())
        losses["loss"] = total_loss

    backward(total_loss, scaler)

以上源码地址为:点击这里


总结

CLIP为多模态融合奠定了基准,也是通过对比训练可实现无监督大模型预训练。个人觉得还是比较重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1096441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

相似性搜索:第 4 部分--分层可导航小世界 (HNSW)

SImilarity 搜索是一个问题,给定一个查询的目标是在所有数据库文档中找到与其最相似的文档。 一、介绍 在数据科学中,相似性搜索经常出现在NLP领域,搜索引擎或推荐系统中,其中需要检索最相关的文档或项目以进行查询。在大量数据中…

从0开始学go第八天

gin获取URL路径参数 package main//获取path(URL)参数 import ("net/http""github.com/gin-gonic/gin" )func main() {r : gin.Default()r.GET("/:name/:age", func(c *gin.Context) {//获取路径参数name : c.Param(&quo…

15 | JPA 对 Web MVC 开发者做了哪些支持

我们使用 Spring Data JPA 的时候,一般都会用到 Spring MVC,Spring Data 对 Spring MVC 做了很好的支持,体现在以下几个方面: 支持在 Controller 层直接返回实体,而不使用其显式的调用方法;对 MVC 层支持标…

如何通过Photoshop将视频转换成GIF图片

一、应用场景 1、将视频转有趣动图发朋友圈 2、写CSDN无法上传视频,而可以用GID动图替代 3、其他 二、实现步骤 1、打开Photoshop APP 2、点击文件——导入——视频帧到图层 3、选择视频文件 4、配置视频信息,按照图片提示配置完毕之后,…

gma 2.0.2 (2023.10.15) 更新日志

安装 gma 2.0.2 pip install gma2.0.2新增 0.1、矢量提取(重要更新) (见简单示例)   现在,你可以像 numpy 或 pandas 一样直接对 Layer 进行切片提取。 0.2、修改属性表(重要更新) &#xff…

数电第一次实验

四选一,信号选择器 三位4选1多路选择器 要求输入信号有4个,且每个信号宽3位 如果是直接根据选择信号选 选择的是信号,选择的是编号,与信号具体是什么内容无关,信号的内容与其是否被选择无关,信号的编号…

E034-服务漏洞利用及加固-利用CVE-2016-5195漏洞实现Linux系统本地提权

实验等级: 中级 任务场景: 【任务场景】 小王接到磐石公司的邀请,对该公司内部网络进行渗透测试,经过对局域网被操作系统进行全面的维护中,发现了一台内核版本为4.2.0-27的Linux服务器,低权限用户利用该漏洞技术可以在全版本L…

【笔记整理】软考-软件设计师

一、计算机系统 计算机基本单位 单位名称简称换算位bitb字节byteB1B8b千字节KB1KB1024B兆字节MB1MB1024KB吉字节GB1GB1024MB太字节TB1TB1024GB 带宽单位Mbps的b是指Bit(位) 速度单位MB/s的B是指Byte(字节) 1MB/s=8M…

Android 10.0 禁止弹出系统simlock的锁卡弹窗功能实现

1.前言 在10.0的系统开发中,在一款产品中,需要实现simlock锁卡功能,在系统实现锁卡功能以后,在开机的过程中,或者是在插入sim卡 后,当系统检测到是禁用的sim卡后,就会弹出simlock锁卡弹窗,要求输入puk 解锁密码,功能需求禁用这个弹窗,所以就需要看是 哪里弹的,禁用…

04-React脚手架

04-React脚手架 1. react脚手架入门 1).脚手架的介绍 xxx脚手架: 用来帮助程序员快速创建一个基于xxx库的模板项目 包含了所有需要的配置(语法检查、jsx编译、devServer…)下载好了所有相关的依赖可以直接运行一个简单效果 react提供了一个用于创建rea…

R/d2及S/C4估计总体标准差,比较其CPK及规格限概率的差异

R/d2 和 S/C4 是用于估计总体标准差的无偏估计方法,通常用于控制图中。这些估计方法的主要目的是通过样本数据来估计总体标准差,以便监测过程的稳定性和变异性,而不需要收集整个总体的数据。 具体来说: R图中的 R/d2 和 S图中的…

【JAVA】有关包的概念

个人主页:【😊个人主页】 系列专栏:【❤️初识JAVA】 前言 Java包是用于组织和管理Java类的方式。它们提供了一种命名空间,以避免名称冲突,并使程序的组织更加有效和可维护。今天我们接着来学习有关包的概念。 包 …

faster lio 回环 加入GTSAM优化的记录

首先感谢这位博主的文章:https://blog.csdn.net/weixin_41281151/article/details/125371285,其中部分代码参考于改博主中的github: https://github.com/kahowang/FAST_LIO_SAM 不同的是,我使用的是faster lio进行更改&#xff0c…

vscode键盘输入不进去

二话不说,直接把输入切换到终端输出即可! 打开设置,搜索terminal,切换到run in terminal 即可!

C语言-指针相关使用

指针是 C语言的重要组成部分,是 C语言的核心、精髓。 在 C语言中,指针使用得当,能显著提高某些程序的效率,使用不当,则很容易造成系统错误、 一、指针使用 编译系统为每个变量都分配了一个能满足其类型大小的内存单…

vqvae简单实战,利用vqvae来提升模型向量表达

最近CV领域各种大模型在图像生成领域大发异彩,比如这两年大火的dalle系列模型。在这些模型中用到一个基础模型vqvae,今天我们写个简单实现来了解一下vqvae的工作原理。vqvae原始论文连接https://arxiv.org/pdf/1711.00937.pdf 1,代码 首先我们…

机器学习——奇异值分解二(特征分解+SVD纯理解)

矩阵的特征分解 特征值和特征向量的定义 抄来的:奇异值分解 困惑1:特征值和特征向量,和原矩阵是怎样的关系,需要一个栗子进行更具象的认识 困惑2:为什么多个特征向量组合成的矩阵,可以构成矩阵A的特征分解…

项目管理之实施关键步骤

项目管理已成为当代企业运营和发展过程中不可或缺的重要环节。如何实现高效、有序和可控的项目管理,一直是企业领导和项目团队追求的目标。本文将结合项目管理七招制胜内容,详细阐述项目管理实战中的具体做法。 如何分析项目 了解项目的背景和目的&…

网工记背配置命令(3)----POE配置示例

POE 供电就是通过以太网供电,这种方式仅凭借那根连接通信终端的网线就可完成为它们供电。POE提供的是-53V~0v 的直流电,供电距离最长可达 100m。PoE 款型的交换机的软件大包天然支持 POE,无需 license,通过执行 poe-enable 命令使…

【力扣1844】将所有数字用字符替换

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述二、题目分析 一、题目描述 给你一个下标从 0 开始的字符串 s ,它的偶数下标处为小写英文字母&am…