【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

news2024/11/13 8:50:42

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)

    # optimizer
    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())
    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
    ...
    # Prepare everything with our `accelerator`.
    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    image_encoder.requires_grad_(False)
    
    #ip-adapter-plus
    image_proj_model = Resampler(
        dim=unet.config.cross_attention_dim,
        depth=4,
        dim_head=64,
        heads=12,
        num_queries=args.num_tokens,
        embedding_dim=image_encoder.config.hidden_size,
        output_dim=unet.config.cross_attention_dim,
        ff_mult=4
    )
    ...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modules
    attn_procs = {}
    unet_sd = unet.state_dict()
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor()
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
            attn_procs[name].load_state_dict(weights)
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):
            load_data_time = time.perf_counter() - begin
            with accelerator.accumulate(ip_adapter):
                # Convert images to latent space
                with torch.no_grad():
                    latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []
     for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
         if drop_image_embed == 1:
             clip_images.append(torch.zeros_like(clip_image))
         else:
             clip_images.append(clip_image)
     clip_images = torch.stack(clip_images, dim=0)
     with torch.no_grad():
         image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]
 
     with torch.no_grad():
         encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2080904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ASP.NET MVC+LayUI视频上传完整教程

前言 前段时间在使用APS.NET MVCLayUI做视频上传功能的时,发现当上传一些内存比较大的视频就会提示上传失败,后来通过查阅相关资料发现.NET MVC框架为考虑安全问题,在运行时对请求的文件的长度(大小)做了限制默认为4M…

缓存Mybatis一级缓存与二级缓存

缓存 为什么使用缓存 缓存(cache)的作用是为了减去数据库的压力,提高查询性能,缓存实现原理是从数据库中查询出来的对象在使用完后不销毁,而是存储在内存(缓存)中,当再次需要获取对象时,直接从内存(缓存)中提取,不再向数据库执行select语句,从而减少了对数据库的查询次数,因此…

力扣之字母异位词分组(python)

题目 给你一个字符串数组,请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 示例 1: 输入: strs ["eat", "tea", "tan", "ate", "nat&qu…

回归预测|基于鹅GOOSE优化LightGBM的数据回归预测Matlab程序 多特征输入单输出 2024年优化算法

回归预测|基于鹅GOOSE优化LightGBM的数据回归预测Matlab程序 多特征输入单输出 2024年优化算法| 文章目录 前言回归预测|基于鹅GOOSE优化LightGBM的数据回归预测Matlab程序 多特征输入单输出 2024年优化算法GOOSE-LightGBM 一、GOOSE-LightGBM模型原理:流程&#xf…

【STM32】IIC

超级常见的外设通信方式,一般叫做I方C。 大部分图片来源:正点原子HAL库课程 专栏目录:记录自己的嵌入式学习之路-CSDN博客 目录 1 基本概念 1.1 总线结构 1.2 IIC协议 1.3 软件模拟IIC逻辑 2 AT24C02 2.1 设备地址与…

华为手机数据丢失如何恢复?

在智能手机普及的今天,华为手机凭借其卓越的性能和用户体验赢得了众多用户的青睐。然而,在使用过程中,我们难免会遇到数据丢失或误删除的情况。面对这一困境,许多用户可能会感到束手无策。别担心,本文将为你提供一份全…

什么是响应式?

表达式: 用于表达式进行插值,渲染到页面之中 语法: {{ 表达式 }} 案例 <template><h1>{{ arr[2] }}</h1><h1>{{ 9 5 }}</h1><h1>{{ "神奇" }}</h1> </template><script setup> import { ref } from vue; …

react中配置Sentry

sentry 打开sentrySentry Docs | Application Performance Monitoring &amp; Error Tracking Software官网&#xff0c; 注册。根据提示创建应用后 在 React 应用中配置 Sentry 可以按照以下步骤进行&#xff1a; 安装 Sentry SDK: 在项目根目录下运行&#xff1a; npm in…

DDR5 Channel SI设计的挑战

DDR5延续了前几代数据速率不断提高的趋势。数据传输速度在3200至6400MT/s之间。同时将继续像前几代一样使用单端数据线的方式。为了帮助减少由高数据速率引起的信号完整性问题&#xff0c;DRAM端也会考虑加入判决反馈均衡&#xff08;DFE&#xff09;来减轻反射、ISI对信号传输…

十、Java异常

文章目录 一、异常简介二、异常体系图三、常见的异常3.1 常见的运行时异常3.2 常见的编译异常 四、异常处理4.1 异常处理的方式4.2 try-catch异常处理4.2.1 try-catch异常处理基本语法4.2.2 try-catch异常处理的注意细节 4.3 throws异常处理4.3.1 throws异常处理基本介绍4.3.2 …

Android - Windows平台下Android Studio使用系统的代理

这应该是第一篇Android的博文吧。以后应该会陆续更新的。记录学习Android的点点滴滴。 之前也看过&#xff0c;不过看完书就忘了&#xff0c;现在重拾Android&#xff0c;记录学习历程。 为何要用代理 因为更新gradle太慢了。 如何使用系统的代理 先找到系统代理的ip和端口。…

YOLO与PyQt5结合-增加论文工作量-实现一个目标检测的UI界面

这是个简单的界面&#xff0c;Qtdesigner支持各种界面&#xff0c;支持替换背景添加图标等。 接下来实现一个简单YOLO目标检测界面&#xff1a; 功能&#xff1a; 1、在窗口打开视频或图片进行目标检测&#xff0c;具有中断检测功能&#xff1a;比如检测视频的时候突然打开图…

速盾:cdn可以解决带宽问题么

一、速盾 CDN 的基本概念 CDN&#xff08;Content Delivery Network&#xff09;即内容分发网络&#xff0c;速盾 CDN 是这一技术的具体应用。它的工作原理是通过在全球多地部署服务器节点&#xff0c;将网站的内容缓存到这些节点上。 速盾 CDN 具有诸多优势。首先&#xff0…

分布式百万商户架构之缓存技术 本地化及未来之窗行业应用跨平台架构

如果数据读取速度比文件读取慢&#xff0c;将数据缓存到文件有以下优点&#xff1a; 一、提高读取效率 当需要反复访问某些数据时&#xff0c;从缓存文件中读取可以大大减少读取时间。因为文件系统通常会对文件进行一定程度的优化&#xff0c;使得文件的读取更加高效。而相比之…

优雅回收多个成员变量内存——使用函数模板实现内存安全释放

目录 从析构类中的多个成员说起什么是函数模板使用函数模板 从析构类中的多个成员说起 你有没有遇到过这种情况&#xff0c;一个类的构造函数中new了很多个成员变量&#xff0c;在析构函数中回收内存时&#xff0c;写了一遍又一遍 下面的代码&#xff1a; if (ptr ! nullptr)…

EXCEL文件如何批量加密,有什么方法

EXCEL文件的加密&#xff0c;通常在EXCEL软件上进行设置&#xff0c;它有打开密码与写保护密码&#xff0c;如果有多个文件的话&#xff0c;想通过一键设置的方法进行密码设置&#xff0c;那么它通常需要用到第三方软件进行批处理&#xff0c;因为EXCEL软件只能对当前打开的文件…

Wan-本科阶段部分作品

1、简易无接触温度测量与身份识别装置&#xff08;电赛 省一&#xff09; 2、基于交叉带式分拣结构的智能垃圾分类系统&#xff08;工训赛 省二&#xff09; 3、基于STM32的智能语音风扇&#xff08;大创优秀结题&#xff09;

鸿蒙界面开发(八):Grid网格布局Badge角标组件

Badge角标组件 在目标组件的外层包裹一层Badge角标组件 支持位置&#xff1a;右上&#xff0c;左&#xff0c;右 也可以使用绝对定位实现更灵活的角标位置。 Badge({count:1,//角标数值&#xff0c;角标数值为0时不展示position:BadgePosition.RightTop,//角标位置&#xff0…

【工作实践】MVEL 2.x语法指南

目录标题 MVEL 2.x语法指南一、基本语法1. 简单属性表达式2. 复合语句3. 返回值 二、值判断1. 判断空值2. 判断Null值3. 强制转换 三、内联Lists、Maps和数组Arrays1. Lists2. Maps3. 数组Arrays4. 数组强制转换 四、属性导航1. Bean属性2. Bean的安全属性导航3. 集合(1). List…

BOSS AI

BOSS AI 人工智能一点也不智能啊&#xff0c;机器人都不考虑用户的需求和体验吗&#xff1f; 这么多&#xff0c;我怎么看&#xff0c;我也不知道对面是人呢&#xff1f;还是机器人&#xff1f; 然后推送的东西也不知道我想要的&#xff0c;难道年龄到了&#xff0c;就活该天…