背景
Pytorch提供了Distributed Data Parallel (DDP)工具以便在多机多卡上并行训练,并提供了torchrun指令来启动。然而,torchrun指令启动不便于debug。可以通过修改成等价mp.spawn启动方式先debug,完成后再转回torchrun指令启动正式训练。
流程
假设原始DDP训练代码是:
import torch.distributed as dist
def main():
args.local_rank = int(os.environ["LOCAL_RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.rank = int(os.environ["RANK"])
dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
torch.cuda.set_device(args.local_rank)
......
if __name__ == "__main__":
main()
通过以下指令启动DDP训练:
torchrun --nnodes=1 --nproc_per_node=4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py
为了在IDE(例如Pycharm)内debug,修改成以下代码,直接在IDE内debug即可。
import torch.distributed as dist
import torch.multiprocessing as mp
def main(rank, world_size):
args.local_rank = rank
args.world_size = world_size
args.rank = rank
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
torch.cuda.set_device(args.local_rank)
......
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)