【扩散模型(五)】IP-Adapter 源码详解3-推理代码

news2025/1/13 7:45:06

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。

文章目录

  • 系列文章目录
  • 前言
  • 一、输入处理
  • 二、过 Unet
  • 三、Unet 中被替换的 CA


前言

这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb 中最基础的以图生图(Image Variations)为例:

SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:

  1. 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
  2. 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
  3. Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。

其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
在这里插入图片描述

*图中的 bs 代表 batch size

一、输入处理

IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py 文件的 IPAdapter 类的 generate() 函数中。

在这里插入图片描述

  1. 输入1: image prompt
    • 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
    • 以及训练好的 image_proj_model(ImageProjModel)
  2. 输入1对应的输出1有:
    • image_prompt_embeds
    • uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
    self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
  1. 输入2: text prompt、negative_prompt(默认的 ['monochrome, lowres, bad anatomy, worst quality, low quality']

    • text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
      • encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
      • 检查是否超过 clip 的长度
      • 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
    • negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
  2. 输入2 对应的输出2有:

    • prompt_embeds_
    • negative_prompt_embeds_
  3. 输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上 torch.cat 后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。

with torch.inference_mode():
    prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
        prompt,
        device=self.device,
        num_images_per_prompt=num_samples,
        do_classifier_free_guidance=True,
        negative_prompt=negative_prompt,
    )
    prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
    negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

二、过 Unet

  1. 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
    • 得到进一步的 prompt_embeds 和 negative_prompt_embeds
  2. prompt_embeds 和 negative_prompt_embeds 做 torch.cat 是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
 # For classifier free guidance, we need to do two forward passes.
 # Here we concatenate the unconditional and text embeddings into a single batch
 # to avoid doing two forward passes
 if self.do_classifier_free_guidance:
     prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1. 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
    在这里插入图片描述

三、Unet 中被替换的 CA

该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1941132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端JS特效第48集:terseBanner焦点图轮播插件

terseBanner焦点图轮播插件&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatibl…

python每日学习:异常处理

python每日学习8&#xff1a;异常处理 Python中的错误可以分为两种&#xff1a;语法错误和异常 语法错误(Syntax errors) &#xff1a;代码编译时的错误&#xff0c;不符合Python语言规则的代码会停止编译并返回 错误信息。 缺少起始符号或结尾符号(括号、引号等)。 缩进错误…

算法篇 滑动窗口 leetCode 30 串联所有单词的子串

串联所有单词的子串 1.题目描述2.题目解释2.1 原理解释2.2 文字分析 3.代码演示 1.题目描述 2.题目解释 2.1 原理解释 2.2 文字分析 3.代码演示

移动硬盘在苹果电脑上使用后在windows中无法读取 Win和Mac的硬盘怎么通用

在日益普及的跨平台工作环境中&#xff0c;苹果电脑与Windows PC之间的数据交换成为日常需求。然而&#xff0c;用户常面临一个困扰&#xff1a;为何苹果电脑的硬盘能在macOS下流畅运行&#xff0c;却在Windows系统中变得“水土不服”&#xff1f;这一问题核心在于硬盘格式的不…

mac docker no space left on device

mac 上 docker 拉取镜像报错 Error response from daemon: write /var/lib/docker/tmp/docker-export-3995807640/b8464f52498789c4ebbc063d508f04e8d2586567fbffa475e3cd9afd3c5a7cf2/layer.tar: no space left on device解决&#xff1a; 增加 docker 虚拟磁盘大小。如下图

Echarts + 低代码 :可视化如何赋能企业的创新之路?

Echarts最新技术资源&#xff08;建议收藏&#xff09; https://gcdn.grapecity.com.cn/forum.php?modviewthread&tid149493&highlightecharts 前言 数据驱动已经成为企业决策和业务优化的关键所在&#xff0c;在数字化时代&#xff0c;高效的数据分析与可视化呈现是…

多类支持向量机损失(SVM损失)

(SVM) 损失。SVM 损失的设置是&#xff0c;SVM“希望”每个图像的正确类别的得分比错误类别高出一定幅度Δ。 即假设有一个分数集合s[13,−7,11] 如果y0为真实值&#xff0c;超参数为10&#xff0c;则该损失值为 超参数是指在机器学习算法的训练过程中需要设置的参数&#xf…

大数据之写入Doris数据问题

1. 解决Key columns should be a ordered prefix of the schema. KeyColumns[1] (starts from zero) is xxx, but 背景 create table if not exists XXX ( fathercorp varchar(50), id decimalv3(38,0) ) ENGINEOLAP UNIQUE KEY(id) COMMENT xxxx DISTRIBUTED BY HASH(id) BUC…

深入理解Linux网络(一):内核如何接收网络包

深入理解Linux网络&#xff08;一&#xff09;&#xff1a;内核如何接收网络包 一、网络收包总览二、Linux启动1、创建 ksoftirqd 内核进程2、网络子系统初始化3、协议栈注册4、网卡初始化NAPI 5、启动网卡 三、接收数据1、硬中断处理2、ksoftirqd 内核线程处理软中断3、网络协…

数据库基础与安装MYSQL数据库

一、数据库管理系统DBMS 数据库技术是计算机科学的核心技术之一&#xff0c;具有完备的理论基础。使用数据库可以高效且条理分明地存储数据&#xff0c;使人们能够更加迅速、方便地管理数据 1.可以结构化存储大量的数据信息&#xff0c;方便用户进行有效的检索和访问 2.可以…

24届电子信息应届硕士生秋招+春招心得与感悟

背景&#xff1a; 研二下学期在深圳某互联网独角兽公司实习过四个月 岗位为测试实习生 求职的方向为互联网-测试岗 24届电子信息硕士 24秋招&#xff08;2023.9-2023.12&#xff09; 其实早在7月份部分互联网公司和大厂已经开始提前批了&#xff0c;因为我不是科班出身&…

Step-DPO 论文——数学大语言模型理解

论文题目&#xff1a;STEP-DPO: STEP-WISE PREFERENCE OPTIMIZATION FOR LONG-CHAIN REASONING OF LLMS 翻译为中文就是&#xff1a;“LLMs长链推理的逐步偏好优化” 论文由港中文贾佳亚团队推出&#xff0c;基于推理步骤的大模型优化策略&#xff0c;能够像老师教学生一样优…

【BUG】已解决:requests.exceptions.ProxyError: HTTPSConnectionPool

已解决&#xff1a;requests.exceptions.ProxyError: HTTPSConnectionPool 目录 已解决&#xff1a;requests.exceptions.ProxyError: HTTPSConnectionPool 【常见模块错误】 原因分析 解决方案 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&am…

OCC 创建方管(拉伸操作)

目录 一、OCC 拉伸操作 二、例子 1、使BRepBuilderAPI_MakeFace 2、使用BRepPrimAPI_MakeRevol 3、垂直路径扫掠 一、OCC 拉伸操作 BRepPrimAPI_MakeSweep Class Reference - Open CASCADE Technology Documentation OCC提供几种图形的构建是由基本图形的旋转&#xff0c;…

使用Python快速比较和替换键值对

问题背景 您需要在多个文件中替换所有特定字符串的实例。例如&#xff0c;您有一个包含 60728 个键值对的映射词典&#xff0c;需要处理多达 50 个文件&#xff0c;每个文件大约有 250000 行&#xff0c;并且需要在每行中替换多个键。 解决方案 方法一&#xff1a;使用正则表…

【区块链 + 智慧政务】山东荣成:区块链政务诚信管理系统 | FISCO BCOS应用案例

2018 年 9 月&#xff0c;荣成市政府与山东观海数据技术有限公司合作&#xff0c;基于 FISCO BCOS 区块链技术推动智慧城市建设&#xff0c; 其中&#xff0c;信用管理是智慧城市核心之一。 荣成市区块链政务诚信管理系统&#xff0c;建设信用信息征集、评价、披露和应用于一体…

CloudCampus的三种部署模式

CloudCampus的三种部署模式 本地部署 客户购买控制器 自己运营 软件永久license sns &#xff0c;将软件补丁、软件升级&#xff08;含升级版本的新特性&#xff09;、远程支持等打包在一起组成SnS年费 msp自建云部署 msp 购买控制器 msp运营 …

美业SaaS门店收银系统怎么管理订单?博弈美业系统App实操|美业系统Java源码

- 打开博弈美业 - 首页点击订单管理 - 选择想查询的相应订单即可 美业门店管理系统Java源码、美业店务系统演示视频请私信

HTTP协议详解:从零开始的Web通信之旅

文章目录 一、引言&#xff1a;Web通信的基石 - HTTP协议二、HTTP请求方法2.1 OPTIONS2.2 HEAD2.3 GET2.4 POST2.5 PUT2.6 DELETE2.7 TRACE2.8 CONNECT2.9 注意 三、HTTP工作原理四、HTTP 请求/响应流程4.1、客户端连接到web服务器4.2、发送HTTP请求4.3、服务器接受请求并返回H…

【C++】学习笔记——红黑树

文章目录 十七、红黑树1.红黑树的概念红黑树的性质 2.红黑树节点的定义3.红黑树的插入4.红黑树的验证5.完整代码结果演示6.红黑树与AVL树的比较 未完待续 十七、红黑树 1.红黑树的概念 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个结点上增加一个存储位表示结点的…