3.基于分数的生成模型

news2024/11/17 3:49:29
1.简介

        基于分数的生成模型(SGM)的核心是Stein分数(或分数函数)。给定一个概率密度函数p(x),其分数函数定义为对数概率密度的梯度Vxlogp(x)。生成模型通过学习并建模输入数据的分布,从而采集生成新的样木,该模型广泛运用于图片视频生成、文本生成和药物分子生成。扩散模型是一类概率生成模型,扩散模型通过向数据中逐步加入噪声来破坏数据的结构,然后学习一个相对应的逆向过程来进行去噪,从而学习原始数据的分布,扩散模型可以生成与真实样本分布高度致的高质量新样本。
        原始图片在加噪过程中逐渐失去了所有信息,最终变成了无法辨识的白噪声。反过程是从噪声开始,模型逐渐对数据进行去噪,可辨识的信息越来越多,直到所有噪声全部被去掉,并生成了新的图片。展示了去噪过程中最重要的概念--分数函数,即当前数据对数似然的梯度,直观上它指向拥有更大似然(更少噪声)的数据分布。逆向过程中去噪的每一步都需要计算当前数据的分数函数,然后根据分数函数对数据进行去噪。一般的生成模型可以分为两类:一类可以直接对数据分布进行建模,比如自回归模型和能量模型;还有是基于潜在变量的模型,它们先假设了潜在变量的分布,然后通过学习一个随机或者非随机的变换将潜在变量进行转换,使转换后的分布逼近真实数据的分布。第二类的生成模型包括变分自编码器(VariationalAuto-Encoder,VAE)、生成对抗网络(Generative Adversarial Network,GAN)、归一化流(Normalizing Flow)。与变分自编码器、生成对抗网络、归一化流等基于潜在变量的生成模型类似,扩散模型也是对潜在变量进行变换,使变换后的数据分布逼近真实数据的分布。但是变分自编码器不仅需要学习从潜在变量到数据的“生成器”q\theta(xIZ),还需要学习用后验分布q\varphi(z|x)来近似真实后验分布q\theta(zIx)以训练生成器。

        而如何选择后验分布是变分自缘码器的难点,如果选得比较简单,那么很可能没办法近似真实后验分布,从而导致型效果不好;而如果选得比较复杂,那么其计算又会很复杂。虽然生成对抗网络和化流都不涉及计算后验分布,但它们也有各自的缺点。生成对抗网络的训练需要外的判别器,这导致其训练难、不稳定;归一化流则要求潜在变量到数据的映射是可映射,这大大限制了其表达能力,并且不能直接使用SOTA(state-ofthe-art)的神经网络框架。

        而扩散模型则综合了上述模型的优点并且避免了上述模型的缺点,只需要训练生成器即可。损失函数的形式简单且容易训练,不需要如判别器等其他的辅助网络表达能力强。当前对扩散模型的研究大多基于3个主要框架:去噪扩散概率模型、基于分数的生成模型、随机微分方程。

2.去噪扩散概率模型

        去噪扩散概率模型(DDPM),定义了一个马尔可夫链(MarkovChain)(马尔科夫链是一种随机过程,它描述了一个系统在不同状态之间转换的概率模型。在马尔科夫链中,系统的未来状态只依赖于当前状态,而与过去的状态无关。这种性质称为无记忆性马尔科夫性质并缓慢地向数据添加随机噪声,然后学习逆向扩散过程,从噪声中构建所需的数据样本。一个DDPM由两个马尔可夫链组成,一个正向马尔可夫链(以下简称“正向链”)将数据转化为噪声;一个逆向马尔可夫链(以下简称“逆向链”)将噪声转化为数据。正向链通常是预先设计的,其目标是逐步将数据分布转化为简单的先验分布如标准高斯分布。而逆向链的每一步的转移核(转移核通常指的是转移概率分布,它描述了马尔科夫链中从一种状态转移到另一种状态的概率)是由深度神经网络学习得到的,其目标是逆向链转正向链从而生成数据。新数据的生成需要先从先验分布中抽取随机向量,然后将此随机向量输入逆向链并使用祖先采样法(祖先采样法它通过构建一个马尔科夫链来近似目标分布,然后通过这个链进行采样)生成新数据。

        超参数是指在学习过程开始之前设置的参数,而不是通过训练数据直接估计的参数。这些参数通常用于控制学习过程中的行为,比如算法的复杂度、学习率、迭代次数等。

3.随机微分方程

        DDPM 和SGM 可以进一步推广到无限时间步长或噪声强度的情况,其中扰动过程和去噪过程是随机微分方程(SDE)的解。我们称这个形式为“Score SDE”,因为它利用SDE进行噪声扰动和样本生成,去过程需要估计噪声数据分布的分数函数。

        扩散模型的加噪过程可以视作“特定SDE的解”,而去噪过程可以视作“基于分数匹配学习到的逆向SDE的解”。扩散模型在计算机视觉、自然语言处理、多模态学习等领域中都有出色的表现,这意味着扩散模型可以处理各种类型的数据,如连续型数据、离散型数据,或是存在于特定区域的数据。理论分析表明,只要分数估计足够精确,并且前向扩散的时间足够长(使得最终加噪后的分布趋于先验分布),那么扩散模型就可以以多项式复杂度逼近任何(满足较弱条件的)连续型分布,而对于有紧支集的分布(如存在于流形上的分布)只需要进行“早期停止”(earlystop),扩散模型就仍然具有多项式的收敛复杂度。

4.扩散模型的架构

        扩散模型需要训练一个神经网络来学习加噪数据的分数函数▽xlogqt(x),或者学习加在数据上的噪声\epsilon。由于分数函数是对输入数据的似然的导数,所以其维度和输入数据的维度相同;同样地,由于我们对输入数据的每一个维度都加入了独立的标准高斯噪声,所以神经网络预测的噪声维度与输入数据相同。将扩散模型用在图像生成上,U-Net是一个常用的选择,因为它满足输出和输入的分辨率相同的条件。U-Net是一种典型的编码-解码结构,主要由3部分组成:下采样、上采和跳连。编码器利用卷积层和池化层进行逐级下采样。下采样过程中因为进行池化,所以数据的空间分辨率变小。但数据的通道数因为卷积的作用逐渐变大,从而可以学习图片的高级语义信息。解码器利用反卷积进行逐级上采样,空间分辨率变大,数据维度变小。输入原始图像中的空间信息与图像中的边缘信息会被逐渐恢复。由此低分辨率的特征图最终会被映射为与原数据维度相同的像素级结果图。因为下采样和上采样过程形成了一个U形结构,所以被称为“U-Net”。而为了进一步弥补编码阶段下采样丢失的信息,在网络的编码器与解码器之间,U-Net 算法利用跳连来融合两个过程中对应位置上的特征图,使得解码器在进行上采样时能够融合不同层次的特征信息,进而恢复、完善原始图像中的细节信息。这是一个用于扩散模型的 U-Net 架构图,该结构在第步去噪过程中,接收去噪对象x和时间嵌入(timeembedding)temb,输出去噪结果。值得注意的是,由于去噪过程是依赖于时间t的,所以 U-Net 中的残差模块也进行了相应的修改,在抽取特征时,将temb考虑进来。

      目前U-Net是扩散模型的主流架构,但是研究人员发现使用其他架构也能实现转好的效果,比如使用Transformer 架构。近年 Transformer 被广泛地应用在深度学习的各个领域中。其在架构中抛弃了传统的CNN和RNN,整个网络结构完全由Attentio机制组成,拥有并行能力和可扩展性。更准确地讲,Transformer仅由自注意力机(Self-Attention Mechanism)和前馈神经网络(Feed Forward Neural Network)组成在自注意力机制中,输入序列中的每个元素都会与其他元素进行相互作用,从而生一个新的特征向量。这种机制允许模型对输入序列进行非常灵活的处理,能够捕捉入序列中的长程依赖关系。除了自注意力机制,Transformer中的前馈神经网络模块发挥着重要作用。该模块由几层全连接层组成,使用激活函数ReLU对中间层进行活。前馈神经网络模块可以帮助模型捕捉输入序列中的非线性关系,从而更好地进数据建模。Transfrmer的自注意力机制是Transformer 最核心的内容,自注意力机能够对一个序列中的每个元素计算权重,表示该元素与其他元素之间的相关性。

        然后,通过加权求和的方式将所有元素聚合起来得到一个新的表示。下面主要讲解Transfommer的编码阶段,因为在扩散模型中我们只需要提取图像特征从而学习分数函数,或者逆向转移核的参数。为了使用Transformer架构处理图像数据,需要先通过patch操作将图像的空间表示转化为一系列token,并加入位置嵌入。对于一个token序列,首先通过可学习的线性映射计算出序列中的每个向量(ti)对应的Query向量(Qi),Key 向量(Ki)和 Value 向量(Vi),然后为每一个向量计算它和其他向量的评分:<Qi,Kj>/√dk,其中dk是K的维度。对评分进行sofmax 计算得到注意力系数aij,最终得到输出结果zi=∑jaijvj。之后z就会被输入前馈神经网络做进一步处理。举一个简单的例子,当顾客在某电商平台搜索某件商品(如有深度学习代码的参考书)时,顾客在搜索引擎中输入的内容便是Query,然后搜索引擎根据Query 为顾客匹配Key(如“深度学习”“代码”“参考书”),然后根据 Query和Key的相似度得到匹配的内容(Value)。这里的<Qi,Kj>可以视为向量i和向量j的相关程度,sij就是向量i对向量j的注意力大小。为了防止学习退化,Self-Attention中使用了残差链接。一个向量可以拥有多个(Q,K,V),对每个(Q,K,V)都进行上述计算,最终的输出结果就是所有并行Head中Self-Attention输出结果的拼接,这种方式被称为“Multi-Head Attention”(多头注意力)。一个基于Transformer 的可训练的神经网络可以通过堆叠Transformer 的形式进行搭建。在扩散模型中,可以使用Transformer 架构对每一步的加噪数据进行编码然后使用编码结果来预测下一步转移核的期望和方差,从而代替U-Net 架构。

        Peebles等人在“ Scalable Diffusion Models with Transformers”中使用Transfmmer替换 U-Ne,不仅速度更快(更高的Gops),而且在条件生成任务上,效果更好。该研究提出的DiT框架,DiT基于“Latent DiiusioIranstonmer”进行了了种改进,将每一步中的temb和label等条件信息作为引导信息加入Transfommer结构中,加入的方式分为3种:(1)自适应的层标准化,将Transforme模块中常用的层标准化(Layer Normalization,LN)换成了自适应的层标准化(AdaptivLayer Nommalization,AdaLN),即用引导信息去自适应地生成相应的缩放和漂移参数(2)交叉注意力,将引导信息直接和输入的中间特征进行混合;(3上下文条件(In-Context Conditioning)。将引导信息作为额外的输入拼接在输入端其中,AdaLN的效果更好,速度更快。在ageNet上的生成实验表明了基Transformer的扩散模型架构的优越性。

        DiT 还做了一项验证实验,如增加 DiT 中“tansformer”的深度/宽度,或者增加输入的“token”数量(减少图像“patch”的大小)都能够提高生成图像的效果。


       

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2184758.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

假期惊喜,收到公司款项86167.14元

假期惊喜 近日&#xff0c;有网友爆料称&#xff0c;比亚迪在未提前通知员工的情况下&#xff0c;突然发放了利润奖金。 有人获得了七八万元&#xff0c;也有人拿到了十多万元。 一位比亚迪员工的帖子显示&#xff0c;在9月26日下午&#xff0c;他的银行卡突然收到一笔 86167.1…

数字化那点事:一文读懂数字孪生

一、数字孪生的定义 数字孪生&#xff08;Digital Twin&#xff09;是指通过数字技术构建的物理实体的虚拟模型&#xff0c;能够对该实体进行全方位、动态跟踪和仿真预测。简单来说&#xff0c;数字孪生就是在一个设备或系统的基础上创造一个数字版的“克隆体”&#xff0c;这…

Redis --- 第二讲 --- 特性和安装

一、背景知识 Redis特性&#xff1a; Redis是一个在内存中存储数据的中间件&#xff0c;用于作为数据库&#xff0c;作为缓存&#xff0c;在分布式系统中能够大展拳脚。Redis的一些特性造就了现在的Redis。 在内存中存储数据&#xff0c;通过一系列的数据结构。MySQL主要是通…

Ollama安装部署CodeGeeX4 - ALL - 9B

一、模型本地部署准备 1、 conda create -n ollama python3.82、 curl -fsSL https://ollama.com/install.sh | sh3、验证安装 安装完成后&#xff0c;通过运行以下命令来验证Ollama是否正确安装&#xff1a; ollama --version4、启动ollama ollama serve模型地址&#xff…

【重学 MySQL】四十八、DCL 中的 commit 和 rollback

【重学 MySQL】四十八、DCL 中的 commit 和 rollback commit的定义与作用rollback的定义与作用使用场景相关示例注意事项DDL 和 DML 的说明 在MySQL中&#xff0c;DCL&#xff08;Data Control Language&#xff0c;数据控制语言&#xff09;用于管理数据库用户和控制数据的访问…

Ubuntu 安装RUST

官方给的是这样如下脚本 curl --proto https --tlsv1.2 -sSf https://sh.rustup.rs | sh 太慢了 curl --proto https --tlsv1.2 -sSf https://sh.rustup.rs | sh -x 执行这个脚本后会给出对应的下载链接 如下图 我直接给出来 大多数应该都是这个 https://static.rust-…

初识算法 · 双指针(1)

目录 前言&#xff1a; 双指针算法 题目一&#xff1a; ​编辑 题目二: 前言&#xff1a; 本文作为算法部分的第一篇文章&#xff0c;自然是少不了简单叭叭两句&#xff0c;对于算法部分&#xff0c;多刷是少不了&#xff0c;我们刷题从暴力过度到算法解法&#xff0c;自…

csp-j模拟二补题报告

目录传送门 前言第一题下棋&#xff08;chess&#xff09;我的代码&#xff08;AC了&#xff09;AC代码 第二题汪洋&#xff08;BigWater&#xff09;我的代码&#xff08;0&#xff09;AC代码 第三题删数&#xff08;delnum&#xff09;我的代码&#xff08;0&#xff09;AC代…

秋招突击——9/13——携程提前准备和实际面经——专程飞过去线下,结果一面挂(难受)

文章目录 引言面经收集面经整理一1. ArrayList和LinkedList2. 线程安全的列表和链表有么&#xff1f;如果没有怎么实现&#xff1f;3. threadlocal4. synchronized锁升级过程及原理5. ReentrantLock原理&#xff0c;以及和synchronized的对比6. 线程池工作原理7. redis常用数据…

数据流和数据流处理技术

一数据流 首先明确数据流概念&#xff1a;数据流是连续不断生成的、快速变化的无界数据序列 数据流类型&#xff1a; 数据流大致可以分为四种类型 1.连续型数据流&#xff1a;不断地产生数据&#xff0c;数据稳定速度输入系统。 2.突发型数据流&#xff1a;在某特定时间或…

【吊打面试官系列-MySQL面试题】Mysql如何存储日期?

大家好&#xff0c;我是锋哥。今天分享关于【Mysql如何存储日期&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; Mysql如何存储日期&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Datatime:以 YYYY-MM-DD HH:MM:SS 格式存储时期时间&a…

基于Spring Boot+Unipp的中考体测训练小程序(协同过滤算法、图形化分析)【原创】

&#x1f388;系统亮点&#xff1a;协同过滤算法、图形化分析&#xff1b; 一.系统开发工具与环境搭建 1.系统设计开发工具 后端使用Java编程语言的Spring boot框架 项目架构&#xff1a;B/S架构 运行环境&#xff1a;win10/win11、jdk17 前端&#xff1a; 技术&#xff1a;框…

C++中stack和queue的模拟实现

目录 1.容器适配器 1.1什么是适配器 1.2STL标准库中stack和queue的底层结构 1.3deque的简单介绍 1.3.1deque的原理介绍 1.3.2deque的优点和缺陷 1.3.3deque和vector进行排序的性能对比 1.4为什么选择deque作为stack和queue的底层默认容器 2.stack的介绍和模拟…

c++-类和对象-点和圆关系

注意&#xff1a; 1.在一个类中可以让另一个类作为成员 2.可以把一个类拆成过个头文件&#xff0c;在.cpp中写成员函数实现&#xff0c;在头文件中留下类的声明和属性 实践 结果

我谈陷波滤波器

《数字图像处理&#xff08;电子信息前沿技术丛书&#xff09;》PP180~182勘误。 陷波滤波器在信号处理中就是带阻滤波器&#xff0c;信号处理中陷波滤波器不是这样定义的&#xff0c;二维比一维有这样的特殊性&#xff0c;我想这是Gonzalez创造的概念&#xff0c;在学术中借用…

初识算法 · 双指针(2)

目录 前言&#xff1a; 盛最多水的容器 题目解析&#xff1a; 算法原理&#xff1a; 算法编写&#xff1a; 有效三角形的个数 题目解析&#xff1a; 算法原理&#xff1a; 算法编写&#xff1a; 前言&#xff1a; 本文介绍两个题目&#xff0c;盛最多水的容器和有效三…

Excel下拉菜单制作及选项修改

Excel下拉菜单 1、下拉菜单制作2、下拉菜单修改 下拉框&#xff08;选项菜单&#xff09;是十分常见的功能。Excel支持下拉框制作&#xff0c;通过预设选项进行菜单选择&#xff0c;可以避免手动输入错误和重复工作&#xff0c;提升数据输入的准确性和效率 1、下拉菜单制作 步…

留存率的定义与SQL实现

1.什么是留存率 留存率是指在特定时间段内&#xff0c;仍然继续使用某项产品或服务的用户占用户总数的百分比。 通常&#xff0c;留存率会以日&#xff0c;周&#xff0c;或月为单位进行统计和分析。 2.SQL留存率常见问题 1.计算新用户登录的日期的次日留存率以及3日留存率 …

【鸿蒙学习】深入了解UIAbility组件

文章目录 组件概述生命周期启动模式基本用法 在鸿蒙操作系统&#xff08;HarmonyOS&#xff09;的开发过程中&#xff0c;UIAbility组件是构建应用界面的关键。本文将带您了解UIAbility组件的概述、生命周期、启动模式以及基本用法&#xff0c;并通过代码示例帮助您更好地掌握这…

微信互助学习平台|互助学习平台系统|基于java的微信互助学习平台设计与实现(源码+数据库+文档)

微信互助学习平台 目录 基于java的微信互助学习平台设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕设布道师…