LLaVA-MoLE:解决多模态大型语言模型指令微调中的数据冲突问题

news2024/9/21 18:44:00

人工智能咨询培训老师叶梓 转载标明出处

多模态大模型(MLLMs)通过指令微调(instruction finetuning),能够执行各种任务,如理解图表、处理文档和回答基于图像的问题。但是,当从不同领域混合指令数据进行微调时,模型在特定领域的任务上可能会出现性能下降。这种现象被称为数据冲突,它限制了通过增加新领域训练数据来扩展MLLM能力的可能性。为了应对这一挑战,来自美团公司的研究者们提出了一种新颖的方法——LLaVA-MoLE,即稀疏混合LoRA专家(Sparse Mixture of LoRA Experts)。

该模型基于LLaVA-1.5,通过在Transformer层中引入一组LoRA(Low-Rank Adaption)专家,并为每个token选择最适合的专家进行处理。这种设计允许模型根据不同领域的token激活不同的专家,从而扩展了MLLM处理多领域数据的能力。

论文链接:https://arxiv.org/pdf/2401.16160

方法

低秩适应(LoRA)是一种针对大模型(LLMs)的参数高效微调方法。它能够应用于任意线性层。具体来说,对于一个输入为 和权重矩阵 ​ 的线性层 h=Wx,LoRA 学习一个低秩分解的更新:

其中,是低秩矩阵,r 是远小于d_o​ 和 d_i​ 的秩,α 控制对原始W 的变化幅度。在学习LoRA模块过程中,只有矩阵A 和 B 会被更新。

图 2 展示了 LLaVA-MoLE 模型的整体框架,该模型基于 LLaVA-1.5 构建,采用了稀疏混合 LoRA 专家(Sparse Mixture of LoRA Experts)的方法来训练。

  1. 输入图像处理:输入图像首先通过 CLIP ViT(Vision Transformer)进行处理,CLIP ViT 是一种视觉编码器,能够将图像转换成一系列的视觉嵌入(visual embeddings)。之后,这些视觉嵌入通过一个两层的多层感知器(MLP)进行进一步的映射。

  2. 文本输入处理:文本输入首先被分词(tokenized),然后通过词嵌入矩阵转换成嵌入表示,这些嵌入与视觉输入一起被串联(concatenated),形成最终输入到大型语言模型(LLM)的混合嵌入序列。

  3. 稀疏混合 LoRA 专家:在 LLaVA-MoLE 模型中,每个 Transformer 层都采用了提出的稀疏混合 LoRA 专家进行训练。具体来说,每个全连接层(FFN)都会根据路由器(router)的输出分布选择并结合一个 LoRA 专家来进行计算。

  4. 路由器(Router)的作用:路由器负责为每个 token 分配一个最合适的 LoRA 专家。路由器的输出分布决定了 FFN 应该选择哪个专家来处理当前的 token。

  5. 自注意力(Self-Attention)训练:自注意力机制同样采用 LoRA 进行训练,但在这个框架中没有应用专家混合(MoE)。

  6. 计算并行化:对于每个 LoRA 专家,相同子序列的 token 可以并行计算,这提高了模型训练的效率。

如图 2 所示,一个MLLM可以被表述为:

其中 是视觉编码器和适配器,将输入图像映射成一系列视觉嵌入,将输入问题 T_q​ 进行标记化并用词嵌入矩阵嵌入离散标记,而 ∣∣ 是序列连接操作。因此,MLLM的输入实际上是一个混合嵌入序列。训练MLLM的指令数据被组织成三元组 (),不同的指令数据集可能有不同的分布,导致训练出的MLLM表现出不同的行为或专长。

为了缓解混合不同类型的指令数据时产生的冲突,研究者引入了一组LoRA专家和一个路由器。在每个输入token上,路由器学习选择最合适的专家激活,使模型具有额外的能力来处理不同类型的输入。假设每层有K 个专家,选择具有最高路由函数值的专家:

然后激活选定的专家来执行实际计算,而忽略当前token的其他专家。例如,对于现代LLMs中的FFN层通常是多层的,每一层的FFN都会有一个单独的MoE,但它们共享相同的路由器。通过只激活top-1专家,实际计算成本与原始FFN中的plain-LoRA大致相同。

为了确保模型的高效运行,研究者还引入了负载平衡损失,以避免专家分配的严重不平衡。负载平衡损失的公式为:

其中 cj​ 是分配给第j 个专家的token数量,pj​ 是第 j 个专家的总路由概率。通过最小化,专家的分配趋于均匀,从而避免了某些专家过载而另一些专家闲置的问题。

通过上述方法,LLaVA-MoLE模型能够有效地解决数据冲突问题,同时保持了计算成本的可控性,为多模态大型语言模型的微调提供了一种有效的解决方案。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加助理微信提供直播链接:amliy007,29.9元即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory,关注享粉丝福利,限时免费CSDN听直播后的录播讲解。
LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

基本模型架构遵循LLaVA1.5的设计,其中使用了CLIP ViT-L作为视觉编码器,输入图像分辨率为336x336,补丁大小为14。适配器是一个两层的MLP,用于转换来自ViT的576个token。大型语言模型(LLM)是Vicuna-7B-v1.5。在所有实验的训练过程中,ViT和Vicuna的权重都被冻结。除非特别说明,否则应用于LLM的LoRA秩是32。

模型在两个阶段进行训练:预训练和指令微调。预训练阶段使用了ShareGPT4V数据集,包含由GPT4V生成的数据训练的标题器产生的130万个详细的字幕数据。指令微调阶段,采用了来自三个不同领域的多模态指令数据集:一般多任务、文档和生物医学。M3IT和ShareGPT4V Instruct是两个一般多任务指令数据集,而UReader收集的文档导向指令数据集包含来自多个公共数据集的图像和指令。还使用了PathVQA作为生物医学领域的指令数据。所有这些数据集都是公开的,并且按照UReader的数据划分进行训练和测试。表格1列出了预训练(PT)和监督指令微调(SFT)阶段的训练参数。

表 2 展示了在不同数据和MoE配置下训练的模型的实验结果。首先提供了官方LLaVA-1.5和LLaVA-Med模型在每个基准测试上的结果。然后,通过在不同数据集上单独训练plain-LoRA模型,并将其命名为LLaVA-1.5、LLaVA-Doc和LLaVA-Med。这些模型在与其训练数据集相对应的基准测试上的性能被视为该基准的基线性能。例如,特别重现的LLaVA-1.5†专门在一般多任务指令数据上训练,在Tiny LVLM-eHub上实现了与官方LLaVA-1.5 (307.2)相当的306.3的总分。通过混合不同数据集,发现LLaVA-Mix在eHub的整体性能比LLaVA1.5†降低了7-9分。这表明一般多任务数据与这些数据类型之间存在冲突,这种冲突可能会损害模型的一般多任务QA能力。提出的LLaVA-MoLE成功地解决了上述冲突。通过比较LLaVAMoLE[1,1,0]与LLaVA-Mix[1,1,0],可以观察到eHub的整体性能显著提高,与基线LLaVA-1.5†相当,而UReader基准测试的性能甚至超过了基线LLaVA-Doc†,例如在ChartQA上绝对性能提高了6.4。这可以证明混合专家已经学会了处理不同类型的指令数据并减少潜在的数据冲突。

表 3 展示了在不同LoRA秩下训练的模型的实验结果。可以看到,对于LoRA秩32、64和96,将文档指令数据与一般多任务指令数据混合都会导致eHub基准测试的性能下降。但通过比较实验LLaVA-Mix[1,1]-R32、LLaVA-Mix[1,1]-R64和LLaVA-Mix[1,1]-R96的结果,也发现增加LoRA秩,即增加模型容量,可以在一定程度上缓解数据冲突问题:eHub的总分从R32的298.8增加到R96的301.1。此外,如果将LoRA秩增加到128,似乎解决了这个问题。然而,作者认为简单地提高模型容量是一种昂贵的解决方案,会导致训练过程中的计算和内存增加。而提出的LLaVA-MoLE可以在不增加太多额外成本的情况下解决这个问题。值得注意的是,对于较小(32)和较大(128)的LoRA秩,LLaVAMoLE在两个基准测试上都显著优于LLaVA-Mix。

图 3 展示了在所有三个数据集的混合上训练的LLaVA-MoLE模型的路由选择的粗略分析。通过计算每个基准测试中分配给每个专家的token比例的均值和标准差,对第0层、第2层、第10层和第28层的结果进行了可视化。对于某些层,例如第2层和第10层,不同类型数据的专家选择模式相似,但在不同层之间有所不同。也有一些层(第10层和第28层),每种类型的数据都有自己的专家选择模式。没有观察到明显的模式表明某个特定专家在其他专家中一直更受青睐。但某些专家可能在特定数据集上比其他专家更倾向于被选择,例如,专家0在所有层的PathVQA样本中更频繁地被激活。

通过这些详细的实验设置和结果分析,证明了LLaVA-MoLE模型在解决多模态大型语言模型指令微调中的数据冲突问题方面是有效的,并且能够在保持计算成本可控的同时提高模型性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2094650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法——K-means算法和算法改进

简介:个人学习分享,如有错误,欢迎批评指正。 一、什么是K-means算法? K-means算法是一种无监督的聚类算法,用于将一组数据点分为K个簇(cluster)。其核心目标是将数据点划分到K个不同的簇中&…

CAS单点登录安装文档

CAS单点登录安装文档 目录 1、 下载CAS 2、 下载xmlsectool 3、 安装xmlsectool 4、 打包CAS 5、 部署CAS 6、 访问CAS 1.下载CAS 在CAS官方Github下载:https://codeload.github.com/apereo/cas/zip/v5.3.0 2.下载xmlsectool 在MVNREPOSITORY下载xm…

JavaWeb JavaScript ⑨ 正则表达式

生命的价值在于你能够镇静而又激动的欣赏这过程的美丽与悲壮 —— 24.8.31 一、正则表达式简介 正则表达式是描述字符模式的对象。正则表达式用简单的API对字符串模式匹配及检索替换,是对字符串执行模式匹配的强大工具。 1.语法 var pattnew RegExp(pattern,modi…

【软考】IO软件

目录 1. 说明2. 读硬盘文件3. IO 系统的层次结构与每层的主要功能4. 例题4.1 例题1 1. 说明 1.设备管理软件的设计水平决定了设备管理的效率。2.从事I0 设备管理软件的结构,其基本思想是分层构造,也就是说把设备管理软件组织成为一系列的层次。3.低层与…

【机器学习】任务二:波士顿房价的数据与鸢尾花数据分析及可视化

目录 1.实验知识准备 1.1 NumPy 1.2 Matplotlib 库 1.3 scikit-learn 库: 1.4 TensorFlow 1.5 Keras 2.波士顿房价的数据分析及可视化 2.1波士顿房价的数据分析 2.1.1 步骤一:导入所需的模块和包 2.1.2 步骤二:从 Keras 库中加载波…

Linux驱动开发基础(DS18B20温度模块)

所学来自百问网 目录 1.DS18B20 简介 2.硬件设计 3.软件设计 3.1 存储器介绍 3.2 通信时序 3.2.1 初始化时序 3.2.2 写时序 3.2.3 读时序 3.3 常用命令 4. 示例代码 4.1 驱动代码 4.2 应用代码 4.3 Makefile 4.4 实验效果 1.DS18B20 简介 DS18B20 温度传感器具…

[线程]阻塞队列

文章目录 阻塞队列生产者消费者模型通过BlockingQueue理解阻塞队列自己实现阻塞队列 阻塞队列 我们之前学的队列, 其实是最基础的队列, 实际开发中, 针对队列还有很多种变种 普通队列优先级队列阻塞队列 先进先出, 线程安全, 并且带有阻塞功能 阻塞功能指: 如果队列为空, 尝试…

23种设计模式之模板模式

一.什么是模板模式 ‌‌模板模式是一种行为型设计模式,它定义了一个算法的骨架,而将一些步骤留给子类实现。‌这种模式允许子类在不改变算法结构的基础上,重新定义算法的某些步骤。模板模式属于行为型设计模式,主要用于处理那些需…

excel透视图、看板案例(超详细)

一、简介 Excel透视图(Pivot Table) 功能:透视图是一种强大的数据分析工具,用于汇总、分析和展示数据。它允许用户对数据进行重新排列和分类,从而更容易发现数据中的模式和趋势。用途:可以用来生成动态报表…

python07-单元测试框架unittest1-3

当测试用例数量增加,一个一个执行效率低下,需要将工程下的,case收集并按顺序执行将对应的代码放入run_tests.py run_tests.py:运行程序目的 收集所有的测试用例执行生成测试报告 运用测试用例的收集器或测试用例的加载器 7 Tes…

2.4梯度下降与量化策略优化

1. 梯度下降法的基本原理 欢迎来到“梯度下降”的世界!听上去有点像在爬山对吧?其实,这个算法的灵感确实来自爬山。想象你在一个山谷中迷路了,周围雾蒙蒙的,看不清楚路,只能摸着石头一步一步往下走。每走一…

短效ip—互联网利器

《瞬息万变:短效IP在网络世界的奇幻之旅》 在浩瀚无垠的数字宇宙中,互联网如同一条奔腾不息的河流,携带着无数创新与技术的浪花。在这片日新月异的疆域里,短效IP以其独有的魅力,悄然成为网络探险家们手中的魔法钥匙。它…

编译原理概述

编译原理概述 编译原理是计算机科学的重要领域,主要研究编译器如何将高级编程语言转换为机器可执行代码。编译器的工作流程可以分为多个阶段,每个阶段都有特定的功能和目标。理解编译原理对于编写高效的代码、优化程序性能以及开发新语言或编译器非常重…

Java 线程实现暂停、中止

需求:用户可以开启任务,暂停任务和中止任务。 用户开启任务后,可以随时暂停或者中止。暂停后又可以回到原进度继续运行。 这里写目录标题 demo版-使用废弃的stop、suspend、resume实现为什么废弃了?不用stop,如何销毁线…

MySQL5.7.36之主从复制部署安装-centos7

主库是192.168.31.209:3306 从库是192.168.31.210:3308、192.168.31.209:3307、192.168.31.210:3309、192.168.31.211:3310、192.168.31.211:3311 切记:不管是主库还是从库,server_id一定不能重复 1、主库创建复制账号及授权 create user repl% iden…

Linux驱动开发基础(IRDA 红外遥控模块)

所学来自百问网 目录 1.红外遥控简介 2.硬件设计 3.软件设计 4. 示例代码 4.1 驱动代码 4.2 Makefile 4.3 实验效果 1.红外遥控简介 红外遥控被广泛应用于家用电器、工业控制和智能仪器系统中,像我们熟知的有电视机盒子遥控器、空调遥控器。红外遥控器系统…

分类预测|基于灰狼GWO优化BP神经网络的数据分类预测Matlab程序GWO-BP 含基础BP对比模型

分类预测|基于灰狼GWO优化BP神经网络的数据分类预测Matlab程序GWO-BP 含基础BP对比模型 文章目录 一、基本原理1. 灰狼优化算法(GWO)简介GWO的基本步骤 2. BP神经网络简介**BP网络的基本结构****训练过程** 3. GWO-BP分类预测的结合**结合流程** 4. GWO-…

苹果mac数据恢复概率大吗 mac数据恢复专业软件哪个好用

一般情况下,当我们把电脑中的数据删掉后,都会保存在回收站里面,但如果回收站被清空了或者数据在回收站中没有找到的话,那么,之前被删掉的数据还能恢复吗?恢复的概率有多大呢? 答案是可以的&…

Hive 案例分析(B站用户行为大数据分析)

Hive 案例分析(B站用户行为大数据分析) 一、案例需求二、设计数据表结构2.1 user 表结构2.2 video 表结构 三、创建数据表3.1 创建 video 数据库3.2 创建外表3.1.2 创建 external_user3.1.3 创建 external_video 3.2 创建内表3.2.1 创建 orc_user3.2.2 创…

Atlas阿特拉斯wordpress主题

Atlas阿特拉斯是一个专为WordPress平台设计的多功能主题,该主题由简站wordpress主题开发,旨在为用户提供一个强大而灵活的工具,以构建各种类型的网站。以下是对Atlas阿特拉斯WordPress主题的简介: Atlas阿特拉斯WordPress主题简介…