Sora 基础作品之 DiT:Scalable Diffusion Models with Transformer

news2024/11/24 10:30:43

Paper name

Scalable Diffusion Models with Transformers (DiT)

Paper Reading Note

Paper URL: https://arxiv.org/abs/2212.09748

Project URL: https://www.wpeebles.com/DiT.html

Code URL: https://github.com/facebookresearch/DiT

TL;DR

  • 2022 年 UC Berkeley 出品的论文,将 transformer 应用于 diffusion 上实现了当时最佳的生成效果。DiT 论文作者也是 OpenAI 项目领导者之一,该论文是 Sora 的基础工作之一。

Introduction

背景

  • transformer 在自回归模型中得到了广泛应用,但在其他生成模型框架中的采用较少。例如,扩散模型已经处于近期图像级生成模型进展的前沿;然而,它们都采用了卷积 U-Net 架构作为默认的主干选择
  • 本文展示了 U-Net 的归纳偏置对扩散模型的性能并非至关重要,可以替换为 transformer

本文方案

  • 提出了基于 transformer 的 diffusion 模型 Diffusion Transformers (简称 DiTs)。该架构具有良好的可扩展性,即网络复杂度(以Gflops衡量)与样本质量(以FID衡量)之间存在强相关性。通过简单地放大 DiT 并训练一个具有高容量主干的 LDM,能够在类条件 256 × 256 ImageNet 生成基准测试上达到 2.27 FID 的最新结果
    DiT 效果可视化

Methods

整体设计思路

  • 使用 Latent diffusion models(LDM) + Classifier-free guidance + transformer + VAE (Conv) 架构设计,从下图可以看出该设计的优势,左图显示有 scaling law,右图显示 LDM 相比于 pixel space diffusion 模型 ADM 有优势,不仅精度更高,训练计算量也更低
    scaling law for DiT

Diffusion Transformers

  • DiT 基于 ViT 修改得到,整体架构如下图所示:DiT block 通过区分 condition 的添加方式分为三种设计思路,分别是通过 adaLN (或 adaLN-Zero),cross-attention 或 In-Context
    DiT 架构
DiT 前向过程的各个模块
  • Patchify:DiT 的输入是图像的空间表示 z(对于 256×256×3 的图像,z 的形状为 32×32×4)。DiT的第一层是 “patchify”,它通过将输入中的每个补丁线性嵌入,将空间输入转换成一系列 T 个 token,每个 token 的维度为 d。在执行 patchify 之后,我们对所有输入 token 应用标准的 ViT 基于频率的位置嵌入(正弦-余弦版本)。由 patchify 创建的 token 数量 T 由补丁大小超参数 p 决定。如下图所示,将 p 减半会使 T 增加四倍,因此至少使整个 transformer Gflops 增加四倍。DiT 中主要实验了 p = 2, 4, 8
    patchify

  • DiT block:如整体框架图中所示,根据 condition 加入的不同方式分为以下四种设计思路

    • 上下文条件化。我们简单地将 t 和 c 的向量嵌入作为输入序列中的两个额外 tokens 附加上去,对待它们与图像 tokens 没有区别。这与 ViT 中的 cls tokens 类似,它允许我们无需修改就使用标准的 ViT 模块。在最后一个模块之后,我们从序列中移除条件化 tokens。这种方法对模型的新 Gflops 增加可以忽略不计。
    • 交叉注意力模块。我们将 t 和 c 的嵌入 concat 成一个长度为二的序列,与图像 token 序列分开。transformer 模块被修改为在多头自注意力模块后面增加一个额外的多头交叉注意力层,类似于 Attention is All you need 中的原始设计,也类似于 LDM 用于条件化类别标签的设计。交叉注意力对模型的 Gflops 增加最多,大约增加了 15% 的开销。
    • 自适应层归一化(adaLN)模块。在 GANs 和具有 UNet 骨干的扩散模型中广泛使用自适应归一化层之后,我们探索了用自适应层归一化(adaLN)替换 transformer 模块中的标准归一化层。adaLN 并不是直接学习维度规模的缩放和偏移参数 γ 和 β,而是从 t 和 c 的嵌入向量之和中回归得到它们。在我们探索的三种模块设计中,adaLN 增加的 Gflops 最少,因此是最计算高效的。它也是唯一一个限制对所有 tokens 应用相同函数的条件化机制。
    • adaLN-Zero 模块。之前的 ResNets 工作发现,将每个残差块初始化为恒等函数是有益的。例如,在监督学习环境中,将每个块中最后的批量归一化缩放因子 γ 零初始化可以加速大规模训练。扩散 U-Net 模型使用了类似的初始化策略,在任何残差连接之前零初始化每个块中的最终卷积层。我们探索了对 adaLN DiT 模块的修改,它做了同样的事情。除了回归 γ 和 β,我们还回归了在 DiT 模块内的任何残差连接之前作用的 dimension-wise 的缩放参数 α。初始化 MLP 以输出所有 α 为零向量;这将完整的 DiT 模块初始化为恒等函数。与标准的 adaLN 模块一样,adaLNZero 对模型的 Gflops 增加可以忽略不计。
  • Model size

    • 使用四种配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。它们涵盖了从 0.3 到 118.6 Gflops 不同范围的模型大小和浮点运算分配,使我们能够评估扩展性能。下表提供了配置的详细信息。
      model size
  • Transformer 解码器

    • 在最后的 DiT 模块之后,需要将图像 tokens 序列解码成输出噪声预测和输出对角协方差预测。这两个输出的形状等于原始的空间输入。我们使用标准 linear 解码器来完成这一任务;我们应用最终的层归一化(如果使用 adaLN 则为自适应的)并将每个 token 线性解码成一个 p×p×2C 的张量,其中 C 是输入到 DiT 的空间输入中的通道数。最后,我们将解码后的 tokens 重新排列成它们原始的空间布局,以得到预测的噪声和协方差。

Experiments

训练配置
  • ImageNet, 256x256 或 512x512 训练
  • AdamW, no weight decay
  • constant lr: 1 x 10−4
  • batch size: 256
  • EMA: decay 0.9999
VAE/Diffusion
  • Stable Diffusion 中的 VAE,下采样倍数为 8: 256 × 256 × 3 -> 32 × 32 × 4.
  • tmax=1000
  • linear variance schedule: 1e-4 -> 2e-2
评测指标
  • FID, FID-50k,250 DDPM sampling step

DiT block 消融

  • condition 的加入方式很影响精度:adaLN-Zero 精度最佳,说明权重初始化方式也很重要(让 DiT 的 block 初始化为 identity 函数)。
  • 计算量:in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)
    DiT block

scaling 分析

  • 提升模型计算量稳定涨点
    scaling 基础实验

  • 只要模型计算量接近,FID 就接近。
    scaling

训练效率
  • 更大的 DiT 模型计算效率更高。
    • 训练计算量的评估方式:Gflops · batch size · training steps · 3。因子 3 大致近似于反向传播的计算量是前向传播的两倍。发现,即使训练时间更长,小型 DiT 模型最终相对于训练步数更少的大型 DiT 模型而言,在计算上变得效率低下。同样,我们发现,除了 patch 大小不同之外,其他都相同的模型即使在控制了训练 Gflops 的情况下,也有不同的性能表现。例如,在大约 1 0 10 10^{10} 1010 Gflops 之后,XL/4 的性能被 XL/2 超越。
      训练效率
可视化效果
  • 提升模型计算量可视化效果明显提升
    visualize scaling

class condition 定量分析

  • 达到 sota 效果,比之前的 sota StyleGAN-XL 精度更高
    class2img
    class2img 512px
增加模型参数量 or 增加 sampling 步数
  • 扩散模型的独特之处在于可以通过增加生成图像时的采样步骤来在训练后使用额外的计算资源。
  • 研究了在使用更多采样计算的情况下,较小模型计算的 DiT 是否能够超越较大的模型。结论是增加采样计算的规模无法弥补模型计算能力的不足
    scaling model vs scaling sampling

Thoughts

  • 符合 scaling law 的简洁架构才是王道。scaling law 在 DiT 这的实验效果极佳,和 OpenAI 价值观相符,这应该是作为 Sora 基础工作之一的原因
  • condition 的加入方式可能还需要更多的 class condition 之外的消融实验,比如 image condition、text condition 等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1564012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue 移动端弹窗带滚动效果 滚动到底的时候弹窗下的页面会跟着滑动

<template><div class"wrap" :style"dynamicStyle"><!--dynamicStyle主要是介个 通过computed设置postion的值 弹窗的时候设置为fixed 关闭弹窗的时候设置为unset--><div class"banner-wrap"><img src"/assets/…

C语言一维数组及二维数组详解

引言&#xff1a; 小伙伴们&#xff0c;我发现我正文更新的有些慢&#xff0c;但相信我&#xff0c;每一篇文章真的都很用心在写的&#xff0c;哈哈&#xff0c;在本篇博客当中我们将详细讲解一下C语言中的数组知识&#xff0c;方便大家后续的使用&#xff0c;有不会的也可以当…

ArcGIS Pro打不开Excel?Microsoft驱动程序安装不上?

刚用ArcGIS pro的朋友们可能经常在打开xls或者xlsx文件的时候都会提示&#xff0c;未安装所需的Microsoft驱动程序。 怎么办呢&#xff1f;当然&#xff0c;按照提示装一下驱动就会好吗&#xff1f;有什么状况会出现&#xff1f;有什么临时替代方案呢&#xff1f; 全文目录&a…

ssm017网上花店设计+vue

网上花店的设计与实现 摘 要 网络技术和计算机技术发展至今&#xff0c;已经拥有了深厚的理论基础&#xff0c;并在现实中进行了充分运用&#xff0c;尤其是基于计算机运行的软件更是受到各界的关注。加上现在人们已经步入信息时代&#xff0c;所以对于信息的宣传和管理就很关…

C++ 哈希思想应用:位图,布隆过滤器,哈希切分

C 哈希思想应用:位图,布隆过滤器,哈希切分 一.位图1.位图的概念1.问题2.分析3.位图的概念4.演示 2.位图的操作3.位图的实现1.char类型的数组2.int类型的数组3.解决一开始的问题位图开多大呢?小小补充验证 4.位图的应用1.给定100亿个整数&#xff0c;设计算法找到只出现一次的整…

【Redis】NoSQL之Redis的配置和优化

关系型数据库与非关系型数据库 关系型数据库 关系型数据库是一个结构化的数据库&#xff0c;创建在关系型模型&#xff08;二维表&#xff09;的基础上&#xff1b;一般面向于记录&#xff1b; SQL语句(标准数据查询语句)就是一种基于关系型数据库的语言&#xff0c;用于执行…

转圈游戏(acwing)

题目描述&#xff1a; n 个小伙伴&#xff08;编号从 0 到 n−1&#xff09;围坐一圈玩游戏。 按照顺时针方向给 n 个位置编号&#xff0c;从 0 到 n−1。 最初&#xff0c;第 0 号小伙伴在第 0 号位置&#xff0c;第 1 号小伙伴在第 1 号位置&#xff0c;…

FastAPI Web框架教程 第14章 部署

14-1 在Linux上安装Python 【环境】 腾讯云服务器 Centos 8 【安装方式】 源码编译安装 安装步骤&#xff1a; 第1步&#xff1a;更新yum源 cd /etc/yum.repos.d/ sed -i s/mirrorlist/#mirrorlist/g /etc/yum.repos.d/CentOS-* sed -i s|#baseurlhttp://mirror.centos.…

SV学习笔记(一)

SV&#xff1a;SystemVerilog 开启SV之路 数据类型 內建数据类型 四状态与双状态 &#xff1a; 四状态指0、1、X、Z&#xff0c;包括logic、integer、 reg、 wire。双状态指0、1&#xff0c;包括bit、byte、 shortint、int、longint。 有符号与无符号 &#xff1a; 有符号&am…

ObjectiveC-03-XCode的使用和基础数据类型

本节做为Objective-C的入门课程&#xff0c;笔者会从零基础开始介绍这种程序设计语言的各个方面。 术语 ObjeC&#xff1a;Objective-C的简称&#xff0c;因为完整的名称过长&#xff0c;后续会经缩写来代替&#xff1b;项目/工程&#xff1a;也称工程&#xff0c;指的是一个A…

记某客户的一次无缝数据迁移

背景 客户需要将 Elasticsearch 集群无缝迁移到移动云&#xff0c;迁移过程要保证业务的最小停机时间。 实现方式 通过采用成熟的 INFINI 网关来进行数据的双写&#xff0c;在集群的切换恢复过程中来记录数据变更&#xff0c;待全量数据恢复之后再追平后面增量数据&#xff…

Node.js------Express

◆ 能够使用 express.static( ) 快 速 托 管 静 态 资 源◆ 能够使用 express 路 由 精 简 项 目 结 构◆ 能够使用常见的 express 中间件◆ 能够使用 express 创建API接口◆ 能够在 express 中启用cors跨域资源共享 一.初识Express 1.Express 简介 官方给出的概念&#xff…

Discuz! X3.5苗木_苗木网_苗木价格_苗木求购信息_苗木批发网模板utf-8

适合做苗木行业平台苗木网站、苗木信息网,提供苗木报价、各地苗木求购信息、绿化苗木采购招标、苗木基地展示、苗木百科知识、花木交易及苗木资讯、各地苗木信息网络行情。解压上传到template目录下&#xff0c;后台安装即可&#xff0c;包含PC手机端模板 下载地址&#xff1a;…

Windows 上路由、端口转发配置,跨网络地址段

一、背景 有时候我们会遇到这样的场景&#xff0c;一批同一局域网中只有某一台主机带外且系统为windows&#xff0c;局域网中其他非带外的主机要想访问外网&#xff0c;本文将介绍如何配置在带外主机上开启路由及端口转发。 二、配置操作 2.1、带外主机开启路由转发 1&#x…

QA测试开发工程师面试题满分问答6: 如何判断接口功能正常?从QA的角度设计测试用例

判断接口功能是否正常的方法之一是设计并执行相关的测试用例。下面是从测试QA的角度设计接口测试用例的一些建议&#xff0c;包括功能、边界、异常、链路、上下游和并发等方面&#xff1a; 通过综合考虑这些测试维度&#xff0c;并设计相应的测试用例&#xff0c;可以更全面地评…

一文盘点Mendix在SAP之上的那些事儿

前言 近来接手了2个与SAP有关的低代码案子&#xff0c;客户都会问Mendix和SAP之间怎么回事。 2017年开始Mendix 成为SAP Endorsed APP级别合作伙伴&#xff0c;并再度升级为Solution Extension最高级别。 两家公司风雨同舟七载&#xff0c;服务的全球大客户不胜枚举。 商业…

【嵌入式智能产品开发实战】(十四)—— 政安晨:通过ARM-Linux掌握基本技能【链接静态库与动态库】

目录 链接静态库 动态链接 与地址无关的代码 全局偏移表 延迟绑定 共享库 政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: 嵌入式智能产品开发实战 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论…

穿什么有这么重要?--装饰模式

1.1 穿什么有这么重要&#xff1f; 约会穿什么&#xff1f; "那要看你想给人家什么印象&#xff1f;是比较年轻&#xff0c;还是比较干练&#xff1b;是比较颓废&#xff0c;还是要比较阳光&#xff1b;也有可能你想给人家一种极其难忘的印象&#xff0c;那穿法又大不一样…

算法错题本

这里写目录标题 错题本注意数据的耦合性对于无解情况的处理思路一组数据以0为结束标记&#xff0c;如何输入到数组中&#xff0c;并计数多个数据进行比较链表删除重复元素的启发循环体里谨慎写类型定义并初始化&#xff08;一般写上就是错&#xff09;队列中读取队尾元素数组当…

基于ssm的三省学堂-学习辅助系统(java项目+文档+源码)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的三省学堂-学习辅助系统。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 三省学堂-学习辅助系统的…