深度学习总结——用自己的数据集微调CLIP

news2025/1/18 4:40:33

CLIP概述

CLIP(Contrastive Language-Image Pretraining)是由OpenAI开发的一种深度学习模型,用于将图像和自然语言文本进行联合编码。它采用了多模态学习的方法,使得模型能够理解图像和文本之间的语义关系。

它的核心思想是将图像和文本视为同等重要的输入,并通过联合训练来学习它们之间的联系。CLIP模型使用了一个共享的编码器,它将图像和文本分别映射到一个共享的特征空间中。通过将图像和文本的编码向量进行比较,模型能够判断它们之间的相似性和相关性。

它在训练过程中使用了对比损失函数,以鼓励模型将相关的图像和文本对编码得更接近,而将不相关的图像和文本对编码得更远。这使得CLIP模型能够具有良好的泛化能力,能够在训练过程中学习到通用的图像和文本理解能力。

它的整体流程如下:
在这里插入图片描述

它展现了强大的zero-shot能力,在许多视觉与语言任务中表现出色,如图像分类、图像生成描述、图像问答等。它的多模态能力使得CLIP模型能够在图像和文本之间建立强大的语义联系,为各种应用场景提供了更全面的理解和分析能力。

正是因为它出色的zero-shot能力,因此训练的模型本身就含有很多可以利用的知识,因此在一些任务上,如分类任务,caption任务,可以尝试在自己的数据集上微调CLIP,或许通过这种操作就能获得不错的性能。但是目前如何微调CLIP网上并没有看到很详细的介绍,因此我整理了相关的知识并在此记录。
参考链接

微调代码

第三方库

  • clip-by-openai
  • torch

下面以我做的图像分类任务为例,介绍相关的步骤。

步骤介绍

1.构建数据集

构建自己的数据集,每次迭代返回的数据包括:RGB图像和图像的标签(a photo of {label})
代码示例如下:

import os
from PIL import Image
import numpy as np
import clip
class YourDataset(Dataset):
    def __init__(self,img_root,meta_root,is_train,preprocess):
        # 1.根目录(根据自己的情况更改)
        self.img_root = img_root
        self.meta_root = meta_root
        # 2.训练图片和测试图片地址(根据自己的情况更改)
        self.train_set_file = os.path.join(meta_root,'train.txt')
        self.test_set_file = os.path.join(meta_root,'test.txt')
        # 3.训练 or 测试(根据自己的情况更改)
        self.is_train = is_train
        # 4.处理图像
        self.img_process = preprocess
        # 5.获得数据(根据自己的情况更改)
        self.samples = []
        self.sam_labels = []
        # 5.1 训练还是测试数据集
        self.read_file = ""
        if is_train:
            self.read_file = self.train_set_file
        else:
            self.read_file = self.test_set_file
		# 5.2 获得所有的样本(根据自己的情况更改)
        with open(self.read_file,'r') as f:
            for line in f:
                img_path = os.path.join(self.img_root,line.strip() + '.jpg')
                label = line.strip().split('/')[0]
                label = label.replace("_"," ")
                label = "a photo of " + label
                self.samples.append(img_path)
                self.sam_labels.append(label)
        # 转换为token
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        token = self.tokens[idx]
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        # 对图像进行转换
        image = self.img_process(image)
        return image,token

2.加载预训练CLIP模型和相关配置

首先使用第三方库加载预训练的CLIP模型,会返回一个CLIP模型和一个图像预处理函数preprocess,这将用于之后的数据加载过程。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)

然后初始化优化器,损失函数,需要注意的是,如果刚开始你的损失很大或者出现异常,可以调整优化器的学习率和其他参数来进行调整,通常是调整的更小会有效果。

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.1)

# 创建损失函数
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

3.加载数据

该步骤主要是调用第一步中创建的类,然后使用DataLoader函数加载自己的数据集。
代码如下:

your_dataset = YourDataset(img_root= '/images',
                                          meta_root= '/meta',
                                          is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)

4.开始训练

训练代码按照模板来写即可,总共要训练epoches次,每次要将一个数据集里面的所有数据都训练一次,然后在每次训练完成的时候保存模型,这里分为两种:

  • 保存模型的参数
  • 保存模型的参数、优化器、迭代次数

该部分的代码如下:

phase = "train"
model_name = "your model name"
ckt_gap = 4
epoches = 30
for epoch in range(epoches):
    scheduler.step()
    total_loss = 0
    batch_num = 0
    # 使用混合精度,占用显存更小
    with torch.cuda.amp.autocast(enabled=True):
        for images,label_tokens in your_dataloader:
            # 将图片和标签token转移到device设备
            images = images.to(device)
            label_tokens = label_tokens.to(device)
            batch_num += 1
            # 优化器梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                logits_per_image, logits_per_text = net(images, label_tokens)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                total_loss += cur_loss
                if phase == "train":
                    cur_loss.backward()
                    if device == "cpu":
                        optimizer.step()
                    else:
                        optimizer.step()
                        clip.model.convert_weights(net) 
            if batch_num % 4 == 0:
                logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))
        epoch_loss = total_loss / dataset_size_your
        torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")
        logger.info(f"weights_{epoch} saved")
        if epoch % ckt_gap == 0:
            checkpoint_path = f"{model_name}_ckt.pth"
            checkpoint = {
                'it': epoch,
                'network': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"checkpoint_{epoch} saved")
        logger.info('{} Loss: {:.4f}'.format(
            phase, epoch_loss))

全部代码

import os
from PIL import Image
import numpy as np
import clip
from loguru import logger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn

class YourDataset(Dataset):
    def __init__(self,img_root,meta_root,is_train,preprocess):
        # 1.根目录(根据自己的情况更改)
        self.img_root = img_root
        self.meta_root = meta_root
        # 2.训练图片和测试图片地址(根据自己的情况更改)
        self.train_set_file = os.path.join(meta_root,'train.txt')
        self.test_set_file = os.path.join(meta_root,'test.txt')
        # 3.训练 or 测试(根据自己的情况更改)
        self.is_train = is_train
        # 4.处理图像
        self.img_process = preprocess
        # 5.获得数据(根据自己的情况更改)
        self.samples = []
        self.sam_labels = []
        # 5.1 训练还是测试数据集
        self.read_file = ""
        if is_train:
            self.read_file = self.train_set_file
        else:
            self.read_file = self.test_set_file
		# 5.2 获得所有的样本(根据自己的情况更改)
        with open(self.read_file,'r') as f:
            for line in f:
                img_path = os.path.join(self.img_root,line.strip() + '.jpg')
                label = line.strip().split('/')[0]
                label = label.replace("_"," ")
                label = "photo if " + label
                self.samples.append(img_path)
                self.sam_labels.append(label)
        # 转换为token
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        token = self.tokens[idx]
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        # 对图像进行转换
        image = self.img_process(image)
        return image,token
# 创建模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.1)

# 创建损失函数
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# 加载数据集
your_dataset = YourDataset(img_root= '/images',
                                          meta_root= '/meta',
                                          is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)

phase = "train"
model_name = "your model name"
ckt_gap = 4
for epoch in range(st,args.epoches):
    scheduler.step()
    total_loss = 0
    batch_num = 0
    # 使用混合精度,占用显存更小
    with torch.cuda.amp.autocast(enabled=True):
        for images,label_tokens in your_dataloader:
            # 将图片和标签token转移到device设备
            images = images.to(device)
            label_tokens = label_tokens.to(device)
            batch_num += 1
            # 优化器梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                logits_per_image, logits_per_text = net(images, label_tokens)
                ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                total_loss += cur_loss
                if phase == "train":
                    cur_loss.backward()
                    if device == "cpu":
                        optimizer.step()
                    else:
                        optimizer.step()
                        clip.model.convert_weights(net) 
            if batch_num % 4 == 0:
                logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))
        epoch_loss = total_loss / dataset_size_food101
        torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")
        logger.info(f"weights_{epoch} saved")
        if epoch % ckt_gap == 0:
            checkpoint_path = f"{model_name}_ckt.pth"
            checkpoint = {
                'it': epoch,
                'network': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            logger.info(f"checkpoint_{epoch} saved")
        logger.info('{} Loss: {:.4f}'.format(
            phase, epoch_loss))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/606866.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

皮卡丘SQL注入汇总

1.Sql Inject(SQL注入)概述 SQL注入漏洞主要形成的原因是在数据交互中,前端的数据传入到后台处理时,没有做严格的判断,导致其传入的“数据”拼接到SQL语句中后,被当作SQL语句的一部分执行。 从而导致数据库受损(被脱裤…

循环控制语句

文章目录 1.break——跳出循环1.1作用 2.continue——控制循环2.1作用 3.猜数字4.while语句4.1while语句的结构4.2算1-10求和:4.3使用while方式批量添加5个用户给这五个用户添加密码: 5.until语句6.拓展6.1 购物6.2shell 计算器实现加减乘除和取余6.3打印…

chatgpt赋能python:Python反转输出正整数-让计算更简单

Python反转输出正整数-让计算更简单 Python是一种高级编程语言,除了可以完成各种任务,还可以反转输出正整数。在本篇SEO文章中,我将介绍如何使用Python编程语言反转输出正整数,并且展现了这个方法是如何简化计算。 什么是Python…

chatgpt赋能python:Python列表指定元素的取出方法

Python列表指定元素的取出方法 在Python编程中,经常需要取出列表中的指定元素。本文将介绍几种常用的取出列表指定元素的方法。 1. 使用索引 列表中的元素可以通过索引来进行访问和修改,索引从0开始。下面的示例展示了如何使用索引来取出列表中的指定…

六、docker安装ngxin部署若以前端

1.第一次安装,不进行挂载数据卷, docker run \ -p 8060:80 \ --name nginx \ --privilegedtrue \ --restartalways \ -d nginx:1.17.82. 将配置信息复制到宿主机本地 # 将容器nginx.conf文件复制到宿主机 docker cp nginx:/etc/nginx/nginx.conf /data…

总投资300亿,南山前海南山村旧改城市更新

南山村 项目位于南山区南山街道南山村旧村片区,东临南新路,南临东滨路,西临前海路,北临南园村。地处联系前海、后海两大中心区的空间发展轴带上,区位交通条件优越。位于9号线延长线前海路站附近,也因地处大…

背包问题总结篇

背包问题总结篇 关于这几种常见的背包,其关系如下: 通过这个图,可以很清晰分清这几种常见背包之间的关系。 在讲解背包问题的时候,我们都是按照如下五部来逐步分析,相信大家也体会到,把这五部都搞透了&…

【C++刷题】【动态规划篇】(一)

动态规划篇(一) 一、1137. 第 N 个泰波那契数(easy)二、三步问题(easy)三、使用最小花费爬楼梯(easy)四、解码方法(medium)五、不同路径(medium&a…

如何利用地面控制点实现倾斜摄影三维模型数据的几何坐标变换和纠正?

如何利用地面控制点实现倾斜摄影三维模型数据的几何坐标变换和纠正? 倾斜摄影是一种在空中拍摄地表物体的技术,可以获得高分辨率、高精度的三维模型数据,广泛应用于城市规划、建筑设计、土地管理等领域。然而,由于航拍时无法避免姿…

ClassLoader源码

介绍 ClassLoader 顾名思义就是类加载器 ClassLoader 是一个抽象类 没有父类 作用 1.负责将 Class 加载到 JVM 中 2.审查每个类由谁加载(父优先的等级加载机制) 3.将 Class 字节码重新解析成 JVM 统一要求的对象格式 常量&变量 //注册本地方法…

chatgpt赋能python:Python实现动态排名:在SEO游戏中的使用

Python实现动态排名:在SEO游戏中的使用 搜索引擎优化(SEO)是一项必不可少的活动,可以提高网站在搜索结果中的排名和流量。其中之一是动态排名,它可以根据网站相应信息的变化而自动更新排名,使网站始终保持…

chatgpt赋能python:Python技巧:如何用Python去除文本中的头和尾

Python技巧:如何用Python去除文本中的头和尾 在任何文本处理任务中,去除文本数据的头和尾是非常常见的需求。这在搜索引擎优化(SEO)中尤其重要,因为头和尾中可能包含重复的内容,这会降低网页的排名。在这篇…

翻筋斗觅食海鸥优化算法-附代码

翻筋斗觅食海鸥优化算法 文章目录 翻筋斗觅食海鸥优化算法1.海鸥优化算法2. 改进海鸥优化算法2.1 非线性参数 A 策略2.2 翻筋斗觅食策略 3.实验结果4.参考文献5.Matlab代码6.python代码 摘要:针对基本海鸥优化算法(SOA)在处理复杂优化问题中存在低精度、…

Linux会替代Windows吗?

Windows用户们,去还是留? Windows 依然是高居榜首的桌面操作系统,占据 90% 以上的市场份额,远超 macOS 和 Linux 。 从数据来看,尽管 linux 并不是 Windows 的头号接班人,但近几年越来越多用户转向 Ubunt…

Vue嵌套表单的 Dialog精美模板分享

文章目录 🐒个人主页🏅Vue项目常用组件模板仓库📖前言:🎀源码如下: 🐒个人主页 🏅Vue项目常用组件模板仓库 📖前言: 本篇博客主要提供vue组件之嵌套表单的 D…

通用权限管理系统+vue3项目实战(一)

1.创建项目 在某个工程文件夹下创建项目 npm init vuelatest各种工具选择都选是,并且安装环境node_modules之后,显示如图: 1.1 引入element-plus 在vue2的时候常用的ui框架是element-ui,在vue3时应该使用它的继承者element-p…

chatgpt赋能python:Python匹配空白字符的完整指南

Python匹配空白字符的完整指南 在Python编程中,处理文本数据是一项常见任务。当我们需要从文本中提取数据时,通常需要从字符串中匹配特定的模式。这些模式可能包括空格、制表符和换行符等空白字符。本文将介绍如何使用Python正则表达式来匹配空白字符&a…

chatgpt赋能python:Python动态Import简介

Python动态Import简介 在Python中,Import语句用于导入其他Python模块中的函数和变量。通常在Python编程中,我们使用静态Import方法来导入模块。但是,Python也支持动态Import,即在运行时根据需要导入模块中的函数和变量。 在本文…

day 39:62. 不同路径63. 不同路径 II

动态规划 [62. 不同路径](https://leetcode.cn/problems/unique-paths/description/)1. dp数组以及下标名义2. 递归公式3. dp数组如何初始化4. 遍历顺序5.代码 [63. 不同路径 II:有障碍物](https://leetcode.cn/problems/unique-paths-ii/description/)1. dp数组以及…

银行从业——法律法规——经济基础知识

第一章、经济基础知识 第一节、宏观经济分析 【 知识点1】 宏观经济发展目标 宏观经济发展的总体目标一般包括四个: 宏观经济发展的总体目标 衡量指标1、经济增长国内生产总值(GDP)2、充分就业 失业率3、物价稳定通货膨胀率4、国际…