DDPM pytorch 代码复现

news2025/1/11 6:12:24

本次只分享代码以及效果,后续更新原理
代码参考 deep_thought

先看动图效果
在这里插入图片描述

1.选择一个数据集

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch

s_curve, _ = make_s_curve(10 ** 4, noise=0.1)
s_curve = s_curve[:, [0, 2]] / 10.0

print("shape of moons:", np.shape(s_curve))

data = s_curve.T
fig, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolor='white')

ax.axis('off')

dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

2. 确定超参数

num_steps = 100  # 对于步骤,一开始可以由 被他、分布的均值和标准差来共同确定

# 制定每一步的 beta
betas = torch.linspace(-6, 6, num_steps)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5

# 计算alpha,alpha_prod,alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)  # p 表示 previous
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape
print("all the same shape:", betas.shape)

3. 确定扩散过程任意时刻的采样值

# 计算任意时刻的x的采样值,基于x_0核参数重整化技巧
def q_x(x_0, t):
    """可以基于x[0]"得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)  # noise 是从正太分布中生成的随机噪声
    alphas_t = alphas_bar_sqrt[t]
    alphas_l_m_t = one_minus_alphas_bar_sqrt[t]
    # alphas_t = extract(alphas_bar_sqrt,t,x_0) # 得到sqrt(alphas_bar[t]),x_0的作用是传入shape
    # alphas_l_m_t = extract(one_minus_alphas_bar_sqrt,t,x_0) # 得到sqrt(1-alphas_bar[t])
    return (alphas_t * x_0 + alphas_l_m_t * noise)  # 在 x[0]基础上添加噪声

4.演示原始数据分布加噪 100 步后的效果

num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 3))
plt.rc('text', color='blue')
# 共有 10000 个点,每个点包含两个坐标
# 生成 100 步以内每隔 5 步加噪声的图像
for i in range(num_shows):
    j = i // 10
    k = i % 10
    q_i = q_x(dataset, torch.tensor([i * num_steps // num_shows]))  # 生成 t 时刻的采样数据
    axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolors='white')

    axs[j, k].set_axis_off()
    axs[j, k].set_title('$q(\mathbf{x}_{' + str(i * num_steps // num_shows) + '})$')

在这里插入图片描述

5. 编写拟合扩散过程高斯分布的模型


import torch
import torch.nn as nn


class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
        self.linears = nn.ModuleList([
            nn.Linear(2, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, 2)
        ])
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
            ]
        )

    def forward(self, x_0, t):
        x = x_0
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)

        x = self.linears[-1](x)

        return x

6.编写训练的误差函数

def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]

    # 随机采样一个时刻t,为了提高训练效率,这里确保 t 不重复
    # weights = torch.ones(n_steps).expand(batch_size,-1)
    # t = torch.multinomial(weights,num_samples=1,replacement=False) # [batch_size,1]
    t = torch.randint(0, n_steps, size=(batch_size // 2,))
    t = torch.cat([t, n_steps - 1 - t], dim=0)
    t = t.unsqueeze(-1)
    # print(t.shape)

    # x0 的系数
    a = alphas_bar_sqrt[t]

    # eps的系数
    aml = one_minus_alphas_bar_sqrt[t]

    # 生成随机噪声eps
    e = torch.randn_like(x_0)

    # 构造模型的输入
    x = x_0 * a + e * aml

    # 送入模型,得到 t 时刻的随机噪声预测值
    output = model(x, t.squeeze(-1))

    # 与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()

7.编写逆扩散采样函数

def p_sample_loop(model, shape, n_steps, betas, one_minus_alphas_bar_sqrt):
   """ 从x[T]恢复x[T-1],x[t-2]...x[0]"""
   cur_x = torch.randn(shape)
   x_seq = [cur_x]
   for i in reversed(range(n_steps)):
       cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
       x_seq.append(cur_x)
   return x_seq


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
   """从x[T]采样 t 时刻的重构值"""

   t = torch.tensor([t])

   coeff = betas[t] / one_minus_alphas_bar_sqrt[t]

   eps_theta = model(x, t)

   mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))

   z = torch.randn_like(x)
   sigma_t = betas[t].sqrt()

   sample = mean + sigma_t * z

   return (sample)

8.开始训练模型,并打印loss以及中间的重构效果

seed = 1234


class EMA():
    """构建一个参数平滑器"""

    def __init__(self, mu=0.01):
        self.mu = mu
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.clone()

    def __call__(self, name, x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average


print("training model...")

"""
ema = EMA(0.5)
for name,param in model.named_parameters():
    if param.requires_grad:
        ema.register(name,param.data)
"""

batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_epochs = 4000
plt.rc('text', color='blue')

model = MLPDiffusion(num_steps)  # 输出维度是 2 ,输入还x和step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for t in range(num_epochs):
    for idx, batch_x in enumerate(dataloader):
        loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        optimizer.step()
        # for name,param in model.named_parameters():
        #     if param.requires_grad:
        #         param.data = ema(name,param.data)

    # print loss
    if t % 100 == 0:
        print(loss)
        x_seq = p_sample_loop(model, dataset.shape, num_steps, betas, one_minus_alphas_bar_sqrt)  # 共有 100 个元素
        fig, axs = plt.subplots(1, 10, figsize=(28, 3))
        for i in range(1, 11):
            cur_x = x_seq[i * 10].detach()
            axs[i - 1].scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white')
            axs[i - 1].set_axis_off()
            axs[i - 1].set_title('$q(\mathbf{x}_{' + str(i * 10) + '})$')

<>:55: SyntaxWarning: invalid escape sequence '\m'
<>:55: SyntaxWarning: invalid escape sequence '\m'
C:\Users\28374\AppData\Local\Temp\ipykernel_10752\1573120526.py:55: SyntaxWarning: invalid escape sequence '\m'
  axs[i-1].set_title('$q(\mathbf{x}_{' + str(i*10)+'})$')

training model...
tensor(0.8371, grad_fn=<MeanBackward0>)
tensor(0.3398, grad_fn=<MeanBackward0>)
tensor(0.3658, grad_fn=<MeanBackward0>)
tensor(0.2152, grad_fn=<MeanBackward0>)
tensor(0.3706, grad_fn=<MeanBackward0>)
tensor(0.2685, grad_fn=<MeanBackward0>)
tensor(0.4213, grad_fn=<MeanBackward0>)
tensor(0.3830, grad_fn=<MeanBackward0>)
tensor(0.2178, grad_fn=<MeanBackward0>)
tensor(0.1918, grad_fn=<MeanBackward0>)
tensor(0.2116, grad_fn=<MeanBackward0>)
tensor(0.3871, grad_fn=<MeanBackward0>)
tensor(0.3366, grad_fn=<MeanBackward0>)
tensor(0.1989, grad_fn=<MeanBackward0>)
tensor(0.5254, grad_fn=<MeanBackward0>)
tensor(0.2641, grad_fn=<MeanBackward0>)
tensor(0.3108, grad_fn=<MeanBackward0>)
tensor(0.1901, grad_fn=<MeanBackward0>)
tensor(0.5101, grad_fn=<MeanBackward0>)
tensor(0.3037, grad_fn=<MeanBackward0>)
tensor(0.8759, grad_fn=<MeanBackward0>)

C:\Users\28374\AppData\Local\Temp\ipykernel_10752\1573120526.py:50: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.
  fig,axs = plt.subplots(1,10,figsize=(28,3))

tensor(0.3038, grad_fn=<MeanBackward0>)
tensor(0.4054, grad_fn=<MeanBackward0>)
tensor(0.3833, grad_fn=<MeanBackward0>)
tensor(0.4251, grad_fn=<MeanBackward0>)
tensor(0.3462, grad_fn=<MeanBackward0>)
tensor(0.1814, grad_fn=<MeanBackward0>)
tensor(0.2301, grad_fn=<MeanBackward0>)
tensor(0.4002, grad_fn=<MeanBackward0>)
tensor(0.4273, grad_fn=<MeanBackward0>)
tensor(0.3140, grad_fn=<MeanBackward0>)
tensor(0.3192, grad_fn=<MeanBackward0>)
tensor(0.8542, grad_fn=<MeanBackward0>)
tensor(0.4358, grad_fn=<MeanBackward0>)
tensor(0.2812, grad_fn=<MeanBackward0>)
tensor(0.4819, grad_fn=<MeanBackward0>)
tensor(0.2980, grad_fn=<MeanBackward0>)
tensor(0.4941, grad_fn=<MeanBackward0>)
tensor(0.6179, grad_fn=<MeanBackward0>)
tensor(0.2370, grad_fn=<MeanBackward0>)

<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>
<Figure size 2800x300 with 10 Axes>

这里应该会生成 40 张图片,这里只展现能够提现过程的图片了。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

9. 动画演示扩散过程核逆扩散过程

# Generating the forward image sequence 生成前向过程,也就是逐步加噪声
import io
from PIL import Image

imgs = []

for i in range(100):
    plt.clf()
    q_i = q_x(dataset, torch.tensor([i]))
    plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5)
    plt.axis('off')
    plt.title('step:'+str(i+1))
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    imgs.append(img)

# Generating the reverse diffusion sequence

reverse = []

for i in range(100):
    plt.clf()
    cur_x = x_seq[i].detach() # 拿到训练末尾阶段生成的 x_seq
    plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5)
    plt.axis('off')
    plt.title('step:'+str(i+1))
    img_buf = io.BytesIO()
    plt.savefig(img_buf,format='png')
    img = Image.open(img_buf)
    reverse.append(img) 

imgs = imgs + reverse
imgs[0].save("diffusion.gif",format='gif',append_images=imgs,save_all=True,duration=100,loop=1)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1876003.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

零基础STM32单片机编程入门(四)ADC详解及实战含源码视频

文章目录 一.概要二.STM32F103C8T6单片机ADC外设特点三.STM32单片机ADC内部结构图1.ADC相关引脚说明2.ADC通道分类3.触发源4.转换周期5.电压转换计算6.更精确电压转换计算 四.规则通道ADC采集信号流向1.单次转换模式2.连续转换模式 五.CubeMX配置一个ADC采集例程六.CubeMX工程源…

通天星CMSV6车载监控平台CompanyList信息泄露漏洞

1 漏洞描述 通天星CMSV6车载视频监控平台是东莞市通天星软件科技有限公司研发的监控平台,通天星CMSV6产品覆盖车载录像机、单兵录像机、网络监控摄像机、行驶记录仪等产品的视频综合平台。通天星科技应用于公交车车载、校车车载、大巴车车载、物流车载、油品运输车载、警车车…

风控图算法之中心性算法(小数据集Python版)

风控图算法之中心性算法&#xff08;小数据集Python版&#xff09; 图算法在金融风控领域的应用已经超越了传统的社区发现技术&#xff0c;这些技术曾被主要用于识别和分析欺诈性行为模式&#xff0c;例如黑产团伙。当前&#xff0c;一系列图统计算法&#xff0c;包括介数中心…

笔记本重装系统怎么操作? windows电脑重装系统,超实用的四种方法

重新安装操作系统是维护计算机性能和确保系统稳定运行的重要步骤。对于 Windows 笔记本用户而言&#xff0c;熟悉重装系统的方法可以帮助他们解决各种问题&#xff0c;从提高系统速度到修复软件故障。然而具体来讲&#xff0c;笔记本重装系统怎么操作呢&#xff1f;接下来&…

【01】Java代码如何运行

JRE: 包含Java虚拟机以及核心类库 JDK: 同样包含了JRE&#xff0c;并且附带了一系列开发、诊断工具 一、为什么Java要在虚拟机中运行 一、 Java语言特性&#xff1a;高级、语法复杂、抽象 Java语言-- 【编译器】 --> Java字节码 --【虚拟机】–> 实现 二、 托管环境 自…

正点原子rk3588编译sdk

1、编译SDK 1.1 安装 RK3588 Linux SDK .repo/repo/repo sync -l -j101.2 SDK 工程目录介绍 app&#xff1a;存放上层应用 app&#xff0c;包括 Qt 应用程序&#xff0c;以及其它的 C/C应用程序。 buildroot&#xff1a;基于 buildroot 开发的根文件系统。 debian&#xff1…

AIGC对图片行业的影响分析!

前言 自从去年生成式AI火起来之后&#xff0c;不论是文字领域还是图片领域受到的冲击都非常大。比如说SD和Midjourney的爆火&#xff0c;不止是创作者&#xff0c;还有交易平台和使用方&#xff0c;都在发生变化。 AIGC自2023年全面进入大家视野&#xff0c;对各行各业造成了或…

论证型大语言模型:促进可解释性与可质疑的决策制定

Argumentative Large Language Models for Explainable and Contestable Decision-Making 论文地址: https://arxiv.org/abs/2405.02079https://arxiv.org/abs/2405.02079 1.概述 在探讨大型语言模型(LLMs)在决策支持系统中的应用时,我们需正视其面临的核心问题。这些问题…

MYSQL函数进阶详解:案例解析(第19天)

系列文章目录 一、MySQL的函数&#xff08;重点&#xff09; 二、MySQL的窗口函数&#xff08;重点&#xff09; 三、MySQL的视图&#xff08;熟悉&#xff09; 四、MySQL的事务&#xff08;熟悉&#xff09; 文章目录 系列文章目录前言一、MySQL的函数1. 聚合函数2. group_c…

基于Vue,mysql,JavaEE的简单投票与投票管理系统

项目介绍 ​ 本项目&#xff0c;基于Vue2.6,mysql,JavaEE 实现简单的投票与投票管理系统 项目地址 VotingSystem: 投票系统1.0 管理员和普通用户 (gitee.com) 有问题请评论私聊哦 项目分类 数据库 创建投票人&#xff0c;被投票人&#xff0c;投票关系&#xff08;追踪谁…

python基础语法 003-4 数据类型集合

1 集合 1.1 什么是集合 什么是集合&#xff1f;ANS:集合set是一个无序的不重复元素序列集合怎么表示&#xff1f;ANS: {} , 用逗号隔开打印元组类型&#xff0c;type()一个元素的集合怎么表示&#xff1f;&#xff1a;ANS:存储多种类型{"a", 1} """…

海外媒体发稿:2个必选媒体宣发套餐引爆影响力-华媒舍

本文旨在介绍2个必选媒体宣发套餐的特点及其如何引爆影响力。 在当今竞争激烈的媒体环境中&#xff0c;有效的宣传和推广策略对于企业和个人的成功至关重要。这就是为什么选择正确的宣发套餐成为了一个关键的决策。 2. 媒体宣发套餐概述 媒体宣发套餐是一种综合性的宣传方案&…

把飞书云文档变成HTML邮件:问题挑战与解决历程

一、背景 云文档转HTML邮件 基于公司内部的飞书办公套件&#xff0c;早在去年6月&#xff0c;我们就建设了将飞书云文档转译成HTML邮件的能力&#xff0c;方便同学们在编写邮件文档和发送邮件时&#xff0c;都能有较好的体验和较高的效率。 当下问题 要被邮件客户端识别&am…

电脑数据恢复篇:如何恢复误删除的文件

在清理电脑或优化存储设备时无意中删除重要文件是人类常见的错误。不可否认的是&#xff0c;在批量删除文件时&#xff0c;您经常会同时删​​除垃圾文件和重要文件。后来您发现一堆重要的文档或文件不见了。在这种情况下&#xff0c;您唯一的选择就是寻找恢复已删除文件的方法…

DC-DC原理,升降压原理,BUCK,BOOST

DC-DC简述 开关电源包括电源模块&#xff0c;可以直接使用&#xff0c;不需要外部电路&#xff0c;提供的功率比较小。还有电源稳压器&#xff0c;这种功率MOS一般集成在芯片内部&#xff0c;但是需要选择外部电感。另外还有PWM控制器&#xff0c;需要选择功率MOS&#xff0c;二…

2024年河北省计划招聘“特岗计划”教师2300名

2024年河北省计划招聘“特岗计划”教师2300名 报名时间&#xff1a;6月28日9:00至7月2日18:00 笔试准考证打印&#xff1a;7月11日-7月13日 笔试时间&#xff1a;7月14日上午9:00-11:30 面试时间&#xff1a;8月3日至8月5日 报名网站&#xff1a;河北教师教育网 报名照规格&…

C#微信预约挂号系统全套源码,适用于各级公立和民营医院,与院内his、lis、pacs系统对接。

C#微信预约挂号系统源码&#xff0c;团队自主研发&#xff0c;三甲医院应用多年&#xff0c;系统稳定&#xff0c;功能齐全&#xff0c;支持二次开发&#xff0c;项目使用。 微信预约挂号系统可以让患者足不出户就可以利用微信进行在线挂号&#xff0c;实现分时段就诊&#xff…

Github忘记了Two-factor Authentication code

意外重置了edge浏览器 码农家园github自从开启开启了2FA认证&#xff0c;每次输入auth code确实麻烦&#xff0c;于是下载了浏览器插件 Open two factor authenticator&#xff0c; 最近edge频繁宕机&#xff0c;而且提示磁盘空间不足&#xff0c;要不要立即清理并重置浏览器临…

kafka 消费者 API 使用总结

前言 应用程序使用KafkaConsumer向Kafka订阅主题&#xff0c;并从订阅的主题中接收消息。不同于从其他消息系统读取数据&#xff0c;从Kafka读取数据涉及一些独特的概念和想法。如果不先理解这些概念&#xff0c;则难以理解如何使用消费者API。本文将先解释这些重要的概念&…

【10分钟速通webpack,全流程打包,编译,发包,全干货,附代码 】

需求 后端有个nodejs 基础库&#xff0c;用typescript编写&#xff0c;需要发包到代码仓库上&#xff0c;被其它业务引入。这其中就涉及了&#xff1a; 编译&#xff0c; 打包&#xff0c;发包。 工作流速览 前提依赖 webpack主体 npm install --save-dev webpack webpack…