【扩散模型(十)】IP-Adapter 源码详解 4 - 训练细节、具体训了哪些层?

news2024/11/13 23:44:32

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)
  • 【扩散模型(九)】IP-Adapter 与 IP-Adapter Plus 的具体区别是什么?

文章目录

  • 系列文章目录
    • adapter_modules 分为两类
  • 总结


通过前面的系列文章,很清楚要训练的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。

而 image_proj_model 这块比较简单,原码如下所示

    # freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    #ip-adapter
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

adapter_modules 分为两类

  1. AttnProcessor 对应 self attention
  2. IPAttnProcessor 对应 cross attention

按理说 self attention 对应的 AttnProcessor 应该不会被训练,但是 training = True,便让人非常费解。
在这里插入图片描述
进一步查看 AttnProcessor2_0 和 IPAttnProcessor2_0 后,就清楚了,因为从 AttnProcessor2_0 的构造函数(init)中并没有参数,就算是 trianing = True 也并不影响训练,实际训练的模块还是 IPAttnProcessor2_0 构造函数中的 to_k_ip 和 to_v_ip 两层 linear!

class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
...

class IPAttnProcessor2_0(torch.nn.Module):
    r"""
    Attention processor for IP-Adapater for PyTorch 2.0.
    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        scale (`float`, defaults to 1.0):
            the weight scale of image prompt.
        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
            The context length of the image features.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
        super().__init__()

        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(

总结

  1. IP-Adapter 训的就是 image_proj_model(或者对于 plus 来说是 resampler) 和 adapter_modules 两块。
  2. 在 adapter_modules 中,实际只训了 IPAttnProcessor2_0 的 to_k_ip 和 to_v_ip。
  3. adapter_modules 是在每个有含有 cross attention 的 unet block 里进行的替换,如下图所示。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2111142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文阅读 - Coordinated Activity Modulates the Behavior and Emotions ofOrganic Users

协调活动调节有机用户的行为和情绪:有关加沙冲突的推文案例研究 https://dl.acm.org/doi/pdf/10.1145/3589335.3651483 目录 摘要 1 INTRODUCTION 2 DATA 3 METHODOLOGY 3.1 Coordinated Activity Detection 3.3 用户互动动态特征 3.4 Organic Users’ Behav…

系统架构师考试学习笔记第三篇——架构设计高级知识(17)云原生架构设计理论与实践

本章知识考点: 第17课时主要学习云原生架构设计理论与实践。根据考试大纲,本课时知识点会涉及单选题型(约占2~4分)、案例题(25分)和论文题,本课时节内容偏重于方法掌握和应用,根据以…

KEIL中编译51程序 算法计算异常的疑问

KEIL开发 51 单片机程序 算法处理过程中遇到的问题 ...... by 矜辰所致前言 因为产品的更新换代, 把所有温湿度传感器都换成 SHT40 ,替换以前的 SHT21。在 STM32 系列产品上的替换都正常,但是在一块 51 内核的无线产品上面,数据…

两个月冲刺软考——逻辑地址与物理地址的转换(例题+讲解);文件类型的考点

1.已知计算机系统页面大小和进程的逻辑地址,根据页面变换表(页号-物理块号),求变换后的物理地址。 首先介绍几个公式: 逻辑地址 页号 页内地址 (默认为32机位) 物理地址 物理块号 物理地址的页内地址 其中:页内地址 物理地址…

Kubernetes--服务发布(Service、Ingress)

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 出自B站博主教程笔记: 完整版Kubernetes(K8S)全套入门微服务实战项目,带你一站式深入掌握K8S核心能…

算法_栈专题---持续更新

文章目录 前言删除字符中的所有相邻重复项题目要求题目解析代码如下 比较含退格的字符串题目要求题目解析代码如下 基本计算器II题目要求题目解析 字符串解码题目要求题目解析代码如下 验证栈序列题目要求题目解析代码如下 前言 本文将会向你介绍有关栈的相关题目:…

matter中的Fabric(网络结构)

什么是Fabric? Fabric可以被理解为一组相互信任的设备和控制器,它们共享一个共同的信任域。这意味着在同一个Fabric中的设备和控制器之间可以进行安全的通信,而无需额外的身份验证或安全检查。每个Fabric有一个唯一的标识,确保Fab…

迁移替换AD域时,有几个关键点需要注意

在当今的数字化时代,企业对于身份管理和访问控制的需求日益增长。然而,传统的AD域控方案在面对国产化替代和业务上云的趋势时,逐渐显露出一些局限性。宁盾国产化身份域管作为一种迁移替换AD域控的创新解决方案,正逐渐崭露头角&…

文心一言 VS 讯飞星火 VS chatgpt (341)-- 算法导论23.1 10题

十、给定图 G G G和 G G G的一棵最小生成树 T T T,假设减小了 T T T中一条边的权重。证明: T T T仍然是 G G G的一棵最小生成树。更形式化地,设 T T T为 G G G的一棵最小生成树, G G G的边权重由权重函数 w w w给出。选择一条边 (…

职称评审中,论文发表要求?

无论是医生、教师或其他等职业,职称评审无疑是一个非常重要的环节。而职称评审中的论文发表则是评定我们专业能力的重要一环,可如何才能让自己辛苦撰写的的论文被发表,达到论文发表都有哪些要求呢? 一、选题要新颖 编辑和审稿人…

VMware的三种网络模式及应用场景

在VMware中,虚拟机网络连接的方式主要有三种模式:桥接模式(Bridged Mode)、NAT模式(Network Address Translation Mode)、仅主机模式(Host-Only Mode)。每种模式都有其独特的用途和配…

SSM+Ajax实现广告系统

文章目录 1.案例需求2.编程思路3.案例源码(这里只给出新增部分的Handler和ajax部分,需要详情的可以私信我)4.小结 1.案例需求 使用SSMAjax实现广告系统,包括登录、查询所有、搜索、新增、删除、修改等功能,具体实现的效果图如下:…

『功能项目』状态模式转换场景【30】

本章项目成果展示 打开上一篇29Unity本地数据库读取进入游戏的项目, 本章要做的事情是通过状态者模式转换场景,在进入账号登陆界面前闪烁显示Logo 首先创建一个新的场景命名为StartUI 修改游戏场景名字 重命名为FightGame01 首先创建一个脚本文件夹&…

gazebo可能打不开的问题

如果经常遇到gazebo只能断网才能运行的时候,主要就是因为无法联网访问gazebo的在线模型库,此时我们一般无法在联网的情况下打开gazebo。 这个时候就直接将下载好的模型先放到~/.gazebo/models/文件夹下面即可: https://github.com/osrf/gazeb…

RTOS Sensor Framework对比

1.背景 传感器(Sensor)是物联网(IOT)的重要组成部分,用于感知和采集环境中的各种数据,大部分智能硬件都需要。 为使传感器能正常⼯作,驱动开发者需要开发不同的驱动程序,驱动程序要实现芯片寄存器\IO设置,又要响应使用…

搭建 canal 监控mysql数据到 elasticsearch 中(本机到远端sql)

搭建 canal 监控mysql数据到 elasticsearch 中(本机到远端sql) 需求: 要将 MySQL 数据库 info 中的 notice 和 result 表的增、删、改操作同步到 Elasticsearch 的 notice 和 result 索引,您需要正确配置 MySQL、Canal 、Canal Adapter 、 …

3--Web前端开发-前端工程化,vue项目

目录 端口配置 element 快速入门 table表格组件 分页组件 Dialog对话框组件 表单组件 端口配置 在vue.config.js中更改 源代码为 const { defineConfig } require(vue/cli-service) module.exports defineConfig({transpileDependencies: true })更改为 const { def…

Linux——redis主从复制、哨兵模式

一、redis 的安全加固: 对redis数据库访问的角度 auth // 验证登录redis 数据库的用户acl // 设置redis用户的权限将配置完成的ACL策略写入配置文件 config rewrite //目前redis生效的配置全部写入到默认配置文件的尾部写入到acl文件中,在加载配置文件时…

《论软件设计模式及其应用》通关范文,软考高级系统架构设计师

论文真题 设计模式(Design Pattern)是一套被反复使用的代码设计经验总结,代表了软件开发人员在软件开发过程中面临的一般问题的解决方案和最佳实践。使用设计模式的目的是提高代码的可重用性,让代码更容易被他人理解,并保证代码可靠性。现有的设计模式已经在前人的系统中…

每日一练:和为K的子数组

一、题目要求 给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1: 输入:nums [1,1,1], k 2 输出:2示例 2: 输入:n…