ZigMa: DiT风格之字形Mamba扩散模型
论文链接:https://arxiv.org/abs/2403.13802
项目链接:https://taohu.me/zigma/
Abstract
扩散模型长期以来一直受到可扩展性和二次复杂度问题的困扰,特别是在基于Transformer的结构中。在本研究中,我们的目标是利用称为Mamba的状态空间模型的长序列建模能力来扩展其对可视化数据生成的适用性。首先,我们确定了目前大多数基于Mamba的视觉方法中的一个关键疏忽,即在Mamba扫描方案中缺乏对空间连续性的考虑。其次,在此基础上,我们介绍了Zigzag Mamba,这是一种简单、即插即用、最小参数负担、DiT风格的解决方案,与基于Transformer的基线相比,它优于基于Mamba的基线,并展示了更高的速度和内存利用率。最后,我们将Zigzag Mamba与随机插值框架相结合,研究了该模型在大分辨率视觉数据集(如FacesHQ 1024 × 1024和UCF101、MultiModal-CelebA-HQ和MS COCO 256 × 256)上的可扩展性。
1 Introduction
扩散模型在图像处理[67]、视频分析[40]、点云处理[87]和人体姿态估计[29]等各种应用中都取得了重大进展。其中许多模型建立在潜在扩散模型(Latent Diffusion models, LDM)[67]之上,后者通常基于UNet主干。然而,可扩展性仍然是LDM的一个重大挑战[41]。最近,基于Transformer的结构因其可扩展性[9,65]和在多模态训练中的有效性[10]而受到欢迎。值得注意的是,基于Transformer的结构DiT[65]甚至有助于OpenAI增强高保真视频生成模型SORA[64]。尽管通过诸如窗口[58]、滑动[12]、稀疏化[18,46]、散列[19,74]、Ring注意力[14,53]、Flash注意力[21]或它们的组合[8,97]等技术来减轻注意力机制的二次复杂度,但它仍然是扩散模型的瓶颈。
另一方面,状态空间模型[31,32,35]在长序列建模方面显示出巨大的潜力,与基于Transformer的方法竞争。几个已经提出了一些方法[27,30,32,69]来增强状态空间模型的鲁棒性[92]、可扩展性[30]和效率[32,33]。其中,一种名为Mamba的方法[30]旨在通过高效的并行扫描和其他依赖数据的创新来缓解这些问题。然而,Mamba的优势在于一维序列建模,将其扩展到二维图像是一个具有挑战性的问题。先前的研究[57,96]提出了通过计算机层次结构(如行-列-主顺序)直接平坦化二维标记,但这种方法忽略了空间连续性,如图1所示。其他工作[54,60]考虑在单个Mamba块中的各个方向,但这引入了额外的参数和GPU内存负担。在本文中,我们旨在强调Mamba空间连续性的重要性,并提出了几种直观和简单的方法,通过在图像中结合基于连续性的归纳偏置,使Mamba块应用于二维图像。通过对三维序列进行时空分解,将这些方法推广到三维。
最后,Stochastic Interpolant[3]提供了一个更广义的框架,可以统一各种生成模型,包括Normalizing Flow[16]、扩散模型[39,70,72]、Flow matching[4,51,56]和Schrödinger Bridge[52]。以前,一些作品[61]在相对较小的分辨率上探索随机插值,例如256×256, 512×512。在这项工作中,我们的目标是在更复杂的场景中探索它,例如1024 × 1024分辨率,甚至在视频中。
综上所述,我们的贡献如下:首先,我们确定了将Mamba块从一维序列建模推广到二维图像和三维视频建模的空间连续性的关键问题。基于这一见解,我们提出了一个简单的,即插即用,零参数的范式,命名为Zigzag Mamba(ZigMa),利用空间连续性来最大限度地结合视觉数据的归纳偏置。其次,我们将方法从二维扩展到三维,通过分解空间和时间序列来优化性能。其次,我们在扩散模型的范围内提供了围绕Mamba块的综合分析。最后,我们证明了我们设计的Zigzag Mamba优于相关的基于Mamba的基线,代表了大规模图像数据(1024×1024)和视频上随机插值的首次探索。
2 Related Works
Mamba。一些研究[82,83]已经证明状态空间模型在一定条件下具有普遍逼近能力。Mamba作为一种新型的状态空间模型,具有高效建模长序列的优越潜力,在医学成像[60,68,86,89]、图像恢复[34,95]、图形[11]、NLP word byte[80]、表格数据[2]、点云[49]、图像生成[25]等多个领域得到了探索。其中与我们关系最密切的是VisionMamba[57,96]、S4ND[63]和Mamba-ND[48]。VisionMamba[57,96]在判别性任务中使用双向SSM,这导致了很高的计算成本。我们的方法在生成模型中应用了一个简单的替代Mamba扩散。S4ND[63]在Mamba的推理过程中引入了局部卷积,超越了仅使用一维数据。Mamba-ND[48]在判别任务中考虑了多维度,在单个块内使用各种扫描。相比之下,我们的重点是在网络的每一层分布扫描复杂性,从而在零参数负担的情况下最大限度地结合视觉数据的归纳偏置。
扩散模型中的主干。扩散模型主要采用基于Unet[39,67]和基于ViT[9,65]的主干。UNet以高内存需求而闻名[67],而ViT则受益于可扩展性[17,22]和多模式学习[10]。然而,ViT的二次复杂度限制了视觉标记处理,促使人们研究如何缓解这一问题[12,21,84]。我们的工作受到Mamba[30]的启发,探索了一种基于SSM的模型作为通用扩散主干,保留了ViT的模态不可知和顺序建模优势。同时,DiffSSM[90]专注于S4模型中的无条件条件和类条件作用[32]。DIS[25]主要在相对较小的尺度上探索状态空间模型,这并不是我们的工作重点。我们的工作与他们的工作有很大的不同,因为它主要关注使用Mamba块的主干设计,并将其扩展到文本调节。此外,我们将该方法应用于更复杂的视觉数据。
扩散模型中的SDE和ODE。基于分数的生成模型领域包含了基础工作的重要贡献,例如Song等人[71]提出的基于朗格万动力学的分数匹配(SMLD),以及Ho等人提出的带有去噪分数匹配(DDPM)的扩散模型[39]。这些方法在随机微分方程(SDE)的框架内运行,这是Song等人[72]研究中进一步完善的概念。最近的研究进展,如Karras等人[42]和Lee等人[47]所示,展示了使用常微分方程(ODE)采样器进行扩散SDE的有效性,与需要离散化扩散SDE此外,在Flow Matching方法学[51]和Rectified Flow架构[55]领域,SMLD和DDPM都是在概率流ODE框架[72]的不同路径下出现的专门实例。这些模型通常利用线性插值的速度场参数化,这一概念在随机插值框架中得到了更广泛的应用[3],随后的推广扩展到流形设置[13]。SiT模型[61]仔细研究了采样和训练背景下插值方法之间的相互作用,尽管是在较小的分辨率(如512 × 512)的背景下。我们的研究努力将这些见解扩展到更大的范围,专注于1024 × 1024的2D图像和3D视频数据的泛化能力。的传统方法相比,大大降低了采样成本。
3 Method
在本节中,我们首先提供关于状态空间模型的背景信息[31,32,35],并特别关注被称为Mamba的特殊情况[30]。然后,我们强调了Mamba框架内空间连续性的关键问题,并基于这一见解,我们提出之字形Mamba。这种增强旨在通过结合二维数据固有的连续性归纳偏置来提高二维数据建模的效率。此外,我们在Mamba块上设计了一个基本的交叉注意块来实现text-conditioning。随后,我们建议将该方法扩展到3D视频数据,将模型分解为空间和时间维度,从而简化建模过程。最后,我们介绍了用于训练和采样的随机插值的理论方面,这是我们网络架构的基础。
3.1 背景:状态空间模型
状态空间模型(State Space Models, SSM)[31,32,35]已经被证明可以在理论上和经验上处理具有线性缩放w.r.t序列长度的远程依赖[33]。线性状态空间模型的一般形式为:
x
′
(
t
)
=
A
(
t
)
x
(
t
)
+
B
(
t
)
u
(
t
)
y
(
t
)
=
C
(
t
)
x
(
t
)
+
D
(
t
)
u
(
t
)
\begin{aligned} x^{\prime}(t) & =\mathbf{A}(t) x(t)+\mathbf{B}(t) u(t) \\ y(t) & =\mathbf{C}(t) x(t)+\mathbf{D}(t) u(t) \end{aligned}
x′(t)y(t)=A(t)x(t)+B(t)u(t)=C(t)x(t)+D(t)u(t)
通过隐式N-D潜在状态序列
x
(
t
)
∈
R
n
x(t) \in \mathbb{R}^n
x(t)∈Rn,将1-D输入序列
u
(
t
)
∈
R
u(t) \in \mathbb{R}
u(t)∈R映射到1-D输出序列
y
(
t
)
∈
R
y(t) \in \mathbb{R}
y(t)∈R。具体而言,深度SSM寻求在神经序列建模架构中使用这种简单模型的堆栈,其中每层的参数
A
,
B
,
C
\mathbf{A}, \mathbf{B}, \mathbf{C}
A,B,C 和
D
\mathbf{D}
D可以通过梯度下降来学习。
最近,Mamba[30]在保持计算效率的同时,通过放宽SSM参数的时不变约束,在很大程度上提高了SSM的灵活性。通过采用高效的并行扫描,Mamba减轻了重复的顺序性质的影响,而融合GPU操作则消除了实现扩展状态的要求。在本文中,我们专注于探索Mamba在扩散模型中的扫描方案,以最大限度地利用多维视觉数据中的归纳偏置。
3.2 扩散主干:之字形Mamba
DiT-Style网络。我们选择使用AdaLN的ViT框架[65],而不是跳跃层的U-ViT结构[9],因为ViT在文献[10,17,64]中已被验证为可扩展结构。考虑到前面提到的几点,它为图4所示的Mamba网络设计提供了信息。这个设计的核心部分是Zigzag形扫描,这将在下面的段落中解释。
Mamba之字形扫描。先前的研究[81,90]在SSM框架内使用了双向扫描。这种方法已经扩展到包括额外的扫描方向[54,57,91],以考虑二维图像数据的特征。这些方法沿着四个方向展开图像patch,产生四个不同的序列。每个序列随后通过每个SSM一起处理。然而,由于每个方向可能有不同的SSM参数(A、B、C和D),因此增加方向的数量可能会导致内存问题。在这项工作中,我们研究了将Mamba的复杂性分摊到网络的每一层的潜力。
我们的方法围绕着token重新排列的概念,然后将它们馈送到前向扫描块。对于来自层
i
i
i的给定输入特征
z
i
z_i
zi,重排后的前向扫描块的输出特征
z
i
+
1
z_{i+1}
zi+1可以表示为:
z
Ω
i
=
arrange
(
z
i
,
Ω
i
)
z
‾
Ω
i
=
scan
(
z
Ω
)
z
i
+
1
=
arrange
(
z
‾
Ω
i
,
Ω
ˉ
i
)
\begin{align} \mathbf{z}_{\Omega_i} & =\operatorname{arrange}\left(\mathbf{z}_i, \Omega_i\right) \tag{1}\\ \overline{\mathbf{z}}_{\Omega_i} & =\operatorname{scan}\left(\mathbf{z}_{\Omega}\right) \tag{2}\\ \mathbf{z}_{i+1} & =\operatorname{arrange}\left(\overline{\mathbf{z}}_{\Omega_i}, \bar{\Omega}_i\right)\tag{3} \end{align}
zΩizΩizi+1=arrange(zi,Ωi)=scan(zΩ)=arrange(zΩi,Ωˉi)(1)(2)(3)
Ω
i
Ω_i
Ωi表示第
i
i
i层的1D排列,将patch token的顺序重新排列
Ω
i
Ω_i
Ωi,
Ω
i
Ω_i
Ωi和
Ω
ˉ
i
\bar{Ω}_i
Ωˉi表示相反的操作。这确保了
z
i
z_i
zi和
z
i
+
1
z_{i+1}
zi+1都保持了原始图像token的采样顺序。
现在我们探索 Ω i Ω_i Ωi操作的设计,考虑来自2D图像的附加归纳偏置。我们提出了一个关键属性:空间连续性。在空间连续性方面,目前图像中Mamba的创新[54,57,96]经常直接按照计算机层次结构(如行-列-主顺序)挤压二维patch token。然而,这种方法对于将归纳偏置与相邻标记结合起来可能不是最优的,如图3所示。为了解决这个问题,我们引入了一种新的扫描方案,旨在保持扫描过程中的空间连续性。另外,我们考虑空间填充,即对于大小为 N × N N × N N×N的patch, 1D连续扫描方案的长度应为N2。这有助于有效地合并token,以最大限度地发挥Mamba块中长序列建模的潜力。
为了实现上述性质,我们启发式地设计了8种可能的空间填充连续方案1,记为 S j S_j Sj(其中 j ∈ [ 0 , 7 ] j∈[0,7] j∈[0,7]),如图3所示。虽然可能还有其他可想到的方案,但为了简单起见,我们将使用限制在这八个方案。因此,每一层的方案可以表示为 Ω i = S { i % 8 } Ω_i=S_{\{i \% 8\}} Ωi=S{i%8},其中%表示模算子。
1我们还尝试了更复杂的连续空间填充路径,如Hilbert空间填充曲线[62]。然而,实证结果表明,这种方法可能导致结果恶化。有关更详细的比较,请参阅附录。
在Zigzag Mamba上部署文本条件。虽然Mamba提供了高效的长序列建模的优势,但它是以牺牲注意力机制为代价的。因此,将文本条件作用纳入基于Mamba的扩散模型的探索有限。为了解决这个问题,我们提出了一个简单的跨注意力块,在Mamba块上构建跳跃层,如图4所示。这种设计不仅支持长序列建模,还支持多token条件,例如text-conditioning。此外,它有可能提供可解释性[15,38,75],因为交叉注意已被用于扩散模型。
通过分解空间和时间信息将其推广到3D视频。在前面的章节中,我们的重点是空间二维Mamba,在那里我们设计了几个空间连续的,空间填充的二维扫描方案。在本节中,我们的目标是利用这一经验来帮助设计3D视频处理的相应机制。如图5所示,我们从传统的定向Mamba开始我们的设计过程。给定一个视频特征输入 z ∈ R B × T × C × W × H z∈R^{B×T ×C×W×H} z∈RB×T×C×W×H,我们提出了视频Mamba块的三种变体,以促进3D视频的生成。
(a) 扫描:在这种方法中,我们直接将3D特征 z z z平面化,而不考虑空间或时间的连续性。值得注意的是,扁平化过程遵循计算机层次结构顺序,这意味着扁平化表示中没有保留连续性。
(b) 3D之字形:与前几节二维之字形的表述相比,我们采用类似的设计将其推广到三维之字形,以保持二维和三维同时的连续性。潜在地,该方案具有更大的复杂性。我们也启发式地列出了8种方案。然而,我们的经验发现这种方案会导致次优优化。
© 分解3D之字形= 2D之字形+ 1D扫描:为了解决次优优化问题,我们建议将空间和时间相关性分解为单独的Mamba块。它们的使用顺序可以根据需要进行调整,例如,“sstt”或“ststst”,其中“s”表示空间之字形Mamba,“t”表示时间之字形Mamba。对于一维时间扫描,我们简单地选择向前和向后扫描,因为在时间轴上只有一个维度。
计算分析。对于视觉序列
T
∈
R
1
×
M
×
D
T∈R^{1×M×D}
T∈R1×M×D,全局自注意力和k方向Mamba和我们的之字形Mamba的计算复杂度如下:
ζ
(
self-attention
)
=
4
M
D
2
+
2
M
2
D
ζ
(
k-mamba
)
=
k
×
[
3
M
(
2
D
)
N
+
M
(
2
D
)
N
2
]
ζ
(
zigzag
)
=
3
M
(
2
D
)
N
+
M
(
2
D
)
N
2
\begin{align} & \zeta(\text { self-attention })=4 \mathrm{MD}^2+2 \mathrm{M}^2 \mathrm{D} \tag{4}\\ & \zeta(\text { k-mamba })=k \times\left[3 \mathrm{M}(2 \mathrm{D}) \mathrm{N}+\mathrm{M}(2 \mathrm{D}) \mathrm{N}^2\right] \tag{5}\\ & \zeta(\text { zigzag })=3 \mathrm{M}(2 \mathrm{D}) \mathrm{N}+\mathrm{M}(2 \mathrm{D}) \mathrm{N}^2 \tag{6} \end{align}
ζ( self-attention )=4MD2+2M2Dζ( k-mamba )=k×[3M(2D)N+M(2D)N2]ζ( zigzag )=3M(2D)N+M(2D)N2(4)(5)(6)
其中,self-attention相对于序列长度M呈现二次复杂度,而Mamba呈现线性复杂度(N为固定参数,默认为16)。这里,k表示单个Mamba块的扫描方向数。因此,k-Mamba和之字形在自我注意方面具有线性复杂性。此外,我们的之字形方法可以消除k序列,进一步降低整体复杂度。
在完成设计之字形Mamba网络,以改善视觉诱导偏置集成,我们继续结合它与一个新的扩散框架,如下所示。
3.3 扩散框架:随机插值
基于向量
v
\mathbf{v}
v和分数
s
s
s进行抽样。继[3,76]之后,
x
t
\mathbf{x}_t
xt的时间相关概率分布
p
t
(
x
)
p_t(\mathbf{x})
pt(x)也与逆时SDE的分布相吻合[6]:
d
X
t
=
v
(
X
t
,
t
)
d
t
+
1
2
w
t
s
(
X
t
,
t
)
d
t
+
w
t
d
W
‾
t
(7)
d \mathbf{X}_t=\mathbf{v}\left(\mathbf{X}_t, t\right) d t+\frac{1}{2} w_t \mathbf{s}\left(\mathbf{X}_t, t\right) d t+\sqrt{w_t} d \overline{\mathbf{W}}_t \tag{7}
dXt=v(Xt,t)dt+21wts(Xt,t)dt+wtdWt(7)
其中
W
‾
t
\overline{\mathbf{W}}_t
Wt为逆时维纳过程,
w
t
>
0
w_t>0
wt>0为任意时变扩散系数,
s
(
x
,
t
)
=
∇
log
p
t
(
x
)
\mathbf{s}(\mathbf{x}, t)=\nabla \log p_t(\mathbf{x})
s(x,t)=∇logpt(x)为分数,
v
(
x
,
t
)
\mathbf{v}(\mathbf{x}, t)
v(x,t)由条件期望给出
v
(
x
,
t
)
=
E
[
x
˙
t
∣
x
t
=
x
]
=
α
˙
t
E
[
x
∗
∣
x
t
=
x
]
+
σ
˙
t
E
[
ε
∣
x
t
=
x
]
(8)
\begin{aligned} \mathbf{v}(\mathbf{x}, t) & =\mathbb{E}\left[\dot{\mathbf{x}}_t \mid \mathbf{x}_t=\mathbf{x}\right] \\ & =\dot{\alpha}_t \mathbb{E}\left[\mathbf{x}_* \mid \mathbf{x}_t=\mathbf{x}\right]+\dot{\sigma}_t \mathbb{E}\left[\varepsilon \mid \mathbf{x}_t=\mathbf{x}\right] \end{aligned} \tag{8}
v(x,t)=E[x˙t∣xt=x]=α˙tE[x∗∣xt=x]+σ˙tE[ε∣xt=x](8)
式中
α
t
\alpha_t
αt是
t
t
t的递减函数,
σ
t
\sigma_t
σt是
t
t
t的递增函数。其中,
α
˙
t
\dot{\alpha}_t
α˙t和
σ
˙
t
\dot{\sigma}_t
σ˙t分别表示
α
t
\alpha_t
αt和
σ
t
\sigma_t
σt的时间导数。
只要我们能够估计速度 v ( x , t ) \mathbf{v}(\mathbf{x}, t) v(x,t)和/或分数 s ( x , t ) \mathbf{s}(\mathbf{x}, t) s(x,t)场,我们就可以通过概率流ODE[72]或逆时SDE(7)将其用于采样过程。从 X T = ε ∼ N ( 0 , I ) \mathbf{X}_T=\varepsilon \sim \mathcal{N}(0, \mathbf{I}) XT=ε∼N(0,I)在时间上向后求解逆时SDE(7),可以从近似的数据分布 p 0 ( x ) ∼ p ( x ) p_0(\mathbf{x}) \sim p(\mathbf{x}) p0(x)∼p(x)生成样本。在采样过程中,我们可以直接从ODE或SDE进行采样,以平衡采样速度和保真度。如果我们选择进行ODE采样,我们可以简单地通过将噪声项 s s s设置为零来实现这一点。
在[3]中表明,在实践中需要对两个量中的一个
s
θ
(
x
,
t
)
\mathbf{s}_\theta(\mathbf{x}, t)
sθ(x,t)和
v
θ
(
x
,
t
)
\mathbf{v}_\theta(\mathbf{x}, t)
vθ(x,t)进行估计。这是直接从约束推导出来的
x
=
E
[
x
t
∣
x
t
=
x
]
=
α
t
E
[
x
∗
∣
x
t
=
x
]
+
σ
t
E
[
ε
∣
x
t
=
x
]
(9)
\begin{aligned} \mathbf{x} & =\mathbb{E}\left[\mathbf{x}_t \mid \mathbf{x}_t=\mathbf{x}\right] \\ & =\alpha_t \mathbb{E}\left[\mathbf{x}_* \mid \mathbf{x}_t=\mathbf{x}\right]+\sigma_t \mathbb{E}\left[\varepsilon \mid \mathbf{x}_t=\mathbf{x}\right] \end{aligned} \tag{9}
x=E[xt∣xt=x]=αtE[x∗∣xt=x]+σtE[ε∣xt=x](9)
用速度
v
(
x
,
t
)
\mathbf{v}(\mathbf{x}, t)
v(x,t)来重新表示分数
s
(
x
,
t
)
\mathbf{s}(\mathbf{x}, t)
s(x,t)等于
s
(
x
,
t
)
=
σ
t
−
1
α
t
v
(
x
,
t
)
−
α
˙
t
x
α
˙
t
σ
t
−
α
t
σ
˙
t
(10)
\mathbf{s}(\mathbf{x}, t)=\sigma_t^{-1} \frac{\alpha_t \mathbf{v}(\mathbf{x}, t)-\dot{\alpha}_t \mathbf{x}}{\dot{\alpha}_t \sigma_t-\alpha_t \dot{\sigma}_t} \tag{10}
s(x,t)=σt−1α˙tσt−αtσ˙tαtv(x,t)−α˙tx(10)
因此,
v
(
x
,
t
)
\mathbf{v}(\mathbf{x}, t)
v(x,t)和
s
(
x
,
t
)
\mathbf{s}(\mathbf{x}, t)
s(x,t)可以相互转换。我们将在下面演示如何计算它们。
分数s和速度v的估计。基于分数的扩散模型[72]表明,分数可以参数化地估计为
s
θ
(
x
,
t
)
\mathbf{s}_\theta(\mathbf{x}, t)
sθ(x,t),利用损失
L
s
(
θ
)
=
∫
0
T
E
[
∥
σ
t
s
θ
(
x
t
,
t
)
+
ε
∥
2
]
d
t
(11)
\mathcal{L}_{\mathrm{s}}(\theta)=\int_0^T \mathbb{E}\left[\left\|\sigma_t \mathbf{s}_\theta\left(\mathbf{x}_t, t\right)+\varepsilon\right\|^2\right] \mathrm{d} t \tag{11}
Ls(θ)=∫0TE[∥σtsθ(xt,t)+ε∥2]dt(11)
同样,速度
v
(
x
,
t
)
\mathbf{v}(\mathbf{x}, t)
v(x,t)可以通过损失参数化地估计为
v
θ
(
x
,
t
)
\mathbf{v}_\theta(\mathbf{x}, t)
vθ(x,t)
L
v
(
θ
)
=
∫
0
T
E
[
∥
v
θ
(
x
t
,
t
)
−
α
˙
t
x
∗
−
σ
˙
t
ε
∥
2
]
d
t
(12)
\mathcal{L}_{\mathrm{v}}(\theta)=\int_0^T \mathbb{E}\left[\left\|\mathbf{v}_\theta\left(\mathbf{x}_t, t\right)-\dot{\alpha}_t \mathbf{x}_*-\dot{\sigma}_t \varepsilon\right\|^2\right] \mathrm{d} t \tag{12}
Lv(θ)=∫0TE[∥vθ(xt,t)−α˙tx∗−σ˙tε∥2]dt(12)
其中
θ
\theta
θ表示我们在上一节描述的Zigzag Mamba网络,我们采用线性路径进行训练,因为它的简单性和相对的直线轨迹:
α
t
=
1
−
t
,
σ
t
=
t
(13)
\alpha_t=1-t, \quad \sigma_t=t \tag{13}
αt=1−t,σt=t(13)
我们注意到,在(11)和(12)的积分下可以包含任何与时间相关的权重。当T变大时,这些权重因素在基于分数的模型中起着至关重要的作用[44,45]。因此,它们提供了一个同时考虑时变权重和随机性的一般形式。
4 Experiment
在本节中,我们首先详细介绍有关图像和视频数据集的实验设置,以及我们的训练细节。随后,我们深入研究了几个深入的分析,旨在阐明我们在各种分辨率下设计方法的基本原理。最后,我们介绍了从更高分辨率和更复杂的数据集获得的结果。
4.1 数据集和训练细节
图像数据集。为了探索高分辨率下的可扩展性,我们在FacesHQ 1024×1024上进行了实验。我们用于训练和消融的一般数据集是FacesHQ,它是CelebA-HQ[88]和FFHQ[43]的汇编,在以前的工作中使用,如[24,26]。
对于文本条件生成,我们在MultiModalCelebA 2562、5122[88]和MS COCO 256×256[50]数据集上进行了实验。两个数据集都由文本-图像对组成,用于训练。通常,在COCO和MultiModal-CelebA中,每张图像有5到10个标题。我们使用CLIP文本编码器[66]在Stable Diffusion[67]之后将离散文本转换为token序列。然后将这些token作为标记序列输入到网络中。
视频数据集。UCF101数据集由13320个视频片段组成,分为101个类。这些视频剪辑的总长度超过27小时。这些视频都是从YouTube上收集的,固定帧率为25 FPS,分辨率为320 × 240。我们随机采样连续16帧,并将帧大小调整为256 × 256。
训练细节。我们统一使用学习率为1e−4的AdamW[59]优化器。为了提取潜在特征,我们使用了现成的VAE编码器。为了减少计算成本,我们采用了混合精度训练方法。此外,我们应用了阈值为2.0和权重衰减为0.01的梯度裁剪,以防止Mamba训练期间NaN的发生。我们的大部分实验都是在4个A100 GPU上进行的,可扩展性探索扩展到16和32个A100 GPU。在采样方面,出于速度考虑,我们采用ODE采样。有关详情,请参阅附录9.5。
4.2 消融研究
扫描方案消融。表1给出了基于不同分辨率MultiModal-CelebA数据集消融研究的几个重要发现。首先,将扫描方案从扫描转换为之字形,可以获得一些增益。其次,当我们将之字形方案从1增加到8时,我们看到了一致的收益。这表明在不同的块中交替扫描方案是有益的。最后,与较低分辨率(256 × 256,或较短序列token数)相比,Zigzag-1和Zigzag-8在较高分辨率(512 × 512,或较长的序列token数)下的相对增益更为突出,这表明在较长的序列token数下,Zigzag-1和Zigzag-8具有更大的潜力和更有效的归纳偏置融合。
空间连续性至关重要。我们首先通过将大小为N ×N的patch分组为2 × 2、4 × 4、8 × 8和16 × 16,探索了空间连续性在Mamba设计中的重要性,得到了大小分别为N/2 ×N /2、N/4 ×N /4、N/8 ×N /8和N/16 ×N /16的patch组。然后,我们将设计的Zigzag-8方案应用于group级而不是patch级。图6表明,随着空间连续性的增加,性能得到了显著提高。此外,我们将我们的方法与N × N块的随机洗牌进行了比较,发现在随机洗牌条件下的性能明显较差。所有这些结果共同表明,空间连续性是一个关键的要求,当应用Mamba在二维序列。
网络与FPS/GPU内存消融性研究。在图7 (a,b)中,我们分析了全局patch尺寸从32 × 32到196 × 196变化时的转发速度和GPU内存使用情况。对于速度分析,我们报告帧每秒(FPS)而不是FLOPS,因为FPS提供了更明确和适当的速度2评估。为了简单起见,我们统一应用zigzag-1 Mamba扫描方案,并在具有80GB内存的A100 GPU上使用batch size=1和patch size=1。值得注意的是,为了公平比较,所有方法都共享几乎相同的参数数。我们主要将我们的方法与两种流行的基于Transformer的扩散主干,U-ViT[9]和DiT[65]进行比较。显然,我们的方法在逐渐增加patch数量的情况下可以获得最佳的FPS和GPU利用率。当patch数为196时,U-ViT表现出最差的性能,甚至超过了内存界限。令人惊讶的是,DiT的GPU利用率与我们的方法接近,从实用的角度支持我们对DiT的主要选择。
此外,在图7 ©中,我们在不同的方法变体中对GPU内存和FPS进行了消融研究。我们发现,当逐渐增加Mamba扫描方案时,我们的方法几乎不会产生FPS和GPU内存负担。该分析提供了不同扩散骨干的性能和效率比较的见解,突出了我们提出的方法的优势。
Order感受野。我们提出了一个基于Mamba的多维数据结构的新概念。考虑到多维数据中可能存在各种空间连续之字形路径,我们引入了术语“Order感受野”,表示网络设计中明确使用的之字形路径的数量。
Order感受野与FPS/GPU内存的消融性研究。如图10(实际为图8)所示,Zigzag Mamba始终保持其GPU内存消耗和FPS速率,即使逐渐增加Order感受野。相比之下,我们的主要基线平行Mamba,以及双向Mamba和视觉Mamba等变体[57,96],由于参数的增加,FPS持续下降。值得注意的是,Zigzag Mamba的Order感受野为8,在不改变参数的情况下可以执行得更快。
patch的大小。我们在图7 (d)中对1、2、4、8的patch尺寸进行了消融研究,旨在探索它们在Mamba框架下的行为。结果表明,随着贴片尺寸的增加,FID会恶化,这与在Transformer领域观察到的共识一致[23,78]。这表明较小的patch大小对于最佳性能至关重要。
4.3 主要结果
在1024×1024 FacesHQ的主要结果。为了详细说明我们的方法在Mamba和随机插值框架内的可扩展性,我们在表3中提供了高分辨率数据集(1024×1024 FacesHQ)的比较。我们的主要比较对象是双向Mamba,这是一种将Mamba应用于二维图像数据的常用解决方案[57,96]。为了研究Mamba在高达1024的大分辨率下的可扩展性,我们在128 × 128的潜在空间上使用扩散模型,patch大小为2,得到4,096个token。该网络在16个A100 GPU上进行训练。值得注意的是,与双向Mamba相比,我们的方法显示出更好的结果。有关损失和FID曲线的详细信息见附录9.3。虽然受GPU资源限制的限制,防止更长的训练时间,我们预计双向Mamba与延长的训练时间一致优于。
COCO数据集。为了进一步比较我们的方法的性能,我们还在更复杂和常见的数据集MS COCO上对其进行了评估。我们与表2中的双向Mamba作为基线进行比较。应该注意的是,为了公平比较,所有方法都共享几乎相同的参数数。我们使用16个A100 GPU训练所有方法。详情请参阅附录9.5。如表2所示,我们的Zigzag-8方法优于双向Mamba和Zigzag-1。这表明,摊销各种扫描方案可以产生显著的改进,归因于更好地结合了Mamba二维图像的感应偏置。
UCF101数据集。在表4中,我们展示了我们在UCF101数据集上的结果,使用4个A100 GPU训练所有方法,并使用16个A100 GPU进行了进一步的可扩展性探索。我们主要将我们的方法与双向Mamba[96]进行一致性比较。关于3DZigzag Mamba的选择,请参阅附录9.5。针对因式三维之字形Mamba在视频处理中的应用,采用sst方案对时空建模进行因式分解。该方案将空间信息复杂性优先于时间信息,并假设在时间域中存在冗余。我们的结果一致地证明了我们的方法在各种情况下的优越性能,强调了我们方法的复杂性和有效性。
可视化。我们在图9中展示了我们在FacesHQ 1024和MultiModal-CelebA 512上的最佳结果的图像可视化。关于视频的可视化,请参见附录9.1。很明显,可视化在各种分辨率下都是令人愉悦的,这表明我们的方法是有效的。
5 Conclusion
在本文中,我们提出了在随机插值框架内发展的Zigzag Mamba扩散模型。我们最初的重点是解决空间连续性的关键问题。然后,我们设计了一个Zigzag Mamba块,以更好地利用2D图像中的归纳偏置。此外,我们将3D Mamba分解为2D和1D之字形Mamba,以方便优化。我们经验性地设计了各种消融研究来检验不同的因素。这种方法允许对随机插值理论进行更深入的探索。我们希望我们的努力能对Mamba网络设计的进一步探索有所启发。
6 Acknowledgement
7 Limitations and Future Work
我们的方法完全依赖于具有DiT风格布局和调节方式的Mamba块。然而,我们工作的一个潜在限制是,我们不能详尽地列出给定特定全局patch大小的所有可能的空间连续之字形扫描方案。目前,我们根据经验设置这些扫描方案,这可能导致性能次优。此外,由于GPU资源限制,我们无法探索更长的训练持续时间,尽管我们预计会得出类似的结论。
对于未来的工作,我们的目标是深入到各种应用之字形Mamba,利用其可扩展性的长序列建模。这种探索可能会提高跨不同领域和应用程序对Mamba框架的利用。
8 Impact Statement
这项工作旨在增强可扩展性,并在扩散模型框架内释放Mamba算法的潜力,从而能够生成高保真的大型图像。通过将交叉注意机制整合到Mamba块中,我们的方法还可以促进文本到图像的生成。然而,就像其他旨在增强大规模图像合成模型的能力和控制的努力一样,我们的方法有可能产生有害或欺骗性的内容。因此,必须实施道德考虑和保障措施来减轻这些风险。
9 Appendix
9.1 可视化
图16中的FacesHQ 1024 × 1024无策化可视化。
MS-COCO无策化可视化。我们在图15中可视化了样本。
9.2 扫描方案与位置嵌入的新结果
我们还对各种因素进行了基本的消融,包括位置嵌入和各种希尔伯特空间填充曲线。与主论文中的实验不同,我们在无条件的MultiModal-ElebA256数据集上进行了这些实验,以进行统一比较。我们训练网络100000步。
关于位置嵌入的消融。如表5所示,可学习嵌入优于正弦嵌入,正弦嵌入优于无位置嵌入。在各种情况下,我们的之字形方法超过了基线。值得注意的是,无论我们使用正弦位置嵌入还是不使用位置嵌入,我们的性能几乎保持不变。这表明,与我们的基线相比,我们的方法可以更好地结合空间诱导偏置。最后,使用可学习的位置嵌入提供了进一步的(尽管是边际的)增益,这表明即使在我们的锯齿形扫描方案中也存在更好的位置嵌入。
希尔伯特空间填充曲线的探索。首先,我们消去希尔伯特扫描曲线[62],如图12所示。考虑到不同的角度和起点,这种扫描也有8种变体。我们用与之字形扫描类似的方式重新排列它们。为了公平比较,所有参数保持一致。我们利用吉尔伯特算法来保证希尔伯特曲线在任何大小的正方形上都是连续的。我们在单个A100-SXM4-80GB上训练网络进行120k次迭代。
我们在5000张图像上评估了固定步骤的FID, FID曲线如图11所示。
虽然希尔伯特空间填充曲线比我们的之字形扫描提供了更多的局部性,并保持了连续性,但其复杂的结构似乎阻碍了SSM在平坦序列上工作的能力,导致比我们的之字形曲线在自然图像上的归纳偏置更差。因此,我们假设在生成任务中,结构可能比局部性更重要。
希尔伯特曲线很难优化。我们在表6中显示了结果。我们可以观察到希尔伯特扫描路径的性能显著下降,即使我们降低接收野的阶数(ORF)。这证实了希尔伯特扫描路径难以优化的假设,即使只考虑希尔伯特扫描的两种不同方案。
不同Order感受野的影响。我们在表7中显示了Order感受野(ORF)的影响。我们观察到将ORF增加到4可以提高性能。性能在ORF=4附近趋于平稳。由于没有额外的参数或GPU内存负担,我们建议从业者选择最大的ORF(在我们的示例中为8)。
另一种解释:之字形扫描是最简单的皮亚诺曲线。我们的z字形扫描可以看作是最简单的Peano曲线,如图13所示。
9.3 2D可视化数据的新结果
ZigMa模型的变体。我们在表9中列出了模型的变体。我们使用Base (B)模型作为默认值。使用交叉注意模型是可选的,因为该模块会引入一些参数和速度负担。然而,注意力优化方面的任何进步都可以无缝地集成到我们的模型中。
模型复杂度与FPS/GPU内存的消蚀研究。如图10所示。我们的方法在加入感受顺序后可以获得更好的参数效率。感受顺序指的是二维图像中累积的空间连续之字形扫描路径,我们将其作为归纳偏置纳入Mamba。我们在图10中列出了逐渐增加感受顺序时的参数消耗。感受顺序指的是二维图像中累积的空间连续之字形扫描路径,我们将其作为归纳偏置纳入Mamba。
损失和FID曲线。训练损失曲线和FID曲线如图14所示。损失和FID显示了相同的趋势,我们的锯齿Mamba始终优于其他基线,如Sweep-1和Sweep-2。
In-context v.s. Cross Attention 在表8中,我们比较了我们的交叉注意和语境内注意。对于上下文注意,我们将文本token与图像token连接起来,并将它们输入到Mamba块中。我们的研究结果表明,上下文注意比交叉注意表现得更差。我们假设这是由于文本token和图像patch token之间的不连续性。
9.4 3D可视化数据的新结果
选择3D之字形Mamba。针对因式三维之字形Mamba在视频处理中的应用,采用sst方案对时空建模进行因式分解。该方案优先考虑空间信息(ss)的复杂性,而不是时间信息(t),假设冗余存在于时间域。还有许多其他可能的s和t的组合需要探索,我们将其留给未来的工作。
9.5 更多详细信息
更多相关工作 一些研究[82,83]已经证明了状态空间模型在一定条件下具有通用逼近能力。Mamba作为一种新的状态空间模型,具有高效建模长序列的优越潜力,已在医学成像[28,60,68,86,89]、图像恢复[34,95]、图形[11,79]、NLP word byte[80]、表格数据[2]、人体运动合成[94]、点云[49,93]、图像生成[25]、半监督学习[85]、可解释性[5]、图像去雾[95]和泛锐化[37]等多个领域进行了探索。已扩展到混合专家[7]、谱空间[1]、多维度[48、57、63、96]和密集连接[36]。其中与我们关系最密切的是VisionMamba[57,96]、S4ND[63]和Mamba-ND[48]。VisionMamba[57,96]在判别性任务中使用双向SSM,这导致了很高的计算成本。我们的方法在生成模型中应用了一个简单的替代Mamba扩散。S4ND[63]在Mamba的推理过程中引入了局部卷积,超越了仅使用一维数据。Mamba-ND[48]在判别任务中考虑了多维度,在单个块内使用各种扫描。相比之下,我们的重点是在网络的每一层分布扫描复杂性,从而在零参数负担的情况下最大限度地结合视觉数据的归纳偏置。
双重索引问题
Ω
i
Ω_i
Ωi。如图2所示。为了实现空间连续Mamba推理,我们需要进行排序和重排操作,需要沿着token数维度进行索引,考虑到token数较大,索引会很耗时,我们可以将排序和重排操作表述如下:
Ω
i
′
=
Ω
ˉ
i
−
1
⋅
Ω
i
z
i
+
1
=
scan
(
z
Ω
i
′
)
\begin{align} \Omega_i^{\prime} & =\bar{\Omega}_{i-1} \cdot \Omega_i \tag{14}\\ \mathbf{z}_{i+1} & =\operatorname{scan}\left(\mathbf{z}_{\Omega_i^{\prime}}\right)\tag{15}\\ \end{align}
Ωi′zi+1=Ωˉi−1⋅Ωi=scan(zΩi′)(14)(15)
缺公式16
其中
Ω
ˉ
−
1
=
I
\bar{\Omega}_{-1}=I
Ωˉ−1=I,假设基于Mamba的网络与token的顺序是排列等变的。它们需要的索引操作减少了50%,为了更清晰的比较,我们在这里重申这一点:
z
Ω
i
=
arrange
(
z
i
,
Ω
i
)
z
‾
Ω
i
=
scan
(
z
Ω
)
z
i
+
1
=
arrange
(
z
‾
Ω
i
,
Ω
ˉ
i
)
\begin{align} \mathbf{z}_{\Omega_i} & =\operatorname{arrange}\left(\mathbf{z}_i, \Omega_i\right) \tag{17}\\ \overline{\mathbf{z}}_{\Omega_i} & =\operatorname{scan}\left(\mathbf{z}_{\Omega}\right) \tag{18}\\ \mathbf{z}_{i+1} & =\operatorname{arrange}\left(\overline{\mathbf{z}}_{\Omega_i}, \bar{\Omega}_i\right)\tag{19} \end{align}
zΩizΩizi+1=arrange(zi,Ωi)=scan(zΩ)=arrange(zΩi,Ωˉi)(17)(18)(19)
评价指标。对于图像级别的保真度,我们使用既定的度量,如Fréchet Inception Distance (FID) 和 Kernel Inception Distance (KID),遵循之前的工作。然而,由于研究[20,73]表明FID不能完全反映基于人类的意见,我们也采用了使用官方存储库的Fréchet DINOv2 Distance(FDD) 。我们的方法主要包括采样5000张真实图像和5000张假图像来计算相关指标。
我们主要考虑视频保真度评估的两个指标:framewise FID和 Fréchet Video Distance (FVD) [77]。我们对200个视频进行采样,并基于这些样本计算相应的指标。
各数据集的训练参数如表10所示。我们不应用任何位置编码,因为与Transformer不同,Mamba不是排列不变的。因此,它的位置是根据它在Mamba中的顺序自动编码的。对于COCO数据集,0.01的权重衰减可以产生边际FID增益(大约为0.8)。
时间步长和提示的调节。算法1说明了调节过程。