音视频开发之旅(98) -潜扩散模型(Latent Diffusion Model)原理及源码解析

news2024/10/27 18:10:03

目录

1.背景

2. 潜扩散模型(Latent Diffusion Model)原理

3. 应用场景

4. 推理源码解析

5. 资料

一、背景

前面我们分析扩散模型(Diffusion Model)了解到,它通过向数据中添加噪声,然后训练一个去噪模型实现数据的还原,扩散模型在图像生成、图像修复等任务中表现了很好的效果,尤其是生成质量和多向性方面有显著的优势。但其本身是一个马尔科夫链的过程,即前后时刻数据有非常紧密的绑定关系,无法进行跳跃预测,生成过程通常需要超多轮次迭代导致生成速度较慢。并且它是直接在原始像素空间进行操作,对于高分辨率的图像生成任务,训练和推理都需要大量的计算资源。

为了在有限的计算资源上进行Diffusion Model的训练,同时保证其生成质量和灵活性,本论文作者创新性的提出将扩散过程应用于高度压缩的潜空间而不是原始的像素空间,并且通过交叉注意力机制,将条件输入(例如:文本描述,图像语义等)注入到网络,使得生成更符合预期的图像。

二、潜扩散模型(Latent Diffusion Model)原理

2.1  语义压缩(Semantic Compression)和感知压缩(Perceptual Compression)

图片

上图横坐标Rate(bits/dim)表示每个像素所需的比特数,Rate越高,表示编码压缩后的图像包含的细节越多,最左侧原图的Rate被设置为接近0,是因为进行语义压缩之前,原图在这条压缩率-失真曲线上没有经过压缩状态,将其设置为0作为参考点。

纵坐标(Distortion RMSE)是通过均方根RMSE计算的失真,值越大表示原图失真率越低。

右上角展示了语义压缩(Semantic Compression),目标是去除图像中对语义(或者称为人类视觉)不重要的信息。这个阶段失真(RMSE)从100快速下降到10左右,这意味着大部分图像数据的减少,而图像中很多的位信息对应的是人类视觉系统感知不到的细节,这些细节可以被大量压缩或者丢弃,不影响对整体图像的理解。

语义压缩在库潜扩散模型中被应用于,将输入的图像x经过编码器ϵ进入潜空间z的过程,编码器会在这个过程中丢弃图像中不重要的细节信息,只保留图像的核心语义,即高层语义信息。

右下角展示了感知压缩(Perceptual Compression),目标是保留图像的视觉细节,即使细微的变换也要保证图像的视觉效果不发生明显的变换。它是通过自动编码器(Autoencoder)+生成对抗网络(GAN)来保证图像的细节和质感。这个阶段的比特率Rate上升主要用于保存图像的精细特征。

感知压缩主要用在图像重建的过程,特别是潜空间中的扩散过程和去噪过程,扩散过程的每个步都在逐渐去噪恢复图像的细节,尤其是去噪UNet网络中的交叉注意力机制(Cross-Attention)和跳跃连接(Skip Connection)在这个过程中通过起到重要作用。通过感知压缩机制,逐渐生成那些细微的、对人类视觉重要的细节(如面部纹理、光影变化等)。

如果注意看会发现:上图下方人脸图像,第一张是女生图像,后面变成了男生图像。这是因为,随着语义压缩,模型逐渐丢弃了图像中的”不重要的“细节信息,只保留了大致的语义结构,后续模型在极低Rate情况下重建的图像可能因缺乏性别等细节特征,生成了男生形象,这反映了生成模型在生成图像时高层语义上的一些不确定性。

2.2 网络结构和流程

图片

图片来自:论文High-Resolution Image Synthesis with Latent Diffusion Models

训练过程:

1. 像素空间到潜在空间:图像x从原始的高纬的像素空间(Pixel Space)经过编码器ϵ,压缩到低维的潜在空间(Latent Space)表示z;

2. 在潜在空间进行使用U-Net网络进行前向扩散处理(Diffusion Process),训练潜空间的去噪模型

推理阶段:

1. 将文本、语义图等数据作为条件输入(Conditioning),通过特定的编码器映射到和潜在空间相同的维度,用于指导图像生成过程

图片

图片来自:how-stable-diffusion-work

2. 使用去噪U-Net网络,在潜空间将噪声图像逐步去噪生成清晰的图像(Z_T,Z_T-1,...,Z_0),U-Net中包含多个跳跃连接(skip connections)和交叉注意力机制(cross-attention),其中跳跃连接有助于不同尺度的特征融合,交叉注意力机制则用来整合条件输入和噪声图像,图像用于Q矩阵的生成,条件输入则用于K和V,使得模型在生成图像的每一步动态地关注条件输入的不同部分,从而生成与条件输入更紧密相关的潜空间图像

图片

图片来自:Latent Diffusion Models (LDMs) 模型学习笔记

3. 使用解码器D将生成的潜空间图像转换为像素空间,得到重建的图像

另外还有个Switch模块,表示在扩散过程不同阶段可以在不同模式(有条件和无条件生成)之间切换。整体流程如下图所示:

图片

2.3 公式解读

下面两个公式分别表示扩散模型(Diffusion Model,简称DM)和潜扩散模型(Latent Diffusion Model,简称LDM)的核心优化目标。

主要区别在于:DM直接操作的是高维的原始数据空间(像素空间),而LDM直接操作的是经过编码器编码后的ϵ(x)所在的低维的潜在空间,大大降低了计算负担。

图片

图片

x 是原始图像数据

ϵ(x)是通过编码器ϵ从原始数据x获得的潜在表示

ϵ~N(0,1),是从标准正态分布N(0,1)中采样的噪声

X_t是在时间步骤t时加噪声的数据,Z_t则是在时间步骤t时加噪声的潜在表示

图片

图片

∥⋅∥是欧几里得距离,这里使用的是L2范数,用于衡量预测的噪声和实际噪声之间的差异

三、应用场景

文生图

图片

图生图

图片

图像超分

图片

图像Inpainting修复

图片

四、推理源码解析

基于https://github.com/CompVis/latent-diffusion分析

Latent Dissusion的训练分为两阶段:

第一阶段训练自编码器(AutoencoderKL),具体实现在latent-diffusion/ldm/models/autoencoder.py

第二阶段训练Diffusion(基于DDPM),具体实现在latent-diffusion-main/ldm/models/diffusion/ddpm.py

图片

图片来自:一文详解 Latent Diffusion官方源码

下面我们看下推理的过程

使用示例(以文生图为例)

流程:获取配置、加载模型,初始化DDIM采样器、进行采样去噪生成图片

这里使用的DDIM采样器(Denoising Diffusion Implicit Models)而不是DDPM采样器(Denoising Diffusion Probabilistic Models),DDIM的主要改进点有:

1. 不再严格遵循马尔科夫过程,允许在采样过程中跳过部分中间步骤,直接从高噪声状态向低噪声状态过渡。使得采样步数大幅减少(10-20步 vs 1000步),从而显著降低计算量和加快生成速度。

2. DDIM提供了一种更灵活的采样策略,可以通过调整参数来控制生成过程的方差。这使得DDIM在生成样本的多样性和质量之间提供了更好的平衡。

class Txt2Img():    def __init__(self):        #txt2img-1p4B-eval.yaml中配置了LatentDiffusion、UNetModel、AutoencoderKL和BERTEmbedder相关配置        config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml")        #加载模型        self.model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt")
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")        self.model = self.model.to(device)
        #设置采样器,这里使用DDIM,而不是DDPM,加速推理过程        self.sampler = DDIMSampler(model)        def do_predict(self,prompt,outpath,ddim_steps=50,ddim_eta=0.0,n_iter=4.0,W=256,H=256,n_samples=4,scale=5.0):        sample_path = os.path.join(outpath, "samples")        os.makedirs(sample_path, exist_ok=True)
        all_samples=list()        with torch.no_grad():            with self.model.ema_scope():                uc = None                if scale != 1.0:                    uc = self.model.get_learned_conditioning(n_samples * [""])                for n in trange(n_iter, desc="Sampling"):                    c = self.model.get_learned_conditioning(n_samples * [prompt])                    shape = [4, H//8, W//8]                    #调用DDIM采用器,进行图像生成                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,                                                    conditioning=c,                                                    batch_size=n_samples,                                                    shape=shape,                                                    verbose=False,                                                    unconditional_guidance_scale=scale,                                                    unconditional_conditioning=uc,                                                    eta=ddim_eta)
                    x_samples_ddim = self.model.decode_first_stage(samples_ddim)                    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
                    for x_sample in x_samples_ddim:                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')                        Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))                        base_count += 1                    all_samples.append(x_samples_ddim)

        # additionally, save as grid        grid = torch.stack(all_samples, 0)        grid = rearrange(grid, 'n b c h w -> (n b) c h w')        grid = make_grid(grid, nrow=n_samples)
        # to image        grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()        Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))

DDIMSampler

latent-diffusion-main/ldm/models/diffusion/ddim.py

sample 方法是类的主入口点,用于生成样本。它接受多个参数,包括采样步数 S、批量大小 batch_size、图像形状 shape 等。

ddim_sampling 方法执行实际的采样过程。它使用 tqdm 进度条来跟踪采样进度,并在每一步中调用 p_sample_ddim 方法来更新图像。p_sample_ddim 方法执行单个时间步的采样。它计算当前时间步的预测值 pred_x0 和下一个时间步的图像 x_prev。这个方法使用了模型的参数,如 alphas、betas 和 sigmas,以及一些辅助函数来计算噪声和预测值。p_sample_ddim 实现如下所示

def p_sample_ddim(self,     x,                     # 当前时刻的噪声图像    c,                     # 条件信息(如文本提示)    t,                     # 当前时间步    index,                # 时间步索引    repeat_noise=False,    # 是否重复使用噪声    use_original_steps=False,  # 是否使用原始DDPM步数    quantize_denoised=False,   # 是否量化去噪结果    temperature=1.,        # 采样温度    noise_dropout=0.,      # 噪声dropout率    score_corrector=None,  # 分数修正器    unconditional_guidance_scale=1.,  # 无条件引导比例    unconditional_conditioning=None    # 无条件信息):    # 如果没有无条件引导或比例为1,直接用条件信息预测噪声    if unconditional_conditioning is None or unconditional_guidance_scale == 1.:        e_t = self.model.apply_model(x, t, c)    else:        # Classifier-Free Guidance (CFG)实现        x_in = torch.cat([x] * 2)  # 复制输入        t_in = torch.cat([t] * 2)  # 复制时间步        c_in = torch.cat([unconditional_conditioning, c])  # 拼接条件        # 分别得到无条件和有条件的预测        e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)        # CFG公式:预测 = 无条件预测 + scale * (有条件预测 - 无条件预测)        e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)            # 根据是否使用原始步数选择相应的参数    alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas    alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev    sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas    sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas        # 获取当前时间步的参数值    a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)    a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)    sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)    sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)        #DDIM采样核心步骤    # 预测x0(原始图像)    pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()        # 可选的量化操作    if quantize_denoised:        pred_x0, *_ = self.model.first_stage_model.quantize(pred_x0)    dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t    # 生成噪声    noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature    if noise_dropout > 0.:        noise = torch.nn.functional.dropout(noise, p=noise_dropout)        # DDIM更新公式:计算上一时刻的图像    x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise    return x_prev, pred_x0

五、资料

论文High-Resolution Image Synthesis with Latent Diffusion Models:https://arxiv.org/pdf/2112.10752

论文源码:GitHub - CompVis/latent-diffusion: High-Resolution Image Synthesis with Latent Diffusion Models

sd-v1:GitHub - CompVis/stable-diffusion: A latent text-to-image diffusion model

sd-v2:GitHub - Stability-AI/stablediffusion: High-Resolution Image Synthesis with Latent Diffusion Models

SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis:https://arxiv.org/pdf/2307.01952

how-stable-diffusion-work:https://stable-diffusion-art.com/how-stable-diffusion-work

The Illustrated Stable Diffusion:The Illustrated Stable Diffusion – Jay Alammar – Visualizing machine learning one concept at a time.

一文读懂Stable Diffusion 论文原理+代码超详细解读:https://zhuanlan.zhihu.com/p/640545463

深入浅出完整解析Stable Diffusion(SD)核心基础知识:https://zhuanlan.zhihu.com/p/632809634

一文详解 Latent Diffusion官方源码: https://blog.csdn.net/yusijinfs/article/details/134684608

Latent Diffusion Models (LDMs) 模型学习笔记 https://blog.csdn.net/hjhr2018/article/details/140152429

DIFFUSION 系列笔记| Latent Diffusion Model https://kevinng77.github.io/posts/notes/articles/%E7%AC%94%E8%AE%B0latent_diffusion.html

深入浅出 diffusion(5):理解 Latent Diffusion Models(LDMs) https://blog.csdn.net/m0_37324740/article/details/135939637

Latent Diffusion:开始的开始 https://zhuanlan.zhihu.com/p/652186695

一文带你看懂DDPM和DDIM(含原理简易推导,pytorch代码)https://zhuanlan.zhihu.com/p/666552214

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SSM学习day01 JS基础语法

一、JS基础语法 跟java有点像,但是不用注明数据类型 使用var去声明变量 特点1:var关键字声明变量,是为全局变量,作用域很大。在一个代码块中定义的变量,在其他代码块里也能使用 特点2:可以重复定义&#…

【mysql进阶】4-6. InnoDB 磁盘文件

InnoDB 磁盘⽂件 1 InnoDB存储引擎包含哪些磁盘⽂件? 🔍 分析过程 ✅ 解答问题 InnoDB的磁盘⽂件主要是表空间⽂件和其他⽂件,表空间包括:系统表空间、独⽴表空间、通⽤表空间、临时表空间和撤销表空间;其他⽂件有重做…

XQT_UI 组件|03 |加载组件 XQtLoading

XQtLoading 使用文档 简介 XQtLoading 是一个自定义的加载动画组件,旨在为用户提供可配置的旋转花瓣动画效果。它可以在应用程序中用于指示加载状态,提升用户体验。 特征 可配置性:用户可以根据需求调整旋转周期、缩放周期、最大/最小缩放…

JavaScript part2

一.前言 前面我们讲了一下js的基础语法,但是这些还是远远不够的,我们要想操作标签,实现一个动态且好看的页面,就得学会BOM和DOM,这些都是浏览器和页面的,这样我们才能实现一个好看的页面 二.BOM对象 BOM…

golang将指针传给cgo后还能被回收吗?

问题&#xff1a; 如果把golang分配的变量&#xff0c;其指针通过cgo传给c&#xff0c;并被c存储&#xff0c;那这个变量还能被gc回收吗&#xff1f; 实验代码&#xff1a; test_memory_leak.go package main/* #include <stdlib.h> #include <string.h> #incl…

yolov11的onnx模型C++ 调用

yolov11的onnx模型C调用 效果图一、python调用二、onnx模型导出三、python的onnx调用调用检测模型调用分割模型 四、C的onnx模型调用五 、视频流的检测后续 效果图 一、python调用 本文只记录生成的yolov11模型如何调用&#xff0c;其他可参考各种yolov11博客 模型下载&#x…

Spring Boot 应用开发全攻略:从入门到精通

Spring Boot 应用开发全攻略&#xff1a;从入门到精通 引言 在当今快速发展的软件开发领域&#xff0c;Spring Boot 作为一种快速开发框架&#xff0c;凭借其简洁、易用的特性&#xff0c;赢得了开发者的广泛青睐。无论是微服务架构还是传统的单体应用&#xff0c;Spring Boo…

Redis 单机、主从、哨兵和集群架构详解和搭建

目录 前言 单机部署 检查安装 gcc 环境 下载安装 Redis 启动 Redis 关闭 Redis 配置Redis 主从部署 整体架构图 主从复制配置 重启 Redis 验证 主从复制的作⽤ 主从复制缺点 哨兵部署&#xff08;Sentinel&#xff09; 整体架构图 哨兵模式配置 启动哨兵 验证…

首席数据官和首席数据分析官

根据分析人士的预测&#xff0c;首席数据官&#xff08;CDO&#xff09;和首席数据分析官&#xff08;CDAO&#xff09;必须更有效地展示他们对企业和AI项目的价值&#xff0c;以保障其在高管层的地位。Gartner的最新报告指出&#xff0c;CDO和CDAO在AI时代需要重新塑造自身定位…

ElegantBook:优美的 LATEX 书籍模板(中文的latex模版)

关注B站可以观看更多实战教学视频&#xff1a;hallo128的个人空间 ElegantBook&#xff1a;优美的 LATEX 书籍模板&#xff08;中文的latex模版&#xff09; Github地址&#xff1a;https://github.com/ElegantLaTeX/ElegantBook使用说明文档&#xff1a;https://static.latexs…

C++11实践指北

C11&#xff1a;书、在线工具、库。 书 1. 《现代C语言核心特性解析》 覆盖 C11~C20 特性的讲解。 视频跟读&#xff1a;https://www.bilibili.com/video/BV1nN4y1j7fv 现代CPP随笔_0CCh - 每天5分钟了解现代C新特性 2. 《C Primer》第五版 基于 C11 的 C 入门书。 正在看…

故障诊断 | CNN-ResNets滚动轴承故障诊断实例代码

故障诊断 | CNN-ResNets滚动轴承故障诊断实例代码 目录 故障诊断 | CNN-ResNets滚动轴承故障诊断实例代码效果一览基本介绍程序设计参考资料 效果一览 基本介绍 CNN-ResNets&#xff08;卷积神经网络-残差网络&#xff09;在滚动轴承故障诊断中是一种常用的方法。这种方法利用…

使用Angular构建动态Web应用

&#x1f496; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4bb; Gitee主页&#xff1a;瑕疵的gitee主页 &#x1f680; 文章专栏&#xff1a;《热点资讯》 使用Angular构建动态Web应用 1 引言 2 Angular简介 3 安装Angular 4 创建Angular项目 5 设计应用结构 6 创建组件 7 …

Python小游戏14——雷霆战机

首先&#xff0c;你需要确保安装了Pygame库。如果你还没有安装&#xff0c;可以使用pip来安装&#xff1a; bash pip install pygame 代码如下&#xff1a; python import pygame import sys import random # 初始化Pygame pygame.init() # 设置屏幕大小 screen_width 800 scr…

云原生笔记

#1024程序员节|征文# 单页应用(Single-Page Application&#xff0c;SPA) 云原生基础 云原生全景内容宽泛&#xff0c;以至于刚开始就极具挑战性。 云原生应用是高度分布式系统&#xff0c;它们存在于云中&#xff0c;并且能够对变化保持韧性。系统是由多个服务组成的&#…

java-JVM面试问题-2024

1、简单介绍下虚拟机内存模型&#xff1f; VM由三部分组成&#xff1a;类加载子系统、运行时数据区、执行引擎 类加载子系统&#xff1a;通过类加载机制加载类的class文件&#xff0c;如果该类是第一次加载&#xff0c;会执行加载、验证、解析。只负责class文件的加载&#x…

基于neo4j的医疗问诊系统

当你身体不适时&#xff0c;想要找到准确的答案却经常遇到模棱两可的答复&#xff0c;糟心吗&#xff1f;现在&#xff0c;基于neo4j的智能医疗问诊系统为你带来全新体验&#xff01;我们设计了一个具备自动化问答功能的医疗系统&#xff0c;帮助用户快速获取专业的健康知识答案…

如何具备阅读JAVA JDK虚拟机源码能力

源码位置https://github.com/openjdk/jdk 核心实现源码[部分截图] /* * Copyright (c) 1995, 2024, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistr…

《欢乐饭米粒儿》持续热播:第四期小品笑中有思,引发观众共鸣

由鲜博士独家冠名播出的独创小品剧《欢乐饭米粒儿》第九季自播出以来&#xff0c;便以其贴近生活的题材和幽默风趣的表演赢得了观众的喜爱。每个小品不仅让人捧腹大笑&#xff0c;更在笑声中传递了深刻的生活哲理。近日&#xff0c;《欢乐饭米粒儿》又带来了几个新的小品&#…

计算机视觉专栏(1)【LeNet】论文详解

Lenet 系列 论文精讲部分0.摘要1.引言2.CNN3.结果分析4.总结 论文精讲部分 本专栏旨在深入解析计算机视觉模型的论文及其发展背景&#xff0c;并通过代码部分的实际实验来加深理解。读者可以根据自己的需要参考其中的内容。其主体为原文&#xff0c;笔者理解内容会采用引用格式…