噪声条件分数网络——NCSN原理解析

news2024/11/24 17:47:00

1、前言

本篇文章,我们讲NCSN,也就是噪声条件分数网络。这是宋飏老师在2019年提出的模型,思路与传统的生成模型大不相同,令人拍案叫绝!!!

参考论文:

①Generative Modeling by Estimating Gradients of the Data Distribution (arxiv.org)

②Tutorial on Diffusion Models for Imaging and Vision (arxiv.org)

参考代码:GitHub - Lingyu-Kong/ncsn: Handwritten Score-Based Generative Model

视频:[噪声条件得分(分数)网络——NCSN原理解析-哔哩哔哩]

Ps:这篇文章我简单讲一下思路就算了,过程并不严谨,因为这个内容并不是很重要

2、引入

回忆一下梯度下降,假设我们有一个二次函数
f ( x ) = ( 0.5 x − 3 ) 2 f(x)=(0.5x-3)^2 f(x)=(0.5x3)2
导数为 f ′ ( x ) = ( 0.5 x − 3 ) f'(x)=(0.5x-3) f(x)=(0.5x3),使用梯度下降
x t + 1 = x t − 0.1 f ′ ( x t ) (1) x_{t+1}=x_t-0.1f'(x_t)\tag{1} xt+1=xt0.1f(xt)(1)
其中 x t 、 x t + 1 x_t、x_{t+1} xtxt+1表示优化前和优化后的x对应的值, 0.1 0.1 0.1是步长。初始化蓝色点 x t = − 6 x_t=-6 xt=6,迭代100轮梯度下降,就可以得到下面的图(可以看到蓝色点逐渐向着函数最低点靠近)

在这里插入图片描述

为什么会这样?因为梯度总是指向函数值上升的方向。而Eq.(1),是减去梯度,相当于对梯度取反方向。于是x的值就沿着函数值下降的方向走了。如果换成梯度上升,则Eq.(1)改为
x t + 1 = x t + 0.1 f ′ ( x t ) (2) x_{t+1}=x_t+0.1f'(x_t)\tag{2} xt+1=xt+0.1f(xt)(2)
对应图像为

在这里插入图片描述

再回忆一下一维高斯分布的概率密度的图像

在这里插入图片描述

当y值(密度值)取到最高点,其对应样本点在均值处

此时我们注意到,高斯分布的图像,与Eq.(2)何其相像,那我们把Eq.(2)里面的 f ( x ) f(x) f(x)当作是高斯分布的密度函数,而 x x x则对应高斯分布的样本点
x t + 1 = x t + 0.1 f ′ ( x t ) x_{t+1}=x_t+0.1f'(x_t) xt+1=xt+0.1f(xt)
那么这个梯度上升的意思就变成了,对于一个样本 x t x_t xt,不断往概率密度函数 f ′ ( x t ) f'(x_t) f(xt)密度值高的地方靠近。如果优化到最优点,那么图像就会变成这样

在这里插入图片描述

也就是说,样本点 x t x_t xt,最终会走到概率值最高对应的点,那么此时的样本点 x t x_t xt,就可以认为是从高斯分布中采样出来的一个概率最高的样本。我们写成概率分布的一般形式
x t + 1 = x t + α ∇ x P ( x t ) x_{t+1}=x_t+\alpha \nabla_xP(x_t) xt+1=xt+αxP(xt)
α \alpha α表示步长,比如之前的0.1, ∇ x \nabla_x x是对x求梯度。

我们在 P ( x t ) P(x_t) P(xt)前面取一个log对数,不改变单调性,仍然会使 x t x_t xt收敛到最优值
x t + 1 = x t + α ∇ x log ⁡ P ( x t ) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t) xt+1=xt+αxlogP(xt)
更一般的,从一个概率分布中采样,我们往往会存在一些偏差项,于是我们加上一个随机噪声
x t + 1 = x t + α ∇ x log ⁡ P ( x t ) + 2 α z t (3) x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)+\sqrt{2\alpha}z_t\tag{3} xt+1=xt+αxlogP(xt)+2α zt(3)
2 α \sqrt{2\alpha} 2α 是缩放系数,而 z t z_t zt是标准高斯分布,加上一个噪声后, x t x_t xt的收敛值会在概率最高点处不断徘徊

图像表示为

在这里插入图片描述

现在,我们更进一步,我们把 x t x_t xt当作是一个随机初始化的图像,然后 P ( x ) P(x) P(x)是我们训练图像的所对应的分布,通过不断执行Eq.(3),便可以让随机初始化的图像,不断往 P ( x ) P(x) P(x)概率最高点周围靠近,那么就间接说明,经过了大T步Eq.(3),得到的 x t x_t xt,可以认为是从 P ( x ) P(x) P(x)中采样出来的。

仔细看一下,这不就是一个生成图像的过程吗?

这种方式,又被称为郎之万动力采样。emmmmm,不懂,物理学的东西。。。

我们看一个可视化的过程(图像来自参考①)

在这里插入图片描述

3、目标函数

既然Eq.(3)能够通过迭代的方式,生成图像,那自然只需要求解Eq.(3)就可以了。不幸的是,我们没办法求解

我们的训练图像,它们所服从的概率分布往往及其复杂,也就是说 P ( x ) P(x) P(x)是难以求解的​,好在我们的目标并不是求出 P ( x ) P(x) P(x),而是对应的梯度(也称为分数)
L = 1 2 E P d a t a ( x ) [ ∣ ∣ s θ ( x ) − ∇ x log ⁡ P d a t a ( x ) ∣ ∣ 2 2 ] (4) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x)}\left[||s_\theta(x)-\nabla_x\log P_{data}(x)||_2^2\right]\tag{4} L=21EPdata(x)[∣∣sθ(x)xlogPdata(x)22](4)
P d a t a P_{data} Pdata表示训练数据所服从的分布

也就是通过最小化上式,便可得到 s θ ( x ) ≈ ∇ x log ⁡ P d a t a ( x ) s_\theta(x)\approx \nabla_x\log P_{data}(x) sθ(x)xlogPdata(x)

4、问题

理论上,我们直接求解Eq.(4)就可以了,但是,我们样本所服从的分布往往是服从,概率分布中往往存在一些低密度区域,那么对应的样本就很少。

而样本少,意味着对应为止的梯度分数,得不到很好的训练,那么神经网络在那些样本点就很容易估不准。作者博客给出了一张很形象的图像(图像来自参考①)

在这里插入图片描述

可以看到,数据的密度分别都在左下角和右上角,那么这些区域就能够用神经网络得到很好的拟合,对应Accurate区域。相反,低密度区域,没有得到很好的拟合,对应Inaccurate区域。

当我们使用郎之万动力采样的时候,随机初始化一个 x 0 x_0 x0,它落在低密度区域的概率非常之高。而低密度的区域没有经过很好的训练,所以郎之万动力采样在短时间内很难得到较好的结果。

那么,该如何解决这个问题呢?一个很好的方法就是——加噪声

我们通过对图像加入随机扰动噪声,会填充原本的低密度区域,从而让整个区域看起来较为的均匀(图像来自参考①)

在这里插入图片描述

也就是这样,让原本的密度点扩张开来。

加噪的过程我们可以表示为 x ~ = x + σ z \tilde x=x+\sigma z x~=x+σz x x x表示原始图像, x ~ \tilde x x~表示加噪后的图像。

我们用 q ( x ~ ∣ x ) ∼ N ( x , σ 2 I ) q(\tilde x|x)\sim N(x,\sigma^2I) q(x~x)N(x,σ2I)去表示这个加噪过程

于是Eq.(3)就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) − ∇ x ~ log ⁡ q ( x ~ ∣ x ) ∣ ∣ 2 2 ] (5) L_{}=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)-\nabla_{\tilde x}\log q(\tilde x|x)||_2^2\right]\tag{5} L=21EPdata(x),x~N(x,σ2I)[∣∣sθ(x+σz)x~logq(x~x)22](5)
emmmm,我感觉这样讲貌似挺合理的,但是它是需要证明的,也就是证明Eq.(4)、Eq.(5)的优化等价性。我就不证明了,证明过程在参考论文②,并不难,读者自己看一下就知道了

除此之外,真正导致需要加噪的,其实有其他原因,我只讲了其中一个。其他原因请看参考②,里面讲的非常之详细。我也懒得写了

现在,我们预测的是加噪后的梯度分数,通过加噪的过程,也避免了直接求解 P ( x ) P(x) P(x)的问题。那我们来看一下这个等式可以变成什么吧

如果我们加的噪声足够小,那么 P d a t a ( x ) ≈ q ( x ~ ∣ x ) P_{data}(x)\approx q(\tilde x|x) Pdata(x)q(x~x)

因为 q ( x ~ ∣ x ) q(\tilde x|x) q(x~x)是服从高斯分布的,是完全可以求出来的,所以梯度为
∇ x ~ log ⁡ q ( x ~ ∣ x ) = ∇ x ~ log ⁡ 1 2 π σ 2 d exp ⁡ { − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 } = ∇ x ~ ( log ⁡ 1 2 π σ 2 d − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 ) = − 2 ( x ~ − x ) 2 σ 2 = − x ~ − x σ 2 = − z σ \begin{aligned}\nabla_{\tilde x}\log q(\tilde x|x)=&\nabla_{\tilde x}\log \frac{1}{\sqrt{2\pi\sigma^2}^d}\exp \left\{-\frac{||\tilde x-x||^2}{2\sigma^2}\right\}\\=&\nabla_{\tilde x}\left(\log \frac{1}{\sqrt{2\pi\sigma^2}^d}-\frac{||\tilde x-x||^2}{2\sigma^2}\right)\\=&-\frac{2(\tilde x-x)}{2\sigma^2}\\=&-\frac{\tilde x -x}{\sigma^2}\\=&-\frac{z}{\sigma}\end{aligned} x~logq(x~x)=====x~log2πσ2 d1exp{2σ2∣∣x~x2}x~(log2πσ2 d12σ2∣∣x~x2)2σ22(x~x)σ2x~xσz
所以损失函数就可以变成
L = 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ 2 I ) [ ∣ ∣ s θ ( x + σ z ) + x ~ − x σ 2 ∣ ∣ 2 2 ] L=\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma^2I)}\left[||s_\theta(x+\sigma z)+\frac{\tilde x -x}{\sigma^2}||_2^2\right] L=21EPdata(x),x~N(x,σ2I)[∣∣sθ(x+σz)+σ2x~x22]

按理说,我们只需要最优化这个目标函数即可。

可问题又来了

我们该如何加入噪声呢?加多少?加的小了,低密度区域没有得到很好的填充。加多了,直接改变原本的数据分布了,这显然也不行。

我们干脆一不做二不休,我们加多个量级噪声,不同量级都进行训练。

当训练完成之后,就得到了不同噪声强度的噪声条件得分网络。

假设不同强度等级的噪声有S个, { σ i } i = 1 S \{\sigma_i\}_{i=1}^S {σi}i=1S,我们看一张图(里面显示了三个噪声强度的情况,图像来自参考①)

在这里插入图片描述

那么进行采样的时候,就可以从高强度的噪声,进行郎之万动力采样,然后慢慢降低噪声的强度。总而言之,就是每个噪声强度,都进行一轮郎之万动力采样,比如下图(图像来自参考①)(Gif图像太大,上传不了…看视频里面吧)

假设有S个噪声强度,那么就可以变成
L = 1 S ∑ i = 1 S λ i 1 2 E P d a t a ( x ) , x ~ ∼ N ( x , σ i 2 I ) [ ∣ ∣ s θ ( x + σ i z , σ i ) + x ~ i − x σ i 2 ∣ ∣ 2 2 ] L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma_i^2I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{\tilde x_i -x}{\sigma_i^2}||_2^2\right] L=S1i=1Sλi21EPdata(x),x~N(x,σi2I)[∣∣sθ(x+σiz,σi)+σi2x~ix22]
x ~ i \tilde x_i x~i表示在噪声强度为 σ i \sigma_i σi的加噪图像。 λ i \lambda_i λi代表的是一个加权系数.一般情况下,我们取 λ i = σ i 2 \lambda_i=\sigma^2_i λi=σi2

对于噪声强度数量,一般是数百到数千;噪声强度选择一般采用几何级数。

采样的时候正如前面所说,先在高强度噪声量级进行郎之万动力采样,而后慢慢降低,所以采样方法为

5、结束

好了,本篇文章到此为止,如有问题,还望指出,阿里嘎多!!!

在这里插入图片描述

6、参考

①Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song (yang-song.net)

②基于分数的生成模型(Score-based generative models) — 张振虎的博客 张振虎 文档 (zhangzhenhu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1685886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IDEA设置运行内存

1.开启内存指示条​​​​​​​ 查看idea右下角​​​​​​​ 2.环境变量查看ideaVM地址,没有的话那就是默认的配置文件: idea 安装 bin 目录下 idea64.exe.vmoptions 3.去对应路径修改内存参数大小 4.重启IDEA,end

leetcode-主持人调度(二)-110

题目要求 思路 1.先将开始时间和结束时间拆分放到两个数组中进行排序 2.如果开始的时间小于结束时间,说明目前没有空闲的人,需要增加人,如果大于等于,说明有人刚结束了主持,可以进行新的主持了,变更到下一…

JavaEE技术之分布式事务(理论、解决方案、Seata解决分布式事务问题、Seata之原理简介、断点查看数据库表数据变化)

文章目录 JavaEE技术之分布式事务准备:1. 本地事务回顾1.1 什么是事务1.2 事务的作用1.3 事务ACID四大特性1.4 事务的并发问题1.5 MySQL事务隔离级别1.6 事务相关命令(了解)1.7 事务传播行为(propagation behavior)1.8 伪代码练习1.9 回滚策略1.10 超时事…

重构2:重构的原则之笔记

最近在看重构2:改善既有代码的设计这本书,对于代码重构指导非常有帮助,然后也是做个笔记记录下,以下是我阅读本书的前两章的时候整理的思维导图:

The Sandbox 和 Bitkub 联手增强东南亚元宇宙中心

作为去中心化游戏虚拟世界和区块链平台的先驱,The Sandbox 正与泰国领先的区块链网络 Bitkub Blockchain Technology Co., Ltd. 展开创新合作。双方合作的目的是将Bitkub元宇宙的影响力扩展到The Sandbox,建立一个元宇宙中心,向用户承诺从 Bi…

react使用antd警告:Warning: findDOMNode is deprecated in StrictMode.

警告信息: Warning: findDOMNode is deprecated in StrictMode. findDOMNode was passed an instance of DOMWrap which is inside StrictMode. Instead, add a ref directly to the element you want to reference. Learn more about using refs safely here: htt…

SerDes系列之CTLE均衡技术

CTLE(连续时间线性均衡)是一种施加在接收器上的线性模拟高通滤波器,通过衰减低频信号分量,以补偿奈奎斯特频率附近的衰减比例,从而实现信道补偿。当低频信号分量向下衰减并推入底噪范围时,CTLE就会失去调节…

解决Wordpress中Cravatar头像无法访问问题

一、什么是Cravatar Gravatar是WordPress母公司Automattic推出的一个公共头像服务,也是WordPress默认的头像服务。但因为长城防火墙的存在,Gravatar在中国时不时就会被墙一下,比如本次从2021年2月一直到8月都是不可访问状态。 在以往的时候&…

JS 实现鼠标框选(页面选择)时返回对应的 HTML 或文案内容

JS 实现鼠标框选(页面选择)时返回对应的 HTML 或文案内容 一、需求背景 1、项目需求 当用户进行鼠标框选选择了页面上的内容时,把选择的内容进行上报。 2、需求解析 虽然这需求就一句话的事,但是很显然,没那么简单…

MySQL -- 相关知识点

1.数据库相关介绍 数据库的选择通常取决于具体的应用需求,如性能、扩展性、数据一致性和易用性等因素。 1. 关系型数据库(RDBMS) MySQL: 广泛使用的开源数据库,支持大多数操作系统。强调易用性、灵活性和广泛的社区支…

代码随想录算法训练营第36期DAY37

DAY37 先二刷昨天的3道题目,每种方法都写:是否已完成:是。 报告:134加油站的朴素法没写对。原因是:在if中缺少了store>0的判断,只给出了indexi的判断。前进法没写出来。因为忘记了总油量的判断。Sum。…

基于Vue的自定义服务说明弹窗组件的设计与实现

基于Vue的自定义服务说明弹窗组件的设计与实现 摘要 随着技术的不断发展,前端开发面临着越来越高的复杂性和不断变化的需求。传统开发方式往往将整个系统构建为整块应用,这导致对系统的任何微小改动都可能触发整体的逻辑变更,从而增加了开发…

第二证券:见证历史!印度这一交易所市值突破5万亿美元

又一次见证前史! 孟买证券交易所本周实现了一个重要的里程碑,其市值突破5万亿美元,总市值在不到6个月的时间里添加了1万亿美元。 据了解,印度股市两大交易所别离为孟买证券交易所(BSE) 和国家证券交易所&…

discuzX2.5的使用心得 札记一

从开始接受php论坛的开发任务,对php感兴趣的我开始迷恋上discuz这个产品了, 像戴志康这样的创新人才,是我们这代人的骄傲和学习的榜样 应该是了解一下,啥事discuzX2.5,百度看一下 discuz x2.5_百度百科 看完百度词条…

如何通过软件IIC使用MPU6050陀螺仪

目录 1. MPU6050简介 2. MPU6050参数 3. MPU6050硬件电路 4. 代码编写 4.1 MPU6050写寄存器 4.2 MPU6050读寄存器 4.3 初始化 4.4 MPU6050获取ID号 4.5 MPU6050获取数据 1. MPU6050简介 MPU6050是一个6轴姿态传感器,可以测量芯片自身X、Y、Z轴的…

AWTK实现汽车仪表Cluster/DashBoard嵌入式GUI开发(七):快启

前言: 汽车仪表是人们了解汽车状况的窗口,而仪表中的大部分信息都是以指示灯形式显示给驾驶者。仪表指示灯图案都较为抽象,对驾驶不熟悉的人在理解仪表指示灯含义方面存在不同程度的困难,尤其对于驾驶新手,如果对指示灯的含义不求甚解,有可能影响驾驶的安全性。即使是对…

关于新配置的adb,设备管理器找不到此设备问题

上面页面中一开始没有找到此android设备, 可能是因为我重新配置的adb和设备驱动, 只把adb配置了环境变量,驱动没有更新到电脑中, 点击添加驱动, 选择路径,我安装时都放在了SDK下面,可以尝试…

怎么查看公网IP?

在网络通信中,每个设备都会被分配一个IP地址,用于在互联网上进行唯一标识和通信。公网IP是指可以被公开访问的IP地址,可以用来建立远程连接或者进行网络访问等操作。怎么查看公网IP呢?下面将介绍几种常用的方法。 使用命令行查询公…

mainwindow.ui和mainwindow.h和ui_mainwindow.h这几个文件之间的联系是什么

在Qt应用程序开发中,mainwindow.ui, mainwindow.h, 和 ui_mainwindow.h 这三个文件之间有着紧密的联系,共同构成了使用Qt Designer设计的图形用户界面(GUI)应用程序的基础。下面是这三个文件各自的作用及它们之间的关联&#xff1…

linux day7 wget,curl

wget下载命令 curl [-O] 网址 不写-O表示请求网址,会返回网页html代码 写-O表示请求下载网页文件