【笔记】扩散模型(七):Latent Diffusion Models(Stable Diffusion)论文解读与代码实现

news2024/11/14 4:13:44

论文链接:High-Resolution Image Synthesis with Latent Diffusion Models

官方实现:CompVis/latent-diffusionCompVis/stable-diffusion

这一篇文章的内容是 Latent Diffusion Models(LDM),也就是大名鼎鼎的 Stable Diffusion。先前的扩散模型一直面临的比较大的问题是采样空间太大,学习的噪声维度和图像的维度是相同的。当进行高分辨率图像生成时,需要的计算资源会急剧增加,虽然 DDIM 等工作已经对此有所改善,但效果依然有限。Stable Diffusion 的方法非常巧妙,其把扩散过程转换到了低维度的隐空间中,解决了这个问题。

方法介绍

本方法的整体结构如下图所示,主要分为三部分:最左侧的红框对应于感知图像压缩,中间的绿框对应 Latent Diffusion Models,右侧的白框表示生成条件,下面将分别介绍这三个部分。

Latent Diffusion Models 结构图

感知图像压缩

LDM 把图像生成过程从原始的图像像素空间转换到了一个隐空间,具体来说,对于一个维度为 x ∈ R H × W × 3 \mathbf{x}\in\mathbb{R}^{H\times W\times 3} xRH×W×3 的 RGB 图像,可以使用一个 encoder E \mathcal{E} E 将其转换为隐变量 z = E ( x ) \mathbf{z}=\mathcal{E}(\mathbf{x}) z=E(x),也可以用一个 decoder D \mathcal{D} D 将其从隐变量转换回像素空间 x ~ = D ( E ( x ) ) \tilde{\mathbf{x}}=\mathcal{D}(\mathcal{E}(\mathbf{x})) x~=D(E(x))。在转换时会将图像下采样,作者测试了一系列下采样倍数 f ∈ { 1 , 2 , 4 , 8 , 16 , 32 } f\in\{1, 2, 4, 8, 16, 32\} f{1,2,4,8,16,32},发现下采样 4-16 倍的时候可以比较好地权衡效率和质量。

在进行图像压缩时,为了防止压缩后的空间是某个高方差的空间,需要进行正则化。作者使用了两种正则化,第一种是 KL-正则化,也就是将隐变量和标准高斯分布使用一个 KL 惩罚项进行正则化;第二种是 VQ-正则化,也就是使用一个 vector quantization 层进行正则化。

Latent Diffusion Models

实际上 latent diffusion models 和普通的扩散模型没有太大区别,只是因为从像素空间变到了隐空间,所以维度降低了。训练的优化目标也没有太大变化,普通的扩散模型优化目标为:
L DM = E x , ϵ ∼ N ( 0 , 1 ) , t [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 2 ] L_\textrm{DM}=\mathbb{E}_{\mathbf{x},\epsilon\sim\mathcal{N}(0,1),t}\left[||\epsilon-\epsilon_\theta(\mathbf{x}_t,t)||_2^2\right] LDM=Ex,ϵN(0,1),t[∣∣ϵϵθ(xt,t)22]
而 Latent Diffusion Models 的优化目标只是套了一层 autoencoder:
L LDM = E E ( x ) , ϵ ∼ N ( 0 , 1 ) , t [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 2 ] L_\textrm{LDM}=\mathbb{E}_{\textcolor{red}{\mathcal{E}(\mathbf{x})},\epsilon\sim\mathcal{N}(0,1),t}\left[||\epsilon-\epsilon_\theta(\mathbf{x}_t,t)||_2^2\right] LLDM=EE(x),ϵN(0,1),t[∣∣ϵϵθ(xt,t)22]
在采样时,首先从隐空间随机采样噪声,在去噪后再用 decoder 转换到像素空间即可。

条件生成

为了进行条件生成,需要学习 ϵ θ ( x t , t , y ) \epsilon_\theta(\mathbf{x}_t,t,y) ϵθ(xt,t,y),这里使用的方法是在去噪网络中加入 cross attention 层,条件通过交叉注意力注入。在计算注意力时, z \mathbf{z} z 为 Query、 y y y 为 Key 和 Value,具体的内容已经在 Classifier-Free Guidance 的文章中介绍过了,对具体细节感兴趣的读者可以去看一下。

代码解读

Stable Diffusion 有两套主流的代码实现,第一种是 CompVis 的官方实现,第二种是 huggingface 的实现。这里的代码解读都以文生图任务为例。

CompVis 的实现

这个实现的代码比较分散,层次结构不太好梳理,不过可以照着配置文件看各部分都在哪里。这个配置文件有点类似 openmmlab 的那套框架的写法,例如文生图的配置文件 models/ldm/text2img256/config.yaml

model:
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
    first_stage_config:
      target: ldm.models.autoencoder.VQModelInterface
    cond_stage_config:
      target: ldm.modules.encoders.modules.BERTEmbedder

无关的内容都略去,可以看到顶层的模块是 LatentDiffusion,去噪网络是 UNetModel、encoder 是 VQModelInterface、文本编码器是 BERTEmbedder

这里主要还是关注 LatentDiffusion 的采样过程。具体的采样代码位于 LatentDiffusion.sample

@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
           verbose=True, timesteps=None, quantize_denoised=False,
           mask=None, x0=None, shape=None, **kwargs):
    # 一些数据的封装以及格式转换等等
    if shape is None:
        shape = (batch_size, self.channels, self.image_size, self.image_size)
    if cond is not None:
        if isinstance(cond, dict):
            cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
            list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
        else:
            cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
    # 实际的采样过程
    return self.p_sample_loop(cond,
                              shape,
                              return_intermediates=return_intermediates, x_T=x_T,
                              verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
                              mask=mask, x0=x0)

可以看到实际的采样过程并不是在这一层进行,这一层只进行了一些封装,例如采样的大小以及条件的数据格式等等,具体的采样则是在 p_sample_loop 中进行的:

@torch.no_grad()
def p_sample_loop(self, cond, shape, timesteps=None):
    iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(range(0, timesteps))
    for i in iterator:
        ts = torch.full((b,), i, device=device, dtype=torch.long)
        img = self.p_sample(img, cond, ts,
                            clip_denoised=self.clip_denoised,
                            quantize_denoised=quantize_denoised)
    return img

去掉一堆杂七杂八的代码之后可以发现在 p_sample_loop 中是一个循环,也就对应于一步步进行降噪的过程,具体的降噪在 p_sample 中实现:

@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
             return_codebook_ids=False, quantize_denoised=False, return_x0=False,
             temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
    b, *_, device = *x.shape, x.device
    outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
                                   return_codebook_ids=return_codebook_ids,
                                   quantize_denoised=quantize_denoised,
                                   return_x0=return_x0,
                                   score_corrector=score_corrector,
                                   corrector_kwargs=corrector_kwargs)
    model_mean, _, model_log_variance = outputs
    noise = noise_like(x.shape, device, repeat_noise) * temperature
    if noise_dropout > 0.:
        noise = torch.nn.functional.dropout(noise, p=noise_dropout)
    # no noise when t == 0
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

p_sample 中,首先用模型预测出了均值和方差(也就是 p_mean_variance,这里就不展开讲了),然后进行了去噪。

综合上述分析来看,如果看原始代码,可能会觉得非常混乱,但是其实去掉不重要的内容之后,核心的代码并不算非常多。这里没有展开具体的 p_mean_variance 内部的内容,在 CompVis 的框架中,定义了很多 diffusion 中常用的常量(例如 alphas_cumprodsqrt_recipm1_alphas_cumprod 等)和方法(例如 q_mean_variancep_mean_variance 等),后续我应该还会写一篇文章专门介绍这些内容,这里暂时略过,只需要知道最顶层的 p_mean_variance 是预测了均值和方差即可。

huggingface 的实现

相比于 CompVis 的实现,huggingface 的实现更加工程化一点,相关的在 diffusers 库中。这个库主要包括三大类元素:models(各种神经网络的实现,unet、vae 等)、schedulers(diffusion 相关的操作,加噪去噪等)、pipelines(high level 封装,相当于 models+schedulers,这个应该是方便用户直接用的)。

这里直接看 diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 的采样过程,定义在 __call__ 函数中:

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
    self,
    prompt: Union[str, List[str]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    timesteps: List[int] = None,
    sigmas: List[float] = None,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.Tensor] = None,
    prompt_embeds: Optional[torch.Tensor] = None,
    negative_prompt_embeds: Optional[torch.Tensor] = None,
    ip_adapter_image: Optional[PipelineImageInput] = None,
    ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
    guidance_rescale: float = 0.0,
    clip_skip: Optional[int] = None,
    callback_on_step_end: Optional[
        Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
    ] = None,
    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
    **kwargs,
):

可以看到参数实在是非常的多,我们在这里不关注工程的部分,只关注核心的逻辑。这里的第一个需要关注的点是对生成条件进行编码:

prompt_embeds, negative_prompt_embeds = self.encode_prompt(
    prompt,
    device,
    num_images_per_prompt,
    self.do_classifier_free_guidance,
    negative_prompt,
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    lora_scale=lora_scale,
    clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

这里实际上还有 LoRA 和 IP-Adaptor 相关的处理,暂时省略。可以看到这里对生成的 prompt 进行了编码,并且不仅有正常的 prompt,还有 negative 的 prompt,这是为了做 classifier-free guidance。并且由于两个 prompt 需要分别推理,这里还将其在 batch 维度拼接,来进行并行化。随后获取 timesteps:

timesteps, num_inference_steps = retrieve_timesteps(
    self.scheduler, num_inference_steps, device, timesteps, sigmas
)

然后初始化噪声,这个就相当于 x T \mathbf{x}_T xT

num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
    batch_size * num_images_per_prompt,
    num_channels_latents,
    height,
    width,
    prompt_embeds.dtype,
    device,
    generator,
    latents,
)

上边准备了 x \mathbf{x} x、timestep 以及 condition,现在就可以正式进行生成了:

num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
        # predict the noise residual
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            timestep_cond=timestep_cond,
            cross_attention_kwargs=self.cross_attention_kwargs,
            added_cond_kwargs=added_cond_kwargs,
            return_dict=False,
        )[0]
        # perform guidance
        if self.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

可以看到有一些为了 classifier-free guidance 进行的处理,其他的都是正常 diffusion 的操作。最后将隐变量解码回像素空间得到生成结果:

image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]

总结

最近看了这么多文章,感觉比较成功的 researcher 的工作都是连贯的。就像宋飏研究 sliced score matching,然后紧随其后做出了 score-based generative model;又如 OpenAI 训出 CLIP 然后基于 CLIP 做了一系列文生图的工作。今天这篇文章看起来也是 CompVis 把 VQGAN 迁移到 diffusion models 上的成果,感觉对平时做研究的启发还是很大的,我个人一直以来研究方向都比较摇摆不定,也应该反思学习一下。

参考资料:

  1. diffusion model(五):LDM: 在隐空间用diffusion model合成高质量图片
  2. 扩散模型(六)| Stable Diffusion

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(七):Latent Diffusion Models(Stable Diffusion)理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2122033.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HarmonyOS开发之Swiper的使用(跳转到指定索引的方法)

一,效果图 class MyDataSource implements IDataSource {private list: number[] []private listener: DataChangeListenerconstructor(list: number[]) {this.list list}totalCount(): number {return this.list.length}getData(index: number): any {return thi…

svg图标的使用

图片的格式有很多,前端经常使用的有以下类型:jpg,jpeg,png,gif,svg,这篇文章将简单svg的情况,以及项目中如何使用和配置svg图标 目录 什么是svg图标 SVG图标的优缺点 优点 缺点 svg前端使用场景 SVG在代码中的使用 简单使用创建svg 作为图标引入…

注册网站怎么注册

网站注册成为我们日常生活中不可或缺的一部分。无论是社交媒体、电子商务平台还是各种在线服务,注册都是参与这些平台的第一步。下面将为您详细介绍一般网站注册的步骤,帮助您轻松完成注册过程。 1. 选择合适的网站 在注册之前,首先要确定您…

使用kubeadm部署k8s集群

1、简介 K8s部署主要有两种方式: 1、Kubeadm Kubeadm是一个K8s部署工具,提供kubeadm init和kubeadm join,用于快速部署Kubernetes集群。 2、二进制 从github下载发行版的二进制包,手动部署每个组件,组成Kubernetes集…

通过 汇编 分析 结构体

不使用结构体的情况, 网上的资料: 使用结构体的情况 总结 ; 使用 结构体之后, 会节省汇编的 ldr 指令, 结构体 就直接使用 偏移量 来 对变量进行赋值了。 注意 : 这里 结构体 依然是一个全局变量。

CentOS7 安装配置Maven

一、Maven介绍 Apache Maven 是一个 Java 项目的构建自动化工具,主要用于构建、依赖管理和项目信息管理。Maven 使用一种称为“生命周期”(Lifecycle)的概念来管理构建过程的不同阶段,例如编译源代码、运行测试、打包、部署等。这…

ubuntu使用命令行查看硬件信息

ubuntu使用命令行查看硬件信息 CPU cat /proc/cpuinfo其中,model name就显示了cpu的型号,cpu cores显示cpu的所有物理核心数量。 内存 cat /proc/meminfo其中,MemTotal就显示总内存大小,这里为32GB内存,SwapTotal显…

走近张大鹏教授:哈工大走出的中国第一位人工智能博士

写在最前 张大鹏,加拿大皇家科学院院士,加拿大工程院院士,国际电气与电子工程师协会终身会士(IEEE Fellow),国际模式识别协会会士,亚太人工智能学会会士,香港中文大学(深…

速通GPT-3:Language Models are Few-Shot Learners全文解读

文章目录 论文实验总览1. 任务设置与测试策略2. 任务类别3. 关键实验结果4. 数据污染与实验局限性5. 总结与贡献 Abstract1. 概括2. 具体分析3. 摘要全文翻译4. 为什么不需要梯度更新或微调⭐ Introduction1. 概括2. 具体分析3. 进一步分析 Approach1. 概括2. 具体分析3. 进一步…

批发订货系统源码怎么弄 门店订货系统小程序价格

上线批发订货系统可以显著提升业务效率和管理水平,它能够帮助企业自动化处理订单、实时跟踪库存、简化订单管理、生成数据报表…这些优势能最终帮助你降低成本、提高效率,提升业务竞争力。今天,小编为您分享批发订货系统源码怎么弄。大家点赞…

自带线充电宝哪个牌子质量好性价比高?口碑最好自带线充电宝

在如今这个快节奏的时代,手机等电子设备已经成为我们生活中不可或缺的一部分。然而,电量不足的困扰时常让我们陷入尴尬境地。自带线充电宝的出现,无疑为我们解决了这一难题。它不仅方便携带,无需再额外携带充电线,而且…

新手入行项目管理,需掌握六大核心技能

对于新手而言,学习项目管理的核心技能对于确保项目目标的明确性、资源的有效利用、团队协作的顺畅性、风险的有效控制,以及按时按质完成任务至关重要。项目管理对组织成功至关重要,它提高资源配置效率,促进创新,确保项…

一个请求入参 req 引发的魔法攻击

项目场景 月初检修上线后没几天,隔壁项目组的同事,反馈说出现了生产问题,调用我们这边的接口报错。 问题描述 看到这个问题的第一眼,什么鬼,请求参数错误? 但是看到 “操作用户信息为空” 这个提示的时候…

MySQL系列—10.Innodb行格式

我们平时的数据以行为单位来想表中插入数据,这些记录在磁盘上的存放方式也被称为行格式或者记录格式。InnoDB存储引擎设计了 4 种不同类型的行格式,分别是Compact、Redundant、Dynamic 和 Compressed行格式 查看MySQL8的默认行格式: SELECT…

STM32 HAL freertos零基础(四) 二值信号量

1、二值信号量 FreeRTOS中的二值信号量是一种用于任务间同步的机制,它只能有两个状态:0 或 1。二值信号量通常用来表示某个事件是否发生,比如硬件中断发生时设置信号量为1,表示事件已发生;而任务在需要等待该事件发生时…

Jupyter Notebook远程登录配置

目录 一、之前的版本修改方法 1、生成配置文件 2、设置密码、获取秘钥 3、修改默认配置文件 注:自动化脚本 二、新版本 注:自动化脚本 三、访问 四、ip查询 1、win 2、linux 一、之前的版本修改方法 1、生成配置文件 jupyter notebook --ge…

选对crm管理系统软件,客户留存率提升70%不是梦!

本文将盘点10款行业领先的crm管理系统软件,为企业选型提供参考! CRM系统,全称Customer Relationship Management System,即客户关系管理系统,是企业用来管理和分析客户互动与数据的软件系统。CRM系统的核心在于“以客户…

idea 拉取项目需要log in to git地址

idea 拉取项目需要log in to git地址 一. 问题复现二. 解决办法 一. 问题复现 1.使用 idea 拉取 git 代码 2.弹出“log in to XXXX 二. 解决办法

JavaWeb案例-登录认证

在前面的文章中,我们复习了部门管理、员工管理的基本功能。但是我们并没有登录,就直接访问到了Tilias智能辅助系统的后台。这是不安全的,所以今天复习登录认证。最终实现的效果就是用户必须登录之后,才可以访问后台系统中的功能。…

Java版本管理工具Jabba安装教程(Windows)

Java版本管理工具Jabba安装教程(Windows) 前言 Java版本的管理工具有很多,诸如Jenv,Jabba等,考虑到我之前使用Node.js的nvm还比较顺手,Jabba是受Node.js的nvm启发而来,故选择Jabba作为版本管理工具 这里…