【扩散模型(七)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(上)

news2025/1/7 7:37:14

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【扩散模型(六)】Stable Diffusion 3 diffusers 源码详解1-推理代码-文本处理部分
  • 本系列将对比介绍 DiT 和 MMDiT 的区别和具体的代码实现,本文先介绍 DiT 的核心代码。

文章目录

  • 系列文章目录
  • 一、DiT
      • DiT 整体代码
      • DiT Block
        • 这六个参数是否是相同的值?
      • 代替 Cross-attention 的 adaLN-Zero Block


一、DiT

  • DiT 1 是 SD3 中 MMDiT 的核心基础,而
  • 通过将 Diffusion 中的 Unet 换成了 DiT Block,来实现基于条件的图像生成。
  • 原文中的条件是类别标签,而非文本提示词。
    在这里插入图片描述
  • 原文测试了多种设置,最终采用了 adaLN-Zero 作为 Cross-Attention 的替代。

DiT 整体代码

官方代码仓库为 https://github.com/facebookresearch/DiT,下面代码的具体位置在 /path/to/DiT/models.py

  • 下方代码为上图的左边部分,输入 x 是 Noised Latent,t 是 Timestep,Label 为 y
  • 其中 block 则是上图中的 DiT Block,将 x 和 c 共同作为输入,以 c 为条件来生成 x (对 Noised Latent 进行去噪)。
   def forward(self, x, t, y):
       """
       Forward pass of DiT.
       x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
       t: (N,) tensor of diffusion timesteps
       y: (N,) tensor of class labels
       """
       x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
       t = self.t_embedder(t)                   # (N, D)
       y = self.y_embedder(y, self.training)    # (N, D)
       c = t + y                                # (N, D)
       for block in self.blocks:
           x = block(x, c)                      # (N, T, D)
       x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
       x = self.unpatchify(x)                   # (N, out_channels, H, W)
       return x

DiT Block

与下面代码中 forward 函数内对应的变量在 DiT Block 中的位置。
在这里插入图片描述

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

在这个 DiTBlock 类中,shift_msascale_msagate_msashift_mlpscale_mlpgate_mlp 是从 adaLN_modulation(c) 这一步中得到的,它们在具体功能上是有所区别的,虽然它们是通过同一个输入 c 生成的。

  1. shift_msascale_msa 这两个变量与 Multi-Head Self-Attention (MSA) 模块的自适应层归一化(adaptive LayerNorm, adaLN)有关:

    • shift_msa: 这个变量用于平移 LayerNorm 的输出,也就是在归一化的基础上加上一个偏置。它在调节 MSA 模块的激活输出时用作偏移量。
    • scale_msa: 这个变量用于缩放 LayerNorm 的输出,即对归一化的结果乘以一个比例因子。它控制了 MSA 模块中激活的放大或缩小程度。
  2. gate_msa: 这个变量是作为一个门控(gate)信号,作用于 MSA 模块的输出上。它决定了 MSA 模块输出在累加到 x 之前的权重。如果 gate_msa 很小,那么这个输出会被抑制;如果 gate_msa 接近1,则输出会如常累加。

  3. shift_mlpscale_mlp 这两个变量与 Pointwise Feedforward (MLP) 模块的自适应层归一化(adaLN)有关,类似于 shift_msascale_msa,但它们作用在 MLP 模块上:

    • shift_mlp: 用于平移 LayerNorm 的输出,在 MLP 模块中作为偏移量。
    • scale_mlp: 用于缩放 LayerNorm 的输出,在 MLP 模块中控制激活的放大或缩小。
  4. gate_mlp: 类似于 gate_msa,但它控制的是 MLP 模块的输出。它决定了 MLP 模块输出在累加到 x 之前的权重。

这六个参数是否是相同的值?

adaLN_modulation(c) 中,c 经过一个 nn.Linear 层(即 nn.Linear(hidden_size, 6 * hidden_size, bias=True)),然后被 chunk(6, dim=1) 分成六个部分,分别得到 shift_msa、scale_msa、gate_msa、shift_mlp、scale_mlp 和 gate_mlp。

虽然这些变量来自于同一个线性层的输出,但由于 nn.Linear 层的权重在训练过程中是可学习的,并且是随机初始化的,因此这些权重会在训练过程中被更新为不同的值。

代替 Cross-attention 的 adaLN-Zero Block

那么为什用 adaLN-Zero 来代替 Cross-Attention 呢?主要是因为计算资源。(DiT 原文提到 Cross-attention adds the most Gflops to the model, roughly a 15% overhead.)

  1. 什么是adaLN-Zero Block?
    adaLN-Zero Block是一种改进版的adaLN(Adaptive Layer Normalization)模块,主要用于扩散模型(Diffusion Model)中。它的核心思想是通过初始化技巧和引入额外的缩放参数,来加速模型训练并提高生成样本的质量。

  2. 为什么引入adaLN-Zero Block?

    • 加速训练: 通过将残差块初始化为恒等映射,模型在训练初期更容易收敛,从而加快训练速度。
    • 提升性能: 引入维度缩放参数,使得模型能够学习到更具表达能力的特征表示,从而生成质量更高的样本。
    • 增强稳定性: 恒等初始化有助于稳定模型的训练过程,尤其对于深层模型。
  3. adaLN-Zero Block的工作原理

    • 恒等初始化: 对于每个残差块的最后一个adaLN层,将缩放参数γ初始化为0。这使得该层在初始阶段相当于一个恒等映射,不会对输入数据进行缩放。
    • 维度缩放参数α: 在残差连接之前,引入一个维度缩放参数α,用于对特征进行缩放。这个参数是可学习的,能够自适应地调整特征的尺度。
  4. 与传统adaLN的区别

    • 初始化方式不同: adaLN-Zero对缩放参数γ进行了特殊的初始化,而传统的adaLN通常使用随机初始化。
    • 参数数量增加: adaLN-Zero引入了额外的维度缩放参数α,增加了模型的参数数量。
  5. 为什么有效?

    • 恒等初始化使得模型在训练初期能够快速学习到残差部分,从而加速训练过程。
    • 维度缩放参数提供了更大的灵活性,使得模型能够更好地适应不同尺度的特征。

最后也附上原文,便于对照理解。
在这里插入图片描述


  1. Peebles, William, and Saining Xie. “Scalable diffusion models with transformers.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2040778.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“从零开始的HTML 表格”——WEB开发系列09

HTML 表格是一种用于在网页上组织和显示信息的结构性元素&#xff0c;它能够将数据以行和列的形式呈现&#xff0c;帮助用户更清晰地理解数据关系。表格在展示统计数据、产品列表、日程安排等方面非常实用。 一、HTML 表格的基本结构 HTML 表格用 ​​<table>​​ 标签来…

创意无限!2024年热门视频剪辑软件精选

从专业级电影剪辑工具到简单易用的手机APP&#xff0c;再到集创意与高效于一身的桌面应用&#xff0c;各类剪辑软件如雨后春笋般涌现。本文将带你一窥2024年火热的剪辑视频的软件。 1.福昕视频剪辑 连接直达>>https://www.pdf365.cn/foxit-clip/ 这款视频编辑工具凭…

Oracle 12.2集群搭建遇到ORA-ORA-15227,ORA-15031,ORA-15018问题处理

报错&#xff1a; [FATAL] [DBT-30056] Labeling of disks failed. ORA-15227: could not perform label set/clear operation ORA-15031: disk specification /dev/asmdisk/ocr01 matches no disks [FATAL] [DBT-30002] Disk group OCR creation failed. ORA-15018: diskgrou…

用Python实现9大回归算法详解——03. 岭回归算法

1. 岭回归的基本概念与动机 1.1 为什么使用岭回归&#xff1f; 在线性回归中&#xff0c;当特征之间存在强烈的相关性&#xff08;即多重共线性&#xff09;时&#xff0c;回归系数会变得不稳定&#xff0c;导致模型在新数据上的表现很差。多重共线性会导致普通最小二乘法&am…

stm32f407新建项目工程及烧录

1、新建一个文件夹&#xff0c;打开keil5将项目工程放入文件夹中 2、弹出选择对应型号设备 3、弹出选择对应库 可以看见出现下图&#xff1a;感叹号表示有错 最后如图所示&#xff1a;点击ok就行了 4、创建对应的文件夹存放文件 4、建立main.c 5、添加对应的设置 最后写一个空白…

sp-eric靶机

端口扫描 靶机ip地址为192.168.7.46 目录扫描 访问80端口 拼接访问 /admin.php 发现登录框 测试sql注入&#xff0c;弱口令等&#xff0c;无结果 扫描目录发现了.git文件&#xff0c;存在源码泄漏 将其下载到kali上读取 python2 GitHack.py -u http://192.168.7.180/.git/…

Linux11

Linux运行级别 graphical.target图形化模式 runlevel查看运行级别 init 6自动重启 centos7单用户模式修改密码 Windows安全模式可用来删除木马&#xff0c;更为方便 单用户模式修改密码 选择第一个 按e键进入编辑模式&#xff0c;并完成以下修改&#xff08;注意&#xff0…

linux上Java生成图片中文乱码

在生成图形二维码时&#xff0c;设置底部中文导出空白乱码&#xff0c;效果如下&#xff1a; 这里服务器使用的是centos7&#xff0c;解决方案下载simsun.ttc文件&#xff0c;放入至jdk安装目录“/opt/jdk/jre/lib/fonts”中&#xff0c;具体根据自身本机jdk安装路径存放&…

ZOOKEEPER+KAFKA消息队列群集

前言 消息队列(Message Queue)&#xff0c;是分布式系统中重要的组件&#xff0c;其通用的使用场景可以简单地描述为:当不需要立即获得结果&#xff0c;但是并发量又需要进行控制的时候&#xff0c;差不多就是需要使用消息队列的时候。 消息队列 什么是消息队列 消息(Messa…

CAD二次开发IFoxCAD框架系列(19)-图层操作

1. 根据名称查询指定的图层 查看层表中是否含有名为“MyLayer”的图层。 using var tr new DBTrans();if(tr.LayerTable.Has("MyLayer")){//要执行的操作}2. 遍历图层名称 遍历图层表并打印每个图层的名字 。 using var tr new DBTrans();tr.LayerTable.GetRecor…

MySQL源码安装与MySQL基础学习

1、安装MySQL ​ 本次安装使用的是绿色硬盘版本&#xff0c;无需额外安装依赖环境&#xff0c;比较简单 修改相关配置文件&#xff1a; 设置环境变量&#xff0c;声明/宣告MySQL命令便于系统识别&#xff1a; 初始化数据库&#xff1a; 设置系统识别&#xff0c;进行操作&…

【代码随想录】数组总结篇

本博文为《代码随想录》的学习笔记&#xff0c;原文链接&#xff1a;代码随想录 数组理论基础 首先要知道数组在内存中的存储方式&#xff0c;数组时存放在连续内存空间上的相同类型数据的集合。数组可以方便地通过下标索引的方式获取到下标对应的数据。举例如下&#xff1a;…

电信优惠套餐到期会自动延续吗?这个问题你了解过吗?

电信优惠套餐到期会自动延续吗&#xff1f;看选择的套餐&#xff0c;不同的套餐情况不同。 对于电信流量卡的优惠期限&#xff0c;有以下几种情况&#xff1a; 短期套餐&#xff1a; 6个月、12个月、24个月等&#xff0c;套餐到期后会恢复原来的资费&#xff0c;不会自动延续…

分组汇总时保留不变列

Excel表格的ID列是分类&#xff0c;Value1和Value2是数值&#xff0c;ID相同时Descr 1和Descr 2保持不变。 ABCDE1IDValue 1Value 2Descr 1Descr 22112.51.8ax13112.31.1ax14111.91.6ax15123.73.5bx26123.91.5bx27132.50.2cx38132.64.1cx391324.8cx310132.71.8cx3 要求&#…

Linux Shell实例

1.查空行 答案&#xff1a; awk /^$/{print NR} file1.txt#awk:一个强大的文本分析工具&#xff0c;把文件逐行的读入&#xff0c;以空格为默认分隔符将每行切片&#xff0c;切开的部分再进行分析#处理。 #1&#xff09;基本语法 #awk [选项参数]/pattern1/{action1} /pattern…

【数据结构详解】——计数排序(动图详解)

目录 &#x1f552; 1. 计数排序 &#x1f552; 1. 计数排序 &#x1f4a1; 算法思想&#xff1a;计数排序又称为鸽巢原理&#xff0c;是对哈希直接定址法的变形应用&#xff0c;操作步骤&#xff1a; 统计相同元素出现次数根据统计的结果将序列回收到原来的序列中 void Coun…

Windows11 WSL2 Ubuntu编译安装perf工具

在Windows 11上通过WSL2安装并编译perf工具&#xff08;Linux性能分析工具&#xff09;可以按以下步骤进行。perf工具通常与Linux内核一起发布&#xff0c;因此你需要确保你的内核版本和perf版本匹配。以下是安装和编译perf的步骤&#xff1a; 1. 更新并升级系统 首先&#x…

【算法】并查集的介绍与使用

1.并查集的概论 定义&#xff1a; 并查集是一种树型的数据结构&#xff0c;用于处理一些不相交集合的合并及查询问题&#xff08;即所谓的并、查&#xff09;。比如说&#xff0c;我们可以用并查集来判断一个森林中有几棵树、某个节点是否属于某棵树等。 主要构成&#xff1a; …

three.js的粒子和粒子系统基础知识扫盲,附案例图

绚烂的烟花、急促的雨滴、深邃的宇宙等等这些效果都可以通过three.js的粒子效果模拟出来&#xff0c;已达到以假乱真的程度了&#xff0c;本文来分享一下three.js的粒子系统&#xff0c;欢迎大家点赞评论收藏。 一、什么是粒子和粒子系统 粒子&#xff1a;可以简单理解为一个具…

JDBC1 Mysql驱动,连接数据库

JDBC 一、JDBC Java Database Connectivity&#xff1a;Java访问数据库的解决方案 JDBC定义了一套标准接口&#xff0c;即访问数据库的通用API&#xff0c; 不同的数据库厂商根据各自数据库的特点去实现这些接口。 JDBC希望用相同的方式访问不同的数据库&#xff0c;让具体的…