【扩散模型(六)】Stable Diffusion 3 diffusers 源码详解1-推理代码-文本处理部分

news2024/9/20 9:41:25

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 本系列文章将介绍 SD3 源码的推理过程,包括文本处理部分(encode_prompt)、提供时间步的 Scheduler(FlowMatchEulerDiscreteScheduler)、代替 Unet 的主干网络 (SD3Transformer2DModel),而本文重点为文本 (caption/prompt) 处理部分。

文章目录

  • 系列文章目录
  • 前言
  • 一、文本处理的整体流程
  • 二、Text Encoder 1、2(CLIP)
    • 1. 模型部分
    • 2. 两个 Text Encoder 的输入和输出
  • 三、Text Encoder 3(T5)
  • 其他


前言

下图为《Scaling Rectified Flow Transformers for High-Resolution Image Synthesis》 (ICML 2024 )中的 SD3 架构图。
在这里插入图片描述


一、文本处理的整体流程

下面流程图只对正向提示词进行了梳理,负向提示词的流程并无差异。
在这里插入图片描述

本文分析的源代码为 diffusers 包中的 SD3 pipeline (位置在/path/to/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py),文本处理部分主要为 其中 __call__() 函数调用的 self.encode_prompt() 函数,主要涉及了 3 个 text encoder 以及对应的 3 个 tokenizer。

其输入输出如下:

 (
     prompt_embeds,
     negative_prompt_embeds,
     pooled_prompt_embeds,
     negative_pooled_prompt_embeds,
 ) = self.encode_prompt(
     prompt=prompt,
     prompt_2=prompt_2,
     prompt_3=prompt_3,
     negative_prompt=negative_prompt,
     negative_prompt_2=negative_prompt_2,
     negative_prompt_3=negative_prompt_3,
     do_classifier_free_guidance=self.do_classifier_free_guidance,
     prompt_embeds=prompt_embeds,
     negative_prompt_embeds=negative_prompt_embeds,
     pooled_prompt_embeds=pooled_prompt_embeds,
     negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
     device=device,
     clip_skip=self.clip_skip,
     num_images_per_prompt=num_images_per_prompt,
     max_sequence_length=max_sequence_length,
 )

输入:

  • 其中 prompt 和 negative_prompt 为输入的字符串
  • 其他的 prompt_2、 prompt_3、 negative_prompt_2、 negative_prompt_3、prompt_embeds、 negative_prompt_embeds、pooled_prompt_embeds、negative_pooled_prompt_embeds 均为 None
  • do_classifier_free_guidance 一般都是 True
  • max_sequence_length = 256

具体而言是在 encode_prompt 函数中,通过两次 _get_clip_prompt_embeds_get_t5_prompt_embeds 来调用 3 个 Text Encoder。

prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
    prompt=prompt,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
    prompt=prompt_2,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)

t5_prompt_embed = self._get_t5_prompt_embeds(
    prompt=prompt_3,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
)

二、Text Encoder 1、2(CLIP)

1. 模型部分

  • 根据输入的 clip_tokenizers、clip_text_encoders 序号分别选择 text_encoder (CLIP L/141) 或者 text_encoder_2 (OpenCLIP bigG/142)。
  • 从下面初始化代码可以看出,二者 text_encodertext_encoder_2 采用的类一致,所以二者的区别主要是模型权重以及 config 不同。
...
def __init__(...
     text_encoder: CLIPTextModelWithProjection,
     tokenizer: CLIPTokenizer,
     text_encoder_2: CLIPTextModelWithProjection,
     tokenizer_2: CLIPTokenizer,
...

 def _get_clip_prompt_embeds(
     self,
     prompt: Union[str, List[str]],
     num_images_per_prompt: int = 1,
     device: Optional[torch.device] = None,
     clip_skip: Optional[int] = None,
     clip_model_index: int = 0,
 ):
     device = device or self._execution_device

     clip_tokenizers = [self.tokenizer, self.tokenizer_2]
     clip_text_encoders = [self.text_encoder, self.text_encoder_2]

     tokenizer = clip_tokenizers[clip_model_index]
     text_encoder = clip_text_encoders[clip_model_index]

在下载的 SD3 模型权重文件中,/path/to/stable-diffusion-3-medium-diffusers 可以找到 text_encodertext_encoder_2 子目录,对比其中的 config(下图中左边为 text_encoder ,右边为 text_encoder_2 ),可以知道二者更具体的不同之处:

  1. hidden_size 不同:768 vs 1280
  2. hidden_act: quick_gelu vs gelu
  3. intermediate_size 不同:3072 vs 5120
  4. “num_attention_heads” 和 “num_hidden_layers”:12/12 vs 20/32
  5. projection_dim 不同:768 vs 1280
  • 从以上 config 中,可以明显看出 text_encoder_2 (OpenCLIP bigG/14) 确实更加 big。
  • 两个 Text Encoder 最终的输出也和上文 “一、文本处理的整体流程” 中的流程图一致,分别输出 [n, 77, 768 ] 和 [n, 77, 1280]。
    • n 为推理时的 num_images_per_prompt,每个 prompt 的出图数量。

在这里插入图片描述

2. 两个 Text Encoder 的输入和输出

  • 二者的输入是相同的 prompt,得到输出为不同的两对 prompt_embed, pooled_prompt_embedprompt_2_embed, pooled_prompt_2_embed
  • 其中,
    • prompt_embed [n, 77, 768 ] 和 prompt_2_embed [n, 77, 1280]为主要的 prompt 特征,并在后续 cat 到一起,得到 clip_prompt_embeds [n, 77, 2048]。
    • pooled_prompt_embed 和 pooled_prompt_2_embed 也一样 cat,
    • 两种特质的区别:prompt_embed(prompt_2_embed)是更主要/细粒度的文本特征、而 pooled_prompt_embed(pooled_prompt_2_embed)是更粗粒度的文本特征
    • 原文:However, as the pooled text representation retains only coarse-grained information about the text input 3, the network also requires information from the sequence representation c t x t c_{txt} ctxt.
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
    prompt=prompt,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=0,
)
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
    prompt=prompt_2,
    device=device,
    num_images_per_prompt=num_images_per_prompt,
    clip_skip=clip_skip,
    clip_model_index=1,
)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
...
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)

三、Text Encoder 3(T5)

T5EncoderModel 的调用则更简洁一点,输入同样是 prompt,并且只有一个输出。

def __init__(...
	text_encoder_3: T5EncoderModel,
       tokenizer_3: T5TokenizerFast,
...

t5_prompt_embed = self._get_t5_prompt_embeds(
    prompt=prompt_3,
    num_images_per_prompt=num_images_per_prompt,
    max_sequence_length=max_sequence_length,
    device=device,
)

# 实际为 clip_prompt_embeds = torch.nn.functional.pad(
#    clip_prompt_embeds, (0, 4096-2048)
#),即在后面 2048 个维度上 pad 全 0. 
clip_prompt_embeds = torch.nn.functional.pad(
    clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
)

# 在序列长度的维度(-2)上 cat 到一起,得到 77+256 = 333 的长度
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
  • 作用:增强对复杂文本的生成能力。
  • 原文:T5 对于复杂的提示词很重要,例如涉及高度细节或拼写较长的文本(第2行和第3行)。然而,对于大多数提示,作者发现在推理时删除T5仍然可以获得具有竞争力的性能。
    在这里插入图片描述

其他

强烈安利另外一位博主的文章:

  1. Stable Diffusion1.5网络结构-超详细原创
  2. Stable Diffusion XL网络结构-超详细原创

  1. Learning transferable visual models from natural language supervision, 2021. ↩︎

  2. Reproducible scaling laws for contrastive language-image learning. In 2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). IEEE, 2023. doi: 10.1109/cvpr52729.2023.00276. URL http://dx.doi.org/10.1109/CVPR52729.2 023.00276. ↩︎

  3. Sdxl: Improving latent diffusion models for high-resolution image synthesis, 2023. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1936566.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于python的去除图像内部填充

1 代码功能 该代码实现了一个图像处理的功能,具体来说是去除图像内部填充(或更准确地说,是提取并显示图像中轮廓的外围区域,而忽略内部填充)。以下是该功能的详细步骤: 读取图像:使用cv2.imread…

Hadoop-38 Redis 高并发下的分布式缓存 Redis简介 缓存场景 读写模式 旁路模式 穿透模式 缓存模式 基本概念等

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: HadoopHDFSMapReduceHiveFlumeSqoopZookeeperHBaseRedis 章节内容 上一节我们完成了: HBase …

机器学习-18-统计学与机器学习中回归的区别以及统计学基础知识

参考通透!一万字的统计学知识大梳理 参考3万字长文!手把手教你学会用Python实现统计学 参考统计学的回归和机器学习中的回归有什么差别? 1 研究对象 一维:就是当前摆在我们面前的“一组”,“一批数据。这里我们会用到统计学的知识去研究这类对象。 二维:就是研究某个“事…

【系统架构设计】数据库系统(三)

数据库系统(三) 数据库模式与范式数据库设计备份与恢复分布式数据库系统分布式数据库的概念特点分类目标 分布式数据库的架构分布式数据库系统与并行数据库系统 数据仓库数据挖掘NoSQL大数据 数据库模式与范式 数据库设计 备份与恢复 分布式数据库系统…

生活中生智慧

【 圣人多过 小人无过 】 觉得自己做得不够才能做得更好,互相成全;反求诸己是致良知的第一步;有苦难才能超越自己,开胸怀和智慧;不浪费任何一次困苦,危机中寻找智慧,成长自己。 把困苦当作当下…

自动驾驶三维车道线检测系列—LATR: 3D Lane Detection from Monocular Images with Transformer

文章目录 1. 概述2. 背景介绍3. 方法3.1 整体结构3.2 车道感知查询生成器3.3 动态3D地面位置嵌入3.4 预测头和损失 4. 实验评测4.1 数据集和评估指标4.2 实验设置4.3 主要结果 5. 讨论和总结 1. 概述 3D 车道线检测是自动驾驶中的一个基础但具有挑战性的任务。最近的进展主要依…

【NetTopologySuite类库】GeometryFixer几何自动修复,解决几何自相交等问题

介绍 NetTopologySuite 2.x 提供了GeometryFixer类,该类能够将几何体修复为有效几何体,同时尽可能保留输入的形状和位置。几何的IsValid属性,反映了几何是否是有效的。 输入的几何图形始终会被处理,因此即使是有效的输入也可能会…

特征工程方法总结

方法有以下这些 首先看数据有没有重复值、缺失值情况 离散:独热 连续变量:离散化(也成为分箱) 作用:1.消除异常值影响 2.引入非线性因素,提升模型表现能力 3.缺点是会损失一些信息 怎么分:…

pdf太大了怎么变小 pdf太大了如何变小一点

在数字化时代,pdf文件已成为工作与学习的重要工具。然而,有时我们可能会遇到pdf文件过大的问题,这会导致传输困难或者存储不便。别担心,下面我将为你介绍一些实用的技巧和工具,帮助你轻松减小pdf文件的大小。 方法一、…

docker的学习(一):docker的基本概念和命令

简介 docker的学习,基本概念,以及镜像命令和容器命令的使用 docker docker的基本概念 一次镜像,处处运行。 在部署程序的过程中,往往是很繁琐的,要保证运行的环境,软件的版本,配置文件&…

SQLite数据库在Android中的使用

目录 一,SQLite简介 二,SQLIte在Android中的使用 1,打开或者创建数据库 2,创建表 3,插入数据 4,删除数据 5,修改数据 6,查询数据 三,SQLiteOpenHelper类 四&…

信弘智能与图为科技共探科技合作新蓝图

本期导读 近日,图为信息科技(深圳)有限公司迎来上海信弘智能科技有限公司代表的到访,双方共同探讨英伟达生态系统在人工智能领域的发展。 在科技日新月异的今天,跨界合作与技术交流成为了推动行业发展的重要驱动。7月…

使用JWT双令牌机制进行接口请求鉴权

在前后端分离的开发过程中,前端发起请求,调用后端接口,后端在接收请求时,首先需要对收到的请求鉴权,在这种情况先我们可以采用JWT机制来鉴权。 JWT有两种机制,单令牌机制和双令牌机制。 单令牌机制服务端…

JAVA 异步编程(线程安全)二

1、线程安全 线程安全是指你的代码所在的进程中有多个线程同时运行,而这些线程可能会同时运行这段代码,如果每次运行的代码结果和单线程运行的结果是一样的,且其他变量的值和预期的也是一样的,那么就是线程安全的。 一个类或者程序…

Linux驱动开发-06蜂鸣器和多组GPIO控制

一、控制蜂鸣器 1.1 控制原理 我们可以看到SNVS_TAMPER1是这个端口在控制着蜂鸣器,同时这是一个PNP型的三极管,在端口输出为低电平时,蜂鸣器响,在高电平时,蜂鸣器不响 1.2 在Linux中端口号的控制 gpiochipX:当前SoC所包含的GPIO控制器,我们知道I.MX6UL/I.MX6ULL一共包…

整顿职场?安全体系建设

本文由 ChatMoney团队出品 00后整顿职场,职场到底怎么了?无压力、无忧虑的00后可以直接开整,那绝大部分打工人寒窗苦读、闯过高考,艰辛毕业,几轮面试杀入职场,结婚买房、上有老下有小,就活该再被…

怎么剪辑音频文件?4款适合新的音频剪辑软件

是谁还不会音频剪辑?无论是个人音乐爱好者,还是专业音频工作者,我们都希望能找到一款操作简便、功能强大且稳定可靠的音频剪辑工具。今天,我就要为大家带来四款热门音频剪辑软件的体验感分享。 一、福昕音频剪辑 福昕音频剪辑是…

JUnit 单元测试

JUnit 测试是程序员测试,就是白盒测试,可以让程序员知道被测试的软件如何 (How)完成功能和完成什么样(What)的功能。 下载junit-4.12和hamcrest-core-1.3依赖包 相关链接 junit-4.12:Central …

【JavaScript 算法】最长公共子序列:字符串问题的经典解法

🔥 个人主页:空白诗 文章目录 一、算法原理状态转移方程初始条件 二、算法实现注释说明: 三、应用场景四、总结 最长公共子序列(Longest Common Subsequence,LCS)是字符串处理中的经典问题。给定两个字符串…

Go语言之参数传递

文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 修改参数 假设你定义了一个函数,并在函数里对参数进行…