diffusion model(四)文生图diffusion model(classifier-free guided)

news2024/11/14 19:17:40

文章目录

系列阅读

  • diffusion model(一)DDPM技术小结 (denoising diffusion probabilistic)
  • diffusion model(二)—— DDIM技术小结
  • diffusion model(三)—— classifier guided diffusion model
URL
paperClassifier-Free Diffusion Guidance
GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models
githubhttps://github.com/openai/glide-text2im

文生图diffusion model(classifier-free guided)

背景

在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。

Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即 ϵ θ ( x t , t ) → ϵ ^ θ ( x t , y , t ) \epsilon_{\theta}(x_t, t) \rightarrow \hat{\epsilon}_{\theta}(x_t, y, t) ϵθ(xt,t)ϵ^θ(xt,y,t)。从而简化整体的pipeline。

此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:

在这里插入图片描述

下面我们来看他是怎么做的吧!

方法大意

classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。

模型训练目标实现功能训练数据
DM (DDPM, DDIM) ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成图片图片
classifier-guided DM ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)和分类器 p ( y ∣ x t ) p(y|x_t) p(yxt)从服从高斯分布的噪声中生成特定类别的图片DM:图片 分类器:图片-标签对
classifier-free DM ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t), ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)从服从高斯分布的噪声中生成符合文本描述的图片图片-文本对
  • 对于训练 ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)来估计 x t x_t xt在时间 t t t上添加的噪声,再根据采样公式推出 x t − 1 x_{t-1} xt1,从而实现图片生成。训练数据只需要准备图片即可。
  • 对于classifier-guided DM是在普通DM的基础上,额外再训练一个Classifier来获得当前时间步生成的图片类别概率分布,从而实现特定类别的图片生成。
  • 对于classifier-free DM将类别信息(或语义信息)集成到diffusion model的训练过程中,训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) ( 即 ϵ θ ( x t , t ) ) \epsilon_{\theta}(x_t, y=\empty,t)(\text{即}\epsilon_{\theta}(x_t,t)) ϵθ(xt,y=,t)(ϵθ(xt,t))。训练的时候也会加入无类别信息(或语义信息)的图片进行训练。

回答3个问题深入理解classifier-free DM

  1. 模型如何融入类别信息(或语义信息)
  2. 如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)
  3. 如何进行采样生成

模型如何融入类别信息(或语义信息)

采用交叉注意力机制融入

我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding。随后需要将text embedding和原本模型中的image representation进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention。具体来说就是将text embedding作为注意力机制中的keyvalue,原始的图片表征作为query。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。

在这里插入图片描述

class SpatialCrossAttention(nn.Module):
  
    def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None:
        super(SpatialCrossAttention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0)
        self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x, context=None):
        
        x_q = self.proj_in(x) 
        b, c, h, w = x_q.shape
        x_q = rearrange(x_q, "b c h w -> b (h w) c")
        if context is None:
            context = x_q
        if context.ndim == 2:
            context = rearrange(context, "b c -> b () c")
        q = self.to_q(x_q)
        k = self.to_k(context)
        v = self.to_v(context)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
        out = self.to_out(out)
        return out 

基于channel-wise attention融入

该融入方法与time-embedding的融入方法相同,在时间中往往会预先和time-embedding进行融合,再融入到图片特征中,伪代码如下:

# mixture time-embedding and label embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
    assert y.shape == (x.shape[0],)
    emb = emb + self.label_emb(y)
while len(emb_out.shape) < len(h.shape):
    emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype)  # h is image feature
scale, shift = th.chunk(emb_out, 2, dim=1)  # One half of the embedding is used for scaling and the other half for offset
h = h * (1 + scale) + shift  

基于channel-wise的融入粒度没有CrossAttention细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention的方法。

在这里插入图片描述

如何训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t) ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\emptyset,t) ϵθ(xt,y=,t)

ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)的数据不宜过多,否则模型生成结果的保真度会降低。反之,若 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=,t)数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。

有两个实践中的trick需要注意

  • 在实践中,为了统一 y = ∅ y=\empty y= y ≠ ∅ y \neq \empty y=两种情形,通常会给定一个 y = ∅ y=\empty y=的embedding(可以随机初始化,也可以人为给定),来统一两种情形的建模。
  • 即使所有的数据都有图片对也没有关系,只需在每一个batch中随机将某些数据的标签编码替换为 y = ∅ y=\empty y=的embedding即可。另外

如何进行采样生成

classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的 ϵ ( x t , t ) \epsilon(x_t, t) ϵ(xt,t)用下式代替。
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{1} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](1)

下面给出详细的推导过程:

首先根据贝叶斯公式有
p ( y ∣ x t ) = p ( x t ∣ y ) p ( y ) ⏞ 先验分布 p ( x t ) ⇒ p ( y ∣ x t ) ∝ p ( x t ∣ y ) / p ( x t ) ⇒ 取对数 log ⁡ p ( y ∣ x t ) = log ⁡ p ( x t ∣ y ) − log ⁡ p ( x t ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( y ∣ x t ) = ∇ x t log ⁡ p ( x t ∣ y ) − ∇ x t log ⁡ p ( x t ) ⇒ 根据score function ∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) ∇ x t log ⁡ p ( y ∣ x t ) = − 1 1 − α ‾ t ( ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ) (2) \begin{aligned} p (y| x_t) & = \frac{p (x_t|y) \overbrace{p(y)}^{\text{先验分布}} } {p(x_t) } \\ \Rightarrow p (y| x_t) & \propto p (x_t|y) / {p (x_t) } \\ \stackrel{取对数} \Rightarrow \log{p (y| x_t)} & = \log{p (x_t|y)} - \log{{p (x_t) }} \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = \nabla_{x_t}\log{p (x_t|y)} - \nabla_{x_t}\log{{p (x_t) }} \\ \stackrel{\text{根据score function} \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t)} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}}(\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ) \end{aligned} \tag{2} p(yxt)p(yxt)取对数logp(yxt)xt求导xtlogp(yxt)根据score functionxtlogpθ(xt)=1αt 1ϵθ(xt)xtlogp(yxt)=p(xt)p(xty)p(y) 先验分布p(xty)/p(xt)=logp(xty)logp(xt)=xtlogp(xty)xtlogp(xt)=1αt 1(ϵθ(xt,y,t)ϵθ(xt,y=,t))(2)
当我们得到 ∇ x t log ⁡ p ( y ∣ x t ) \nabla_{x_t}\log{p (y| x_t)} xtlogp(yxt),参考classifier-guided的式(17)
ϵ ^ ( x t ∣ y ) ⏟ 本文中的 ϵ ^ θ ( x t , y , t ) : = ϵ θ ( x t ) ⏟ 本文中的 ϵ θ ( x t , y = ∅ , t ) − s 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (3) \underbrace{\hat{\epsilon}(x_t|y)}_{\text{本文中的}\hat{\epsilon}_{\theta}(x_t, y, t)} := \underbrace{\epsilon_\theta(x_t)}_{\text{本文中的}\epsilon_{\theta}(x_t, y=\empty, t)} - s\sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{3} 本文中的ϵ^θ(xt,y,t) ϵ^(xty):=本文中的ϵθ(xt,y=,t) ϵθ(xt)s1αt xtlogpϕ(yxt)(3)
可得
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{4} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=,t)+s[ϵθ(xt,y,t)ϵθ(xt,y=,t)](4)
后面的采样过程与之前的方式一致。

结语

本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。

参考文献

[1]: Classifier-Free Diffusion Guidance
[2]: GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/739735.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

P2 第一章 电路模型与电路定律

1、电源为什么可以等效为负值电阻&#xff1f; 当然电源可以等效为负值电阻&#xff0c;但是它不是真实电阻。 思考&#xff1a;阻碍反义词有推动&#xff0c;电动势&#xff1a;即电子运动的趋势&#xff0c;能够克服导体电阻对电流的阻力&#xff0c;使电荷在闭合的导体回路…

【Linux后端服务器开发】信号量与信号

目录 一、信号量概述 二、信号概述 三、信号产生 1、终端按键产生信号 2、调用系统函数产生信号 3、硬件异常产生信号 4、软件条件 四、信号保存 1、信号阻塞 2、信号捕捉流程 五、信号递达 一、信号量概述 信号量&#xff1a;一个计数器&#xff0c;通常用来表示公…

【档案专题】二、电子档案管理

导读&#xff1a;主要针对电子档案管理相关内容介绍。对从事电子档案管理信息化的职业而言&#xff0c;不断夯实电子档案管理相关理论基础是十分重要。只有通过不断梳理相关知识体系和在实际工作当中应用实践&#xff0c;才能走出一条专业化加职业化的道路&#xff0c;从而增强…

《黑马头条》 ElectricSearch 分词器 联想词 MangoDB day08-平台管理[实战]作业

07 app端文章搜索 1) 今日内容介绍 1.1)App端搜索-效果图 1.2)今日内容 2) 搭建ElasticSearch环境 2.1) 拉取镜像 docker pull elasticsearch:7.4.0 2.2) 创建容器 docker run -id --name elasticsearch -d --restartalways -p 9200:9200 -p 9300:9300 -v /usr/share/elasticse…

Android 查看ANR和Crash日志(adb bugreport)

今天测试那儿出了个ANR&#xff0c;我自己手机没问题&#xff0c;很烦&#xff0c;定位不了位置。 于是还是得用ADB连接来看一下&#xff0c;之前用&#xff0c;但是老是会忘记&#xff0c;今天总结一下。 ADB命令查看应用包名_adb查看包名命令_&岁月不待人&的博客-C…

基于matlab使用部分或较低分辨率图像快速处理阻塞图像(附源码)

一、前言 此示例展示了如何使用两种策略快速处理阻塞图像&#xff0c;这两种策略可以对高分辨率图像的较小代表性样本进行计算。 处理被阻止的图像可能非常耗时&#xff0c;这使得算法的迭代开发成本过高。有两种常见的方法可以缩短反馈周期&#xff1a;迭代较低分辨率的图像…

【Tensorflow2.x】tensorflow-gpu 在 Ubuntu 上的安装

好几次遇到问为什么安装的 tensorflow 不能调用GPU&#xff0c;之前搞定过几次&#xff0c;前两天又有人问&#xff0c;又捣鼓了很久才搞定&#xff0c;这里简单记录一下我遇到的问题&#xff0c;以及解决方案。 一、安装方法 &#xff08;一&#xff09;安装并更新 conda 1…

C++STL:顺序容器之list

文章目录 1. 概述2. 成员函数3. list容器的创建4. 迭代器5. 访问元素6. 添加/插入元素list insert()成员方法list splice()成员方法 7. 删除元素 1. 概述 STL list 容器&#xff0c;又称双向链表容器&#xff0c;即该容器的底层是以双向链表的形式实现的。这意味着&#xff0c…

RTOS学习笔记

前言 进程&#xff1f;线程&#xff1f;并发&#xff1f;并行&#xff1f;主线程&#xff1f;子线程&#xff1f;主线程中创建子线程&#xff1f;每个线程就是一个死循环&#xff1f; 进程 多个线程&#xff0c;每个线程可以写一个死循环处理一个需要循环执行的代码块&#x…

leetcode-203.移除链表元素

leetcode-203.移除链表元素 文章目录 leetcode-203.移除链表元素题目描述代码提交 题目描述 代码提交 代码 class Solution { public:ListNode* removeElements(ListNode* head, int val) {ListNode *dummyhead new ListNode(0); // 设置一个虚拟头结点,堆上dummyhead->ne…

SOLIDWORKS、UG、Proe三款三维绘图软件哪个好?

提到制图&#xff0c;很多人可能会先想到AutoCAD&#xff0c;但它现在主要会被用来进行二维的平面制图。3DMAX是一款被广泛应用的三维制图软件。Proe也是一种比较好用的三维建模软件。而SW也就是SolidWorks更为知名&#xff0c;它是世界上第一个专为Windows系统开发的三维CAD建…

解决小程序 scroll-view 里面的image有间距、小程序里面的图片之间有空隙的问题。

1&#xff09;小程序 image跟view标签上下会有间隙&#xff0c;解决方法如下&#xff1a; 在image那里设置vertical-align:top/bottom/text-top/text-bottom 原因&#xff1a;图片文字等inline元素默许是跟父级元素的baseline对齐&#xff0c;而baseline又和父级底边有必定间距…

web 前端 Day 3

伪类选择器 <title>伪类选择器</title> </head> <style>a:link {color: beige;} a:visited {color: aquamarine; } a:hover { 鼠标悬停cursor: cell; 鼠标样式font-size: 80px; } a:active {font-size: 70px; } div{width: 300px;height: 400…

813. 打印矩阵

链接&#xff1a; 打印矩阵 题目&#xff1a; 给定一个 rowcolrowcol 的二维数组 aa&#xff0c;请你编写一个函数&#xff0c;void print2D(int a[][N], int row, int col)&#xff0c;打印数组构成的 rowrow 行&#xff0c;colcol 列的矩阵。 注意&#xff0c;每打印完一整行…

JPA的saveAndFlush

#Stable Diffusion 美图活动一期# 关于MyBatis与JPA&#xff1a; 笔者初次接触这两个持久层框架的时候&#xff0c;那还是得从iBatis、Hibernate开始说起。那时候知道的一个很浅显、但最明显的区别就是&#xff1a;iBatis是半自动化的ORM框架&#xff0c;适用于表关联关系复杂的…

浅谈利用树莓派卡片电脑进行图像识别学习和研发

利用树莓派进行图像识别学习和研发是一个非常有前景和潜力的领域。树莓派是一款小巧且功能强大的单板计算机&#xff0c;具备较高的处理能力和丰富的接口&#xff0c;非常适合用于图像识别的应用开发。 在图像识别方面&#xff0c;树莓派可以利用其强大的计算能力和丰富的软件…

react知识点汇总四--react router 和 redux

react-router React-Router 是一个用于在 React 应用中实现页面导航和路由管理的库。它提供了一种方式来创建单页应用&#xff08;Single-Page Application&#xff0c;SPA&#xff09;&#xff0c;其中页面的切换是在客户端进行的&#xff0c;而不需要每次跳转都向服务器请求…

Mac上绿色软件怎么长期保存

1、找到想长期保存的绿色软件&#xff0c;右键拷贝 2、来到「应用程序」&#xff0c;点工具栏-操作-粘贴项目 3、这样绿色软件就长期保留下来了

华纳云:一台香港多IP服务器如何设置多个IP?

在一台香港多IP服务器上设置多个IP的步骤如下&#xff1a; 1.确认服务器支持多个IP地址&#xff1a;首先&#xff0c;确保你的服务器有多个网卡接口或虚拟网卡接口&#xff0c;以支持多个IP地址。 2.查看当前IP配置&#xff1a;运行以下命令来查看当前的IP配置信息&#xff1a;…

深度学习——神经网络参数的保存和提取

代码与详细注释&#xff1a; Talk is cheap. Show you the code&#xff01; import torch import matplotlib.pyplot as plt# 造数据 x torch.unsqueeze(torch.linspace(-1, 1, 100), dim1) # x data (tensor), shape(100, 1) y x.pow(2) 0.2*torch.rand(x.size()) # n…