Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

news2024/12/30 3:08:21

1、绪论

1.1 概念

为便于后文讨论,首先进行相关概念的陈述。

  • Stable Diffusion:Stable Diffusion 是一个在图像生成领域广泛使用的技术,尤其是用于文本到图像的转换。它基于扩散模型(Diffusion
    Models),这是一种深度生成模型,通过逐步去除图像中的噪声来生成新图像。Stable Diffusion
    是一种特定的实现,以其稳定性和高质量的图像生成能力而闻名。

  • KerasCV:Keras 是一个高级神经网络API,能够运行在 TensorFlow、Theano 或 CNTK 之上。KerasCV 是 Keras 的一个扩展库,专注于计算机视觉任务。虽然目前并没有一个官方的 KerasCV 库被广泛接受,但这样的命名通常意味着一个专注于计算机视觉的 Keras 扩展。

  • 高性能图像生成:这指的是使用高效且强大的计算方法来生成高质量的图像。在深度学习和计算机视觉中,高性能通常意味着模型能够在较短时间内处理大量数据,同时保持或提高生成的图像质量。

1.2 本文探讨的范围

本文将展示如何使用KerasCV实现的稳定性.ai的文本到图像模型Stable Diffusion,根据文本提示生成新颖的图像。

Stable Diffusion是一个功能强大、开源的文本到图像生成模型。尽管存在多个开源实现,可以让您轻松地从文本提示创建图像,但KerasCV的实现提供了一些独特优势。这些包括XLA编译和混合精度支持,这两者结合在一起实现了最先进的生成速度。

本文将探索KerasCV的Stable Diffusion实现,展示如何使用这些强大的性能提升,并探索它们提供的性能优势。

注意:如果要在torch后端运行,请在所有地方将jit_compile设置为False。Stable Diffusion的XLA编译目前不适用于torch。

1.3 Stable Diffusion的优缺点

Stable Diffusion 是一种强大的开源文本到图像生成模型,它使用扩散模型(Diffusion Models)来逐步去除图像中的噪声,从而生成新的图像。以下是 Stable Diffusion 的一些主要优点和缺点:

优点:

  • 高质量输出:Stable Diffusion 能够根据文本描述生成高分辨率、逼真的图像,满足各种设计需求。
  • 可控的生成过程:通过调整扩散参数和逆向过程,可以控制生成过程的速度和效果,使用户能够根据自己的需求和偏好生成图像。
  • 广泛的适用性:该模型适用于各种领域,如艺术创作、产品设计、虚拟现实等,为这些领域提供了更多的创意和可能性。
  • 可扩展性:随着深度学习技术的不断进步,Stable Diffusion 的性能和效果有望进一步提升。
  • 高稳定性:Stable Diffusion 通过引入新的稳定性系数来控制模型的稳定性,避免了其他扩散模型中出现的不稳定性问题。
  • 较快的训练速度:通过使用更小的批次大小和更少的步骤来训练模型,Stable Diffusion 提高了训练速度。

缺点:

  • 计算资源需求高:Stable Diffusion 的训练和生成过程需要大量的计算资源和时间,尤其是在生成高分辨率图像时。
  • 超参数调优:Stable Diffusion 的性能和效果可能受到许多超参数的影响,如扩散参数、模型架构等,因此需要进行细致的调优。
  • 牺牲多样性:由于引入了稳定性系数,Stable Diffusion 可能会在一定程度上牺牲生成样本的多样性。
  • 生成样本速度可能较慢:虽然训练速度有所提升,但在生成样本时,Stable Diffusion 的速度可能会相对较慢。

总的来说,Stable Diffusion 是一种功能强大的文本到图像生成模型,具有广泛的应用前景。然而,它也面临一些挑战,如计算资源需求高和需要细致的超参数调优。随着技术的不断发展,这些问题有望得到解决。

1.4 设置

安装

!pip install -q --upgrade keras-cv
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import time
import keras_cv
import keras
import matplotlib.pyplot as plt

2、Stable Diffusion的使用方法

2.1 构建模型

使用类似如下的代码构建模型

model = keras_cv.models.StableDiffusion(
    img_width=512, img_height=512, jit_compile=False
)

2.2 编辑文本提示

按照以下代码的方式给模型输入文本提示

images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)

在这里插入图片描述
下面,我们看一个更复杂的示例

images = model.text_to_image(
    "cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(images)

在这里插入图片描述

3、工作原理

Stable Diffusion 并不是基于魔法的,它是一种“潜在扩散模型”。让我们深入了解一下这意味着什么。

程序员可能对超分辨率有所了解:可以通过训练一个深度学习模型来对输入图像进行去噪,从而将其转换为更高分辨率的版本。深度学习模型并不是通过神奇地恢复从嘈杂的低分辨率输入中缺失的信息来实现这一点的——相反,该模型使用其训练数据分布来“想象”出最可能基于输入的视觉细节。要了解更多关于超分辨率的信息,你可以查看以下 Keras.io 教程:

  • 使用高效的子像素卷积神经网络的图像超分辨率
  • 用于单图像超分辨率的增强深度残差网络

当程序员将这个想法推向极致时,可能会开始思考——如果我们只在纯噪声上运行这样的模型会怎样?模型将“对噪声进行去噪”并开始生成一个全新的图像。通过多次重复这个过程,程序员可以将一小块噪声转变为越来越清晰和高分辨率的人造图片。

这2020 年在《High-Resolution Image Synthesis with Latent Diffusion Models》中提出的潜在扩散的关键思想。要深入了解扩散模型,程序员可以查看 Keras.io 上的教程《Denoising Diffusion Implicit Models》。

要将潜在扩散模型转化为文本到图像系统,程序员还需要添加一个关键功能:通过提示关键字控制生成的视觉内容。这是通过“条件化”(conditioning)实现的,这是一种经典的深度学习技术,包括将表示文本的向量与噪声块连接起来,然后在包含 {图像: 描述} 对的数据集上训练模型。

这就产生了 Stable Diffusion 架构。Stable Diffusion 由三部分组成:

  • 文本编码器,将你的提示转换为潜在向量。
  • 扩散模型,反复“去噪”一个 64x64 的潜在图像块。
  • 解码器,将最终的 64x64 潜在块转换为更高分辨率的 512x512 图像。

首先,程序员的文本提示被文本编码器投影到潜在向量空间,这只是一个预训练并固定的语言模型。然后,该提示向量与随机生成的噪声块连接起来,通过一系列的“步骤”被扩散模型反复“去噪”(你运行的步骤越多,图像就越清晰、越漂亮——默认值是 50 步)。

最后,64x64 的潜在图像通过解码器以高分辨率正确渲染。

总的来说,这是一个相当简单的系统——Keras 实现包含在四个文件中,总共不到 500 行代码:

  • text_encoder.py: 87 行代码
  • diffusion_model.py: 181 行代码
  • decoder.py: 86 行代码
  • stable_diffusion.py: 106 行代码

但是,一旦程序员在数十亿张图片及其描述上进行训练,这个相对简单的系统就开始看起来像魔法一样。就像费曼(Feynman)对宇宙的描述:“它并不复杂,只是有很多!”

3.1 KerasCV的Stable Diffusion模型的优点

在多个公开的Stable Diffusion实现中,为什么你应该选择keras_cv.models.StableDiffusion呢?

除了易于使用的API之外,KerasCV的Stable Diffusion模型还带有一些强大的优势,包括:

  • 图模式执行(Graph mode execution):这种方式可以更高效地执行计算图,因为它允许优化和并行化计算。
  • 通过jit_compile=True启用XLA编译:XLA(Accelerated Linear Algebra)是TensorFlow的一个特性,它可以将多个计算步骤融合为一个优化的操作,从而提高执行速度。
  • 支持混合精度计算:混合精度计算允许模型在较低精度的数据类型上运行,如FP16(半精度浮点数),这通常可以提高速度并减少内存消耗,同时保持模型精度。

当这些特性组合在一起时,KerasCV的Stable Diffusion模型运行速度比原始实现快得多。为了进行比较,我们运行了基准测试,比较了HuggingFace的diffusers实现的Stable Diffusion与KerasCV实现的运行时。两个实现的任务都是为每张图像生成3张图像,每张图像的步数为50。在这个基准测试中,我们使用了Tesla T4 GPU。

我们所有的基准测试都是开源的,可以在GitHub上找到,并且可以在Colab上重新运行以重现结果。基准测试的结果如下表所示:

GPU ModelImplementationRuntime (s)
Tesla T4KerasCV (Warm Start)28.97
Tesla T4diffusers (Warm Start)41.33
Tesla V100KerasCV (Warm Start)12.45
Tesla V100diffusers (Warm Start)12.72

在Tesla T4上,执行时间提高了30%!尽管在V100上的提升幅度较小,但我们通常期望在所有的NVIDIA GPU上,基准测试的结果都将一致地偏向于KerasCV。

为了完整性,我们报告了冷启动和热启动的生成时间。冷启动执行时间包括模型创建和编译的一次性成本,因此在生产环境中(在那里你会多次重用相同的模型实例)它是可以忽略的。尽管如此,以下是冷启动的数据:

GPUModelRuntime (Cold Start)
Tesla T4KerasCV83.47s
Tesla T4diffusers46.27s
Tesla V100KerasCV76.43s
Tesla V100diffusers13.90s

注意:每个优化在硬件设置之间的性能提升可能会有很大的差异。

3.2 模型的基准测试方法

为了后续的讨论,我们先来对未优化模型进行基准测试:

benchmark_result = []
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.

在这里插入图片描述

3.3混合精度

“混合精度”是指在计算时使用float16精度,同时将权重存储在float32格式中。这是为了利用现代NVIDIA GPU上float16操作比float32操作具有显著更快内核支持的事实。

在Keras中启用混合精度计算只需调用:

keras.mixed_precision.set_global_policy("mixed_float16")

之后,我们就能够非常简单的使用float16了:

model = keras_cv.models.StableDiffusion(jit_compile=False)

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
    "Variable dtype:",
    model.diffusion_model.variable_dtype,
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
Compute dtype: float16
Variable dtype: float32

这样,上面构建的模型现在使用混合精度计算;在计算时使用float16操作的速度优势,同时以float32精度存储变量。

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()

在这里插入图片描述

3.4XLA编译

TensorFlow和JAX内置了XLA(加速线性代数)编译器。keras_cv.models.StableDiffusion原生支持jit_compile参数。将此参数设置为True可以启用XLA编译,从而显著提高速度。

# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)
y using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
 50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 209ms/step

在这里插入图片描述
XLA对标结果

start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()
50/50 ━━━━━━━━━━━━━━━━━━━━ 11s 210ms/step
With XLA: 10.63 seconds

在这里插入图片描述
在目标GPU上,XLA实现了2倍的速度提升。

4、总结

在构建高性能的稳定扩散推理流程时,关键在于利用现代计算硬件和软件优化技术来最大化模型的推理速度和效率。以下是构建此类流程时需要考虑的关键点:

  • 硬件选择:选择支持高效浮点运算和并行处理的现代GPU硬件,如NVIDIA的GPU系列。这些GPU通常具有强大的计算能力和优化的硬件加速功能,能够显著提升模型的推理速度。

  • 混合精度计算:利用混合精度计算技术,即在计算过程中使用float16精度,并将权重存储在float32格式中。这种策略能够减少内存占用和计算量,同时保持足够的数值稳定性。在NVIDIA GPU上,float16操作通常比float32操作更快,因此可以显著提高推理速度。

  • XLA编译:启用TensorFlow或JAX中的XLA编译器,将计算图编译成优化的机器代码。XLA编译器能够自动进行各种优化,如循环展开、内联函数和向量化等,从而进一步提高推理性能。

  • 模型优化:对稳定扩散模型进行必要的优化,如剪枝、量化或模型压缩等。这些技术可以减少模型的复杂性和参数量,从而加快推理速度并减少内存占用。

  • 批处理:在可能的情况下,使用批处理来一次性处理多个输入数据。这可以充分利用GPU的并行处理能力,提高推理吞吐量。

  • 数据预处理和后处理:优化数据预处理和后处理流程,确保它们与推理流程相匹配并尽可能高效。例如,使用适当的数据加载器和缓存机制来加速数据读取。

  • 软件优化:利用最新的深度学习框架和库(如TensorFlow、PyTorch或JAX),这些库通常包含了许多针对高性能计算的优化功能。同时,确保你的代码是高效且并发友好的,以便充分利用多核CPU和GPU资源。

  • 监控与调优:在推理流程运行过程中进行性能监控,并使用性能分析工具来识别瓶颈和潜在优化点。根据监控结果进行调整和优化,以实现最佳性能。

综上所述,构建高性能的稳定扩散推理流程需要综合考虑硬件选择、计算精度、编译优化、模型优化、批处理、数据预处理和后处理以及软件优化等多个方面。通过结合这些技术和策略,可以显著提高模型的推理速度和效率,从而在处理大规模图像和视频数据时获得更好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Springboot017学生读书笔记共享

springboot005学生心理咨询评估系统 亲测完美运行带论文:获取源码,私信评论或者v:niliuapp 运行视频 Springboot017学生读书笔记共享 包含的文件列表(含论文) 数据库脚本:db.sql其他文件:ppt.ppt论文&am…

Java+Swing+Mysql实现飞机订票系统

一、系统介绍 1.开发环境 操作系统:Win10 开发工具 :Eclipse2021 JDK版本:jdk1.8 数据库:Mysql8.0 2.技术选型 JavaSwingMysql 3.功能模块 4.数据库设计 1.用户表(users) 字段名称 类型 记录内容…

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器(可能需要花钱) 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 ,使用 Ubuntu 也可以 。这里提供几个安装方法: 电脑安装双系统(不…

OpenWrt 安装Quagga 支持ospf Bgp等动态路由协议 软路由实测 系列四

1 Quagga 是一个路由软件套件, 提供 OSPFv2,OSPFv3,RIP v1 和 v2,RIPng 和 BGP-4 的实现. 2 web 登录安装 #或者ssh登录安装 opkg install quagga quagga-zebra quagga-bgpd quagga-watchquagga quagga-vtysh # reboot 3 ssh 登录 #重启服务 /etc/init.d/quagga restart #…

Linux下Vision Mamba环境配置+多CUDA版本切换

上篇文章大致讲了下Vision Mamba的相关知识,网上关于Vision Mamba的配置博客太多,笔者主要用来整合下。 笔者在Win10和Linux下分别尝试配置相关环境。 Win10下配置 失败 \textcolor{red}{失败} 失败,最后出现的问题如下: https://…

ps进程查看命令详解

1、PS 命令是什么 查看它的man手册可以看到,ps命令能够给出当前系统中进程的快照。它能捕获系统在某一事件的进程状态。如果你想不断更新查看的这个状态,可以使用top命令。 2、ps命令支持三种使用的语法格式 UNIX 风格,选项可以组合在一起…

flutter 实现旋转星球

先看效果 planet_widget.dart import dart:math; import package:flutter/material.dart; import package:vector_math/vector_math_64.dart show Vector3; import package:flutter/gestures.dart; import package:flutter/physics.dart;class PlanetWidget extends StatefulW…

Android14 - 绘制系统 - 概览

从Android 12开始,Android的绘制系统有结构性变化, 在绘制的生产消费者模式中,新增BLASTBufferQueue,客户端进程自行进行queue的生产和消费,随后通过Transation提交到SurfaceFlinger,如此可以使得各进程将缓…

Altium Designer 中键拖动,滚轮缩放,并修改缩放速度

我的版本是AD19,其他版本应该都一样。 滚轮缩放 首先,要用滚轮缩放,先要调整一下AD 设置,打开Preferences,在Mouse Wheel Configuration 里,把Zoom Main Window 后面Ctrl 上的对勾取消掉,再把…

翻译《The Old New Thing》- What‘s the deal with the EM_SETHILITE message?

Whats the deal with the EM_SETHILITE message? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20071025-00/?p24693 Raymond Chen 2007年10月25日 简要 文章讨论了EM_SETHILITE和EM_GETHILITE消息在文档中显示为“未实现”的原因。这些…

C语言 | Leetcode C语言题解之第112题路径总和

题目: 题解: bool hasPathSum(struct TreeNode *root, int sum) {if (root NULL) {return false;}if (root->left NULL && root->right NULL) {return sum root->val;}return hasPathSum(root->left, sum - root->val) ||ha…

与用户沟通获取需求的方法

1 访谈 访谈是最早开始使用的获取用户需求的技术,也是迄今为止仍然广泛使用的需求分析技术。 访谈有两种基本形式,分别是正式的和非正式的访谈。正式访谈时,系统分析员将提出一些事先准备好的具体问题,例如&#xff0…

C++_string简单源码剖析:模拟实现string

文章目录 🚀1.构造与析构函数🚀2.迭代器🚀3.获取🚀 4.内存修改🚀5. 插入🚀6. 删除🚀7. 查找🚀8. 交换swap🚀9. 截取substr🚀10. 比较符号重载🚀11…

pod install 报错 ‘SDK does not contain ‘libarclite‘ at the path...‘

报错内容: SDK does not contain ‘libarclite’ at the path ‘/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/arc/libarclite_iphoneos.a’; 这是报错已经很明确告诉我们,Xcode默认的工具链中缺少一个工具…

Linux Tcpdump抓包入门

Linux Tcpdump抓包入门 一、Tcpdump简介 tcpdump 是一个在Linux系统上用于网络分析和抓包的强大工具。它能够捕获网络数据包并提供详细的分析信息,有助于网络管理员和开发人员诊断网络问题和监控网络流量。 安装部署 # 在Debian/Ubuntu上安装 sudo apt-get install…

爬虫100个Python例子优化

今天看到一个Python 100例的在线资源,感觉每个都需要去点,太费时间了,于是,使用Python将数据爬取下来,方便查看。实际效果如下: 。。。。。。 用了13分钟,当然,这是优化后的效果,如果没有优化,需要的时间更长。 爬取url如下: https://www.runoob.com/python/pytho…

Pytorch深度学习实践笔记1

🎬个人简介:一个全栈工程师的升级之路! 📋个人专栏:pytorch深度学习 🎀CSDN主页 发狂的小花 🌄人生秘诀:学习的本质就是极致重复! 《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibi…

【数据结构与算法 刷题系列】移除链表元素

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:数据结构与算法刷题系列(C语言) 期待您的关注 目录 一、问题描述 二、解题思路 三、源代码实现 一、问题…

不靠后端,前端也能搞定接口!

嘿,前端开发达人们!有个超酷的消息要告诉你们:MemFire Cloud来袭啦!这个神奇的东东让你们不用依赖后端小伙伴们,也能妥妥地搞定 API 接口。是不是觉得有点不可思议?但是事实就是这样,让我们一起…

软件项目详细设计说明书实际项目参考(word原件下载及全套软件资料包)

系统详细设计说明书案例(直接套用) 1.系统总体设计 2.性能设计 3.系统功能模块详细设计 4.数据库设计 5.接口设计 6.系统出错处理设计 7.系统处理规定 软件开发全文档下载(下面链接或者本文末个人名片直接获取):软件开发全套资料-…