Stable Diffusion 微调及推理优化实践指南

news2025/2/24 2:21:29

随着 Stable Diffsuion 的迅速走红,引发了 AI 绘图的时代变革。然而对于大部分人来说,训练扩散模型的门槛太高,对 Stable Diffusion 进行全量微调也很难入手。由此,社区催生了一系列针对 Stable Diffusion 的高效微调方案,在保留原模型泛化能力的同时,实现自定义风格的融合,最关键的是,操作简单且资源消耗量低。

本文将介绍 Stable Diffsuion 微调方案选型,以及如何使用 Dreambooth 和 LoRA 进行微调实践,最后,我们会使用腾讯云 TACO 对微调后的 Dreambooth 和 LoRA 模型进行推理优化。

图片

Stable Diffusion 微调

Stable Diffusion 微调的目标,是将新概念注入预训练模型,利用新注入的概念以及模型的先验知识,基于文本引导条件生成自定义图片。目前主流训练 Stable Diffusion 模型的方法有 Full FineTune、Dreambooth、Text Inversion 和 LoRA,不同方法的实现逻辑和使用场景不同,选型简单对比如下:
在这里插入图片描述

需要注意的是,LoRA 是一种加速训练的方法,Stable Diffusion 从大语言模型微调中借鉴而来,可以搭配 Full FineTune 或 Dreambooth 使用。针对上述几种训练方法,我们在 A10-24G 机型上进行测试,5-10张训练图片,所需资源和时长对比如下:

在这里插入图片描述

接下来,我们重点介绍如何使用 Dreambooth 和 Lora(w Dreambooth) 对 Stable Diffusion 模型进行微调。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

建了技术答疑、交流群!想要进交流群、需要资料的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、添加微信号:mlc2060,备注:技术交流
方式②、微信搜索公众号:机器学习社区,后台回复:技术交流

资料1
在这里插入图片描述

资料2
在这里插入图片描述

Dreambooth

图片

Dreambooth 用一个罕见字符(identifier)来代表训练图片的概念,对 UNet 模型的所有权重进行调整。这里选择罕见字符(identifier),是希望原模型没有该 identifier 的先验知识,否则容易在模型先验和新注入概念(instance)间产生混淆。

对比 Full FineTune,虽然都会调整原模型的所有权重,但 Dreambooth 的创新点在于,它会使用 Stable Diffusion 模型去生成一个已有相关主题(class) 的先验知识,并在训练中充分考虑原 class 和新 instance 的 prior preservation loss,从而避免新 instance 图片特征渗透到其他生成里。

另外,训练中加入一个已有的相关主题(class)的描述,可以将 instance 和 class 进行绑定,这样新 instance 也可以使用到 class 对应的先验知识。

我们使用 Huggingface 提供的训练代码,准备5-10张图片,在A10上使用以下脚本启动训练:

accelerate launch train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --mixed_precision=fp16 \
  --instance_prompt="a photo of az baby" \
  --class_prompt="a photo of baby" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=800

其中 --instance_data_dir 为新 instance 的图片目录,在 --instance_prompt 参数里设置对应的 identifier,在 --class_prompt 设置相关 class 描述。训练代码

图片

训练集图片示例:

图片

训练完毕后,输入“a photo of az baby”,可以看到生成的图片具备训练集人物特征。

图片

训练好的模型,如果需要在 Stable Diffusion Web UI 上使用,先通过脚本进行转换,输出ckpt或者safetensors格式,再放入 $HOME/stable-diffusion-webui/models/Stable-diffusion 目录。脚本链接

python  ../scripts/convert_diffusers_to_original_stable_diffusion.py --model_path ./dreambooth_baby --checkpoint_path dreambooth_baby.safetensors --use_safetensors

LoRA(w Dreambooth)

LoRA(Low-Rank Adaptation of Large Language Models ) 是一种轻量级的微调方法,通过少量的图片训练出一个小模型,然后和基础模型结合使用,并通过插层的方式影响模型结果。

LoRA 的一个创新点,是通过“矩阵分解”的方式,优化插入层的参数量。我们可以将一个权重矩阵分解为两个矩阵进行存储,如果W是d*d维矩阵,那么A和B矩阵的尺寸可以减小到d*n,这样n远小于d,大幅度减少存储空间。

图片

训练会冻结预训练模型的参数,通过 W’ = W +△W 的方式来调整模型参数,这里的△W= ABT,其中AB矩阵就是我们的训练目标。如下图所示:

图片

LoRA 的优势在于生成的模型较小,训练速度快,但推理需要同时使用 LoRA 模型和基础模型。LoRA 模型虽然会向原有模型中插入新的网络层,但最终效果还是依赖基础模型。

我们使用 Huggingface 提供的训练代码,准备好图片后,在A10上使用以下脚本启动训练:

accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of az baby" \
  --class_prompt="a photo of baby" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1000 \
  --validation_prompt="a photo of az baby" \
  --validation_epochs=50 \
  --seed="0"

因为我们采用 Dreambooth-LoRA 方式进行训练,所以超参数基本与前述的 Dreambooth 一致。训练代码

LoRA 输出默认为 Pytorch 文件格式,如果需要在 Stable Diffusion Web UI 里使用,先将模型转化为 safetensors 格式,然后放入 $HOME/stable-diffusion-webui/models/Lora 目录使用。脚本链接

python diffusers-lora-to-safetensors.py --file pytorch_lora_weights.bin

Stable Diffusion 性能优化

与训练阶段侧重于准确预测标签和提高模型精度不同,推理阶段更看重高效处理输入并生成预测结果,同时减少资源消耗,在一些应用场景里,还会采用量化技术,在精度和性能之间取得平衡。

Stable Diffusion 是一个多模型组成的扩散Pipeline,由三个部分组成:变分自编码器 VAE、UNet 和文本编码器 CLIP。模型的推理耗时主要集中在 UNet,我们选择对这部分进行优化,提高推理性能和效率。

图片

目前社区和硬件厂商提供了多种优化方案,但这些方案接口定义复杂,使用门槛高,使得难以被广泛采用。腾讯云 TACO 只需简单操作,即可实现 Stable Diffusion 推理优化,轻松应用只被少数专家掌握的技术。

腾讯云 TACO 使用自研的编译后端,对 UNet 模型以静态图方式进行编译优化,同时根据不同的底层硬件,动态选择 Codegen 优化策略,输出更高效的机器代码,提升推理速度,减少资源占用。

Dreambooth 优化

复用训练使用的 A10 GPU 服务器,参考TACO Infer 优化 Stable Diffusion 模型,安装 Docker runtime,并拉取预置优化环境的 sd_taco:v3 镜像。因涉及编译生成机器码,最终部署的目标 GPU 型号,需要和优化时的 GPU 型号保持一致。

使用-v命令挂载微调后的 Dreambooth diffusers 模型目录,交互式启动容器。

docker run -it --gpus=all --network=host -v /[diffusers_model_directory]:/[custom_container_directory] sd_taco:v3 bash

在镜像里执行 python export_model.py,采用 TorchScript tracing 生成序列化的 UNet 模型文件。

script_model = torch.jit.trace(model, test_data, strict=False)
script_model.save("trace_module.pt")

在镜像里执行 python demo.py,对导出的 UNet Model 进行性能优化。这一步 TACO sdk 会对导出的 IR 进行编译优化,包括计算图结构优化、算子优化、以及其他针对代码生成和执行的优化技术。

完成后,使用 jit 方式加载优化后的 UNet Model。对模型输入 a. 图像隐空间向量【batchsize,隐空间通道,图片高度/8,图片宽度/8】b. timesteps值 【batchsize】c. 【batchsize,文本最大编码长度,向量大小】,即可对优化结果进行测试。代码参考如下:

import torch
import taco
import os

taco_path = os.path.dirname(taco.__file__)
torch.ops.load_library(os.path.join(taco_path, "torch_tensorrt/lib/libtorchtrt.so"))
optimized_model = torch.jit.load("optimized_recursive_script_module.pt")

pic = torch.rand(1, 4, 64, 64).cuda() // picture
timesteps = torch.tensor([1]*1) // timesteps
context = torch.randn(1, 77, 768) // text embedding

with torch.no_grad():
    output = optimized_model(pic, timesteps, context)
    print(output)

对比社区方案,TACO 优化后模型出图速度提高50%,效果见下图:

图片

(20 steps,Euler a,512 * 512,torch 1.12,无xformers,1s出图)

LoRA 优化

使用 LoRA合并脚本,将训练得到的 LoRA 文件,和基础模型进行合并。命令参考:

python networks/merge_lora.py --sd_model ../v1-5-pruned-emaonly.safetensors --save_to ../lora-v1-5-pruned-emaonly.safetensors --models <LoRA文件目录> --ratios <LoRA权重>

参考上述 Dreambooth 的优化方法,对合并后的模型进行导出和优化。效果见下图:

图片

(20 steps,Euler a,512 * 512,anime-tarot-card,torch 1.12,无xformers,1s出图)

ControlNet 优化

Dreambooth 及 LoRA 优化模型,依然适用于 ControlNet 使用场景,对比社区方案,TACO 优化后 ControlNet 的出图速度可以提高30%以上,效果见下图:

图片

(20 steps,Euler a,512 * 512,ControlNet-canny,torch 1.12,无xformers,2s出图)

经过 TACO 优化后的 UNet 模型,测试表明前向推理速度提高至开源方案的4倍。在实际应用中,512*512,20 steps 的配置下,Stable Diffusion Web UI 端到端的推理时间缩短 1 秒。以上优化详细过程及环境获取,参考 TACO Infer 优化 Stable Diffusion 系列模型。

总结

本文介绍了 Dreambooth 和 LoRA 在腾讯云A10机型上的微调实践,以及针对这两种模型的 TACO 推理优化过程。感兴趣的同学可以在文章的基础上,尝试训练风格独特的模型,辅以 TACO 推理优化能力,创造符合自身业务的云上 Stable Diffusion。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316674.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go学习redis的学习与使用

文章目录 一、redis的学习与使用1.Redis的基本介绍2.Redis的安装下载安装包即可3.Redis的基本使用1&#xff09;Redis的启动&#xff1a;2&#xff09;Redis的操作的三种方式3&#xff09;说明&#xff1a;Redis安装好后&#xff0c;默认有16个数据库&#xff0c;初始默认使用0…

深入解析 Spring 和 Spring Boot 的区别

目录 引言 1. 设计理念 1.1 Spring 框架的设计理念 1.2 Spring Boot 的设计理念 2. 项目配置 2.1 Spring 框架的项目配置 2.2 Spring Boot 的项目配置 3. 自动配置 3.1 Spring 框架的自动配置 3.2 Spring Boot 的自动配置 4. 微服务支持 4.1 Spring 框架的微服务支持…

C# WPF上位机开发(加密和解密)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 在报文传输的过程中&#xff0c;根据报文传输的形态&#xff0c;有两种形式&#xff0c;一种是明文传输&#xff0c;一种是加密传输。当然明文传输…

基于单片机的太阳能数据采集系统(论文+源码)

1. 系统设计 在本次太阳能数据采集系统的设计中&#xff0c;以AT89C52单片机为主要核心&#xff0c;主要是由LCD液晶显示模块、存储模块、温度检测模块、串口通信模块&#xff0c;光照检测模块等组成&#xff0c;其实现了对太阳能板的温度&#xff0c;光照强度的检测和记录&…

【Qt开发流程】之UDP

概述 UDP (User Datagram Protocol)是一种简单的传输层协议。与TCP不同&#xff0c;UDP不提供可靠的数据传输和错误检测机制。UDP主要用于那些对实时性要求较高、对数据传输可靠性要求较低的应用&#xff0c;如音频、视频、实时游戏等。 UDP使用无连接的数据报传输模式。在传…

2机5节点系统潮流MATLAB仿真

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 电力系统潮流计算是电力系统最基本的计算&#xff0c;也是最重要的计算。所谓潮流计算&#xff0c;就是已知电网的接线方式与参数及运行条件&#xff0c;计算电力系统稳态运行各母线电压、各支路电流、功率及…

AntDesignBlazor示例——分页查询

本示例是AntDesign Blazor的入门示例&#xff0c;在学习的同时分享出来&#xff0c;以供新手参考。 示例代码仓库&#xff1a;https://gitee.com/known/BlazorDemo 1. 学习目标 分页查询框架天气数据分页功能表格自定义分页 2. 创建分页查询框架 Table组件分页默认为前端分…

《科技风》期刊发表投稿方式、收稿方向

《科技风》杂志是经国家新闻出版总署批准&#xff0c;河北省科学技术协会主管&#xff0c;河北省科技咨询服务中心主办的国内公开发行的大型综合类科技期刊。 该刊集科技性、前瞻性、创新性和专业性于一体&#xff0c;始终以“把脉科技创新 引领发展风尚”为办刊宗旨&#xff…

回归预测 | MATLAB实现NGO-SCN北方苍鹰算法优化随机配置网络的数据回归预测 (多指标,多图)

回归预测 | MATLAB实现NGO-SCN北方苍鹰算法优化随机配置网络的数据回归预测 &#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现NGO-SCN北方苍鹰算法优化随机配置网络的数据回归预测 &#xff08;多指标&#xff0c;多图&#xff09;效果一览基本介绍…

crmeb v5新增一个功能的完整示例记录

首先&#xff0c;需求 工作中的二开需求是这样的&#xff0c;修改首页的装修&#xff0c;并新增回收报价的功能 开始动手 第一步&#xff0c;我们要到后面的管理界面&#xff0c;去装修中修改首面的展示 首页的页面配置好之后&#xff0c;就要在 前端的展示程序中 配置相…

【音视频 | H.264】H.264视频编码及NALU详解

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

ChatGPT在指尖跳舞: open-interpreter实现本地数据采集、处理一条龙

原文&#xff1a;ChatGPT在指尖跳舞: open-interpreter实现本地数据采集、处理一条龙 - 知乎 目录 收起 Part1 前言 Part2 Open - Interpreter 简介 Part3 安装与运行 Part4 工作场景 1获取网页内容 2 pdf 文件批量转换 3 excel 文件合并 Part5总结 参考资料 往期推…

nodejs+vue+微信小程序+python+PHP校园二手交易系统的设计与实现-计算机毕业设计推荐

(2)管理员 进行维护&#xff0c;以及平台的后台管理工作都依靠管理员&#xff0c;其可以对信息进行管理。需具备功能有&#xff1b;首页、个人中心、学生管理、物品分类管理、物品信息管理、心愿贴、系统管理、订单管理等功能。系统分成管理员控制模块和学生模块。 本系统有良好…

极坐标下的牛拉法潮流计算39节点MATLAB程序

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 潮流计算&#xff1a; 潮流计算是根据给定的电网结构、参数和发电机、负荷等元件的运行条件&#xff0c;确定电力系统各部分稳态运行状态参数的计算。通常给定的运行条件有系统中各电源和负荷点的功率、枢纽…

飞翔的鸟。

一.准备工作 首先创建一个新的Java项目命名为“飞翔的鸟”&#xff0c;并在src中创建一个包命名为“com.qiku.bird"&#xff0c;在这个包内分别创建4个类命名为“Bird”、“BirdGame”、“Column”、“Ground”&#xff0c;并向需要的图片素材导入到包内。 二.代码呈现 pa…

Flask学习三:模型操作

ORM flask 通过Model操作数据库&#xff0c;不管你的数据库是MySQL还是Sqlite&#xff0c;flask自动帮你生成相应数据库类型的sql语句&#xff0c;所以不需要关注sql语句和类型&#xff0c;对数据的操作flask帮我们自动完成&#xff0c;只需要会写Model就可以了 flask使用对象关…

105基于matlab的阶次分析算法

基于matlab的阶次分析算法&#xff0c;用于变转速机械故障特征提取&#xff0c;可运行&#xff0c;包含寻找脉冲时刻&#xff0c;等角度时刻。数据可更换自己的&#xff0c;程序已调通&#xff0c;可直接运行。 105阶次分析变转速信号处理 (xiaohongshu.com)

Java 基础学习(十一)File类与I/O操作

1 File类 1.1 File类概述 1.1.1 什么是File类 File是java.io包下作为文件和目录的类。File类定义了一些与平台无关的方法来操作文件&#xff0c;通过调用File类中的方法可以得到文件和目录的描述信息&#xff0c;包括名称、所在路径、读写性和长度等&#xff0c;还可以对文件…

『K8S 入门』二:深入 Pod

『K8S 入门』二&#xff1a;深入 Pod 一、基础命令 获取所有 Pod kubectl get pods2. 获取 deploy kubectl get deploy3. 删除 deploy&#xff0c;这时候相应的 pod 就没了 kubectl delete deploy nginx4. 虽然删掉了 Pod&#xff0c;但是这是时候还有 service&#xff0c…

时序预测 | Python实现CNN-LSTM电力需求预测

时序预测 | Python实现CNN-LSTM电力需求预测 目录 时序预测 | Python实现CNN-LSTM电力需求预测预测效果基本描述程序设计参考资料预测效果 基本描述 该数据集因其每小时的用电量数据以及 TSO 对消耗和定价的相应预测而值得注意,从而可以将预期预测与当前最先进的行业预测进行比…