【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数

news2025/1/12 5:00:28

本文是对 Hugging Face Diffusers 文档中关于回调函数的翻译与总结,:


管道回调函数

在管道的去噪循环中,可以使用callback_on_step_end参数添加自定义回调函数。该回调函数在每一步结束时执行,并修改管道属性和变量,以供下一步使用。这在动态调整某些管道属性或修改张量变量时非常有用。利用回调函数,你可以实现新的功能而无需修改底层代码。

目前,Diffusers 仅支持callback_on_step_end,如果你有其他执行点的回调需求,可以在 github 上提出功能请求。

官方回调函数

官方提供了一些可用于修改去噪循环的回调函数列表:

  • SDCFGCutoffCallback:在一定步数后禁用 CFG。对于 SD 1.5 pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。
  • SDXLCFGCutoffCallback:在一定步数后禁用 CFG。对于 SDXL pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。
  • IPAdapterScaleCutoffCallback:在一定步数后禁用 IP Adapter。对所有支持 IP-Adapter 的 pipelines 适用。

要设置回调函数,可以指定cutoff_step_ratiocutoff_step_index参数。

  • cutoff_step_ratio:带有步长比的浮点数。
  • cutoff_step_index:一个整数,包含步数的确切编号。

示例代码

import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.callbacks import SDXLCFGCutoffCallback

callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
# 也可以用 cutoff_step_index
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)


pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)

prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"
generator = torch.Generator(device="cpu").manual_seed(2628670641)

out = pipeline(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
)

out.images[0].save("official_callback.png")

在这里插入图片描述

动态无分类器引导

动态无分类器引导(classifier-free guidance,CFG)允许在一定步数后禁用 CFG,从而节省计算成本。回调函数应包含以下参数:

  • pipeline:访问管道实例属性(如num_timesteps和guidance_scale)。
  • step_indextimestep:当前步骤索引和时间步。在达到num_timesteps的40%后,使用step_index关闭CFG。
  • callback_kwargs:包含在去噪循环中可以修改的张量变量。是一个dict,包含可以在去噪循环中修改的张量变量。
    • 它只包括callback_on_step_end_tensor_inputs参数中指定的变量,该参数被传递给管道的__call__方法。
    • 不同的管道可能使用不同的变量集,因此请检查管道的_callback_tensor_inputs属性以获取可以修改的变量列表。一些常见的变量包括latents和prompt_embeds。
    • 对于此函数,请在将guidance_scale设置为0.0后更改prompt_embeds的批处理大小,以使其正常工作。

示例回调函数:

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
    if step_index == int(pipeline.num_timesteps * 0.4):
        prompt_embeds = callback_kwargs["prompt_embeds"]
        prompt_embeds = prompt_embeds.chunk(2)[-1]

# update guidance_scale and prompt_embeds
        pipeline._guidance_scale = 0.0
        callback_kwargs["prompt_embeds"] = prompt_embeds
    return callback_kwargs

每步生成后显示图像(中间结果)

通过访问并转换潜在空间,可以在每步生成后显示图像。以下函数将 SDXL 的潜在空间(4 通道)转换为 RGB 张量(3 通道)。

  1. 使用以下函数将SDXL潜伏时间(4个通道)转换为RGB张量(3个通道)
def latents_to_rgb(latents):
    weights = (
        (60, -60, 25, -70),
        (60,  -5, 15, -50),
        (60,  10, -5, -35)
    )

    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
    rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
    image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
    image_array = image_array.transpose(1, 2, 0)

    return Image.fromarray(image_array)
  1. 使用该函数在每步生成后解码并保存潜在空间为图像。
def decode_tensors(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    image = latents_to_rgb(latents)
    image.save(f"{step}.png")
    return callback_kwargs
  1. decode_tensors函数传递给callback_on_step_end参数,以在每一步之后对张量进行解码。还需要在callback_on_step_end_tensor_inputs参数中指定要修改的内容,在本例中为 latents。
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")

image = pipeline(
    prompt="A croissant shaped like a cute bear.",
    negative_prompt="Deformed, ugly, bad anatomy",
    callback_on_step_end=decode_tensors,
    callback_on_step_end_tensor_inputs=["latents"],
).images[0]

在这里插入图片描述

详细内容请参见Hugging Face Diffusers 官方文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小米SU7遇冷,下一代全新车型被官方意外曝光

不知道大伙儿有没有发现,最近小米 SU7 热度好像突然之间就淡了不少? 作为小米首款车型,SU7 自上市以来一直承载着新能源轿车领域流量标杆这样一个存在。 发售 24 小时订单量破 8 万,2 个月后累计交付破 2 万台。 看得出来限制它…

模拟自动滚动并展开所有评论列表以及回复内容(如:抖音、b站等平台)

由于各大视频平台的回复内容排序不都是按照时间顺序,而且想看最新的评论回复讨论内容还需逐个点击展开,真的很蛋疼,尤其是热评很多的情况,还需要多次点击展开,太麻烦! 于是写了一个自动化展开所有评论回复…

英伟达成全球市值第一公司;苹果暂停下一代高端头显研发丨 RTE 开发者日报 Vol.227

开发者朋友们大家好: 这里是 「RTE 开发者日报」 ,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE(Real-Time Engagement) 领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「…

4.1 初探Spring Boot

初探Spring Boot实战概述 Spring Boot简介 Spring Boot是一个开源的Java框架,由Pivotal团队(现为VMware的一部分)开发,旨在简化Spring应用程序的创建和部署过程。它通过提供一系列自动化配置、独立运行的特性和微服务支持&#…

JupyterLab使用指南(六):JupyterLab的 Widget 控件

1. 什么是 Widget 控件 JupyterLab 中的 Widget 控件是一种交互式的小部件,可以用于创建动态的、响应用户输入的界面。通过使用 ipywidgets 库,用户可以在 Jupyter notebook 中创建滑块、按钮、文本框、选择器等控件,从而实现数据的交互式展…

Orangepi Zero2

1、Orangepi Zero2 Orangepi Zero2 是基于全志H616的一款产品 特性: CPU全志H616四核64位1.5GHz高性能Cortex-A53处理器 GPU MaliG31MP2 SupportsOpenGLES1.0/2.0/3.2、OpenCL2.0 运行内存1GB DDR3(与GPU共享) 存储TF卡插槽,测试128G可支持、2MB SPI Fl…

深度学习(十二)——神经网络:搭建小实战和Sequential的使用

一、torch.nn.Sequential代码栗子 官方文档:Sequential — PyTorch 2.0 documentation # Using Sequential to create a small model. When model is run, # input will first be passed to Conv2d(1,20,5). The output of # Conv2d(1,20,5) will be used as the in…

嵌入式实验---实验一 通用GPIO实验

一、实验目的 1、掌握STM32F103 GPIO程序设计流程; 2、熟悉STM32固件库的基本使用。 二、实验原理 1、通过按键实现:按键按下,LED点亮;按键释放,LED熄灭。 三、实验设备和器材 电脑、Keil uVision5软件、Proteus…

TVS的原理及选型

目录 案例描述 TVS管的功能与作用: TVS选型注意事项: 高速TVS管选型 最近项目中遇到TVS管选型错误的问题。在此对TVS的功能及选型做一个分享。 案例描述 项目中保护指标应为4-14V,而选型的TVS管位SMJ40CA,其保护电压为40V未…

深信服科技:2023网络安全深度洞察及2024年趋势研判报告

2023 年,生成式人工智能和各种大模型迅速应用在网络攻击与对抗中,带来了新型攻防场景和安全威胁。漏洞利用链组合攻击实现攻击效果加成,在国家级对抗中频繁使用。勒索团伙广泛利用多个信创系统漏洞,对企业数据安全与财产安全造成了…

【Android】三种常见的布局LinearLayout、GridLayout、RelativeLayout

【Android】三种常见的布局LinearLayout、GridLayout、RelativeLayout 在 Android 开发中,布局(Layout)是构建用户界面的基础。通过合理的布局管理,可以确保应用在不同设备和屏幕尺寸上都能有良好的用户体验。本文将简单介绍 And…

Premiere Pro 关键帧的运用(热气球飞行)

制作“热气球飞行”效果,步骤如下: 1.新建项目>新建序列,项目名称为“热气球飞行”,序列设置如下图所示。 2.将“背景”和“飞船”素材导入项目面板中,并将“背景”素材拖到时间线的视频轨道v1上,将“飞…

如何下载Visual Studio2022编译器。详细教程

朝三暮四就该死! --莫迪大仙 引言:今天下个Visual Studio给大家也分享分享,同时怕自己忘记,也记录记录。(windows版本演示) 一、下载网址 网址:Visual Studio2022下载 - Windows、Mac、Linux…

Leetcode Hot100之双指针

1. 移动零 题目描述 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行操作。解题思路 双指针遍历一遍即可解决: 我们定义了两个指针 i 和 j&#xf…

scratch编程03-反弹球

这篇文章和上一篇文章《scratch3编程02-使用克隆来编写小游戏》类似(已经完全掌握了克隆的可以忽略这篇文章),两篇文章都使用到了克隆来编写一个小游戏,这篇文章与上篇文章不同的是,本体在进行克隆操作时,不…

图卷积网络(Graph Convolutional Network, GCN)

图卷积网络(Graph Convolutional Network, GCN)是一种用于处理图结构数据的深度学习模型。GCN编码器的核心思想是通过邻接节点的信息聚合来更新节点表示。 图的表示 一个图 G通常表示为 G(V,E),其中: V 是节点集合,…

【单片机】三极管的电路符号及图片识别

一:三极管的电路符号 二:三极管的分类 a;按频率分:高频管和低频管 b;按功率分:小功率管,中功率管和的功率管 c;按机构分:PNP管和NPN管 d;按材质分:硅管和锗管 e;按功能分:开关管和放…

Python基础-引用参数、斐波那契数列、无极分类

1.引用参数的问题 (1)列表(list) 引用参数,传地址的参数,即list1会因list2修改而改变。 list1 [1,2,3,4] list2 list1 print(list1) list2[2] 1 print(list2) print(list1)非引用参数,不传…

获得淘宝app商品详情原数据API接口|商品价格详情页面优惠券主图

item_get_app:通过商品id获取商品详情页数据 注册账号获取API测试地址 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中&#xf…