【AIGC】Stable Diffusion原理快速上手,模型结构、关键组件、训练预测方式

news2024/9/21 0:39:29

【AIGC】Stable Diffusion的建模思想、训练预测方式快速

在这篇博客中,将会用机器学习入门级描述,来介绍Stable Diffusion的关键原理。目前,网络上的使用教程非常多,本篇中不会介绍如何部署、使用或者微调SD模型。也会尽量精简语言,无公式推导,旨在理解思想。让有机器学习基础的朋友,可以快速了解SD模型的重要部分。如有理解错误,请不吝指正。

大纲

  1. 关键概念
  2. 模型结构及关键组件
  3. 训练和预测方式

关键概念

名词解释

Stable Diffusion

之所以叫Stable,是因为金主公司叫StabilityAI。

其基础模型是Latent Diffusion Model(LDM),也是本文主要介绍的部分。

模型任务

  1. text-2-img:输入文本描述、输出图像
  2. img-2-img:输入图片及其他文本描述,输出图像

总的来说,不论是输入是文字还是图片,都可以称为是“condition”,用于指引图像生成的“方向”。因此,SD模型的任务,可以统称为是cond-2-img任务。

模型任务

模型结构与关键组件

模型结构

LDM论文结构图,初看时会有点懵,但稍微理解后还是非常清晰准确的。先初步介绍几个大的模块。建议把这张图截图固定在屏幕上,再继续浏览下面的内容。

模型结构

整体输入输出

上图中最左侧的 x x x x ~ \widetilde{x} x 是模型的输入与输出,形如 [ W , H , C ] [W, H, C] [W,H,C]的三维张量,代表一张图像的宽、高和通道数。

需要注意,这里的输入 x x x,并不是模型img-2-img中的输入图像,而是模型训练时的原始图像输入。img-2-img的输入图像,是上图中最右侧的Conditioning模块中的images。

像素空间与隐空间

所谓空间,可以理解为数据的表示形式,通常有着不同的坐标轴。

  • 像素空间(Pixel Space),上图左侧,红框部分。通常是人眼可以识别的图像内容。
  • 隐空间(Latent Space),上图中央,绿框部分。通常是人眼无法识别的内容,但包含的信息量与像素空间相近。

像素空间到隐空间

输入的图像 x x x,经过Encoder(图中蓝色的 E \mathcal{E} E),转换为另一种shape的张量 z z z,即称为隐空间。

从压缩角度理解:图像经过转换后,产生的新张量是人眼无法识别的。但其包含的信息量相差不大,数据尺寸却大幅缩小,因此可以看做是一种图像数据压缩方式

隐空间到像素空间

经过模型处理后的隐向量输出 z z z(特指绿框左下角的 z z z),经过Decoder(图中蓝色的 D \mathcal{D} D),转换回像素空间。

隐空间Diffusion操作

对应图中绿色Latent Space框的上半部分,包括以下三步:

  1. 图像经过Encoder压缩后,得到隐向量表示 z = E ( x ) z=\mathcal{E}(x) z=E(x)隐向量
  2. 从1~1000的均匀分布中,随机采样一个整数 T T T,称为扩散步数
  3. 对向量 z z z T T T次高斯噪声,满足分布 N ( 0 , β t ) N(0, \beta_t) N(0,βt),得到 z T z_T zT向量

在这个操作中,有一些有趣的特性:

噪声收敛

加噪声次数足够多时,理论上会得到一组符合高斯分布的噪声。利用这个特性,在预测阶段我们就不需要执行Diffusion操作,只需要采样一组高斯分布的噪声,即代表了 z T z_T zT

高斯噪声可加性

当我们需要得到任意时刻的 z T z_T zT时,可以直接从 z 0 z_0 z0以及一系列 β t \beta_t βt计算得到,只需要采样一次噪声。这部分的具体公式推导,可以参考由浅入深了解Diffusion Model - 知乎 (zhihu.com)。

隐空间Denoising操作

对应图中绿色框的下半部分,包括以下步骤:

  1. 输入 z t , t , c o n d z_t,t,cond zt,t,cond给U-Net结构,预测出一个噪声 ϵ θ ( z t , t , c o n d ) \epsilon_{\theta}(z_t,t,cond) ϵθ(zt,t,cond),shape与 z t z_t zt一致
  2. 使 z t − 1 = z t − ϵ θ ( z t , t , c o n d ) z_{t-1} = z_t - \epsilon_{\theta}(z_t,t,cond) zt1=ztϵθ(zt,t,cond),重复上一步骤,直至获得 z 0 z_0 z0隐向量
  3. 使用Decoder得到输出图像, x ~ = D ( z 0 ) \widetilde{x} = \mathcal{D}(z_0) x =D(z0)

条件Conditioning

对应图中最右边灰白色框,输入类型包括text、images等。在Conditioning模块中,会执行以下步骤:

  1. 这些“附加信息”会通过对应的编码器 τ θ \tau_\theta τθ,转换成向量表示
  2. 转换后的向量,会输入给U-Net,作为其中Attention模块的K、V输入,辅助噪声的预测

在这个模块中,有几个有趣的问题:

为什么需要Conditioning

由于“噪声收敛”特性,当噪声加得比较多时, z T z_T zT已经趋近于一个“纯噪声”了,但训练过程需要比对输入图像 x x x和输出图像 x ~ \widetilde{x} x 的相似度。如何从一个“纯噪声”,还原回与输入图像相似的图像,就必须要给模型提供额外的信息指引,这就是Conditioning的作用。

关键组件

VAE(Variational Auto Encoders)

在LDM中,如何将原始图片“压缩”转换至隐空间,经过处理再转换回来,即使用VAE的Encoder和Decoder。这个模块是预训练好的,在LDM训练时固定住参数。

原理

  1. 原始张量输入,经过非常简单的网络结构,转换成较小的张量
  2. 在Latent张量上,加一点点噪声扰动
  3. 用对称的简单网络结构,还原回原始大小
  4. 对比输入前后的张量是否相似

特点

  1. 网络计算复杂度比较低
  2. Encoder和Decoder可以分开使用
  3. 无监督训练,不需要标注输入的label
  4. 有了噪声扰动之后,Latent Space的距离具有实际物理含义,可以实现例如“(满杯水+空杯子)/ 2 = 半杯水”的操作

VAE

CLIP

文本信息如何转换成张量,靠的是CLIP模块。这个模块是预训练好的,在LDM训练时固定住参数。

训练方式

图像以及它的描述文本,经过各自的Encoder转换为向量表示,希望转换后的向量距离相近。经过训练后,文本描述可以映射到向量空间的一个点,其代表的物理含义与原始图像相近。

CLIP

假设无预训练

开个脑洞,假如没有这个模块,直接将文本token化后,去Embedding Table中查表作为文本张量,理论上也是可以训练的,只不过收敛速度会慢很多。

因此,这里使用一个预训练text-2-embedding模块,主要目的是加速训练。CLIP的训练数据集,也选择了和LDM的数据集的同一个(LAION-5B的子集),语义更一致。

模型标识解释

我们经常会看到类似“ViT-L/14”的模型名,表示一种CLIP的结构。具体的,ViT表示Vision Transformer,L表示Large(此外还有Base、Huge),14表示训练时把图像划分成14*14个子图序列输入给Transformer。

模型标识解释

U-Net

作为LDM的核心组件,U-Net是模型训练过程中,唯一需要参数更新的部分。在这个结构中,输入是带有噪声的隐向量 z t z_t zt、当前的时间戳 t t t,文本等Conditioning的张量表示 E E E,输出是 z t z_t zt中的噪声预测。

模型任务

U-Net的任务,就是从 z t z_t zt中预测出噪声部分 ϵ t \epsilon_t ϵt,从而得到降低噪声后的 z t − 1 = z t − ϵ t z_{t-1}=z_t - \epsilon_t zt1=ztϵt,直到获得 z 0 z_0 z0。下图是一个可视化示意图,实际上,我们去噪的 z t z_t zt是隐向量空间的数据,人眼无法识别。

U-Net模型任务

模型结构

U-Net大致上可以分为三块:降采样层、中间层、上采样层。之所以叫U-Net,是因为它的模型结构类似字母U。

U-Net模型结构

降采样层

  1. 时间戳 t t t转换为向量形式。用的是“Attention is All you Need”论文的Transformer方法,通过sin和cos函数再经过两个Linear进行变换
  2. 初始化输入 X = c o n v ( c o n c a t ( z t , E ) ) X = conv(concat(z_t, E)) X=conv(concat(zt,E)),其中 c o n v conv conv是卷积, E E E是Conditioning
  3. 重复以下步骤(a~c)多次,将输入尺寸降至目标尺寸(如上图的 4 × 4 4\times4 4×4
    1. 重复以下两步多次,训练多个ResBlock和SpatialTransformer层,输入值 X X X的尺寸不变
      1. 输入上一层的输出 X X X和时间戳向量,给ResBlock
      2. ResBlock的输出,与 E E E一起输入给SpatialTransformer,在这里考虑到text等信息
    2. 重复多次3~4步,
    3. 通过卷积或Avg-Pooling进行降采样,缩小 X X X的尺寸

U-Net模型结构-1
U-Net模型结构-2

中间层

很简单,ResBlock + SpatialTransformer + ResBlock,输入 X X X尺寸不变。

上采样层

大部分步骤与降采样层一致,只有以下两点不同

  1. 输入 X X X需要拼上对应降采样层的输出,称为skip connection,对应U-Net结构图中横向的箭头
  2. 把降采样步骤,换成使用卷积或插值(interpolate)方式来上采样,使得 X X X的尺寸增大

输出

上采样层的输出,会经过normalization + SiLU + conv,得到U-Net的最终输出,即噪声的预测值,尺寸保持与输入 z t z_t zt一致。

训练方式

模型更新方式

LDM模型需要训练的部分,只有U-Net的参数。训练的方式,可以简单总结为:

  1. 输入一张图片 x x x,以及它的文本描述等Conditioning,一个随机的整数 T T T
  2. 经过Encoder压缩、Diffusion加噪声,得到 z T z_T zT隐向量
  3. 结合Conditioning,使用U-Net,进行 T T T次去噪,得到预测值 z 0 z_0 z0向量
  4. 使用Decoder还原回 x ~ \widetilde{x} x ,计算 x x x x ~ \widetilde{x} x 之间的差距(KL散度),得到模型更新的loss

模型预测方式

  1. 随机一个高斯噪声,作为 z T z_T zT向量
  2. 输入text等Conditioning,使用U-Net进行指定次数 T T T的去噪操作
  3. 使用Decoder还原回 x ~ \widetilde{x} x ,得到图像输出

训练、预测过程,在论文中的伪代码为下图所示。
模型训练预测伪代码

展望

下一篇文章,将会讨论以下几个更深入的内容:

  1. ControlNet、LoRA等插件的实现
  2. 各种Conditioning Context是如何转换为张量的
  3. 训练的数据集情况

参考

The Illustrated Stable Diffusion – Jay Alammar – Visualizing machine learning one concept at a time. (jalammar.github.io)

【原创】万字长文讲解Stable Diffusion的AI绘画基本技术原理 - 知乎 (zhihu.com)

Diffusion Models:生成扩散模型 (yinglinzheng.netlify.app)

由浅入深了解Diffusion Model - 知乎 (zhihu.com)

How does Stable Diffusion work? - Stable Diffusion Art (stable-diffusion-art.com)

[2006.11239] Denoising Diffusion Probabilistic Models (arxiv.org)

CompVis/latent-diffusion: High-Resolution Image Synthesis with Latent Diffusion Models (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

靶机精讲之Tr0ll

主机发现 nmap扫描 端口扫描 UDP扫描 服务扫描 先从ftp和http下手,shh排后 尝试ftp 匿名登录 查看文件下载的信息 wireshark利用读取文件 strings读取 lol.pcap文本 读代码感觉像目录 进行访问 下载 拷贝到目录下(记得背后加点) file查看文…

Redis五大数据类型

关于Redis的五大数据类型,它们分别为:String、List、Hash、Set、SortSet。本文将会从它的底层数据结构、常用操作命令、一些特点和实际应用这几个方面进行解析。对于数据结构的解析,本文只会从大的方面来解析,不会介绍详细的代码实…

Linux_Shell命令解析

简介 在linux终端中执行ls命令,ls命令是如何被解析并且执行的。Shell命令的格式一般为: [commond] [-options] [parameter]执行命令 命令的选项 命令的参数当执行ls命令是显示当前目录下所有文件的名称 执行ls -l命令是显示当前目录下所有文件的属性…

软件工程开发文档写作教程(01)—开发文档的意义与作用

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl本文参考资料:电子工业出版社《软件文档写作教程》 马平,黄冬梅编著 软件工程开发文档的意义 软件文档是整个软件开发工作的基础,现代工程…

Maven(一)基础入门

目录 一、Maven简介1.背景2.Maven是什么3.Maven的作用 二、下载与安装1.下载2.安装3.配置环境变量 三、Maven基础概念1.仓库2.坐标3.本地仓库配置4.远程仓库配置5.阿里云-镜像仓库配置6.全局 settings 与用户 settings 区别 四、第一个Maven项目(手工制作&#xff0…

【Python】快速简单搭建HTTP服务器并公网访问「cpolar内网穿透」

转载自远程内网穿透的文章:【Python】快速简单搭建HTTP服务器并公网访问「cpolar内网穿透」 1.前言 Python作为热度比较高的编程语言,其语法简单且语句清晰,而且python有良好的兼容性,可以轻松的和其他编程语言((比如…

Qt/QML编程学习之心得:ALSA音频开发(六)

Linux内核中对音频播放和捕获的最初支持是由开放声音系统(OSS)提供的。OSS API是为音频而设计的带有16位双通道回放和捕获的卡,以及随后的API通过open()、close()、read()和write()系统调用的标准POSIX。OSS的主要问题是,虽然基于文件的API实际上易于应用程序开发人…

电磁阀“位”与“通”的详细解说(示意图)

电磁阀是用电磁控制的工业设备,是用来控制流体的自动化基础元件,属于执行器。 而气动电磁阀是其中的一种,是通过控制阀体的移动来档住或漏出不同的排油的孔,而进油孔是常开的,液压油就会进入不同的排油管,…

物联网定位技术|实验报告|实验一 Wi-Fi指纹定位

目录 实验1 Wi-Fi指纹定位 1. 实验目标 2. 实验背景 3. 实验原理 3.1 WIFI基础知识 3.2室内定位方法建模 3.3指纹定位算法 ①离线/训练阶段 ②在线/定位阶段 4. 关键代码 5. 实验结果 6. 室内定位误差分析 6.1 非视距传播 6.2 多径传播 6.3 阴影效应 7. 实验总结 物联网定位技…

ESP32学习三-环境搭建(ESP-IDF V5.0,Ubuntu20.4)

一、准备事项 Ubuntu 20.04。具体安装可以参考如下链接。使用VMware安装Ubuntu虚拟机和VMware Tools_t_guest的博客-CSDN博客 二、安装ESP-IDF 1)、确认python3版本 输入python3 --version来确认python3的版本。因为要安装ESP-IDF 5.0版本,python3的版本…

Docker Compose与Docker Swarm的简介和区别

Docker Compose与Docker Swarm的简介和区别 背景Compose 简介Swarm 简介Compose 和 Swarm区别 背景 之前公司很多都是单体的spring boot服务,使用Docker的时候,只需要定义Dockerfile 文件,然后打成镜像把容器启动起来就ok了。但是现在的微服…

低成本,全流程!基于PaddleDepth和Paddle3D的三维视觉技术应用方案

现实生活中的很多应用场景都需要涉及到三维信息。针对三维视觉技术应用场景复杂多样、三维感知任务众多、流程复杂等问题,飞桨为开发者提供了低成本的深度信息搜集方案 PaddleDepth 以及面向自动驾驶三维感知的全流程开发套件 Paddle3D 。 三维视觉技术应用场景 3D …

01——计算机系统基础

计算机系统基础知识 计算机系统基础一、计算机系统的基本组成1 计算机硬件系统 二、计算机的类型三、计算机的组成和工作原理1 计算机的组成2 总线的基本概念2.1 总线的定义与分类 3 系统总线3.1 系统总线的概念3.2 常见的系统总线 4 外总线5 中央处理单元(CPU&…

【刷题】搜索——BFS:八数码【A*模板】

A*简介 某点u的距离f(u)定义如下: f ( u ) g ( u ) h ( u ) f(u) g(u) h(u) f(u)g(u)h(u) g(u):起点到u走的距离 h(u):u到终点估计的距离,保证 0 ≤ h ( u ) ≤ h ′ ( u ) 0 \leq h(u) \leq h(u) 0≤h(u)≤h′(u)。其中h’…

健康体检信息系统源码,个人体检、团队体检、体检报告、统计分析

健康体检管理系统源码 PEIS源码 数据对接 体检人员管理系统,系统有演示,文档齐全。 一套专业的体检管理系统源码,该系统涵盖个人体检、团队体检、关爱体检等多种体检类型,提供体检登记管理、体检结果管理、体检报告打印及发放…

阿里云服务器搭建网站流程by宝塔Linux面板

阿里云服务器安装宝塔面板教程,云服务器吧以阿里云Linux系统云服务器安装宝塔Linux面板为例,先配置云服务器安全组开放宝塔所需端口8888、888、80、443、20和21端口,然后执行安装宝塔面板命令脚本,最后登录宝塔后台安装LNMP&#…

尝试图像锐化

#图像锐化 拉普拉斯: 导数f(x,y)f(x1,y)f(x−1,y)f(x,y1)f(x,y−1)−4f(x,y) 可以扩展到8邻域: ​ Mat Sharpen(Mat input, int percent, int type) { Mat result; Mat s input.clone(); Mat kernel; switch (type) { case 0: kernel (Mat_(3, 3)…

4个令人惊艳的ChatGPT项目,开源了

自从 ChatGPT、Stable Diffusion 发布以来,各种相关开源项目百花齐放,着实让人应接不暇。今天,将着重挑选几个优质的开源项目,对我们的日常工作、学习生活,都会有很大的帮助。 一、Visual ChatGPT 这个是微软开源的项…

代码随想录_二叉树_leetcode700、98

leetcode700.二叉搜索树中的搜索 700. 二叉搜索树中的搜索 给定二叉搜索树(BST)的根节点 root 和一个整数值 val。 你需要在 BST 中找到节点值等于 val 的节点。 返回以该节点为根的子树。 如果节点不存在,则返回 null 。 示例 1: 输入&…

乘客出租出行需求短时预测

CLAB模型是一种空间-时间环境下基于深度学习的乘客流量预测模型,可有效挖掘出租车乘客出行的时空相关性,考虑历史数据流入量对出行需求的影响,从而提高预测准确性。 数据挖掘维度: 1.时间维度:预测的是短时预测&#x…