Scalable Diffusion Models with Transformers(DiTs)论文阅读 -- 文生视频Sora模型基础结构DiT

news2024/9/23 7:23:27

nlpcver

nlpcver

忠于理想

​关注他

106 人赞同了该文章

文章地址:Scalable Diffusion Models with Transformers

简介

文章提出使用Transformers替换扩散模型中U-Net主干网络,分析发现,这种Diffusion Transformers(DiTs)不仅速度更快(更高的Gflops),而且在ImageNet 512×512和256×256的类别条件图片生成任务上,取得了更好的效果,256×256上实现了SOTA的FID指标(2.27)。

Transformers已经广泛应用于包括NLP、CV在内的机器学习的各个领域。然而,很多图片level的生成模型还坚持使用卷积神经网络,比如扩散模型采用的就是U-Net的主干网络架构。经过演化,扩散模型中的U-Net网络增加了稀疏的自注意力模块,此外 Dhariwal and Nichol 也尝试过在U-Net模型上的一些改变,比如通过增加适配的正则化层来注入条件信息和Channel数量。尽管如此,U-Net的顶层设计还是与原始U-Net相差无几。

文章的目标就是要揭开扩散模型架构选择的神秘面纱,提供一个强有力的baseline。文章发现U-Net并非不可替代,并且很容易使用诸如Transformers的结构替代U-Net,使用Transformers可以很好地保持原有的优秀特性,比如可伸缩性、鲁棒性、高效性等,并且使用新的标准化架构可能在跨领域研究上展现出更多的可能。文章从网络复杂度和采样质量两个方面对DiTs方法进行评估。

相关工作

Transformers

当前,Transformers架构已经应用在了文本、视觉、强化学习、元学习等多个领域,同时模型的大小、训练开销、数据量等也不断地上涨。在语言模型的启发下,有些工作在视觉任务上训练离散的codebook,这种架构可以同时应用在自回归模型和masked生成模型。本文将研究在扩散模型的主干网络上应用Transformers。

DDPMs

扩散模型是借鉴了物理学上的扩散过程,在生成模型上,分为正向和逆向的过程。正向过程是向信号中逐渐每步加少量噪声,当步数足够大时可以认为信号符合一个高斯分布。所以逆向过程就是从随机噪声出发逐渐的去噪,最终还原成原有的信号。

去噪过程一般采用UNet或者ViT,使用t步的结果和条件输入预测t-1步增加的噪声,然后使用DDPM可以得到t-1步的分布,经过多步迭代就可以从随机噪声还原到有实际意义的信号。如果使用原始DDPM速度会慢很多,所以很多工作如DDIM、FastDPM等工作实现了解码加速。

在图像的无条件生成任务上,扩散模型的性能已经超过了GANs,并且在有条件生成如文图生成任务上大放异彩。

架构复杂度

对于图片生成的迭代过程,我们可以使用参数量来衡量不同模型的复杂度。一般而言,参数量来评估模型复杂度不是很合适,因为参数量并不能代表模型的计算复杂度,比如当模型参数量相同时,图片分辨率不同会导致计算复杂度上较大的差异。所以文章采用Gflops来衡量模型架构的复杂度。

方法

扩散模型基础

前向过程是一个T步逐渐加噪的马尔科夫链,公式如下

给定前向扩散过程作为先验,扩散模型训练反转的过程,可以通过去除所加噪声从XT恢复成X0,并且每步的扩散过程都采样自特定的高斯分布,其期望和方差如下:

优化目标是负的X0概率似然,其上界如下所示:

并且其目标可以简化为预测和ground truth之间的l2 loss。

Classifier-free guidance

条件扩散模型是将条件信息作为额外的输入,比如一个分类标签c。这种情况下反向过程变为了

根据贝叶斯规则

因此

所以在想要条件的概率较大,就可以将条件的梯度增加到优化目标里,最终可以表示成如下形式:

模型在训练时,使用一个网络架构优化两个模型(uncond,cond)。

Latent diffusion models

模型使用VAE(固定权重)将图片encoder到隐空间,生成结果同样也是通过VAE解码成原始大小的图片。

DiTs架构

文章提出DiTs模型架构,完整的架构图如下所示:

Patch化:DiT的输入是通过VAE后的一个稀疏的表示z(256×256×3的图片,z为32×32×4),类似其他ViTs的方式,首先要将输入转成patch,文章采用超参p=2,4,8进行对比实验。

DiT模块设计

  • In-context条件:in-context条件是将t和c作为额外的token拼接到DiT的token输入中;
  • Cross-attention模块:DiT结构与Condition交互的方式,与原来U-Net结构类似;
  • Adaptive layer norm(adaLN)模块:使用adaLN替换原生LayerNorm(NeurIPS2019的文章,LN 模块中的某些参数不起作用,甚至会增加过拟合的风险。所以提出一种没有可学习参数的归一化技术);
  • adaLN-zero模块:之前的工作发现ResNets中每一个残差模块使用相同的初始化函数是有益的。文章提出对DiT中的残差模块的参数γ、β、α进行衰减,以达到类似的目的。

模型大小:与ViT大小相似,分别使用DiT-S、DiT-B、DiT-L和DiT-XL,Gflops从0.3dao118.6。

Transformer Decoder:在Transformer最上层需要预测噪音,因为Transformer可以保证大小与输入一致,所以在最上层使用一层线性进行decoder。

实验

实验设置

模型使用结构/patch数量方式表示,比如DiT-XL/2表示模型采用DiT-XL,patch size为2。

训练:在ImageNet 256×256和512×512分辨率的数据集上训练。初始化最后一层线性层为0,其他初始化都与ViT一致。训练模型采用AdamW,学习率1e-4,no weight decay,batch size为256,数据增广仅有水平翻转。无需学习率warmup和正则化。实验结果使用EMA model(decay 0.9999)。

Diffusion:使用VAE将256×256×3的图像编码到32×32×4的隐空间,经过扩散模型的逆向过程后,将32×32×4的隐空间还原到256×256×3的图像。

评价指标:使用250步DDPM采样,计算FID-50K的结果,没用特殊说明时未采用classifier-free guiance。此外还增加了Inception Score、sFID、Precision/Recall等指标。

实验结果

DiT结构:从下图可以看出adaLN-Zero方法明显好于cross-attention和in-contenxt,所以下文中均采用adaLN-Zero方法进行上下文交互。

model size and patch size评估:如下图所示,模型越大、patch size越小生成图像质量越好。

计算开销和模型效果的关系:如下图左所示,Gflops越大模型效果越好,同样如下图右所示,模型越大计算约高效(相同计算量下模型效果越好)

不同扩散模型的效果对比如下( DiT-XL/2 (118.6 Gflops) is compute-efficient relative to latent space U-Net models like LDM-4 (103.6 Gflops) ):

结论

文章提出DiTs结构进行扩散模型图像生成,在Gflops与Stable Diffusion相当的DiTs-XL/2的结构上,把ImageNet 256×256数据集上的FID指标优化到了2.27,达到了SOTA的水平。未来将进一步探索更大的DiTs模型和token数量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

市域社会治理现代化指挥中心项目方案

1.4.1专题分析应用建设要求 1.4.1.1总体要求 系统根据各专题实际业务需求提供详细的指标体系清单,同时应根据指标体系提供设计各专题应用的原型效果图;围绕党建引领、基层治理、城市管理、公共服务、公共安全多方面进行分析展示核心数据,体…

AI跟踪报道第32期-新加坡内哥谈技术-本周AI新闻:超越GPT4的Claude

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

删除网络连接不存在的网络驱动器

目录预览 一、问题描述二、解决方案三、参考链接 一、问题描述 虚拟机之前连接了主机, 二、解决方案 方法一:从文件资源管理器断开 Windows 10 上映射的网络驱动器, 在 Windows 10 上打开文件资源管理器。 从左窗格中单击此 PC 。 在“网络…

Linux系统架构----nginx的访问控制

nginx的访问控制 一、nginx基于授权的访问控制概述 Nginx与Apache一样,可以实现基于用户权限的访问控制,当客户端想要访问相应的网站或者目录时,要求用户输入用户名和密码,才能正常访问配置步骤生成用户密码认证文件 &#xff1…

若依/RuoYi-Vue使用docker-compose部署

系统需求 JDK > 1.8 MySQL > 5.7 Maven > 3.0 Node > 12 Redis > 3 思路 前端服务器 nginx 后端服务器代码打包 java、maven、node 数据库/缓存 mysql、redis 开始 创建目录ruoyi并进入 克隆若依代码 git clone RuoYi-Vue: 🎉 基于Spring…

【数据分享】2013-2022年全国范围逐日SO2栅格数据

空气质量数据是在我们日常研究中经常使用的数据!之前我们给大家分享了2013-2022年全国范围逐月SO2栅格数据和逐年SO2栅格数据(均可查看之前的文章获悉详情)。 本次我们给大家带来的是2013-2022年全国范围的逐日的SO2栅格数据,原始…

mq基础类设计

消息队列就是把阻塞队列这样的数据结构单独提取成一个程序独立进行部署。——>实现生产者消费者模型。 但是阻塞队列是在一个进程内部进行的; 消息队列是在进程与进程之间进行实现的, 解耦合:就是在分布式系统中,A服务器调用B…

RT-DETR优化改进:特征融合篇 | GELAN(广义高效层聚合网络)结构来自YOLOv9

🚀🚀🚀本文改进:使用GELAN改进架构引入到RT-DETR 🚀🚀🚀RT-DETR改进创新专栏:http://t.csdnimg.cn/vuQTz 🚀🚀🚀学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 🚀🚀🚀RT-DETR模型创新优化,涨点技巧分享,科研小助手; 1.YOLOv9介绍 论…

在用Java写算法的时候如何加快读写速度

对于解决该方法我们一般如下操作,不需要知道为什么,有模板(个人观点) 使用BufferedReader代替Scanner:Scanner类在读取大量输入时性能较差,而BufferedReader具有更高的读取速度。可以使用BufferedReader的r…

JVM的工作流程

目录 1.JVM 简介 2.JVM 执行流程 3. JVM 运行时数据区 3.1 堆(线程共享) 3.3 本地方法栈(线程私有) 3.4 程序计数器(线程私有) 3.5 方法区(线程共享) 4.JVM 类加载 ① 类…

webUI自动化之元素及浏览器操作

一、元素定位方式 1、元素属性定位: 1 element driver.find_element_by_id(self, id)    该类方法已经过时,新的方法如下: element driver.find_element(By.ID, ID 值)        # 用元素的 ID 属性定位element driver.find_eleme…

云打印软件免费版在哪?云打印服务怎么使用?

随着新的一年的到来,很多同学们又开始准备着新一轮的学习冲刺了。在学习的旅途中,打印资料的需求必然伴随着每一个人,但是线下打印店价格贵、打印不方便、没时间去打印等多种因素总是制约着我们。在这种情况下,云打印软件和云打印…

Svg Flow Editor 原生svg流程图编辑器(一)

系列文章 Svg Flow Editor 原生svg流程图编辑器(二) 效果展示 项目概述 svg flow editor 是一款流程图编辑器,提供了一系列流程图交互、编辑所必需的功能,支持前端研发自定义开发各种逻辑编排场景,如流程图、ER 图、…

【论文笔记】Scalable Diffusion Models with State Space Backbone

原文链接:https://arxiv.org/abs/2402.05608 1. 引言 主干网络是扩散模型发展的关键方面,其中基于CNN的U-Net(下采样-跳跃连接-上采样)和基于Transformer的结构(使用自注意力替换采样块)是代表性的例子。…

使用R语言进行聚类分析

一、样本数据描述 城镇居民人均消费支出水平包括食品、衣着、居住、生活用品及服务、通信、文教娱乐、医疗保健和其他用品及服务支出这八项指标来描述。表中列出了2016年我国分地区的城镇居民的人均消费支出的原始数据,数据来源于2017年的《中国统计年鉴》&#xf…

传递函数硬件化

已知一个系统的传递函数,如何进行硬件化呢? 只需要将传递函数离散化,得到差分方程,就可以根据差分方程进行硬件设计。 通过例子说明: 得到差分方程后,其中y(k)/y(k-1)/y(k-2)/u(k-1)/u(k-2)等代表不同周期…

【Spring】Spring状态机

1.什么是状态机 (1). 什么是状态 先来解释什么是“状态”( State )。现实事物是有不同状态的,例如一个自动门,就有 open 和 closed 两种状态。我们通常所说的状态机是有限状态机,也就是被描述的事物的状态的数量是有…

BC161 大吉大利,今晚吃鸡

一&#xff1a;题目 二&#xff1a;思路 三&#xff1a;代码 #include<bits/stdc.h>using namespace std;long long cnt;//柱子定义为x, y, z void move(int n, char x, char y, char z) {if(n 1){//printf("%c -> %c\n", x, y);//最大盘从x->y//prin…

git远程仓库分支推送与常见问题

1.查看远程仓库分支情况 git fetch origin git branch -r2.删除远程仓库中的某一分支(如master) git push origin --delete master问: 如果我的本地文件只有一个分支main,而远程仓库有两个分支Main和CubeMX, 若要将本地文件中新增的文件Test1.txt更改放入CubeMX中&#xff0c…

大数据开发-Hadoop分布式集群搭建

大数据开发-Hadoop分布式集群搭建 文章目录 大数据开发-Hadoop分布式集群搭建环境准备Hadoop配置启动Hadoop集群Hadoop客户端节点Hadoop客户端节点 环境准备 JDK1.8Hadoop3.X三台服务器 主节点需要启动namenode、secondary namenode、resource manager三个进程 从节点需要启动…