【原理+使用】DeepCache: Accelerating Diffusion Models for Free

news2024/9/23 5:19:46

论文:arxiv.org/pdf/2312.00858

代码:horseee/DeepCache: [CVPR 2024] DeepCache: Accelerating Diffusion Models for Free (github.com)

介绍


DeepCache是一种新颖的无训练且几乎无损的范式,从模型架构的角度加速了扩散模型。DeepCache利用 扩散模型顺序去噪步骤中观察到的固有时间冗余,缓存和检索相邻去噪阶段的特征,从而减少冗余计算。利用U-Net的特性,重用高级特征,同时以低成本的方式更新低级特征。将 Stable Diffusion v1.5 加速了 2.3 倍,CLIP 分数仅下降了 0.05 倍,LDM-4-G(ImageNet) 加速了 4.1 倍,FID 降低了 0.22。

动机:

由于顺序去噪过程和繁琐的模型尺寸,训练扩散模型会产生大量的计算成本。本文希望在没有额外训练的情况下,减少每个去噪步骤的计算开销,从而实现对扩散模型的无成本压缩。

背景:

反向扩散过程的加速。反向扩散过程的固有性质减慢了推理速度。目前的研究主要集中在两种加速扩散模型推理的方法上:

  1. 优化采样效率。侧重于减少采样步骤的数量。DDIM、一致性模型将随机噪声转换为初始图像,只需要进行一次模型评估。
  2. 优化结构效率。减少每个采样步骤的推理时间。

U-Net的高级和低级特征。由于跳跃式连接,UNet具有很强的合并低级和高级特征的能力。U-Net构建在堆叠的下采样和上采样块上,将输入图像编码为高级表示,然后对其进行解码,用于下游任务。表示为的块对,通过额外的跳过路径连接,直接将低级的信息从Di转发到Ui。在U-Net体系结构的前向传播过程中,数据通过两条路径并发地遍历:主分支和跳过分支。这些分支汇聚在一个连接模块,主分支提供处理过的高级特征,这些特征来自前面的上采样块Ui+1,而跳过分支提供来自对称块Di的相应特征。因此,U-Net模型的核心是来自跳过分支的低级特征和来自主分支的高级特征的连接:

原理

序列去噪中的特征冗余

去噪过程中的相邻步骤在高级特征上表现出显著的时间相似性。

图2实验揭示了两个主要观点:

  1. 在去噪过程中,相邻步骤之间,存在明显的时间特征相似性,表明连续步骤之间的变化通常较小。
  2. 无论使用哪种扩散模型,如稳定扩散、LDM和DDPM,对于每个时间步长,至少有10%的相邻时间步长与当前步长表现出高度相似(>0.95),这表明某些高级特征以渐进的速度变化。

每次计算,得到的特征都与前一步相似,存在大量冗余计算,产生边际效益。本文目标是利用这一特性来加速去噪过程。

扩散模型的深度缓存

DeepCache利用反向扩散过程中步骤之间的时间冗余来加速推理。从计算机系统中的缓存机制中获得灵感,结合了为随时间变化最小的元素设计的存储组件。应用于扩散模型,通过缓存那些变化缓慢的特征,来消除冗余计算,从而无需在后续步骤中重复计算

实现重点为U-Net中的跳过连接,它本质上提供了双路径优势:主分支需要大量的计算来遍行整个网络,而跳过分支只需要通过一些浅层,从而产生非常小的计算负载。主要分支中突出的特征相似性允许重用已经计算的结果,而不是为所有时间步重复计算。

去噪中的可缓存特性。

在两个连续时间步长 𝑡 和 𝑡−1 之间,根据反向过程,𝑥𝑡−1 将基于先前的结果 𝑥𝑡​ 进行条件生成。实验:首先生成 𝑥𝑡​,计算跨整个U-Net进行。为了获得下一个输出 𝑥𝑡−1,我们检索在先前时间步长 𝑡 中生成的高层次特征。即,考虑U-Net中的一个跳跃分支 𝑚,它连接 𝐷𝑚​ 和 𝑈𝑚​,在时间 𝑡 从先前的上采样块缓存特征图:

这是时间步长 𝑡 的主分支中的特征。这些缓存的特征将在后续推理中使用。

在下一个时间步长 𝑡−1 中,推理并不在整个网络上进行,只计算 m-th 跳跃分支中所需的部分,并用缓存中的特征替代主分支的计算。因此,时间步长 𝑡−1 中 𝑈𝑡−1𝑚​ 的输入可表示为:

𝐷𝑡−1𝑚​ 代表 m-th 下采样块的输出,如果选择一个较小的 𝑚,则只包含几层。例如,如果我们在第一层执行 DeepCache 并选择 𝑚=1,则只需要执行一个下采样块以获得。至于第二个特征 ​,由于可以简单地从缓存中检索,因此不需要额外的计算成本。过程如图3.

在第t - 1步,通过重用第t步缓存的特征,来生成xt - 1,并且为了更有效的推理,不执行D2, D3, U2, U3块。

扩展到1:N推理。缓存的特征计算一次,可以在后续的N−1步中重用,以取代原始的。对于所有去噪的T步,执行完全推理的时间步长序列为:

非均匀1:N推理。

基于1:N策略,在假定高级特征在连续N步中不变的前提下,成功地加速了扩散推理。然而,并非总是如此,特别是对于N,如图2(c)所示,特征的相似性并不是在所有步骤中都保持不变。对于像LDM这样的模型,特征的时间相似性会在去噪过程中显著降低40%左右。

因此,对于非均匀的1:N推理,我们倾向于对那些与相邻步骤相似度相对较小的步骤进行更多采样。在这里,执行完整推理的时间步长序列变为:

使用

import torch
from diffusers import StableDiffusionPipeline
from DeepCache import DeepCacheSDHelper

# 加载 Stable Diffusion 模型
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda:0")

# 创建 DeepCacheSDHelper 对象
helper = DeepCacheSDHelper(pipe=pipe)

# 设置缓存参数
helper.set_params(
    cache_interval=3,
    cache_branch_id=0,
)

# 启用缓存机制
helper.enable()

# 定义输入提示词
prompt = "a beautiful landscape with mountains and rivers"

# 生成图像
deepcache_image = pipe(
    prompt,
    output_type='pt'
).images[0]

# 禁用缓存机制
helper.disable()

 库:

diffusers==0.24.0
transformer

仅需要用DeepCache提供的Pipeline替换Diffusers库的Pipeline,即可实现扩散模型加速。目前支持 StableDiffusionPipeline 可以加载的模型。可以通过参数指定模型名称。

尝试1:将 DeepCacheSDHelper 应用于整个 pipeline,并确保缓存机制只启用一次

 pipe = Pose2VideoPipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    # 初始化 DeepCacheSDHelper
    helper = DeepCacheSDHelper(pipe=pipe)
    # 设置缓存参数
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    # 启用缓存机制
    helper.enable()

报错:

AttributeError: 'Pose2VideoPipeline' object has no attribute 'unet'

报错信息显示 Pose2VideoPipeline 对象没有 unet 属性,这说明 DeepCacheSDHelper 无法找到所需的 UNet 模型。要解决这个问题,必须确保传递给 DeepCacheSDHelper 的 pipeline 具有 unet 属性,并且该属性指向实际的 UNet 模型。

而 Pose2VideoPipeline 包含多个 UNet 模型( reference_unetdenoising_unet),需要对 DeepCacheSDHelper 进行修改,使其能够处理这种情况。一种解决方法是扩展 DeepCacheSDHelper 以接受多个 UNet 模型。解决方案:修改DeepCacheSDHelper类,pipe 和包含所有 UNet 模型的列表传递给 DeepCacheSDHelper:

    pipe = Pose2VideoPipeline(
        vae=vae,
        image_encoder=image_enc,
        reference_unet=reference_unet,
        denoising_unet=denoising_unet,
        pose_guider=pose_guider,
        scheduler=scheduler,
    )
    pipe = pipe.to("cuda", dtype=weight_dtype)

    # Initialize DeepCacheSDHelper with both UNet models
    helper = DeepCacheSDHelper(pipe=pipe, unets=[reference_unet, denoising_unet])
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    helper.enable()

尝试2:分别对 reference_unetdenoising_unet 初始化并启用 DeepCacheSDHelper

 reference_unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_base_model_path,
        subfolder="unet",
    ).to(dtype=weight_dtype, device="cuda")
    
    # Import the DeepCacheSDHelper
    helper = DeepCacheSDHelper(reference_unet=reference_unet)
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,
    )
    helper.enable()

    inference_config_path = config.inference_config
    infer_config = OmegaConf.load(inference_config_path)
    denoising_unet = UNet3DConditionModel.from_pretrained_2d(
        config.pretrained_base_model_path,
        config.motion_module_path,
        subfolder="unet",
        unet_additional_kwargs=infer_config.unet_additional_kwargs,
    ).to(dtype=weight_dtype, device="cuda")
    
    helper = DeepCacheSDHelper(denoising_unet=denoising_unet)
    helper.set_params(
        cache_interval=3,
        cache_branch_id=0,   # 指定缓存的分支 ID,上下两个unet是否需要不同分支?
    )
    helper.enable()

TypeError: DeepCacheSDHelper.__init__() got an unexpected keyword argument 'reference_unet', DeepCacheSDHelper 需要对 pipeline 中所有相关的 UNet 模型进行统一处理,而不是分别处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1906639.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小白·使用Tesseract-OCR工具读取图片

1、直接pip安装 工具使用vscode和pycharm都可以。 这里介绍使用vscode的方法。 (1)、调出终端 (2)、安装依赖 (3)、编写代码 import pyocr import pyocr.builders from PIL import Image import re# 获取Tesseract-OCR工具 tools pyocr.get_available_tools() tool tools[…

使用 MFA 保护对企业应用程序的访问

多因素身份验证(MFA)是在授予用户访问特定资源的权限之前,使用多重身份验证来验证用户身份的过程,仅使用单一因素(传统上是用户名和密码)来保护资源,使它们容易受到破坏,添加其他身份…

C# 实现基于exe内嵌HTTPS监听服务、从HTTP升级到HTTPS 后端windows服务

由于客户需要把原有HTTP后端服务升级为支持https的服务,因为原有的HTTP服务是一个基于WINDOWS服务内嵌HTTP监听服务实现的,并不支持https, 也不像其他IIS中部署的WebAPI服务那样直接加载HTTPS证书,所以这里需要修改原服务支持https和服务器环…

Java基础-Java中的常用类(上)

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,请留下您的足迹) 目录 String类 创建字符串 字符串长度 连接字符串 创建格式化字符串 String 方法 System类 常用方法 方…

谨慎投稿!这本EI期刊正在被“劫持”!

Journsl ofTisniin lniversity Seience and Technology《天津大学学报(自然科学与工程技术版)》创刊于l955年,月刊,全国核心期刊,天津市一级期刊。该刊是由天津大学主办的综合性学术刊物,主要刊登自然科学和…

【第三版 系统集成项目管理工程师】第4章 信息系统架构

持续更新。。。。。。。。。。。。。。。 【第三版】系统集成项目管理工程师 考情分析4.1架构基础4.1.1指导思想(非重点) P1364.1.2设计原则(非重点) P1364.1.3建设目标(非重点) P1374.1.4总体框架 P138练习…

SaaS产品和独立部署型产品有什么区别,该怎么选择?

随着云计算和软件服务的多样化,产品形式主要划分SaaS型(开通即用)和独立部署(完整交付)两种模式,那么SaaS产品和独立部署产品有哪些区别,我们在选择产品的时候应该如何去抉择?本文我…

Java的Thread类中的常用方法解析

Java可以通过Thread类实现多线程,下面来介绍几个Thread类中常用的方法 void start() 开启线程,jvm自动调用run方法 void run() 设置线程任务,这个run方法是Thread重写的接口Runnable中的run方法 String getName() 获取线程名字 void s…

linux 安装Openjdk1.8

一、在线安装 1、更新软件包 sudo apt-get update 2、安装openjdk sudo apt-get install openjdk-8-jdk 3、配置openjdk1.8 openjdk默认会安装在/usr/lib/jvm/java-8-openjdk-amd64 vim ~/.bashrc export JAVA_HOME/usr/lib/jvm/java-8-openjdk-amd64 export JRE_HOME${J…

【Linux】文件和目录管理命令——ls,cp,rm,mv

1.文件与目录的查看:Is ls [-aAdfFhilnrRst] 文件名或目录名称ls [ --color{never,auto,always} ]文件名或目录名称ls [ --full-time ]文件名或目录名称 选项与参数: -a:全部的文件,连同隐藏文件&am…

电子产品分销商 DigiKey 在新视频系列中探索智能城市中的AI

电子产品分销商DigiKey推出了一系列新视频,深入探讨了AI在智能城市中的集成应用。这个名为“智能世界中的AI”的系列是其“城市数字”视频系列的第四季,它审视了城市环境中从基础设施到公共服务的多种AI硬件和软件的部署情况。 该系列由电子制造商莫仕&…

Java的垃圾回收机制解说

Java 内存运行时区域中的程序计数器、虚拟机栈、本地方法栈随线程而生灭;栈中的栈帧随着方法的进入和退出而有条不紊地执行着出栈和入栈操作。每一个栈帧中分配多少内存基本上是在类结构确定下来时就已知的(尽管在运行期会由 JIT 编译器进行一些优化&…

苹果电脑视频压缩工具,苹果电脑视频压缩软件

随着数字媒体内容的爆炸性增长,视频文件的体积越来越大,如何在保证画质的前提下,有效地压缩视频文件,成为许多创作者和普通用户的一大需求。本文将为您详细介绍视频压缩界的佼佼者,让您轻松应对视频文件体积过大的难题…

vue3中使用 tilwindcss报错 Unknown at rule @tailwindcss

解决方法: vscode中安装插件 Tailwind CSS IntelliSense 在项目中的 .vscode中 settings.json添加 "files.associations": {"*.css": "tailwindcss"}

网络连接线相关问题

问题1; 直通线为什么两头都是T568B?是否可以两台T5568A?或者任意线序,只需两头一致? 不行,施工规范规定。(原因;网线最长距离100m,实际用起来要把网线包管,走…

Mapboxgl 根据 AWS 地形的高程值制作等高线

更多精彩内容尽在dt.sim3d.cn&#xff0c;关注公众号【sky的数孪技术】&#xff0c;技术交流、源码下载请添加VX&#xff1a;digital_twin123 使用mapboxgl 3.0版本&#xff0c;根据 AWS 地形图块的高程值制作等高线&#xff0c;源码如下&#xff1a; <!DOCTYPE html> &…

CSS content 计数器

CSS content 计数器 CSS 计数器通过一个变量来设置&#xff0c;根据规则递增变量。 使用计数器自动编号 CSS 计数器根据规则来递增变量。 CSS 计数器使用到以下几个属性&#xff1a; counter-reset - 创建或者重置计数器&#xff0c;给计算器命名。注意声明计算器不能在自身…

乡村振兴指数与其30个原始变量数据(Shp/Dta/Excel格式,2000-2022年)

数据简介&#xff1a;这份数据是我国各地级市乡村振兴指数与其30各原始变量数据并对其进行地图可视化表达。城镇化是当今中国社会经济发展的必由之路。当前我国城镇化处于发展的关键时期&#xff0c;但城镇化发展的加快却是一把双刃剑&#xff0c;为何要如此形容呢?因为当前城…

【产品经理】订单处理12-订单的取消与反取消

在电商ERP系统中&#xff0c;订单取消与反取消也是常见功能之一。 订单取消与反取消也是电商ERP系统的常见功能&#xff0c;本次主要讲解下订单取消与反取消的逻辑。 一、订单取消 在电商ERP系统中&#xff0c;订单取消一般由审单员操作&#xff0c;此类取消一般是由于上下游…

ICMP隧道

后台私信找我获取工具 目录 ICMP隧道作用 ICMP隧道转发TCP上线MSF 开启服务端 生成后门木马 msf开启监听 开启客户端icmp隧道 执行后门木马&#xff0c;本地上线 ICMP隧道转发SOCKS上线MSF 开启服务端 生成后门木马 msf开启监听 开启客户端icmp隧道 ​执行后…