【论文笔记】Scalable Diffusion Models with State Space Backbone

news2024/11/15 21:43:57

原文链接:https://arxiv.org/abs/2402.05608

1. 引言

主干网络是扩散模型发展的关键方面,其中基于CNN的U-Net(下采样-跳跃连接-上采样)和基于Transformer的结构(使用自注意力替换采样块)是代表性的例子。

状态空间模型(SSM)在长序列建模方面有极大潜力。本文受Mamba启发,建立基于SSM的扩散模型,称为DiS。DiS将所有输入(时间、条件和有噪声的图像patch)视为离散token。DiS中的状态空间模型使其比CNN和Transformer有更优的放缩性,且有更低的计算开销。

2. 方法

2.1 准备知识

扩散模型:扩散模型逐步向数据加入噪声,然后将此过程反过来从噪声生成数据。噪声的加入过程称为前向过程,可表达为马尔科夫链。逆过程中,使用高斯模型近似真实逆转移,其中学习相当于对噪声的预测(即使用噪声预测网络,来最小化噪声预测目标)。

条件扩散模型会将条件(如类别、文本等,通常形式为索引或连续嵌入)引入噪声预测目标中。

具体公式见扩散模型(Diffusion Model)简介 - CSDN。

状态空间主干:状态空间模型的传统定义是将 x ( t ) ∈ R N x(t)\in\mathbb R^N x(t)RN通过隐状态 h ( t ) ∈ R N h(t)\in\mathbb R^N h(t)RN映射为 y ( t ) ∈ R N y(t)\in\mathbb R^N y(t)RN的线性时不变系统:
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) h'(t)=Ah(t)+Bx(t)\\y(t)=Ch(t) h(t)=Ah(t)+Bx(t)y(t)=Ch(t)

其中 A ∈ R N × N A\in\mathbb R^{N\times N} ARN×N为状态矩阵, B , C ∈ R N B,C\in\mathbb R^N B,CRN为输入和输出矩阵。真实世界的数据通常为离散形式,可将上式离散化为
h t = A ˉ h t − 1 + B ˉ x t y t = C h t h_t=\bar Ah_{t-1}+\bar Bx_t\\y_t=Ch_t ht=Aˉht1+Bˉxtyt=Cht

其中 A ˉ = exp ⁡ ( Δ ⋅ A ) , B ˉ = ( Δ ⋅ A ) − 1 ( exp ⁡ ( Δ ⋅ A ) − I ) ⋅ ( Δ B ) \bar A=\exp(\Delta\cdot A),\bar B=(\Delta\cdot A)^{-1}(\exp(\Delta\cdot A)-I)\cdot(\Delta B) Aˉ=exp(ΔA),Bˉ=(ΔA)1(exp(ΔA)I)(ΔB)为离散状态参数, Δ \Delta Δ为离散步长。

虽然SSM理论上性质优良,但通常有高计算量和数值不稳定性。结构状态空间模型(S4)通过强制 A A A的形式来减轻这一问题,能达到比Transformer更高的性能;Mamba则进一步通过输入依赖的选择机制和更快的硬件感知算法改进之。

2.2 模型结构设计

DiS参数化噪声预测网络 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t,t,c) ϵθ(xt,t,c),以时间 t t t、条件 c c c和噪声图像 x t x_t xt,预测向 x t x_t xt加入的噪声。DiS基于双向Mamba结构,如下图所示。
在这里插入图片描述
图像patch化:DiS的第一层将输入图像 I ∈ R H × W × C I\in\mathbb R^{H\times W\times C} IRH×W×C转化为拉直的2D patch X ∈ R J × ( p 2 ⋅ C ) X\in\mathbb R^{J\times (p^2\cdot C)} XRJ×(p2C)。然后,通过对每个patch进行线性嵌入,转化为含 J J J个token的、维度为 D D D的序列。为每个输入token使用可学习位置编码。 J = H × W p 2 J=\frac{H\times W}{p^2} J=p2H×W由patch大小 p p p决定。

SSM块:输入token会被一组SSM块处理。SSM块的输入还包括时间 t t t与条件 c c c。本文使用双向序列建模,即SSM块的前向过程包含了前向和反向两个方向的处理。

跳跃连接:本文将 L L L个SSM块分为前半和后半两部分,每部分 ⌊ L 2 ⌋ \lfloor\frac L2\rfloor 2L个。设 h s h a l l o w , h d e e p ∈ R J × D h_{shallow},h_{deep}\in\mathbb{R}^{J\times D} hshallow,hdeepRJ×D分别为跳跃连接分支和主分支的隐状态,则通过拼接和线性投影后再送入下一个SSM块,即 L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))

线性解码器:需要将最后一个SSM块的隐状态解码为噪声预测和对角化协方差矩阵(与原始输入尺寸相同)。本文使用线性解码器,即LayerNorm+线性层,将每个token转化为 p 2 ⋅ C p^2\cdot C p2C的张量。最后,将解码的token重排为原始大小,得到预测噪声与协方差。

条件引入:本文在输入token的序列上增加时间 t t t与条件 c c c的向量嵌入作为额外token(类似ViT中的类别token),从而无需修改SSM块。在最后一个SSM块后,从序列移除条件token。此外,还用自适应归一化层替换标准归一化层,使模型从 c c c t t t嵌入向量的和中回归缩放和偏移参数。

2.3 计算分析

对序列 X ∈ R 1 × J × D X\in\mathbb R^{1\times J\times D} XR1×J×D和状态扩维默认设置 E = 2 E=2 E=2,自注意力与SSM的计算复杂度分别为 O ( S A ) = 4 J D 2 + 2 J 2 D O(SA)=4JD^2+2J^2D O(SA)=4JD2+2J2D O ( S S M ) = 3 J ( 2 D ) N + J ( 2 D ) N 2 O(SSM)=3J(2D)N+J(2D)N^2 O(SSM)=3J(2D)N+J(2D)N2

其中自注意力的计算是序列长度 J J J的二次方,而SSM则是线性关系。注意 N N N为固定参数。这说明DiS有较强的可放缩性。

3. 实验

3.1 实验设置

数据集:仅使用水平翻转数据增广。

实施细节:本文对DiS的权重使用指数移动平均方法。

3.2 模型分析

patch大小的影响:当模型大小一致时,减小patch大小(增加token数),性能会提高。这可能是扩散模型噪声预测任务的低级特性,导致需要小型patch,而不像更高级的分类任务。对高分辨率图像,使用小尺寸patch可能会引入高计算成本,可将图像转换为低维隐式表达,然后再使用DiS处理。

长跳跃的影响:比较拼接( L i n e a r ( C o n c a t ( h s h a l l o w , h d e e p ) ) \mathtt{Linear}(\mathtt{Concat}(h_{shallow},h_{deep})) Linear(Concat(hshallow,hdeep))
、求和( h s h a l l o w + h d e e p h_{shallow}+h_{deep} hshallow+hdeep)和无跳跃连接三种方式。实验表明,求和不会带来明显的性能提升,因为SSM自身可以通过线性方式保留一些浅层信息。而使用拼接和可学习的线性投影可以大幅增加性能。

条件组合:比较两种引入时间 t t t的方案:(1)将 t t t视为token,与图像patch一同处理;(2)将 t t t的嵌入整合到SSM块的层归一化中,类似U-Net中的自适应分组归一化,得到自适应层归一化: A d a L N ( h , s ) = y s L a y e r N o r m ( h ) + y b AdaLN(h,s)=y_s\mathtt{LayerNorm}(h)+y_b AdaLN(h,s)=ysLayerNorm(h)+yb,其中 h h h为SSM的隐状态, y s , y b y_s,y_b ys,yb为时间嵌入的线性投影。实验表明前者的性能优于后者。

缩放模型大小:增大模型深度(SSM块层数)和宽度(隐状态维度)均能提高性能。

3.3 主要结果

无条件图像生成:DiS与基于U-Net或Transformer的扩散模型有相当的性能,但参数量更少。

以类别为条件的图像生成:本文的方法可以超过其余方法的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1502872.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用R语言进行聚类分析

一、样本数据描述 城镇居民人均消费支出水平包括食品、衣着、居住、生活用品及服务、通信、文教娱乐、医疗保健和其他用品及服务支出这八项指标来描述。表中列出了2016年我国分地区的城镇居民的人均消费支出的原始数据,数据来源于2017年的《中国统计年鉴》&#xf…

传递函数硬件化

已知一个系统的传递函数,如何进行硬件化呢? 只需要将传递函数离散化,得到差分方程,就可以根据差分方程进行硬件设计。 通过例子说明: 得到差分方程后,其中y(k)/y(k-1)/y(k-2)/u(k-1)/u(k-2)等代表不同周期…

【Spring】Spring状态机

1.什么是状态机 (1). 什么是状态 先来解释什么是“状态”( State )。现实事物是有不同状态的,例如一个自动门,就有 open 和 closed 两种状态。我们通常所说的状态机是有限状态机,也就是被描述的事物的状态的数量是有…

BC161 大吉大利,今晚吃鸡

一&#xff1a;题目 二&#xff1a;思路 三&#xff1a;代码 #include<bits/stdc.h>using namespace std;long long cnt;//柱子定义为x, y, z void move(int n, char x, char y, char z) {if(n 1){//printf("%c -> %c\n", x, y);//最大盘从x->y//prin…

git远程仓库分支推送与常见问题

1.查看远程仓库分支情况 git fetch origin git branch -r2.删除远程仓库中的某一分支(如master) git push origin --delete master问: 如果我的本地文件只有一个分支main,而远程仓库有两个分支Main和CubeMX, 若要将本地文件中新增的文件Test1.txt更改放入CubeMX中&#xff0c…

大数据开发-Hadoop分布式集群搭建

大数据开发-Hadoop分布式集群搭建 文章目录 大数据开发-Hadoop分布式集群搭建环境准备Hadoop配置启动Hadoop集群Hadoop客户端节点Hadoop客户端节点 环境准备 JDK1.8Hadoop3.X三台服务器 主节点需要启动namenode、secondary namenode、resource manager三个进程 从节点需要启动…

Linux操作系统项目上传Github代码仓库指南

文章目录 1 创建SSH key2.本地git的用户名和邮箱设置3.测试连接4.创建仓库5.终端项目上传 1 创建SSH key 1.登录github官网,点击个人头像,点击Settings,然后点击SSH and GPG keys,再点击New SSH key。 Title 可以随便取&#xff0c;但是 key 需要通过终端生成。 Linux终端执行…

窄带波束形成

阵列信号处理有以下三个研究方向&#xff1a; 检测入射信号是否存在&#xff0c;以及入射信号的数目检测入射信号的到达方向&#xff08;DOA)角增强某个感兴趣方向的信号&#xff0c;抑制其他方向的干扰&#xff08;beamforming) 波束形成&#xff08;beamforming&#xff09;…

福州·名城银河湾220㎡现代简约风装修案例分享。福州中宅装饰,福州装修

以手作维度构境, 跳脱约定成俗的风格, 转化内外地域分际, 于静谧中凝聚丰厚的美学能量, 谦虚且沉默以对。 平面设计图 项目信息 项目名称 | 名城银河湾 设计地址 | 福建福州 项目面积 | 220㎡ 项目户型 | 5室2厅2厨3卫 设计风格 | 现代轻奢 首席设计师丨欧阳光玉 中…

STM32 | STM32F407ZE(LED寄存器开发续第二天源码)

上节回顾 STM32 | STM32时钟分析、GPIO分析、寄存器地址查找、LED灯开发(第二天)STM32 | Proteus 8.6安装步骤(图文并茂)一、 LED灯开发 1、理解led灯原理图 LED0连接在PF9 PF9输出低电平(0),灯亮;PF9输出高电平(1),灯灭;(低电平有效) 2、打开GPIOF组时钟 //将…

随机输一次(Python3)

大家应该都会玩“锤子剪刀布”的游戏&#xff1a;两人同时给出手势&#xff0c;胜负规则如图所示&#xff1a; 现要求你编写一个控制赢面的程序&#xff0c;根据对方的出招&#xff0c;给出对应的赢招。但是&#xff01;为了不让对方意识到你在控制结果&#xff0c;你需要隔 K …

网络安全相关证书有哪些?

从事于信息安全工作的人们&#xff0c;在面对繁杂问题的时候&#xff0c;往往会有焦虑和烦躁的表现。一部分可能来自于系统和流程的实际漏洞&#xff0c;一方面可能是自身的能力还有部分短板。许多人认为庞杂的问题或多或少的难以下手&#xff0c;如果有好的方式能够同时解决这…

Linux报错排查-刚安装好的ubuntu系统无法ssh连接

Linux运维工具-ywtool 目录 一.问题描述二.问题解决2.1 先给ubuntu系统配置阿里云源2.2 安装openssh-server软件2.3 在尝试ssh连接,可以连接成功了 三.其他命令 一.问题描述 系统:ubuntu-18.04-desktop-amd64 系统安装完后,想要通过xshell软件连接系统,发现能Ping通系统的IP,但…

计算布尔二叉树的值

题目 题目链接 . - 力扣&#xff08;LeetCode&#xff09; 题目描述 代码实现 class Solution { public:bool evaluateTree(TreeNode* root) {if(root->left nullptr && root->right nullptr) return root->val;bool left evaluateTree(root->left)…

CubeMX使用教程(6)——ADC模拟输出

本篇将利用CubeMX开发工具学习ADC&#xff08;模拟输出&#xff09;的使用 我们还是利用上一章的工程进行二次开发&#xff0c;这样方便 首先打开CubeMX进行相关配置 通过查看G431RBT6开发板有关模拟输出部分的原理图可知&#xff0c;模拟输出用到的IO口是PB15和PB12 接着我…

11、Linux-安装和配置Redis

目录 第一步&#xff0c;传输文件和解压 第二步&#xff0c;安装gcc编译器 第三步&#xff0c;编译Redis 第四步&#xff0c;安装Redis服务 第五步&#xff0c;配置Redis ①开启后台启动 ②关闭保护模式&#xff08;关闭之后才可以远程连接Redis&#xff09; ③设置远程…

接口自动化测试框架搭建:基于python+requests+pytest+allure实现

众所周知&#xff0c;目前市面上大部分的企业实施接口自动化最常用的有两种方式&#xff1a; 1、基于代码类的接口自动化&#xff0c;如&#xff1a; PythonRequestsPytestAllure报告定制 2、基于工具类的接口自动化&#xff0c;如&#xff1a; PostmanNewmanJenkinsGit/svnJme…

【Kotlin】类和对象

1 前言 Kotlin 是面向对象编程语言&#xff0c;与 Java 语言类似&#xff0c;都有类、对象、属性、构造函数、成员函数&#xff0c;都有封装、继承、多态三大特性&#xff0c;不同点如下。 Java 有静态&#xff08;static&#xff09;代码块&#xff0c;Kotlin 没有&#xff1…

和数软件:区块链技术的爆发与冲击

什么是区块链&#xff1f;它是如何发展而来的&#xff1f;应用在哪些领域&#xff1f;将会对我国的社会经济产生哪些重大影响&#xff1f; 什么是区块链 区块链作为一种底层技术&#xff0c;最早的实践是数字货币。根据最早的中本聪定义&#xff0c;区块链实质上是一种基于网…

人工智能|机器学习——Canopy聚类算法(密度聚类)

1.简介 Canopy聚类算法是一个将对象分组到类的简单、快速、精确地方法。每个对象用多维特征空间里的一个点来表示。这个算法使用一个快速近似距离度量和两个距离阈值T1 > T2 处理。 Canopy聚类很少单独使用&#xff0c; 一般是作为k-means前不知道要指定k为何值的时候&#…