diffusion扩散模型之hello world

news2024/9/17 7:20:13

以mnist图像生成样本为例,详细解释diffusion的每个步骤和过程

扩散模型包括两个过程:前向过程(forward process)和反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成数据,这里我们将通过变分推断来进行建模和求解。
另外要指出的是,扩散过程往往是固定的,即采用一个预先定义好的variance schedule,比如DDPM就采用一个线性的variance schedule。
扩散过程的一个重要特性是我们可以直接基于原始数据对于任意步长的xt进行采样
在这里插入图片描述
记住这个公式,这个公式很重要,因为它可以直接从原图片推断出,经过t次加噪声之后的图像。

反向过程
扩散是将数据噪声化,那么反向的过程就是一个逐步去噪的过程,如果我们知道反向过程每一步的真实的噪声分布,那么从一个随机噪声开始就能生成一个真实的样本,所以反向过程就是生成数据的过程,我们可以用神经网络来估计这些噪声的分布。
推导公式的过程比较麻烦,我这里直接把bubbliiiing大佬结论放在这里
在这里插入图片描述
也就是说,我们加噪声的时候,可以一步到位直接从原始图像知道迭代了N步以后的噪声,但是图像还原的时候,要迭代着一步步计算之前的噪声,去噪,然后一步步把图片还原出来

这里我们先实现一个用于预测噪声图像的unet的网络结构
我们模型会有两个输入,一个是时间的步长,表示噪声加了多少步了,第二个是原图,但是时间的步长它是一个数值,比如0,1,2等,它是一个整数,这里我们需要把0,1,2等整数步长转化为离散类型的矩阵tensor才能跟模型的特征图做累加。
把整数转化为矩阵tensor的过程叫做embedding,在更复杂的任务里面它也可以是把文字,词组等转化为矩阵
这个实现,我们手动来实现一个embedding层,(其实就是把步长1,2,3等转化为矩阵)

from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)
## 对于我们的embedding层进行测试一下看看功能
if __name__ == "__main__":
    #当前步长为80,总步长为400
    N = 80
    total_T = 400
    x1 = torch.tensor([N/total_T])
    # 把1维的输入转化为128维的矩阵
    model1 = EmbedFC(1,128)
    y1 = model1(x1)
    print(y1.shape)

上面的输出为
torch.Size([1, 128])

接下来我们要实现unet的部分;Unet是一个对称结构的网络,很多博客都有,详情看看bubbliiiing大佬的博客,最早用于图像分割领域,这里我们用于预测我们的噪声图:
在这里插入图片描述
可以直接拉取别人的改一改也行:


class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        '''
        process and downscale the image feature maps
        '''
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        '''
        process and upscale the image feature maps
        '''
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x
       
 
class Unet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=10):
        super(Unet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
      
        self.up0 = nn.Sequential(
            # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x,t):


        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)
        print(t)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        up1 = self.up0(hiddenvec)
        # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        up2 = self.up1(up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out
 if __name__ == "__main__":
    x = torch.rand(1,1,28,28)

    t = torch.tensor([80/400])
    model = Unet(in_channels=1,n_feat=128)
    y = model(x,t)
    print(y.shape)

输出为:torch.Size([1, 1, 28, 28])
这里我们已经把EbedFC的代码加入UNet中用于对特征图中叠加我们的时间步长信息,这样子在我们在训练的过程中就有了随着时间波动的噪声图。到这里网络已经搭建完毕了。接下来我们要训练我们的加噪和去噪的过程。
第一步就是结合上面的公式,从原图直接推导出第N次加噪后的图像。
在这里插入图片描述
也就是这里:
实现代码为:

#先随机选取一个步长N,代表了加噪N次后的原图和噪声叠加的图像
n_T = 400 #加噪的总步长
#随机加噪声的步长
_ts = torch.randint(1, n_T+1, (x.shape[0],))
#噪声图
noise = torch.randn_like(x) # eps  ~ N(0, 1),这就是公式中的eps
sqrtab = torch.sqrt(alphabar_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
x_t = (
            sqrtab[_ts, None, None, None] * x
            + sqrtmab[_ts, None, None, None] * noise
        )

## x_t就是经过了_ts步加噪以后的有x生成的噪声图

这里我们就推断出了经过N次加噪以后的合成图,我们的Unet的输入是利用加噪后的图像和步长信息反推出上一步的噪声图,但是上一步的噪声图其实就是服从标准正太分布N(0,1)的噪声图
所以我们利用这两个信息来计算一下loss

loss = nn.MSELoss(noise, uNet(x_t, _ts / self.n_T))

loss的计算就算是搞明白了,接下来我们开始计算恢复的过程
由于恢复图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/454563.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Taro React组件开发(9) —— RuiCountDown 倒计时

1. 需求实现 根据传入的格式,返回倒计时的文本字段;时间格式需要自定义,需要返回对应时间的值;对毫秒级的时间进行渲染;自定义时间的样式;手动控制倒计时的开始、暂停和重置。2. 需求实现 查找网上类似组件 uView CountDown 倒计时;由于 uView CountDown 倒计时 是使用 …

深度学习 - 42.特征交叉与 SetNET、Bilinear Interaction 与 FiBiNet

目录 一.引言 二.摘要 - ABSTRACT 三.介绍 - INTRODUCTION 四.相关工作 - RELATED WORK 1.因式分解机及其变体 - Factorization Machine and Its relevant variants 2. 基于深度学习的点击率模型 - Deep Learning based CTR Models 3.SENET Module 五.FiBiNet Model 1…

【嵌入式】HC32F定时器PWM捕获+APC芯片实现模拟AD采样

目录 一 项目背景 二 原理说明 三 设计实现——定时器初始化 四 设计实现——PWM捕获 五 梳理总结 一 项目背景 目前使用了TI的ADC采样芯片ADS1018实现模拟量4-20mA/0-20mA的采样,原理是将外部输入的模拟量信号4-20mA,经由并联的两个100Ω电阻&#…

day-01 one-day projects

个人名片: 😊作者简介:一名大一在校生,web前端开发专业 🤡 个人主页:python学不会123 🐼座右铭:懒惰受到的惩罚不仅仅是自己的失败,还有别人的成功。 🎅**学习…

AIGC席卷,抖快、阅文、知乎大战网文圈

配图来自Canva可画 成熟的网文市场,时不时进来一条鲶鱼。 经历了二十几个夏秋秋冬,网文市场形成了阅文、晋江、七猫、番茄等平台割据一方稳定的市场格局。后来暗自布局网文市场的知乎、抖音、快手等新玩家开始浮出水面,未来的市场纷争下或许…

Docker持久化方式-v和-volume的区别

docker数据的持久化一直用的是-v的方式,又叫Bind Mounts(目录绑定),偶然间发现还有一种通过卷轴来实现持久化的方式,翻了下资料,整理了一下两种方式使用的场景。 -v(Bind Mounts) …

使用 Apache PDFBox 操作PDF文件

简介 Apache PDFBox库是一个用于处理PDF文档的开源Java工具。该项目允许创建新的PDF文档,操作现有PDF文档,并从PDF文档中提取内容。Apache PDFBox还包括几个命令行实用程序。 Apache PDFBox的主要功能如下: 从PDF文件中提取Unicode文本。将…

浅析提高倾斜摄影超大场景的三维模型轻量化的数据质量关键技术

浅析提高倾斜摄影超大场景的三维模型轻量化的数据质量关键技术 倾斜摄影超大场景的三维模型轻量化的质量关键技术主要包括: 1、保持数据精度。在进行轻量化处理时,必须确保数据的精度不受损失,否则会影响后续分析和应用方案。因此&#xff0…

接口测试不再难。这篇文章让你在5分钟内掌握接口自动化测试用例

目录 摘要: 一、背景 二、测试用例设计 三、测试脚本实现 四、最佳实践和技巧 总结 摘要: 本文介绍了接口自动化测试的重要性,并提供了一个简单的测试用例,涵盖了设计、条件、步骤和数据方面的考虑。通过使用Python中的req…

C/C++|物联网开发入门+项目实战|函数输入与输出|值传递|地址传递|连续空间的传递|嵌入式C语言高级|C语言函数的使用(1)-学习笔记(11)

文章目录 函数概述输入参数示例:值传递地址传递连续空间的传递 参考: 麦子学院-嵌入式C语言高级-C语言函数的使用 函数概述 一堆代码的集合,用一个标签去描述它 复用化,降低冗余度 标签 ------ 函数名 函数和数组都属于内存空间&#xff0c…

C语言system讲解

‘system’是C语言标准库中的一个函数,它的作用是对计算机系统进行操作,如创建文件夹,打开文件夹,清空屏幕等等,下面介绍一下常用的几个system命令 system函数原型 int system(const char* command); command是字符…

联发科的好日子结束,出货量暴跌,高通稳住阵脚并开始反击

在手机芯片市场连续3年时间顺风顺水之后,联发科终于迎来了高通的反击,特别是骁龙8G2的发布更是导致联发科在手机芯片市场的步步后退,推动了高通的反弹。 一、形势有利于高通 高通此前的骁龙8G1和骁龙888因出现发热问题,因此被誉为…

4. 线性表

4. 线性表 线性表是最基本、最简单、也是最常用的一种数据结构(逻辑结构)。一个线性表是n个具有相同特性的数据元素的有限序列。 前驱元素: 若A元素在B元素的前面,则称A为B的前驱元素 后继元素: 若B元素在A元素的后面,则称B为…

【翻译一下官方文档】之uniapp的界面弹框交互

大致分 3 种 普通提示loading框弹出选项 我个人理解就是大致知道有些什么,有啥功能,用到的时候,直接去用,不会的回来翻看文档 uni.showToast(OBJECT) 参数类型必填说明平台差异说明titleString是提示的内容,长度与…

C++ 多态详解

目录 多态的概念 定义 C直接支持多态条件 举例 回顾继承中遇到的问题 虚函数-虚函数指针-虚函数列表 虚函数 虚函数指针 虚函数列表 虚函数调用流程 虚函数于普通成员函数的区别 多态的概念 定义 多态:相同的行为方式导致了不同的行为结果,同一行…

【翻译一下官方文档】之uniapp的.sync修饰符

先用一个案例引出.sync修饰符 就是这样一个场景 父组件直接修改状态A当然没问题,但是子组件不能直接修改状态A,因为单向数据流限制 单向数据流 uni-app官网 所有的 prop 都使得其父子 prop 之间形成了一个单向下行绑定:父级 prop 的更新会…

AFP vs SMB vs NFS: 谁是最好的数据传输协议?

目录 SMB: 什么是SMB 协议? NFS: 什么是NFS协议? AFP: 设么是AFP协议? 如何选择合适的传输协议? 场景1: 大型企业 场景2: 小型网站设计公司 场景3: Linux软件开发组 可以在互联网上使用这些协议吗? AFP vs SMB vs NFS …

Docker的安装和镜像容器的基本操作

Docker的安装和镜像容器的基本操作 Docker 概述Docker与虚拟机的区别namespace的六项隔离Docker核心概念 安装 DockerDocker 镜像操作搜索镜像获取镜像镜像加速下载查看镜像信息查看下载的镜像文件信息查看下载到本地的所有镜像根据镜像的唯一标识 ID 号,获取镜像详…

基于struts + spring + hibernate的题库与试卷管理系统源码

3需求分析和设计方案 3.1 题库管理 3.1.1 试题管理需求分析 试题管理是整个系统非常核心的模块,它基于知识点模块、章节模块、课程模块、题型管理模块完成的基础上的。其中核心元素是试题,通过试题将题库中的各模块连接起来。 试题管理分为题库录入和…

MyBatisPlus学习

官网:https://mp.baomidou.com/ MyBatis Plus,简化 MyBatis ! 1.概述 需要的基础:把我的MyBatis、Spring、SpringMVC就可以学习这个了! 为什么要学习它呢?MyBatisPlus可以节省我们大量工作时间&#xff0…