使用Dreambooth LoRA微调SDXL 0.9

news2024/11/20 3:18:22

本文将介绍如何通过LoRA对Stable Diffusion XL 0.9进行Dreambooth微调。DreamBooth是一种仅使用几张图像(大约3-5张)来个性化文本到图像模型的方法。

本教程基于通过LoRA进行Unet微调,而不是进行全部的训练。LoRA是在LoRA: Low-Rank Adaptation of Large Language Models中引入的一种参数高效的微调技术。

本文基于diffusers包,至少需要0.18.2或更高版本。

基于GeForce RTX 4090 GPU (24GB)的本地实验,VRAM消耗如下:

  • 512分辨率- 11GB用于训练,19GB保存检查点
  • 1024分辨率- 17GB的训练,19GB时保存检查点

环境设置

建议创建一个新的虚拟环境,下面是我们需要使用的python包

Pytorch

 pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

Diffusers

 pip install git+https://github.com/huggingface/diffusers

LoRA进行SDXL 0.9 Dreambooth微调需要0.19.0.dev0及以上版本的diffusers

还有一些其他的依赖包:

 pip install invisible_watermark transformers accelerate safetensors

然后就是进行配置,在终端上执行如下命令配置accelerate:

 accelerate config

使用以下设置在单个GPU上进行混合精度的训练:

 ----------------------------------------------------------------------------------------------------------------------------
 In which compute environment are you running?
 This machine
 ----------------------------------------------------------------------------------------------------------------------------
 Which type of machine are you using?
 No distributed training
 
 Do you want to run your training on CPU only (even if a GPU is available)? [yes/NO]:
 no
 
 Do you wish to optimize your script with torch dynamo?[yes/NO]:
 no
 
 Do you want to use DeepSpeed? [yes/NO]: 
 no
 
 What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:
 all
 ----------------------------------------------------------------------------------------------------------------------------
 Do you wish to use FP16 or BF16 (mixed precision)?
 fp16

或者,使用以下命令使用默认值

 accelerate config default

数据集

我们这里将介绍Dreambooth微调所需的最简单配置。对于数据集的准备,只需收集一些相同主题或风格的图像,并将其放在一个目录中。

比如下面的文件夹结构:

 data/xxx.png
 data/xxy.png
 ...
 data/xxz.png
 data/yyz.png

这里要确保所有的训练图像都是相同的大小。如果大小不同,需要先调整大小。建议使用1024 * 1024作为图像分辨率。

我们这里使用dog示例数据集通过LoRA测试Dreambooth微调。这个数据集可以直接从网站下载,以下Python脚本可以将其下载到本地:

 from huggingface_hub import snapshot_download
 
 local_dir = "./data"
 snapshot_download(
     "diffusers/dog-example",
     local_dir=local_dir, repo_type="dataset",
     ignore_patterns=".gitattributes",
 )

微调训练

在官方库下载train_dreambooth_lora_sdxl.py训练脚本。将该文件放在工作目录中。

如果你使用的是旧版本的diffusers,它将由于版本不匹配而报告错误。但是你可以通过在脚本中找到check_min_version函数并注释它来轻松解决这个问题,如下所示:

 # check_min_version("0.19.0.dev0")

虽然可以用,但是还是建议使用官方的推荐版本。

如果全部设置正确,那么可以通过LoRA进行Dreambooth微调的训练命令:

 accelerate launch train_dreambooth_lora_sdxl.py \
   --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-0.9" \
   --instance_data_dir=data \
   --output_dir=output \
   --mixed_precision="fp16" \
   --instance_prompt="a photo of zwc dog" \
   --resolution=1024 \
   --train_batch_size=1 \
   --gradient_accumulation_steps=4 \
   --learning_rate=1e-4 \
   --lr_scheduler="constant" \
   --lr_warmup_steps=0 \
   --checkpointing_steps=500 \
   --max_train_steps=1000 \
   --seed="0" \
   --checkpoints_total_limit=5

对于Windows用户,将所有\符号替换为^符号。因为转义符不同

简单的介绍一些参数:

  • instance_prompt:带有指定实例标识符的提示符。
  • resolution:输入图像的分辨率,训练/验证数据集中的所有图像都将调整为此大小。默认值是512,将其设置为1024,因为它是用于SDXL训练的分辨率。
  • train_batch_size:训练数据加载器的批处理大小(每个设备)。减少批处理大小,防止训练过程中出现内存不足错误。
  • num_train_steps:训练步数。建议设置为N × 100,其中N表示训练图像的个数。
  • checkpointing_steps:每X次更新时保存训练状态的检查点。默认为500。将其设置为更高的值以减少保存的检查点数量,因为模型需要保存到磁盘,所以频繁的保存会降低训练速度。
  • checkpoints_total_limit:限制保存的检查点的数量。将删除/删除旧的检查点。

在第一次运行是,程序会下载Stable Diffusion模型并将其保存在本地缓存文件夹中,如果网不好的话这里会很慢。在随后的运行中,它将重用相同的缓存数据。

请注意SDXL 0.9权重需要登录HuggingFace并接受许可。然后,通过HuggingFace -cli命令登录,并使用从HuggingFace设置中获取的API令牌。

默认情况下,每个checkpointing_steps脚本只保存一次LoRA权重和一些检查点文件。

最后我们的结果如下:

 |- output
 |  |- checkpoint-500
 |  |- checkpoint-1000
 |  |- checkpoint-1500
 |  |- checkpoint-2000
 |- data
 |- train_dreambooth_lora_sdxl.py

上面的每个checkpoint文件夹包含以下文件:

  • optimizer.bin
  • pytorch_lora_weights.bin
  • random_states_0.pkl
  • scaler.pt
  • scheduler.bin

pytorch_lora_weights.bin文件可以直接用于推理。

推理

创建一个名为inference.py的新Python文件:

 from diffusers import DiffusionPipeline
 import torch
 #初始化,加载所需的LoRA权重
 pipe = DiffusionPipeline.from_pretrained(
     "stabilityai/stable-diffusion-xl-base-0.9",
     torch_dtype=torch.float16,
     variant="fp16",
     use_safetensors=True
 )
 # load LoRA weight
 pipe.unet.load_attn_procs("data/checkpoint-2000/pytorch_lora_weights.bin", use_safetensors=False)
 pipe.enable_model_cpu_offload()
 
 refiner = DiffusionPipeline.from_pretrained(
     "stabilityai/stable-diffusion-xl-refiner-0.9",
     torch_dtype=torch.float16,
     variant="fp16",
     use_safetensors=True
 )
 refiner.enable_model_cpu_offload()
 
 #推理和保存文件
 seed = 12345
 n_steps = 50
 prompt = "a photo of zwc dog in a bucket"
 
 generator = torch.Generator(device="cuda").manual_seed(seed)
 latent_image = pipe(prompt=prompt, num_inference_steps=n_steps, generator=generator, output_type="latent").images[0]
 image = refiner(prompt=prompt, num_inference_steps=n_steps, generator=generator, image=latent_image).images[0]
 image.save("image.jpg")

然后我们可以执行如下命令:

 python inference.py

结果展示

以下是我做的一个快速测试,使用16张具有各种情绪的chibi 人物图像作为训练数据集。

分辨率1024 × 1024 、duoduo 作为实例提示

大约花了4个小时的训练,下面的输出示例:

总结

使用我们上面的代码可以随意使用不同的数据集和训练配置进行实验,以获得所需的结果。

本文首先简要介绍了Dreambooth和LoRA背后的基本概念。然后介绍了通过pip install进行安装的过程。还探讨了数据集的准备。然后整理了训练命令,并对一些常用的训练参数进行了详细的说明。并使用代码加载新训练的LoRA权重,根据输入提示生成相应的图像。最后展示了一个在本地进行的训练的简单实验。

本文使用的主要库:

https://avoid.overfit.cn/post/0423804f782b4cb9a74f1ae6a6f99b34

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/755839.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何通过CRM系统减低客户流失率并提高销售业绩?

销售人员如何提高业绩,减低客户流失率?通过CRM客户管理系统与客户建立良好的客户关系、提升客户体验助力销售人员业绩节节攀升,降低客户流失率。接下来我们就来说一说CRM系统如何实现的? 1.全渠道沟通提升客户体验 只有足够多的…

搜索结果处理

1、排序 #sort排序 GET /hotel/_search {"query": {"match_all": {}},"sort": [{"score": "desc"},{"price": "asc"}] }#找到121.6,31周围的酒店,距离升序排序 GET /hotel/_sea…

前端学习——JS进阶 (Day2)

深入对象 创建对象三种方式 构造函数 小练习 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport"…

python_openpyxl常用语法记录

目录 写在前面&#xff1a; 开始 工作薄 and 工作表 操作单元格 访问&#xff08;调用&#xff09;单个单元格 访问&#xff08;调用&#xff09;多个单元格 保存工作薄 用例 例&#xff1a;简单数据和条形图 操作样式 样式的默认值 单元格样式和命名样式 单元格样…

配置Hadoop_0

配置Hadoop_0 1配置Hadoop100模板虚拟机1.1配置Hadoop100模板虚拟机硬件1.2配置Hadoop100模板虚拟机软件1.3配置Hadoop100模板虚拟机IP地址1.4配置Hadoop100模板虚拟机主机名称/主机名称映射1.5配置Hadoop100模板虚拟机远程操作工具 1配置Hadoop100模板虚拟机 Hadoop100 内存…

编译给IOS平台用的liblzma库(xz与lzma)

打开官方网: XZ Utils 新工程仓库: git clone https://git.tukaani.org/xz.git git clone https://github.com/tukaani-project/xz 旧工程仓库: git clone https://git.tukaani.org/lzma.git 旧版本工程编译: 先进行已下载好的lzma目录 执行./autogen.sh生成configure配置…

233. 数字 1 的个数

题目描述&#xff1a; 主要思路&#xff1a; 寻找1的个数主要分为两个部分&#xff1a;完整的1和取余的1。 完整的1&#xff1a;从个位一直到最高位&#xff0c;例如十位上的1可以出现10次&#xff0c;10-19&#xff0c;然后看他的高位&#xff0c;看看可以出现几轮循环。 取余…

QListWidget 小节

QListWidget 小节 QListWidget 简介举例UI设计头文件源文件 QListWidget 简介 以下是 QListWidget 常用函数的一些说明&#xff1a; addItem(item)&#xff1a;向列表中添加一个项。 addItems(items)&#xff1a;向列表中添加多个项。 clear()&#xff1a;清空列表中的所有项…

射线与物质的相互作用

射线与物质的相互作用 射线与物质的相互作用概要 电离——核外层电子克服束缚成为自由电子&#xff0c;原子成为正离子激发——使核外层电子由低能级跃迁到高能级而使原子处于激发状态&#xff0c;退激发光 射线 致电离辐射 慢化 电离损失&#xff1a;带电粒子与靶物质原子…

this指针/闭包及作用域

一.作用域链 1.通过一个例子 let aglobalconsole.log(a);//globalfunction course(){let bjsconsole.log(b);//jssession()function session(){let cthisconsole.log(c);//Windowteacher()//函数提升function teacher(){let dstevenconsole.log(d);//stevenconsole.log(test1,…

【unity之IMGUI实践】单例模式管理数据存储【二】

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;uni…

FocusState, SubmitTextField 的使用

1. FocusState 输入文本添加焦点状态 1.1 实现 /// 输入文本焦点状态 struct FocusStateBootcamp: View {// 使用枚举enum OnboardingFields: Hashable{case usernamecase password}//FocusState private var usernameInFocus: BoolState private var username: String "…

两分钟python发个邮件

python简单发个邮件 直接上代码测试 之前spring boot简单发送发送个邮件大约5min&#xff0c;ennn这个python发个邮件两三分钟吧 直接上代码 import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMETextclass MailTest(object):def…

Flink 在新能源场站运维的应用

摘要&#xff1a;本文整理自中南电力设计院工程师、注册测绘师姚远&#xff0c;在 Flink Forward Asia 2022 行业案例专场的分享。本篇内容主要分为四个部分&#xff1a; 1. 建设背景 2. 技术架构 3. 应用落地 4. 后续及其他 Tips&#xff1a;点击「阅读原文」免费领取 5000CU*…

Vue localhost 从 http 307 到 https

Vue localhost 从 http 307 到 https HTTP 307 与 HSTS HTTP 307中间人攻击HSTS - HTTP Strict Transport Security 如何解决问题 Vue localhost 从 http 307 到 https 一个 Vue2 项目之前本地都是通过 HTTP 的 localhost 访问(如下) 后来突然无法访问了, 提示的错误内容是 E…

静电接地桩的设计和施工

静电接地桩是用于将静电荷引导到地下的装置&#xff0c;以确保工作环境。以下是一般静电接地桩设计的一些建议和步骤&#xff1a; 1. 选择合适的位置&#xff1a;静电接地桩应该位于静电产生源附近&#xff0c;并接近地面。可以选择在室内或室外&#xff0c;但要确保容易维护和…

web中引入live2d的moc3模型

文章目录 前言下载官方sdk文件使用ide编译项目&#xff08;vsCode&#xff09;项目初始化使用vsCode项目树介绍使用live server运行index页面 演示导入自己的模型并显示modelDir文件resources文件夾案例模型修改modelDir然後重新打包項目運行 前言 先跟着官方sdk调试一遍&…

14.live555mediaserver-play请求与响应

live555工程代码路径 live555工程在我的gitee下&#xff08;doc下有思维导图、drawio图&#xff09;&#xff1a; live555 https://gitee.com/lure_ai/live555/tree/master 章节目录链接 0.前言——章节目录链接与为何要写这个&#xff1f; https://blog.csdn.net/yhb1206/art…

基于C/S架构工作原理序号工作步骤和理论的区别

基于C/S架构工作原理序号工作步骤和理论的区别 SSH 概念 对称加密linux 系统加密&#xff0c;就是加密和揭秘都是使用同一套密钥。 非对称加密有两个密钥&#xff1a;“私钥”和“公钥”。私钥加密后的密文&#xff0c;只能通过对应的公钥进行揭秘。而通过私钥推理出公钥的…

深入解析浏览器Cookie(图文码教学)

深入解析浏览器Cookie 前言一、什么是 Cookie?二、Cookie的特点二、如何创建 Cookie&#xff1f;三、服务器如何获取 Cookie四、Cookie 值的修改4.1 方案一4.2 方案二 五、浏览器查看 Cookie六、Cookie 生命控制七、Cookie 有效路径 Path 的设置八、案例&#xff1a;Cookie 练…