diffusion model 简单demo

news2024/11/25 2:28:01

参考自:
Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
diffusion 简单demo
扩散模型之DDPM

核心公式和逻辑

在这里插入图片描述
在这里插入图片描述

q_x 计算公式,后面会用到:
在这里插入图片描述
推理:
在这里插入图片描述

代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve, make_swiss_roll
from PIL import Image
import torch
import io

# get data
# s_curve, _ = make_s_curve(10**4 , noise=0.1)
# s_curve = s_curve[:, [0, 2]] / 10.0

swiss_roll, _ = make_swiss_roll(10**4,noise=0.1)
s_curve = swiss_roll[:, [0, 2]]/10.0

print('shape of moons: ', np.shape(s_curve))

data = s_curve.T
fix, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolors='white', alpha=0.5)

ax.axis('off')

# plt.show()
plt.savefig('./s_curve.png')

dataset = torch.Tensor(s_curve).float()

# set params
num_steps = 100

betas = torch.linspace(-6, 6, num_steps)    # # 逐渐递增
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5    # β0,β1,...,βt

print('beta: ', betas)

alphas = 1 - betas
alphas_pro = torch.cumprod(alphas, 0)   # αt^ = αt的累乘

# αt^往右平移一位, 原第t步的值维第t-1步的值, 第0步补1
alphas_pro_p = torch.cat([torch.tensor([1]).float(), alphas_pro[:-1]], 0)   # p表示previous, 即 αt-1^


alphas_bar_sqrt = torch.sqrt(alphas_pro)    # αt^ 开根号
one_minus_alphas_bar_log = torch.log(1 - alphas_pro)    # log (1 - αt^)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_pro)  # 根号下(1-αt^)

assert alphas.shape == alphas_pro.shape == alphas_pro_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape

print('beta: shape ', betas.shape)

# diffusion process

def q_x(x_0, t):
    ''' get q_x_{\t}
    作用: 可以基于x[0]得到任意时刻t的x[t]
    输入: x_0:初始干净图像; t:采样步
    输出: x_t:第t步时的x_0的样子
    '''
    noise = torch.randn_like(x_0) # 正态分布的随机噪声
    alphas_t = alphas_bar_sqrt[t]
    alphas_l_m_t = one_minus_alphas_bar_sqrt[t]

    return (alphas_t * x_0 + alphas_l_m_t * noise)


# test add noise
num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 3))
plt.rc('text', color='blue')

# 测试一下加噪下过
## 共有10000个点,每个点包含两个坐标
## 生成100步以内,每个5步加噪后图像


for i in range(num_shows):
    j = i // 10
    k = i % 10
    q_i = q_x(dataset, torch.tensor(i * num_steps // num_shows))    # 生成t时刻的采样数据
    axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')
    axs[j, k].set_axis_off()
    axs[j, k].set_title('$q(\mathbf{x}_{' + str(i*num_steps // num_shows) + '})$')
    
# plt.show()
plt.savefig('diffusion_process.png')

# diffusion reverse process

# --------------------- diffusion model -----------------

import torch
import torch.nn as nn

class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=32):
        super(MLPDiffusion, self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2)
            ]
        )
        
        self.step_embeddings = nn.ModuleList(
            [nn.Embedding(n_steps, num_units),
             nn.Embedding(n_steps, num_units),
             nn.Embedding(n_steps, num_units),
             ]
        )

    def forward(self, x, t):
        """
        模型的输入是加噪后的图片x和加噪step-> t, 输出是噪声
        """
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)

        x = self.linears[-1](x) # shape: [10000, 2]

        return x

# loss function
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, use_cuda=False):
    """
    作用: 对任意时刻t进行采样计算loss
    参数:
        model: 模型
        x_0: 干净的图
        alphas_bar_sqrt: 根号下αt^
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
        n_steps: 采样步
    """
    batch_size = x_0.shape[0]

    # 对一个batchsize样本生成随机的时刻t, 覆盖到更多不同的t
    t = torch.randint(0, n_steps, size=(batch_size//2,))  # 在0~99内生成整数采样步
    t = torch.cat([t, n_steps-1-t], dim=0)  # 一个batch的采样步, 尽量让生成的t不重复
    t = t.unsqueeze(-1)  # 扩展维度 -> [batchsize, 1]
    if use_cuda:
        t = t.cuda()

    # x0的系数
    a = alphas_bar_sqrt[t]  # 根号下αt^

    # eps的系数
    aml = one_minus_alphas_bar_sqrt[t]  # 根号下(1-αt^)

    # 生成随机噪音eps
    e = torch.randn_like(x_0)
    if use_cuda:
        e = e.cuda()

    # 构造模型的输入
    x = x_0 * a + e * aml  # 前向过程:根号下αt^ * x0 + 根号下(1-αt^) * eps

    # 送入模型,得到t时刻的随机噪声预测值
    output = model(x, t.squeeze(-1))  # 模型预测的是噪声, 噪声维度与x0一样大, [10000,2]

    # 与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()



# --------------- reverse process ---------------
def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt, use_cuda=False):
    """
    作用: 从x[T]恢复x[T-1]、x[T-2]、...x[0]
    输入:
        model:模型
        shape:数据大小,用于生成随机噪声
        n_steps:逆扩散总步长
        betas: βt
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
    输出:
        x_seq: 一个序列的x, 即 x[T]、x[T-1]、x[T-2]、...x[0]
    """
    if use_cuda:
        cur_x = torch.randn(shape).cuda()
    else:
        cur_x = torch.randn(shape)  # 随机噪声, 对应xt
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt, use_cuda=use_cuda)
        x_seq.append(cur_x)

    return x_seq


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt, use_cuda=False):
    """
    作用: 从x[T]采样t时刻的重构值
    输入:
        model:模型
        x: 采样的随机噪声x[T]
        t: 采样步
        betas: βt
        one_minus_alphas_bar_sqrt: 根号下(1-αt^)
    输出:
        sample: 样本
    """
    if use_cuda:
        t = torch.tensor([t]).cuda()
    else:
        t = torch.tensor([t])

    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]  # 模型输出的系数:βt/根号下(1-αt^) = 1-αt/根号下(1-αt^)
    
    eps_theta = model(x, t)  # 模型的输出: εθ(xt, t)
        
    # (1/根号下αt) * (xt - (1-αt/根号下(1-αt^))*εθ(xt, t))
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))  
    if use_cuda:
        z = torch.randn_like(x).cuda()  # 对应公式中的 z
    else:
        z = torch.randn_like(x)  # 对应公式中的 z

    sigma_t = betas[t].sqrt()  # 对应公式中的 σt

    sample = mean + sigma_t * z

    return (sample)


# ----------- trainning ------------

print('Training model...')
if_use_cuda = True
batch_size = 1024
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)
num_epoch = 4000
plt.rc('text',color='blue')


model = MLPDiffusion(num_steps)  # 输出维度是2,输入是x和step
if if_use_cuda:
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

iteration = 0
for t in range(num_epoch):
    for idx, batch_x in enumerate(dataloader):
        # 损失计算
        if if_use_cuda:
            loss = diffusion_loss_fn(model, batch_x.cuda(), alphas_bar_sqrt.cuda(), one_minus_alphas_bar_sqrt.cuda(), num_steps, use_cuda=if_use_cuda)
        else:
            loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)

        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 损失回传
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.)  # 梯度裁剪
        optimizer.step()

        iteration += 1

        # if iteration % 100 == 0:
        if(t % 100 == 0):
            print(f'epoch: {t} , loss: ', loss.item())
            if if_use_cuda:
                x_seq = p_sample_loop(model, dataset.shape, num_steps, betas.cuda(), one_minus_alphas_bar_sqrt.cuda(), use_cuda=True)
            else:
                x_seq = p_sample_loop(model, dataset.shape, num_steps, betas, one_minus_alphas_bar_sqrt, if_use_cuda)

            fig, axs = plt.subplots(1, 10, figsize=(28,3))
            for i in range(1, 11):
                cur_x = x_seq[i*10].cpu().detach()
                axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
                axs[i-1].set_axis_off();
                axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

            plt.savefig('./diffusion_train_tmp.png')


### ----------------动画演示扩散过程和逆扩散过程-------------------------
# 前向过程
imgs = []
for i in range(100):
    plt.clf()
    q_i = q_x(dataset,torch.tensor([i]))
    plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off');
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    imgs.append(img)

# 逆向过程
reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].cpu().detach()
    plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);
    plt.axis('off')

    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    reverse.append(img)

print('save gif...')
imgs = imgs
imgs[0].save("diffusion_forward.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

imgs = reverse
imgs[0].save("diffusion_denoise.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1604845.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

08-GPtimer

通用定时器 (GPTimer) 通用定时器简介 通用定时器可用于准确设定时间间隔、在一定间隔后触发(周期或非周期的)中断或充当硬件时钟。如下图所示,ESP32-S3 包含两个定时器组,即定时器组 0 和定时器组 1。每…

力扣练习题(2024/4/14)

1接雨水 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3,2…

vue3 -- 项目使用自定义字体font-family

在Vue 3项目中使用自定义字体(font-family)的方法与在普通的HTML/CSS项目中类似。可以按照以下步骤进行操作: 引入字体文件: 首先,确保你的字体文件(通常是.woff、.woff2、.ttf等格式)位于项目中的某个目录下,比如src/assets/font/。 在全局样式中定义字体: 在你的全局…

mysql常见语法操作笔记

1. 数据库的基本操作 1.1. MYSQL登录与退出 D:\phpstudy_pro\Extensions\MySQL5.7.26\bin 输入 mysql -uroot -proot -h127.0.0.1 退出的三种方法 mysql > exit; mysql > quit; mysql > \q; 1.2. MYSQL数据库的一些解释 注意:数据库就相当于文件夹 …

IDEA 控制台中文乱码 4 种解决方案

前言 IntelliJ IDEA 如果不进行相关设置,可能会导致控制台中文乱码、配置文件中文乱码等问题,非常影响编码过程中进行问题追踪。本文总结了 IDEA 中常见的中文乱码解决方法,希望能够帮助到大家。 IDEA 中文乱码 解决方案 一、设置字体为支…

挣钱新玩法,一文带你掌握流量卡推广秘诀

手机流量卡推广项目是什么?听名字我相信大家就已经猜出来了,就是三大运营商为了开发新用户,发起的有奖推广活动,也是为了长期黏贴用户。在这个活动中,用户通过我们的渠道,就能免费办理低套餐流量卡&#xf…

Obsidian 插件安装

方法一: Obsidian 最简单的插件安装当然是通过第三方插件库进行搜索,但是由于魔法上网的问题,经常连不上github,或者下载不了,导致插件无法安装。 方法二: obsidian 社区插件汇总:Airtable -…

【第三十一篇】Autorize插件安装使用教程(结合Burp实现越权实战案例)

Burp Suite是一款功能强大的渗透测试工具,被广泛应用于Web应用程序的安全测试和漏洞挖掘中。 本专栏将结合实操及具体案例,带领读者入门、掌握这款漏洞挖掘利器 读者可订阅专栏:【Burp由入门到精通 |CSDN秋说】 文章目录 前言安装教程使用教程垂直越权垂直越权实战注意前言 …

群晖 NAS rsync 远程文件同步

客户机是外网的 Windows 11,服务器是群晖。 客户机上安装 WSL Alpine Linux 来运行 rsync 进行文件下载。Alpine 相对比 Ubuntu、Debian,要小巧轻量,占用存储空间少,启动速度也很快。 一、安装 WSL Alpine Linux 在 Windows 中&…

scala---基础核心知识(变量定义,数据类型,流程控制,方法定义,函数定义)

一、什么是scala Scala 是一种多范式的编程语言,其设计初衷是要集成面向对象编程和函数式编程的各种特性。Scala运行于Java平台(Java虚拟机),并兼容现有的Java程序。 二、为什么要学习scala 1、优雅 2、速度快 3、能融合到hado…

ADOP-400G光模块问题发布会

前沿光学(ADOP)400G光模块为客户提供各种超高密度的400G以太网连接方案,广泛应用于数据中心、企业网和服务提供商。 📣📣以下一些问题是我们新一代400G光模块常能遇见问题,所以我们决定在这里开一场小小的…

ubuntu22安装宝塔面板

方法一:运行安装宝塔命令 wget -O install.sh https://download.bt.cn/install/install-ubuntu_6.0.sh && sudo bash install.sh ed8484bec 安装成功后,需到服务器管理后台的安全组中配置新规则,放行宝塔面板的端口(以阿…

基于SSM和vue的机票订购管理系统

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM和vue的机票订购管理系统2拥有两种角色 管理员:用户管理、机票管理、订票管理、公告管理、广告管理、系统管理、添加机票等 用户:登录注册、订票、查看公…

论文复现《SplaTAM: Splat, Track Map 3D Gaussians for Dense RGB-D SLAM》

前言 SplaTAM算法是首个开源的基于RGB-D数据,生成高质量密集3D重建的SLAM技术。 通过结合3DGS技术和SLAM框架,在保持高效性的同时,提供精确的相机定位和场景重建。 代码仓库:spla-tam/SplaTAM: SplaTAM: Splat, Track & Map 3…

MySQL表级锁——技术深度+1

引言 本文是对MySQL表级锁的学习,MySQL一直停留在会用的阶段,需要弄清楚锁和事务的原理并DEBUG查看。 PS:本文涉及到的表结构均可从https://github.com/WeiXiao-Hyy/blog中获取,欢迎Star! MySQL表级锁 MySQL中表级锁主要有表锁…

【简单介绍下PostCSS】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

如何实现在 Windows 上运行 Linux 程序?

在Windows 上运行Linux程序是可以通过以下几种方法实现: 1.使用 Windows Subsystem for Linux (WSL): WSL是微软提供的功能,可以在Windows 10上运行一个完整的Linux系统。用户可以在Microsoft Store中安装所需的 在开始前我有一些资料,是我根据网友给的…

SQL --索引

索引 INDEX 伪列 伪装起来的列,不容易被看见,要特意查询才能看见 ROWNUM: 是对查询结果自动生成的一组连续的自然数序号。 SELECT emp.*,ROWNUM FROM emp例题:查询emp表中,前三个员工 SELECT * FROM * from emp w…

Midjourney 实现角色一致性的新方法

AI 绘画的奇妙之处,实乃令人叹为观止!就像大千世界中,寻不见两片完全相同的树叶一般,AI 绘画亦复如是。同一提示之词,竟能催生出千变万化的图像,使得AI所绘之作,宛如自然之物般独特,…

将百度网盘中数据集直接下载到服务器上

步骤: 1:下载安装bypy pip install bypybypy,是一个使用 python 编写的命令行百度网盘客户端 2:初始化 bypy info将这个链接复制到浏览器中打开 复制授权码,粘贴到服务器命令,回车 等待一会,会显示你云盘空间大小信…