【扩散模型 李宏毅B站教学以及基础代码运用】

news2024/12/25 12:48:44

李宏毅教学视频:
Link1

B站DDPM公式推导以及代码实现:
Link2

这个视频里面有论文里面的公式推导,并且1小时10分开始讲解实例代码。

文章目录

    • 扩散模型概念:
    • Diffusion Model工作原理:
    • 影像生成模型本质上的共同目标
    • B站简单示例代码讲解

扩散模型概念:

就像石头里面已经有了雕塑,只需要看我们怎么把其他多余的部分去掉。
在这里插入图片描述
注意观察,我们每一个Denoise阶段都不一样,因为每一个阶段传入的图片以及需要处理的noise都不一样,并且直接产生图片比直接产生噪音更难,所以我们通过预测noise来解决问题。
在这里插入图片描述

比如下图所示:step2是我们加的噪声,那么传入input和2的时候就希望预测出gt了,然后进行相减得到step1的图片。
在这里插入图片描述

Diffusion Model工作原理:

VAE和Diffusion的区别
在这里插入图片描述
先看整个训练过程:
在这里插入图片描述

实际结果和我们想的是不一样的。训练时通过X0和噪声得到一个图,逆向的时候输入t和生成的图来得到噪音。想象的是一点一点加入噪音,实际上是直接加进去的。在这里插入图片描述
推断时刻:theat是带有参数的网络。
在这里插入图片描述

影像生成模型本质上的共同目标

通过采样一个高深distribution生成一个图片。希望生成的图片和真实的图片的distribution很接近。
在这里插入图片描述
那么怎么衡量这两个分布的接近程度呢?多数采用的都是Maximum liklihood Estimation.
我们希望我们采样的数据能够通过theta网络计算出来的概率越大越好。 在这里插入图片描述
通过数学变换,将概率最大变为Pdata和Ptheat这两个distribution的KL散度最小。
在这里插入图片描述
VAE的下界
Ptheat(x)表示:通过theta产生x的概率。
在这里插入图片描述

在这里插入图片描述
DDPM计算Ptheta(x)的方法 下图表示产生X0的概率。
在这里插入图片描述
两者对比
在这里插入图片描述
接下来需要计算q(x1|x0)此类公式。
计算方法:X1到X2的计算方法在论文中有提及。
在这里插入图片描述
两个高斯分布都是服从N(0,1),相加的话还是一个高斯分布,并且还是服从N(0,1),只是前面系数会发生变化。系数的话是根号下面数字相加。所以相加之后均值还是为0,方差a方加b方即可,这个在另外一个视频里面有讲解。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
经过一番推导之后得到:
在这里插入图片描述
之后计算最下面三项:
在这里插入图片描述
通过以下推导:
在这里插入图片描述
之后通过X0,Xt可以得到Xt-1的分布。
在这里插入图片描述
可以看到前面一项的mean 和 variance是固定的,第二项的variance也是固定的,因此我们需要把第二项的mean变得和第一项的接近。
在这里插入图片描述
那么怎么minimiaze这个mean呢?希望用Xt去预测出来那个mean。
在这里插入图片描述
经过推导:
在这里插入图片描述
最终得到下图:
在这里插入图片描述
里面beta可以学习,但是效果不好,所以使用线性固定。最后加上一个噪声猜测是为了增强鲁棒性,并且本身就是从噪声开始,不加噪声的话可能不会生成图片。

B站简单示例代码讲解

# 加载数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch

s_curve,_ = make_s_curve(10**4,noise=0.1)
print(np.shape(s_curve))
s_curve = s_curve[:,[0,2]]/10.0

print("shape of s:",np.shape(s_curve))

data = s_curve.T

fig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');

ax.axis('off')
 
dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

# 2确定超参数的值
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
# print(alphas_prod)
alphas_prod_p = torch.cat([torch.tensor([1]).float(),alphas_prod[:-1]],0)
# print(alphas_prod_p)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)

、确定扩散过程任意时刻的采样值

#3 计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):
    """可以基于x[0]得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)
    alphas_t = alphas_bar_sqrt[t]
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
    return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
j
# 4 演示原始数据分布加噪100步后的结果

num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
    j = i//10
    k = i%10
    q_i = q_x(dataset,torch.tensor([i*num_steps//num_shows]))#生成t时刻的采样数据
    axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
    axs[j,k].set_axis_off()
    axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')

在这里插入图片描述

# 5 编写拟合逆扩散过程高斯分布的模型

import torch
import torch.nn as nn
​
class MLPDiffusion(nn.Module):
    def __init__(self,n_steps,num_units=128):
        super(MLPDiffusion,self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,2),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
            ]
        )
    def forward(self,x,t):
#         x = x_0
        for idx,embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2*idx](x)
            x += t_embedding
            x = self.linears[2*idx+1](x)
            
        x = self.linears[-1](x)
        
        return x

loss_fn 就是Lsimple得表达式。通过传入参数,生成一个随机噪声,并且送入model里面,那么上面也讲了,model的作用是根据X0,和t预测出我们应该减去的噪声,所以损失函数就是用生成的噪声减去预测的噪声。
在这里插入图片描述

# 6 编写训练的误差函数
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]
    
    #对一个batchsize样本生成随机的时刻t
    t = torch.randint(0,n_steps,size=(batch_size//2,))
    t = torch.cat([t,n_steps-1-t],dim=0)
    t = t.unsqueeze(-1)
    
    #x0的系数
    a = alphas_bar_sqrt[t]
    
    #eps的系数
    aml = one_minus_alphas_bar_sqrt[t]
    
    #生成随机噪音eps
    e = torch.randn_like(x_0)
    
    #构造模型的输入
    x = x_0*a+e*aml
    
    #送入模型,得到t时刻的随机噪声预测值
    output = model(x,t.squeeze(-1))
    
    #与真实噪声一起计算误差,求平均值
    return torch.pow((e - output),2).mean()
# 7、编写逆扩散采样函数(inference)

def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
    """从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq
​
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
    """从x[T]采样t时刻的重构值"""
    t = torch.tensor([t])
    
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    
    eps_theta = model(x,t)
    
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
    
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    
    sample = mean + sigma_t * z
    
    return (sample)
# 8、开始训练模型,打印loss及中间重构效果

seed = 1234class EMA():
    """构建一个参数平滑器"""
    def __init__(self,mu=0.01):
        self.mu = mu
        self.shadow = {}
        
    def register(self,name,val):
        self.shadow[name] = val.clone()
        
    def __call__(self,name,x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0-self.mu)*self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average
    
print('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
​
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)for t in range(num_epoch):
    for idx,batch_x in enumerate(dataloader):
        loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
        optimizer.step()
        
    if(t%100==0):
        print(loss)
        x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
        
        fig,axs = plt.subplots(1,10,figsize=(28,3))
        for i in range(1,11):
            cur_x = x_seq[i*10].detach()
            axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
            axs[i-1].set_axis_off();
            axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

最后的演示

动画演示扩散过程和逆扩散过程

import io
from PIL import Image
​
imgs = []
for i in range(100):
    plt.clf()
    q_i = q_x(dataset,torch.tensor([i]))
    plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off');
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    imgs.append(img)
mg = Image.open(img_buf)
    reverse.append(img)
reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].detach()
    plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    reverse.append(img)
​
imgs = imgs +reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/982441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

qt使用QCustomplot绘制cpu和内存使用率图

一、QCustomPlot介绍 QCustomPlot是一个开源的Qt C图表库,用于可视化数据。该库提供了多种类型的可定制的图表,包括散点图、线图、柱状图和等高线图等。它还支持自定义绘制,可以创建任意形状和大小的元素,并使其与其他元素交互。Q…

iPad电容笔贵吗?开学季比较好用的ipad手写笔

“ipad好买,但是ipad的配件不好买”,我相信很多人都会有这个问题,如果你想要购买像Apple Pencil这样的官方配件,却很难舍得下手,因为一款Apple Pencil的价格就已经接近1000元了。就像许多人不愿购买昂贵的苹果官方产品…

05 C/C++ 指针复杂类型说明 9月5日

目录 C语⾔ (1)数组 (2)指针 指针变量 空指针 (3)指针复杂类型 int a 0; int *p &a; int p[3];​​​​​​​ int *p[3]; int (*p)[3]; int **p; int p(int); int(*p)(int); C语⾔ (1)数组 当数据具有相同的数据类型;使用过程中需要保留原始…

在学习DNS的过程中给我的启发

在国内,关于DNS相关的话题一直络绎不绝,比如DNS根服务器为什么中国没有,还有Anycast BGP实现负载,为什么DNS只有13个,还有DNS over HTTPS 和 DNS over TLS的优劣等等问题,接下来我会找出几个一一说一下其中…

【Linux】- 一文秒懂shell编程

shell编程 1.1 Shell 是什么1.2 Shell 脚本的执行方式1.3 编写第一个 Shell 脚本2.1 Shell 的变量2.2 shell 变量的定义2.3 设置环境变量3.1 位置参数变量3.2 预定义变量4.1 运算符4.2 条件判断5.1 流程控制5.2 case 语句5.3 for 循环5.4 while 循环5.5 read基本语法6.1函数6.2…

API接口已经成为企业应用程序开发和管理的重要组成部分

API接口的价值 随着数字化时代的到来,API接口已经成为企业应用程序开发和管理的重要组成部分。API不仅是一种连接不同系统、提高数据流动性和促进协作的工具,而且还是一种重要的商业战略,可以为组织带来许多实际的价值。本文将探讨API接口的…

Android的本地数据

何为本地,即写完之后除非手动修改,否像嘎了一样在那固定死了 在实际安卓开发中,这种写死的概念必不可少,如控件的id,某一常量,Kotlin中的Val 当然,有些需求可能也会要求我们去写死数据&#x…

一文搞懂XaaS

云服务是指通过互联网按需提供给企业和客户的各种服务,大致可以分为IaaS、PaaS、SaaS三类,每一类又衍生出不同细分的云服务模式。本文介绍了当前已经提出的19种云服务模式,原文: The Comprehensive Concept of IaaS, PaaS, SaaS, AaaS, BaaS,…

基于STM32,TB6612,TCRT5000的简易红外循迹小车

提醒:本文章只叙述此小车相关大概内容(如模块的设置,C语言基础实现等),单片机详细教学不涉及。 摘要 循迹小车是学习单片机的“地基”,它能够让初学者认识单片机内部硬件结构及其功能,熟悉单片机…

安装RabbitMQ的各种问题(包括已注册成windows服务后,再次重新安装,删除服务重新注册遇到的问题)

一、安装Erlang(傻瓜式安装) 安装完成之后,配置环境变量: 1.新建系统变量名为:ERLANG_HOME 变量值为erlang安装地址 2. 双击系统变量path,点击“新建”,将%ERLANG_HOME%\bin加入到path中。 …

学习笔记——Java入门第一季

1.1 Java的介绍与前景 Java语言最早期的制作者:James Gosling(詹姆斯高斯林) 1995年5月23日,Sun Microsystems公司宣布Java语言诞生。 1.2 Java的特性与版本 跨平台 开源(开放源代码) Java代码&#xff…

酷开系统游戏空间,开启大屏娱乐新玩法

在这个充满科技感和无限创意的时代,游戏已经成为我们生活的一部分。而随时着科技的不断发展,以及游戏爱好者的游戏需求在不断提高,促使游戏体验也向更加丰富多彩的方向发展。显然,酷开科技早已经认识到游戏发展的新蓝图&#xff0…

金鸣识别名片识别模块 ,名片扫描仪的神仙“伴侣”

名片扫描仪是现代办公中常见的设备,其作用是将纸质名片转换为电子格式并进行识别。在实现这一功能方面,使用自带OCR功能和金鸣识别两种方式均具有各自的优势。 一方面,自带OCR功能的名片扫描仪具有便捷性和即时性的优势。通过设备内置的OCR技…

国产信创服务器如何进行安全可靠的文件传输?

信创,即信息技术应用创新,2018年以来,受“华为、中兴事件”影响,国家将信创产业纳入国家战略,并提出了“28n”发展体系。从产业链角度,信创产业生态体系较为庞大,主要包括基础硬件、基础软件、应…

SpringMVC综合案例

目录 一、SpringMVC常用注解 二、传递参数 2.1 基础类型String 2.2 复杂类型 2.3 RequestParam 2.4 PathVariable 2.5 RequestBody 2.6 RequestHeader 2.7 请求方法 三、返回值 3.1 void 3.2 String 3.3 StringModel 3.4 ModelAndView 四、页面跳转 4.1 转发 4…

iPhone用户的价值是安卓用户的4倍?难以置信,研究发现竟是7.4倍

据Asymco机构分析师Horace Dediu发布的最新报告,苹果用户在应用上的平均支出是安卓用户的7.4倍,远高于此前提出的4倍观点。这意味着,尽管安卓用户数量是iPhone用户的两倍,但iPhone应用商店开发者的收入是谷歌PlayStore的两倍。 在…

淘宝销量展示方式变更背后的逻辑

淘宝销量展示方式发生了调整,平台于8月16日将商品详情销量展示表达由【月销**件】全部换成展示【已售**件】,将30天销量改成了近365天销量。 【已售**件】统计口径:统计近365天支付的商品件数,数据更新请关注24-48小时。其中涉及销…

数据库模式迁移工具的演进:CLI,GUI,集成式协作数据库平台

数据库模式迁移可能是应用程序开发中最具风险的领域,它困难、有风险且令人痛苦。数据库模式迁移工具的存在就是为了减轻这些痛苦,并且已经取得了长足的进步:从基本的CLI工具到GUI工具,从简单的SQL GUI客户端到集成式协作数据库平台…

PROSOFT PTQ-PDPMV1网络接口模块

通信接口:PROSOFT PTQ-PDPMV1 网络接口模块通常配备了多种通信接口,以便与不同类型的设备和网络进行通信。常见的接口包括以太网、串行端口(如RS-232和RS-485)、Profibus、DeviceNet 等。 协议支持:该模块通常支持多种…

《向量数据库指南》——AI原生向量数据库Milvus Cloud 2.3 新功能ScaNN 索引和Iterator

ScaNN 索引 Milvus 目前支持了 Faiss 中的 FastScan 算法,在各项 benchmark 中有着不俗的表现,对比 HNSW 有 20% 左右提升,约为 IVFFlat 的 7 倍,同时构建索引速度更快。ScaNN 在算法上跟 IVFPQ 比较类似,聚类分桶,然后桶里的向量使用 PQ 做量化,区别是 ScaNN 对于量化比…