【笔记】扩散模型(九):Imagen 理论与实现

news2024/11/14 15:30:29

论文链接:Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding

非官方实现:lucidrains/imagen-pytorch

Imagen 是 Google Research 的文生图工作,这个工作并没有沿用 Stable Diffusion 的架构,而是级联了一系列普通的 DDPM 模型。其主要的贡献有以下几个方面:

  1. 使用比较大的文本模型进行文本嵌入,可以获得比使用 CLIP 更好的文本理解能力;
  2. 在采样阶段引入了一种动态阈值的方法,可以利用更高的 guidance scale 来生成更真实、细节更丰富的图像(这里的阈值是控制 x \mathbf{x} x 的范围);
  3. 改良了 UNet,提出 Efficient UNet,使模型更简单、收敛更快、内存消耗更少。

该模型的架构如下图所示,可以看到使用了一个条件生成的 diffusion 模型以及两个超分辨率模型,每个模型都以文本模型的 embedding 作为条件,先生成一个 64 分辨率的图像,然后逐步超分辨率到 1024 大小。

Imagen 模型结构

Imagen

预训练文本模型

现在的文生图模型主流使用的文本嵌入方法是使用 CLIP 文本编码器,在直观上感觉是比较合理的,因为 CLIP 的文本特征和图像特征共享同一个空间,用来控制图像的生成过程是比较合理的。不过 CLIP 的缺点是对文本的表达能力比较有限,处理复杂文本比较困难。

这里选择的不是使用 CLIP,而是使用规模比较大、且在大规模文本语料上训练的文本模型,具体来说使用的模型有 BERT、T5 和 CLIP。经过实验(具体结果可以看原论文 Figure 4 的 a 和 b,以及 Figure A.5),主要有以下发现:

  • 缩放文本编码器对提升生成质量的作用很明显;
  • 相比增大 UNet 的尺寸,增大文本编码器的尺寸更重要;
  • 相比于 CLIP,人类更偏好 T5-XXL 的结果。

高 Guidance Scale 的改善

提高 classifier-free guidance 的 guidance scale 可以提升文本-图像的匹配程度,但是会破坏图像的质量。这个现象是因为高 guidance scale 会导致训练阶段和测试阶段出现 mismatch。具体来说,在训练时,所有的 x \mathbf{x} x 都分布在 [ − 1 , 1 ] [-1,1] [1,1] 的范围里,然而当使用比较大的 guidance scale 时,得到的 x \mathbf{x} x 会超出这个范围。这样会导致 x \mathbf{x} x 落在已经学习过的范围以外,为了解决这个问题,作者研究了静态阈值(static thresholding)和动态阈值(dynamic thresholding)两种方案,具体算法如下图所示:

静态阈值和动态阈值算法

静态阈值

这种方法就是在预测噪声后,先计算出 x 0 \mathbf{x}_0 x0,然后将其取值范围直接裁剪到 [ − 1 , 1 ] [-1,1] [1,1] 之间,然后再进行去噪。这种方法已经很多方法都使用了,例如 openai/guided-diffusion 中的这段代码就是为了进行这种处理:

def process_xstart(x):
    if denoised_fn is not None:
        x = denoised_fn(x)
    if clip_denoised:
        return x.clamp(-1, 1) # 裁剪到 [-1,1]
    return x

if self.model_mean_type == ModelMeanType.EPSILON:
    pred_xstart = process_xstart(
        self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) # 得到 x_0
    )
model_mean, _, _ = self.q_posterior_mean_variance(
    x_start=pred_xstart, x_t=x, t=t
)

动态阈值

这个方法不是很好理解,我们可以从一个例子出发,我们平时进行 classifier-free guidance 时使用的 guidance scale 通常都是 7.5,那么一个原本分布在 [ − 1 , 1 ] [-1,1] [1,1] 之间的变量乘以这个系数之后就会变到 [ − 7.5 , 7.5 ] [-7.5,7.5] [7.5,7.5] 的范围内。如果某处的几个数分别是 { 0.2 , 0.4 , 0.6 , 0.8 } \{0.2, 0.4, 0.6, 0.8\} {0.2,0.4,0.6,0.8},乘以 7.5 后就变成了 { 1.5 , 3.0 , 4.5 , 6.0 } \{1.5,3.0,4.5,6.0\} {1.5,3.0,4.5,6.0}。如果此时直接将这些数裁剪到 [ − 1 , 1 ] [-1,1] [1,1],那么所有的数都会变成 1,原本这些数之间是有比较大的差别的,裁剪后都变成了相同的数,这样很明显是不合理的,动态阈值就是为了寻找一个比较合理的裁剪范围。

这里的做法是寻找一个 x 0 \mathbf{x}_0 x0 的 p-分位数 s s s,也就是找到大多数的数字落在什么范围内,然后先裁剪到 [ − s , s ] [-s,s] [s,s] 范围内,再全部除以 s s s 以缩放到 [ − 1 , 1 ] [-1,1] [1,1] 的范围内。实验发现这种方法能比较好地改善图像的质量,这部分的代码如下所示(摘自非官方实现):

if pred_objective == 'noise':
    x_start = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
elif pred_objective == 'x_start':
    x_start = pred
elif pred_objective == 'v':
    x_start = noise_scheduler.predict_start_from_v(x, t=t, v=pred)

if dynamic_threshold: # 动态阈值
    # 找到 p-分位数
    s = torch.quantile(
        rearrange(x_start, 'b ... -> b (...)').abs(),
        self.dynamic_thresholding_percentile,
        dim = -1
    )
    s.clamp_(min=1.)
    s = right_pad_dims_to(x_start, s)
    # 进行归一化
    x_start = x_start.clamp(-s, s) / s
else: # 静态阈值,直接截断
    x_start.clamp_(-1., 1.)
mean_and_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)

级联扩散模型

为了生成高分辨率图像,模型级联了三个扩散模型,一个用来生成低分辨率图像,两个用来将低分辨率图像逐步超分到高分辨率。在训练阶段,作者发现使用带有噪声条件增强的超分模型可以生成更高质量的模型。具体来说,每次生成噪声时,还从 [ 0 , 1 ] [0,1] [0,1] 范围内随机采样一个 aug level,然后基于这个 level 进行增强。在预测噪声时,不仅输入带噪声的图像、低分辨率图像、时间步,还输入一个 aug level。在推理阶段,使用一系列 aug level 进行增强,然后分别进行推理,从中选取一个最佳样本,这样可以提升采样效果。具体的算法如下所示:

超分模型的训练和采样过程

总结

除了上述的一些贡献,Imagen 还做了一些工程上的改进,例如使用了不同的 text condition 注入方式,以及对基础的 UNet 模型进行了改进,提出了 Efficient UNet 模型等。相比同期的其他方法,Imagen 应该是为数不多可以直接生成 1024 分辨率图像的 diffusion 模型,虽然和主流的 Stable Diffusion 架构不同,但其中的一些改进思路还是值得学习一下的。

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(九):Imagen 理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2237675.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华夏风物 3.2.0 | 中国风物志,记录各地特产、美食、风景,旅游吃货必备

华夏风物是一款记录中国各地风物的App,类似于一本中国“风物志”。它记录了各地的特产、美食、风景,为用户提供了一个了解和探索中国文化的窗口。该应用的社区氛围非常真实,用户可以发现许多家乡的特色小吃和传统手艺。许多帖子由当地人发布&…

BIST(Built-in Self-Test,内建自测试)学习笔记

参考资料: 内建自测试(Built-in Self-Test,简称BIST)详解_built in self test-CSDN博客 芯片测试术语 ,片内测试(BIST),ATE测试-CSDN博客 可能是DFT最全面的介绍--BIST - 知乎 (zhihu.com) 汽车功能安全--TC3xx LB…

【Ubuntu24.04】从双系统到虚拟机再到单系统的故事

故事 在大学前期,我使用Ubuntu系统都是为了学习一些命令或者其它Linux的东西,对性能的要求不高,所以选择了虚拟机,后来为了做毕设,选择安装了Ubuntu20.04双系统,因为虚拟机实在带不动,那时我的主…

AntFlow一款开源免费且自主可控的仿钉钉工作流引擎

在现代企业管理中,流程审批的高效性直接影响到工作的流畅度与生产力。最近,我发现了一个非常有趣的项目——AntFlow。这个项目不仅提供了一个灵活且可定制的工作流平台,还能让用户以可视化的方式创建和管理审批流程。 如果你寻找一个快速集成…

科学计算服务器:如何计算算力?如何提升科学研究效率?

在现代科学研究的舞台上,科学计算服务器犹如一位强大的幕后英雄,为复杂科学计算任务的攻克提供着坚实支撑。准确计算其算力并充分发挥优势,对提升科学研究效率意义非凡。 服务器的中央处理器(CPU)计算力。在科学计算服…

Java String字符串

Java字符串通常被视为一种数据类型,但由于它们按顺序存储字符类型的元素,类似于数组,因此也常被视为数据结构。在本文中,我们将通过以下大纲简明地了解有关Java字符串的所有内容。 什么是Java字符串?如何创建Java字符…

leetcode25:k个一组链表反转

给你链表的头节点 head ,每 k 个节点一组进行翻转,请你返回修改后的链表。 k 是一个正整数,它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍,那么请将最后剩余的节点保持原有顺序。 你不能只是单纯的改变节点内部的值…

《⼆叉搜索树》

《⼆叉搜索树》 1. ⼆叉搜索树的概念2. ⼆叉搜索树的性能分析3 二叉树的功能说明及实现3.1 ⼆叉搜索树的插⼊3.2 ⼆叉搜索树的查找3.3 ⼆叉搜索树的删除 4二叉搜索树的实现代码5 ⼆叉搜索树key和key/value使⽤场景5.1 key搜索场景:5.2 key/value搜索场景&#xff1a…

势不可挡 创新引领 | 生信科技SOLIDWORKS 2025新品发布会·苏州站精彩回顾

2024年11月01日,由生信科技举办的SOLIDWORKS 2025新产品发布会在江苏苏州圆满落幕。现场邀请到制造业的专家学者们一同感受SOLIDWORKS 2025最新功能,探索制造业数字化转型之路。 在苏州站活动开场,达索系统专业客户事业部华东区渠道经理马腾飞…

Spark 程序开发与提交:本地与集群模式全解析

Spark 的介绍与搭建:从理论到实践-CSDN博客 Spark 的Standalone集群环境安装与测试-CSDN博客 PySpark 本地开发环境搭建与实践-CSDN博客 目录 一、本地开发与远程提交测试 (一)问题背景 (二)解决方案 集群环境准…

童装类目电商代运营公司——品融电商

童装类目电商代运营公司——品融电商 随着电商行业的快速发展,童装类目已成为市场中极具潜力的细分领域之一。消费者对童装的需求不仅限于基本穿着功能,更倾向于选购具有设计感、安全性和舒适度的产品。童装类目涵盖婴儿服、儿童套装、家居服、户外服饰等…

利用pythonstudio写的PDF、图片批量水印生成器,可同时为不同读者生成多组水印

现在很多场合需要将PDF或图片加水印,本程序利用pythonstudio编写。 第一步 界面 其中: LstMask:列表框 PopupMenu:PmnMark LstFiles:列表框 PopupMenu:PmnFiles OdFiles:文件选择器 Filter:PDF文件(.PDF)|.PDF|图像文件(.JPG)|.JPG|图像文件(.png…

PDF模板制作与填充(Java)

1.PDF模板制作 准备原始模板 准备一个原始PDF模板,可以编辑好Word,预留出要填充的部分,再转换成PDF格式。 设置表单域 用任意PDF编辑器打开PDF模板文件,设置表单域,下面以WPS为例: 拖动文本域到需要填充的…

kafka中节点如何服役和退役

服役新节点 1)新节点准备 (1)关闭 bigdata03,进行一个快照,并右键执行克隆操作。 (2)开启 bigdata04,并修改 IP 地址。 vi /etc/sysconfig/network-scripts/ifcfg-ens33修改完记…

笔记本怎么开启TPM2.0_笔记本开启TPM2.0教程(不同笔记本开启tpm2.0方法)

在win11最低要求是提示,电脑必须满足 TPM 2.0,并开需要开启TPM 才能正常安装windows11系统,有很多笔记本的用户问我,笔记本怎么开启tpm功能呢?下面小编就给大家详细介绍一下笔记本开启tpm功能的方法。 如何确认你笔记本…

【PyTorch项目实战】图像分割 —— U-Net:Semantic segmentation with PyTorch

文章目录 一、项目介绍二、项目实战2.1、搭建环境2.1.1、下载源码2.1.2、下载预训练模型2.1.3、下载训练集 2.2、环境配置2.3、模型预测 U-Net是一种用于生物医学图像分割的卷积神经网络架构,最初由Olaf Ronneberger等人于2015年提出。 论文: U-Net: Con…

开源竞争-大数据项目期末考核

开源竞争: 自己没有办法完全掌握技术的时候就开源这个技术,培养出更多的技术依赖,让更多人完善你的技术,那么这不就是在砸罐子吗?一个行业里面总会有人砸罐子的,你不如先砸还能听个想。 客观现实&#xf…

11月7日星期四今日早报简报微语报早读

11月7日星期四,农历十月初七,早报#微语早读。 1、河南:旅行社组织1000人次境外游客在豫住宿2夜以上,可申请激励奖补; 2、主播宣称下播后商品恢复原价构成欺诈,广州市监:罚款5万元;…

HTMLCSS:3D 旋转卡片的炫酷动画

效果演示 这段代码是一个HTML和CSS的组合&#xff0c;用于创建一个具有3D效果的动画卡片。 HTML <div class"obj"><div class"objchild"><span class"inn6"><h3 class"text">我是谁&#xff1f;我在那<…

词嵌入方法(Word Embedding)

词嵌入方法&#xff08;Word Embedding&#xff09; Word Embedding是NLP中的一种技术&#xff0c;通过将单词映射到一个空间向量来表示每个单词 ✨️常见的词嵌入方法&#xff1a; &#x1f31f;Word2Vec&#xff1a;由谷歌提出的方法&#xff0c;分为CBOW&#xff08;conti…