如何魔改 diffusers 中的 pipelines

news2024/11/29 0:30:18

如何魔改 diffusers 中的 pipelines

整个 Stable Diffusion 及其 pipeline 长得就很适合 hack 的样子。不管是通过简单地调整采样过程中的一些参数,还是直接魔改 pipeline 内部甚至 UNet 内部的 Attention,都可以实现很多有趣的功能或采样生图结果。

本文主要介绍两种魔改 diffusers pipelines 的方式,一是通过注册 callback 函数,来在采样生图过程中执行某些操作,二是直接自己写 custom pipelines。

pipeline callbacks

可参考官方文档:Pipelines Callback

通过在 pipe 推理生图时传入自定义的回调函数,不用动底层代码,可以在每个时间步结束时动态地执行一些我们想要的动作,比如在特定的时间步修改特定的采样参数或张量。目前仅支持 callback_on_step_end 在单步结束时执行回调函数。我们通过两个例子来介绍如何通过 callback 函数来魔改 diffusers pipelines。

Dynamic classifier-free guidance

classifier-free guidance (cfg)用于通过 prompt 来引导图像生成的内容。在 diffusers 中,会同时用 CLIP 文本编码器同时编码 prompt 的 embeds 和空 prompt (空字符串或 negative prompt)的 embeds,然后拼接起来,一起通过交叉注意力与 UNet 交互。

通过 callback 函数,我们可以按照自己的需求动态控制 cfg,比如说,我想在特定的时间步之后停止使用 cfg,从而节省计算开销,并且性能不会有很大的损失。 callback 函数需要接收以下参数:

  • pipeline:通过 pipe 可以访问和编辑许多重要的采样参数,如 pipe.num_timestepspipe.guidance_scale 等。在本例中,我们就可以通过将 pipe._guidance_scale 来停用 cfg。

  • timestep 和 step_index:这两个参数可以让我们知道本次采样过程一共有多少时间步,以及当前我们位于哪一步。从而可以根据当前位于整个采样过程中的位置,来选择进行什么操作。在本例中,我们可以设置在整个采样过程 40% 及以后的位置,停用 cfg。

  • callback_kwargs:callback_kwargs 是一个字典,包含了在采样生图的过程中你可以编辑的张量。具体包含哪些张量,需要再调用 pipe 采样生图时通过 callback_on_step_end_tensor_inputs 参数传入。不同的 pipe 可能包含不同的可编辑张量,具体可以通过 pipe 的 _callback_tensor_inputs 属性查看。在本例中,我们需要在停用 cfg 之后调整 prompt_embeds 张量的批尺寸,丢掉空 prompt 部分。这是因为 sd pipe 是根据 _guidance_scale 的值来判断是否进行 cfg,所以我们将这个值改为零,就不会进行 cfg 了,所以需要将 prompt_embeds 不带 prompt 的部分丢掉,只保存带 prompt 的部分。

返回值方面,回调函数必须返回(修改好的) callback_kwargs。

最终,我们的 callback 函数是这样的:

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs, percent):
    if step_index == int(pipe.num_timesteps * percent):
        prompt_embeds = callback_kwargs['prompt_embeds']
        prompt_embeds = prompt_embeds.chunk(2)[-1]
        pipe._guidance_scale = 0.0
        callback_kwargs['prompt_embeds'] = prompt_embeds
    return callback_kwargs

然后在推理生图时传入该参数:

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(1)
out = pipeline(
    prompt,
    generator=generator,
    callback_on_step_end=callback_dynamic_cfg,
    callback_on_step_end_tensor_inputs=['prompt_embeds']
)

out.images[0].save("out_custom_cfg.png")
Display image after each generation step

在搭建生图 UI 时,一般需要支持用户在看到前几个时间步的结果不符合预期时,手动终止采样生图过程。这就需要两个功能,一是展示每一步的生图结果,二是支持在过程中终止本次生图。下面以展示中间步结果为例,介绍回调函数的使用。

我们知道,在 SD 中,去噪采样生图过程是发生在隐空间的,以 SDXL 为例,隐空间的特征图尺寸为 128 × 128 × 3 128\times 128\times 3 128×128×3 。我们需要将其转换到像素空间,才能看到这一步的实际生图结果。SDXL 还有点特殊,需要先将四通道的隐空间转换为 RGB 三通道,详情见 Explaining the SDXL latent space 。

def latents_to_rgb(latents):
    weights = (
        (60, -60, 25, -70),
        (60,  -5, 15, -50),
        (60,  10, -5, -35)
    )

    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
    rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
    image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
    image_array = image_array.transpose(1, 2, 0)

    return Image.fromarray(image_array)
  
def decode_tensors(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    
    image = latents_to_rgb(latents)
    image.save(f"{step}.png")

    return callback_kwargs

然后再推理生图时传入该参数回调函数:

from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")

image = pipeline(
    prompt = "A croissant shaped like a cute bear.",
    negative_prompt = "Deformed, ugly, bad anatomy",
    callback_on_step_end=decode_tensors,
    callback_on_step_end_tensor_inputs=["latents"],
).images[0]

在这里插入图片描述

Custom Pipelines

可参考官方文档:Custome Pipelines、contribute-pipeline

在 diffusers 中,我们可以很方便的自定义并加载自己的定制化 pipelines。在实现自己的自定义 pipelines 时,需要继承基类 DiffusionPipeline

加载 custom pipelines

1 从 diffusers 仓库中加载自定义的 pipeline

从 hf hub 加载自定义的 pipeline 非常简单,只需要将 model_id 传入 custom_pipeline 参数,然后就会加载该仓库中对应的 pipeline.py。

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)

2 从本地文件加载自定义的 pipeline

如果想从本地文件加载 pipeline,需要将 pipeline.py 所在的文件目录(注意是目录名)传给 custom_pipeline 参数。

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "google/ddpm-cifar10-32", custom_pipeline="path/to/dir"
)

3 加载官方收录的社区 custom pipelines

这里是合入 diffuses 官方仓库的一些社区的自定义 pipelines。我们只需要将对应文件名(不含 py 后缀,如 clip_guided_stable_diffusion)传给 custom_pipeline 参数。

由于自定义 pipelines 的通常比较复杂,所以我们也可以通过官方 pipeline 来加载模型,再将模型传入自定义 pipelines。

from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModel

clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"

feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id)

pipeline = DiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    custom_pipeline="clip_guided_stable_diffusion",
    clip_model=clip_model,
    feature_extractor=feature_extractor,
)
实现 custom pipelines

我们可以继承 DiffusionPipeline 基类并实现自己的 custom pipeline,这样所有人就都可以加载我们实现的 pipeline。一个 custom pipeline 的框架大致如下:

import torch
from diffusers import DiffusionPipeline


class MyPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()

        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
        # Sample gaussian noise to begin loop
        image = torch.randn((batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size))

        image = image.to(self.device)

        # set step values
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):
            # 1. predict noise model_output
            model_output = self.unet(image, t).sample

            # 2. predict previous mean of image x_t-1 and add variance depending on eta
            # eta corresponds to η in paper and should be between [0, 1]
            # do x_t -> x_t-1
            image = self.scheduler.step(model_output, t, image, eta).prev_sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()

        return image

确保 xxxPipeline 这个类的相关实现都在这一个文件,而且该文件中只包含 xxxPipeline 一个类。因为 pipeline 的识别加载是自动的。

接下来,我们以一个最简单的 one-step pipeline 为例,简单介绍自己实现 custom pipeline 的过程。在这个one-step pipeline 中,只会用到 UNet 一个模型,并将 timestep 固定为 1,只进行一次模型前向。

首先新建一个 one_step_unet.py 文件,然后在其中继承 DiffusionPipeline 基类并实现 UnetSchedulerOneForwardPipeline 类。

初始化:定义 __init__ 方法

在初始化方法中,我们简单地的 one-step pipeline 只需要接收 unet 和 scheduler 两个初始化参数,我们在初始化方法中将变量定义好。注意,为了使得 save_pretrained 方法能够将我们的模型完整地保存下来,需要通过 register_modules 方法将我们想要保存的 unet 和 scheduler 注册进来。

from diffusers import DiffusionPipeline
import torch

class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

前向推理:定义 __call__ 方法

定义好初始化方法 __init__ 之后,再来实现 pipeline 推理生图的 __call__ 方法。在这里,我们可以任意发挥,任意组合,魔改扩散模型的采样过程,实现自己想要的功能。在我们的 one-step pipeline 中,这里要做的非常简单:采样一个噪声图,UNet 进行一次前向。

@torch.no_grad()
def __call__(self):
    image = torch.randn(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.sanple_size)
    timestep = 1
    model_output = self.unet(image, timestep).sample
    scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
    return scheduler_output

这就 ok 了,我们已经实现好了自定义的 one-step pipeline。

推理

我们传入 unet 和 scheduler,实例化一个刚刚自定义好的 UnetSchedulerOneForwardPipeline,然后进行推理生图:

from diffusers import DDPMScheduler, UNet2DModel

scheduler = DDPMScheduler()
unet = UNet2DModel()

pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)

output = pipeline()

如果我们的 custom pipeline 结果如果跟某个已有的 pipeline 的预训练权重是完全一样的,我们还可以直接通过 from_pretrained 方法来加载它们的权重。比如说,我们的 UnetSchedulerOneForwardPipeline 就可以直接加载 google/ddpm-cifar10-32 的权重:

pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)

output = pipeline()
分享 custom pipelines

要共享自己的 custom pipeline 有三个方法:

  1. 将自己实现的 custom pipeline 推送到 hf hub 仓库,需要将文件名命名为 pipeline.py

  2. 像 diffusers 官方仓库提交 PR,合并之后可以我们的 pipeline 就会出现在这里,别人可以通过文件名来加载

  3. 共享出自己的源码文件,如 clip_guided_stable_diffusion.py,别人也可以导入我们的 pipeline

而作为使用者,使用 custom pipeline 的方式有两种:

# 方式 1
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
    "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
)

# 方式 2
from cus_pipe import CustomPipeline    # cus_pipe is copied from hf-internal-testing/diffusers-dummy-pipeline
pipe = CustomPipeline.from_pretrained("google/ddpm-cifar10-32")

这两种使用方式应该是等价的,这可以从源码中看到:

if custom_pipeline is not None:
    pipeline_class = get_class_from_dynamic_module(
        custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
    )
elif cls != DiffusionPipeline:
    pipeline_class = cls
else:
    diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
    pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

总结

diffusers 的 api 设计非常友好,我们可以通过 pipeline callback 和 custom pipeline 等方式定制化实现自己想要的功能,其中前者不用动底层代码,简单优雅,后者则是功能强大,现在最新的 AIGC 相关的论文基本都是通过 custom diffusion 的方式公开自己的源码,非常方便。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1576703.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM 全景图

今天我重新复习了一下 jvm 的一些知识点。我以前觉得 jvm 的知识点很多很碎,而且记起来很困难,但是今天我重新复习了一下,对这些知识点进行了简单的梳理之后,产生了不一样的看法。虽然 jvm 的知识点很碎,但是如果你真的…

创建型模式--1.单例模式【巴基速递】

1. 巴基的订单 在海贼世界中,巴基速递是巴基依靠手下强大的越狱犯兵力,组建的集团海贼派遣公司,它的主要业务是向世界有需要的地方输送雇佣兵(其实是不干好事儿)。 自从从特拉法尔加罗和路飞同盟击败了堂吉诃德家族 &…

系统监测工具-tcpdump的使用

一个简单的tcpdump抓包过程。主要抓包观察三次握手,四次挥手的数据包 有两个程序:客户端和服务器两个程序 服务器端的ip地址使用的是回环地址127.0.0.1 端口号使用的是6000 tcpdump -i 指定用哪个网卡等,dstip地址端口指定抓取目的地址…

ctfshow web入门 文件包含 web151--web161

web151 打算用bp改文件形式(可能没操作好)我重新试了一下抓不到 文件上传不成功 改网页前端 鼠标右键&#xff08;检查&#xff09;&#xff0c;把png改为php访问&#xff0c;执行命令 我上传的马是<?php eval($_POST[a]);?> 查看 web152 上传马 把Content-Type改为…

基于大模型的态势认知智能体

源自&#xff1a;指挥控制与仿真 作者&#xff1a;孙怡峰, 廖树范, 吴疆 李福林 “人工智能技术与咨询” 发布 摘要 针对战场态势信息众多、变化趋势认知困难的问题,提出基于大模型的态势认知智能体框架和智能态势认知推演方法。从认知概念出发,结合智能体的抽象性、具…

基于单片机收音机调幅系统设计仿真源码

**单片机设计介绍&#xff0c;基于单片机收音机调幅系统设计仿真源码 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机收音机调幅系统设计的仿真源码&#xff0c;主要实现了通过单片机控制调幅收音机的核心功能。以下是…

如何关闭WordPress的自动更新功能

Wordpress为什么自动更新 WordPress自动更新是为了提供更好的安全性和稳定性。 安全性&#xff1a;WordPress是一个广泛使用的内容管理系统&#xff0c;因此成为恶意攻击的目标。WordPress的自动更新功能确保你的网站及时获得最新的安全补丁和修复程序&#xff0c;以保护你的网…

ctfshow web入门 命令运行 web39---web52

ctfshow web入门 命令执行 昨天看了一下我的博客真的很恼火&#xff0c;不好看&#xff0c;还是用md来写吧 web39 查看源代码 看到include了&#xff0c;还是包含(其实不是) 源代码意思是当c不含flag的时候把c当php文件运行 php伪协议绕过php文件执行 data://text/plain 绕…

算法汇总啊

一些常用算法汇总 算法思想-----数据结构动态规划(DP)0.题目特点1.【重点】经典例题(简单一维dp&#xff09;1.斐波那契数列2.矩形覆盖3.跳台阶4.变态跳台阶 2.我的日常练习汇总(DP)1.蓝桥真题-----路径 算法思想-----数据结构 数据结构的存储方式 : 顺序存储(数组) , 链式存储…

ubuntu安装nginx以及开启文件服务器

1. 下载源码 下载页面&#xff1a;https://nginx.org/en/download.html 下载地址&#xff1a;https://nginx.org/download/nginx-1.24.0.tar.gz curl -O https://nginx.org/download/nginx-1.24.0.tar.gz2. 依赖配置 sudo apt install gcc make libpcre3-dev zlib1g-dev ope…

轨迹规划 | 图解最优控制LQR算法(附ROS C++/Python/Matlab仿真)

目录 0 专栏介绍1 最优控制理论2 线性二次型问题3 LQR的价值迭代推导4 基于差速模型的LQR控制5 仿真实现5.1 ROS C实现5.2 Python实现5.3 Matlab实现 0 专栏介绍 &#x1f525;附C/Python/Matlab全套代码&#x1f525;课程设计、毕业设计、创新竞赛必备&#xff01;详细介绍全…

护眼台灯什么品牌好?揭秘护眼台灯十大排名

台灯作为我们日常生活中使用率较高的照明工具&#xff0c;光源的品质也是很重要的&#xff01;如果长时间使用一款质量不好的台灯&#xff0c;可能会影响我们的眼睛健康&#xff0c;特别是孩子的眼睛&#xff0c;还没有发育完全&#xff0c;影响更大。 要知道市面上很多劣质台…

IT廉连看——SpringBoot——SpringBoot快速入门

IT廉连看——SpringBoot——SpringBoot快速入门 1、idea创建工程 &#xff08;1&#xff09;普通Maven工程创建 工程名spring-boot-test 2、添加依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/PO…

理解 Golang 变量在内存分配中的规则

为什么有些变量在堆中分配、有些却在栈中分配&#xff1f; 我们先看来栈和堆的特点&#xff1a; 简单总结就是&#xff1a; 栈&#xff1a;函数局部变量&#xff0c;小数据 堆&#xff1a;大的局部变量&#xff0c;函数内部产生逃逸的变量&#xff0c;动态分配的数据&#x…

人工智能的分类有哪些

人工智能&#xff08;AI&#xff09;可以根据不同的分类标准进行分类。以下是一些常见的分类方法&#xff1a; 1. **按功能分类**&#xff1a; - 弱人工智能&#xff08;Narrow AI&#xff09;&#xff1a;也称为狭义人工智能&#xff0c;指专注于执行特定任务的AI系统&…

【蓝桥杯嵌入式】第十三届省赛(第二场)

目录 0 前言 1 展示 1.1 源码 1.2 演示视频 1.3 题目展示 2 CubeMX配置(第十三届省赛第二场真题) 2.1 设置下载线 2.2 HSE时钟设置 2.3 时钟树配置 2.4 生成代码设置 2.5 USART1 2.5.1 基本配置 2.5.2 NVIC 2.5.3 DMA 2.6 TIM 2.6.1 TIM2 2.6.2 TIM4 2.6.3 …

【Linux】 OpenSSH_9.3p1 升级到 OpenSSH_9.6p1(亲测无问题,建议收藏)

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

品牌门店稽查可调研内容

执行门店稽查的方式主要分为两类&#xff0c;明访和暗访&#xff0c;调研形式的不同&#xff0c;使得调查内容也有所差距&#xff0c;为了保证调研数据的真实性&#xff0c;目前稽查多为暗访&#xff0c;只有在需要对门店导购等的专业素养&#xff0c;或者产品库存盘点时做明访…

JAVA毕业设计134—基于Java+Springboot+Vue的社区医院管理系统(源代码+数据库+万字论文)

毕设所有选题&#xff1a; https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootVue的社区医院管理系统(源代码数据库万字论文)134 一、系统介绍 本项目前后端分离&#xff0c;分为管理员、用户、医生、前台四种角色 1、用户&#xff1a; 注…

有大学老师正用ChatGPT批改论文,让同学也这么做!

4月7日&#xff0c;CNN消息&#xff0c;美国伊萨卡学院-战略传播学教授Diane Gayeski&#xff0c;正在使用ChatGPT批改学生的论文。 当Diane收到学生提交的论文时&#xff0c;会将部分内容输入到ChatGPT&#xff0c;然后让其进行评分并给出详细的修改建议。 Diane也会让班里的…