PyTorch 新库 TorchMultimodal 使用说明:将多模态通用模型 FLAVA 扩展到 100 亿参数

news2024/11/27 14:43:15

先前的文章中,我们介绍了 TorchMultimodal,今天我们将从一个具体案例出发,演示如何在 Torch Distributed 技术加持下,在 TorchMultimodal 库中扩展多模态基础模型。

近年来,大模型已成为一个备受关注的研究领域。以自然语言处理为例,语言模型已经从几亿参数(BERT)发展到了几千亿参数(GPT-3),对下游任务的性能提升显示出重大作用。

业界对大规模语言模型如何扩展进行了广泛的研究。在视觉领域也可以观察到类似的趋势,越来越多的开发者开始转向基于 transformer 的模型(如 Vision Transformer、Masked Auto Encoders)。

显然,由于大规模模型的发展,单模态(如文本、图像、视频)相关研究不断改进,框架也迅速适应了更大的模型。

同时,随着图像-文本检索、视觉问答、视觉对话和文本到图像的生成等任务在现实世界中的应用,多模态越来越受到重视。

接下来就是训练大规模多模态模型。该领域也有了一些努力成果,如 OpenAI 的 CLIP,谷歌的 Parti 和 Meta 的 CM3。

本文将通过一个案例研究,展示如何使用 PyTorch Distributed 技术将 FLAVA 扩展到 100 亿参数。

补充阅读:HyperAI超神经:Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

FLAVA 是一个视觉和语言基础模型在 TorchMultimodal 中可用

FLAVA 在单模态和多模态 Benchmark 中都表现出了非常突出的性能优势。本文将结合相关代码示例,演示如何扩展 FLAVA 。

代码详见:

multimodal/examples/flava/native at main · facebookresearch/multimodal · GitHub

扩展 FLAVA 概览

FLAVA 是一个基础多模态模型,由基于 transformer 的图像和文本编码器以及基于 transformer 的多模态融合模块组成。

FLAVA 在单模态和多模态数据上都进行了预训练,且这些数据的损失 (loss) 各不相同,包括掩码的语言、图像和多模态模型 loss,要求模型从其上下文中重建原始输入(自监督学习)。

此外,它还使用了图像文本匹配损失 (image text matching loss),包括对齐图像-文本对的 positive 和 negative 示例,以及 CLIP 风格的对比损失。

除了多模态任务(如图像-文本检索),FLAVA 在单模态 Benchmark(如 NLP 的 GLUE 任务和视觉的图像分类)上也表现出极佳的性能。

最初 FLAVA 模型约有 3.5 亿参数,并使用 ViT-B16 配置,用于图像和文本编码器。

Reference:https://arxiv.org/pdf/2010.11929.pdf

多模态融合 transformer 沿用了单模态编码器,但层数只有之前的 1/2。PyTorch 开发团队一直在探索增加编码器的尺寸,以适应更大的 ViT 变量 (variant)。

扩展 FLAVA 的另一个方面,就是增加批尺寸。FLAVA 巧妙利用了 in-batch negative 的对比损失,这通常只在大批尺寸中才有。

Reference:https://openreview.net/pdfid=U2exBrf_SJh

一般来说,当操作接近最大可能的批尺寸时,也能实现最大的训练效率或吞吐量,这由可用的 GPU 内存数量决定(参见实验部分)。

下表演示了不同模型配置的输出,实验中已确定每个配置能够适应内存的最大批尺寸。

优化概览

PyTorch 提供了几种原生技术来有效地扩展模型。在下面的章节中会详细介绍三种方法,并演示如何应用这些技术,将 FLAVA 模型扩展到 100 亿参数。

分布式数据并行

分布式训练的一个常见起点是数据并行。数据并行在 GPU 之间复制模型,并进行数据集划分。不同的 GPU 会并行地处理不同的数据分区,并在模型权重更新前同步其梯度(通过 all reduce)。

下图展示了处理一个数据并行(正向迭代、反向迭代和权重更新步骤)的流程:

为了实现数据并行,PyTorch提供了一个原生 API,即 DistributedDataParallel(DDP),它可以作为一个模块封装器 (module wrapper) 使用,如下所示:

from torchmultimodal.models.flava.model import flava_model_for_pretraining
import torch
import torch.distributed as dist

model = flava_model_for_pretraining().cuda()
# Initialize PyTorch Distributed process groups
# Please see https://pytorch.org/tutorials/intermediate/dist_tuto.html for details
dist.init_process_group(backend=”nccl”)
# Wrap model in DDP
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()])

完全分片式数据并行

训练应用程序的 GPU 内存使用可以大致细分为模型输入、中间激活存储(intermediate activation,梯度计算需要用到)、模型参数、梯度和优化器状态。

扩展模型时通常会将这些元素同时增加。当单个 GPU 内存不足时,使用 DDP 扩展模型可能导致内存不足,因为它会在所有 GPU 上复制参数、梯度和优化器状态。

为了减少这种复制并节省 GPU 内存,可以将模型参数、梯度和优化器状态分片给所有 GPU,每个 GPU 只管理一个分片。这个方法参照了微软提出的 ZeRO-3。

这种方法的 PyTorch 原生实现可作为 FullyShardedDataParallel(FSDP)API,已在 PyTorch 1.12 中作为 beta 版功能发布。

在模块的正向和反向迭代过程中,FSDP 会根据计算需要对模型参数进行整合(使用 all-gather),并在计算后重新分片。它使用散射规约集合来同步梯度,以确保分片的梯度是全局平均的。FSDP 中模型的正向迭代和反向迭代流程如下:

使用 FSDP 时要用 API 封装模型的子模块,从而控制某一特定子模块何时被分片或不分片。FSDP 提供了一个开箱即用的 auto-wrapping API、几个封装策略 (wrapping policy) 以及编写策略的能力。

以下示例演示了如何用 FSDP 封装 FLAVA 模型。指定自动封装策略为:transformer_auto_wrap_policy 。这将把单个 transformer 层(TransformerEncoderLayer)、图像 transformer (ImageTransformer)、文本编码器 (BERTTextEncoder) 和多模态编码器 (FLAVATransformerWithoutEmbeddings)封装为单个 FSDP 单元。

这采用了一种递归封装的方法来进行有效的内存管理。例如,在单个 transformer 层的正向或反向迭代完成后,删除参数、释放内存从而减少了峰值内存使用。

FSDP 还提供了一些可配置的选项来调整应用程序的性能,如本例中 limit_all_gathers 的使用。它可以防止过早地收集所有模型参数,减轻应用程序的内存压力。

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torchmultimodal.models.flava.text_encoder import BertTextEncoder
from torchmultimodal.models.flava.image_encoder import ImageTransformer
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining().cuda()
dist.init_process_group(backend=”nccl”)

model = FSDP(
               model,
               device_id=torch.cuda.current_device(),
               auto_wrap_policy=partial(
                   transformer_auto_wrap_policy,
                   transformer_layer_cls={
                       TransformerEncoderLayer,
                       ImageTransformer,
                       BERTTextEncoder,
                       FLAVATransformerWithoutEmbeddings
                   },
               ),
               limit_all_gathers=True,
           )

activation checkpointing

如上,中间激活存储 (intermediate activation)、模型参数、梯度和优化器状态会影响 GPU 内存的使用。FSDP 可以减少后三者带来的内存消耗,但不能减少激活所消耗的内存。激活所使用的内存随着批尺寸或隐藏层数量的增加而增加。

activation checkpointing 通过在反向迭代过程中重新计算激活,而非将其保存在特定 checkpointed 模块的内存中,来减少内存的使用。

例如,通过对 27 亿参数模型应用 activation checkpointing,正向迭代后的活动内存峰值减少了 4 倍。

PyTorch 提供了一个基于 wrapper 的 activation checkpointing API。且 checkpoint_wrapper允许用户通过 check 封装单个模块,apply_activation_checkpointing 允许用户指定策略在整个模块中用 checkpointing 封装模块。

这两个 API 可以应用于大多数模型,因为它们不需要对模型定义代码进行任何修改。

然而,如果需要对 checkpointed segment 进行更细化的控制,如对模块内的特定功能进行 checkpointing,可以利用 torch.utils.checkpoint API,这需要修改模型代码。

activation checkpointing wrapper 对单个 FLAVA transformer 层(用 TransformerEncoderLayer 表示)的应用如下所示:

from torchmultimodal.models.flava.model import flava_model_for_pretraining
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer

model = flava_model_for_pretraining()
checkpoint_tformer_layers_policy = lambda submodule: isinstance(submodule, TransformerEncoderLayer)

apply_activation_checkpointing(
               model,
               checkpoint_wrapper_fn=checkpoint_wrapper,
               check_fn=checkpoint_tformer_layers_policy,
           )

如上所示,用 activation checkpointing 封装 FLAVA transformer 层,用 FSDP 封装整体模型,可以将 FLAVA 扩展到 100 亿参数。

实验

对于上文提到的不同优化方法,我们将进一步实验其对系统性能的影响。

背景介绍:

  • 使用含 8 个 A100 40 GB GPU 的单节点

  • 运行 1000 次迭代预训练

  • 使用 bfloat16 数据类型的 PyTorch 混合精度训练 (automatic mixed precision)

  • 启用 TensorFloat32 格式,提高 A100 上的 matmul 性能

  • 将吞吐量定义为每秒处理的平均项目数(测量吞吐量时忽略前 100 次迭代)

  • 训练收敛及其对下游任务指标的影响,会作为未来研究的新方向

图 1 显示了每个模型配置和优化的吞吐量,local batch size 为 8,在 1 个节点上可能的最大 batch size。优化的模型变体 (model variant) 没有数据点,说明该模型无法在单个节点上训练。

​图1:不同配置下的训练吞吐量

图 2 展示了所有 GPU 在每个优化中可能的最大批尺寸。

图2:不同配置下可能的最大本地批尺寸

从中可以观察到:

1. 扩展模型尺寸:

DDP 只能在一个节点上适应 350M 和 900M 的模型。使用 FSDP 可以节省内存,所以能够训练比 DDP 大 3 倍的模型(即 1.8B 和 2.7B 的变体)。将激活检查点(AC)与 FSDP 结合起来,可以训练更大的模型,约为 DDP 的 10 倍(如 4.8B 和 10B 变体)。

2. 吞吐量:

- 对于较小的模型,当批尺寸为 8 时,DDP 的吞吐量略高于或等于 FSDP,可以解释为 FSDP 需要额外的通信。FSDP 和 AC 结合在一起的吞吐量最低。这是因为 AC 在反向迭代的过程中,重新运行 checkpointed 正向迭代通道,为了节省内存牺牲了额外的计算。然而,对于 2.7B 模型,与单独的 FSDP 相比,FSDP + AC 实际上具有更高的吞吐量。这是因为带有 FSDP 的 2.7B 模型即使在批处尺寸为 8 的情况下也接近内存的极限,会触发 CUDA malloc retry,从而导致训练速度减慢。AC 有助于减少内存压力导致 no retry。

- 对于 DDP 和 FSDP + AC,模型的吞吐量会随着批尺寸的增加而增加。FSDP 对较小的变体也是如此。然而,对于 1.8B 和 2.7B 参数模型,当增加批尺寸时,吞吐量下降。一个潜在的原因是,在内存极限时,PyTorch 的 CUDA 内存管理可能不得不重试 cudaMalloc 调用或运行成本高昂的碎片整理 (defragmentation),以找到空闲的内存块来处理工作负载的内存需求,这可能导致训练速度减慢。

- 对于只能用 FSDP 训练的大模型 (1.8B,2.7B,4.8B) 而言,最高吞吐量的设置是用 FSDP+AC 扩展到最大批尺寸。对于 10B,可以观察到小批尺寸和最大批尺寸的吞吐量几乎相等。这是因为 AC 会导致计算量增加,而最大批尺寸可能会由于在 CUDA 内存限制下运行,导致成本高昂的碎片整理操作。然而,对于这些大模型,批尺寸的增加足以抵销这种开销。

3. 批尺寸:

与 DDP 相比,单独使用 FSDP 可以实现略高的批尺寸。对于 350M 参数模型,使用 FSDP+AC 可以实现比 DDP 高 3 倍的批尺寸,对于 900M 参数模型,可以实现 5.5 倍的批尺寸。即使是 10B,最大的批尺寸也约是 20,这相当不错。FSDP+AC 基本上可以用较少的 GPU 实现较大的全局批尺寸,对对比学习任务 (contrastive learning task) 特别有效。

结论

随着多模态基础模型的发展,扩展模型参数和高效训练正在成为一个重点领域。PyTorch 生态系统旨在通过提供不同的工具,加速训练和扩展多模态模型。

未来,PyTorch 将增加对其他类型模型的支持,比如多模态生成模型,以及提升相关技术的自动化。欢迎大家持续关注 PyTorch 开发者社区公众号,你也可以扫码备注「PyTorch」,加入 PyTorch 社群。

PyTorch 官方博客、教程

最新进展、最佳实践

扫码备注加入讨论组

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/51635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GOM传奇引擎登录器商业版与免费版的区别

商业版与免费版登录器的区别: 1商业版自定义界面功能可以保存配置 2商业版登录器支持读取二次加密的Pak。需要购买Pak二次加密工具。 3商业版增加数字证书,防止杀毒软件误报 4商业版支持163博客远程列表,列表首尾需要$BEGIN $END关键字 5商业…

如何提高网站安全防护?

网站安全是网站建设后非常关键的一个问题,是不可以忽视的,一个安全性不高的网站对于网站的危害是很多的,网站安全性不高的网站容易被攻击、容易被挂马、容易造成网站数据泄露,下面安全狗小编来跟大家聊一下网站安全性该如何提高。…

Git常见操作

什么是Git 简单说,git就是版本管理工具。 Git解决的问题 试想一下,你是公司的设计人员。老板要求你设计一份海报。你花了3天时间,画好了,并命名为海报1.0.然后你给老板看,老板看了说“设计的很好,要是能…

将 Vue.js 项目部署至静态网站托管,并开启 Gzip 压缩

摘要:关于使用 Nginx 开启静态网站 Gzip 压缩的教程已经有很多了,但是好像没几个讲怎么在对象存储的静态网站中开启 Gzip 压缩。其实也不复杂,我们一起来看下~本文分享自华为云社区《将 Vue.js 项目部署至静态网站托管,并开启 Gzi…

基于模糊推理的滑膜控制

目录 前言 1.系统描述 2.控制器设计 3.模糊推理估计不确定f 3.1构造模糊系统 3.2模糊推理过程 3.3 自适应律设计 4.仿真分析 4.1仿真模型 4.2仿真结果 5.总结 前言 在一般的建模仿真中,我们假设模型都是可以用数学模型描述出来的是确定的,称…

Flink系列之Flink集群搭建

title: Flink系列 二、Flink集群搭建 2.1 Flink的Standalone模式集群安装 1、上传解压重命名 [roothadoop10 software]# tar -zxvf flink-1.14.3-bin-scala_2.12.tgz [roothadoop10 software]# mv flink-1.14.3 flink2、进入到解压之后的目录里面修改配置文件flink-conf.yam…

Hash 的定义

Hash,一般翻译做散列、杂凑,或音译为哈希。 这句话就是很多混乱的根源。 笔者还是比较时候直接使用 哈希这个翻译,或者干脆不翻译。 混乱来源 在查看很多资料的时候,经常会看到最多的一个词就是散列算法。 如果不深入追究下的…

PyQt5 数据库处理

PyQt5 数据库处理SQLite介绍连接数据库执行SQL语句创建SQLite数据库关闭窗口时断开SQLite连接数据库模型视图列表模式显示数据栅格模式分页显示数据SQLite介绍 SQLite是一个轻量级的数据库,实现了自给自足、无服务器、零配置、事务性的SQL数据库引擎。 下载地址 安…

深入理解ThreadLocal源码

1. 预备知识:强软弱虚引用 在Java中有四种引用的类型:强引用、软引用、弱引用、虚引用。 设计这四种引用的目的是可以用程序员通过代码的方式来决定对象的生命周期,方便GC。 强引用 强引用是程序代码中最广泛使用的引用,如下&a…

如何通过股票行情接口查询财务数据?

我们做交易,有时候还是需要用到一些上市公司的财务数据的,有什么板块可以快速获取财务数据呢?那肯定就是利用股票行情接口进行查询了,那具体要怎么做呢?下面这组代码可以了解一下: get_fundamentals - 查询…

Markdown格式表情包大全最新整理分享

Markdown表情包一、前言❤️二、Emoji表情大全👮People(人物)❄️Nature(自然)🔔Objects(物体)🏠Places(地点)🔟Symbols(符…

如何选择独立站ERP系统?

在选择ERP系统时所需要考虑以下几个问题,首先是看看ERP的操作流程是否简单明了。ERP最核心的作用就是提升工作效率,如果操作流程过于复杂,反倒是会增加学习成本,因此快速上手是先决条件。 其次便是需要看看功能是否符合卖家的需…

如何快速编辑图片?轻量级图片在线处理工具使用教程

不管在生活还是工作的时间里,图片都是经常会使用到的,但是可能在使用图片的时候,需要根据要求来做图片处理(在线ps 图片编辑制作工具 免费照片编辑器_压缩图)。比如我们常用的jpg、png、gif三种图片格式,经…

二叉树,平衡二叉树,B树,B+树,红黑树

1.普通树 A为整个树的根节点。而B,C,D可以看做子树的根节点,在下面分别长出三棵子树。 二、二叉树概念及结构 1.概念 一棵二叉树是结点的一个有限集合,该集合或者为空,或者是由一个根节点加上两棵别称为左子树和右子…

Python Pubg 武器自动识别与压枪 全过程记录

博文目录 文章目录环境准备压枪原理需求分析求两张图片的相似度背包检测 是否在背包界面武器识别名称识别 纯白计数法配件识别 瞄具/枪口/握把/枪托 相似对比法模式识别 全自动/半自动/单发姿态识别 站/蹲/爬余弹识别激活识别 是否持有武器/一号武器/二号武器 (未完成, 做不下去…

Qt通过ODBC连接openGauss数据库

文章目录前言一、Qt链接测试1.测试代码2.测试效果二、环境搭建1.通过ODBC连接openGauss数据库2. 环境测试三、Qt通过ODBC操作数据库1.查询数据1.插入数据3.更新数据总结前言 本文就介绍了Qt通过ODBC连接opengauss数据库的基础内容。 一、Qt链接测试 1.测试代码 在.pro文件中…

java(面向对象)的23种设计模式(11)——观察者模式

一、定义 观察者模式:指多个对象间存在一对多的依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都得到通知并被自动更新。 换种说法,定义两种对象,观察者和目标对象,多个观察者同时监听一个目标对…

高等数学基础概念的Python开发实现

一般的数学算式math函数库就可以解决了,如果是涉及到高等数学极限,微积分等知识,就需要用到sympy科学计算库,它是专门用来解决数学的运算问题的。 Sympy是一个符号计算的Python库。它的目标是成为一个全功能的计算机代数系统&…

你的团队是王者还是青铜(上)

(图片来源:https://unsplash.com/photos/RxOrX1iW15A) 4月18日早上9点30分,团队跟着大屏计时器整齐地喊出倒计时,“五、四、三、二、一”,Tech Lead 强哥和 PO 小楠相对看了一眼,一起按下了eart…

双端 Diff 算法原理解析及 snabbdom 简单实现

一、准备工作 先找个放猪的容器canvas,这里宽设置了1200&#xff0c;高设置了600 <canvas width"1200" height"600" id"canvas">当前浏览器不支持canvas元素</canvas> 然后获取它进行操作 const canvas document.getElementById(…