扩散模型条件生成——Classifier Guidance和Classifier-free Guidance原理解析

news2024/11/28 20:48:50

1、前言

从讲扩散模型到现在。我们很少讲过条件生成(Stable DIffusion曾提到过一点),所以本篇内容。我们就来具体讲一下条件生成。这一部分的内容我就不给原论文了,因为那些论文并不只讲了条件生成,还有一些调参什么的。并且推导过程也相对复杂。我们从一个比较简单的角度出发。

参考论文:Understanding Diffusion Models: A Unified Perspective (arxiv.org)

参考代码:

classifier guidance:GitHub - openai/guided-diffusion

classifier-free guidance:GitHub - coderpiaobozhe/classifier-free-diffusion-guidance-Pytorch: a simple unofficial implementation of classifier-free diffusion guidance

视频:[扩散模型条件生成——Classifier Guidance和Classifier-free Guidance原理解析-哔哩哔哩]

2、常用的条件生成方法

在diffusion里面,如何进行条件生成呢?我们不妨回忆一下在Stable Diffusion里面的一个常用做法。即在训练的时候。给神经网络输入一个条件。
L = ∣ ∣ ϵ − ϵ θ ( x t , t , y ) ∣ ∣ 2 L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2 L=∣∣ϵϵθ(xt,t,y)2
里面的y就是条件。至于为什么有效,请看我之前写过的Stable DIffusion那篇文章。在此不过多赘述了。我们来讲这种方法所存在的问题。

很显然的,这种训练的方式,会有一个问题,那就是神经网络或许会学会忽略或者淡化掉我们输入的条件信息。因为就算我们不输入信息,他也照样能够生成。

接下来我们来讲两种更为流行的方法——分类指导器(Classifier Guidance) 和无分类指导器( Classifier-Free Guidance)

3、Classifier Guidance

为了简单起见。我们从分数模型的角度出发。

回忆一下在SDE里面的结论。其反向过程为
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ (1) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{1} dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ(1)
如果施加条件的话,还是根据Reverse-time diffusion equation models - ScienceDirect这篇论文,可得条件生成时的反向SDE为
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ∣ y ) ] d t + g ( t ) d w ˉ (2) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x|y)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{2} dx=[f(x,t)g(t)2xlogpt(xy)]dt+g(t)dwˉ(2)
我们利用贝叶斯公式,对 ∇ x log ⁡ p t ( x ∣ y ) \nabla x \log p_t(x|y) xlogpt(xy)进行处理
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( y ∣ x ) p t ( x ) p t ( y ) = ∇ x ( log ⁡ p t ( y ∣ x ) + log ⁡ p t ( x ) − log ⁡ p t ( y ) ) = ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log\frac{p_t(y|x)p_t(x)}{p_t(y)}\\=&\nabla_x\left(\log p_t(y|x)+\log p_t(x)-\log p_t(y)\right)\\=&\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\end{aligned}\nonumber xlogpt(xy)===xlogpt(y)pt(yx)pt(x)x(logpt(yx)+logpt(x)logpt(y))xlogpt(x)+xlogpt(yx)
第二个等号到第三个等号是因为对 log ⁡ p t ( y ) \log p_t(y) logpt(y)关于x求梯度等于0( log ⁡ p t ( y ) \log p_t(y) logpt(y)与x无关)

把它代入Eq.(2)可得
d x = [ f ( x , t ) − g ( t ) 2 ( ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) ) ] d t + g ( t ) d w ˉ (3) \mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\left(\nabla_x\log p_t(x)+\nabla_x\log p_t(y|x)\right)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{3} dx=[f(x,t)g(t)2(xlogpt(x)+xlogpt(yx))]dt+g(t)dwˉ(3)
对比Eq.(1)和Eq.(3)。我们不难发现,它们的差别,居然是只多了一个 ∇ x log ⁡ p t ( y ∣ x ) \nabla_x\log p_t(y|x) xlogpt(yx)

p t ( y ∣ x ) p_t(y|x) pt(yx)是什么?是以 x x x作为条件,时间为t对应条件y的概率。我们可以怎么求呢?该怎么求出来呢?

当然是使用神经网络了。也就是说,我们可以额外设定一个神经网络,该神经网络输入是 x t x_t xt,输出是条件为y的概率

所以,实际上我们现在需要训练两部分,一部分是 ∇ x log ⁡ p t ( x ) \nabla_x\log p_t(x) xlogpt(x),这我们在SDE中已经讲过该如何训练了。

另一个就是 ∇ x log ⁡ p t ( y ∣ x ) \nabla_x\log p_t(y|x) xlogpt(yx),他就是一个分类神经网络网络。训练好之后,我们就可以使用Eq.(3)通过不同的数值求解器,进行优化了。

作者在此基础上,又引入了一个控制参数 λ \lambda λ
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( y ∣ x ) (4) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(y|x)\tag{4} xlogpt(xy)=xlogpt(x)+λxlogpt(yx)(4)
λ = 0 \lambda=0 λ=0,表示不加入任何条件。当 λ \lambda λ很大时,模型会产生大量附带条件信息的样本。

这种方法的一个缺点就是,需要额外学习一个分类器 p t ( y ∣ x ) p_t(y|x) pt(yx)

4、Classifier-Free Guidance

之前推出
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ∇ x log ⁡ p t ( y ∣ x ) (5) \nabla_x \log p_t(x|y)=\nabla_x \log p_t(x)+\nabla_x\log p_t(y|x)\tag{5} xlogpt(xy)=xlogpt(x)+xlogpt(yx)(5)
把该式子代入Eq.(4)可得
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + λ ( ∇ x log ⁡ p t ( x ∣ y ) − ∇ x log ⁡ p t ( x ) ) = ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( x ∣ y ) − λ ∇ x log ⁡ p t ( x ) = ( 1 − λ ) ∇ x log ⁡ p t ( x ) + λ ∇ x log ⁡ p t ( x ∣ y ) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+\lambda\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)-\lambda\nabla_x\log p_t(x)\\=&\left(1-\lambda\right)\nabla_x\log p_t(x)+\lambda\nabla_x\log p_t(x|y)\end{aligned}\nonumber xlogpt(xy)===xlogpt(x)+λ(xlogpt(xy)xlogpt(x))xlogpt(x)+λxlogpt(xy)λxlogpt(x)(1λ)xlogpt(x)+λxlogpt(xy)
此时我们注意到,当 λ = 0 \lambda=0 λ=0是,第二项完全为0,会忽略掉条件;当 λ = 1 \lambda=1 λ=1时,使用第二项,第二项就是附带有条件情况下的分布分数网络;而当 λ > 1 \lambda> 1 λ>1,模型会优化考虑条件生成样本,并且远离第一项的无条件分数网络的方向,换句话说,它降低了生成不使用条件信息的样本的概率,而有利于生成明确使用条件信息的样本。

事实上,如果你看了free-Classifier Guidance这篇论文,会发现我们的结论不一样。

其实论文里面的控制参数是 w w w,也就是说,Eq.(4)就变成了这样
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + w ∇ x log ⁡ p t ( y ∣ x ) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+w\nabla_x\log p_t(y|x) xlogpt(xy)=xlogpt(x)+wxlogpt(yx)
我们把控制参数改成 1 + w 1+w 1+w不会有任何影响
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ∇ x log ⁡ p t ( y ∣ x ) \nabla_x \log p_t(x|y)=\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(y|x) xlogpt(xy)=xlogpt(x)+(1+w)xlogpt(yx)
把Eq.(5)代入该式子
∇ x log ⁡ p t ( x ∣ y ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ( ∇ x log ⁡ p t ( x ∣ y ) − ∇ x log ⁡ p t ( x ) ) = ∇ x log ⁡ p t ( x ) + ( 1 + w ) ∇ x log ⁡ p t ( x ∣ y ) − ( 1 + w ) ∇ x log ⁡ p t ( x ) = ( 1 + w ) ∇ x log ⁡ p t ( x ∣ y ) − w ∇ x log ⁡ p t ( x ) (6) \begin{aligned}\nabla_x \log p_t(x|y)=&\nabla_x\log p_t(x)+(1+w)\left(\nabla_x\log p_t(x|y)-\nabla_x\log p_t(x)\right)\\=&\nabla_x\log p_t(x)+(1+w)\nabla_x\log p_t(x|y)-(1+w)\nabla_x\log p_t(x)\\=&(1+w)\nabla_x\log p_t(x|y)-w\nabla_x\log p_t(x)\end{aligned}\tag{6} xlogpt(xy)===xlogpt(x)+(1+w)(xlogpt(xy)xlogpt(x))xlogpt(x)+(1+w)xlogpt(xy)(1+w)xlogpt(x)(1+w)xlogpt(xy)wxlogpt(x)(6)
这就是原论文里面的结论。

那么接下来,我们来探讨一下该如何去训练。

对于 ∇ x log ⁡ p t ( x ) \nabla_x\log p_t(x) xlogpt(x),这个不用说了,之前我们训练的就是这个;如何计算 ∇ x log ⁡ p t ( x ∣ y ) \nabla_x\log p_t(x|y) xlogpt(xy)呢,它实际上就是在给定y的情况下,求出 p t ( x ∣ y ) p_t(x|y) pt(xy)。那我们可以怎么做呢?

在NCSN,我们是使用一个加噪分布 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x)取代 p ( x ) p(x) p(x),而从让它是可解的。

对于 p t ( x ∣ y ) p_t(x|y) pt(xy),即便是加多了一个条件之后,我们仍然建模为 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x),也就是说,我们仍然把它建模成一个正向加噪过程。因此,无论是否增加条件。最终的损失函数结果都是
L = ∣ ∣ s θ − ∇ x log ⁡ q ( x ~ ∣ x ) ∣ ∣ 2 = ∣ ∣ s θ − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 L=||s_\theta-\nabla_x\log q(\tilde x|x)||^2=||s_\theta-\nabla_x\log q(x_t|x_0)||^2 L=∣∣sθxlogq(x~x)2=∣∣sθxlogq(xtx0)2
后者是通过SDE统一的结果(我在SDE那一节讲过)

那该如何体现条件y呢?其实我们在第二节的时候已经说过了,就是在里面神经网络的输出加入一个条件y。
L = ∣ ∣ s θ ( x t , t , y ) − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 (7) L=||s_\theta(x_t,t,y)-\nabla_x\log q(x_t|x_0)||^2\tag{7} L=∣∣sθ(xt,t,y)xlogq(xtx0)2(7)
而不施加条件的时候,长这样
L = ∣ ∣ s θ ( x t , t ) − ∇ x log ⁡ q ( x t ∣ x 0 ) ∣ ∣ 2 (8) L=||s_\theta(x_t,t)-\nabla_x\log q(x_t|x_0)||^2\tag{8} L=∣∣sθ(xt,t)xlogq(xtx0)2(8)
由Eq.(5)可知,我们需要训练两种情况,一种是有条件的,对应Eq.(7);另外一种是无条件的,对应Eq.(8)。

理论上,我们其实也是要训练两个神经网络。但实际上,我们可以把他们结合成一种神经网络。

具体操作就是把无条件的情况作为一种特例。

当我们训练有条件的神经网络的时候,会照样把条件输入进网络里面。而训练无条件的时候,我们构造一个无条件的标识符,把它作为条件输入给神经网络,比如对于所有无条件的情况,我都构造一个0作为条件输入到神经网络里面。通过这种方式,我们就可以把两个网络变成一个网络了,

对于损失函数,直接使用Eq.(7)。我们在SDE里面讲过 ∇ x log ⁡ p ( x ) = − 1 σ ϵ \nabla_x \log p(x)=-\frac{1}{\sigma}\epsilon xlogp(x)=σ1ϵ。所以我们最终我们把预测噪声,变成了预测分数。我们同样可以把它变回来,变成预测分数
L = ∣ ∣ ϵ − ϵ θ ( x t , t , y ) ∣ ∣ 2 L=||\epsilon-\epsilon_{\theta}(x_t,t,y)||^2 L=∣∣ϵϵθ(xt,t,y)2
所以损失函数就变成了这样。在训练的时候,作者设定一个大于等于0,小于等于1的超参数 p u n c o n d p_{uncond} puncond,它的作用就是判断是否需要输入条件(从0-1分布采样一个值,大于 p u n c o n d p_{uncond} puncond则使用条件,反之则不使用)。也就是说,这相当于dropout一样,随机舍弃掉一些条件,把他们作为无条件的情况(因为我们既要学习有条件的,又要学习无条件的)。所以,最终的训练过程就是这样

在这里插入图片描述

其中里面的 λ \lambda λ你就当作是时刻t吧(其实不是,其实是时刻t的噪声(噪声的初始化不一样,不是传统的等差数列,是用三角函数初始化的)。由于与本篇内容无关,故而忽略),c是条件。

同样的,采用过程使用Eq.(6)的结构进行采样

在这里插入图片描述

5、结束

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1802411.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ARFoundation自学04】AR Tracked Image 图像追踪识别

图像识别是很常用的AR功能!AR foundation 可以帮助我们轻松实现! 1.安装插件 首先还是在资源包中导入ARfoundation 。然后搭建基本的AR ARFoundation框架! 2.创建AR session 和XR origin结构! 3.然后在XR Origin 物体身上添加A…

ubuntu22.04编译OpenCV4.9(带contrib-4.9.0)

操作系统:ubuntu22.04 OpenCV版本:4.9.0 opencv_contrib版本:4.9.0 源码下载 OPenCV4.9.0下载地址:https://github.com/opencv/opencv/releases/tag/4.9.0 如下图所示: 按箭头所指点击下载source code(tar.gz)文件到…

电机专用32位MCU PY32MD310,Arm® Cortex-M0+内核

PY32MD310是一颗专为电机控制设计的MCU,非常适合用做三相/单相 BLDC/PMSM 的主控芯片。芯片采用了高性能的 32 位 ARM Cortex-M0 内核,QFN32封装。内置最大 64 Kbytes flash 和 8 Kbytes SRAM 存储器,最高48 MHz工作频率,多达 16 …

Day51 动态规划part10+Day52 动态规划part11

LC121买卖股票的最佳时机(未掌握) 暴力:双层循环寻找最优间距,每一次都确定一个起点,遍历剩余节点当作终点 贪心:取最左最小值,不断遍历那么得到的差值最最大值就是最大利润。 动态规划 dp数组…

【C++】C++ 基于QT实现散列表学生管理系统(源码+数据+课程论文)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

0基础学习Elasticsearch-使用Java操作ES

文章目录 1 背景2 前言3 Java如何操作ES3.1 引入依赖3.2 依赖介绍3.3 隐藏依赖3.4 初始化客户端(获取ES连接)3.5 发送请求给ES 1 背景 上篇学习了0基础学习Elasticsearch-Quick start,随后本篇研究如何使用Java操作ES 2 前言 建议通篇阅读再回…

SpringBoot: 可执行jar的特殊逻辑

这一篇我们来看看Java代码怎么操作zip文件(jar文件),然后SpringBoot的特殊处理,文章分为2部分 Zip API解释,看看我们工具箱里有哪些工具能用SpringBoot的特殊处理,看看SpringBoot Jar和普通Jar的不同 1. Zip API解释 1. ZipFil…

NRF24L01(2.4G)模块的使用——SPI时序(软件)篇

一、SPI的简介: SPI 是英语Serial Peripheral interface的缩写,顾名思义就是串行外围设备接口。是Motorola首先在其MC68HCXX系列处理器上定义的。 SPI,是一种高速的,全双工,同步的通信总线,并且在芯片的管脚…

R语言 | 使用最简单方法添加显著性ggpubr包

本期教程原文:使用最简单方法添加显著性ggsignif包 本期教程 获得本期教程代码和数据,在后台回复关键词:20240605 小杜的生信笔记,自2021年11月开始做的知识分享,主要内容是R语言绘图教程、转录组上游分析、转录组下游…

毫米波SDK使用2

5.5 毫米波SDK-TI组件 毫米波SDK功能分解成组件将在接下来的几小节中解释。有关这些模块的详细文档&#xff0c;请参阅位于mmwave_mcuplus_sdk_<ver>/docs/mmwave_sdk_module_document .html的顶层文档。 5.5.1 演示 5.5.1.1 毫米波演示 这个演示位于mmwave_mcuplus_sd…

批量高效调整图片像素:自定义缩小bmp图片,画质优先,一键实现高效优化

图片已经成为我们生活中不可或缺的一部分。无论是社交媒体分享&#xff0c;还是工作文件传输&#xff0c;图片总是扮演着重要的角色。然而&#xff0c;有时候&#xff0c;我们可能会面临一个问题&#xff1a;图片像素过大&#xff0c;不仅占用过多的存储空间&#xff0c;还可能…

【网络教程】Iptables官方教程-学习笔记7-简单理解IPTABLES规则的作用流程

前面学习了IPTABLES的所有功能介绍后&#xff0c;一个Linux设备里的IPTABLES规则集是如何运行的&#xff0c;这里简单做个介绍。 在Linux设备里输入"iptables -nvl",得到该设备的所有防火墙规则&#xff0c;得到的结果中可以看到这个设备防火墙里所有的链以及链里的…

STM32F103C8移植uCOSIII并以不同周期点亮两个LED灯(HAL库方式)【uCOS】【STM32开发板】【STM32CubeMX】

STM32F103C8移植uC/OSIII并以不同周期点亮两个LED灯&#xff08;HAL库方式&#xff09;【uC/OS】【STM32开发板】【STM32CubeMX】 实验说明 将嵌入式操作系统uC/OSIII移植到STM32F103C8上&#xff0c;构建两个任务&#xff0c;两个任务分别以1s和3s周期对LED进行点亮—熄灭的…

力扣hot100:394. 字符串解码(递归)

LeetCode&#xff1a;394. 字符串解码 本题容易想到用递归处理&#xff0c;在写递归时主要是需要明确自己的递归函数的定义。 不过我们也可以利用括号匹配的方式使用栈进行处理。 1、递归 定义递归函数string GetString(string & s,int & i); 表示处理处理整个numbe…

高中数学:数列-基础概念

一、什么是数列&#xff1f; 一般地&#xff0c;我们把按照确定的顺序排列的一列数称为数列&#xff0c;数列中的每一个数叫做这个数列的项&#xff0c;数列的第一项称为首项。 项数有限个的数列叫做有穷数列&#xff0c;项数无限个的数列叫做无穷数列。 二、一般形式 数列和…

2024高考作文引发的人工智能争议

又是一年高考季&#xff0c;多少学子的修行成果也在这这一刻迎来了终极检验&#xff0c;多少学子的梦也在这一刻拉开了揭晓序幕&#xff0c;多少学习的命运也在这一刻迎来了人生中的第一次转变。每年的高考不仅是学子们的人生大事&#xff0c;也是多少父母的热切期望&#xff0…

Java Web学习笔记25——Vue组件库Element

什么是Element&#xff1f; Element: 是饿了么团队研发的&#xff0c;一套为开发者、设计师和产品经理准备的基于Vue2.0的桌面端组件库。 组件&#xff1a;组成网页的部件&#xff0c;例如&#xff1a;超链接、按钮、图片、表格、表单、分页条等等。 官网&#xff1a;https:…

详解C++中的ANSI、Unicode和UTF8三种字符编码及相互转换

目录 1、概述 2、Visual Studio中的字符编码 3、ANSI窄字节编码 4、Unicode宽字节编码 5、UTF8编码 6、如何使用字符编码 7、三种字符编码之间的相互转换&#xff08;附源码&#xff09; 8、Windows系统对使用ANSI窄字节字符编码的程序的兼容 9、字符编码导致程序启动…

1-8 C语言分支循环语句

C语言的语句分为 5 类 1&#xff1a;表达式语句2&#xff1a;函数调用语句3&#xff1a;控制语句4&#xff1a;复合语句5&#xff1a;空语句 控制语句&#xff1a;用于控制程序的执行流程&#xff0c;以实现程序的各种结构方式&#xff0c;它们由特定的语句定义符组成&#x…

【日记】遇到了一个 “不愿睁眼看世界也没受过社会毒打” 的逆天群友(464 字)

正文 今天坐在柜台玩了一天手机…… 手机都玩没电了快。下午在劝一个群友睁眼看世界&#xff0c;实在劝不动。他真的太逆天了&#xff0c;我不清楚这么高学历的人&#xff0c;怎么能说出这么天真的话。逆天又离谱。 晚上的时间几乎全在做家务。平时晚上都是跳舞来着&#xff0c…