DiT:Scalable Diffusion Models with Transformers

news2024/12/25 0:37:49

TOC

  • 1 前言
  • 2 方法和代码

1 前言

该论文发表之前,市面上几乎都是用卷积网络作为实际意义上的(de-facto)backbone。于是一个想法就来了:为啥不用transformer作为backbone呢?

文章说本论文的意义就在于揭示模型选择对于扩散模型的重要性,并为生成模型研究提供一个可借鉴的基准(baseline)。

本文还揭示出卷积网络的inductive bias对生成性能并没有多大的影响,所以可以使用transformer网络去替代卷积网络。文章使用Gflops和FID去分别评估模型复杂度和生成图像质量。

刚刚又去学了一下FLOPs,真是破破烂烂,缝缝补补啊……

总的来说,DiT有如下优点:

  1. 高质量:achieve a state-of-the-art result of 2.27 FID on the classconditional 256 × 256 ImageNet generation benchmark.
  2. 发现了FID和GFLOPs之间存在强相关关系,通过增加depth of transformer或者amount of patches可以增加GFLOPs
  3. 灵活性:可以挑战模型大小、patches大小和序列长度
  4. 跨领域研究:DiT架构和ViT类似,为跨领域研究提供可能

2 方法和代码

在这里插入图片描述
整体来看:

  • 使用transformer作为其主干网络,代替了原先的UNet
  • 在latent space进行训练,通过transformer处理潜在的patch
  • 输入的条件(timestep 和 text/label )的四种处理方法:
    • In-context conditioning: 将condition和input embedding合并成一个tokens(concat),不增加额外计算量
    • Cross-attention block:在transformer中插入cross attention,将condition当作是K、V,input当作是Q
    • Adaptive layer norm (adaLN) block:将timestep和 text/label相加,通过MLP去回归参数scale和shift,也不增加计算量。并且在每一次残差相加时,回归一个gate系数。
    • adaLN-Zero block:参数初始化为0,那么在训练开始时,残差模块当于identical function。
  • 整体流程:patchify -> Transfomer Block -> Linear -> Unpatchify。 注意最后输出的维度是原来维度的2倍,分别输出noise和方差。

由下图可见,adaLN-Zero最好。然后就是探索各种调参效果,此处略。
在这里插入图片描述

代码以及注释:
DiTBlock

# DIT的核心子模块
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        # 此处为miltihead-self-Attention

        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        #使用自适应归一化替换标准归一化层
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
  • addLN_zero: 先通过SiLU,然后再通过线性层输出6个值

forward

  def forward(self, x, t, y):

        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        # time step embedding
        y = self.y_embedder(y, self.training)    # (N, D)
        c = t + y                                # (N, D)
        # 送入上述的DIT-Block中
        for block in self.blocks:
            x = block(x, c)                      # (N, T, D)
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x
  • x通过embedding,与position embedding相加(固定的sin-cos位置编码)
  • t通过embedding
  • y通过embedding, t和y相加得到c
  • 遍历每一个block,传入x和c
  • 最后传入最后一层线性层,然后通过unpatchify恢复图像
class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )
        
     nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.linear.weight, 0)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x
  • 同样引入adpLN_zero,并且让输出维度为p*p*2c,是特征维度原来大小的2倍,分别预测noise和方差

最后unpatchify

    def unpatchify(self, x):
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1500312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二叉树—层序遍历

102. 二叉树的层序遍历 代码实现: /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ /*** Return an array of arrays of size *returnSize.* The sizes of the arrays …

L波段光端机-L波段+CATV射频光端机工作机制及行业应用探究

L波段光端机-L波段CATV射频光端机工作机制及行业应用探究 北京海特伟业任洪卓发布于2023年3月8日 一、何为L波段光端机 L波段光端机是一种用于光通信的设备,其主要工作波长位于L波段,即40~860MHz和950~2600MHz的带宽,可选独立工作于950~260…

开发Chrome扩展插件

1.首先开发谷歌chrome扩展插件,没有严格的项目结构目录,但是需要保证里面有一个mainfest.json文件 (必不可少的文件)。在这个文件里有三个属性必不可少:name、version、mainfest_version; // 清单文件的版本,这个必须写…

消息队列-Kafka-消费方如何分区与分区重平衡

消费分区 资料来源于网络 消费者订阅的入口:KafkaConsumer#subscribe 消费者消费的入口:KafkaConsumer#poll 处理流程: 对元数据重平衡处理:KafkaConsumer#updateAssignmentMetadataIfNeeded 协调器的拉取处理:onsum…

java常用排序算法——冒泡排序,选择排序概述

前言: 开始接触算法了,记录下心得。打好基础,daydayup! 算法 算法是指解决某个实际问题的过程和方法 排序算法 排序算法指给混乱数组排序的算法。常见的有:冒泡排序,选择排序 冒泡排序: 冒泡排序指在数组…

python异常机制

当代码出现异常后底下代码都不会被执行了,也就是程序崩溃了。当然能避免异常的话尽量避免但是有的时候这个是没有办法避免的。 异常处理 (注:异常处理是从上往下处理,所以编写代码时要注意) 语法 try:可能出现异常…

SpringCloud-SpringBoot读取Nacos上的配置文件

在 Spring Boot 应用程序中,可以使用 Spring Cloud Nacos 来实现从 Nacos 服务注册中心和配置中心读取配置信息。以下是如何在 Spring Boot 中读取 Nacos 上的配置文件的步骤: 1. 引入依赖 首先,在 Spring Boot 项目的 pom.xml 文件中添加 …

JAVA虚拟机实战篇之内存调优[3](诊断问题:MAT工具分析堆内存快照)

文章目录 版权声明解决内存溢出的思路诊断 – 内存快照 MAT内存泄漏检测原理基础知识支配树深堆和浅堆string案例分析 MAT内存泄漏检测原理 导出运行中系统内存快照分析超大堆的内存快照 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明&am…

回溯算法题解(难度由小到大)(力扣,洛谷)

目录 注意: P1157 组合的输出(洛谷)https://www.luogu.com.cn/problem/P1157int result[10000] { 0 }; 216. 组合总和 IIIhttps://leetcode.cn/problems/combination-sum-iii/ 17. 电话号码的字母组合https://leetcode.cn/problems/lett…

YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information

paper: https://arxiv.org/abs/2402.13616 code YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information 一、引言部分二、问题分析2.1 信息瓶颈原理2.2 可逆函数 三、本文方法3.1 可编程梯度信息 四、实验4.1消融实验部分 今天的深度学习方法关注的…

ELK介绍使用

文章目录 一、ELK介绍二、Elasticsearch1. ElasticSearch简介:2. Elasticsearch核心概念3. Elasticsearch安装4. Elasticsearch基本操作1. 字段类型介绍2. 索引3. 映射4. 文档 5. Elasticsearch 复杂查询 三、LogStash1. LogStash简介2. LogStash安装 四、kibana1. …

hv静态资源web服务

在实际工作中,为了保证App的高可用性,服务端需要缓存一部分静态资源,通过web服务来分发资源。hv即可快速实现web服务。 hv静态资源服务。 HttpService router; router.Static("/statics", "smart-yi-ui");目录结构(sma…

kafka 可视化工具

kafka可视化工具 随着科技发展,中间件也百花齐放。平时我们用的redis,我就会通过redisInsight-v2 来查询数据,mysql就会使用goland-ide插件来查询,都挺方便。但是kafka可视化工具就找了半天,最后还是觉得redpandadata…

javaSE-----继承和多态

目录 一.初识继承: 1.1什么是继承,为什么需要继承: 1.2继承的概念与语法: 二.成员的访问: 2.1super关键字 2.2this和super的区别: 三.再谈初始化: 小结: 四.初识多态: 4.1多…

Java Web开发---复试Tips复习

***********(自用,摘录自各种文章和自己总结)********** 小知识点理解 Web Web应用开发主要是基于浏览器的应用程序开发。一个Web应用由多部分组成 java web就是用java语言开发出可在万维网上浏览的程序 Web应用程序编写完后,…

【自然语言处理六-最重要的模型-transformer-上】

自然语言处理六-最重要的模型-transformer-上 什么是transformer模型transformer 模型在自然语言处理领域的应用transformer 架构encoderinput处理部分(词嵌入和postional encoding)attention部分addNorm Feedforward & add && NormFeedforw…

在哪里能找到抖音短视频素材?推荐热门的抖音短视频素材下载资源

哎呦喂,小伙伴们,是不是在短视频的大海里划船,想找到那颗能让你起飞的珍珠,但又觉得素材难寻如针海捞针?别急,今天我就来给你们送上几个超实用的宝藏素材网站,让你的短视频创作不再愁素材 1&am…

从零开始的LeetCode刷题日记:142.环形链表II

一.相关链接 视频链接:代码随想录:142.环形链表II 题目链接:142.环形链表II 二.心得体会 这道题是一道链表题,但他没有对头结点的操作,所以不用虚拟头结点。这道题要分两步进行,第一步是判断链表有没有环…

如何获取国外信用卡?需要国外银行卡支付怎么解决?如何订阅国外产品?

当国内的用户想要使用国外的产品时,很多产品是需要订阅付费的。其中有些产品还没有引入国内,只能用国外的信用卡支付,对于在国内的朋友,如何获取一张国外的信用卡呢? 这里推荐一个平台:wildCard waildCard…

基于Java的生活废品回收系统(Vue.js+SpringBoot)

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容三、界面展示3.1 登录注册3.2 资源类型&资源品类模块3.3 回收机构模块3.4 资源求购/出售/交易单模块3.5 客服咨询模块 四、免责说明 一、摘要 1.1 项目介绍 生活废品回收系统是可持续发展的解决方案,旨在鼓…