Diffusion 扩散模型(DDPM)详解及torch复现

news2024/7/6 19:24:26

文章目录

  • torch复现
    • 第1步:正向过程=噪声调度器
    • Step 2: 反向传播 = U-Net
    • Step 3: 损失函数
    • 采样
    • Training

我公众号文章目录综述:

https://wangguisen.blog.csdn.net/article/details/127065903

保姆级讲解 Diffusion 扩散模型(DDPM)

https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247486128&idx=1&sn=7ffef5d8c1bbf24565d0597eb5eaeb16&chksm=c337b729f4403e3f4ca4fcc1bc04704f72c1dc02876a2bf83c330e7857eba567864da6a64e18#rd

torch复现

先看下数据样子

import torch
import torchvision
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"

def show_images(datset, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15)) 
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols + 1), cols, i + 1)
        plt.imshow(img[0])

data = torchvision.datasets.StanfordCars(root=".", download=True)
show_images(data)

在这里插入图片描述

第1步:正向过程=噪声调度器

我们首先需要为我们的模型构建输入,这些输入是越来越多的噪声图像。我们可以使用论文中提供的封闭形式来单独计算任何时间步长的图像,而不是按顺序执行此操作。

详情点:

  • 可以预先计算噪声水平/方差
  • 有不同类型的方差表
  • 我们可以对每个时间步长的图像进行独立采样(高斯之和也是高斯的)
  • 在这个前进步骤中不需要模型
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    返回所传递的值列表vals中的特定索引,同时考虑到批处理维度。
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device=device):
    """ 
    接收一个图像和一个时间步长作为输入,并 返回它的噪声版本
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    #均值+方差
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# 界定测试时间表
T = 300
betas = linear_beta_schedule(timesteps=T)

# 预先计算闭合形式的不同项
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

让我们在我们的数据集上测试一下:

from torchvision import transforms 
from torch.utils.data import DataLoader
import numpy as np

IMG_SIZE = 64
BATCH_SIZE = 128

# 数据转换
def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1] 
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.StanfordCars(root=".", download=True, 
                                         transform=data_transform)

    test = torchvision.datasets.StanfordCars(root=".", download=True, 
                                         transform=data_transform, split='test')
    return torch.utils.data.ConcatDataset([train, test])

#tensor转化成图像
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

模拟正向扩散:

image = next(iter(dataloader))[0]

plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    image, noise = forward_diffusion_sample(image, t)
    show_tensor_image(image)

在这里插入图片描述

Step 2: 反向传播 = U-Net

U-Net教程: https://amaarora.github.io/2020/09/13/unet.html.

详情点:

  • 我们使用一种简单形式的 UNet 来预测图像中的噪声
  • 输入是噪声图像,输出图像中的噪声
  • 因为参数是跨时间共享的,所以我们需要告诉网络我们在哪个时间步长
  • Timestep 由变压器 Sinusoidal Embedding 编码
  • 我们输出一个单一的值(均值),因为方差是固定的
from torch import nn
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # 第一次卷积
        h = self.bnorm1(self.relu(self.conv1(x)))
        # 时间嵌入
        time_emb = self.relu(self.time_mlp(t))
        # 扩展到最后2个维度
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # 添加时间通道
        h = h + time_emb
        # 第二次卷积
        h = self.bnorm2(self.relu(self.conv2(h)))
        # 上采样或者下采样
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class SimpleUnet(nn.Module):
    """
    Unet架构的一个简化版本
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1 
        time_emb_dim = 32

        # 时间嵌入
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # 初始预估
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # 下采样
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # 上采样
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # 时间嵌入
        t = self.time_mlp(timestep)
        # 初始卷积
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # 添加残差结构作为额外的通道
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

Step 3: 损失函数

def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

采样

  • 如果不加入@torch.no_grad(),我们很快就会耗尽内存,因为pytorch会把之前所有的图像都打包用于梯度计算

  • 因为我们预先计算了前向通道的噪声方差,所以当我们依次执行后向过程时也必须使用这些方差

@torch.no_grad()#防止内存爆炸
def sample_timestep(x, t):
    """
    调用模型来预测图像中的噪声,并返回 
    去噪后的图像。
    如果我们还没有进入最后一步,则对该图像施加噪声。
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # 调用模型(当前图像--噪声预测)。
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image():
    # 样本噪声
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize+1))
            show_tensor_image(img.detach().cpu())
    plt.show()            

Training

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
      loss = get_loss(model, batch[0], t)
      loss.backward()
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        sample_plot_image()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/190460.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python面向对象:入门

python面向对象:入门 文章目录python面向对象:入门一、实验目的二、实验原理三、实验环境四、实验内容五、实验步骤1.创建类和对象2.构造方法3.类中的实例方法一、实验目的 掌握类的基本使用 二、实验原理 面向过程:根据业务逻辑从上到下写…

Win10 BCD文件损坏怎么修复?

在Windows 10操作系统中,BCD代表引导配置数据(Boot Configuration Data),Windows运行时BCD将告诉Windows引导加载程序在哪里查找引导信息, 因此,它对成功地加载和运行操作系统是非常重要的。哪些情况下会导…

二叉树的前序遍历

题目144. 二叉树的前序遍历 - 力扣(LeetCode)本章我们来探讨运用二叉树中所学的知识解决本题。对二叉树仍有疑问的小伙伴可以点击下方链接哦。参考文献:(1条消息) 二叉树(三)_染柒_GRQ的博客-CSDN博客原理首先我们来回…

CyberBattleSim-(内网自动化渗透)研究分析

**01 **背景知识介绍 CyberBattleSim介绍 CyberBattleSim是一款微软365 Defender团队开源的人工智能攻防对抗模拟工具,来源于微软的一个实验性研究项目。该项目专注于对网络攻击入侵后横向移动阶段进行威胁建模,用于研究在模拟抽象企业网络环境中运行的…

基于蜣螂算法优化的核极限学习机(KELM)分类算法-附代码

基于蜣螂算法优化的核极限学习机(KELM)分类算法 文章目录基于蜣螂算法优化的核极限学习机(KELM)分类算法1.KELM理论基础2.分类问题3.基于蜣螂算法优化的KELM4.测试结果5.Matlab代码摘要:本文利用蜣螂算法对核极限学习机(KELM)进行优化,并用于分类1.KELM理…

R语言贝叶斯方法在生态环境领域中的应用

贝叶斯统计已经被广泛应用到物理学、生态学、心理学、计算机、哲学等各个学术领域,其火爆程度已经跨越了学术圈,如促使其自成统计江湖一派的贝叶斯定理在热播美剧《The Big Bang Theory》中都要秀一把。贝叶斯统计学即贝叶斯学派是一门基本思想与传统基于…

STM32CubeMx使用FreeRTOS搭建printf输出串口打印-----基于正点原子开发板阿波罗

文章目录STM32CubeMx使用FreeRTOS搭建printf输出串口打印-----基于正点原子开发板阿波罗1.输入目标芯片2.选择RCC时钟3.配置调试模式4.USART的配置5.配置中断6.printf的重定向功能7.代码添加8.修改中断函数9.添加全局变量10.增加FreeRTOS支持11.在FreeRTOS中添加源码STM32CubeM…

【数学建模】数学建模中的常用工具推荐

前言 整理了几款我在建模比赛中需要准备的小工具,后续会随时不定期更新,以及完善内容,需要的小伙伴建议收藏一波~ [1] 公式编译器 Axmath(建议购买正版)mathtypeWord内置的公式编辑器 Axmath是国产的软件&#xff0…

pytorch安装(模式识别与图像处理课程实验)

pytorch安装(模式识别与图像处理课程实验)1、 打开cmd,创建torch虚拟环境。2、 激活创建的torch虚拟环境2.1、 进入pytorch官网,复制如下的命令,进行pytorch的安装2.2、测试安装是否成功3、 通过pip命令安装pytorch&am…

基于WPS实现Excel表的二级下拉选择框

基于WPS实现Excel表的二级下拉选择框第一步:先在sheet2上创建源数据第二步:创建一级下拉框第三步:创建二级下拉框报错记录: “列表源”XXXXXX第一步:先在sheet2上创建源数据 第二步:创建一级下拉框 一级下…

难受啊,备战字节跳动132天,因为一个疏忽让我前功尽弃...

📌 博客主页: 程序员二黑 📌 专注于软件测试领域相关技术实践和思考,持续分享自动化软件测试开发干货知识! 📌 如果你也想学习软件测试,文末卡片有我的交流群,加入我们,一…

Kibana报错:Kibana server is not ready yet

背景 网页中访问kinaba http://localhost:5601 ,一直提示“Kibana server is not ready yet”。 执行如下命令查看kibana日志, docker logs kibana 发现有提示: 正文 怀疑是不是容器重启后,各容器内部ip变化了导致。 1、故执行如下命令查看…

Android框架源码分析——从设计模式角度看 RxJava 核心源码

从设计模式角度来看 RxJava 核心源码 从订阅者模式和装饰器模式的角度来看 RxJava 源码。 1. 普通订阅者模式与 RxJava 中的订阅者模式 订阅者模式又叫做观察者模式,主要用于实现响应式编程。其本质,就是接口回调。 普通订阅者模式:多个对…

来啦,华东师范大学2024年入学MBA提前面试流程及时间

项目简介华东师范大学系国家教育部和国务院学位办【2007(36)号】批准的工商管理硕士(MBA)培养单位。华东师范大学MBA项目,依托学校深厚的人文底蕴和育人文化,利用多学科支撑的优势、利用多元化的办学资源&a…

CUAD学习笔记

目录一、头文件**1、mex.h****2、matrix.h****3、string****4、iostream****5、omp.h**6、cuda_runtime.h7、stdlib.h8、sys/time.h9、stdio.h10、string.h11、time.h12、math.h13、device_launch_parameters.h二、一些声明语句1、using namespace std**2、typedef unsigned ch…

pytorch数据读取深入理解

来源:投稿 作者:小灰灰 编辑:学姐 了解数据 Q:我现在什么基础也没有,我要学习深度学习,学习cv,学习nlp。 A:首先我们知道,深度学习是建立在数据集的基础上。现在呢&…

【C++】C++11 ~ 右值引用和移动语义

🌈欢迎来到C专栏~~右值引用和移动语义 (꒪ꇴ꒪(꒪ꇴ꒪ )🐣,我是Scort目前状态:大三非科班啃C中🌍博客主页:张小姐的猫~江湖背景快上车🚘,握好方向盘跟我有一起打天下嘞!送给自己的一…

mongodb 使用密钥文件身份验证部署副本集

一 副本集介绍 集群中每个节点有心跳检测 如果由于资源限制,可以部署一主一从一仲裁 副本集集群可以实现主从的自动切换 Read Preference 在客户端连接中,可以实现读取优先,就是连接器会自动判断,把读取请求发送到副本集中的…

whois命令常见用法

whois命令常见用法whois命令简介安装whoisWindows使用whoisLinux安装whoiswhois常见用法Linux下whois查询域名注册信息whois命令简介 whois就是一个用来查询域名是否已经被注册,以及注册域名的详细信息的数据库(如域名所有人、域名注册商)。…

分析第一个安卓项目

整体分析 .gradle和.idea 这两个目录下放置的都是Android Studio自动生成的一些文件,我们无须关心,也不要去手动编辑。 app 项目中的代码、资源等内容几乎都是放置在这个目录下的。 gradle 这个目录下包含了gradle wrapper的配置文件,使…