用 Transformer 替换 diffusion 的U-Net:可伸缩的 diffusion 模型

news2024/11/22 23:11:20

论文标题: Scalable diffusion models with transformers

论文链接:https://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html

代码:https://github.com/facebookresearch/DiT/blob/main/README.md

7f5a4edd3bf8dd430bbd8f657714f661.png

引用:Peebles W, Xie S. Scalable diffusion models with transformers[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023: 4195-4205.

导读

本文探索了一种基于 transformer 体系结构的新型扩散模型。作者使用Transformer来训练图像的潜在扩散模型,取代了通常使用的U-Net骨干网络,这个Transformer操作在潜在图像块上。研究还分析了这种新模型的可伸缩性,通过Gflops(每秒十亿次浮点运算)来衡量前向传播复杂性。研究发现,具有更高Gflops的Diffusion Transformers(DiTs)——通过增加Transformer的深度/宽度或增加输入标记的数量——通常具有更低的FID(Frechet Inception Distance)。此外,研究中最大的DiT-XL/2模型在类别条件的ImageNet 512x512和256x256基准上表现出色,取得了后者的最先进 FID 成绩为2.27。

本文贡献

提出了一种新的扩散模型架构,称之为Diffusion Transformers(DiTs)。这一架构基于Transformers,用于图像生成任务。

研究表明,传统的扩散模型架构中常用的U-Net骨干并不是性能的关键因素。他们成功地将U-Net替换为标准的Transformer架构,这意味着扩散模型可以采用更通用的设计,如Transformers,而不受限于特定的架构。

通过使用DiTs架构,研究者在 ImageNet 生成基准上实现了显著的性能提升,将FID(Frechet Inception Distance)降低到2.27,达到了最新的最先进水平。

预备知识

扩散模型的基本理论

高斯扩散模型假设了一个前向噪声过程,并逐渐将噪声应用于真实数据:

b4255b8e6c184f8a23e1b96ef4ea3a06.png
976abaa09c8464dbd8f5b903c2a5cdbb.png

通过应用重新参数化的技巧,我们可以进行采样:

2214e66cffe4838b72b4b7773426bf24.png

扩散模型的训练是为了学习反向过程,即将前向过程中的损坏恢复成原始数据的过程:

50979fad211e8d9142df244a733798ae.png
3cb9a4c27d5751327dbd0f5fab406ccf.png

训练反向过程模型时,使用了变分下界(variational lower bound)来估计x0的对数似然:

f7616c3bd337d3ce2ac4e5d603c19cde.png
9df90b514f6a30be5339bb3683ac2a33.png

通过将µθ重新参数化为噪声预测网络 εθ,该模型可以使用预测噪声

41ba4aafdfb98bdda05589a6d9a1ec91.png

和地面真实采样高斯噪声 εt 之间的简单均方误差进行训练: b6654b7c197d4b838ce5103fb6f02fe9.png

但是,为了用学习到的反向过程协方差

2bebf58130f766239d5eff636ae68029.png
来训练扩散模型,需要优化完整的D_KL项,我们遵循 Nichol 和 Dhariwal 的方法。

无分类器引导

条件扩散模型将额外信息作为输入,如类别标签c。在这种情况下,反向过程变为:

400d4aa5c862b66ded6aeb84636f738f.png

在这种情况下,可以使用无分类器的指导来鼓励采样程序找到x,从而使 log p(c|x) 变高。根据贝叶斯规则:

e6e83c9c252e26882513b40e3bf3ea28.png

因此,

a5e0470a89c3710c78d40d3f186db20e.png

所以在想要条件的概率较大,就可以将条件的梯度增加到优化目标里,最终可以表示成如下形式:

00cb9787fa4e621505a4702e457b9816.png

无分类器引导已被广泛认为能够显著提高样本生成的质量,而这一趋势在DiTs模型中同样有效。

潜在扩散模型(Latent diffusion models)

直接在高分辨率像素空间中训练扩散模型在计算上是代价高昂的。潜在扩散模型(LDMs)通过一个两阶段的方法来解决这一问题:首先,学习一个自编码器,将图像压缩为具有学习编码器E的较小空间表示;其次,训练表示z = E(x)的扩散模型,而不是图像x的扩散模型(E是冻结的)。然后,可以通过从扩散模型中采样表示z,然后使用学习的解码器进行解码,生成新的图像x = D(z)。如图2所示,潜在扩散模型在使用像ADM这样的像素空间扩散模型的Gflops的一小部分的情况下实现了良好的性能。因为作者关注计算效率,这使得它们成为架构探索的吸引人的起点。

62ab5cda45171d12005d997ac1add895.png

本文方法

e79b660077ac7646dde2d52b3b87616b.png

Patchify: DiT的输入是一个空间表示z(对于256x256x3的图像,z的形状为32x32x4)。DiT的第一层是“patchify”,它通过线性嵌入输入中的每个图像块,将空间输入转换为T个维度为d的标记序列。随后,我们对所有输入标记应用标准的ViT基于频率的位置嵌入(正弦-余弦版本)。通过patchify创建的标记数量T由补丁大小的超参数p确定。

如图4所示,将p减半会使T增加四倍,从而至少使总的Transformer Gflops增加四倍。尽管对Gflops有显著影响,但需要注意的是,更改p对下游参数数量没有实际影响。

作者将p设置为2、4和8。

5cce97db4dc59bbb8f95281da977c96a.png

DiT块设计:在经过patchify之后,输入标记由一系列Transformer块进行处理。除了噪声图像输入,扩散模型有时还处理额外的条件信息,如噪声时间步t、类别标签c、自然语言等。研究者探索了四种不同处理条件输入的Transformer块变体。这些设计对标准的ViT块设计进行了小而重要的修改。所有块的设计都在图3中显示(经过实验分析最终作者选择了adaLN块)。下面介绍这四种块。

  • In-context Conditioning:这种方法简单地将t和c的向量嵌入作为两个额外的标记附加在输入序列中,对待它们与图像标记没有区别。这类似于ViTs中的cls标记,它允许我们在不进行修改的情况下使用标准的ViT块。在最后一个块之后,将条件化标记从序列中删除。这种方法对模型引入了几乎可以忽略的新Gflops开销。

  • Cross-Attention Block:这种方法将t和c的嵌入连接成一个长度为 2 的序列,与图像标记序列分开。Transformer块进行了修改,包括多头自注意力块之后的多头跨注意力层,跨注意力块为模型添加的Gflops最多,大约增加了15%的开销。

  • Adaptive Layer Norm (adaLN) Block:这种方法基于GANs和具有UNet骨干的扩散模型中广泛使用的自适应标准化层,将Transformer块中的标准规范层替换为自适应规范(adaLN)。与直接学习γ和β等参数不同,它们从t和c的嵌入向量之和中回归得出。在作者研究的三种块设计中,adaLN添加的Gflops最少,因此计算效率最高。这也是唯一一种将相同函数应用于所有标记的条件化机制。

  • adaLN-Zero Block:之前的研究发现,将每个残差块初始化为恒等函数是有益的。为了实现这一目标,作者探索了adaLN DiT块的修改版本,该版本与之前类似。除了回归γ和β,他们还回归了应用于DiT块内的任何残差连接之前的维度缩放参数α。他们将MLP初始化为对所有α输出零向量,从而将整个DiT块初始化为恒等函数。与普通的adaLN块一样,adaLN-Zero对模型添加的Gflops几乎可以忽略不计。

Transformer decoder:在DiT架构的最后一个DiT块之后,需要将图像标记序列解码为输出的噪声预测和对角协方差预测。这两个输出的形状与原始的空间输入相同。为了实现这一目标,作者使用了标准的线性解码器,将最后的层规范(如果使用adaLN,则为自适应)应用于每个标记,并线性解码为一个p x p x 2C张量,其中C是DiT输入中的通道数。最后,将解码后的标记重新排列成原始的空间布局,得到了噪声和协方差的预测。

实验

实验结果

不同扩散模型的效果对比如下:

1726957096ef3bef1fef197b3490fe2c.png

从下图可以看出adaLN-Zero方法明显好于cross-attention和in-contenxt,所以实验中均采用adaLN-Zero方法进行上下文交互:

86c9aaee9595d10553aa02b106d86e43.png

缩放DiT模型可以提高训练的所有阶段的FID:

c9a9b9375d75f8aa8de14056d09e386c.png

模型越大、patch size越小生成图像质量越好:

fc7b6f03fea0d84c0a3ffbe7108ba4ec.png

结论

文章提出DiTs结构进行扩散模型图像生成,在Gflops与Stable Diffusion相当的DiTs-XL/2的结构上,把ImageNet 256×256数据集上的FID指标优化到了2.27,达到了SOTA的水平。未来将进一步探索更大的DiTs模型和token数量。

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

ed8e02ecc19416a334b6089070c5eca1.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1187497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【EI会议征稿】第四届计算机网络安全与软件工程国际学术会议(CNSSE 2024)

第四届计算机网络安全与软件工程国际学术会议(CNSSE 2024) 2024 4th International Conference on Computer Network Security and Software Engineering 第四届计算机网络安全与软件工程国际学术会议(CNSSE 2024)将于2024年2月…

FPBJXDN224、FPBJXDV224插头式电比例节流阀放大器

FPBJXDN224、FPBJXDV224此阀是一款先导控制,常开,电比例节流阀,带反向单向阀。比例电磁铁得电可以在先导级产生作用力,从而比例地关闭主级的阀芯。液流方向为2口流向1口。无论电比例开启还是关闭反向单向阀都允许油液从1口自由流向…

影视小程序源码 付费短剧小程序源码 支持会员模式 多平台支付方式

这是一款功能强大的全开源付费短剧小程序源码,支持多种展现形式,包括付费、免费、任务等方式解锁自由配置。此外,还有用户运营、营销推广、付费观看和成熟代理机制等多种功能。 该小程序源码支持无限滑动、高性能滑动、预加载和视频预览等功能…

天猫店铺所有商品数据接口(Tmall.item_search_shop)

天猫平台店铺所有商品数据接口是开放平台提供的一种API接口,通过调用该接口,开发者可以获取天猫整店的商品的标题、价格、库存、月销量、总销量、库存、详情描述、图片、价格信息等详细信息。 要使用天猫店铺所有商品数据接口,您需要先登录天…

I/O软件层次介绍

一、I/O系统 1.设备独立性软件 2.设备驱动程序 3.中断处理程序 总览 二、输入、输出管理 1.应用程序接口 网络通信方式过程 2.设备驱动程序接口

应用在便携式多媒体播放器中的音频Codec芯片

便携式多媒体播放器(PMP,Portable Media Player),也就是通常人们所说的MP4。PMP的主要优点是:携带方便,能够直接播放高品质音/视频文件;也可以浏览图片,以及作为移动硬盘使用;此外,P…

对Mysql和应用微服务做TPS压力测试

1.对Mysql 使用工具:mysqlslap工具 使用命令: mysqlslap -uroot pGG8697000!#--auto generate sql -auto generate sql-load typemixed-concurrency100,200 - number of queries1000-iterations10 - number-int-cols7 - number-charcols13auto genera…

PBJ | IF=13.8 利用ChIP-seq和ATAC-seq技术揭示MdRAD5B调控苹果耐旱性的双重分子作用机制

2023年10月24日,西北农林科技大学园艺学院管清美教授团队在Plant Biotechnology Journal(最新IF:13.8)上发表题为“The chromatin remodeller MdRAD5B enhances drought tolerance by coupling MdLHP1-mediated H3K27me3 in apple…

Word文件损坏怎么办?这3个方法教你轻松解决!

使用Word编写文档时,我们可能会遇到各种各样的问题,这会给我们的学习和工作带来不好的影响。Word文件损坏也是比较常见的一种情况。怎么解决这个问题呢? 如果Word文档损坏后想要恢复应该怎么做呢?小编给大家总结了几个小妙招&…

更改 npm的默认缓存地址

npm的默认缓存一般在C:\Users\用户名\AppData\Roaming路径下的npm和npm_cache,而c盘往往空间不大。 1、在其他盘新建两个文件夹,如D盘,node_cache和node_global。如下图所示。 2、在cmd中执行npm config set prefix “node_cache的路径”&a…

基于ssm的网上药房管理系统的设计与实现(源码+LW+调试)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。今天给大家介绍一篇基于java的ssm网上药房管…

阿里云服务器系统怎么选?Alibaba Cloud Linux操作系统介绍

Alibaba Cloud Linux是阿里云基于龙蜥社区(OpenAnolis)的龙蜥操作系统(Anolis OS)打造的操作系统发行版,在全面兼容RHEL/CentOS生态的同时也为云上应用程序环境提供Linux社区的增强功能,并针对阿里云基础设…

若依vue-初步下载使用

若依框架可以满足大部分的后台管理系统的开发,使用频率也是比较高的,所以这里讲一下如何使用若依框架 若依框架代码克隆 首先去若依官网 http://www.ruoyi.vip/ 这里演示的是若依-vue版本的使用 我们点击下载 会跳转到码云仓库 或者直接点击下面的链接去码云仓库 https://git…

Linux开发工具之vim

文章目录 1.vim是啥?1.1问问度娘1.2自己总结 2.vim的初步了解2.1进入和退出2.2vim的模式1.介绍2.使用 3.vim的配置3.1自己配置3.2下载插件3.3安装大佬配置好的文件 4.程序的翻译 1.vim是啥? 1.1问问度娘 1.2自己总结 vi/vim都是多模式编辑器,vim是vi的升级版本&a…

linux安装kafka教程

kafka需要安装jdk,我的是jdk17 一、安装kafka 1、下载kafka 1.到kafka的官网,去下载想用的kafka包:http://kafka.apache.org/downloads 2.我这里下载的是:kafka_2.12-3.4.1.tgz 3.将安装包传送到服务器并解压(默认…

json数据格式的理解(前+后)

什么是JSON: JSON(JavaScript Object Notation)是一种广泛使用的数据交换格式,它在前端和后端开发中都扮演着重要的角色。 JSON 的结构: JSON 数据由大括号 {} 包围,表示对象。 对象中的数据以键值对形式…

leetcode:203. 移除链表元素(有哨兵位的单链表和无哨兵位的单链表)

一、题目 函数原型: struct ListNode* removeElements(struct ListNode* head, int val) 二、思路 本题有两种思路: 思路1 遍历单链表,如果遇到值为val的结点,则将该结点删除。 注意:当删除结点时,如果出现…

《持续交付:发布可靠软件的系统方法》- 读书笔记(十二)

持续交付:发布可靠软件的系统方法(十二) 第 12 章 数据管理12.1 引言12.2 数据库脚本化12.3 增量式修改12.3.1 对数据库进行版本控制12.3.2 联合环境中的变更管理 12.4 数据库回滚和无停机发布12.4.1 保留数据的回滚12.4.2 将应用程序部署与数…

数字孪生与电力行业的完美融合

电力行业一直是现代社会不可或缺的一部分,而数字孪生技术正逐渐改变这一传统行业的面貌。数字孪生电力解决方案通过将物理世界与数字世界相结合,为电力行业带来了前所未有的机会和挑战。本文为大家介绍山海鲸电力行业系列解决方案,带大家了解…