Multi-Concept Customization of Text-to-Image Diffusion——【代码复现】

news2024/11/14 20:22:22

本文是发表于CVPR 2023上的一篇论文:[2212.04488] Multi-Concept Customization of Text-to-Image Diffusion (arxiv.org)

一、引言

本文主要做的工作是对stable-diffusion的预训练模型进行微调,需要的显存相对较多,论文中测试时是在两块GPU上微调,需要30GB的显存,不过他调的batchsize=8,因为我自己的算力有限,我把复现的时候把batchsize调成了2,然后在两块3090上跑的,至于最低要求多少还没测试,不过个人认为最低也要有一张3090。

在复现前,请自行安装好Python的环境,本文就不叙述了哈哈。

二、下载相关文件及搭建环境

1.下载项目及环境搭建

adobe-research/custom-diffusion: Custom Diffusion: Multi-Concept Customization of Text-to-Image Diffusion (CVPR 2023) (github.com)

上述链接是本文代码的链接,这篇文章的代码实际上是基于Stable-diffusion构建的,所以我的建议是可以先去复现一下stable-diffusion的代码,再来学习这篇文章以及代码。stable-diffusion的复现可以看我另外一篇文章:stable-diffusion复现笔记,当然如果你想直接上手,可以按照项目中readme来构建,这里我默认已经有装过stable-diffusion了哈,因为很多文件都是相同的,如果你是直接上手,有些文件比如sd-v1-4.ckpt的下载等问题,都可以去看我这篇stable-diffusion复现笔记。

git clone https://github.com/adobe-research/custom-diffusion.git
cd custom-diffusion
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm 

上述是论文给出的环境搭建代码,如果你跟我一样已经做过stable-diffusion的安装,可以直接执行最后一行 pip install clip-retrieval tqdm 。

2.下载数据集

复现的时候我用的是官方给的数据集,下载地址:https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip

三、运行

复现的过程我主要采用以生成的图像作为正则化来实现,方便起见,主要还是按照官方给的示例来复现。

1.单一概念微调——生成的图像作为正则化

第一步:这里我们可以直接执行命令文件,<pretrained-model-path>是预训练模型的路径,如:/data/disk1/sxtang/models/sd-v1-4.ckpt

bash scripts/finetune_gen.sh "cat" data/cat gen_reg/samples_cat  cat finetune_addtoken.yaml <pretrained-model-path>

这个sh文件会执行两个脚本文件:sample.py、train.py。

先执行sample.py生成用于正则化的图像,一共是200张,然后再执行train.py文件对预训练的模型进行微调,如果一切顺利,命令行最后的输入应该如下:

生成的正则化图像的目录:

 

微调所得模型目录:

 

复现过程中我所遇到的问题:

(1).我是在RTX3090上进行采样生成图片的,但是如果按照代码中默认的参数去执行,我的显存是不够的(论文毕竟是在两块A100做的),然后我的解决方法是把参数调了一下,改成:

--n_samples 5  --n_iter 40 

这里主要还是根据自己的情况去调整,如果还是爆显存的话,可以把数值都调小点,然后多执行几次sample脚本也是可以的。

(2).之前也说了,代码默认的batchsize=4,我跑不了哈哈,所以调整一下batchsize的大小。

具体的,在configs/custom-diffusion/finetune_addtoken.yaml文件中更改:

(3).TypeError: CUDACallback.on_train_epoch_end() missing 1 required positional argument: 'outputs'问题。

这里主要是pytorch-lighting的版本问题,需要把这个outputs参数删掉,具体的,在train.py文件下的on_train_epoch_end函数中:

 

(4).pytorch_lightning.utilities.exceptions.MisconfigurationException: No `test_dataloader()` method defined to run `Trainer.test`.


 这里说什么没定义这个方法,解决的方法就是在运行的时候直接加上参数--no-test即可。

第二步:更新权重

执行下面的命令即可实现,这里<folder-name> 就是你微调后的那个模型的文件夹,比如:2024-01-13T14-11-49_cat-sdv4,这一步我在执行过程中没有遇到什么问题。

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 1

第三步:运行

## sample
python sample.py --prompt "<new1> cat playing with a ball" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

这个new1就是个占位符,无需更改;<folder-name>和上述的含义一样,最后这个“000004.ckpt”是你想要用的权重文件名称。 最后--ckpt <pretrained-model-path> 就是预训练的模型路径。

如果一切顺利的话,就会出图啦!

图片存放的位置以及我生成的图片如下:

 

2.多概念微调——生成的图像作为正则化

官方的readme中只给出了基于真实图像的代码,所以自己实现了一下生成图像正则化。

第一步:生成正则化图像。

上面我们已经生成的cat的正则化图像,这里还需要wooden_pot的正则化图像,所以我们需要先采样生成图像,我这里用的命令如下:

python -u sample.py \
        --n_samples 5 \
        --n_iter 40 \
        --scale 6 \
        --ddim_steps 50  \
        --ckpt  /data/disk1/sxtang/models/sd-v1-4.ckpt  \  #预训练模型的路径
        --ddim_eta 1. \
        --outdir "gen_reg/samples_wooden_pot" \   # 输出图像的路径
        --prompt "photo of a wooden_pot" 

 第二步:微调,这里我稍微改了一下那个项目中给出的基于真实图像实现的.sh文件

#!/usr/bin/env bash
#### command to run with retrieved images as regularization
# 1st arg: target caption1
# 2nd arg: path to target images1
# 3rd arg: path where retrieved images1 are saved
# 4rth arg: target caption2
# 5th arg: path to target images2
# 6th arg: path where retrieved images2 are saved
# 7th arg: name of the experiment
# 8th arg: config name
# 9th arg: pretrained model path

ARRAY=()

for i in "$@"
do
    echo $i
    ARRAY+=("${i}")
done


python -u  train.py \
        --base configs/custom-diffusion/${ARRAY[7]}  \
        -t --gpus 6,7 \
        --resume-from-checkpoint-custom  ${ARRAY[8]} \
        --caption "<new1> ${ARRAY[0]}" \
        --datapath ${ARRAY[1]} \
        --reg_datapath "${ARRAY[2]}/samples" \
        --reg_caption "${ARRAY[0]}" \
        --caption2 "<new2> ${ARRAY[3]}" \
        --datapath2 ${ARRAY[4]} \
        --reg_datapath2 "${ARRAY[5]}/samples" \
        --reg_caption2 "${ARRAY[3]}" \
        --modifier_token "<new1>+<new2>" \
        --name "${ARRAY[6]}-sdv4"

 执行命令:

bash scripts/finetune_joint_gen.sh "wooden pot" data/wooden_pot gen_reg/samples_wooden_pot \
                                    "cat" data/cat gen_reg/samples_cat  \
                                    wooden_pot+cat finetune_joint.yaml /data/disk1/sxtang/models/sd-v1-4.ckpt

注:如果需要调整如batchsize等参数,这里是在finetune_joint.yaml文件中更改。

如果一切顺利,出现如下界面,就代表着微调成功啦:

后面两步和单个概念那边一样,这里不过多叙述。

第二步:更新权重

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 2

 第三步:运行

## sample
python sample.py --prompt "the <new2> cat sculpture in the style of a <new1> wooden pot" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

下面是我测试所生成的图像:

四、最后

这篇文章和Dreambooth等有着异曲同工之妙,都是为了实现个性化的图像生成,当然论文中还有比如通过diffusers实现等功能,如果感兴趣可以自己去试试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1380929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

原子操作类原理剖析

UC包提供了一系列的原子性操作类&#xff0c;这些类都是使用非阻塞算法CAS实现的&#xff0c;相比使用锁实现原子性操作这在性能上有很大提高。 由于原子性操作类的原理都大致相同&#xff0c;所以只讲解最简单的AtomicLong类的实现原理以及JDK8中新增的LongAdder和LongAccumu…

领导风格测试

领导风格指的是管理者在开展管理工作时的思维和行为模式&#xff0c;通常我们也称之为习惯&#xff0c;或者是人格特征。这种习惯是固化的&#xff0c;是长期的经历所形成的&#xff0c;其中包含了个人的知识&#xff0c;经验&#xff0c;人际关系等。领导风格测试是企业人才选…

必练的100道C语言程序设计练习题(上)

前言: 在计算机编程的世界中&#xff0c;C语言一直是一门备受推崇的语言。它的简洁性、高效性以及广泛应用使得学习C语言成为每一位程序员的必由之路。然而&#xff0c;掌握这门语言并不是一蹴而就的事情&#xff0c;它需要不断的练习和实践。为了帮助各位编程爱好者更好地理解…

进程和线程的比较

目录 一、前言 二、Linux查看进程、线程 2.1 Linux最大进程数 2.2 Linux最大线程数 2.3 Linux下CPU利用率高的排查 三、线程的实现 四、上下文切换 五、总结 一、前言 进程是程序执行相关资源&#xff08;CPU、内存、磁盘等&#xff09;分配的最小单元&#xff0c;是一…

Halcon边缘滤波器edges_image 算子

Halcon边缘滤波器edges_image 算子 基于Sobel滤波器的边缘滤波方法是比较经典的边缘检测方法。除此之外&#xff0c;Halcon也提供了一些新式的边缘滤波器&#xff0c;如edges_image算子。它使用递归实现的滤波器&#xff08;如Deriche、Lanser和Shen&#xff09;检测边缘&…

【无标题】关于异常处理容易犯的错

一般项目是方法打上 try…catch…捕获所有异常记录日志&#xff0c;有些会使用 AOP 来进行类似的“统一异常处理”。 其实&#xff0c;这种处理异常的方式非常不可取。那么今天&#xff0c;我就和你分享下不可取的原因、与异常处理相关的坑和最佳实践。 捕获和处理异常容易犯…

Java研学-分页查询

一 分页概述 1 介绍 将大量数据分段显示&#xff0c;避免一次性加载造成的内存溢出风险 2 真假分页 ① 真分页   一次性查询出所有数据存到内存&#xff0c;翻页从内存中获取数据&#xff0c;性能高但易造成内存溢出 ② 假分页   每次翻页从数据库中查询数据&#xff0c…

day16 二叉树的最大深度 n叉树的最大深度 二叉树的最小深度 完全二叉树的节点数

题目1&#xff1a;104 二叉树的最大深度 题目链接&#xff1a;104 二叉树的最大深度 题意 二叉树的根节点是root&#xff0c;返回其最大深度&#xff08;从根节点到最远叶子节点的最长路径上的节点数&#xff09; 递归 根节点的的高度就是二叉树的最大深度 所以使用后序遍…

【Node.js学习 day4——模块化】

模块化介绍 什么是模块化与模块&#xff1f; 将一个复杂的程序文件依据一定规则&#xff08;规范&#xff09;拆分成多个文件的过程称之为模块化 其中拆分的每个文件就是一个模块&#xff0c;模块的内部数据是私有的&#xff0c;不过模块可以暴露内部数据以便其他模块使用。什…

010集:with as 代码块读写关闭文件—python基础入门实例

接009集&#xff1a; 读写文本文件的相关方法如下。 read &#xff08; size-1 &#xff09;&#xff1a;从文件中读取字符串&#xff0c; size 限制读取的字符数&#xff0c; si ze-1 指对读取的字符数没有限制。 readline &#xff08; size-1 &#xff09;&#xff1a;在…

常见的加密算法

加密算法 AES 高级加密标准(AES,Advanced Encryption Standard)为最常见的对称加密算法(微信小程序加密传输就是用这个加密算法的)。对称加密算法也就是加密和解密用相同的密钥&#xff0c;具体的加密流程如下图&#xff1a; RSA RSA 加密算法是一种典型的非对称加密算法&am…

JavaScript数据类型、判断、检测

JavaScript数据类型 number、string、boolean、null、undefined、symbol、bigint Object【Array、RegExp、Date、Math、Function】 存储方式 1. 基础类型存储在栈内存中&#xff0c;被引用或者拷贝时&#xff0c;会创建一个完全相同的变量。 2. 引用类型存放在堆内存中&…

[redis] redis高可用之持久化

一、Redis 高可用的相关知识 1.1 什么是高可用 在web服务器中&#xff0c;高可用是指服务器可以正常访问的时间&#xff0c;衡量的标准是在多长时间内可以提供正常服务(99.9%、99.99%、99.999%等等)。 但是在Redis语境中&#xff0c;高可用的含义似乎要宽泛一些&#xff0c;…

wpf使用Popup封装数据筛选框

(关注博主后,在“粉丝专栏”,可免费阅读此文) 类似于DevExpress控件的功能 这是DevExpress的winform筛选样式,如下: 这是DevExpress的wpf筛选样式,如下: 这是Excel的筛选样式,如下: 先看效果 本案例使用wpf原生控件封装,功能基本上都满足,只是颜色样式没有写…

轻松掌握构建工具:Webpack、Gulp、Grunt 和 Rollup 的使用技巧(下)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

蓝桥杯省赛无忧 STL 课件15 queue

01 queue队列 02 priority_queue优先队列 接下来介绍几种优先队列修改比较函数的方法 03 deque双端队列 04 例题讲解 https://www.lanqiao.cn/problems/1113/learning/?page1&first_category_id1&problem_id1113输入 5 IN xiaoming N IN Adel V IN laozhao N OUT …

VMware workstation搭建与安装AlmaLinux-9.2虚拟机

VMware workstation搭建与安装AlmaLinux-9.2虚拟机 适用于需要在VMware workstation平台安装AlmaLinux-9.2&#xff08;最小化安装、无图形化界面&#xff09;虚拟机。 1. 安装准备 1.1 安装平台 Windows 11 1.2. 软件信息 软件名称软件版本安装路径VMware-workstation 1…

一些硬件知识(三)

uint8_t, uint32_t, 和 uint16_t 是 C 和 C 语言中的数据类型&#xff0c;它们分别表示无符号的 8 位、32 位和 16 位整数。这些数据类型定义在标准库 <stdint.h>&#xff08;在 C 语言中&#xff09;或 <cstdint>&#xff08;在 C 中&#xff09;。 uint8_t&…

记录 | ubuntu软链接查看、删除、创建

软连接查看 ls -il 软连接删除 rm -rf ** 软连接创建 ln -s 源文件 目标文件 实例&#xff0c;软连接报错&#xff1a; 若要建立libtiny_reid.so*间软连接&#xff1a; 先删除 rm -rf libtiny_reid.so libtiny_reid.so.3 libtiny_reid.so.3.1 再建立 ln -s libtiny_re…

Nocalhost 为 KubeSphere 提供更强大的云原生开发环境

1 应用商店安装 Nocalhost Server 已集成在 KubeSphere 应用商店&#xff0c;直接访问&#xff1a; 设置应用「名称」&#xff0c;确认应用「版本」和部署「位置」&#xff0c;点击「下一步」&#xff1a; 在「应用设置」标签页&#xff0c;可手动编辑清单文件或直接点击「安装…