SDXS:知识蒸馏在高效图像生成中的应用

news2024/11/24 2:37:20

人工智能咨询培训老师叶梓 转载标明出处

扩散模型虽然在图像生成方面表现出色,但其迭代采样过程导致在低功耗设备上部署面临挑战,同时在云端高性能GPU平台上的能耗也不容忽视。为了解决这一问题,小米公司的Yuda Song、Zehao Sun、Xuanwu Yin等人提出了一种新的方法——SDXS,通过知识蒸馏简化了U-Net和图像解码器架构,并引入了一种创新的一步式DM训练技术,使用特征匹配和得分蒸馏,从而在单GPU上实现了大约100 FPS(比SD v1.5快30倍)和30 FPS(比SDXL快60倍)的推理速度。

图1为在图像生成时间限制为1秒的情况下,不同模型的性能对比。SDXL模型在这种情况下只能使用16次函数评估(NFEs)来生成稍微模糊的图像,而提出的SDXS-1024模型却能够生成30张清晰的图像。这表明SDXS-1024在保持图像质量的同时显著提高了生成速度。本方法还能够训练ControlNet,这是一种能够嵌入空间引导的网络,用于图像到图像的任务,如草图到图像的转换、修复和超分辨率等。证明了SDXS方法的灵活性和应用潜力。

方法

LDM框架由三个关键要素组成:文本编码器、图像解码器以及一个需要多次迭代以生成清晰图像的去噪模型。由于文本编码器的开销相对较低,因此优化其大小并不是研究的重点。

VAE优化:LDM框架通过将样本投影到计算效率更高的低维潜在空间,显著提高了高分辨率图像扩散模型的训练效率。这一过程通过使用预训练模型,如变分自编码器(Variational AutoEncoder, VAE)或向量量化变分自编码器(Vector Quantised-Variational AutoEncoder, VQVAE)来实现高比例图像压缩。VAE包含一个将图像映射到潜在空间的编码器,以及一个重建图像的解码器。其训练通过平衡重建损失、Kullback-Leibler (KL) 散度和GAN损失来优化。然而,训练中对所有样本同等对待引入了冗余。研究者们提出了一种VAE蒸馏(VD)损失,用于训练一个小型的图像解码器G: 其中,D是GAN判别器,用于平衡两个损失项,表示在8倍下采样图像上的L1损失。图2(a)展示了蒸馏小型图像解码器的训练策略。倡使用简化的CNN架构,不包含注意力机制和归一化层等复杂组件,只关注基本的残差块和上采样层。

U-Net优化: LDMs采用U-Net架构作为核心去噪模型,该架构结合了残差块和Transformer块。为了利用预训练的U-Nets的能力,同时减少计算需求和参数数量,研究者们采用了知识蒸馏策略,这一策略受到BK-SDM的块移除训练策略启发。这涉及从U-Net中选择性地移除残差和Transformer块,目的是训练一个更紧凑的模型,该模型仍能有效复现原始模型的中间特征图和输出。图2(b)展示了蒸馏小型U-Net的训练策略。知识蒸馏通过输出知识蒸馏(OKD)和特征知识蒸馏(FKD)损失实现:总的损失函数是两者的结合: 其中,λF​平衡两个损失项。与BK-SDM不同,研究者们排除了原始的去噪损失。模型基于SD-2.1基础版和SDXL-1.0基础版进行了小型化。对于SD-2.1基础版,研究者们去除了中间阶段、下采样阶段的最后阶段和上采样阶段的第一阶段,并去除了最高分辨率阶段的Transformer块。对于SDXL-1.0基础版,研究者们去除了大部分Transformer块。

ControlNet优化: ControlNet通过嵌入空间引导来增强扩散模型,使图像到图像的任务如草图到图像的转换、修复和超分辨率成为可能。它复制了U-Net的编码器架构和参数,并增加了额外的卷积层以纳入空间控制。尽管ControlNet继承了U-Net的参数并采用零卷积来提高训练稳定性,但其训练过程仍然成本高昂且显著受数据集质量影响。为了解决这些挑战,研究者们提出了一种蒸馏方法,将原始U-Net中的ControlNet蒸馏到小型U-Net中的相应ControlNet。图2(b)展示了这一过程,不是直接蒸馏ControlNet零卷积的输出,而是将ControlNet与U-Net结合,然后蒸馏U-Net的中间特征图和输出,这使得蒸馏后的ControlNet和小型U-Net能够更好地协同工作。考虑到ControlNet不影响U-Net编码器的特征图,特征蒸馏仅应用于U-Net的解码器。

尽管扩散模型(DMs)在图像生成方面表现出色,但它们依赖于多个采样步骤,即使采用先进的采样器,这也引入了显著的推理延迟。为了解决这个问题,先前的研究引入了知识蒸馏技术,例如渐进式蒸馏(progressive distillation)和一致性蒸馏(consistency distillation),旨在减少采样步骤并加速推理。然而,这些方法通常只能在4到8个采样步骤中产生清晰的图像,这与在生成对抗网络(GANs)中看到的一步式生成过程形成了鲜明对比。

直接训练一步式模型的方法包括初始化噪声ϵ,并使用常微分方程(ODE)采样器ψ进行采样以获得生成的图像,从而构建噪声-图像对。这些对在训练期间作为学生模型的输入和真实情况。然而,这种方法通常导致生成质量低下的图像。根本问题是使用预训练的DM生成的噪声-图像对的采样轨迹交叉,导致不适定问题。Rectified Flow通过拉直采样轨迹来解决这一挑战。它替换了训练目标,并提出了一种“重流”策略来优化配对,从而最小化轨迹交叉。

采样轨迹的交叉可能导致一个噪声输入对应多个真实图像,导致训练模型生成的图像是多个可行输出的加权和。为了解决这个问题,研究者们探索了改变权重方案以优先考虑更清晰图像的替代损失函数。在大多数情况下,可以使用L1损失、感知损失和LPIPS损失来改变权重形式。研究者们基于特征匹配的方法,计算由编码器模型生成的中间特征图上损失。具体来说,他们从DISTS损失中汲取灵感,对这些特征图应用结构相似性指数(SSIM)以获得更精细的特征匹配损失: 其中 是由编码器 编码的第 个中间特征图上计算的SSIM损失的权重,是由小型U-Net 生成的图像,是由原始U-Net xϕ​ 使用ODE采样器ψ生成的图像。在实践中,使用预训练的CNN骨干、ViT骨干和DM U-Net的编码器都能产生有利的结果,与MSE损失的比较在图6中展示。

尽管特征匹配损失可以产生几乎清晰的图像,但它未能实现真正的分布匹配,因此训练的模型只能作为正式训练的初始化。为了解决这一差距,Diff-Instruct中使用的训练策略,该策略旨在通过在时间步上匹配边际得分函数,使模型的输出分布与预训练模型的分布更紧密地对齐。然而,因为它需要在 t→T 时添加高水平的噪声以使目标得分可计算,此时估计的得分函数是不准确的。研究者们指出,扩散模型的采样轨迹从粗糙到精细,这意味着 t→T 时,得分函数提供了低频信息的梯度,而 t→0 时,它提供了高频信息的梯度。因此,研究者们将时间步分为两段:,后者被LFM替换,因为它可以提供足够的低频梯度。这种策略可以正式表示为: 其中 是在时间 t 和状态 下的函数,用于平衡两段的梯度,。研究者们有意将 α 设置接近1,并将 设置在高值,以确保模型的输出分布与预训练得分函数预测的分布平滑对齐。在概率密度显著重叠后,逐渐降低 α 和 。图3描述了训练策略,其中离线DM表示预训练DM的U-Net,在线DM是从离线DM初始化并在生成的图像上通过等式(1)微调得到的。在实践中,在线DM和学生DM交替训练,如算法1所示。

 一旦一步式DM训练完成,就可以像其他DM一样进行微调,以调整生成图像的风格。研究者们结合使用LoRA和提出的分段得分蒸馏来微调一步式DM,如图4所示。具体为将预训练的LoRA插入离线DM中,如果它也与教师DM兼容,也会插入到那里。要注意,不将LoRA插入在线DM中,因为它对应于一步式DM的输出分布。然后,使用与一步式训练相同的训练程序,但跳过特征匹配预热,因为LoRA微调比完全微调更稳定。另外当教师DM不能纳入预训练的LoRA时,使用降低的 。通过这种方式,可以将预训练的LoRA蒸馏到SDXS的LoRA中。

研究者们的方法也可以适应于ControlNet的训练,使微小的一步式模型能够在其图像生成过程中纳入图像条件,如图5所示。与用于文本到图像生成的基础模型相比,这里训练的模型是伴随前面提到的小型U-Net的蒸馏ControlNet,并且在训练期间U-Net的参数是固定的。重点是需要从教师模型采样的图像中提取控制图像,而不是从数据集图像中提取,以确保噪声、目标图像和控制图像形成一个配对三元组。此外,原始多步U-Net的伴随预训练ControlNet与在线U-Net和离线U-Net集成,但不参与训练。与文本编码器类似,其功能限于作为预训练的特征提取器。通过这种方式,为了进一步减少损失L,训练的ControlNet学习利用从目标图像中提取的控制图像。同时,得分蒸馏鼓励模型匹配边际分布,增强生成图像的上下文相关性。值得注意的是,研究发现用新初始化的噪声替换U-Net噪声输入的一部分可以增强控制能力。图5展示了基于特征匹配和得分蒸馏提出的一步式ControlNet训练策略。虚线表示梯度反向传播。

实验

研究者的代码是基于diffusers库开发的。由于他们无法访问SD v2.1基础版和SDXL的训练数据集,整个训练过程几乎是无数据的,完全依赖于公开可访问数据集中提供的提示。他们使用开源的预训练模型与这些提示结合,生成相应的图像。为了训练模型,他们将训练小批量大小配置在1,024到2,048之间。为了在现有硬件上适应这个批量大小,必要时他们有策略地实施了梯度累积。他们发现所提出训练策略导致模型生成的图像纹理较少。因此,在训练后,他们使用GAN损失结合极低秩的LoRA进行了短暂的微调。当需要GAN损失时,他们使用了StyleGAN-T中的Projected GAN损失,基本设置与ADD一致。对于SDXS-1024的训练,他们使用Vega,SDXL的紧凑版本,作为在线DM和离线DM的初始化,以减少训练开销。

表3为在MS-COCO 2017验证集上的定量结果,即FID和CLIP分数。由于FID对高斯分布的强烈假设,它不是衡量图像质量的一个好的指标,因为它受到生成样本多样性的显著影响。表3显示了MS-COCO 2017 5K子集上的性能比较,图7显示了一些示例。尽管模型大小和所需的采样步骤数量都有明显减少,但SDXS-512的提示跟随能力仍然优于SD v1.5。与Tiny SD(另一个为效率而设计的模型)相比,SDXS-512的优越性更加明显。这一观察结果也在SDXS-1024的性能中得到了一致的验证。使用所提方法训练LoRA的样本如图9所示。显然,模型生成的图像风格可以有效地转移到与离线DM集成的风格导向LoRA匹配的风格,同时通常保持场景布局的一致性。

研究者引入的一步式训练方法是足够通用的,可以应用于图像条件生成。他们展示了其在促进图像到图像转换方面的有效性,特别是利用ControlNet进行涉及canny边缘和深度图的转换。图8展示了两个不同任务的代表性示例,突出了生成图像紧密遵循控制图像提供的指导的能力。然而,这也揭示了在图像多样性方面的显著局限性。如图1所示,虽然问题可以通过替换提示来缓解,但它仍然是后续研究工作中加强的领域。

实验证明将高效的图像条件生成部署在边缘设备上是一个充满前景的研究方向,研究者计划在未来探索包括修复和超分辨率在内的更多应用。通过不断的技术创新和优化,人工智能在图像生成领域的应用将更加广泛和深入。

论文链接:https://arxiv.org/abs/2403.16627

项目地址:https://idkiro.github.io/sdxs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2064053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

财务报表解读指南:关键指标与分析方法详解

一、概述 财务报表中包含了丰富的信息,但如果在分析时缺乏明确的思路或忽略重点,很容易被复杂的数据搞得无所适从。本文将介绍财务报表中的关键指标,包括资产负债率的分析、净资产收益率的解读,以及销售复合增长率的计算&#xf…

如何生成随机数(通过rand函数,srand函数,time函数深入讲解)

目录 1. 随机数的生成 2. srand函数 3. time函数 4. 设置随机数的范围 1. 随机数的生成 既然是猜数字游戏,那么最终的数字答案肯定是重要的,我们要如何实现这个随机数的生成呢? 在这个功能上,C语言提供了一个函数叫rand&…

智算中心算力池化技术深度分析报告

智算中心算力池化技术深度分析 智能算力,人工智能基石,助力构建多要素融合信息基础设施。作为数字经济高质量发展核心引擎,智能算力基础设施建设正迎来高潮。 智算中心,作为集约化算力基础设施,以智能算力为核心&…

特征工程练手(四):特征选择

本文为和鲸python 特征工程入门与实践闯关训练营资料整理而来,加入了自己的理解(by GPT4o) 原活动链接 原作者:云中君,大厂后端研发工程师 目录 0、关卡总结1、前言2、基础知识讲解2.1get_best_model_and_accuracy2…

springboot集成kafka-生产者发送消息

springboot集成kafka发送消息 1、kafkaTemplate.send()方法1.1、springboot集成kafka发送消息Message对象消息1.2、springboot集成kafka发送ProducerRecord对象消息1.3、springboot集成kafka发送指定分区消息 2、kafkaTemplate.sendDefault()方法3、kafkaTemplate.send(...)和k…

案例-异常

题目: (如果一开始不知道如何用异常的语法写,可先用如if语句代替try...catch,最后再把if优化为try...catch) 代码: javabean类: 测试类:

Java CompletableFuture:你真的了解它吗?

文章目录 1 什么是 CompletableFuture?2 如何正确使用 CompletableFuture 对象?3 如何结合回调函数处理异步任务结果?4 如何组合并处理多个 CompletableFuture? 1 什么是 CompletableFuture? CompletableFuture 是 Ja…

springboot静态资源访问问题归纳

以下内容基于springboot 2.3.4.RELEASE 1、默认配置的springboot项目,有四个静态资源文件夹,它们是有优先级的,如下: "classpath:/META-INF/resources/", (优先级最高) "classpath:/reso…

【精选】基于Spark的国漫推荐系统(精选设计产品)

目录: 系统开发技术 Python可视化技术 Django框架 Hadoop介绍 Scrapy介绍 IDEA介绍 B/S架构 MySQL数据库介绍 系统流程分析 操作流程 添加信息流程 删除信息流程 系统系统介绍: 可以查看我的B站: 系统测试 运行环境 软件平台 硬…

docker-compose安装NebulaGraph 3.8.0

文章目录 一. 安装NebulaGraph1.1 通过 Git 克隆nebula-docker-compose仓库的3.8.0分支到主机1.2 部署1.3 卸载1.4 查看 二. 安装NebulaGraph Studio2.1 下载 Studio 的部署配置文件2.2 创建nebula-graph-studio-3.10.0目录,并将安装包解压至目录中2.3 解压后进入 n…

shaushaushau1

CVE-2023-7130 靶标介绍: College Notes Gallery 2.0 允许通过“/notes/login.php”中的参数‘user’进行 SQL 注入。利用这个问题可能会使攻击者有机会破坏应用程序,访问或修改数据. 已经告诉你在哪里存在sql注入了,一般上来应该先目录扫…

【补充篇】AUTOSAR多核OS介绍(下)

文章目录 前文回顾1 AUTOSAR OS1.1 AUTSOAR OS元素1.1.1 操作系统对象1.1.2 操作系统应用程序1.1.3 AUTOSAR OS裁剪类型1.1.4 AUTOSAR OS软件分区1.2 AUTOSAR OS自旋锁1.3 AUTOSAR OS核间通信1.4 AUTOSAR OS多核调度前文回顾 在上篇文章【补充篇】AUTOSAR多核OS介绍(上)中,…

对于一个36岁的人来说,现在转行AI大模型还来得及吗?

前言 在职场生涯中,33岁似乎是一个尴尬的年龄。许多人在这个阶段已经定型,难以寻求新的突破。然而,随着科技行业的飞速发展,人工智能成为了新时代的宠儿。那么,对于一个33岁的人来说,现在转行AI大模型还来…

做SSH实验下载 paramiko库

今天做SSH实验下载paramiko库文件一直出问题,后面库文件下好了还是报错,这里记录了我的解决方案。 pycharm修改默认下载路径为国内镜像(我这里用清华大学的镜像下载快一些) Simple Index 到这里路径就改好了,接下来就…

从就业出发,深度剖析大数据行业的现状与前景

以一个经典案例引入——啤酒与纸尿裤的故事。 20世纪90年代,沃尔玛从购物的后台信息数据中,发现很多买了纸尿裤的男士会同时买啤酒。后来,调查发现,此类人多是被“轰出来”买纸尿裤,一想到养娃压力大,心情…

牛客竞赛数据结构专题班树状数组、线段树练习题

牛客竞赛_ACM/NOI/CSP/CCPC/ICPC算法编程高难度练习赛_牛客竞赛OJ G 智乃酱的平方数列(线段树,等差数列,多项式) 题目描述 想必你一定会用线段树维护等差数列吧?让我们来看看它的升级版。 请你维护一个长度为510 ^5…

Mysql高级 [Linux版] 性能优化 数据库系统配置优化 和 MySQL的执行顺序 以及 Mysql执行引擎介绍

数据库系统配置优化 1、定义 数据库是基于操作系统的,目前大多数MySQL都是安装在linux系统之上,所以对于操作系统的一些参数配置也会影响到MySQL的性能,下面就列出一些常用的系统配置。 2、优化配置参数-操作系统 优化包括操作系统的优化及My…

集运系统:如何实现不同员工的不同操作权限?

在集运行业,员工的角色和职责各有不同,因此对系统的操作权限需求也不尽相同。为了确保数据的安全性和业务的顺利进行,易境通集运系统提供了灵活的权限管理功能,让企业可以根据员工的角色和职责,设置不同的操作权限。 易…

Redis (day 3)

一、通过jedis连接数据库 1.首先导入依赖 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>5.1.0</version></de…

mac 微信数据直接存储到移动硬盘

在apple设备上存储都是1500块/128gb的价格收取的&#xff0c;真的是寸土寸金。在手机已经占用了一遍存储空间之后&#xff0c;微信备份还要占用一遍。 iCloud备份微信聊天记录的稳定性真的非常差劲&#xff0c;比如我微信30g&#xff0c;经常恢复到20g左右就被打断&#xff0c;…