[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

news2024/9/22 3:42:40

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"
 

DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"

if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; then
    echo "-> preprocessed_11k or 212k"
    S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; then
    S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
else
    echo "Something wrong in data folder name"
    exit
fi

wget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"

UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if needed

BATCH_SIZE=2
GRAD_ACCUMULATION=4

StartTime=$(date +%s)

CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \
  --pretrained_model_name_or_path $MODEL_NAME \
  --train_data_dir $TRAIN_DATA_DIR\
  --use_ema \
  --resolution 512 --center_crop --random_flip \
  --train_batch_size $BATCH_SIZE \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --learning_rate 5e-05 \
  --max_grad_norm 1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --report_to="all" \
  --max_train_steps=20 \
  --seed 1234 \
  --gradient_accumulation_steps $GRAD_ACCUMULATION \
  --checkpointing_steps 5 \
  --valid_steps 5 \
  --lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \
  --use_copy_weight_from_teacher \
  --unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \
  --output_dir $OUTPUT_DIR


EndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1273184.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

了解ConcurrnetHashMap 吗?

程序员的公众号:源1024,获取更多资料,无加密无套路! 最近整理了一波电子书籍资料,包含《Effective Java中文版 第2版》《深入JAVA虚拟机》,《重构改善既有代码设计》,《MySQL高性能-第3版》&…

kafka中的常见问题处理

文章目录 1. 如何防⽌消息丢失2. 如何防⽌重复消费3. 如何做到消息的顺序消费4. 如何解决消息积压问题4.1 消息积压问题的出现4.2 消息积压的解决⽅案 5. 实现延时队列的效果5.1 应用场景5.2 具体方案 1. 如何防⽌消息丢失 ⽣产者:1)使⽤同步发送 2&…

决策树(Classification and Regression Tree)

学了数据结构的树后,一直没发现树有哪些应用。学而时习(实践)之,不亦说乎?故特地上网查了查树的应用,在下阐释: 1.文件系统:文件和目录的组织通常以树的形式表示,允许高效…

掌握Python BentoML:构建、部署和管理机器学习模型

更多资料获取 📚 个人网站:ipengtao.com BentoML是一个开源的Python框架,旨在简化机器学习模型的打包、部署和管理。本文将深入介绍BentoML的功能和用法,提供详细的示例代码和解释,帮助你更好地理解和应用这个强大的工…

【C++】异常处理 ③ ( 栈解旋 | 栈解旋概念 | 栈解旋作用 )

文章目录 一、栈解旋1、栈解旋引入2、栈解旋概念3、栈解旋作用 二、代码示例 - 栈解旋1、代码示例2、执行结果 一、栈解旋 1、栈解旋引入 C 程序 抛出异常后 对 局部变量的处理 : 当 C 应用程序 在 运行过程 中发生异常时 , 程序会跳转到异常处理程序 , 并执行一些操作以处理异…

10.30 作业 C++

设计一个Per类&#xff0c;类中包含私有成员:姓名、年龄、指针成员身高、体重&#xff0c;再设计一个Stu类&#xff0c;类中包含私有成员:成绩、Per类对象p1&#xff0c;设计这两个类的构造函数、析构函数和拷贝构造函数。 #include <iostream>using namespace std;clas…

【C语言学习疑难杂症】第6期:C语言中如何打印一些特殊字符,比如打印扩展ascii码字符

首先我们来看下ascii表和ascii拓展表: ascii表中的字符只有128个,是从0-127,而拓展ascii表的内容是128-255。拓展表中它们都是一些特殊的字符,如果我们想答应ascii拓展码中的一些字符应该要怎么操作呢? 比如下面的代码: unsigned char a = 176, b = 219;printf("%…

客餐书房一体布局,新中式风格禅意十足。福州中宅装饰,福州装修

你是否曾经遇到过这样的痛点&#xff1a;装修时不知道该选择什么样的风格&#xff0c;让家居空间显得既时尚又实用&#xff1f;如果你对此感到困惑&#xff0c;那么新中式风格可能正是你想要的选择&#xff01; 今天&#xff0c;我们将一起探讨一种别样的家居布局&#xff0c;它…

openGauss学习笔记-136 openGauss 数据库运维-例行维护-检查数据库性能

文章目录 openGauss学习笔记-136 openGauss 数据库运维-例行维护-检查数据库性能136.1 检查办法136.2 异常处理 openGauss学习笔记-136 openGauss 数据库运维-例行维护-检查数据库性能 136.1 检查办法 通过openGauss提供的性能统计工具gs_checkperf可以对硬件性能进行检查。 …

单词拆分 II

题目链接 单词拆分 II 题目描述 注意点 s 和 wordDict[i] 仅有小写英文字母组成wordDict 中所有字符串都 不同词典中的同一个单词可能在分段中被重复使用多次以任意顺序 返回所有这些可能的句子 解答思路 使用深度优先遍历回溯解决本题&#xff0c;每一层从idx开始遍历s&a…

OSG编程指南<十六>:OSG渲染到纹理RTT及三维纹理体渲染技术简介

1、渲染到纹理&#xff08;RTT&#xff09; 1.1 RTT介绍 RTT&#xff08;Render to Texture&#xff09;即渲染到纹理。在普通的图形渲染流程中&#xff0c;最终结果是渲染到帧缓存中&#xff0c;然后才会显示到屏幕上。而RTT则是将场景渲染到一张纹理上&#xff0c;并且在之后…

知识蒸馏代码实现(以MNIST手写数字体为例,自定义MLP网络做为教师和学生网络)

dataloader_tools.py import torchvision from torchvision import transforms from torch.utils.data import DataLoaderdef load_data():# 载入MNIST训练集train_dataset torchvision.datasets.MNIST(root "../datasets/",trainTrue,transformtransforms.ToTens…

Unity 注释的方法

1、单行注释&#xff1a;使用双斜线&#xff08;//&#xff09;开始注释&#xff0c;后面跟注释内容。通常注释一个属性或者方法&#xff0c;如&#xff1a; //速度 public float Speed;//打印输出 private void DoSomething() {Debug.Log("运行了我"); } …

老师旁听公开课到底听什么

经常参加公开课是老师提升自己教学水平的一种方式。那么&#xff0c;在旁听公开课时&#xff0c;老师应该听什么呢&#xff1f; 听课堂氛围 一堂好的公开课&#xff0c;应该能够让学生积极参与&#xff0c;课堂气氛活跃&#xff0c;而不是老师一个人唱独角戏。如果老师能够引导…

第16关 革新云计算:如何利用弹性容器与托管K8S实现极速服务POD扩缩容

------> 课程视频同步分享在今日头条和B站 天下武功&#xff0c;唯快不破&#xff01; 大家好&#xff0c;我是博哥爱运维。这节课给大家讲下云平台的弹性容器实例怎么结合其托管K8S&#xff0c;使用混合服务架构&#xff0c;带来极致扩缩容快感。 下面是全球主流云平台弹…

Windows系列:windows2003-建立域

windows2003-建立域 Active Directory建立DNS建立域查看日志xp 加入域 Active Directory 活动目录是一个包括文件、打印机、应用程序、服务器、域、用户账户等对象的数据库。 常见概念&#xff1a;对象、属性、容器 域组件&#xff08;Domain Component&#xff0c;DC&#x…

java操作windows系统功能案例(二)

1、打印指定文件 可以使用Java提供的Runtime类和Process类来打印指定文件。以下是一个示例代码&#xff1a; import java.io.File; import java.io.IOException;public class PrintFile {public static void main(String[] args) {if (args.length ! 1) {System.out.println(…

C# Onnx 百度飞桨开源PP-YOLOE-Plus目标检测

目录 效果 模型信息 项目 代码 下载 C# Onnx 百度飞桨开源PP-YOLOE-Plus目标检测 效果 模型信息 Inputs ------------------------- name&#xff1a;image tensor&#xff1a;Float[1, 3, 640, 640] name&#xff1a;scale_factor tensor&#xff1a;Float[1, 2] ----…

HuggingFace学习笔记--Model的使用

1--Model介绍 Transformer的 model 一般可以分为&#xff1a;编码器类型&#xff08;自编码&#xff09;、解码器类型&#xff08;自回归&#xff09;和编码器解码器类型&#xff08;序列到序列&#xff09;&#xff1b; Model Head&#xff08;任务头&#xff09;是在base模型…

Windows11如何让桌面图标的箭头消失(去掉快捷键箭头)

在Windows 11中&#xff0c;桌面图标的箭头是快捷方式图标的一个标志&#xff0c;用来表示该图标是一个指向文件、文件夹或程序的快捷方式。如果要隐藏这些箭头&#xff0c;你需要修改Windows注册表或使用第三方软件。 在此之前&#xff0c;我需要提醒你&#xff0c;修改注册表…