手把手写深度学习(29):将DDP训练代码改成DeepSpeed

news2025/1/23 13:43:50

手把手写深度学习(0):专栏文章导航

前言:deepspeed已经成为了大模型时代训练模型的常规武器,这篇博客以一个基于DDP的 Stable Diffusion模型训练为例,讲解如何从将DDP训练代码改成DeepSpeed。

目录

构建optimizer

构建scheduler

Argument

Init Process

加载训练组件

训练代码修改

启动命令行

weight_dtype

DDP并行策略

配置文件

Loss为NAN

训练精度

调参步骤

示例1

示例2


构建optimizer

原来DDP代码如下:

    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

deepspeed只需要在配置文件中写好这些参数即可:

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-08,
            "weight_decay": 1e-2
        }
    }

如果你需要其他optimizer,可以看官方文档:DeepSpeed Configuration JSON - DeepSpeed 

构建scheduler

原先DDP代码中构建scheduler:

    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )

deepspeed中同样只需要在配置文件中构建:

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 0.001,
            "warmup_num_steps": 1000
        }
    }

Argument

必须在train.py 文件添加以下的代码:

parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,
                    help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()

否则会报错!也无法识别deepspeed配置文件!

Init Process

原先DDP的代码:

import torch.distributed as dist
dist.init_process_group(backend=backend, **kwargs)

现在改成:

deepspeed.init_distributed()

加载训练组件

如果在配置文件中写好了之后,只需要像这样直接加载,不需要另外写代码:

    unet, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=unet,
        # optimizer=optimizer,
        model_parameters=trainable_params,
        args=deep_args
    )

训练代码修改

deepspeed的标准训练方式如下,和torch的训练有一点点区别,直接修改就好:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

当然optimizer 和 scheduler的step可以自己加上:

            lr_scheduler.step()
            optimizer.step()

启动命令行

这是我的参考:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed --num_gpus=8 train_deepspeed.py \
    --deepspeed --deepspeed_config 'configs/deepspeed/stage_2.json'

deepspeed_config这个参数是下面说的配置文件。

weight_dtype

虽然配置文件中会写上训练精度,但是代码中仍需要手动转换!

    if mixed_precision in ("fp16", "bf16"):
        weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
    else:
        weight_dtype = torch.float32

DDP并行策略

原始代码:

    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

 这句话直接删掉,deepspeed会自动帮我们做并行

配置文件

下面是我用的一个stage 2的参考文件:

{
    "bfloat16": {
        "enabled": false
    },
    "fp16": {
        "enabled": true,
        "auto_cast": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1e-4
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-08,
            "weight_decay": 1e-2
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 0.001,
            "warmup_num_steps": 1000
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "train_batch_size": 8,
    "train_micro_batch_size_per_gpu": 1,
    "steps_per_print": 1e5
}

Loss为NAN

训练时用的是bf16,使用时是fp16。常常发生于google在TPU上train的模型,如T5。此时需要使用fp32或者bf16。

训练精度

  • 由于 fp16 混合精度大大减少了内存需求,并可以实现更快的速度,因此只有在在此训练模式下表现不佳时,才考虑不使用混合精度训练。 通常,当模型未在 fp16 混合精度中进行预训练时,会出现这种情况(例如,使用 bf16 预训练的模型)。 这样的模型可能会溢出,导致loss为NaN。 如果是这种情况,使用完整的 fp32 模式。

  • 如果是基于 Ampere 架构的 GPU,pytorch 1.7 及更高版本将自动切换为使用更高效的 tf32 格式进行某些操作,但结果仍将采用 fp32。

  • 使用 Trainer,可以使用 --tf32 启用它,或使用 --tf32 0 或 --no_tf32 禁用它。 PyTorch 默认值是使用tf32。

调参步骤

  • 将batch_size设置为1,通过梯度累积实现任意的有效batch_size

  • 如果OOM则,设置--gradient_checkpointing 1 (HF Trainer),或者 model.gradient_checkpointing_enable()

  • 如果OOM则,尝试ZeRO stage 2

  • 如果OOM则,尝试ZeRO stage 2 + offload_optimizer

  • 如果OOM则,尝试ZeRO stage 3

  • 如果OOM则,尝试offload_param到CPU

  • 如果OOM则,尝试offload_optimizer到CPU

  • 如果OOM则,尝试降低一些默认参数。比如使用generate时,减小beam search的搜索范围

  • 如果OOM则,使用混合精度训练,在Ampere的GPU上使用bf16,在旧版本GPU上使用fp16

  • 如果仍然OOM,则使用ZeRO-Infinity ,使用offload_param和offload_optimizer到NVME

  • 一旦使用batch_size=1时,没有导致OOM,测量此时的有效吞吐量,然后尽可能增大batch_size

  • 开始优化参数,可以关闭offload参数,或者降低ZeRO stage,然后调整batch_size,然后继续测量吞吐量,直到性能比较满意(调参可以增加66%的性能)

示例1

官方提供了完整的例子,很多操作可以照抄:GitHub - microsoft/DeepSpeedExamples: Example models using DeepSpeed

初学者可以从imagenet开始学:DeepSpeedExamples/training/imagenet/main.py at master · microsoft/DeepSpeedExamples · GitHub

import argparse
import os
import pdb
import random
import shutil
import time
import warnings
from enum import Enum

import torch
import deepspeed
from openpyxl import Workbook
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
                    help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing_distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('--local_rank', type=int, default=-1, help="local rank for distributed training on gpus")

parser = deepspeed.add_config_arguments(parser)
best_acc1 = 0

def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    if torch.cuda.is_available():
        ngpus_per_node = torch.cuda.device_count()
    else:
        ngpus_per_node = 1

    args.world_size = ngpus_per_node * args.world_size
    t_losses, t_acc1s = main_worker(args.gpu, ngpus_per_node, args)
    #dist.barrier()
    
    # Write the losses to an excel file
    if dist.get_rank() ==0:
        all_losses = [torch.empty_like(t_losses) for _ in range(ngpus_per_node)]
        dist.gather(tensor=t_losses, gather_list=all_losses,dst=0)
    else:
        dist.gather(tensor=t_losses, dst=0)

    if dist.get_rank() ==0:
        all_acc1s = [torch.empty_like(t_acc1s) for _ in range(ngpus_per_node)]
        dist.gather(tensor=t_acc1s, gather_list=all_acc1s,dst=0)
    else:
        dist.gather(tensor=t_acc1s, dst=0)

    if dist.get_rank() == 0:
        outputfile = "Acc_loss_log.xlsx"
        workbook = Workbook()
        sheet1 = workbook.active
        sheet1.cell(row= 1, column = 1, value = "Loss")
        sheet1.cell(row= 1, column = ngpus_per_node + 4, value = "Acc")
        for rank in range(ngpus_per_node):
            for row_idx, (gpu_losses, gpu_acc1s) in enumerate(zip(all_losses[rank], all_acc1s[rank])):
                sheet1.cell(row=row_idx + 2, column = rank+1, value = float(gpu_losses))
                sheet1.cell(row=row_idx + 2, column = rank+1 + ngpus_per_node + 3, value = float(gpu_acc1s))
        workbook.save(outputfile)

def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    # In case of distributed process, initializes the distributed backend
    # which will take care of sychronizing nodes/GPUs
    if args.local_rank == -1:
        if args.gpu:
            device = torch.device('cuda:{}'.format(args.gpu))
        else:
            device = torch.device("cuda")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        deepspeed.init_distributed()
    def print_rank_0(msg):
        if args.local_rank <=0:
            print(msg)

    args.batch_size = int(args.batch_size / ngpus_per_node)
    if not torch.cuda.is_available():# and not torch.backends.mps.is_available():
        print('using CPU, this will be slow')
        device = torch.device("cpu")
        model = model.to(device)

    # define loss function (criterion), optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            elif torch.cuda.is_available():
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Initialize DeepSpeed for the model
    model, optimizer, _, _ = deepspeed.initialize(
        model = model,
        optimizer = optimizer,
        args = args,
        lr_scheduler = None,#scheduler,
        dist_init_required=True
        )

    # Data loading code
    if args.dummy:
        print("=> Dummy data is used!")
        train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor())
        val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    if args.local_rank != -1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
    else:
        train_sampler = None
        val_sampler = None

    print("Batch_size:",args.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)


    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    losses = torch.empty(args.epochs).cuda()
    acc1s = torch.empty(args.epochs).cuda()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # train for one epoch
        this_loss = train(train_loader, model, criterion, optimizer, epoch, device, args)
        losses[epoch] = this_loss

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)
        acc1s[epoch] = acc1

        scheduler.step()
        
        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.gpu is None):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best)

    return (losses, acc1s)

def train(train_loader, model, criterion, optimizer, epoch, device, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        model.backward(loss)
        model.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i + 1)

    return (float(losses.val))


def validate(val_loader, model, criterion, args):

    def run_validate(loader, base_progress=0):
        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(loader):
                i = base_progress + i

                if torch.cuda.is_available():
                    target = target.cuda(args.gpu, non_blocking=True)
                    images = images.cuda(args.gpu, non_blocking=True)

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % args.print_freq == 0:
                    progress.display(i + 1)

    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    run_validate(val_loader)
    if args.distributed:
        top1.all_reduce()
        top5.all_reduce()

    if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
        aux_val_dataset = Subset(val_loader.dataset,
                                 range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
        aux_val_loader = torch.utils.data.DataLoader(
            aux_val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        run_validate(aux_val_loader, len(val_loader))

    progress.display_summary()

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)
        
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
        
    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()

示例2

原始的DDP代码:

import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess

from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoModel

from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr

from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_distributed_dataloader, sanity_check


def init_dist(launcher="slurm", backend="nccl", port=29500, **kwargs):
    """Initializes distributed environment."""
    if launcher == "pytorch":
        rank = int(os.environ["RANK"])
        num_gpus = torch.cuda.device_count()
        local_rank = rank % num_gpus
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend=backend, **kwargs)

    elif launcher == "slurm":
        proc_id = int(os.environ["SLURM_PROCID"])
        ntasks = int(os.environ["SLURM_NTASKS"])
        node_list = os.environ["SLURM_NODELIST"]
        num_gpus = torch.cuda.device_count()
        local_rank = proc_id % num_gpus
        torch.cuda.set_device(local_rank)
        addr = subprocess.getoutput(
            f"scontrol show hostname {node_list} | head -n1")
        os.environ["MASTER_ADDR"] = addr
        os.environ["WORLD_SIZE"] = str(ntasks)
        os.environ["RANK"] = str(proc_id)
        port = os.environ.get("PORT", port)
        os.environ["MASTER_PORT"] = str(port)
        dist.init_process_group(backend=backend)
        print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; "
              f"node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")

    else:
        raise NotImplementedError(f"Not implemented launcher type: `{launcher}`!")

    return local_rank


def main(
    image_finetune: bool,

    name: str,
    launcher: str,

    output_dir: str,
    pretrained_model_path: str,
    pretrained_clip_path: str,

    train_data: Dict,
    train_cfg: bool = True,
    cfg_uncond_ratio: float = 0.05,
    pose_shuffle_ratio: float = 0.0,

    pretrained_vae_path: str = "",
    pretrained_mm_path: str = "",
    unet_checkpoint_path: str = "",
    unet_additional_kwargs: Dict = None,
    noise_scheduler_kwargs: Dict = None,
    pose_guider_kwargs: Dict = None,
    fusion_blocks: str = "full",
    clip_projector_kwargs: Dict = None,
    fix_ref_t: bool = False,
    zero_snr: bool = False,
    snr_gamma: Optional[float] = None,
    v_pred: bool = False,

    max_train_epoch: int = -1,
    max_train_steps: int = 100,

    learning_rate: float = 5e-5,
    scale_lr: bool = False,
    lr_warmup_steps: int = 0,
    lr_scheduler: str = "constant",

    trainable_modules: Tuple[str] = (),
    num_workers: int = 4,
    train_batch_size: int = 1,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    checkpointing_epochs: int = 5,
    checkpointing_steps: int = -1,
    checkpointing_steps_tuple: Tuple[int] = (),

    mixed_precision: str = "fp16",
    resume: bool = False,

    global_seed: int or str = 42,
    is_debug: bool = False,

    *args,
    **kwargs,
):
    # check version
    check_min_version("0.30.0.dev0")

    # Initialize distributed training
    local_rank = init_dist(launcher=launcher)
    global_rank = dist.get_rank()
    num_processes = dist.get_world_size()
    is_main_process = global_rank == 0

    if global_seed == "random":
        global_seed = int(datetime.now().timestamp()) % 65535

    seed = global_seed + global_rank
    torch.manual_seed(seed)

    # Logging folder
    if resume:
        # first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
        # the second split "a/b/c/checkpoints" -> "a/b/c
        output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
    else:
        folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
        output_dir = os.path.join(output_dir, folder_name)
    if is_debug and os.path.exists(output_dir):
        os.system(f"rm -rf {output_dir}")
    *_, config = inspect.getargvalues(inspect.currentframe())

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Handle the output folder creation
    if is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))

    if train_cfg and is_main_process:
        logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")

    # Load scheduler, tokenizer and models
    if is_main_process:
        logging.info("Load scheduler, tokenizer and models.")
    if pretrained_vae_path != "":
        if 'SVD' in pretrained_vae_path:
            vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
        else:
            vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
    else:
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")

    image_encoder = AutoModel.from_pretrained(pretrained_clip_path)

    noise_scheduler_kwargs_dict = OmegaConf.to_container(
        noise_scheduler_kwargs
    ) if noise_scheduler_kwargs is not None else {}
    if zero_snr:
        if is_main_process:
            logging.info("Enable Zero-SNR")
        noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
        if v_pred:
            noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
            noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
    noise_scheduler = DDPMScheduler.from_pretrained(
        pretrained_model_path,
        subfolder="scheduler",
        **noise_scheduler_kwargs_dict,
    )

    unet = RealisDanceUnet(
        pretrained_model_path=pretrained_model_path,
        image_finetune=image_finetune,
        unet_additional_kwargs=unet_additional_kwargs,
        pose_guider_kwargs=pose_guider_kwargs,
        clip_projector_kwargs=clip_projector_kwargs,
        fix_ref_t=fix_ref_t,
        fusion_blocks=fusion_blocks,
    )

    # Load pretrained unet weights
    unet_state_dict = {}
    if pretrained_mm_path != "" and not image_finetune:
        if is_main_process:
            logging.info(f"mm from checkpoint: {pretrained_mm_path}")
        mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
        state_dict = mm_ckpt[
            "state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
        unet_state_dict.update(
            {name: param for name, param in state_dict.items() if "motion_modules." in name})
        unet_state_dict.pop("animatediff_config", "")
        m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
        print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0
        for unet_main_module in unet.unet_main.children():
            if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
                unet_main_module.temporal_transformer.proj_out = zero_module(
                    unet_main_module.temporal_transformer.proj_out
                )

    resume_step = 0
    if unet_checkpoint_path != "":
        if is_main_process:
            logging.info(f"from checkpoint: {unet_checkpoint_path}")
        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
        if "global_step" in unet_checkpoint_path:
            if is_main_process:
                logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
            if resume:
                resume_step = unet_checkpoint_path['global_step']
        state_dict = unet_checkpoint_path["state_dict"]
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module."):
                new_k = k[7:]
            else:
                new_k = k
            new_state_dict[new_k] = state_dict[k]
        m, u = unet.load_state_dict(new_state_dict, strict=False)
        if is_main_process:
            logging.info(f"Load from checkpoint with missing keys:\n{m}")
            logging.info(f"Load from checkpoint with unexpected keys:\n{u}")
        assert len(u) == 0

    # Set unet trainable parameters
    unet.requires_grad_(False)
    unet.set_trainable_parameters(trainable_modules)

    # Set learning rate and optimizer
    if scale_lr:
        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)

    trainable_parameter_keys = []
    trainable_params = []
    for param_name, param in unet.named_parameters():
        if param.requires_grad:
            trainable_parameter_keys.append(param_name)
            trainable_params.append(param)
    if is_main_process:
        logging.info(f"trainable params number: {trainable_parameter_keys}")
        logging.info(f"trainable params number: {len(trainable_params)}")
        logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")

    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    # Set learning rate scheduler
    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )

    # Freeze vae and image_encoder
    vae.eval()
    vae.requires_grad_(False)
    image_encoder.eval()
    image_encoder.requires_grad_(False)

    # move to cuda
    vae.to(local_rank)
    image_encoder.to(local_rank)
    unet.to(local_rank)
    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

    # Get the training dataloader
    train_dataloader = get_distributed_dataloader(
        dataset_config=train_data,
        batch_size=train_batch_size,
        num_processes=num_processes,
        num_workers=num_workers,
        shuffle=True,
        global_rank=global_rank,
        seed=global_seed,)

    # Get the training iteration
    overrode_max_train_steps = False
    if max_train_steps == -1:
        assert max_train_epoch != -1
        max_train_steps = max_train_epoch * len(train_dataloader)
        overrode_max_train_steps = True

    if checkpointing_steps == -1:
        assert checkpointing_epochs != -1
        checkpointing_steps = checkpointing_epochs * len(train_dataloader)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if overrode_max_train_steps:
        max_train_steps = max_train_epoch * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps

    if is_main_process:
        logging.info("***** Running training *****")
        logging.info(f"  Num examples = {len(train_dataloader)}")
        logging.info(f"  Num Epochs = {num_train_epochs}")
        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
        logging.info(f"  Total optimization steps = {max_train_steps}")
    global_step = resume_step
    first_epoch = int(resume_step / num_update_steps_per_epoch)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
    progress_bar.set_description("Steps")

    # Support mixed-precision training
    scaler = torch.cuda.amp.GradScaler() if mixed_precision in ("fp16", "bf16") else None

    for epoch in range(first_epoch, num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)
        unet.train()

        for step, batch in enumerate(train_dataloader):
            # Data batch sanity check
            # if epoch == first_epoch and step == 0:
            #     sanity_check(batch, f"{output_dir}/sanity_check", image_finetune, global_rank)

            """ >>>> Training >>>> """
            # Get images
            image = batch["image"].to(local_rank)
            pose = batch["pose"].to(local_rank)
            hamer = batch["hamer"].to(local_rank)
            smpl = batch["smpl"].to(local_rank)
            ref_image = batch["ref_image"].to(local_rank)
            ref_image_clip = batch["ref_image_clip"].to(local_rank)

            if train_cfg and random.random() < cfg_uncond_ratio:
                ref_image_clip = torch.zeros_like(ref_image_clip)
                drop_reference = True
            else:
                drop_reference = False

            if not image_finetune and pose_shuffle_ratio > 0:
                B, C, L, H, W = pose.shape
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]

            # Convert images to latent space
            with torch.no_grad():
                if not image_finetune:
                    video_length = image.shape[2]
                    image = rearrange(image, "b c f h w -> (b f) c h w")

                latents = vae.encode(image).latent_dist
                latents = latents.sample()
                latents = latents * vae.config.scaling_factor

                ref_latents = vae.encode(ref_image).latent_dist
                ref_latents = ref_latents.sample()
                ref_latents = ref_latents * vae.config.scaling_factor

                clip_latents = image_encoder(ref_image_clip).last_hidden_state
                # clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)

                if not image_finetune:
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)

                # Sample noise that we"ll add to the latents
                bsz = latents.shape[0]
                noise = torch.randn_like(latents)

                del image, ref_image, ref_image_clip
                torch.cuda.empty_cache()

            # Sample a random timestep for each video
            train_timesteps = noise_scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            # Mixed-precision training
            if mixed_precision in ("fp16", "bf16"):
                weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
            else:
                weight_dtype = torch.float32
            with torch.cuda.amp.autocast(
                enabled=mixed_precision in ("fp16", "bf16"),
                dtype=weight_dtype
            ):
                model_pred = unet(
                    sample=noisy_latents,
                    ref_sample=ref_latents,
                    pose=pose,
                    hamer=hamer,
                    smpl=smpl,
                    timestep=timesteps,
                    encoder_hidden_states=clip_latents,
                    drop_reference=drop_reference,
                ).sample

                if snr_gamma is None:
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                else:
                    snr = compute_snr(noise_scheduler, timesteps)
                    mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
                        dim=1
                    )[0]
                    if noise_scheduler.config.prediction_type == "epsilon":
                        mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8)  # incase zero-snr
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        mse_loss_weights = mse_loss_weights / (snr + 1)
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()

            # Backpropagate
            if mixed_precision in ("fp16", "bf16"):
                scaler.scale(loss).backward()
                """ >>> gradient clipping >>> """
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
                """ <<< gradient clipping <<< """
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                """ >>> gradient clipping >>> """
                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
                """ <<< gradient clipping <<< """
                optimizer.step()

            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            global_step += 1
            """ <<<< Training <<<< """
            
            del noisy_latents, ref_latents, pose, hamer, smpl
            torch.cuda.empty_cache()

            # Save checkpoint
            if is_main_process and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
                save_path = os.path.join(output_dir, f"checkpoints")
                state_dict = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "state_dict": unet.state_dict(),
                }
                torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
                logging.info(f"Saved state to {save_path} (global_step: {global_step})")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            if is_main_process and global_step % 500 == 0:
                logging.info(f"step: {global_step} / {max_train_steps}:  {logs}")
            if global_step >= max_train_steps:
                break

    # save the final checkpoint
    if is_main_process:
        save_path = os.path.join(output_dir, f"checkpoints")
        state_dict = {
            "epoch": num_train_epochs - 1,
            "global_step": global_step,
            "state_dict": unet.state_dict(),
        }
        torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
        logging.info(f"Saved final state to {save_path}")

    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
    parser.add_argument("--output", type=str, required=False, default="")
    parser.add_argument("--resume", type=str, default="")
    args = parser.parse_args()

    exp_name = Path(args.config).stem
    exp_config = OmegaConf.load(args.config)

    if args.resume != "":
        exp_config["unet_checkpoint_path"] = args.resume
        exp_config["lr_warmup_steps"] = 0
        exp_config["resume"] = True

    if args.output != "":
        exp_config["output_dir"] = args.output

    main(name=exp_name, launcher=args.launcher, **exp_config)

修改后:

import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess

from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoModel

from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr

from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_deepspeed_dataloader

import deepspeed

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--resume", type=str, default="")
parser.add_argument('--local_rank', type=int, default=-1,
                help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
deep_args = parser.parse_args()

exp_name = Path(deep_args.config).stem
exp_config = OmegaConf.load(deep_args.config)

if deep_args.resume != "":
    exp_config["unet_checkpoint_path"] = deep_args.resume
    exp_config["lr_warmup_steps"] = 0
    exp_config["resume"] = True

if deep_args.output != "":
    exp_config["output_dir"] = deep_args.output

def main(
    image_finetune: bool,

    name: str,

    output_dir: str,
    pretrained_model_path: str,
    pretrained_clip_path: str,

    train_data: Dict,
    train_cfg: bool = True,
    cfg_uncond_ratio: float = 0.05,
    pose_shuffle_ratio: float = 0.0,

    pretrained_vae_path: str = "",
    pretrained_mm_path: str = "",
    unet_checkpoint_path: str = "",
    unet_additional_kwargs: Dict = None,
    noise_scheduler_kwargs: Dict = None,
    pose_guider_kwargs: Dict = None,
    fusion_blocks: str = "full",
    clip_projector_kwargs: Dict = None,
    fix_ref_t: bool = False,
    zero_snr: bool = False,
    snr_gamma: Optional[float] = None,
    v_pred: bool = False,

    max_train_epoch: int = -1,
    max_train_steps: int = 100,

    learning_rate: float = 5e-5,
    scale_lr: bool = False,
    lr_warmup_steps: int = 0,
    lr_scheduler: str = "constant",

    trainable_modules: Tuple[str] = (),
    num_workers: int = 4,
    train_batch_size: int = 1,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    checkpointing_epochs: int = 5,
    checkpointing_steps: int = -1,
    checkpointing_steps_tuple: Tuple[int] = (),

    mixed_precision: str = "fp16",
    resume: bool = False,

    global_seed: int or str = 42,
    is_debug: bool = False,

    *args,
    **kwargs,
):
    # check version
    check_min_version("0.30.0.dev0")

    # Initialize distributed training
    deepspeed.init_distributed()

    global_rank = dist.get_rank()

    if global_seed == "random":
        global_seed = int(datetime.now().timestamp()) % 65535

    seed = global_seed
    torch.manual_seed(seed)

    # Logging folder
    if resume:
        # first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
        # the second split "a/b/c/checkpoints" -> "a/b/c
        output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
    else:
        folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
        output_dir = os.path.join(output_dir, folder_name)
    if is_debug and os.path.exists(output_dir):
        os.system(f"rm -rf {output_dir}")
    *_, config = inspect.getargvalues(inspect.currentframe())

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    if mixed_precision in ("fp16", "bf16"):
        weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
    else:
        weight_dtype = torch.float32


    if train_cfg:
        logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")

    # Load scheduler, tokenizer and models
    if pretrained_vae_path != "":
        if 'SVD' in pretrained_vae_path:
            vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
        else:
            vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
    else:
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")

    image_encoder = AutoModel.from_pretrained(pretrained_clip_path)

    noise_scheduler_kwargs_dict = OmegaConf.to_container(
        noise_scheduler_kwargs
    ) if noise_scheduler_kwargs is not None else {}
    if zero_snr:
        logging.info("Enable Zero-SNR")
        noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
        if v_pred:
            noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
            noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
    noise_scheduler = DDPMScheduler.from_pretrained(
        pretrained_model_path,
        subfolder="scheduler",
        **noise_scheduler_kwargs_dict,
    )

    unet = RealisDanceUnet(
        pretrained_model_path=pretrained_model_path,
        image_finetune=image_finetune,
        unet_additional_kwargs=unet_additional_kwargs,
        pose_guider_kwargs=pose_guider_kwargs,
        clip_projector_kwargs=clip_projector_kwargs,
        fix_ref_t=fix_ref_t,
        fusion_blocks=fusion_blocks,
    )

    # Load pretrained unet weights
    unet_state_dict = {}
    if pretrained_mm_path != "" and not image_finetune:
        logging.info(f"mm from checkpoint: {pretrained_mm_path}")
        mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
        state_dict = mm_ckpt[
            "state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
        unet_state_dict.update(
            {name: param for name, param in state_dict.items() if "motion_modules." in name})
        unet_state_dict.pop("animatediff_config", "")
        m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
        print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0
        for unet_main_module in unet.unet_main.children():
            if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
                unet_main_module.temporal_transformer.proj_out = zero_module(
                    unet_main_module.temporal_transformer.proj_out
                )

    resume_step = 0
    if unet_checkpoint_path != "":
        logging.info(f"from checkpoint: {unet_checkpoint_path}")
        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
        if "global_step" in unet_checkpoint_path:
            logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
        state_dict = unet_checkpoint_path["state_dict"]
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module."):
                new_k = k[7:]
            else:
                new_k = k
            new_state_dict[new_k] = state_dict[k]
        m, u = unet.load_state_dict(new_state_dict, strict=False)
        logging.info(f"Load from checkpoint with missing keys:\n{m}")
        logging.info(f"Load from checkpoint with unexpected keys:\n{u}")


    # Set unet trainable parameters
    unet.set_trainable_parameters(trainable_modules)

    trainable_parameter_keys = []
    trainable_params = []
    for param_name, param in unet.named_parameters():
        if param.requires_grad:
            trainable_parameter_keys.append(param_name)
            trainable_params.append(param)

    # logging.info(f"trainable params number: {trainable_parameter_keys}")
    logging.info(f"trainable params number: {len(trainable_params)}")
    logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")

    torch.cuda.set_device(deep_args.local_rank)
    device = torch.device("cuda", deep_args.local_rank)
    unet, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=unet,
        # optimizer=optimizer,
        model_parameters=trainable_params,
        args=deep_args
    )

    # Freeze vae and image_encoder
    vae.eval()
    vae.requires_grad_(False)
    image_encoder.eval()
    image_encoder.requires_grad_(False)

    # move to cuda
    vae.to(weight_dtype)
    image_encoder.to(weight_dtype)
    unet.to(weight_dtype)
    vae.to(device)
    image_encoder.to(device)
    unet.to(device)

    # Get the training dataloader
    train_dataloader = get_deepspeed_dataloader(
        dataset_config=train_data,
        batch_size=train_batch_size,
        num_workers=num_workers,
        shuffle=True,
        seed=global_seed,)

    # Get the training iteration
    overrode_max_train_steps = False
    if max_train_steps == -1:
        assert max_train_epoch != -1
        max_train_steps = max_train_epoch * len(train_dataloader)
        overrode_max_train_steps = True

    if checkpointing_steps == -1:
        assert checkpointing_epochs != -1
        checkpointing_steps = checkpointing_epochs * len(train_dataloader)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if overrode_max_train_steps:
        max_train_steps = max_train_epoch * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Handle the output folder creation
    if unet.global_rank == 0:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))

    # Train!
    if unet.global_rank == 0:
        logging.info("***** Running training *****")
        logging.info(f"  Num examples = {len(train_dataloader)}")
        logging.info(f"  Num Epochs = {num_train_epochs}")
        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
        logging.info(f"  Total optimization steps = {max_train_steps}")
    global_step = resume_step
    first_epoch = int(resume_step / num_update_steps_per_epoch)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, max_train_steps), disable=not unet.global_rank == 0)
    progress_bar.set_description("Steps")


    for epoch in range(first_epoch, num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)
        unet.train()

        for step, batch in enumerate(train_dataloader):
            # Data batch sanity check
            """ >>>> Training >>>> """
            # Get images
            image = batch["image"].to(weight_dtype)
            pose = batch["pose"].to(weight_dtype)
            hamer = batch["hamer"].to(weight_dtype)
            smpl = batch["smpl"].to(weight_dtype)
            ref_image = batch["ref_image"].to(weight_dtype)
            ref_image_clip = batch["ref_image_clip"].to(weight_dtype)

            image = image.to(device)
            pose = pose.to(device)
            hamer = hamer.to(device)
            smpl = smpl.to(device)
            ref_image = ref_image.to(device)
            ref_image_clip = ref_image_clip.to(device)

            if train_cfg and random.random() < cfg_uncond_ratio:
                ref_image_clip = torch.zeros_like(ref_image_clip)
                drop_reference = True
            else:
                drop_reference = False

            if not image_finetune and pose_shuffle_ratio > 0:
                B, C, L, H, W = pose.shape
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]

            # Convert images to latent space
            with torch.no_grad():
                if not image_finetune:
                    video_length = image.shape[2]
                    image = rearrange(image, "b c f h w -> (b f) c h w")

                latents = vae.encode(image).latent_dist
                latents = latents.sample()
                latents = latents * vae.config.scaling_factor

                ref_latents = vae.encode(ref_image).latent_dist
                ref_latents = ref_latents.sample()
                ref_latents = ref_latents * vae.config.scaling_factor

                clip_latents = image_encoder(ref_image_clip).last_hidden_state
                # clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)

                if not image_finetune:
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)

                # Sample noise that we"ll add to the latents
                bsz = latents.shape[0]
                noise = torch.randn_like(latents)

                latents = latents.to(weight_dtype)
                noise = noise.to(weight_dtype)

            # Sample a random timestep for each video
            train_timesteps = noise_scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            model_pred = unet(
                sample=noisy_latents,
                ref_sample=ref_latents,
                pose=pose,
                hamer=hamer,
                smpl=smpl,
                timestep=timesteps,
                encoder_hidden_states=clip_latents,
                drop_reference=drop_reference,
            ).sample

            if snr_gamma is None:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            else:
                snr = compute_snr(noise_scheduler, timesteps)
                mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
                    dim=1
                )[0]
                if noise_scheduler.config.prediction_type == "epsilon":
                    mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8)  # incase zero-snr
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    mse_loss_weights = mse_loss_weights / (snr + 1)
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                loss = loss.mean()

            unet.backward(loss)
            # unet.clip_grad_norm_(max_grad_norm)
            unet.step()

            lr_scheduler.step()
            optimizer.step()
            # optimizer.zero_grad()
            progress_bar.update(1)
            global_step += 1
            """ <<<< Training <<<< """

            # Save checkpoint
            if unet.global_rank == 0 and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
                save_path = os.path.join(output_dir, f"checkpoints")
                state_dict = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "state_dict": unet.state_dict(),
                }
                torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
                logging.info(f"Saved state to {save_path} (global_step: {global_step})")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            if global_step >= max_train_steps:
                break

    # save the final checkpoint
    if unet.global_rank == 0:
        save_path = os.path.join(output_dir, f"checkpoints")
        state_dict = {
            "epoch": num_train_epochs - 1,
            "global_step": global_step,
            "state_dict": unet.state_dict(),
        }
        torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
        logging.info(f"Saved final state to {save_path}")

    dist.destroy_process_group()

main(name=exp_name, **exp_config)

另外训diffusion优先用accelerate,这里面已经集成了deepspeed!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2239018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

信息收集系列(二):ASN分析及域名收集

内容预览 ≧∀≦ゞ 信息收集系列&#xff08;二&#xff09;&#xff1a;ASN分析及域名收集前言一、ASN 分析1. 获取 ASN 码2. 使用 ASNMap 获取 IP 范围3. 将 IP 范围转化为 IP 列表 二、关联域名收集1. 顶级域&#xff08;TLD&#xff09;收集测试方法 2. 根域名收集常用方法…

《数学分析》中不等式及补充

说明&#xff1a;此文章用于本人复习巩固&#xff0c;如果也能帮到大家那就更加有意义了。 注&#xff1a;1&#xff09;《数学分析》中的不等式及不等式的补充

HTML之图片和超链接的学习记录

图片 在HTML中&#xff0c;我们可以使用img标签来显示一张图片。对于img标签&#xff0c;我们只需要掌握它的三个属性&#xff1a;src、alt和title。 <img src"" alt"" title"" /> src用于指定图片所在的路径&#xff0c;这个路径可以是…

unity显示获取 年月日周几【日期】

unity显示获取 年月日周几【日期】 public void ShowDate(Text txt){//txt.text DateTime now DateTime.Now; // 获取当前时间int year now.Year; // 获取年份int month now.Month; // 获取月份&#xff08;1-12&#xff09;int day now.Day; // 获取天数&#xff08;1-31&…

【区块链】深入理解智能合约 ABI

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 深入理解智能合约 ABI&#xff08;应用程序二进制接口&#xff09;一、ABI 基础…

鸿蒙ZRouter动态路由框架—生命周期管理能力

文章目录 基本使用(单个页面生命周期&#xff09;页面的全局生命周期监听工作流程图源码 ZRouter从1.1.0版本开始支持生命周期函数管理能力&#xff0c;主要有以下特点&#xff1a; 不影响你原有的生命周期业务逻辑&#xff0c;对NavDestination页面保持着零侵入性&#xff0c;…

代码随想录算法训练营第十九天|理论基础、77. 组合、216.组合总和III、17.电话号码的字母组合

理论基础 文章链接&#xff1a;代码随想录 视频讲解&#xff1a;带你学透回溯算法&#xff08;理论篇&#xff09;| 回溯法精讲&#xff01;_哔哩哔哩_bilibili关于回溯算法&#xff0c;我公众号里已经讲完了&#xff0c;并且将回溯算法专题整理成一本PDF&#xff0c;该PDF共…

uniapp的基本使用(easycom规范和条件编译)和uview组件的安装和使用

文章目录 1、uniapp1.uview组件安装2.uview-plus组件安装 2、条件编译3、easycom规范1.组件路径符合规范2.自定义easycom配置的示例 总结 1、uniapp UniApp的UI组件库&#xff0c;如TMUI、uViewUI、FirstUI、TuniaoUI、ThorUI等&#xff0c;这些组件库适用于Vue3和TypeScript&…

深入探索GDB调试技巧及其底层实现原理

本文分为两个大模块&#xff0c;第一部分记录下本人常用到的GDB的调试方法和技巧&#xff0c;第二部分则尝试分析GDB调试的底层原理。 一、GDB调试 要让程序能被调试&#xff0c;首先得编译成debug版本&#xff0c;当然release版本的也能通过导入符号表来实现调试&#xff0c…

Kubernetes的基本概念

Kubernetes是谷歌以Borg为前身,基于谷歌15年生产环境经验的基础上开源的一个项目,Kubernetes致力于提供跨主机集群的自动部署、扩展、高可用以及运行应用程序容器的平台。 一、资源对象概述 Kubernetes中的基本概念和术语大多是围绕资源对象(Resource Object)来说的,而资…

JavaWeb后端开发案例——苍穹外卖day01

day1遇到问题&#xff1a; 1.前端界面打不开&#xff0c;把nginx.conf文件中localhost:80改成81即可 2.前后端联调时&#xff0c;前端登录没反应&#xff0c;application.yml中默认用的8080端口被占用&#xff0c;就改用了8081端口&#xff0c;修改的时候需要改两个地方&…

(一)<江科大STM32>——软件环境搭建+新建工程步骤

一、软件环境搭建 &#xff08;1&#xff09;安装 Keil5 MDK 文件路径&#xff1a;江科大stm32入门教程资料/Keil5 MDK/MDK524a.EXE&#xff0c;安装即可&#xff0c;路径不能有中文。 &#xff08;2&#xff09;安装器件支持包 文件路径&#xff1a;江科大stm32入门教程资料…

软件开发的各类模型

目录 软件的生命周期 常见开发模型 瀑布模型 螺旋模型 增量模型、迭代模型 敏捷模型 Scrum模型 常见测试模型 V模型 W模型&#xff08;双V模型&#xff09; 软件的生命周期 软件的生命周期包括需求分析&#xff0c;计划&#xff0c;设计&#xff0c;编码&#xff0c;…

ElasticSearch学习笔记一:简单使用

一、前言 该系列的文章用于记录本人从0学习ES的过程&#xff0c;首先会对基本的使用进行讲解。本文默认已经安装了ES单机版本&#xff08;当然后续也会有对应的笔记&#xff09;&#xff0c;且对ES已经有了相对的了解&#xff0c;闲话少叙&#xff0c;书开正文。 二、ES简介 …

C++笔记---异常

1. 异常的概念 1.1 异常和错误 异常通常是指在程序运行中动态出现的非正常情况&#xff0c;这些情况往往是可以预见并可以在不停止程序的情况下动态地进行处理的。 错误通常是指那些会导致程序终止的&#xff0c;无法动态处理的非正常情况。例如&#xff0c;越界访问、栈溢出…

python opencv3

三、图像预处理2 1、图像滤波 为图像滤波通过滤波器得到另一个图像。也就是加深图像之间的间隙&#xff0c;增强视觉效果&#xff1b;也可以模糊化间隙&#xff0c;造成图像的噪点被抹平。 2、卷积核 在深度学习中&#xff0c;卷积核越大&#xff0c;看到的信息越多&#xff0…

ENSP作业——小型园区网

题目 根据上图&#xff0c;可得需求为&#xff1a; 1.配置交换机上的VLAN及IP地址。 2.设置SW1为VLAN 2/3的主根桥&#xff0c;设置SW2为VLAN 20/30的主根桥&#xff0c;且两台交换机互为主备。 3.可以使用super vlan。&#xff08;本次实验中未使用&#xff09; 4.上层通过静…

解决 Vue3、Vite 和 TypeScript 开发环境下跨域的问题,实现前后端数据传递

引言 本文介绍如何在开发环境下解决 Vite 前端&#xff08;端口 3000&#xff09;和后端&#xff08;端口 80&#xff09;之间的跨域问题&#xff1a; 在开发环境中&#xff0c;前端使用的 Vite 端口与后端端口不一致&#xff0c;会产生跨域错误提示&#xff1a; Access to X…

Windows系统中Oracle VM VirtualBox的安装

一.背景 公司安排了师带徒&#xff0c;环境搭建问题一直是初级程序员头疼的事情&#xff0c;我记录一下这些基础的内容&#xff0c;方便初学者。大部分开发者的机器还是windows系统&#xff0c;所以写了怎么安装。 二.版本信息及 操作系统&#xff1a;windows11 家庭版…

uniapp 集成 uview

注意&#xff1a;HBuildX新建项目时必须选择vue2版本&#xff0c;vue3会不支持uview 下载安装方式&#xff1a; uview安装网站&#xff1a;uView2.0重磅发布&#xff0c;利剑出鞘&#xff0c;一统江湖 - DCloud 插件市场 配置&#xff1a; 1.安装sass插件 // 安装sass npm i …