VLM训练——Trainer源码解读

news2025/1/24 22:47:12

本文将以LLaVa源码为例,解析如何使用Trainer训练/微调一个VLM。

  • 1. 参数解析
      • ModelArguments
      • DataArguments
      • TrainingArguments
  • 2. 加载模型
  • 3. 加载数据
  • 4. 创建Trainer开始训练

在这里插入图片描述

1. 参数解析

VLM 和 LLM 相关训练框架都会引入 ModelArgumentsDataArgumentsTrainingArgumentsGeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,然后再用parse_args_into_dataclasses()方法解析成 hf 的标准形式model_args, data_args, training_args,实现了两行代码处理训练全程的参数问题。这些命令行参数会从.sh的Shell 代码文件中导入。

from typing import Optional
from dataclasses import dataclass, field
import transformers
 
 
...
 
    添加上述的 Argument Class
 
...
 
 
if __name__ == '__main__':
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
    model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
 
    print(model_args)
    print(data_args)
    print(training_args)
    print(generate_args)

ModelArguments

ModelArguments 通常包含模型路径,以及一些架构上的参数。
在这里插入图片描述
在这里插入图片描述

DataArguments

DataArguments 通常包含 数据路径,以及一些预处理参数。
在这里插入图片描述
在这里插入图片描述

TrainingArguments

TrainingArguments 通常包含模型训练的一些必要参数,如优化器、学习率等参数。
在这里插入图片描述
在这里插入图片描述

2. 加载模型

对于我们不仅要加载 LLM 还需要加载 Image EncoderProjector,因此我们可以直接写一个VLM Model(继承transformer库中的LLM)

		model = LlavaLlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                attn_implementation=attn_implementation,
                torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
                **bnb_model_from_pretrained_args
            )

LlavaLlamaForCausalLM 继承了LLM(transformer.LlamaForCausalLM)VLM抽象类(LlavaMetaForCausalLM)LlavaLlamaForCausalLM中的Visual ModulesLlavaLlamaModel 用于加载 Image Encoder 和 Projector。其多模态forward的流程就是,先对 image 和 text 计算 embedding,然后将其多模态的 tokens 拼接在一起送入LLM。

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs

另外,我们还需要加载Tokenizer,并设置其词表:

		tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )

3. 加载数据

在开始构造Trainer勋训练之前,我们还需要创建dataset和data collator:

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
                                data_path=data_args.data_path,
                                data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

4. 创建Trainer开始训练

    trainer = LLaVATrainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    **data_module)
	trainer.train()
    trainer.save_state()

构造VLM的Trainer,继承Trainer,重写_get_train_samplercreate_optimizer_save_checkpoint_save即可。

class LLaVATrainer(Trainer):

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        if self.args.group_by_modality_length:
            lengths = self.train_dataset.modality_lengths
            return LengthGroupedSampler(
                self.args.train_batch_size,
                world_size=self.args.world_size * self.args.gradient_accumulation_steps,
                lengths=lengths,
                group_by_modality=True,
            )
        else:
            return super()._get_train_sampler()

    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        if is_sagemaker_mp_enabled():
            return super().create_optimizer()

        opt_model = self.model

        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            if self.args.mm_projector_lr is not None:
                projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                        "lr": self.args.mm_projector_lr,
                    },
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                        "lr": self.args.mm_projector_lr,
                    },
                ]
            else:
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                ]

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
            if optimizer_cls.__name__ == "Adam8bit":
                import bitsandbytes

                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

                skipped = 0
                for module in opt_model.modules():
                    if isinstance(module, nn.Embedding):
                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                        logger.info(f"skipped {module}: {skipped/2**20}M params")
                        manager.register_module_override(module, "weight", {"optim_bits": 32})
                        logger.debug(f"bitsandbytes: will optimize {module} in fp32")
                logger.info(f"skipped: {skipped/2**20}M params")

        return self.optimizer

    def _save_checkpoint(self, model, trial, metrics=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

            run_dir = self._get_output_dir(trial=trial)
            output_dir = os.path.join(run_dir, checkpoint_folder)

            # Only save Adapter
            keys_to_match = ['mm_projector', 'vision_resampler']
            if getattr(self.args, "use_im_start_end", False):
                keys_to_match.extend(['embed_tokens', 'embed_in'])

            weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)

            if self.args.local_rank == 0 or self.args.local_rank == -1:
                self.model.config.save_pretrained(output_dir)
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
        else:
            super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        if getattr(self.args, 'tune_mm_mlp_adapter', False):
            pass
        else:
            super(LLaVATrainer, self)._save(output_dir, state_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1815957.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高考志愿填报秘籍:个人篇

选择适合自己的大学和专业,对广大考生来说至关重要。从某种程度上来说,决定了考生未来所从事的行业和发展前景。为了帮助广大考生更加科学、合理地填报志愿,选择适合自己的大学和专业,本公众号将推出如何用AI填报高考志愿专栏文章…

Linux环境各种软件安装配置

安装Java 官网 找个喜欢的版本 下载好了传到linux里,xshell的xftp直接拖过去就可以传 #安装rpm包管理 yum install -y rpm or apt-get install rpm #查找Java rpm -qa | grep java\|jdk\|gcj\|jre #卸载java rpm -e --nodeps jdk-1.8-1.8.0_401-10.x86_64 #安装 …

明基的台灯值得入手吗?书客、柏曼真实横向测评对比

近年来人们在工作、学习、娱乐等方面对电子设备的依赖程度也越来越高,长时间使用电子设备会对眼睛造成一定的伤害,如眼疲劳、干涩、近视等。人们对于能够缓解眼疲劳的照明产品的需求逐渐增加。护眼台灯能够更好地模拟自然光,提供更加柔和舒适…

AD24设计步骤

一、元件库的创建 1、AD工程创建 然后创建原理图、PCB、库等文件 2、电阻容模型的创建 注意:防止管脚时设置栅格大小为100mil,防止线段等可以设置小一点,快捷键vgs设置栅格大小。 1.管脚的设置 2.元件的设置 3、IC类元件的创建 4、排针类元件模型创建…

机器学习笔记 - 用于3D数据分类、分割的Point Net简述

一、简述 在本文中,我们将了解Point Net,目前,处理图像数据的方法有很多。从传统的计算机视觉方法到使用卷积神经网络到Transformer方法,几乎任何 2D 图像应用都会有某种现有的方法。然而,当涉及到 3D 数据时,现成的工具和方法并不那么丰富。3D 空间中一个工具就是Point …

14、modbus poll 使用教程小记1

开发平台:Win10 64位 Modbus Slave版本:64位 7.0.0 Modbus Poll版本:64位 7.2.2 因为项目中经常会用到modbus协议,所以就避免不了的要使用modbus测试工具,Modbus Slave/Poll无疑是众多测试工具中应用最广泛的。 文章目…

dll文件丢失了要如何处理?教你一键修复所有dll缺失的方法

dll文件丢失了要如何处理?其实dll文件的丢失还是比较常见的,它的丢失会引起一些程序无法启动,所以我们必须要去修复dll文件,这点是毋容置疑的,其修复方法也是有很多种的,今天就来给大家详细的聊一下dll文件…

BUAA-2024年春-OO第四单元总结

正向建模与开发 在本单元中,我们需要模拟一个小型的图书管理系统,完成图书馆所支持的相关业务,并遵守一定的规章制度。与前几次不同的是,本单元中,我们需要预先将自己的设计思路用UML来实现,然后进行编程。…

Coze+Discord:打造你的免费AI助手(教您如何免费使用GPT-4o/Gemini等最新最强的大模型/Discord如何正确连接Coze)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 文章内容 📒📝 准备Discord📝 准备Coze🔌 连接💡 测试效果⚓️ 相关链接 ⚓️📖 介绍 📖 你是否想免费使用GPT-4o/Gemini等最新最强的大模型,但又不想花费高昂的费用?本文将教你如何通过Coze搭建Bot,并将其转发…

RAG系统进阶(五)文本分割优化技巧及代码

背景 前边在介绍RAG系统时提到了文本分割(或分段)的作用和重要性。也提到了分段后所带来的一些问题,比如由于分段导致检索出来的TOP-n的结果可能未包含完整的答案。 粒度太大可能导致检索不精准,粒度太小可能导致信息不全面问题的…

教你一段代码激活计算机系统

方法简单粗暴,再也不用遭受未激活的烦恼了! 新建文本 输入代码,把文件后缀.txt改.bat slmgr /skms kms.03k.org slmgr /ato

2024-2025最新软考系统架构设计师的复习资料教材,解决如何快速高效通过该考试,试题的重点和难点在哪里?案例分析题和论文题的要点和踩坑点分析

目录 引言考试概述 考试结构考试内容 复习策略 制定复习计划学习资源 知识点详解 系统架构基础设计原则与模式系统分析与设计软件开发过程项目管理系统集成性能与优化安全性设计新兴技术 试题解析 选择题案例分析题论文题 重点与难点分析模拟试题与答案参考资料总结 引言 系…

DeepSpeed Pipeline并行

DeepSpeed为了克服一般Pipeline并行的forward时weights,和backward时计算梯度的weights, 二者不相同的问题,退而求其次,牺牲性能,采用gradient-accumulate方式,backward时只累积梯度至local,并不更新weights&#xff1…

css display:grid布局,实现任意行、列合并后展示,自适应大小屏幕

现有6X7列的一个布局&#xff0c;如下图所示 想要用户能组成任意矩形盒子&#xff0c;并展示内容&#xff0c;具体效果如下&#xff08;仅为一个示例&#xff0c;其实可以任意组合矩形&#xff09;&#xff1a; html代码&#xff1a; <div class"grid-container"…

SQL 截取函数

目录 1、substring 2、left 3、right 4、substring_index 1、substring 用途&#xff1a;字段截取从指定开始的字符开始&#xff0c;截取要的数&#xff1b;指定开始的字符数字可以用负的&#xff0c;指定开始的字符从后往前(向左)数&#xff0c;截取要的数不能为负。 语…

BoxSizer布局

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 在前面的实例中&#xff0c;使用了文本和按钮等控件&#xff0c;并将这些控件通过pos参数布置在pannel画板上。虽然这种设置坐标的方式很容易理解&am…

GitLab教程(四):分支(branch)和合并(merge)

文章目录 1.分支&#xff08;branch&#xff09;&#xff08;1&#xff09;分支的概念&#xff08;2&#xff09;branch命令 2.合并&#xff08;merge&#xff09;&#xff08;1&#xff09;三个命令pullfetchmergegit fetchgit mergegit pull &#xff08;2&#xff09;合并冲…

C++开源软件:跨平台本地密码管理器KeePassXC/KeePassDX

KeePassXC、KeePass和KeePassDX在功能、平台和特点上有所区别&#xff0c;以下是对这三款密码管理器的清晰区分&#xff1a; KeePassXC&#xff1a; 平台&#xff1a;跨平台&#xff0c;支持Windows、macOS和Linux等主流操作系统。 安全性&#xff1a;使用AES加密算法&#x…

路虽远,行则将至 - 附暑期实习、秋招历程经验分享

前言 大家好 许久没有时间静下心来打开编辑器写文章了 忙碌暂过&#xff0c;难得一闲时 求学三年&#xff0c;终到离别时 回忆过往&#xff0c;枯燥且多彩 有一点经验&#xff0c;以文字形式分享&#xff0c;希望帮助到大家 可能是这段时间事多且杂&#xff0c;加上很长一…

NSS题目练习8

[SWPUCTF 2022 新生赛]numgame 打开发现不能直接更改数值&#xff0c;会变成负数&#xff0c;快捷键不能用&#xff0c;输入view-source查看源代码&#xff0c;发现js文件 点开后发现最下面有个酷似flag的东西 提交后是错的&#xff0c;看着像是base64&#xff0c;解码后得到另…