MicroDiffusion——采用新的掩码方法和改进的 Transformer 架构,实现了低预算的扩散模型

news2025/1/2 17:38:17

介绍

论文地址:https://arxiv.org/abs/2407.15811
现代图像生成模型擅长创建自然、高质量的内容,每年生成的图像超过十亿幅。然而,从头开始训练这些模型极其昂贵和耗时。文本到图像(T2I)扩散模型降低了部分计算成本,但仍需要大量资源。

目前最先进的技术需要大约 18 000 个 A100 GPU 小时,而使用 8 个 H100 GPU 进行训练则需要一个多月的时间。此外,该技术通常依赖于大型或专有数据集,因此难以普及。

在这篇评论性论文中,我们开发了一种低成本、端到端文本到图像扩散建模管道,目的是在没有大型数据集的情况下显著降低成本。它侧重于基于视觉变换器的潜在扩散模型,利用其简单的设计和广泛的适用性。为了降低计算成本,可通过随机屏蔽输入标记来减少每幅图像需要处理的斑块数量。本文克服了现有遮蔽方法在高遮蔽率下性能下降的难题。

为了克服文本到图像扩散模型性能不佳的问题,本文提出了一种 "延迟掩蔽 "策略。通过在轻量级补丁混合器中处理补丁,然后将其输入扩散变换器,可以低成本实现可靠的训练,同时即使在高掩蔽率下也能保留语义信息。它还结合了变压器架构的最新发展,以提高大规模训练的性能。

该实验训练了一个 1.16 亿参数的稀疏扩散变换器,预算仅为 1,890 美元、3,700 万张图像和 75% 的屏蔽率。结果,在 COCO 数据集上零镜头生成的 FID 达到了 12.7。在一台 8×H100 GPU 机器上的训练时间仅为 2.6 天,与目前最先进的方法(37.6 天,GPU 成本 28,400 美元)相比缩短了 14 倍。

建议方法

延迟掩蔽

由于变换器的计算复杂度与序列的长度成正比,降低训练成本的一种方法是通过使用大尺寸补丁来减少序列,如图 1-b 所示。使用大尺寸补丁可使每幅图像的补丁数量呈二次方减少,但由于图像的大区域会被主动压缩成单个补丁,因此会显著降低性能。

一种方法是使用遮罩去除变换器输入层中的一些斑块,如图 1-c 所示,同时保持斑块大小。这种方法类似于卷积网络中的随机裁剪训练,但遮蔽补丁可以在图像的非连续区域进行训练。这种方法被广泛应用于视觉和语言领域。

图 1-d 中的MaskDiT还增加了补充自编码损失,以鼓励从遮蔽的斑块中学习表示法,从而促进遮蔽斑块的重建。这种方法屏蔽了 75% 的输入图像,大大降低了计算成本。

图 1.压缩补丁序列以降低计算成本。

然而,高遮罩率会大大降低转换器的整体性能:即使使用 MaskDiT,也只能看到与简单遮罩相比的微弱改进。这是因为即使采用这种方法,大部分图像斑块也会在输入层被去除。

本文引入了一个名为 "补丁混合器 "的预处理模块,用于在屏蔽之前处理补丁嵌入。这可确保未屏蔽的补丁保留整个图像的信息,从而提高学习效率。这种方法有可能提高性能,同时在计算上与现有的 MaskDiT 策略相当。

补丁混合器和学习障碍

补片混合器指的是任何能够融合单个补片嵌入的神经架构。在变换器模型中,这一目标自然可以通过注意机制和前馈层的结合来实现。因此,本文使用轻量级变压器(只有几层)作为补片混合器。输入序列标记经补丁混合器处理后,将被屏蔽(图 2e)。假定掩码为二进制 m,则使用以下损失函数对模型进行训练。

引入专家混合(MoE)和分层缩放的变换器架构

论文采用了先进变压器结构的创新技术,在计算受限的情况下提高了模型性能。

  • 专家混合物(MoE,Zhou 等人,2022 年):使用 MoE 层扩展模型的参数和表现力,同时避免训练成本的显著增加 简化的 MoE 层与专家选择路由允许额外的辅助损失函数无需调整负载。
  • 分层缩放 (Mehta 等人,2024 年 ):这种方法已被证明能提高大型语言模型的性能,其中变换器块的宽度(隐藏层的维度)随深度线性增加。更多的参数被分配给更深的层,以学习更复杂的特征。

整体架构如图 2 所示。

图 2:拟议方法的总体概览。

试验

验证延迟遮蔽和补丁混合器的有效性

当许多补丁被遮蔽时,遮蔽性能会下降;Zheng 等人(2024 年)指出,当遮蔽率超过 50%,MaskDiT 的性能会显著下降。本文评估了遮蔽率高达 87.5% 时的性能,并将其与不使用补丁混合器的传统天真遮蔽方法进行了比较。本文中的 "延迟掩蔽 "使用了一个四层变压器块贴片混合器,其参数小于主干变压器参数的 10%。两者都使用了设置完全相同的 AdamW 优化器。

图 3 对结果进行了总结。延迟掩蔽在所有指标上都明显优于天真掩蔽和MaskDiT,表明随着掩蔽率的增加,性能差异也在扩大。例如,在屏蔽率为 75% 时,原始屏蔽的 FID 得分为 80,MaskDiT 的FID 得分为16.5,而拟议方法的 FID 得分为 5.03,优于未屏蔽时的 3.79。

图 3:验证延迟屏蔽和贴片混频器的有效性。

验证 "专家混合 "和 "分层缩放 "的有效性

分层缩放: 使用 DiT-Tiny 架构进行的实验比较了分层缩放和恒宽变换器与天真屏蔽。两个模型都在相同的计算负荷下进行了相同时间的训练。在所有性能指标上,逐层缩放方法始终优于恒定宽度模型,而且在屏蔽训练中更为有效。

专家混合物(MoE): 测试了在交替区块中具有 MoE 层的 DiT-Tiny/2 变压器。总体性能与无 MoE 层的基线模型相似,Clip-score 略有提高(从 28.11 到 28.66),FID 分数有所下降(从 6.92 到 6.98)。改进幅度有限的原因是 60K 步的训练和每位专家看到的样本量较小。

与以往研究的比较

COCO 数据集(表 1)上的零镜头图像生成: 根据标题生成 30 000 幅图像,并使用 FID-30K 比较其与真实图像的分布情况。拟议方法的FID-30K 得分为 12.66,与之前的低成本训练方法相比,计算成本降低了 14 倍,而且不依赖于专有数据集。该方法的计算成本也比 Würstchen 低 19 倍(Pernias et al.)

表 1:在 COCO 数据集上生成零镜头图像

详细图像生成比较**(表 2)****:**GenEval(Ghosh 等人,2024 年)用于评估生成物体位置、共现、数量和颜色的能力。与 Stable-DiffusionXL-turbo 和 PixArt-α 模型相比,拟议方法在单个物体生成方面的准确性接近完美,与 Stable-Diffusion 变体相当,优于 Stable-Diffusion-1.5。在以下方面也表现出卓越的性能

表 2.详细图像生成对比

总结

本评论文章重点讨论了旨在降低扩散变换器训练计算成本的补丁掩蔽策略。本文提出了一种 "延迟掩蔽 "策略,以缓解现有掩蔽方法的不足,并显示了所有掩蔽比率下的显著性能改进。

特别是使用了 75% 的延迟掩蔽率,并在真实和合成图像数据集上进行了大规模训练。尽管与最先进的技术相比成本大大降低,但还是取得了具有竞争力的零镜头图像生成性能结果。希望这种低成本的训练机制能鼓励更多研究人员参与大规模扩散模型的训练和开发。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2268054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java - 日志体系_Simple Logging Facade for Java (SLF4J)日志门面_SLF4J集成Log4j1.x 及 原理分析

文章目录 Pre官网集成Log4j1.x步骤POM依赖使用第一步:编写 Log4j 配置文件第二步:代码 原理分析1. 获取对应的 ILoggerFactory2. 根据 ILoggerFactory 获取 Logger 实例3. 日志记录过程 小结 Pre Java - 日志体系_Apache Commons Logging(JC…

详解VHDL如何编写Testbench

1.概述 仿真测试平台文件(Testbench)是可以用来验证所设计的硬件模型正确性的 VHDL模型,它为所测试的元件提供了激励信号,可以以波形的方式显示仿真结果或把测试结果存储到文件中。这里所说的激励信号可以直接集成在测试平台文件中,也可以从…

Linux day 11 28

一.Linux简介 1.1 Linux介绍 Linux 是一套免费使用和自由传播的操作系统。说到操作系 统,大家比较熟知的应该就是 Windows 和 MacOS 操作系统, 我们今天所学习的 Linux 也是一款操作系统。 1.2 Linux发展历史 时间: 1991 年 地点&#xf…

【深度学习环境】NVIDIA Driver、Cuda和Pytorch(centos9机器,要用到显示器)

文章目录 一 、Anaconda install二、 NIVIDIA driver install三、 Cuda install四、Pytorch install 一 、Anaconda install Step 1 Go to the official website: https://www.anaconda.com/download Input your email and submit. Step 2 Select your version, and click i…

Lecture 17

10’s Complement Representation 主要内容: 1. 10’s 补码表示: • 10’s 补码表示法需要指定表示的数字位数(用 n 表示)。 • 表示的数字取决于 n 的位数,这会影响具体数值的解释。 2. 举例: • 如果采用 3 位补码&…

惠州市政数局局长杨伟斌:惠州市公共数据授权运营模式探索

近期,2024数字资产管理大会召开。会上,惠州市政务服务和数据管理局局长杨伟斌在会上做了题为基于“隐私计算区块链”的惠州市公共数据授权运营模式探索主旨演讲,从三个方面展开,一是建制度汇数据,二是夯基础保安全&…

C#编写的金鱼趣味小应用 - 开源研究系列文章

今天逛网,在GitHub中文网上发现一个源码,里面有这个金鱼小应用,于是就下载下来,根据自己的C#架构模板进行了更改,最终形成了这个例子。 1、 项目目录; 2、 源码介绍; 1) 初始化; 将样…

探索贝叶斯魔法和误差的秘密

引言 今天我们要一起学习两个神秘的魔法概念:贝叶斯魔法和误差的秘密。这些概念听起来可能有点复杂,但别担心,我会用最简单的方式来解释它们。 一、贝叶斯魔法 贝叶斯魔法是一种预测的魔法,它帮助我们理解在不确定的情况下事情…

深度学习blog-卷积神经网络(CNN)

卷积神经网络(Convolutional Neural Network,CNN)是一种广泛应用于计算机视觉领域,如图像分类、目标检测和图像分割等任务中的深度学习模型。 1. 结构 卷积神经网络一般由以下几个主要层组成: 输入层:接收…

深度学习笔记(6)——循环神经网络RNN

循环神经网络 RNN 核心思想:RNN内部有一个“内部状态”,随着序列处理而更新 h t f W ( h t − 1 , x t ) h_tf_W(h_{t-1},x_t) ht​fW​(ht−1​,xt​) 一般来说 h t t a n h ( W h h h t − 1 W x h x t ) h_ttanh(W_{hh}h_{t-1}W_{xh}x_t) ht​tanh(Whh​ht−1​Wxh​xt…

最新版Edge浏览器加载ActiveX控件技术——alWebPlugin中间件V2.0.28-迎春版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品,致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器,实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

【前端,TypeScript】TypeScript速成(三):枚举类型

枚举类型 枚举类型是 TypeScript 相较于 JavaScript 而言特有的部分。一个简单的枚举声明如下: enum HTTPStatus {OK,NOT_FOUND,INTERNAL_STATUS_ERROR, }与编译成 JavaScript 的代码相比较: 显然 TypeScript 非常的简洁。 尝试使用上述枚举类型&…

Webpack学习笔记(6)

首先搭建一个基本的webpack环境: 执行npm init -y,创建pack.json,保存安装包的一些信息 执行npm install webpack webpack-cli webpack-dev-server html-webpack-plugin -D,出现node_modules和package-lock.json。 1.source-Ma…

Java高频面试之SE-06

hello啊,各位老6!!!本牛马baby今天又来了!哈哈哈哈哈嗝🐶 访问修饰符 public、private、protected的区别是什么? 在Java中,访问修饰符用于控制类、方法和变量的访问权限。主要的访…

报表工具DevExpress Reporting v24.2亮点 - AI功能进一步强化

DevExpress Reporting是.NET Framework下功能完善的报表平台,它附带了易于使用的Visual Studio报表设计器和丰富的报表控件集,包括数据透视表、图表,因此您可以构建无与伦比、信息清晰的报表。 报表工具DevExpress Reporting v24.2将于近期发…

每天40分玩转Django:Django表单集

Django表单集 一、知识要点概览表 类别知识点掌握程度要求基础概念FormSet、ModelFormSet深入理解内联表单集InlineFormSet、BaseInlineFormSet熟练应用表单集验证clean方法、验证规则熟练应用自定义配置extra、max_num、can_delete理解应用动态管理JavaScript动态添加/删除表…

MVCC实现原理以及解决脏读、不可重复读、幻读问题

MVCC实现原理以及解决脏读、不可重复读、幻读问题 MVCC是什么?有什么作用?MVCC的实现原理行隐藏的字段undo log日志版本链Read View MVCC在RC下避免脏读MVCC在RC造成不可重复读、丢失修改MVCC在RR下解决不可重复读问题RR下仍然存在幻读的问题 MVCC是什么…

自学记录鸿蒙API 13:实现人脸比对Core Vision Face Comparator

完成了文本识别和人脸检测的项目后,我发现人脸比对是一个更有趣的一个小技术玩意儿。我决定整一整,也就是对HarmonyOS Next最新版本API 13中的Core Vision Face Comparator API的学习,这项技术能够对人脸进行高精度比对,并给出相似…

代码解析:安卓VHAL的AIDL参考实现

以下内容基于安卓14的VHAL代码。 总体架构 参考实现采用双层架构。上层是 DefaultVehicleHal,实现了 VHAL AIDL 接口,并提供适用于所有硬件设备的通用 VHAL 逻辑。下层是 FakeVehicleHardware,实现了 IVehicleHardware 接口。此类可模拟与实…