[读论文]DiT Scalable Diffusion Models with Transformers

news2025/2/24 8:36:46

论文翻译
Scalable Diffusion Models with Transformers-CSDN博客

论文地址:
https://arxiv.org/pdf/2212.09748.pdf

项目地址:
GitHub - facebookresearch/DiT: Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"

论文主页:
Scalable Diffusion Models with Transformers

实验指标

1 可视化展示(x increasing transformer size | y decreasing patch size)
2 Transformer Gflops
   Training Compute (Gflops)
4 FID IS Precision Recalls (256x256 512x512)
5 sampling-up compute (Gflops)

github使用说明

一、采样过程samlpe

Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py. Weights for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use:

重新训练DiT检查点。
您可以使用sample.py从我们预训练的DiT模型中进行抽样。
我们预训练的DiT模型的权重将根据您使用的模型自动下载。
该脚本具有各种参数,用于在256x256和512x512模型之间切换,调整采样步骤,更改无分类器的指导尺度等。例如,要从512x512 DiT-XL/2模型中取样,您可以使用:

python sample.py --image-size 512 --seed 1

训练好的模型下载

Custom DiT checkpoints. If you've trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

自定义DiT检查点。
如果您使用train.py(见下文)训练了一个新的DiT模型,
那么您可以添加——ckpt参数来使用您自己的检查点。
例如,要从自定义256x256 DiT-L/4型号的EMA权重中采样,请运行:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

二、训练过程

We provide a training script for DiT in train.py.
This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with N GPUs on one node:

我们在train.py中为DiT提供了一个训练脚本。
该脚本可用于训练类条件DiT模型,但是可以很容易地修改它以支持其他类型的条件。在一个节点上使用N个gpu启动DiT-XL/2 (256x256)训练:

torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train

torchrun
--nnodes=1
--nproc_per_node=N
train.py
--model DiT-XL/2
--data-path /path/to/imagenet/train

三、训练结果

PyTorch Training Results

We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:

我们已经用PyTorch训练脚本从头开始训练DiT-XL/2和DiT-B/4模型,以验证它能够再现多达数十万次训练迭代的原始JAX结果。
在我们的实验中,与jax训练的模型相比,pytorch训练的模型给出了类似(有时略好)的结果,但存在合理的随机变化。一些数据点:

These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID here is computed with 250 DDPM sampling steps, with the mse VAE decoder and without guidance (cfg-scale=1).

这些模型以256x256分辨率进行训练;
我们用8倍a100来训练XL/2,
4倍a100来训练B/4。
请注意,这里的FID是用250 DDPM采样步骤计算的,使用mse VAE解码器,没有指导(cfg-scale=1)。

TF32 Note (important for A100 users). When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. We've enabled them at the top of train.py and sample.py because it makes training and sampling way way way faster on A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to the above results.

TF32说明(对于A100用户很重要)。
当我们运行上述测试时,根据PyTorch的默认设置禁用了TF32 matmul。
我们在train.py和sample.py的顶部启用了它们,因为它使a100上的训练和采样方式更快(对于其他ampere gpu也应该如此),但请注意,与上述结果相比,使用TF32可能会导致一些差异。

eature Update Check out this repository at https://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features.  With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU.
查看此存储库https://github.com/chuanyangjin/fast-DiT
预览选择的训练速度加速和内存节省功能,包括梯度检查点,混合精度训练和预提取的VAE功能。凭借这些进步,我们仅使用单个A100 GPU就实现了DiT-XL/2的0.84步/秒的训练速度。

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a DiT model in parallel.
This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics.
For example, to sample 50K images from our pre-trained DiT-XL/2 model over N GPUs, run:

我们包含了一个sample_ddp.py脚本,该脚本从DiT模型中并行采样大量图像。
该脚本生成一个样本文件夹以及一个.npz文件,该文件可以直接与ADM的TensorFlow评估套件一起使用,以计算FID, Inception Score和其他指标
例如,要在N个gpu上从预训练的DiT-XL/2模型中采样50K图像,请运行:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

torchrun --nnodes=1 --nproc_per_node=N
sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1215242.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

个推用户运营全新上线用户生命周期管理功能,助力APP快速实现用户精细化运营

近期,个推用户运营上线了APP用户生命周期管理功能。该功能可以帮助APP多维度洞察⽤户所处的⽣命周期分布,旨在帮助运营人员快速全面地了解用户,从而基于用户生命周期针对性地做出用户运营策略调整,提升用户价值和运营指标。 个推如…

【LeetCode:2760. 最长奇偶子数组 | 模拟 双指针】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

(C++)把字符串转换成整数

把字符串转换成整数_牛客题霸_牛客网 愿所有美好如期而遇 思路 看到这个题目我们首先应该想到的就是去处理第一个字符,但是第一个字符也可能是数字字符,所以我们需要对他单独处理,如果他不符合条件,直接return,符合条…

QGIS之二十三矢量线融合

效果 步骤 1、准备数据 现有线分段太多,需要将部分线按照某个字段融合起来 2、融合 运行 3、结果 线已经融合了 线相交处也添加了线的节点

【开源】基于Vue和SpringBoot的网上药店系统

项目编号: S 062 ,文末获取源码。 \color{red}{项目编号:S062,文末获取源码。} 项目编号:S062,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 药品类型模块2.3 药…

[C++]:8.C++ STL引入+string(介绍)

C STL引入string(介绍) 一.STL引入:1.什么是STL2.什么是STL的版本:2-1:原始版本:2-2:P. J 版本:2-3:RW 版本:2-4:SGL版本: 3.STL 的六大组件&…

JS-项目实战-删除库存记录

1、fruit.js function $(name) {if (name) {//假设name是 #fruit_tblif (name.startsWith("#")) {name name.substring(1); //fruit_tblreturn document.getElementById(name);}} }//当页面加载完成后执行后面的匿名函数 window.onload function () {//get:获取…

YOLOv8-Seg改进: 捕捉空间上的局部关系和全局关系的CoordAttention注意力 | 分割注意力系列篇

🚀🚀🚀本文改进:CoordAttention注意力,引入到YOLOv8-seg,CoordAttention在计算注意力时,不仅会考虑输入的特征信息,还会考虑每个像素点的位置信息,从而更好地捕捉空间上的局部关系和全局关系。 🚀🚀🚀Context Aggregation小目标分割&复杂场景首选,实现…

Python winreg将cmd/PowerShell(管理员)添加到右键菜单

效果 1. 脚本 用管理员权限运行,重复执行会起到覆盖效果(根据sub_key)。 icon自己设置。text可以自定义。sub_key可以改但不推荐(避免改成和系统已有项冲突的)。command不要改。 from winreg import *registry r&q…

前端转行可以做什么

前端开发者通常拥有很好的技术背景和解决问题的能力,所以有很多可能的职业选择。以下是一些可能的选择: 全栈开发:这是一个非常热门的职位,需要能够处理前端和后端工作。使用多种编程语言和技术来构建从数据库到用户界面的整个应…

ps5计时计费软件安装教程,佳易王电玩店计时收费系统

ps5计时计费软件安装教程,佳易王电玩店计时收费系统 一、佳易王电玩PS5游戏厅计时计费软件部分功能简介: 1、计时计费功能 :开台时间和所用的时长直观显示,每3秒即可刷新一次时间。 2、销售商品功能 :商品可以绑定桌…

【深度学习实验】网络优化与正则化(五):数据预处理详解——标准化、归一化、白化、去除异常值、处理缺失值

文章目录 一、实验介绍二、实验环境1. 配置虚拟环境2. 库版本介绍 三、优化算法0. 导入必要的库1. 随机梯度下降SGD算法a. PyTorch中的SGD优化器b. 使用SGD优化器的前馈神经网络 2.随机梯度下降的改进方法a. 学习率调整b. 梯度估计修正 3. 梯度估计修正:动量法Momen…

ubuntu云服务器配置SFTP服务

目录 一、安装并运行SSH服务 1,安装ssh服务 2,运行ssh 3,查看ssh运行状态 二、创建SFTP用户并进行用户相关的配置 1,创建SFTP用户 2,限制用户只能使用 SFTP,并禁止 SSH 登录。打开/ect/ssh/sshd_conf…

一文看懂Spark中Cache和CheckPoint的区别

目录 循循渐进理解使用Cache或者PersistCheckPoint缓存和CheckPoint的区别 循循渐进理解 wc.txt数据 hello java spark hadoop flume kafka hbase kafka flume hadoop看下面代码会打印多少条-------------------------(RDD2) import org.apache.spark.rdd.RDD import org.ap…

这就是思维导图!全面分析思维导图的实际用途

思维导图是一种以图形方式呈现的思维工具,它以中心主题为核心,通过分支展开相关的子主题和想法。它可以帮助我们更好地组织和理解信息,提高学习、工作和生活的效率。 在信息爆炸的时代,有效地管理和利用大量的信息成为一个亟待解决…

Linux - Namespace

一、namespace 是什么? Linux namespaces 是对全局系统资源的一种封装隔离,使得处于不同 namespace 的进程拥有独立的全局系统资源,改变一个 namespace 中的系统资源只会影响当前 namespace 里的进程,对其他 namespace 中的进程没…

12-2- DCGAN -简单网络-卷积网络

功能 随机噪声→生成器→MINIST图像。 训练方法 1 判别器的训练,首先固定生成器参数不变,其次判别器应当将真实图像判别为1,生成图像判别为0 loss=loss(real_out, 1)+loss(fake_out, 0) 2 生成器的训练,首先固定判别器参数不变,其次判别器应当将生成图像判别为1 loss =…

如何避免在Flask中使用Response对象

在Flask框架中,Response对象的__bool__和__nonzero__方法被重载,以便返回一个表示HTTP响应状态是否为’OK’的布尔值。然而,这可能会导致一些预期之外的行为。 解决方案 对于上述问题,可以通过直接检查Response对象的ok属性来避…

在哪里可以制作一本精美的翻页产品册呢?

你是否曾经为了一张可滑动的画册而翻看了整个产品册?翻页产品册是一种数字化的画册形式,它可以在电脑、手机、平板等设备上进行浏览和阅读。相比传统的纸质画册,翻页产品册有着更多的优势和用途。那么,在哪里可以制作一本这种精美…

解决requests库中session.verify参数失效的问题

在使用requests库进行HTTP请求时,如果在环境变量中设置了’REQUESTS_CA_BUNDLE’,并且在session对象中设置了verify参数为False,那么API请求会使用环境变量中的值而不是session对象中的值。这是因为在requests库中,当session对象中…