CLIP微调方法总结

news2024/11/23 22:00:31

文章目录

  • 前言
  • 1️⃣ Tip-Adapter
    • 论文和源码
    • 原理介绍
  • 2️⃣Cross-modal Adaptation(跨模态适应)
    • 论文和源码
    • 原理介绍
  • 3️⃣ FD-Align(Feature Discrimination Alignment,特征判别对齐)
    • 论文和源码
    • 原理介绍
  • 总结


前言

在这里插入图片描述

本文主要介绍和总结了三种不错的 C L I P CLIP CLIP微调方法,包括原理和思想,并且按照自己的理解给出了相应的代码实现,相当于是一个简化版的code实现。
所有代码使用 j i t t o r jittor jittor框架实现,具体代码请请参考👇

Gitlink-Code 或者 Github-Code


1️⃣ Tip-Adapter

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 本质上就是在 C L I P CLIP CLIP的预测结果 X X X上又加上了一个预测结果 Y Y Y,我们都知道结果 X X X是测试图像和所有分类文本的相似度之间的关系,而 Y Y Y就是测试图像和训练 C L I P CLIP CLIP时的训练图像之间的相似度关系,最终将 X X X Y Y Y加权求和便得到最终的预测结果,所以可以发现他的优势在于: Z e r o − s h o t   t r a n s f e r (无需额外训练) Zero-shot\ transfer(无需额外训练 ) Zeroshot transfer(无需额外训练)

  • 下面结合论文给的框架图就能很好理解这个方法(每个变量后面标出了 s h a p e shape shape大小,方便理解):

    T i p − A d a p t e r Tip-Adapter TipAdapter添加之前:假设分类类别数目是 N N N W c T W_{c}^{T} WcT N N N个文本标签经过 C L I P CLIP CLIP T e s t   E n c o d e r Test\ Encoder Test Encoder得到的文本特征,大小 N × 512 N×512 N×512
    输入一张测试图像 I t e s t I_{test} Itest → 经过 C L I P 模型的 V i s u a l   E n c o d e r 之后 \xrightarrow{经过CLIP模型的Visual\ Encoder之后} 经过CLIP模型的Visual Encoder之后 得到 f t e s t : 1 × 512 f_{test} :1×512 ftest1×512 → 和 C L I P 的 T e s t   F e a t u r e s 作相似度,也就是图中的 f t e s t ∗ W c T \xrightarrow{和CLIP的Test\ Features作相似度,也就是图中的f_{test}*W_{c}^{T}} CLIPTest Features作相似度,也就是图中的ftestWcT 得到分类结果(实际上就是和所有文本标签的相似度) X : 1 × N X:1×N X:1×N

    T i p − A d a p t e r Tip-Adapter TipAdapter添加之后:
    上面步骤同样完全相同,得到 X X X
    首先将所有的训练图像 I K I_{K} IK(假设共有 M M M张, M = C × N M=C×N M=C×N C C C是一个系数,因为训练时一般每个类别的图像会有多张) → 同样经过 C L I P 模型的 V i s u a l   E n c o d e r \xrightarrow{同样经过CLIP模型的Visual\ Encoder} 同样经过CLIP模型的Visual Encoder 得到 F t r a i n : M × 512 F_{train}:M×512 FtrainM×512 ,并作为缓存模型( c a c h e   m o d e l cache\ model cache model)的 k e y key key
    然后将所有训练图像的文本标签经过 O n e   H o t One\ Hot One Hot处理,得到 L t r a i n : M × N L_{train}:M×N LtrainM×N,并作为缓存模型的 v a l u e value value;到此便构建了一个缓存模型,相当于多了一份存储有训练样本特征的先验信息。
    接着将之前得到的 f t e s t : 1 × 512 f_{test} :1×512 ftest1×512 → 送入 c a c h e   m o d e l , 计算和训练图像之间的特征余弦相似度 \xrightarrow{送入cache\ model,计算和训练图像之间的特征余弦相似度} 送入cache model,计算和训练图像之间的特征余弦相似度 得到 A = e x p ( − β ( 1 − f t e s t F t r a i n T ) ) : 1 × M A=exp(-\beta(1-f_{test}F_{train}^{T})):1×M A=exp(β(1ftestFtrainT)):1×M → 和 c a c h e   m o d e l 的 v a l u e s 相乘,得到预测结果 Y \xrightarrow{和cache\ model的values相乘,得到预测结果Y} cache modelvalues相乘,得到预测结果Y Y = A L t r a i n : 1 × N Y=AL_{train}:1×N Y=ALtrain1×N
    最后将 T i p − A d a p t e r Tip-Adapter TipAdapter的预测结果 Y Y Y和原始 C L I P CLIP CLIP预测结果 X X X进行加权求和:
    logits = α A L train + f test W c T = α φ ( f t e s t F t r a i n T ) L t r a i n + f t e s t W c T , \begin{aligned} \begin{aligned} \text{logits}& =\alpha A\mathbf{L}_\text{train}+f_\text{test}W_c^T \\ &=\alpha\varphi(f_{\mathrm{test}}\mathbf{F}_{\mathrm{train}}^T)\mathbf{L}_{\mathrm{train}}+f_{\mathrm{test}}W_c^T, \end{aligned} \end{aligned} logits=αALtrain+ftestWcT=αφ(ftestFtrainT)Ltrain+ftestWcT,

在这里插入图片描述

2️⃣Cross-modal Adaptation(跨模态适应)

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 原理图和伪代码在这里插入图片描述
    在这里插入图片描述
  • 该方法的核心思想就是将多种模态的信息融合在一起,并且论文假设 C L I P CLIP CLIP可以将不同模态的样本映射到同一个特征空间。比如对于文本-图像这种模态形式,在训练过程中,就可以引入这里的文本信息(也就是每个类别的标签),将其作为额外的训练样本,其实就是将每张图像的图像特征和文本特征视作同一个特征来进行训练。
  • 同上面一样,根据伪代码的内容,将维度变换显示出来也非常好理解整个实现过程:
    假设输入的 b a t c h _ s i z e batch\_size batch_size大小为 b b b,分类的类别数为 n u m _ c l a s s num\_class num_class

i m a g e _ e n c o d e r 输出的图像特征 i m _ f : b × 512 t e x t _ e n c o d e r 输出的文本特征 t x _ f : b × 512 在行维度上将两个特征拼接起来并归一化 f e a t u r e s : 2 b × 512 对应的标签也进行拼接 l a b e l s : 2 b × 512 将 f e a t u r e s 通过一个分类器得到每个类别的预测概率 l o g i t s : 2 b × n u m _ c l a s s 最后 l o g i t s 和 l a b e l s 之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数 \begin{aligned} image\_encoder输出的图像特征 \quad im\_f:b×512\\ text\_encoder输出的文本特征 \quad tx\_f:b×512\\ 在行维度上将两个特征拼接起来并归一化\quad features:2b×512\\ 对应的标签也进行拼接\quad labels:2b×512\\ 将features通过一个分类器得到每个类别的预测概率 \quad logits:2b×num\_class\\ 最后logits和labels之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数 \end{aligned} image_encoder输出的图像特征im_fb×512text_encoder输出的文本特征tx_fb×512在行维度上将两个特征拼接起来并归一化features2b×512对应的标签也进行拼接labels2b×512features通过一个分类器得到每个类别的预测概率logits:2b×num_class最后logitslabels之间作交叉熵损失,并更新分类器、图像编码器和文本编码器的参数

注意:在实现该代码进行训练的过程中发现如果按照伪代码中将cross_logits除以一个常量,loss反而会很难下降,相反乘上一个系数loss下降的更好一些。(直接loss=cross_entropy_loss(logits*3.0,labels)即可),否则loss值很难会下降。

在这里插入图片描述
在这里插入图片描述

3️⃣ FD-Align(Feature Discrimination Alignment,特征判别对齐)

论文和源码

🔥 论文地址
🚀 代码地址

原理介绍

  • 原理图:
    在这里插入图片描述
  • 论文中提出了一个概念:虚假关联性的鲁棒性,它指的是模型是否具有区分出样本中和类别相关信息(因果信息)以及(背景、风格等)类别无关信息(虚假信息)的能力。同时注意到全微调的CLIP的OOD性能会下降,因此提出了一种不影响模型对虚假特征识别能力的微调方法来保证微调后的模型对虚假关联性的鲁棒性。从模型框架图中看,实际上就是在微调的过程中通过约束微调后的CLIP模型和原始的CLIP模型对虚假特征的分布保持一致,从而在一定程度上避免微调过程中CLIP的OOD性能下降。
  • 该方法相对于前两个方法稍显复杂,先熟悉它定义的几个符号意义,再来结合框架图看一下它的整个模型原理:

首先假设存在一个小样本数据集 D ⊂ X × Y ,( X 表示图像, Y 表示标签) 有 M 个提示模板 ( P 1 , … , P M ) , C L I P 模型的 t e x t − e n c o d e r 和 i m a g e − e n c o d e r 分别表示为 g 0 和 f 0 ; 假设任意的一个类别 y ,那么 y 的原型表示为: μ y class  ,也被称为类的原型 首先假设存在一个小样本数据集D\subset X\times Y,(X表示图像,Y表示标签)\\ 有M个提示模板(P_1,\ldots,P_M),CLIP模型的text-encoder和image-encoder分别表示为g_{0}和f_{0};\\ 假设任意的一个类别y,那么y的原型表示为:\mu_y^\text{class },也被称为类的原型 首先假设存在一个小样本数据集DX×Y,(X表示图像,Y表示标签)M个提示模板(P1,,PM)CLIP模型的textencoderimageencoder分别表示为g0f0;假设任意的一个类别y,那么y的原型表示为:μyclass ,也被称为类的原型
μ y class  : = 1 M ∑ j = 1 M g 0 ( [ P j , y ] ) . \begin{aligned} \mu_y^\text{class }:=\frac{1}{M}\sum_{j=1}^Mg_0([P_j,y]). \end{aligned} μyclass :=M1j=1Mg0([Pj,y]).
因此第一个损失函数 L c l a s s \mathcal{L}_{\mathrm{class}} Lclass和clip模型中的损失函数本质上相同的,约束图像-文本之间的相似度,只不过这里的文本不在是单个的prompt,而是多个prompt取平均值得到的。
L class = − 1 ∣ D ∣ ∑ ( x i , y i ) ∈ D log ⁡ exp ⁡ ( s ( f t ( x i ) , μ y i class ) ) ∑ y ∈ Y exp ⁡ ( s ( f t ( x i ) , μ y class ) ) 其中, s ( : ) 表示余弦相似度 \begin{aligned} \mathcal{L}_{\text{class}}=-\frac{1}{|\mathcal{D}|}\sum_{(x_i,y_i)\in\mathcal{D}}\log\frac{\exp(s(f_t(x_i),\mu_{y_i}^{\text{class}}))}{\sum_{y\in\mathcal{Y}}\exp(s(f_t(x_i),\mu_y^{\text{class}}))}\\ 其中,s(:)表示余弦相似度 \end{aligned} Lclass=D1(xi,yi)DlogyYexp(s(ft(xi),μyclass))exp(s(ft(xi),μyiclass))其中,s(:)表示余弦相似度
紧接着,定义提示模板( p r o m p t )的原型:每个 P j 在所有类中的特征平均值,公式为: 紧接着,定义提示模板(prompt)的原型:每个P_{j}在所有类中的特征平均值,公式为: 紧接着,定义提示模板(prompt)的原型:每个Pj在所有类中的特征平均值,公式为:
μ P j spurious : = 1 ∣ Y ∣ ∑ y ∈ Y g 0 ( [ P j , y ] ) \begin{aligned} \mu_{P_j}^\text{spurious}:=\frac{1}{|\mathcal{Y}|}\sum_{y\in\mathcal{Y}}g_0([P_j,y]) \end{aligned} μPjspurious:=Y1yYg0([Pj,y]) 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性 , 即保持模型在微调前后提取的虚假特征不变。 所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性,即保持模型在微调前后提取的虚假特征不变。\\所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。 现在希望的是在微调过程中保持模型对虚假相关性的鲁棒性,即保持模型在微调前后提取的虚假特征不变。所以需要知道模型在虚假特征上的分布——即将微调模型提取的特征与虚假原型之间的相似度定义为模型虚假特征的分布。

因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布: 因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布: 因此,计算由微调模型提取的特征和虚假原型之间的相似性,并且如下产生虚假特征的分布:
P spurious ( x ; f t ) = SoftMax [ s ( f t ( x ) , μ P 1 spurious ) , … , s ( f t ( x ) , μ P M spurious ) ] \begin{aligned} \mathcal{P}_\text{spurious}(x;f_t)=\text{SoftMax}\left[s\left(f_t(x),\mu_{P_1}^\text{spurious}\right),\ldots,s\left(f_t(x),\mu_{P_M}^\text{spurious}\right)\right] \end{aligned} Pspurious(x;ft)=SoftMax[s(ft(x),μP1spurious),,s(ft(x),μPMspurious)]
类似地,将 f t 换成 f 0 ,可以得到微调前模型的虚假特征分布: 类似地,将f_{t}换成f_{0},可以得到微调前模型的虚假特征分布: 类似地,将ft换成f0,可以得到微调前模型的虚假特征分布:
P spurious ( x ; f 0 ) = SoftMax [ s ( f 0 ( x ) , μ P 1 spurious ) , … , s ( f 0 ( x ) , μ P M spurious ) ] \begin{aligned} \mathcal{P}_{\text{spurious}}(x;f_0)=\text{SoftMax}\left[s\left(f_0(x),\mu_{P_1}^{\text{spurious}}\right),\ldots,s\left(f_0(x),\mu_{P_M}^{\text{spurious}}\right)\right] \end{aligned} Pspurious(x;f0)=SoftMax[s(f0(x),μP1spurious),,s(f0(x),μPMspurious)]

因此第二个损失函数的作用就是保持微调前后模型对虚假特征概率分布保持一致:
L spurious = 1 ∣ D ∣ ∑ ( x i , y i ) ∈ D KL ( P spurious ( x i ; f t ) ∣ ∣ P spurious ( x i ; f 0 ) ) \begin{aligned} \mathcal{L}_{\text{spurious}}=\frac{1}{|\mathcal{D}|}\sum_{(x_i,y_i)\in\mathcal{D}}\text{KL}\left(\mathcal{P}_{\text{spurious}}(x_i;f_t)\mid\mid\mathcal{P}_{\text{spurious}}(x_i;f_0)\right) \end{aligned} Lspurious=D1(xi,yi)DKL(Pspurious(xi;ft)∣∣Pspurious(xi;f0))
综上,最终的损失函数为:
L t o t a l = α ⋅ L c l a s s + β ⋅ L s p u r i o u s 论文中取 α = 1 , β = 20 \begin{aligned} \mathcal{L}_{\mathrm{total}}=\alpha\cdot\mathcal{L}_{\mathrm{class}}+\beta\cdot\mathcal{L}_{\mathrm{spurious}} \end{aligned}\\ 论文中取\alpha=1,\beta=20 Ltotal=αLclass+βLspurious论文中取α=1,β=20

更多细节的推导和更准确的表述请参考作者的原论文😀

总结

  • 本文介绍了三种CLIP微调方法的原理以及给出了对应的更加简化版代码实现,如果有问题的地方,欢迎评论区指正。
  • 三种方法相比较而言,Tip-Adapter最通用,无论是免训练版本还是训练版本,使用之后均有一定的提升效果;Cross-modal Adaptation思路最简单,但是要想有效果,尝试后发现需要针对自己的数据集不断调节参数大小;FD-Align方法在保持CLIP的zero-shot能力方面是几个方法当中最好的;
  • 觉得有帮助的话,给个赞吧👋👋👋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2082499.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

教您用军团要塞2服务器开服联机教程

1、购买后登录服务器 进入控制面板后会出现正在安装的界面 2、下载连接工具 打开Steam库中搜索Source SDK Base 2013 Multiplayer并安装 3、下载游戏 以下三个链接均为同一个游戏 百度 通过百度网盘分享的文件:tf2classic.zip 链接:百度网盘 请输入提…

(最新)华为 2024 届秋招-硬件技术工程师-单板硬件开发—机试题—(共12套)(每套四十题)

(最新)华为 2024 届秋招-硬件技术工程师-单板硬件开发—机试题—(共12套)(每套四十题) 岗位——硬件技术工程师 岗位意向——单板硬件开发 真题题目分享,完整版带答案(有答案和解析&#xff0…

浅析WebRTC技术在智慧园区视频管理场景中的应用

随着科技的飞速发展,智慧园区作为城市智慧化的重要组成部分,正逐步成为现代化管理的重要方向。智慧园区的建设不仅涉及硬件设施的智能化升级,还离不开高效的视频管理和实时通信技术。在这一背景下,WebRTC(Web Real-Tim…

BackdoorLLM:一个针对生成性LLMs后门攻击的全面基准测试

大型语言模型(LLMs)在从自然语言理解到机器翻译等一系列任务上取得了显著的突破性进展。例如,GPT-4模型展示了在生成类人文本和解决复杂问题方面的前所未有的能力。然而,近期的研究表明,LLMs存在一个关键的脆弱性&…

sqli-labs靶场通关攻略(41-45关)

第41关 这关我们使用工具sqlmap练习一下(这里如果用本机的回环地址访问靶场的话只能在你的本机访问,因为我们是在虚拟机上进行扫描,所以不能使用127.0.0.1访问) 进入虚拟机kali,打开终端 查库 sqlmap -u 网址 -- curr…

uniapp生活记账小程序

Springboot vue uniapp生活记账小程序,前端采用vue uni-app设计开发,后端采用 Springboot 开发前端对应的数据接口,首页显示生活账单信息,我的野蛮统计记账信息和微信登录状况。记账页面可以,根据不同类别的日常消费记…

MySQL集群技术4——MySQL路由

mysql-route MySQL 路由(Routing)通常指的是在 MySQL 架构中如何处理客户端请求和数据流向的问题。在 MySQL 中,路由可以涉及多种不同的场景和技术,包括但不限于反向代理、负载均衡、读写分离等。下面我将详细介绍这些场景和技术…

如何反射获取类的全部信息?(java代码)

什么是反射? 反射是 Java 提供的一种机制,允许在运行时动态地获取类的信息(如类的名称、方法、字段等),以及创建对象和调用方法。反射利用了 java.lang.reflect 包中的类,如 Class、Method、Field 和 Cons…

仿BOSS招聘系统开发:构建高效、智能的在线招聘平台

在数字化时代,招聘行业正经历着前所未有的变革。BOSS直聘作为国内领先的招聘平台,以其高效的匹配机制、丰富的职位信息和便捷的用户体验,赢得了广泛的市场认可。本文将探讨如何开发一个仿照BOSS招聘系统的在线招聘平台,旨在为企业…

8月28复盘日记

8月28复盘日记 前言今日感恩今日知识今日反思今日名言 前言 今天早上是六点半起床嘻嘻,这两天因为生理期,皮质醇似乎有些高,入睡会有些困难。但是因为今天是开学第一天,意味着,健身房恢复晨练了!我可太喜欢晨练时间安安静静的健身…

怎么又快又好制作流程图?试试这2款流程图制作神器,专业!

推荐2款简单好用的流程图制作软件,帮你轻松搞定各种流程图。 1、GitMind 点点击链接直达官网>>gitmind.cn GitMind是一个在线流程图制作工具,模板类型全,支持免费下载,由国内团队研发,操作简单,使用…

功率器件和滤波器件的选型及测试方法

目录 一、功率器件的选型及测试方法 1.1功率器件的选型 1.2功率器件的测试方法 二、滤波器件的选型及测试方法 2.1滤波器件的选型 2.2滤波器件的测试方法 三、表格总结 一、功率器件的选型及测试方法 1.1功率器件的选型 在电子电路设计中,功率器件的选择是…

c++异常处理(c++11版)与智能指针 SmartPtr 的应用(主讲shared_ptr浅实现)

引子:找C语言的异常时,你是否会被奇怪的错误码或程序终止报错而感到无奈,对找不到错误在哪的心烦?在害怕内存泄漏时,你是否每一步,每一句代码都要仔细分析,在用完之后进行资源空间的释放&#x…

用powermock编写单元测试

1、pom文件引入jar包 <!-- 单元测试 start --> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope> </dependency> <dependency>&…

Jenkins服务安装配置

什么是 Jenkins Jenkins 是一个开源的自动化服务器&#xff0c;主要用于持续集成&#xff08;CI&#xff0c;Continuous Integration&#xff09;和持续交付/持续部署&#xff08;CD/CD&#xff0c;Continuous Delivery/Continuous Deployment&#xff09;。它帮助开发团队自动…

48.【C语言】结构体补充

承接20.【C语言】初识结构体&#xff08;重要&#xff09;中的结构体成员的访问 目录&#xff1a; 1.结构体创建 2.利用函数控制结构体 3.使用“结构体指针变量-->结构体成员变量”来修改结构体的数据 4.传值还是传址&#xff1f; 1.结构体创建 依据第20篇&#xff0c;可以…

Kafka【三】Windows下安装Kafka集群

前文&#xff1a;Kafka【一】Windows下安装单节点Kafka ① 安装ZooKeeper 在D:\创建文件夹kafkacluster&#xff0c;将kafka安装包kafka_2.12-3.6.1解压缩到kafka文件夹。 修改文件夹名为kafkazk 因为kafka内置了ZooKeeper软件&#xff0c;所以此处将解压缩的文件作为ZooKeep…

Ubuntu20.04可以同时安装ROS(Noetic)和ROS2(Humble)

Ubuntu系统确实可以同时安装ROS&#xff08;Robot Operating System&#xff09;和ROS2&#xff0c;但需要注意一些关键步骤和配置以确保两者能够顺利共存并独立运行。以下是在Ubuntu上同时安装ROS和ROS2的详细步骤和注意事项&#xff1a; 安装前准备 检查Ubuntu版本&#xff…

CTFHub SSRF靶场通关攻略

内网访问 首先进入环境 在url后面输入 http://127.0.0.1/flag.php访问&#xff0c;得出flag 伪协议读取文件 进入环境后再url后面拼接 file:///var/www/html/flag.php 访问后是&#xff1f;&#xff1f;&#xff1f;&#xff0c;那么我们F12检查源码得出flag 端口扫描 我们进行…

【苍穹外卖】Day2 员工接口 分类接口

1 新增员工 1.1 设计 前端表单&#xff1a; 路径&#xff1a;/admin/employee 方法&#xff1a;POST 本项目约定: 管理端发出的请求&#xff0c;统一使用 /admin 作为前缀 用户端发出的请求&#xff0c;统一使用 /user 作为前缀 存在数据库中的实体类对象&#xff1a; pac…