Stable Diffusion——SDXL Turbo让 AI 出图速度提高10倍

news2024/11/26 0:25:47

摘要

在本研究中,我们提出了一种名为对抗扩散蒸馏(ADD)的创新训练技术,它能够在1至4步的采样过程中,高效地对大规模基础图像扩散模型进行处理,同时保持图像的高质量。该方法巧妙地结合了分数蒸馏技术,利用现有的大型图像扩散模型作为指导信号,同时引入对抗损失,确保在较少的采样步骤下依然能够产生高保真度的图像。通过这种结合,我们的模型在单步操作中的表现显著超越了现有的少步骤方法,包括生成对抗网络(GAN)和潜在一致性模型。更为引人注目的是,ADD在仅经过四个采样步骤的情况下,就能达到与当前最先进的扩散模型SDXL相媲美的性能水平。这一成就标志着ADD成为首个能够利用基础模型实现单步实时图像合成的技术,为图像生成领域带来了新的突破。

论文地址: https://stability.ai/research/adversarial-diffusion-distillation

1.简介

扩散模型(DMs)在生成建模领域扮演着核心角色,尤其在高质量图像和视频合成方面取得了显著的进展。DMs的一个关键优势在于它们的可伸缩性和迭代性,使得它们能够处理复杂的任务,如根据自由格式文本提示合成图像。然而,DMs的迭代推理过程需要大量的采样步骤,这限制了它们在实时应用中的使用。相比之下,生成式对抗网络(GANs)以其单步生成和快速特性而闻名。尽管GANs尝试扩展到大型数据集,但在样本质量方面往往无法与DMs相媲美。本研究的目标是将DMs的高质量样本生成能力与GANs的速度优势结合起来。

我们提出了一种概念上简单的方法——对抗扩散蒸馏(ADD),这是一种将预训练扩散模型的推理步骤减少到1至4个采样步骤的通用方法,同时保持高采样保真度,并可能提升模型的整体性能。我们结合了两个训练目标:对抗损失和与分数蒸馏抽样(SDS)相对应的蒸馏损失。对抗损失确保模型在每次前向传递时直接生成高质量的图像样本,避免了其他蒸馏方法中常见的模糊和伪影。蒸馏损失则利用另一个预训练的DM作为教师模型,有效地利用其丰富的知识,并保持大型DMs的强组合性。在推理过程中,我们的方法避免了无分类器引导,从而进一步降低了内存需求,同时保留了通过迭代细化改进结果的能力,这比以往的基于GAN的一步方法更具优势。

主要贡献包括:

  • 引入了ADD,这是一种能够将预训练的扩散模型转化为高保真的实时图像生成器的方法,仅需1至4个采样步骤。
  • 我们的方法结合了对抗性训练和分数蒸馏,经过精心设计,显著优于现有的强基线模型,如LCM、LCM-XL和单步GAN,并能够在单个推理步骤中处理复杂的图像组成,同时保持高图像真实感。
  • 使用四个采样步骤的ADD-XL在512像素分辨率上的表现甚至超过了其教师模型SDXL-Base。

2. 研究背景

扩散模型在合成和编辑高分辨率图像及视频方面取得了显著成就,但其迭代性质限制了在实时应用中的使用。潜在扩散模型尝试通过在潜在空间中表示图像来降低计算成本,但仍然依赖于大型模型的迭代处理。为了解决这一问题,研究者们探索了更快的采样器和模型蒸馏技术,如渐进蒸馏和引导蒸馏,尽管这些方法减少了采样步骤,但可能会牺牲原始性能,并需要迭代训练过程。

一致性模型通过在ODE轨迹上施加一致性正则化来提高性能,并在少量样本设置中展示了基于像素的模型的强大能力。其他方法,如InstaFlow,推荐使用修正流来优化蒸馏过程。然而,这些方法通常在较少的采样步骤下会产生模糊和伪影问题。

与此同时,生成对抗网络(GANs)作为单步模型,虽然在文本到图像合成方面表现出快速的采样速度,但其性能通常不如基于扩散的模型。这是因为GANs的训练需要精细地平衡特定架构以实现稳定的对抗目标,而且在不破坏这种平衡的情况下扩展模型并集成神经网络架构的进步是具有挑战性的。目前最先进的文本到图像GANs缺乏像无分类器制导这样的技术,这对于大规模的DMs来说至关重要。

分数蒸馏抽样是一种新提出的方法,用于将基础T2I模型的知识转移到3D综合模型中。一些基于分数蒸馏抽样的作品专注于3D对象的逐场景优化,并已应用于文本到3D视频合成和图像编辑。最近的研究展示了基于分数的模型和GANs之间的紧密联系,并提出了使用来自DM的基于分数的扩散流进行训练的Score gan。类似地,DiffInstruct是一种推广分数蒸馏抽样的方法,它可以将预训练的扩散模型转化为不带鉴别器的生成器。

我们的方法结合了对抗性训练和分数蒸馏,旨在解决当前几步生成模型中的最佳表现问题。通过这种混合目标的方法,我们旨在提升模型的性能,同时减少采样步骤,以实现更快、更高质量的图像合成。

3. 研究方法

我们的目标是在尽可能少的采样步骤中生成高保真度的样本,同时匹配最先进模型的质量[7,50,53,55]。对抗性目标[14,60]自然适合快速生成,因为它训练的模型可以在单个向前步骤中输出图像。然而,将gan扩展到大型数据集的尝试[58,59]发现,不仅要依赖鉴别器,还要使用预训练的分类器或CLIP网络来改善文本对齐,这一点至关重要。如文献[59]所述,过度使用判别网络会引入伪影,影响图像质量。相反,我们通过分数蒸馏目标利用预训练扩散模型的梯度来改善文本对齐和样本质量。此外,我们使用预训练的扩散模型权值初始化模型,而不是从头开始训练;已知预训练生成器网络可以显著改善具有对抗性损失的训练[15]。最后,我们采用了标准的扩散模型框架,而不是使用仅用于GAN训练的解码器架构[26,27]。这种设置自然支持迭代细化。

3.1. 训练过程

我们的训练过程如图2所示,涉及三个网络:ADD-student从一个权重为θ的预训练UNet-DM初始化,一个权重为φ的可训练鉴别器初始化,一个权重为ψ的DM教师初始化。在训练过程中,ADD-student从有噪声的数据xs中生成样本δ xθ(xs, s)。通过前向扩散过程xs = αsx0 + σ sε,从真实图像数据集x0产生带噪数据点。在我们的实验中,我们使用相同的系数αs和σs作为学生DM和样本s,均匀地来自集合Tstudent = {τ1,…, τn} (N个选择的学生时间步长)在实践中,我们选择N = 4。

图片

重要的是,我们设置τn = 1000,并在训练过程中强制执行zero-terminalSNR[33],因为模型在推理过程中需要从纯噪声开始。

对于对抗目标,生成的样本x0和真实图像x0被传递给鉴别器,目的是区分它们。鉴别器和对抗性损耗的设计将在3.2节中详细描述。为了从DM老师那里提取知识,我们用老师的前向过程将学生样本δ xθ扩散到δ xθ,t,并使用老师的去噪预测δ xψ(δ xθ,t, t)作为蒸馏损失Ldistill的重建目标,参见3.3节。因此,总体目标是

图片

虽然我们在像素空间中制定了我们的方法,但很容易将其适应于在潜在空间中工作的ldm。当使用师生共享潜在空间的ldm时,蒸馏损失可以在像素或潜在空间中计算。我们在像素空间中计算蒸馏损失,因为这在蒸馏潜在扩散模型时产生更稳定的梯度[72]。

3.2. 对抗损失

对于鉴别器,我们遵循[59]中提出的设计和训练程序,并对其进行简要总结;有关细节,我们建议读者参阅原文。我们使用一个固定的预训练特征网络F和一组可训练的轻量级鉴别器头d φ,k。对于特征网络F, Sauer等[59]发现视觉变压器(ViTs)[9]工作良好,我们在第4节中对ViTs目标和模型尺寸进行了不同的选择。将可训练判别器头应用于特征网络不同层的特征Fk上。

为了提高性能,鉴别器可以通过投影以附加信息为条件[46]。通常,在文本到图像设置中使用文本嵌入ctext。但是,与标准GAN训练相比,我们的训练配置还允许对给定图像进行条件设置。当τ < 1000时,ADD-student接收到来自输入图像x0的信号。因此,对于给定的生成样本xθ(xs, s),我们可以根据来自x0的信息来约束鉴别器。这鼓励add学生有效地利用输入。在实践中,我们使用附加的特征网络来提取嵌入cimg的图像。

接下来[57,59],我们使用hinge loss[32]作为对抗目标函数。因此,add学生的对抗性目标Ladv(xθ(xs, s), ϕ)等于

图片

而鉴别器被训练成最小化

图片

其中R1为R1梯度惩罚[44]。我们不是计算相对于像素值的梯度惩罚,而是在每个鉴别器头dpn,k的输入上计算它。我们发现,当输出分辨率大于128像素时,R1惩罚特别有益。

3.3. 分数蒸馏损失

式(1)中的蒸馏损失表示为

图片

式中sg为停止梯度操作。直观地说,损失使用距离度量d来测量由ad -student和DM-teacher的输出生成的样本xθ之间的不匹配:xψ(φ xθ,t, t) = (φ xθ,t - σt φ ϵψ(φ xθ,t, t))/αt随时间步长t和噪声ε′的平均值。值得注意的是,教师并没有直接应用于ADD-student的生成的δ xθ上,而是应用于扩散的输出δ xθ上,t = αt δ xθ + σ δ δ ',因为对于教师模型来说,非扩散的输入将不在分布范围内[68]。

下面,我们定义距离函数d(x, y):= ||x−y||2。对于加权函数c(t),我们考虑两种选择:指数加权,其中c(t) = αt(噪声水平越高贡献越小),以及分数蒸馏抽样(SDS)加权[51]。在补充材料中,我们证明了当d(x, y) = ||x−y||2以及c(t)的特定选择时,我们的蒸馏损失等于SDS目标LSDS,如[51]中提出的那样。我们的公式的优点是它能够实现重建目标的直接可视化,并且它自然地促进了几个连续去噪步骤的执行。最后,我们还评估了无噪声分数蒸馏(NFSD)目标,这是最近提出的SDS的变体[28]。

4. 实验

在我们的实验中,我们训练了两个不同容量的模型,ADD-M (860M参数)和ADD-XL (3.1B参数)。为了消融ADD-M,我们使用了稳定扩散(SD) 2.1主干[54],为了与其他基线进行公平比较,我们使用了SD1.5。ADD-XL使用SDXL[50]骨干网。所有实验均在512x512像素的标准化分辨率下进行;产生更高分辨率的模型的输出被下采样到这个大小。

我们在所有实验中采用λ = 2.5的蒸馏加权因子。此外,R1惩罚强度γ设为10−5。对于判别器条件,我们使用预训练的clip - vit -g-14文本编码器[52]来计算文本嵌入上下文,并使用DINOv2 vit- l编码器[47]的CLS嵌入来进行图像嵌入。对于基线,我们使用了最好的公开可用模型:潜在扩散模型[50,54](SD1.5, SDXL)级联像素扩散模型[55](IF-XL),蒸馏扩散模型[39,41](LCM-1.5, LCM-1.5- xl),以及OpenMUSE [48],这是MUSE[6]的重新实现,是专门为快速推理而开发的转换器模型。注意,我们比较的是没有附加细化模型的sdxl - base -1.0模型;这是为了保证公平的比较。由于没有公开的最先进的GAN模型,我们使用改进的鉴别器重新训练StyleGAN-T[59]。这个基线(stylegan - t++)在FID和CS中明显优于之前最好的gan,见补充。我们通过FID[18]量化样本质量,通过CLIP评分[17]量化文本对齐。对于CLIP评分,我们使用LAION-2B训练的ViT-g-14模型[61]。这两个指标在来自COCO2017的5k个样本上进行了评估[34]。

4.1. 消融实验

我们的训练设置在对抗性损失、蒸馏损失、初始化和损失相互作用方面开辟了许多设计空间。我们对表1中的几个选择进行了消融研究;每个表格下面突出显示了关键的见解。我们将在下文中讨论每个实验。

图片

鉴别器特征网络。(表1)。Stein等人[67]最近的见解表明,使用CLIP[52]或DINO[5,47]目标训练的vit特别适合评估生成模型的性能。同样地,这些模型作为鉴别器特征网络似乎也很有效,其中DINOv2成为最佳选择。

鉴别器。(表1 b)。与先前的研究类似,我们观察到鉴别器的文本条件反射增强了结果。值得注意的是,图像条件反射优于文本条件反射,并且文本和图像条件反射的组合产生了最好的结果。

学生pretraining。(表1 c)。我们的实验证明了对add学生进行预训练的重要性。与纯GAN方法相比,能够使用预训练的生成器是一个显著的优势。gan的一个问题是缺乏可扩展性;Sauer等[59]和Kang等[25]都观察到在达到一定网络容量后,性能会达到饱和。这一观察结果与dm一般平滑的标度规律形成了对比[49]。然而,ADD可以有效地利用更大的预训练DM(见表1c),并从稳定的DM预训练中获益。

损失项。(表1 d)。我们发现这两种损失都是必不可少的。蒸馏损失本身是无效的,但当与对抗损失相结合时,结果会有明显的改善。不同的权重计划导致不同的行为,指数计划倾向于产生更多样化的样本,正如较低的FID, SDS和NFSD计划所表明的那样,可以提高质量和文本对齐。当我们使用指数计划作为所有其他消融的默认设置时,我们选择NFSD加权来训练我们的最终模型。选择最优权重函数提供了改进的机会。或者,如3D生成建模文献[23]中所探讨的,可以考虑在训练上调度蒸馏权值。

老师类型。(表1 e)。有趣的是,一个更大的学生和老师并不一定会带来更好的FID和CS。相反,学生采用教师的特点。SDXL通常获得更高的FID,可能是因为其输出的多样性较少,但它具有更高的图像质量和文本对齐[50]。

老师的步数。(表1)。虽然我们的蒸馏损失公式允许通过构造与老师一起采取连续的几个步骤,但我们发现几个步骤并不能最终导致更好的性能。

4.2. 与最新技术的定量比较

对于我们与其他方法的主要比较,我们避免使用自动化指标,因为用户偏好研究更可靠[50]。在这项研究中,我们的目标是评估及时的依从性和整体形象。作为一种性能度量,我们在比较几种方法时计算两两比较的获胜百分比和ELO分数。对于报告的ELO分数,我们计算提示跟随和图像质量之间的平均分数。关于ELO分数计算和研究参数的详细信息在补充材料中列出。

图片

图片

图5、图6为研究结果。最重要的结果是:首先,ADD-XL仅用一步就优于LCM-XL(4步)。其次,ADD-XL在大多数比较中可以击败SDXL(50步)。这使得ADD-XL在单步和多步设置中都是最先进的。图7显示了相对于推理速度的ELO分数。最后,表2比较了使用相同基本模型的不同小步取样和蒸馏方法。ADD优于所有其他方法,包括标准的八步DPM求解器。

图片

4.3. 定性结果

为了补充我们上面的定量研究,我们在本节中提出了定性结果。为了描绘更完整的画面,我们在补充材料中提供了额外的样本和定性比较。图3比较了ADD-XL(1步)和当前的最佳基线。ADD-XL的迭代采样过程如图4所示。这些结果显示了我们的模型在初始样本基础上的改进能力。这种迭代改进代表了与纯GAN方法(如stylegan - t++)相比的另一个显著优势。最后,图8将ADD-XL与其教师模型SDXL-Base进行了直接比较。从4.2节的用户研究中可以看出,ADD-XL在质量和快速对齐方面都优于它的老师。

图片

图片

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1578731.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用TensorBoard可视化PyTorch

一、TensorBoard与PyTorch配合使用的基本步骤 PyTorch可以直接与TensorBoard进行集成&#xff0c;因为TensorBoard是一个独立于TensorFlow之外的可视化工具。TensorBoard被设计为支持机器学习实验的可视化&#xff0c;如训练的进度和结果等。PyTorch中的torch.utils.tensorboa…

【数据结构】考研真题攻克与重点知识点剖析 - 第 6 篇:图

前言 本文基础知识部分来自于b站&#xff1a;分享笔记的好人儿的思维导图与王道考研课程&#xff0c;感谢大佬的开源精神&#xff0c;习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析&#xff0c;本人技术…

智慧安防系统EasyCVR视频汇聚平台接入大华设备无法语音对讲的原因排查与解决

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台支持7*24小时实时高清视频监控&#xff0c;能同时播放多路监控视频流&#xff0c;视频画面1、4、9、16个可选&#xff0c;支持自定义视频轮播。EasyCVR平台可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标…

数据可视化-ECharts Html项目实战(10)

在之前的文章中&#xff0c;我们学习了如何在ECharts中编写雷达图&#xff0c;实现特殊效果的插入运用&#xff0c;函数的插入&#xff0c;以及多图表雷达图。想了解的朋友可以查看这篇文章。同时&#xff0c;希望我的文章能帮助到你&#xff0c;如果觉得我的文章写的不错&…

甲方安全建设之研发安全-SCA

前言 大多数企业或多或少的会去采购第三方软件&#xff0c;或者研发同学在开发代码时&#xff0c;可能会去使用一些好用的软件包或者依赖包&#xff0c;但是如果这些包中存在恶意代码&#xff0c;又或者在安装包时不小心打错了字母安装了错误的软件包&#xff0c;则可能出现供…

shrine-攻防世界

题目 代码 import flask import os app flask.Flask(__name__) app.config[FLAG] os.environ.pop(FLAG) app.route(/) def index(): return open(__file__).read() app.route(/shrine/) def shrine(shrine): def safe_jinja(s): s s.replace((, ).replace(), ) …

算法之美:缓存数据淘汰算法分析及分解实现

在设计一个系统的时候&#xff0c;由于数据库的读取速度远小于内存的读取速度&#xff0c;那么为加快读取速度&#xff0c;需先将一部分数据加入到内存中&#xff08;该动作称为缓存&#xff09;&#xff0c;但是内存容量又是有限的&#xff0c;当缓存的数据大于内存容量时&…

nodejs+python基于vue的羽毛球培训俱乐部管理系统django

语言&#xff1a;nodejs/php/python/java 框架&#xff1a;ssm/springboot/thinkphp/django/express 请解释Flask是什么以及他的主要用途 Flask是一个用Python编写的清凉web应用框架。它易于扩展且灵活&#xff0c;适用于小型的项目或者微服务&#xff0c;以及作为大型应用的一…

spring eureka 服务实例实现快速下线快速感知快速刷新配置解析

背景 默认的Spring Eureka服务器&#xff0c;服务提供者和服务调用者配置不够灵敏&#xff0c;总是服务提供者在停掉很久之后&#xff0c;服务调用者很长时间并没有感知到变化。或者是服务已经注册上去了&#xff0c;但是服务调用方很长时间还是调用不到&#xff0c;发现不了这…

【Mysql高可用集群-双主双活-myql+keeplived】

Mysql高可用集群-双主双活-myqlkeeplived 一、介绍二、准备工作1.两台centos7 linux服务器2.mysql安装包3.keepalived安装包 三、安装mysql1.在128、129两台服务器根据《linux安装mysql服务-两种安装方式教程》按方式一安装好mysql应用。2.修改128服务器/etc/my.cnf配置文件&am…

第8章 数据集成和互操作

思维导图 8.1 引言 数据集成和互操作(DII)描述了数据在不同数据存储、应用程序和组织这三者内部和之间进行移动和整合的相关过程。数据集成是将数据整合成物理的或虚拟的一致格式。数据互操作是多个系统之间进行通信的能力。数据集成和互操作的解决方案提供了大多数组织所依赖的…

携程旅行 abtest

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01;wx a15018601872 本文章…

Java 基于微信小程序的助农扶贫小程序

博主介绍&#xff1a;✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3fb; 不…

React - 你知道useffect函数内如何模拟生命周期吗

难度级别:中级及以上 提问概率:65% 很多前端开发人员习惯了Vue或者React的组件式开发,熟知组件的周期过程包含初始化、挂载完成、修改和卸载等阶段。但是当使用Hooks做业务开发的时候,看见一个个useEffect函数,却显得有些迷茫,因为在us…

Flutter之Flex组件布局

目录 Flex属性值 轴向:direction:Axis.horizontal 主轴方向:mainAxisAlignment:MainAxisAlignment.center 交叉轴方向:crossAxisAlignment:CrossAxisAlignment 主轴尺寸:mainAxisSize 文字方向:textDirection:TextDirection 竖直方向排序:verticalDirection:VerticalDir…

Java 线程池 参数

1、为什么要使用线程池 线程池能有效管控线程&#xff0c;统一分配任务&#xff0c;优化资源使用。 2、线程池的参数 创建线程池&#xff0c;在构造一个新的线程池时&#xff0c;必须满足下面的条件&#xff1a; corePoolSize&#xff08;线程池基本大小&#xff09;必须大于…

JVM流程图自我总结

JVM流程图总览 运行时数据区是否有GC、OOM图 从线程共享角度区别图

【深度学习】最强算法之:图神经网络(GNN)

图神经网络 1、引言2、图神经网络2.1 定义2.2 原理2.3 实现方式2.4 算法公式2.4.1 GNN2.4.2 GCN 2.5 代码示例 3、总结 1、引言 小屌丝&#xff1a;鱼哥&#xff0c;给俺讲一讲图神经网络啊 小鱼&#xff1a;你看&#xff0c;我这会在忙着呢 小屌丝&#xff1a;啊~ 小鱼&#…

如何在Rust中操作JSON

❝ 越努力&#xff0c;越幸运 ❞ 大家好&#xff0c;我是「柒八九」。一个「专注于前端开发技术/Rust及AI应用知识分享」的Coder。 前言 我们之前在Rust 赋能前端-开发一款属于你的前端脚手架中有过在Rust项目中如何操作JSON。 由于文章篇幅的原因&#xff0c;我们就没详细介绍…

java算法day48 | 动态规划part09 ● 198.打家劫舍 ● 213.打家劫舍II ● 337.打家劫舍III

198.打家劫舍 class Solution {public int rob(int[] nums) {if(nums.length0) return 0;if(nums.length1) return nums[0];int[] dpnew int[nums.length];dp[0]nums[0];dp[1]Math.max(nums[1],nums[0]);for(int i2;i<nums.length;i){dp[i]Math.max(dp[i-1],dp[i-2]nums[i])…