手把手写深度学习(0):专栏文章导航
前言:deepspeed已经成为了大模型时代训练模型的常规武器,这篇博客以一个基于DDP的 Stable Diffusion模型训练为例,讲解如何从将DDP训练代码改成DeepSpeed。
目录
构建optimizer
构建scheduler
Argument
Init Process
加载训练组件
训练代码修改
启动命令行
weight_dtype
DDP并行策略
配置文件
Loss为NAN
训练精度
调参步骤
示例1
示例2
构建optimizer
原来DDP代码如下:
optimizer = torch.optim.AdamW(
trainable_params,
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
)
deepspeed只需要在配置文件中写好这些参数即可:
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [0.9, 0.999],
"eps": 1e-08,
"weight_decay": 1e-2
}
}
如果你需要其他optimizer,可以看官方文档:DeepSpeed Configuration JSON - DeepSpeed
构建scheduler
原先DDP代码中构建scheduler:
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
)
deepspeed中同样只需要在配置文件中构建:
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
Argument
必须在train.py 文件添加以下的代码:
parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()
否则会报错!也无法识别deepspeed配置文件!
Init Process
原先DDP的代码:
import torch.distributed as dist
dist.init_process_group(backend=backend, **kwargs)
现在改成:
deepspeed.init_distributed()
加载训练组件
如果在配置文件中写好了之后,只需要像这样直接加载,不需要另外写代码:
unet, optimizer, _, lr_scheduler = deepspeed.initialize(
model=unet,
# optimizer=optimizer,
model_parameters=trainable_params,
args=deep_args
)
训练代码修改
deepspeed的标准训练方式如下,和torch的训练有一点点区别,直接修改就好:
for step, batch in enumerate(data_loader):
#forward() method
loss = model_engine(batch)
#runs backpropagation
model_engine.backward(loss)
#weight update
model_engine.step()
当然optimizer 和 scheduler的step可以自己加上:
lr_scheduler.step()
optimizer.step()
启动命令行
这是我的参考:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed --num_gpus=8 train_deepspeed.py \
--deepspeed --deepspeed_config 'configs/deepspeed/stage_2.json'
deepspeed_config这个参数是下面说的配置文件。
weight_dtype
虽然配置文件中会写上训练精度,但是代码中仍需要手动转换!
if mixed_precision in ("fp16", "bf16"):
weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
else:
weight_dtype = torch.float32
DDP并行策略
原始代码:
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
这句话直接删掉,deepspeed会自动帮我们做并行
配置文件
下面是我用的一个stage 2的参考文件:
{
"bfloat16": {
"enabled": false
},
"fp16": {
"enabled": true,
"auto_cast": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1e-4
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-5,
"betas": [0.9, 0.999],
"eps": 1e-08,
"weight_decay": 1e-2
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": 1.0,
"train_batch_size": 8,
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1e5
}
Loss为NAN
训练时用的是bf16,使用时是fp16。常常发生于google在TPU上train的模型,如T5。此时需要使用fp32或者bf16。
训练精度
-
由于 fp16 混合精度大大减少了内存需求,并可以实现更快的速度,因此只有在在此训练模式下表现不佳时,才考虑不使用混合精度训练。 通常,当模型未在 fp16 混合精度中进行预训练时,会出现这种情况(例如,使用 bf16 预训练的模型)。 这样的模型可能会溢出,导致loss为NaN。 如果是这种情况,使用完整的 fp32 模式。
-
如果是基于 Ampere 架构的 GPU,pytorch 1.7 及更高版本将自动切换为使用更高效的 tf32 格式进行某些操作,但结果仍将采用 fp32。
-
使用 Trainer,可以使用 --tf32 启用它,或使用 --tf32 0 或 --no_tf32 禁用它。 PyTorch 默认值是使用tf32。
调参步骤
-
将batch_size设置为1,通过梯度累积实现任意的有效batch_size
-
如果OOM则,设置--gradient_checkpointing 1 (HF Trainer),或者 model.gradient_checkpointing_enable()
-
如果OOM则,尝试ZeRO stage 2
-
如果OOM则,尝试ZeRO stage 2 + offload_optimizer
-
如果OOM则,尝试ZeRO stage 3
-
如果OOM则,尝试offload_param到CPU
-
如果OOM则,尝试offload_optimizer到CPU
-
如果OOM则,尝试降低一些默认参数。比如使用generate时,减小beam search的搜索范围
-
如果OOM则,使用混合精度训练,在Ampere的GPU上使用bf16,在旧版本GPU上使用fp16
-
如果仍然OOM,则使用ZeRO-Infinity ,使用offload_param和offload_optimizer到NVME
-
一旦使用batch_size=1时,没有导致OOM,测量此时的有效吞吐量,然后尽可能增大batch_size
-
开始优化参数,可以关闭offload参数,或者降低ZeRO stage,然后调整batch_size,然后继续测量吞吐量,直到性能比较满意(调参可以增加66%的性能)
示例1
官方提供了完整的例子,很多操作可以照抄:GitHub - microsoft/DeepSpeedExamples: Example models using DeepSpeed
初学者可以从imagenet开始学:DeepSpeedExamples/training/imagenet/main.py at master · microsoft/DeepSpeedExamples · GitHub
import argparse
import os
import pdb
import random
import shutil
import time
import warnings
from enum import Enum
import torch
import deepspeed
from openpyxl import Workbook
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--multiprocessing_distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('--local_rank', type=int, default=-1, help="local rank for distributed training on gpus")
parser = deepspeed.add_config_arguments(parser)
best_acc1 = 0
def main():
args = parser.parse_args()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
cudnn.benchmark = False
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
if torch.cuda.is_available():
ngpus_per_node = torch.cuda.device_count()
else:
ngpus_per_node = 1
args.world_size = ngpus_per_node * args.world_size
t_losses, t_acc1s = main_worker(args.gpu, ngpus_per_node, args)
#dist.barrier()
# Write the losses to an excel file
if dist.get_rank() ==0:
all_losses = [torch.empty_like(t_losses) for _ in range(ngpus_per_node)]
dist.gather(tensor=t_losses, gather_list=all_losses,dst=0)
else:
dist.gather(tensor=t_losses, dst=0)
if dist.get_rank() ==0:
all_acc1s = [torch.empty_like(t_acc1s) for _ in range(ngpus_per_node)]
dist.gather(tensor=t_acc1s, gather_list=all_acc1s,dst=0)
else:
dist.gather(tensor=t_acc1s, dst=0)
if dist.get_rank() == 0:
outputfile = "Acc_loss_log.xlsx"
workbook = Workbook()
sheet1 = workbook.active
sheet1.cell(row= 1, column = 1, value = "Loss")
sheet1.cell(row= 1, column = ngpus_per_node + 4, value = "Acc")
for rank in range(ngpus_per_node):
for row_idx, (gpu_losses, gpu_acc1s) in enumerate(zip(all_losses[rank], all_acc1s[rank])):
sheet1.cell(row=row_idx + 2, column = rank+1, value = float(gpu_losses))
sheet1.cell(row=row_idx + 2, column = rank+1 + ngpus_per_node + 3, value = float(gpu_acc1s))
workbook.save(outputfile)
def main_worker(gpu, ngpus_per_node, args):
global best_acc1
args.gpu = gpu
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
# In case of distributed process, initializes the distributed backend
# which will take care of sychronizing nodes/GPUs
if args.local_rank == -1:
if args.gpu:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device("cuda")
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
deepspeed.init_distributed()
def print_rank_0(msg):
if args.local_rank <=0:
print(msg)
args.batch_size = int(args.batch_size / ngpus_per_node)
if not torch.cuda.is_available():# and not torch.backends.mps.is_available():
print('using CPU, this will be slow')
device = torch.device("cpu")
model = model.to(device)
# define loss function (criterion), optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
elif torch.cuda.is_available():
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Initialize DeepSpeed for the model
model, optimizer, _, _ = deepspeed.initialize(
model = model,
optimizer = optimizer,
args = args,
lr_scheduler = None,#scheduler,
dist_init_required=True
)
# Data loading code
if args.dummy:
print("=> Dummy data is used!")
train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor())
val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.local_rank != -1:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
else:
train_sampler = None
val_sampler = None
print("Batch_size:",args.batch_size)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
if args.evaluate:
validate(val_loader, model, criterion, args)
return
losses = torch.empty(args.epochs).cuda()
acc1s = torch.empty(args.epochs).cuda()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
this_loss = train(train_loader, model, criterion, optimizer, epoch, device, args)
losses[epoch] = this_loss
# evaluate on validation set
acc1 = validate(val_loader, model, criterion, args)
acc1s[epoch] = acc1
scheduler.step()
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.gpu is None):
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
'scheduler' : scheduler.state_dict()
}, is_best)
return (losses, acc1s)
def train(train_loader, model, criterion, optimizer, epoch, device, args):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# move data to the same device as model
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
model.backward(loss)
model.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i + 1)
return (float(losses.val))
def validate(val_loader, model, criterion, args):
def run_validate(loader, base_progress=0):
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
images = images.cuda(args.gpu, non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i + 1)
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
run_validate(val_loader)
if args.distributed:
top1.all_reduce()
top5.all_reduce()
if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
aux_val_dataset = Subset(val_loader.dataset,
range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
aux_val_loader = torch.utils.data.DataLoader(
aux_val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
run_validate(aux_val_loader, len(val_loader))
progress.display_summary()
return top1.avg
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()
示例2
原始的DDP代码:
import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess
from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModel
from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr
from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_distributed_dataloader, sanity_check
def init_dist(launcher="slurm", backend="nccl", port=29500, **kwargs):
"""Initializes distributed environment."""
if launcher == "pytorch":
rank = int(os.environ["RANK"])
num_gpus = torch.cuda.device_count()
local_rank = rank % num_gpus
torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, **kwargs)
elif launcher == "slurm":
proc_id = int(os.environ["SLURM_PROCID"])
ntasks = int(os.environ["SLURM_NTASKS"])
node_list = os.environ["SLURM_NODELIST"]
num_gpus = torch.cuda.device_count()
local_rank = proc_id % num_gpus
torch.cuda.set_device(local_rank)
addr = subprocess.getoutput(
f"scontrol show hostname {node_list} | head -n1")
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = str(ntasks)
os.environ["RANK"] = str(proc_id)
port = os.environ.get("PORT", port)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend)
print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; "
f"node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
else:
raise NotImplementedError(f"Not implemented launcher type: `{launcher}`!")
return local_rank
def main(
image_finetune: bool,
name: str,
launcher: str,
output_dir: str,
pretrained_model_path: str,
pretrained_clip_path: str,
train_data: Dict,
train_cfg: bool = True,
cfg_uncond_ratio: float = 0.05,
pose_shuffle_ratio: float = 0.0,
pretrained_vae_path: str = "",
pretrained_mm_path: str = "",
unet_checkpoint_path: str = "",
unet_additional_kwargs: Dict = None,
noise_scheduler_kwargs: Dict = None,
pose_guider_kwargs: Dict = None,
fusion_blocks: str = "full",
clip_projector_kwargs: Dict = None,
fix_ref_t: bool = False,
zero_snr: bool = False,
snr_gamma: Optional[float] = None,
v_pred: bool = False,
max_train_epoch: int = -1,
max_train_steps: int = 100,
learning_rate: float = 5e-5,
scale_lr: bool = False,
lr_warmup_steps: int = 0,
lr_scheduler: str = "constant",
trainable_modules: Tuple[str] = (),
num_workers: int = 4,
train_batch_size: int = 1,
adam_beta1: float = 0.9,
adam_beta2: float = 0.999,
adam_weight_decay: float = 1e-2,
adam_epsilon: float = 1e-08,
max_grad_norm: float = 1.0,
gradient_accumulation_steps: int = 1,
checkpointing_epochs: int = 5,
checkpointing_steps: int = -1,
checkpointing_steps_tuple: Tuple[int] = (),
mixed_precision: str = "fp16",
resume: bool = False,
global_seed: int or str = 42,
is_debug: bool = False,
*args,
**kwargs,
):
# check version
check_min_version("0.30.0.dev0")
# Initialize distributed training
local_rank = init_dist(launcher=launcher)
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
if global_seed == "random":
global_seed = int(datetime.now().timestamp()) % 65535
seed = global_seed + global_rank
torch.manual_seed(seed)
# Logging folder
if resume:
# first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
# the second split "a/b/c/checkpoints" -> "a/b/c
output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
else:
folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
output_dir = os.path.join(output_dir, folder_name)
if is_debug and os.path.exists(output_dir):
os.system(f"rm -rf {output_dir}")
*_, config = inspect.getargvalues(inspect.currentframe())
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))
if train_cfg and is_main_process:
logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")
# Load scheduler, tokenizer and models
if is_main_process:
logging.info("Load scheduler, tokenizer and models.")
if pretrained_vae_path != "":
if 'SVD' in pretrained_vae_path:
vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
else:
vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
else:
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
image_encoder = AutoModel.from_pretrained(pretrained_clip_path)
noise_scheduler_kwargs_dict = OmegaConf.to_container(
noise_scheduler_kwargs
) if noise_scheduler_kwargs is not None else {}
if zero_snr:
if is_main_process:
logging.info("Enable Zero-SNR")
noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
if v_pred:
noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_path,
subfolder="scheduler",
**noise_scheduler_kwargs_dict,
)
unet = RealisDanceUnet(
pretrained_model_path=pretrained_model_path,
image_finetune=image_finetune,
unet_additional_kwargs=unet_additional_kwargs,
pose_guider_kwargs=pose_guider_kwargs,
clip_projector_kwargs=clip_projector_kwargs,
fix_ref_t=fix_ref_t,
fusion_blocks=fusion_blocks,
)
# Load pretrained unet weights
unet_state_dict = {}
if pretrained_mm_path != "" and not image_finetune:
if is_main_process:
logging.info(f"mm from checkpoint: {pretrained_mm_path}")
mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
state_dict = mm_ckpt[
"state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
unet_state_dict.update(
{name: param for name, param in state_dict.items() if "motion_modules." in name})
unet_state_dict.pop("animatediff_config", "")
m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
assert len(u) == 0
for unet_main_module in unet.unet_main.children():
if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
unet_main_module.temporal_transformer.proj_out = zero_module(
unet_main_module.temporal_transformer.proj_out
)
resume_step = 0
if unet_checkpoint_path != "":
if is_main_process:
logging.info(f"from checkpoint: {unet_checkpoint_path}")
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
if "global_step" in unet_checkpoint_path:
if is_main_process:
logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
if resume:
resume_step = unet_checkpoint_path['global_step']
state_dict = unet_checkpoint_path["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_k = k[7:]
else:
new_k = k
new_state_dict[new_k] = state_dict[k]
m, u = unet.load_state_dict(new_state_dict, strict=False)
if is_main_process:
logging.info(f"Load from checkpoint with missing keys:\n{m}")
logging.info(f"Load from checkpoint with unexpected keys:\n{u}")
assert len(u) == 0
# Set unet trainable parameters
unet.requires_grad_(False)
unet.set_trainable_parameters(trainable_modules)
# Set learning rate and optimizer
if scale_lr:
learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
trainable_parameter_keys = []
trainable_params = []
for param_name, param in unet.named_parameters():
if param.requires_grad:
trainable_parameter_keys.append(param_name)
trainable_params.append(param)
if is_main_process:
logging.info(f"trainable params number: {trainable_parameter_keys}")
logging.info(f"trainable params number: {len(trainable_params)}")
logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
optimizer = torch.optim.AdamW(
trainable_params,
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
)
# Set learning rate scheduler
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
)
# Freeze vae and image_encoder
vae.eval()
vae.requires_grad_(False)
image_encoder.eval()
image_encoder.requires_grad_(False)
# move to cuda
vae.to(local_rank)
image_encoder.to(local_rank)
unet.to(local_rank)
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
# Get the training dataloader
train_dataloader = get_distributed_dataloader(
dataset_config=train_data,
batch_size=train_batch_size,
num_processes=num_processes,
num_workers=num_workers,
shuffle=True,
global_rank=global_rank,
seed=global_seed,)
# Get the training iteration
overrode_max_train_steps = False
if max_train_steps == -1:
assert max_train_epoch != -1
max_train_steps = max_train_epoch * len(train_dataloader)
overrode_max_train_steps = True
if checkpointing_steps == -1:
assert checkpointing_epochs != -1
checkpointing_steps = checkpointing_epochs * len(train_dataloader)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if overrode_max_train_steps:
max_train_steps = max_train_epoch * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
if is_main_process:
logging.info("***** Running training *****")
logging.info(f" Num examples = {len(train_dataloader)}")
logging.info(f" Num Epochs = {num_train_epochs}")
logging.info(f" Instantaneous batch size per device = {train_batch_size}")
logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logging.info(f" Total optimization steps = {max_train_steps}")
global_step = resume_step
first_epoch = int(resume_step / num_update_steps_per_epoch)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
progress_bar.set_description("Steps")
# Support mixed-precision training
scaler = torch.cuda.amp.GradScaler() if mixed_precision in ("fp16", "bf16") else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
# Data batch sanity check
# if epoch == first_epoch and step == 0:
# sanity_check(batch, f"{output_dir}/sanity_check", image_finetune, global_rank)
""" >>>> Training >>>> """
# Get images
image = batch["image"].to(local_rank)
pose = batch["pose"].to(local_rank)
hamer = batch["hamer"].to(local_rank)
smpl = batch["smpl"].to(local_rank)
ref_image = batch["ref_image"].to(local_rank)
ref_image_clip = batch["ref_image_clip"].to(local_rank)
if train_cfg and random.random() < cfg_uncond_ratio:
ref_image_clip = torch.zeros_like(ref_image_clip)
drop_reference = True
else:
drop_reference = False
if not image_finetune and pose_shuffle_ratio > 0:
B, C, L, H, W = pose.shape
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]
# Convert images to latent space
with torch.no_grad():
if not image_finetune:
video_length = image.shape[2]
image = rearrange(image, "b c f h w -> (b f) c h w")
latents = vae.encode(image).latent_dist
latents = latents.sample()
latents = latents * vae.config.scaling_factor
ref_latents = vae.encode(ref_image).latent_dist
ref_latents = ref_latents.sample()
ref_latents = ref_latents * vae.config.scaling_factor
clip_latents = image_encoder(ref_image_clip).last_hidden_state
# clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)
if not image_finetune:
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
# Sample noise that we"ll add to the latents
bsz = latents.shape[0]
noise = torch.randn_like(latents)
del image, ref_image, ref_image_clip
torch.cuda.empty_cache()
# Sample a random timestep for each video
train_timesteps = noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
# Mixed-precision training
if mixed_precision in ("fp16", "bf16"):
weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
else:
weight_dtype = torch.float32
with torch.cuda.amp.autocast(
enabled=mixed_precision in ("fp16", "bf16"),
dtype=weight_dtype
):
model_pred = unet(
sample=noisy_latents,
ref_sample=ref_latents,
pose=pose,
hamer=hamer,
smpl=smpl,
timestep=timesteps,
encoder_hidden_states=clip_latents,
drop_reference=drop_reference,
).sample
if snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(noise_scheduler, timesteps)
mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
dim=1
)[0]
if noise_scheduler.config.prediction_type == "epsilon":
mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8) # incase zero-snr
elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Backpropagate
if mixed_precision in ("fp16", "bf16"):
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
""" <<< gradient clipping <<< """
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
global_step += 1
""" <<<< Training <<<< """
del noisy_latents, ref_latents, pose, hamer, smpl
torch.cuda.empty_cache()
# Save checkpoint
if is_main_process and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
save_path = os.path.join(output_dir, f"checkpoints")
state_dict = {
"epoch": epoch,
"global_step": global_step,
"state_dict": unet.state_dict(),
}
torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
logging.info(f"Saved state to {save_path} (global_step: {global_step})")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if is_main_process and global_step % 500 == 0:
logging.info(f"step: {global_step} / {max_train_steps}: {logs}")
if global_step >= max_train_steps:
break
# save the final checkpoint
if is_main_process:
save_path = os.path.join(output_dir, f"checkpoints")
state_dict = {
"epoch": num_train_epochs - 1,
"global_step": global_step,
"state_dict": unet.state_dict(),
}
torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
logging.info(f"Saved final state to {save_path}")
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
parser.add_argument("--output", type=str, required=False, default="")
parser.add_argument("--resume", type=str, default="")
args = parser.parse_args()
exp_name = Path(args.config).stem
exp_config = OmegaConf.load(args.config)
if args.resume != "":
exp_config["unet_checkpoint_path"] = args.resume
exp_config["lr_warmup_steps"] = 0
exp_config["resume"] = True
if args.output != "":
exp_config["output_dir"] = args.output
main(name=exp_name, launcher=args.launcher, **exp_config)
修改后:
import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess
from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModel
from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr
from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_deepspeed_dataloader
import deepspeed
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--resume", type=str, default="")
parser.add_argument('--local_rank', type=int, default=-1,
help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
deep_args = parser.parse_args()
exp_name = Path(deep_args.config).stem
exp_config = OmegaConf.load(deep_args.config)
if deep_args.resume != "":
exp_config["unet_checkpoint_path"] = deep_args.resume
exp_config["lr_warmup_steps"] = 0
exp_config["resume"] = True
if deep_args.output != "":
exp_config["output_dir"] = deep_args.output
def main(
image_finetune: bool,
name: str,
output_dir: str,
pretrained_model_path: str,
pretrained_clip_path: str,
train_data: Dict,
train_cfg: bool = True,
cfg_uncond_ratio: float = 0.05,
pose_shuffle_ratio: float = 0.0,
pretrained_vae_path: str = "",
pretrained_mm_path: str = "",
unet_checkpoint_path: str = "",
unet_additional_kwargs: Dict = None,
noise_scheduler_kwargs: Dict = None,
pose_guider_kwargs: Dict = None,
fusion_blocks: str = "full",
clip_projector_kwargs: Dict = None,
fix_ref_t: bool = False,
zero_snr: bool = False,
snr_gamma: Optional[float] = None,
v_pred: bool = False,
max_train_epoch: int = -1,
max_train_steps: int = 100,
learning_rate: float = 5e-5,
scale_lr: bool = False,
lr_warmup_steps: int = 0,
lr_scheduler: str = "constant",
trainable_modules: Tuple[str] = (),
num_workers: int = 4,
train_batch_size: int = 1,
adam_beta1: float = 0.9,
adam_beta2: float = 0.999,
adam_weight_decay: float = 1e-2,
adam_epsilon: float = 1e-08,
max_grad_norm: float = 1.0,
gradient_accumulation_steps: int = 1,
checkpointing_epochs: int = 5,
checkpointing_steps: int = -1,
checkpointing_steps_tuple: Tuple[int] = (),
mixed_precision: str = "fp16",
resume: bool = False,
global_seed: int or str = 42,
is_debug: bool = False,
*args,
**kwargs,
):
# check version
check_min_version("0.30.0.dev0")
# Initialize distributed training
deepspeed.init_distributed()
global_rank = dist.get_rank()
if global_seed == "random":
global_seed = int(datetime.now().timestamp()) % 65535
seed = global_seed
torch.manual_seed(seed)
# Logging folder
if resume:
# first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
# the second split "a/b/c/checkpoints" -> "a/b/c
output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
else:
folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
output_dir = os.path.join(output_dir, folder_name)
if is_debug and os.path.exists(output_dir):
os.system(f"rm -rf {output_dir}")
*_, config = inspect.getargvalues(inspect.currentframe())
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if mixed_precision in ("fp16", "bf16"):
weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
else:
weight_dtype = torch.float32
if train_cfg:
logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")
# Load scheduler, tokenizer and models
if pretrained_vae_path != "":
if 'SVD' in pretrained_vae_path:
vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
else:
vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
else:
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
image_encoder = AutoModel.from_pretrained(pretrained_clip_path)
noise_scheduler_kwargs_dict = OmegaConf.to_container(
noise_scheduler_kwargs
) if noise_scheduler_kwargs is not None else {}
if zero_snr:
logging.info("Enable Zero-SNR")
noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
if v_pred:
noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_path,
subfolder="scheduler",
**noise_scheduler_kwargs_dict,
)
unet = RealisDanceUnet(
pretrained_model_path=pretrained_model_path,
image_finetune=image_finetune,
unet_additional_kwargs=unet_additional_kwargs,
pose_guider_kwargs=pose_guider_kwargs,
clip_projector_kwargs=clip_projector_kwargs,
fix_ref_t=fix_ref_t,
fusion_blocks=fusion_blocks,
)
# Load pretrained unet weights
unet_state_dict = {}
if pretrained_mm_path != "" and not image_finetune:
logging.info(f"mm from checkpoint: {pretrained_mm_path}")
mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
state_dict = mm_ckpt[
"state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
unet_state_dict.update(
{name: param for name, param in state_dict.items() if "motion_modules." in name})
unet_state_dict.pop("animatediff_config", "")
m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
assert len(u) == 0
for unet_main_module in unet.unet_main.children():
if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
unet_main_module.temporal_transformer.proj_out = zero_module(
unet_main_module.temporal_transformer.proj_out
)
resume_step = 0
if unet_checkpoint_path != "":
logging.info(f"from checkpoint: {unet_checkpoint_path}")
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
if "global_step" in unet_checkpoint_path:
logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
state_dict = unet_checkpoint_path["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_k = k[7:]
else:
new_k = k
new_state_dict[new_k] = state_dict[k]
m, u = unet.load_state_dict(new_state_dict, strict=False)
logging.info(f"Load from checkpoint with missing keys:\n{m}")
logging.info(f"Load from checkpoint with unexpected keys:\n{u}")
# Set unet trainable parameters
unet.set_trainable_parameters(trainable_modules)
trainable_parameter_keys = []
trainable_params = []
for param_name, param in unet.named_parameters():
if param.requires_grad:
trainable_parameter_keys.append(param_name)
trainable_params.append(param)
# logging.info(f"trainable params number: {trainable_parameter_keys}")
logging.info(f"trainable params number: {len(trainable_params)}")
logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
torch.cuda.set_device(deep_args.local_rank)
device = torch.device("cuda", deep_args.local_rank)
unet, optimizer, _, lr_scheduler = deepspeed.initialize(
model=unet,
# optimizer=optimizer,
model_parameters=trainable_params,
args=deep_args
)
# Freeze vae and image_encoder
vae.eval()
vae.requires_grad_(False)
image_encoder.eval()
image_encoder.requires_grad_(False)
# move to cuda
vae.to(weight_dtype)
image_encoder.to(weight_dtype)
unet.to(weight_dtype)
vae.to(device)
image_encoder.to(device)
unet.to(device)
# Get the training dataloader
train_dataloader = get_deepspeed_dataloader(
dataset_config=train_data,
batch_size=train_batch_size,
num_workers=num_workers,
shuffle=True,
seed=global_seed,)
# Get the training iteration
overrode_max_train_steps = False
if max_train_steps == -1:
assert max_train_epoch != -1
max_train_steps = max_train_epoch * len(train_dataloader)
overrode_max_train_steps = True
if checkpointing_steps == -1:
assert checkpointing_epochs != -1
checkpointing_steps = checkpointing_epochs * len(train_dataloader)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if overrode_max_train_steps:
max_train_steps = max_train_epoch * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# Handle the output folder creation
if unet.global_rank == 0:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))
# Train!
if unet.global_rank == 0:
logging.info("***** Running training *****")
logging.info(f" Num examples = {len(train_dataloader)}")
logging.info(f" Num Epochs = {num_train_epochs}")
logging.info(f" Instantaneous batch size per device = {train_batch_size}")
logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logging.info(f" Total optimization steps = {max_train_steps}")
global_step = resume_step
first_epoch = int(resume_step / num_update_steps_per_epoch)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not unet.global_rank == 0)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
# Data batch sanity check
""" >>>> Training >>>> """
# Get images
image = batch["image"].to(weight_dtype)
pose = batch["pose"].to(weight_dtype)
hamer = batch["hamer"].to(weight_dtype)
smpl = batch["smpl"].to(weight_dtype)
ref_image = batch["ref_image"].to(weight_dtype)
ref_image_clip = batch["ref_image_clip"].to(weight_dtype)
image = image.to(device)
pose = pose.to(device)
hamer = hamer.to(device)
smpl = smpl.to(device)
ref_image = ref_image.to(device)
ref_image_clip = ref_image_clip.to(device)
if train_cfg and random.random() < cfg_uncond_ratio:
ref_image_clip = torch.zeros_like(ref_image_clip)
drop_reference = True
else:
drop_reference = False
if not image_finetune and pose_shuffle_ratio > 0:
B, C, L, H, W = pose.shape
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
if random.random() < pose_shuffle_ratio:
rand_idx = torch.randperm(L).long()
smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]
# Convert images to latent space
with torch.no_grad():
if not image_finetune:
video_length = image.shape[2]
image = rearrange(image, "b c f h w -> (b f) c h w")
latents = vae.encode(image).latent_dist
latents = latents.sample()
latents = latents * vae.config.scaling_factor
ref_latents = vae.encode(ref_image).latent_dist
ref_latents = ref_latents.sample()
ref_latents = ref_latents * vae.config.scaling_factor
clip_latents = image_encoder(ref_image_clip).last_hidden_state
# clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)
if not image_finetune:
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
# Sample noise that we"ll add to the latents
bsz = latents.shape[0]
noise = torch.randn_like(latents)
latents = latents.to(weight_dtype)
noise = noise.to(weight_dtype)
# Sample a random timestep for each video
train_timesteps = noise_scheduler.config.num_train_timesteps
timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(
sample=noisy_latents,
ref_sample=ref_latents,
pose=pose,
hamer=hamer,
smpl=smpl,
timestep=timesteps,
encoder_hidden_states=clip_latents,
drop_reference=drop_reference,
).sample
if snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(noise_scheduler, timesteps)
mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
dim=1
)[0]
if noise_scheduler.config.prediction_type == "epsilon":
mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8) # incase zero-snr
elif noise_scheduler.config.prediction_type == "v_prediction":
mse_loss_weights = mse_loss_weights / (snr + 1)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
unet.backward(loss)
# unet.clip_grad_norm_(max_grad_norm)
unet.step()
lr_scheduler.step()
optimizer.step()
# optimizer.zero_grad()
progress_bar.update(1)
global_step += 1
""" <<<< Training <<<< """
# Save checkpoint
if unet.global_rank == 0 and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
save_path = os.path.join(output_dir, f"checkpoints")
state_dict = {
"epoch": epoch,
"global_step": global_step,
"state_dict": unet.state_dict(),
}
torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
logging.info(f"Saved state to {save_path} (global_step: {global_step})")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= max_train_steps:
break
# save the final checkpoint
if unet.global_rank == 0:
save_path = os.path.join(output_dir, f"checkpoints")
state_dict = {
"epoch": num_train_epochs - 1,
"global_step": global_step,
"state_dict": unet.state_dict(),
}
torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
logging.info(f"Saved final state to {save_path}")
dist.destroy_process_group()
main(name=exp_name, **exp_config)
另外训diffusion优先用accelerate,这里面已经集成了deepspeed!