MM-LLM:使用Llava类构建图文多模态大模型实践

news2024/10/6 19:13:42

在这里插入图片描述
多模态大模型的结构如上,llava是用两层MLP作为连接器。该模式也是后续很多工作的基础。

本文主要参考了https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava的工作,最初是在b站看到的,讲解的很细致。

基础模型

大语言模型:Qwen2-1.5B-Instruct
视觉模型:clip-vit-large-patch14-336
连接器:MLP
框架:llava模型

1.LLM的处理

下载模型权重到本地后,修改Qwen2-1.5B-Instruct/tokenizer_config.json的added_tokens_decoder的值,添加

"151646": {
      "content": "<image>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    }

additional_special_tokens添加 "<image>"

2.初始化llava模型

# 模型权重路径
modify_qwen_tokenizer_dir = "autodl-tmp/Qwen2-1.5B-Instruct"
clip_model_name_or_path = (
    "autodl-tmp/clip-vit-large-patch14-336"
)

# 加载qwen2
qwen_tokenizer = AutoTokenizer.from_pretrained(modify_qwen_tokenizer_dir)
qwen_model = AutoModelForCausalLM.from_pretrained(
                                            modify_qwen_tokenizer_dir, 
                                            device_map='cuda:0', 
                                            torch_dtype=torch.bfloat16
                                            )


# 加载clip
clip_model = AutoModel.from_pretrained(clip_model_name_or_path, device_map="cuda:0")
processor = AutoProcessor.from_pretrained(clip_model_name_or_path)

# 将clip模型和llm_model模型的config拿出来,初始化一个llava model
# Initializing a CLIP-vision config
vision_config = clip_model.vision_model.config
# Initializing a Llama config
text_config = qwen_model.config
# Initializing a Llava llava-1.5-7b style configuration
configuration = LlavaConfig(vision_config, text_config)
# Initializing a model from the llava-1.5-7b style configuration
model = LlavaForConditionalGeneration(configuration)

输出:

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=1024, out_features=4096, bias=True)
              (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (multi_modal_projector): LlavaMultiModalProjector(
    (linear_1): Linear(in_features=1024, out_features=1536, bias=True)
    (act): GELUActivation()
    (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
  )
  (language_model): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(151936, 1536)
      (layers): ModuleList(
        (0-27): 28 x Qwen2DecoderLayer(
          (self_attn): Qwen2SdpaAttention(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
            (k_proj): Linear(in_features=1536, out_features=256, bias=True)
            (v_proj): Linear(in_features=1536, out_features=256, bias=True)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm()
          (post_attention_layernorm): Qwen2RMSNorm()
        )
      )
      (norm): Qwen2RMSNorm()
    )
    (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
  )
)

这样得到了llava模型的结构,但是旧有的权重参数还没迁移过来,要将其移动到新model里。

# 权重复制
model.vision_tower.vision_model = clip_model.vision_model
model.language_model = qwen_model

然后保存到本地,注意要将autodl-tmp/processor的preprocessor_config.json复制到autodl-tmp/vlm_1

# 保存模型
model.save_pretrained("autodl-tmp/vlm_1")
qwen_tokenizer.save_pretrained("autodl-tmp/vlm_1")
processor.save_pretrained("autodl-tmp/processor")

3.数据集加载代码

采用该数据集:https://huggingface.co/datasets/OpenGVLab/ShareGPT-4o

主要代码:

class LlavaDataset(Dataset):
    def __init__(self, dataset_dir: str) -> None:
        super().__init__()

        self.chat_data, self.image_dir = self.build_dataset(dataset_dir)

    def build_dataset(self, data_dir: str) -> Tuple[List[Dict], Path]:
        # 得到对话文件和图像文件的路径
        data_dir = Path(data_dir) # 父文件夹路径
        chat_file = data_dir.joinpath("final_data.jsonl") # 对话文件
        image_dir = data_dir.joinpath("image") # 图像文件夹
        # 读取为记录,转为dict
        chat_data = pd.read_json(chat_file, lines=True).to_dict(orient="records")

        return chat_data, image_dir

    def __len__(self):
        return len(self.chat_data)

    def __getitem__(self, index) -> Tuple[str, str, Path]:
        # 根据索引定位到记录
        cur_data = self.chat_data[index] # 定位

        conversations = cur_data.get("conversations") # 字典格式获取到对话记录

        human_input = conversations[0].get("value") # 查询
        chatbot_output = conversations[1].get("value") # 回复
        image_path = self.image_dir.joinpath(cur_data.get("image")) # 图片的路径,由图片文件夹+图片名构成

        return human_input, chatbot_output, image_path

4.训练

使用deepseed训练,主要代码

def train():

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model, processor = load_model_processor(model_args)
    data_collator = TrainLLavaModelCollator(processor, -100)
    train_dataset = load_dataset(data_args)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)

5.推理

没有训练的模型进行推理的结果:

很抱歉,我无法看到或描述图片,因为我是一个文本生成模型,无法处理图像。如果您需要帮助,可以提供文字描述,我会尽力帮助您。

训练后的模型推理:

The image depicts a scene of a person sitting on a chair with their
legs crossed. The person is wearing a white shirt and dark blue jeans.
The person’s hair is styled in a messy, tousled manner, which adds to
the casual and relaxed atmosphere of the image. The person’s eyes are
closed, and they appear to be in a state of deep thought or
contemplation.

In the background, there is a small, white, rectangular object that
appears to be a piece of paper or a piece of writing. The object is
positioned in a manner that suggests it might be part of a document or
a note. The background is a light beige color, which contrasts with
the person’s clothing and the white object.

The chair is a wooden chair with a simple design, featuring a single
armrest and a backrest. The chair is positioned on a dark wooden
floor, which adds to the overall casual and comfortable feel of the
scene. The floor is also light beige, which complements the background
and the person’s clothing.

The lighting in the image is soft and diffused, giving the scene a
warm and inviting atmosphere. The person’s posture suggests they are
in a relaxed position, possibly after a long day or a moment of
reflection.

In summary, the image captures a person sitting on a chair with their
legs crossed, wearing casual clothing, and in a relaxed position. The
background includes a small white object, and the lighting is soft and
diffused, creating a warm and inviting atmosphere.

我仅仅训练了三轮,使用了不到300条数据。虽然结果不是很好,但是可以看出来是有成效的。
在这里插入图片描述

在我查找的多模态大模型实现中性价比是最高的,不用重写LLM的forward函数什么的。

相关代码放在https://github.com/stay-leave/enhance_llm。

参考:
https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava
https://github.com/OpenGVLab/InternVL/blob/main/internvl_chat
https://github.com/AviSoori1x/seemore
https://github.com/alexander-moore/vlm
https://github.com/WatchTower-Liu/VLM-learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1886971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【BES2500x系列 -- RTX5操作系统】深入探索CMSIS-RTOS RTX -- 同步与通信篇 -- 消息队列和邮箱处理 --(四)

&#x1f48c; 所属专栏&#xff1a;【BES2500x系列】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f49…

容器内存

一、容器内存概述 容器本质上还是一个进程&#xff0c;是一个被隔离和限制的进程。因此容器内存和进程内存在表现形式上其实是一样的&#xff0c;这块主要涉及三部分内容&#xff1a;RSS&#xff0c;page cache和swap这三部分&#xff0c;容器基于memory Cgroup对内存进行限制…

k8s部署单机版mysql8

一、创建命名空间 # cat mysql8-namespace.yaml apiVersion: v1 kind: Namespace metadata:name: mysql8labels:name: mysql8# kubectl apply -f mysql8-namespace.yaml namespace/mysql8 created# kubectl get ns|grep mysql8 mysql8 Active 8s二、创建mysql配…

某网页gpt的JS逆向

原网页网址 (base64) 在线解码 aHR0cHM6Ly9jbGF1ZGUzLmZyZWUyZ3B0Lnh5ei8 逆向效果图 调用代码&#xff08;复制即用&#xff09; 把倒数第三行换成下面的base64解码 aHR0cHM6Ly9jbGF1ZGUzLmZyZWUyZ3B0Lnh5ei9hcGkvZ2VuZXJhdGU import hashlib import time import reques…

Python学习篇:PyCharm的基本使用教程(二)

目录 1 前言 2 创建Python项目 3 创建Python文件 4 编写 Hello World 并运行 5 PyCharm界面简介 1 前言 PyCharm的使用贯穿整个Python的学习&#xff0c;所以单独拿出来出教程不合适&#xff0c;说多了对于新手来说也还是不明白&#xff0c;这里我们先从学习开始前大家需…

【仪器仪表】 矢量网络分析仪 Vector Network Analyzer

主要功能&#xff1a; 测量S参数&#xff1a; S11&#xff08;输入反射系数&#xff09;&#xff1a;测量输入端口的反射。S21&#xff08;正向传输系数&#xff09;&#xff1a;测量从输入端口到输出端口的传输。S12&#xff08;反向传输系数&#xff09;&#xff1a;测量从输…

【后端面试题】【中间件】【NoSQL】MongoDB的配置服务器、复制机制、写入语义和面试准备

MongoDB的配置服务器 引入了分片机制之后&#xff0c;MongoDB启用了配置服务器(config server) 来存储元数据&#xff0c;这些元数据包括分片信息、权限控制信息&#xff0c;用来控制分布式锁。其中分片信息还会被负责执行查询mongos使用。 MongoDB的配置服务器有一个很大的优…

全网小视频去水印接口使用说明

一、请求地址&#xff1a; https://www.lytcreate.com/api/qsy/ 二、请求方式&#xff1a;POST 三、请求体&#xff1a;JSON body {"token": "个人中心的token","url": "视频分享地址"} token获取地址&#xff0c;访问&#xff…

DP:子序列问题

文章目录 什么是子序列子序列的特点举例说明常见问题 关于子序列问题的几个例题1.最长递增子序列2.摆动序列3.最长递增子序列的个数4.最长数对链5.最长定差子序列 总结 什么是子序列 在计算机科学和数学中&#xff0c;子序列&#xff08;Subsequence&#xff09;是指从一个序列…

继承QAbstractListModel,结合QListView

这里想要写一个QAbstractListModel的子类&#xff0c;学习一下如何实例化QAbstractListModel。 QAbstractListModel子类化-CSDN博客 QVariant与自定义类型互转之奇巧淫技_qt 类型转 qvariant-CSDN博客 #pragma once#include <QStyledItemDelegate> #include <qmeta…

012-GeoGebra基础篇-构造圆的切线

前边文章对于基础内容已经悉数覆盖了&#xff0c;这一篇我就不放具体的细节&#xff0c;若有需要可以复刻一下 目录 一、成品展示二、算式内容三、正确性检查五、文章最后 一、成品展示 二、算式内容 A(0,0) B(3,0) c: Circle(A,B) C(5,4) sSegment(A,C) DMidpoint(s) d: Circ…

javaEE——Servlet

1.web开发概述 所谓web开发,指的是从网页中向后端程序发送请求,与后端程序进行交互 2.java后端开发环境搭建 web后端(javaEE)程序需要运行在服务器中的&#xff0c;这样前端才可以访问得到 3.服务器是什么&#xff1f; ①服务器就是一款软件&#xff0c;可以向其发送请求&#…

分解+降维+预测!多重创新!直接写核心!EMD-KPCA-Transformer多变量时间序列光伏功率预测

分解降维预测&#xff01;多重创新&#xff01;直接写核心&#xff01;EMD-KPCA-Transformer多变量时间序列光伏功率预测 目录 分解降维预测&#xff01;多重创新&#xff01;直接写核心&#xff01;EMD-KPCA-Transformer多变量时间序列光伏功率预测效果一览基本介绍程序设计参…

嵌入式学习——硬件(Linux系统在2440上的启动)——day57

1. Linux2.6系统在s3c2440上的启动过程分三个阶段 1.1 启动u-boot 1.2 启动Linux内核 1.3 挂载根文件系统 2. bootloader 2.1 定义 bootloader的本质是一个裸机程序&#xff0c;bootlood专门是为了能够正确地启动linux操作系 统&#xff0c;在系统初上电时需要对系统做一些…

中霖教育怎么样?咨询工程师备考技巧

中霖教育怎么样?咨询工程师备考技巧 在备考咨询工程师的过程中&#xff0c;掌握正确的方式方法能够少走很多弯路&#xff0c;所以想取得好成绩采用恰当的备考技巧是非常重要的。 1、了解题型及考试结构 在准备阶段&#xff0c;理解各类型题目的特征和作答要求&#xff0c;确…

github仓库的基本使用-创建、上传文件、删除

1.第一步 先点击左侧菜单栏的远程仓库 2.点击NEW 3.创建仓库 然后点击右下角的 CREATE 4.点击code 点击SSH,然后我出现了You don’t have any public SSH keys in your GitHub account. You can add a new public key, or try cloning this repository via HTTPS. 1&#xff…

【JavaEE精炼宝库】多线程进阶(2)synchronized原理、JUC类——深度理解多线程编程

一、synchronized 原理 1.1 基本特点&#xff1a; 结合上面的锁策略&#xff0c;我们就可以总结出&#xff0c;synchronized 具有以下特性(只考虑 JDK 1.8)&#xff1a; 开始时是乐观锁&#xff0c;如果锁冲突频繁&#xff0c;就转换为悲观锁。 开始是轻量级锁实现&#xff…

成人职场商务英语学习柯桥外语学校|邮件中的“备注”用英语怎么说?

在英语中&#xff0c;"备注"通常可以翻译为"Notes" 或 "Remarks"。 这两个词在邮件中都很常用。例如: 1. Notes Notes: 是最通用和最常见的表达&#xff0c;可以用在各种情况下&#xff0c;例如&#xff1a; 提供有关电子邮件内容的附加信息 列…

Mysql并发控制和日志

文章目录 一、并发控制锁机制事务&#xff08;transactions&#xff09;事务隔离级别 二、日志事务日志错误日志通用日志慢查询日志二进制日志 备份在线查看二进制离线查看二进制日志 一、并发控制 锁机制 锁类型&#xff1a; 读锁&#xff1a;共享锁&#xff0c;也称为 S 锁…

ANSYS新能源汽车动力电池仿真应用案例

燃料电池是一种非燃烧过程的电化学能转换装置&#xff0c;将氢气&#xff08;等燃料&#xff09;和氧气的化学能连续不断地转换为电能&#xff0c;是发电设备而非储能设备。 根据电解质的不同&#xff0c;分为碱性燃料电池AFC、磷酸燃料电池PAFC、熔融碳酸盐燃料电池MCFC、固体…