[pai-diffusion]pai的easynlp的clip模型训练

news2024/11/25 23:43:13

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态理解的重要任务,…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/528476134

initialize_easynlp()->

train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/clip_chinese_roberta_base_vit_base"),
    data_file="MUGE_MR_train_base64_part.tsv",
    max_seq_length=32,
    input_schema="text:str:1,image:str:1",
    first_sequence="text",
    second_sequence="image",
    is_training=True)
valid_dataset = CLIPDataset()

model = get_application_model(app_name='clip',...)
- easynlp.appzoo.api.ModelMapping->CLIPApp
- easynlp.appzoo.clip.model.py->CLIPApp
- CHINESE_CLIP->
- self.visual = VisualTransformer()
- self.bert = BertModel()

trainer = Trainer(model,train_dataset,user_defined_parameters,  
                evaluator=get_application_evaluator(app_name="clip",valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))

trainer.train()
- for _epoch in range(self._first_epoch,int(args.epoch_num)):
      for _step,batch in enumerate(self._train_loader):    
          label_ids = batch.pop()
          forward_outputs = self._model(batch)
          loss_dict = self.model_module.compute_loss(forward_outputs,label_ids)
          _loss = loss_dict('loss')
          
          _loss.backward()

model = get_application_model_evaluation()
evaluator = get_application_evaluator()
evaluator.evaluate(model)

数据处理:

import os
import base64
import multiprocessing
from tqdm import tqdm


def process_image(image_path):
    # 从图片路径中提取中文描述
    image_name = os.path.basename(image_path)
    description = os.path.splitext(image_name)[0]

    # 将图片转换为 Base64 编码
    with open(image_path, 'rb') as f:
        image_data = f.read()
        base64_data = base64.b64encode(image_data).decode('utf-8')

    return description, base64_data


def generate_tsv(directory):
    image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) if
                   filename.endswith(('.jpg', '.png'))]

    with multiprocessing.Pool() as pool, tqdm(total=len(image_paths), desc='Processing Images') as pbar:
        results = []
        for result in pool.imap_unordered(process_image, image_paths):
            results.append(result)
            pbar.update(1)

    with open(
            '/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train.tsv',
            'w', encoding='utf-8') as f:
        for description, base64_data in results:
            line = f"{description}\t{base64_data}\n"
            f.write(line)


if __name__ == '__main__':
    target_directory = "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train/img_download/"
    # import pdb;pdb.set_trace()
    generate_tsv(target_directory)

训练代码:

import torch.cuda
from easynlp.appzoo import CLIPDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, \
    get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parameters


def main():
    # /root/.easynlp/modelzoo中
    train_dataset = CLIPDataset(
        pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),
        data_file=args.tables.split(",")[0],
        max_seq_length=args.sequence_length,
        input_schema=args.input_schema,
        first_sequence=args.first_sequence,
        second_sequence=args.second_sequence,
        is_training=True)

    valid_dataset = CLIPDataset(
        # 预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"alibaba-pai/clip_chinese_roberta_base_vit_base"以得到其路径,并自动下载模型
        pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),
        data_file=args.tables.split(",")[-1],
        # "data/pai/MUGE_MR_valid_base64_part.tsv"
        max_seq_length=args.sequence_length,  # 文本最大长度,超过将截断,不足将padding
        input_schema=args.input_schema,  # 输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等
        first_sequence=args.first_sequence,  # 用于说明input_schema中哪些字段作为第一/第二列输入数据
        second_sequence=args.second_sequence,
        is_training=False)  # 是否为训练过程,train_dataset为True,valid_dataset为False

    model = get_application_model(
        app_name=args.app_name,  # 任务名称,这里选择文本分类"clip"
        pretrained_model_name_or_path=get_pretrain_model_path(
            args.pretrained_model_name_or_path),
        user_defined_parameters=user_defined_parameters
        # user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters
    )

    trainer = Trainer(model=model,
                      train_dataset=train_dataset,
                      user_defined_parameters=user_defined_parameters,
                      evaluator=get_application_evaluator(app_name=args.app_name,
                                                          valid_dataset=valid_dataset,
                                                          user_defined_parameters=user_defined_parameters,
                                                          eval_batch_size=32))
    trainer.train()

    # 模型评估
    model = get_application_model_for_evaluation(app_name=args.app_name,
                                                 pretrained_model_name_or_path=args.checkpoint_dir,
                                                 user_defined_parameters=user_defined_parameters)

    evaluator = get_application_evaluator(app_name=args.app_name,
                                          valid_dataset=valid_dataset,
                                          user_defined_parameters=user_defined_parameters,
                                          eval_batch_size=32)
    model.to(torch.cuda.current_device())
    evaluator.evaluate(model=model)

    # 模型预测
    if test:
        predictor = get_application_predictor(app_name="clip",
                                              model_dir="./outputs/clip_model/",
                                              first_sequence="text",
                                              second_sequence="image",
                                              sequence_length=32,
                                              user_defined_parameters=user_defined_parameters)

        predictor_manager = PredictorManager(predictor=predictor,
                                             input_file="data/vcg_furnitures_text_image/vcg_furnitures_test.tsv",
                                             input_schema="text:str:1",
                                             output_file="text_feat.tsv",
                                             output_schema="text_feat",
                                             append_cols="text",
                                             batch_size=2)
        predictor_manager.run()


if __name__ == "__main__":
    initialize_easynlp()
    args = get_args()
    user_defined_parameters = parse_user_defined_parameters(
        'pretrain_model_name_or_path=alibaba-pai/clip_chinese_roberta_base_vit_base')
    args.checkpoint_dir = "./outputs/clip_model/"
    args.pretrained_model_name_or_path = "alibaba-pai/clip_chinese_roberta_base_vit_base"
    # args.n_gpu = 3
    # args.worker_gpu = "1,2,3"
    args.app_name = "clip"
    args.tables = "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"
    # "data/vcg_furnitures_text_image/vcg_furnitures_train.tsv," \
    #               "data/vcg_furnitures_text_image/vcg_furnitures_test.tsv"
    # "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"
    args.input_schema = "text:str:1,image:str:1"
    args.first_sequence = "text"
    args.second_sequence = "image"
    args.learning_rate = 1e-4
    args.epoch_num = 1000
    args.random_seed = 42
    args.save_checkpoint_steps = 200
    args.sequence_length = 32
    # args.train_batch_size = 2
    args.micro_batch_size = 32

    test = False

    main()

# python -m torch.distributed.launch --nproc_per_node 4 tools/train_pai_chinese_clip.py


说一点自己的想法,在我自己工作之初,我很喜欢去拆解一些框架,例如openmm系列,但其实大部分在训练过程上都是相似的,大可不必,在改动上,也没有必要对其进行流程上的大改动,兼具百家之长,了解整体pipeline,更加专注在pipeline实现和效果导向型的结果提交更加有效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1029363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人脸识别技术应用安全管理规定(试行)|企业采用人脸打卡方式,这4条规定值得关注

近日,为规范人脸识别技术应用,国家互联网信息办公室起草了,并向全社会公开征求意见。该规定一共列举了25条,企业如借助人脸识别技术采集考勤打卡数据,以下4条规定值得关注。 第四条 只有在具有特定的目的和充分的必要…

【前端知识】Three 学习日志(四)—— 相机控件

Three 学习日志&#xff08;四&#xff09;—— 相机控件 一、引入相机控件 <!-- 引入相机控件 --> <script type"importmap">{"imports": {"three": "../build/three.module.js","three/addons/": "../…

idea中提示:error has occurred, please check your installation and try again

目录 报错原因解决总结 报错 idea中提示&#xff1a;error has occurred, please check your installation and try again 原因 1.起初我是把一个运行正常的java程序&#xff0c;放到了src下&#xff0c;新建的一个包&#xff08;包名为java.first&#xff09;中&#xff0c…

torch其他层和联合使用

recurrent layers一般是特定的结构&#xff0c;在语音识别和创作用的比较多&#xff0c;又RNN,LSTM,GRU一些东西。 transform 层nlp常用&#xff0c;在cv领域表现得很不错 线性层&#xff0c;infeature和outfeature还有一个偏置 dropout层&#xff0c;是为了防止过拟合&…

基于ssm扶贫产品和扶贫物资捐赠系统033

大家好✌&#xff01;我是CZ淡陌。一名专注以理论为基础实战为主的技术博主&#xff0c;将再这里为大家分享优质的实战项目&#xff0c;本人在Java毕业设计领域有多年的经验&#xff0c;陆续会更新更多优质的Java实战项目&#xff0c;希望你能有所收获&#xff0c;少走一些弯路…

软考考试多少分算通过?

软考证书取得需要达到总分45分&#xff0c;每门科目满分为75分。因此&#xff0c;不要小看45分&#xff0c;在考试中获得这个分数并不容易。此外&#xff0c;软考要求一次性通过&#xff0c;如果没有通过&#xff0c;成绩将不被保留。因此&#xff0c;必须在一次考试中成功通过…

改写paddledetection为cmake版(c++)

下载源代码 官方地址&#xff1a; https://gitee.com/paddlepaddle/PaddleDetection 网盘&#xff1a; paddledetection 链接&#xff1a;https://pan.baidu.com/s/1g0z5SYQNDR1pwe9iAtvR3A?pwdktl6 提取码&#xff1a;ktl6 paddleocr 链接&#xff1a;https://pan.baidu.c…

不理解路径问题的大坑记录

./表示当前目录 当前所在的目录 一直写的是…/老是访问不到 就像着人家组件有什么问题 ./了一下成功了 果然 有句话说的真的很棒 不报错才是最可怕的 谁知道你的错误是什么

No servers available for service: renren…。 Gateway 网关报503错误 ,已解决

目录 环境配置问题描述loadbalancer的作用 环境配置 问题描述 配置spring cloud gateway使用端口访问就可以&#xff0c;使用lb:// 就报503 gateway:routes:- id: admin_routeuri: lb://gulimall-admin # uri: http://localhost:8080predicates:- Path/api/**filter…

Start 方法源码深究——模板方法设计模式

目录 一. &#x1f981; 前言1.1 New状态1.2 Runnable1.3 Runing1.4 Block状态1.5 Terminated状态 二. &#x1f981; 线程 start 方法源码剖析2.1 虚拟机调用run方法执行线程2.2 最少有两个线程在执行2. 3 不可以重复执行2.4 start方法体 三. &#x1f981; 模板方法设计模式3…

李沐深度学习记录1:零碎知识记录、08线性回归

简要记录&#xff0c;以便查阅~ 一、零碎知识 x.numel()&#xff1a;看向量或矩阵里元素个数 A.sum()&#xff1a;向量或矩阵求和&#xff0c;axis参数可对某维度求和&#xff0c;keepdims参数设置是否保持维度不变 A.cumsum&#xff1a;axis参数设置沿某一维度计算矩阵累计和…

05_Bootstrap插件02

7 小标签 通过 .label 实现小标签&#xff0c;用于提示类。 <h1>h1标题 <span class"label label-default">标签</span></h1> <h2>h2标题<span class"label label-default">标签</span></h2> <h3&g…

精品Python思政素材数据库在线学习资源网

《[含文档PPT源码等]精品基于Python实现的思政素材数据库设计与实现》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程等 软件开发环境及开发工具&#xff1a; 开发语言&#xff1a;python 使用框架&#xff1a;Django 前端技术&#xff1a;JavaScri…

Linux:GlusterFS 集群

GlusterFS介绍 1&#xff09;Glusterfs是一个开源的分布式文件系统,是Scale存储的核心,能够处理千数量级的客户端.在传统的解决 方案中Glusterfs能够灵活的结合物理的,虚拟的和云资源去体现高可用和企业级的性能存储. 2&#xff09;Glusterfs通过TCP/IP或InfiniBand RDMA网络链…

2023年9月21日

完善登录界面的注册登录功能 头文件1 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QPushButton> #include <QLineEdit> #include <QLabel> #include <QMovie> #include <QDebug> #include <QMessage…

【计算机网络】深入理解TCP协议二(连接管理机制、WAIT_TIME、滑动窗口、流量控制、拥塞控制)

TCP协议 1.连接管理机制2.再谈WAIT_TIME状态2.1理解WAIT_TIME状态2.2解决TIME_WAIT状态引起的bind失败的方法2.3监听套接字listen第二个参数介绍 3.滑动窗口3.1介绍3.2丢包情况分析 4.流量控制5.拥塞控制5.1介绍5.2慢启动 6.捎带应答、延时应答 1.连接管理机制 正常情况下&…

记一次 .NET 某餐饮小程序 内存暴涨分析

一&#xff1a;背景 1. 讲故事 前些天有位朋友找到我&#xff0c;说他的程序内存异常高&#xff0c;用 vs诊断工具 加载时间又太久&#xff0c;让我帮忙看一下到底咋回事&#xff0c;截图如下&#xff1a; 确实&#xff0c;如果dump文件超过 10G 之后&#xff0c;市面上那些可…

ESP8266 WiFi物联网智能插座—项目简介

目录 1、项目背景 2、设备节点功能 3、上位机功能 物联网虽然能够使家居设备和系统实现自动化、智能化管理&#xff0c;但是依然需要依靠更为先进的终端插座作为根本保障&#xff0c;插座是所有家用电器需要使用的电源设备&#xff0c;插座的有序智能管理&#xff0c;对于实…

SpringMVC初级

文章目录 一、SpringMVC 概述二、springMVC步骤1、新建maven的web项目2、导入maven依赖3、创建controller4、创建spring-mvc.xml配置文件&#xff08;本质就是spring的配置件&#xff09;5、web.xml中配置前端控制器6、新建a.jsp文件7、配置tomcat8、启动测试 三、工作流程分析…

echart在折线显示横纵(横纵线沿着折线展示)

产品有个需求&#xff0c;需要在echart折线上展示横纵向坐标系&#xff0c;echart的axisPointer默认是展示在鼠标当前位置的&#xff0c;不符合需求&#xff0c;所以是使用markline实现的 在线例子和源码 先上效果图 实现思路 横纵线的x轴线是比较容易的&#xff0c;因为ech…