从头开始制作扩散模型(实现快速扩散模型的简单方法)

news2025/1/16 14:57:48

一、说明

        本文是关于自己从头开始构建扩散模型的教程。我总是喜欢让事情变得简单易行,所以在这里,我们避免了复杂的数学。这不是一个正常的扩散模型。相反,我称之为快速扩散模型。将仅使用卷积神经网络(CNN)来制作扩散模型。在本文中,我不会为您提供任何现有的模型/权重/脚本文件。


您需要自己训练模型。
(我们正在使用TensorFlow提供的CIFAR-10数据集。

你可以在我的 GitHub
https://github.com/Seachaos/Tree.Rocks/blob/main/QuickDiffusionModel/QuickDiffusionModel.ipynb中找到代码

二、这个想法

        这就是扩散模型的工作原理:它就像基于一个完全嘈杂的图像,并逐渐提高图像质量,直到它变得清晰。
(如下图所示)

扩散模型示例改善了图像

        因此,我们可以创建一个深度学习模型,可以提高图像质量(从全噪声到清晰的图像),流程思想:

快速扩散模型流程

        为了更清晰地了解,请查看此附加流程图。

图像在扩散模型中的流动方式

        如上图所示,该模型正在尝试生成噪声逐渐减少的图像。现在,我们只需要训练一个深度学习模型来学习如何减少噪音。
        对于该任务,我们需要模型中的两个输入:

  • 输入图像 — 需要处理噪声图像
  • 时间戳 — 告诉模型什么是噪声状态,以便更容易学习

三、实现快速扩散模型

        首先,让我们导入我们需要的内容:

import numpy as np

from tqdm.auto import trange, tqdm
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers

        并准备我们的数据集, 在本教程中,我们将使用大量汽车图像(CIFAR-10)作为示例,以使事情尽可能简单快捷。
(但是,如果您有足够的样本,则可以选择您喜欢的任何图像。

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train[y_train.squeeze() == 1]
X_train = (X_train / 127.5) - 1.0

        接下来,让我们定义变量。

IMG_SIZE = 32     # input image size, CIFAR-10 is 32x32
BATCH_SIZE = 128  # for training batch size
timesteps = 16    # how many steps for a noisy image into clear
time_bar = 1 - np.linspace(0, 1.0, timesteps + 1) # linspace for timesteps

        在这里,我们设置“时间步长”,这意味着我们的模型将学习通过训练过程生成从嘈杂(级别 0)到清晰(级别 16)的图像。

让我们看一张图片以获得更清晰的想法

plt.plot(time_bar, label='Noise')
plt.plot(1 - time_bar, label='Clarity')
plt.legend()
图像噪点和清晰度随时间步长的变化

        如您所见,从时间步长 0 到 16,噪音减少,清晰度逐渐提高。这就是我们希望我们的模型学习的内容。

        并为预览数据准备一些功能

def cvtImg(img):
    img = img - img.min()
    img = (img / img.max())
    return img.astype(np.float32)

def show_examples(x):
    plt.figure(figsize=(10, 10))
    for i in range(25):
        plt.subplot(5, 5, i+1)
        img = cvtImg(x[i])
        plt.imshow(img)
        plt.axis('off')

show_examples(X_train)

CIFAR-10 汽车

3.1 培训准备

        在这里,我们需要准备训练图像的代码。

        这个想法是从随机时间点获得两个图像(A和B),其中A是噪声图像,B是更清晰的图像。
我们的模型将学习根据该特定时间点将A转换为B(从嘈杂到更清晰)。
(再次作为此图)

图像 A 在上面,图像 B 在下面

        因此,我们在这里forward_noise功能。

def forward_noise(x, t):
    a = time_bar[t]      # base on t
    b = time_bar[t + 1]  # image for t + 1
    
    noise = np.random.normal(size=x.shape)  # noise mask
    a = a.reshape((-1, 1, 1, 1))
    b = b.reshape((-1, 1, 1, 1))
    img_a = x * (1 - a) + noise * a
    img_b = x * (1 - b) + noise * b
    return img_a, img_b
    
def generate_ts(num):
    return np.random.randint(0, timesteps, size=num)

# t = np.full((25,), timesteps - 1) # if you want see clarity
# t = np.full((25,), 0)             # if you want see noisy
t = generate_ts(25)             # random for training data
a, b = forward_noise(X_train[:25], t)
show_examples(a)

        如果你想了解它是如何工作的,我建议运行我注释掉的代码。( t = ... )

预览训练数据示例

3.2 构建 CNN 块

        我们将使用 U-Net 作为我们的模型,详细信息将在下面的代码中解释。

        模型架构,详细内容会在后面的代码中讲解,在构建模型之前,我们需要先定义块。
        这是 make 块的代码:

def block(x_img, x_ts):
    x_parameter = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)
    x_parameter = layers.Activation('relu')(x_parameter)

    time_parameter = layers.Dense(128)(x_ts)
    time_parameter = layers.Activation('relu')(time_parameter)
    time_parameter = layers.Reshape((1, 1, 128))(time_parameter)
    x_parameter = x_parameter * time_parameter
    
    # -----
    x_out = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)
    x_out = x_out + x_parameter
    x_out = layers.LayerNormalization()(x_out)
    x_out = layers.Activation('relu')(x_out)
    
    return x_out

        每个块包含两个带有时间参数的卷积网络,允许网络确定其当前的时间步长并输出相应的信息。
        您可以看到块流程图:
                (x_img 是输入图像,是噪声图像,x_ts 是时间步长的输入)

块的流向

搭建模型,现在我们可以构建我们的模型

def make_model():
    x = x_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')
    
    x_ts = x_ts_input = layers.Input(shape=(1,), name='x_ts_input')
    x_ts = layers.Dense(192)(x_ts)
    x_ts = layers.LayerNormalization()(x_ts)
    x_ts = layers.Activation('relu')(x_ts)
    
    # ----- left ( down ) -----
    x = x32 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)
    
    x = x16 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)
    
    x = x8 = block(x, x_ts)
    x = layers.MaxPool2D(2)(x)
    
    x = x4 = block(x, x_ts)
    
    # ----- MLP -----
    x = layers.Flatten()(x)
    x = layers.Concatenate()([x, x_ts])
    x = layers.Dense(128)(x)
    x = layers.LayerNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Dense(4 * 4 * 32)(x)
    x = layers.LayerNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Reshape((4, 4, 32))(x)
    
    # ----- right ( up ) -----
    x = layers.Concatenate()([x, x4])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)
    
    x = layers.Concatenate()([x, x8])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)
    
    x = layers.Concatenate()([x, x16])
    x = block(x, x_ts)
    x = layers.UpSampling2D(2)(x)
    
    x = layers.Concatenate()([x, x32])
    x = block(x, x_ts)
    
    # ----- output -----
    x = layers.Conv2D(3, kernel_size=1, padding='same')(x)
    model = tf.keras.models.Model([x_input, x_ts_input], x)
    return model

model = make_model()
# model.summary()

这是一个U-Net,左、右、MLP部分可以参考上图(模型架构)。

不要忘记编译模型

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = tf.keras.losses.MeanAbsoluteError()
model.compile(loss=loss_func, optimizer=optimizer)

        我们使用 Adam 作为优化器,使用 MeanAbsoluteError (MAE) 作为损失函数。

        预测结果:我们现在可以尝试我们的第一个预测。预测步骤如下:

  1. 创建嘈杂的图像
  2. 以时间步长输入到我们的模型中
  3. 继续这样做直到时间步结束

        所以这是这个函数:

def predict(x_idx=None):
    x = np.random.normal(size=(32, IMG_SIZE, IMG_SIZE, 3))
    for i in trange(timesteps):
        t = i
        x = model.predict([x, np.full((32), t)], verbose=0)
    show_examples(x)

predict()

        未经训练的模型输出图像 上面是我们的未经训练的模型输出,如您所见,它没有任何用处。 这个函数还可以帮助我们查看每个步骤:

def predict_step():
    xs = []
    x = np.random.normal(size=(8, IMG_SIZE, IMG_SIZE, 3))

    for i in trange(timesteps):
        t = i
        x = model.predict([x, np.full((8),  t)], verbose=0)
        if i % 2 == 0:
            xs.append(x[0])

    plt.figure(figsize=(20, 2))
    for i in range(len(xs)):
        plt.subplot(1, len(xs), i+1)
        plt.imshow(cvtImg(xs[i]))
        plt.title(f'{i}')
        plt.axis('off')

predict_step()
未经训练的模型输出步骤

四、训练模型

        这个训练功能很简单

def train_one(x_img):
    x_ts = generate_ts(len(x_img))
    x_a, x_b = forward_noise(x_img, x_ts)
    loss = model.train_on_batch([x_a, x_ts], x_b)
    return loss

        我们只需要提供x_tsx_img(x_a),使我们的模型能够学习如何生成x_b。

        并使其成为纪元函数

def train(R=50):
    bar = trange(R)
    total = 100
    for i in bar:
        for j in range(total):
            x_img = X_train[np.random.randint(len(X_train), size=BATCH_SIZE)]
            loss = train_one(x_img)
            pg = (j / total) * 100
            if j % 5 == 0:
                bar.set_description(f'loss: {loss:.5f}, p: {pg:.2f}%')

        最后,多次运行并逐渐降低学习率

for _ in range(10):
    train()
    # reduce learning rate for next training
    model.optimizer.learning_rate = max(0.000001, model.optimizer.learning_rate * 0.9)

    # show result 
    predict()
    predict_step()
    plt.show()

        你可以得到一些这样的输出图像

快速扩散模型输出示例

五、结论

        本教程设计简单,允许您进行实验。您可以尝试自己的参数(如更改图像大小,CNN过滤器,时间步长或MLP等)和更多的时期训练以获得更好的结果。海沌

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1016302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LwIP介绍

文章目录 一、LwIP简介二、LwIP主要特性:三、文件说明lwip-2.1.3contrib-2.1.0一、LwIP简介 lwIP(Light weight IP)是瑞典计算机科学院(SICS)的Adam Dunkels 开发的一个小型开源的TCP/IP协议栈。LwIP是Light Weight (轻型)IP协议,有无操作系统的支持都可以运行。LwIP 的设…

【Obsidian】中编辑模式和阅读模式光标乱跳问题以及编辑模式中段落聚集的问题解决

前言 最近用Obsidian 软件写md笔记,但是当我分别使用编辑模式和阅读模式时出现了光标乱跳的问题。比如我在编辑模式,光标停留在第500行,但是切换成编辑模式就变成了1000行。而且光标根本没停在原来的位置。这样重新定位非常麻烦。 两种阅读…

mybati缓存了解

title: “mybati缓存了解” createTime: 2021-12-08T12:19:5708:00 updateTime: 2021-12-08T12:19:5708:00 draft: false author: “ggball” tags: [“mybatis”] categories: [“java”] description: “mybati缓存了解” mybatis的缓存 首先来看下mybatis对缓存的规范&…

Web服务(Web Service)

简介 Web服务(Web Service)是一种Web应用开发技术,用XML描述、发布、发现Web服务。它可以跨平台、进行分布式部署。 Web服务包含了一套标准,例如SOAP、WSDL、UDDI,定义了应用程序如何在Web上实现互操作。 Web服务的服…

第十九章、【Linux】开机流程、模块管理与Loader

19.1.1 开机流程一览 以个人计算机架设的 Linux 主机为例,当你按下电源按键后计算机硬件会主动的读取 BIOS 或 UEFI BIOS 来载入硬件信息及进行硬件系统的自我测试, 之后系统会主动的去读取第一个可开机的设备 (由 BIOS 设置的) …

线程安全问题的原因及解决方案

要想知道线程安全问题的原因及解决方案,首先得知道什么是线程安全,想给出一个线程安全的确切定义是复杂的,但我们可以这样认为:如果多线程环境下代码运行的结果是符合我们预期的,即在单线程环境应该的结果,…

基于 IntelliJ 的 IDE 将提供 Wayland 支持

导读对于使用 IntelliJ 开发环境的用户,JetBrains 一直致力于提供原生 Wayland 支持。 JetBrains 正在致力于为基于 IntelliJ 的 IDE 提供 Wayland 支持,以增强 Linux 桌面体验以及在 Windows Subsystem for Linux 下运行。 Wayland 支持功能尚未完成&…

Jmeter性能实战之分布式压测

分布式执行原理 1、JMeter分布式测试时,选择其中一台作为调度机(master),其它机器作为执行机(slave)。 2、执行时,master会把脚本发送到每台slave上,slave 拿到脚本后就开始执行,slave执行时不需要启动GUI&#xff0…

专栏十:10X单细胞的聚类树绘图

经常在文章中看到对细胞群进行聚类,以证明两个cluster之间的相关性,这里总结两种绘制这种图的方式和代码,当然我觉得这些五颜六色的颜色可能是后期加的,本帖子只总结画树状图的方法 例一 文章Single-cell analyses implicate ascites in remodeling the ecosystems of pr…

zemax慧差与消慧差

基础设置: 该表面用于对系统的波前进行调制,得到想要的波前形状 通过理想透镜的光线在像空间聚焦,得到完美的球面波,经过调制可以模拟出任意的像差 这里的系数为泽尼克系数 1:平移 2:x轴倾斜 3&#x…

C盘简易无门槛清理指南

C盘在日常使用过程中会逐渐越来越少明明什么也没装,C盘空间却满了,导致最后爆满出现系统运行变慢,软件卡等现象。但随便删除一些东西,系统就崩溃了。本篇分析原因和介绍一些解决方法。 爆满原因主要分为四大类: 一&a…

浅谈C++|文件篇

引子&#xff1a; 程序运行时产生的数据都属于临时数据&#xff0c;程序一旦运行结束都会被释放通过文件可以将数据持久化。C中对文件操作需要包含头文件< fstream > 。 C提供了丰富的文件操作功能&#xff0c;你可以使用标准库中的fstream库来进行文件的读取、写入和定位…

Mobirise for Mac:轻松创建手机网站的手机网站建设软件

如果你是一位设计师或者开发人员&#xff0c;正在寻找一款强大的手机网站建设软件&#xff0c;那么Mobirise for Mac绝对值得你尝试。这个独特的应用程序将帮助你轻松创建优雅而实用的手机网站&#xff0c;而无需编写复杂的代码。 Mobirise for Mac的主要特点包括&#xff1a;…

Java ReentrantLock锁源码走读

目录 多线程例子程序&#xff1a;两个线程累加共享变量&#xff0c;结果正确非公平锁加锁&#xff08;即 lock.lock();&#xff09;过程非公平锁解锁&#xff08; lock.unlock();&#xff09;过程公平锁公平锁的加锁逻辑公平锁的释放锁逻辑 多线程例子程序&#xff1a;两个线程…

【JavaSE笔记】继承与多态(万字详解)

一、前言 在Java的核心概念中&#xff0c;继承和多态无疑是重要的一环。它们都是Java以及其他许多面向对象编程语言的基石&#xff0c;为我们提供了强大的工具来创建模块化&#xff0c;可重用和易于维护的代码。继承让我们可以创建新的类&#xff0c;通过继承现有类的属性和方…

关于单片机的分频定时器的记录

记录一内部时钟&#xff1a; 对于单片机的频率原来一直不太明白&#xff0c;现在在学习进行记录&#xff1a; 主频&#xff1a; 以一个72M的STM32单片机作为主频为例子&#xff0c;这个72M主频说得是一秒钟产生72000000&#xff08;七千两百万&#xff09;个脉冲或周期&…

POLARDB IMCI 白皮书 云原生HTAP 数据库系统 一 数据压缩打更新 (本篇有数据到列节点异步但不延迟的解释)...

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;在新加的朋友会分到2群&#xff08;共…

使用ChatGPT和Blender绘制金色球的完整指南

简介&#xff1a; 在本篇博客中&#xff0c;我们将了解如何结合使用ChatGPT和Blender来创建一个金色的球体。ChatGPT是OpenAI开发的强大自然语言处理模型&#xff0c;而Blender则是一款流行的3D建模和渲染软件。通过结合这两个工具&#xff0c;您可以获得详细的指导&#xff0c…

【JavaEE】_JavaScript(WebAPI)

目录 1. DOM 1.1 DOM基本概念 1.2 DOM树 2. 选中页面元素 2.1 querySelector 2.2 querySelectorAll 3. 事件 3.1 基本概念 3.2 事件的三要素 3.3 示例 4.操作元素 4.1 获取/修改元素内容 4.2 获取/修改元素属性 4.3 获取/修改表单元素属性 4.3.1 value&#xf…

04条件构造器和常用接口

条件构造器和常用接口 wapper介绍 条件构造器的两个条件之间默认就是AND并列关系,如果需要或者的关系则需要调用构造器的or()方法 条件构造器类型作用Wrapper条件构造抽象类,最顶端父类AbstractWrapper生成SQL的where条件QueryWrapper封装查询或删除的条件UpdateWrapper封装修…