手撕扩散模型(一)| 训练部分——前向扩散,反向预测代码全解析

news2025/1/14 20:51:04

文章目录

  • 1 直接使用 核心代码
  • 2 工程代码实现
    • 2.1 DDPM
    • 2.2 训练

三大模型VAE,GAN, DIffusion扩散模型 是生成界的重要模型,但是最近一段时间扩散模型被用到的越来越多的,最近爆火的OpenAI的 Sora文生视频模型其实也是用了这种的方式,因而我打算系统回顾扩散系列知识,并注重代码的分析,感兴趣可以关注这一系列的博客,先介绍基础版本的,之后介绍扩散进阶的相关知识。

扩散模型很多的讲解上来会讲解很多的数学,会让人望而却步,但其实扩散在实际使用的时候并不复杂,我会先从代码的角度告诉大家怎么实操,再介绍数学推理

扩散要弄明白训练和推理两个过程~这节主要分析训练过程

1 直接使用 核心代码

基础版本的扩散核心就两句话

(1) DDPM前向扩散得到加噪后的图片在这里插入图片描述
得到标记,对应一个核心公式**

(2) DDPM反向利用Unet网络预测加的噪声

实际上抽象一下,忽略细节,训练部分代码就主要以下部分

import torch
from torch import nn
n_steps=1000#假设我们最大的加噪步数是1000
x0=torch.ones(128,1,28,28) #模拟输入,1个batch有128张图片,通道数1,宽度高度为28
eta = torch.randn_like(x0) #生成初始随机噪声,形状和模拟输入一样
t= torch.randint(0, n_steps, (128,))#t是加噪时间,注意这里的t是随机生成的0到1000的128个随机数
noisy_imgs = ddpm(x0, t, eta) #前向加噪 输入原始输入图片和随机的t,得到128个加噪后的图像,扩散模型核心的第一句话
eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) #反向预测,给定图和t,得到预测噪声,扩散模型核心的第二句话
loss = nn.mse(eta_theta, eta) #计算噪声和实际的噪声之间的差异作为损失
optim.zero_grad()
loss.backward()
optim.step()

2 工程代码实现

当然上面是一个简略版本,实际中肯定要考虑较多的细节问题~

先来实现DDPM

2.1 DDPM

我们申明一个这样一个类MyDDPM

class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(1, 28, 28)):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps  #扩散时间总步数 
        self.device = device  
        self.image_chw = image_chw #image_chw 用于表示图像的通道数、高度和宽度。这里通道数1,宽度高度为28
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
            device)  # beta预先算出来了
        self.alphas = 1 - self.betas #alphas也预先算出来
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device) #alphas_bars也预先算出来了 前i个乘积

    def forward(self, x0, t, eta=None):
        
        n, c, h, w = x0.shape  #[批大小,通道数,图片高,图片宽]
        a_bar = self.alpha_bars[t]  #t的大小和批大小相等

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy_img = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy_img

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)

这段代码定义了一个名为MyDDPM的类,它是nn.Module的子类。

MyDDPM类的构造函数__init__中,有以下几个重要的属性和操作:

  • n_steps:扩散时间总步数,表示模型在每个输入上进行的扩散步数。
  • device:设备,表示模型在哪个设备上运行(如CPU或GPU)。
  • image_chw:图像通道数、高度和宽度的元组,用于表示图像的形状。在这里,通道数为1,高度和宽度为28。
  • network:神经网络模型,用于估计添加的噪声。
  • betas:通过使用torch.linspace函数在min_betamax_beta之间生成n_steps个均匀间隔的值,得到一个表示扩散系数的张量。
  • alphas:通过将1减去betas得到的张量,表示衰减系数。
  • alpha_bars:通过计算alphas的前i+1个元素的乘积,得到一个表示衰减系数累积乘积的张量。

MyDDPM类还定义了两个方法:

  • forward方法用于前向传播。它接受输入x0、时间步t和可选的噪声eta作为参数。在该方法中,首先获取输入x0的形状,并根据时间步t获取对应的衰减系数a_bar。如果未提供噪声eta,则使用torch.randn函数生成一个与输入形状相同的噪声张量。然后,根据衰减系数和噪声,计算得到带有噪声的图像张量,并返回该张量作为输出。
  • backward方法用于反向传播。它接受输入x和时间步t作为参数,并通过调用network模型对每个时间步t的输入x进行处理,得到估计的添加噪声。最后,返回估计的噪声张量作为输出。

2.2 训练

有了DDPM我们就可以进行训练了(实际上这里的network我们先当做一个黑盒,在下一节讲解结构,network实现的效果就是输入某一时刻的t,和该时刻加噪后的图像,输出预测的噪声结果,该结果和前向生成的噪声做损失函数~优化参数)

def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):
    mse = nn.MSELoss()
    best_loss = float("inf")
    n_steps = ddpm.n_steps

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
            # Loading data
            x0 = batch[0].to(device) #[128,1,1,28]
            n = len(x0)

            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            eta = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device) #注意这里的t是随机生成的

            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, eta) #经过前向过程 y一次得到一个批次的

            # Getting model estimation of noise based on the images and the time-step
            eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1)) 
            loss = mse(eta_theta, eta) #预测噪声和给出的噪声之间的差异
            optim.zero_grad()
            loss.backward()
            optim.step()

            epoch_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)
  1. 函数定义:

    def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):
    
    • 这个函数接受多个参数:ddpm是一个对象,loader是一个数据加载器,n_epochs是训练的轮数,optim是优化器,device是设备(如CPU或GPU),display是一个布尔值,用于控制是否显示生成的图像,store_path是模型存储的路径。
    • 函数没有返回值。
  2. 导入模块:

    mse = nn.MSELoss()
    
    • 这里导入了nn模块,并创建了一个MSELoss的实例对象mse
  3. 初始化变量:

    best_loss = float("inf")
    n_steps = ddpm.n_steps
    
    • best_loss被初始化为正无穷大,用于跟踪最佳损失值。
    • n_stepsddpm对象中获取,表示模型的步数。
  4. 训练循环:

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
            # Loading data
            x0 = batch[0].to(device) #[128,1,1,28]
            n = len(x0)
            ...
    
    • 外部循环是训练的轮数,使用range(n_epochs)生成一个迭代器,并使用tqdm函数包装,以显示训练进度条。
    • 内部循环是对数据加载器中的批次进行迭代,使用enumerate函数包装,并使用tqdm函数包装,以显示每个批次的进度条。
    • 在每个批次中,首先从批次中加载数据,并将其移动到指定的设备上。
    • x0是批次中的第一个元素,表示输入数据。
    • n是批次的大小。
  5. 数据处理和模型训练:

    eta = torch.randn_like(x0).to(device)
    t = torch.randint(0, n_steps, (n,)).to(device)
    noisy_imgs = ddpm(x0, t, eta)
    eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))
    loss = mse(eta_theta, eta)
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    • eta是一个与x0形状相同的随机张量,用于添加噪声。
    • t是一个随机生成的整数张量,表示时间步骤。
    • noisy_imgs是通过将x0t作为输入,使用ddpm对象进行前向传播得到的噪声图像。
    • eta_theta是通过将noisy_imgst进行反向传播,使用ddpm对象得到的噪声估计。
    • loss是通过计算eta_thetaeta之间的均方误差(MSE)得到的损失。
    • optim.zero_grad()用于清除优化器的梯度。
    • loss.backward()用于计算损失相对于模型参数的梯度。
    • optim.step()用于更新模型参数。
  6. 显示生成的图像和存储模型:

    if display:
        show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")
    ...
    if best_loss > epoch_loss:
        best_loss = epoch_loss
        torch.save(ddpm.state_dict(), store_path)
        log_string += " --> Best model ever (stored)"
    ...
    print(log_string)
    
    • 如果displayTrue,则调用show_images函数显示生成的图像。
    • generate_new_images函数用于生成新的图像样本。
    • 如果当前轮的损失比之前的最佳损失更低,则将模型参数保存到指定的路径。
    • 最后,打印训练日志字符串。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1459146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java期末】学生成绩管理系统(MySQL数据库)

诚接C语言、C、Java、Python、HTML、JavaScript、vue、MySQL相关编程作业, 标价10-20每份,如有需要请加文章最下方QQ。 本文资源:https://download.csdn.net/download/weixin_47040861/88856340 1.题目要求 学生成绩管理系统 通过Java控制…

PNG图片压缩-UPNG.js参数说明及示例

UPNG.js是一个非常轻量且高效的库,用于处理PNG图像。它可以编码和解码PNG图片,同时支持压缩和解压缩功能。特别适合在前端项目中处理图像,尤其是在需要优化图像大小而不牺牲质量时。 UPNG.encode()函数是UPNG.js中用于将图像数据编码成PNG格…

量化巨头“卖空”被刷屏!网友:又一类量化策略要“收摊”了

量化圈遇到了龙年首宗“大事件”! 2月20日晚间,沪深交易所同时出手对量化巨头灵均投资的异常交易行为进行“处理”。 沪深交易所均称发现灵均在2月19日开盘1分钟内,名下多个账户通过计算机程序自动生产交易指令,短时间大量下单卖…

WireShark 安装指南:详细安装步骤和使用技巧

Wireshark是一个开源的网络协议分析工具,它能够捕获和分析网络数据包,并以用户友好的方式呈现这些数据包的内容。Wireshark 被广泛应用于网络故障排查、安全审计、教育及软件开发等领域。接下将讲解Wireshark的安装与简单使用。 目录 Wireshark安装步骤…

过了30岁了,一定要专注一件事情?视频号值得尝试!

经常说视频号下载助手, 但发现大多数的大佬都只是先专注一件事情。 小编初6就回来了,和一个大佬吃饭,虽然人家规模并不大,但日引客户上千也是基本的。 这里给大家揭秘一下,他的做法!!&#x…

猫头虎分享已解决Bug || 脚本执行错误(Script Execution Failure):ScriptError, ExecutionFailure

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

挑战30天学完Python:Day15 错误类型

📘 Day 14 🎉 本系列为Python基础学习,原稿来源于 30-Days-Of-Python 英文项目,大奇主要是对其本地化翻译、逐条验证和补充,想通过30天完成正儿八经的系统化实践。此系列适合零基础同学,或仅了解Python一点…

Linux编辑器——Vim详解

目录 ⭐前言 ⭐vim的基本概念 ⭐vim的基本操作 ⭐vim命令模式命令集 ⭐vim末行模式命令集 ⭐简单vim配置 ⭐配置文件的位置 ⭐常用配置选项 ⭐前言 vi/vim的区别简单点来说,它们都是多模式编辑器,不同的是vim是vi的升级版本,它不仅兼容…

课程大纲:图像处理中的矩阵计算

课程名称:《图像处理中的矩阵计算》 课程简介: 图像处理中的矩阵计算是图像分析与处理的核心部分。本课程旨在教授学员如何应用线性代数中的矩阵计算,以实现各种图像处理技术。我们将通过强调实际应用和实践活动来确保学员能够理解和掌握这些…

代码随想录算法训练营第三六天 | 无重叠区间、划分字母区间、合并区间

目录 无重叠区间划分字母区间合并区间 LeetCode 435. 无重叠区间 LeetCode 763.划分字母区间 LeetCode 56. 合并区间 无重叠区间 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠…

vue3 之 商城项目—会员中心

整体功能梳理 1️⃣个人中心—个人信息和猜你喜欢数据渲染 2️⃣我的订单—各种状态下的订单列表展示 路由配置&#xff08;三级路由配置&#xff09; 准备模版member/index.vue <script setup> </script><template><div class"container">…

深度学习图像算法工程师--面试准备(1)

1 请问人工神经网络中为什么 ReLU 要好过于 tanh 和 Sigmoid function&#xff1f; 采⽤Sigmoid 等函数&#xff0c;算激活函数时&#xff08;指数运算&#xff09;&#xff0c;计算量⼤&#xff0c;反向传播求误差梯度时&#xff0c;求导涉及除法和指数运算&#xff0c;计算量…

《VitePress 简易速速上手小册》第2章:Markdown 与页面创建(2024 最新版)

文章目录 2.1 Markdown 基础及扩展2.1.1 基础知识点解析2.1.2 重点案例&#xff1a;技术博客2.1.3 拓展案例 1&#xff1a;食谱分享2.1.4 拓展案例 2&#xff1a;个人旅行日记 2.2 页面结构与布局设计2.2.1 基础知识点解析2.2.2 重点案例&#xff1a;公司官网2.2.3 拓展案例 1&…

软件测试方法_边界值分析法

目录&#xff1a; ①边界值分析法的介绍和概念 ②边界值分析法的原理和思想 ③单缺陷假设和多缺陷假设 ④边界值测试数据类型 ⑤内部边界值分析 ⑥各类边界值测试介绍 ⑦基于边界值分析方法选择测试用例的原则 ⑧边界值分析法的实例分析 1、边界值分析法的介绍和概念 …

力扣94 二叉树的中序遍历 (Java版本) 递归、非递归

文章目录 题目描述递归解法非递归解法 题目描述 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3,2] 示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[] 示…

【大厂AI课学习笔记】【2.2机器学习开发任务实例】(7)特征构造

特征分析之后&#xff0c;就是特征构造。 特征构造第一步 特征构造往往要进行数据的归一化。 在本案例中&#xff0c;我们将所有的数据&#xff0c;将所有特征区间调整为0~1之间。 如上图。 那么&#xff0c;为什么要进行归一化&#xff0c;又如何将数据&#xff0c;调整为…

【安卓基础1】初识Android

&#x1f3c6;作者简介&#xff1a;|康有为| &#xff0c;大四在读&#xff0c;目前在小米安卓实习&#xff0c;毕业入职。 &#x1f3c6;安卓学习资料推荐&#xff1a; 视频&#xff1a;b站搜动脑学院 视频链接 &#xff08;他们的视频后面一部分没再更新&#xff0c;看看前面…

【力扣白嫖日记】1873.计算特殊奖金

前言 练习sql语句&#xff0c;所有题目来自于力扣&#xff08;https://leetcode.cn/problemset/database/&#xff09;的免费数据库练习题。 今日题目&#xff1a; 1873.计算特殊奖金 表&#xff1a;Employees 列名类型employee_idintnamevarcharsalaryint employee_id 是…

LeetCode 450.删除二叉搜索树中的节点和669.修建二叉搜索树思路对比 及heap-use-after-free问题解决

题目描述 450.删除二叉搜索树中的节点 给定一个二叉搜索树的根节点 root 和一个值 key&#xff0c;删除二叉搜索树中的 key 对应的节点&#xff0c;并保证二叉搜索树的性质不变。返回二叉搜索树&#xff08;有可能被更新&#xff09;的根节点的引用。 一般来说&#xff0c;…

代码控制写入excel文件

1、引言 在工作和学习中&#xff0c;我们经常使用到excel表格&#xff0c;有时候表格中的数据很多&#xff0c;此时我们就希望能够通过程序去控制某些表格数据的生成和修改&#xff0c;从而达到简化操作&#xff0c;缩减工作量的目的&#xff0c;这里就来简单实现一下对excel表…