【扩散模型(七)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

news2024/12/23 23:02:24

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

文章目录

  • 系列文章目录
  • 一、Image Encoder 的使用区别?
    • 1.1 Image Encoder 组成
    • 1.2 .hidden_states[-2] 表示什么?
  • 二、ImageProjModel 和 Resampler 的区别?
    • 2.1 ImageProjModel 代码
    • 2.2 Resampler 代码


从下图中可以很直观地看出有两处不同,第一是使用 image encoder 的方式不同、得到了不同的图像特征,第二是将原有的简单 ImageProjModel 替换成了更加复杂的 Resampler 以提取更多的图像信息。

在这里插入图片描述

一、Image Encoder 的使用区别?

1.1 Image Encoder 组成

Image Encoder 是 CLIPVisionModelWithProjection 类(位于 /path/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py)

根据其构造函数,可见分为两块(1)vision_model【CLIPVisionTransformer】 和 (2)visual_projection【Linear】

class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self, config: CLIPVisionConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionTransformer(config)

        self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

1.视觉模型(vision_model

  • 这通常是一个处理输入图像的视觉转换器(ViT)。
  • 它从图像中提取特征并输出表示,通常包括总结整个图像的“合并”输出。
  • 视觉模型处理了理解图像内容的繁重任务。

2.视觉投影(visual_projection

  • 这是一个线性层,将视觉模型的高维输出映射到低维空间。
  • 在 CLIP 这样的多模态模型中,投影会将图像表示与文本表示对齐
  • 它确保图像嵌入与文本嵌入位于同一空间,便于比较或组合。

1.2 .hidden_states[-2] 表示什么?

我们仔细对比 IP-Adapter 和 IP-Adapter Plus 的细节,会发现采用 Image Encoder 的方式不一样

# IP-Adapter
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds

# IP-Adapter Plus
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]

关键区别在于:
(1) .image_embeds属性来获取图像的嵌入表示,是 经过编码器最后一层(visual_projection)处理后的结果。

(2) .hidden_states[-2]:这行代码调用self.image_encoder时,通过 output_hidden_states=True 参数指示编码器输出除了最终输出之外的所有隐藏状态

  • self.image_encoder返回一个对象,其中hidden_states是一个包含所有隐藏层输出的列表。
  • 然后,通过索引 [-2] 访问这个列表的倒数第二个元素,即倒数第二个隐藏层的输出。
  • 未经过编码器最后一层(visual_projection)处理后的结果。

二、ImageProjModel 和 Resampler 的区别?

  • ImageProjModelResampler 都是用于将图像嵌入(image_embeds)投影到一个更高维度的空间,以便作为后续的生成引导。通过对 2.1 和 2.2 的两段代码,可以总结出差异:
  1. 网络结构

    • ImageProjModel:包含一个线性层self.proj用于投影,以及一个层归一化self.norm
    • Resampler:包含位置嵌入(如果apply_pos_embTrue)、输入投影self.proj_in、输出投影self.proj_out和层归一化self.norm_out。此外,它还包含一个由多个注意力和前馈网络层组成的模块列表self.layers,这些层用于处理输入数据。
  2. 注意力机制

    • ImageProjModel:没有使用注意力机制。
    • Resampler:使用自定义的PerceiverAttention模块进行注意力计算。
  3. 前馈网络

    • ImageProjModel:没有前馈网络。
    • Resampler:使用FeedForward模块,这是一个标准的前馈网络,通常用于Transformer架构中。
  4. 序列处理

    • ImageProjModel:没有特别处理序列数据。
    • Resampler:设计用于序列数据,包括可选的通过self.to_latents_from_mean_pooled_seq从平均池化序列生成额外的潜在表示。
  5. 可学习的参数

    • ImageProjModel:主要参数是线性层的权重。
    • Resampler:除了线性层的权重外,还包括可学习的潜在表示self.latents
  6. 输出

    • ImageProjModel:输出经过投影和归一化的图像嵌入。
    • Resampler:输出经过多层处理和归一化的序列特征。
  7. 特殊函数

    • Resampler中使用了masked_mean函数,这表明它可能用于处理带有掩码的序列数据,例如在处理变长序列时。

总结来说,ImageProjModel是一个简单的投影模型,可能用于将图像特征投影到一个多维空间以便于与其他类型的数据结合。而Resampler是一个更复杂的模型 (主要来源于论文1),设计用于处理序列数据,并通过注意力和前馈网络层进行特征转换。

2.1 ImageProjModel 代码

class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

2.2 Resampler 代码

Flamingo 论文中的图像,可与代码对照理解。
在这里插入图片描述

class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
        
        print(embedding_dim, dim)
        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        return self.norm_out(latents)


def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

  1. Flamingo: a Visual Language Model for Few-Shot Learning ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2080085.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

19c库启动报ORA-600 kcbzib_kcrsds_1---惜分飞

一套19c的库由于某种情况,发现异常,当时的技术使用隐含参数强制拉库,导致数据库启动报ORA-00704 ORA-600 kcbzib_kcrsds_1错误 2024-08-24T06:11:25.49430408:00 ALTER DATABASE OPEN 2024-08-24T06:11:25.49437008:00 TMI: adbdrv open database BEGIN 2024-08-24 06:11:25.49…

Iptables-快速上手

Iptables firewall 防火墙Iptables简述一、Iptables的四表五链1.filter表2.nat表3.raw表4. mangle表5.数据包的流通过程 二、快速上手1. 查看规则2. 规则详细3. 添加规则4. 自定义链 三、关于iptables和docker1. 背景2. 解决方案 firewall 防火墙 从逻辑上讲,可以分…

【国外比较权威的免费的卫星数据网站】

国外比较权威的免费卫星数据网站有多个,它们各自在数据覆盖范围、分辨率、以及数据种类等方面具有不同的特点和优势。以下是一些推荐的网站: NASA Worldview 网址:https://worldview.earthdata.nasa.gov/简介:NASA Worldview显示…

Visual Studio解决scanf不能正常输入的问题

总所周知,vs中直接使用scanf会报错,用scanf_s就不会,然而很多时候我们用的还是scanf,下面讲解如何在vs中使用scanf 🎁1.添加#define _CRT_SECURE_NO_WARNINGS 不做任何处理,会出现的报错 注意下方的C499…

MySQL商品复购率计算

先看表格 复购率计算: 根据商品ID、商品名称、订单状态、订单创建时间、收货人电话来进行复购率计算: select b.商品ID,b.名称,b.购买人数,c.复购人数,c.复购人数/b.购买人数 as "复购率" from ( select 商品ID,max(商品名称) as "名称…

嵌入式学习day34

单循环服务器:同一时刻,只能处理一个客户端的任务 并发服务器:同一时刻,能够处理多个客户端的任务 UDP不需要创建连接 TCP并发服务器 1.多进程 2.多线程 3.IO多路复用 1、多进程 2、多线程 3、IO多路复用 IO模型&#xff1a…

机器学习:K-means算法(内有精彩动图)

目录 前言 一、K-means算法 1.K-means算法概念 2.具体步骤 3.精彩动图 4.算法效果评价 二、代码实现 1.完整代码 2.结果展示 3.步骤解析 1.数据预处理 2.建立并训练模型 3.打印图像 四、算法优缺点 1.优点 2.缺点 总结 前言 机器学习里除了分类算法&#xff0…

Threejs绘制方形管道

之前有用Threejs的TubeGeometry绘制管道效果,但是TubeGeometry的管道效果默认是圆形的截面,这节实现方形截面的管道绘制。 因为Threejs不提供方形截面的管道,所以使用的是绘制截面,然后拉伸的方式,所以需要先绘制一个方…

【FPGA数字信号处理】- 什么是时域

​数字信号处理的领域中,时域是我们理解和处理数字信号的关键维度之一。 时域分析能够让我们直接观察信号随时间的变化情况,为后续的信号处理和系统设计提供坚实的基础。 接下来将以通俗易懂的方式,让大家深入了解数字信号处理基础中的时域…

算法学习:一维数组的排序算法

【排序算法】八种排序算法可视化过程_哔哩哔哩_bilibili 1,冒泡排序: 冒泡排序(Bubble Sort): 冒泡排序是一种简单的排序算法,它通过重复地交换相邻的元素,直到整个序列有序。算法思路是:从第一个元素开始,依次比较相邻的两个元素,如果前者大于后者,就交…

day-41 零钱兑换

思路 动态规划的思想&#xff0c;创建一个长度为amount的数组arr&#xff0c;arr[i]表示当amounti时的最少硬币数 解题过程 arr初始化值为Integer.MAX_VALUE&#xff0c;再令arr[0]0&#xff0c;arr[coins[j]]1(0<j<coins.length),然后i从1向后遍历&#xff08;icoins[j…

DNS劫持问题

目录 DNS劫持概述 定义 图示 ​编辑图示说明 DNS劫持的原理 1. DNS请求与响应过程 图示 ​编辑2. 劫持发生点 本地劫持 路由器劫持 中间人攻击 图示 ​编辑图示说明 DNS劫持的影响 1. 对个人用户的影响 图示 ​编辑图示说明 2. 对企业的影响 图示 ​编辑图示…

2024年8月23日(docker 数据存储)

1、打包 [rootdocker1 ~]# docker save -o centos.tar centos:latest [rootdocker1 ~]# systemctl start docker [rootdocker1 ~]# docker ps -all CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES e84261634543 …

LoadBalancer负载均衡

一、概述 1.1、Ribbon目前也进入维护模式 Spring Cloud Ribbon是基于Netflix Ribbon实现的一套客户端负载均衡的工具。 简单的说&#xff0c;Ribbon是Netflix发布的开源项目&#xff0c;主要功能是提供客户端的软件负载均衡算法和服务调用。Ribbon客户端组件提供一系列完善的…

监控领域的物理对抗攻击综述——Physical Adversarial Attacks for Surveillance: A Survey

介绍 文章贡献 框架提出&#xff1a;提出了一个新的分析框架&#xff0c;用于理解和评估生成和设计物理对抗性攻击的方法。全面调查&#xff1a;对物理对抗性攻击在监控系统中的四个关键任务—检测、识别、跟踪和行为识别—进行了全面的调查和分析。跨领域探索&#xff1a;讨…

OpenHarmony轻量设备Hi3861芯片开发板启动流程分析

引言 OpenHarmony作为一款万物互联的操作系统&#xff0c;覆盖了从嵌入式实时物联网操作系统到移动操作系统的全覆盖&#xff0c;其中内核包括LiteOS-M,LiteOS-A和Linux。LiteOS-M内核是面向IoT领域构建的轻量级物联网操作系统内核&#xff0c;主要面向没有MMU的处理器&#x…

数据结构---顺序表---单链表

目录 一、什么是程序&#xff1f; 程序 数据结构 算法 二、一个程序释放优秀的两个标准 2.1.时间复杂度 2.2.空间复杂度 三、数据结构 3.1.数据结构间的关系 1.逻辑结构 1&#xff09;线性关系 2&#xff09;非线性关系 2.存储结构 1&#xff09;顺序存储结构 …

Python的起源与发展历程:从创意火花到全球热门编程语言

目录 创意的火花名字的由来圣诞节的礼物社区的力量今天的Python Python的起源可以追溯到1989年&#xff0c;当时荷兰计算机科学家Guido van Rossum&#xff08;吉多范罗苏姆&#xff09;在阿姆斯特丹的荷兰国家数学和计算机科学研究所&#xff08;CWI&#xff09;工作。Python的…

Android Studio 自定义字体大小

常用编程软件自定义字体大全首页 文章目录 前言具体操作1. 打开设置对话框2. 选择外观字体 前言 Android Studio 自定义字体大小&#xff0c;统一设置为 JetBrains Mono &#xff0c;大小为 14 具体操作 【File】>【Settings...】>【Appearance & Behavior】>【…

计算机视觉编程 3(图片处理)

目录 图像差分 高斯差分 形态学-物体计数 ​编辑 图片降噪 图像差分 # -*- coding: utf-8 -*- from PIL import Image from pylab import * from scipy.ndimage import filters import numpy# 添加中文字体支持 from matplotlib.font_manager import FontProperties font…