PyTorch 并行训练 DistributedDataParallel完整代码示例

news2024/11/16 19:55:00

使用大型数据集训练大型深度神经网络 (DNN) 的问题是深度学习领域的主要挑战。 随着 DNN 和数据集规模的增加,训练这些模型的计算和内存需求也会增加。 这使得在计算资源有限的单台机器上训练这些模型变得困难甚至不可能。 使用大型数据集训练大型 DNN 的一些主要挑战包括:

  • 训练时间长:训练过程可能需要数周甚至数月才能完成,具体取决于模型的复杂性和数据集的大小。
  • 内存限制:大型 DNN 可能需要大量内存来存储训练期间的所有模型参数、梯度和中间激活。 这可能会导致内存不足错误并限制可在单台机器上训练的模型的大小。

为了应对这些挑战,已经开发了各种技术来扩大具有大型数据集的大型 DNN 的训练,包括模型并行性、数据并行性和混合并行性,以及硬件、软件和算法的优化。

在本文中我们将演示使用 PyTorch 的数据并行性和模型并行性。

我们所说的并行性一般是指在多个gpu,或多台机器上训练深度神经网络(dnn),以实现更少的训练时间。数据并行背后的基本思想是将训练数据分成更小的块,让每个GPU或机器处理一个单独的数据块。然后将每个节点的结果组合起来,用于更新模型参数。在数据并行中,模型体系结构在每个节点上是相同的,但模型参数在节点之间进行了分区。每个节点使用分配的数据块训练自己的本地模型,在每次训练迭代结束时,模型参数在所有节点之间同步。这个过程不断重复,直到模型收敛到一个令人满意的结果。

下面我们用用ResNet50和CIFAR10数据集来进行完整的代码示例:

在数据并行中,模型架构在每个节点上保持相同,但模型参数在节点之间进行了分区,每个节点使用分配的数据块训练自己的本地模型。

PyTorch的DistributedDataParallel 库可以进行跨节点的梯度和模型参数的高效通信和同步,实现分布式训练。本文提供了如何使用ResNet50和CIFAR10数据集使用PyTorch实现数据并行的示例,其中代码在多个gpu或机器上运行,每台机器处理训练数据的一个子集。训练过程使用PyTorch的DistributedDataParallel 库进行并行化。

导入必须要的库

 importos
 fromdatetimeimportdatetime
 fromtimeimporttime
 importargparse
 importtorchvision
 importtorchvision.transformsastransforms
 importtorch
 importtorch.nnasnn
 importtorch.distributedasdist
 fromtorch.nn.parallelimportDistributedDataParallel

接下来,我们将检查GPU

 importsubprocess
 result=subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE)
 print(result.stdout.decode())

因为我们需要在多个服务器上运行,所以手动一个一个执行并不现实,所以需要有一个调度程序。这里我们使用SLURM文件来运行代码(slurm面向Linux和Unix类似内核的免费和开源工作调度程序),

 defmain():
         
     # get distributed configuration from Slurm environment
     
     parser=argparse.ArgumentParser()
     parser.add_argument('-b', '--batch-size', default=128, type=int,
                         help='batch size. it will be divided in mini-batch for each worker')
     parser.add_argument('-e','--epochs', default=2, type=int, metavar='N',
                         help='number of total epochs to run')
     parser.add_argument('-c','--checkpoint', default=None, type=str,
                         help='path to checkpoint to load')
     args=parser.parse_args()
     
     rank=int(os.environ['SLURM_PROCID'])
     local_rank=int(os.environ['SLURM_LOCALID'])
     size=int(os.environ['SLURM_NTASKS'])
     master_addr=os.environ["SLURM_SRUN_COMM_HOST"]
     port="29500"
     node_id=os.environ['SLURM_NODEID']
     ddp_arg= [rank, local_rank, size, master_addr, port, node_id]
     train(args, ddp_arg)     

然后我们使用DistributedDataParallel 库来执行分布式训练。

 deftrain(args, ddp_arg):
     
     rank, local_rank, size, MASTER_ADDR, port, NODE_ID=ddp_arg
     
     # display info
     ifrank==0:
         #print(">>> Training on ", len(hostnames), " nodes and ", size, " processes, master node is ", MASTER_ADDR)
         print(">>> Training on ", size, " GPUs, master node is ", MASTER_ADDR)
     #print("- Process {} corresponds to GPU {} of node {}".format(rank, local_rank, NODE_ID))
     
     print("- Process {} corresponds to GPU {} of node {}".format(rank, local_rank, NODE_ID))
     
     
     # configure distribution method: define address and port of the master node and initialise communication backend (NCCL)
     #dist.init_process_group(backend='nccl', init_method='env://', world_size=size, rank=rank)
     dist.init_process_group(
         backend='nccl',
         init_method='tcp://{}:{}'.format(MASTER_ADDR, port),
         world_size=size,
         rank=rank
     )
     
     # distribute model
     torch.cuda.set_device(local_rank)
     gpu=torch.device("cuda")
     #model = ResNet18(classes=10).to(gpu)
     model=torchvision.models.resnet50(pretrained=False).to(gpu)
     ddp_model=DistributedDataParallel(model, device_ids=[local_rank])
     ifargs.checkpointisnotNone:
         map_location= {'cuda:%d'%0: 'cuda:%d'%local_rank}
         ddp_model.load_state_dict(torch.load(args.checkpoint, map_location=map_location))
     
     # distribute batch size (mini-batch)
     batch_size=args.batch_size
     batch_size_per_gpu=batch_size//size
     
     # define loss function (criterion) and optimizer
     criterion=nn.CrossEntropyLoss()  
     optimizer=torch.optim.SGD(ddp_model.parameters(), 1e-4)
     
     
     transform_train=transforms.Compose([
         transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ])
 
     # load data with distributed sampler
     #train_dataset = torchvision.datasets.CIFAR10(root='./data',
     #                                           train=True,
     #                                           transform=transform_train,
     #                                           download=False)
     
     # load data with distributed sampler
     train_dataset=torchvision.datasets.CIFAR10(root='./data',
                                                train=True,
                                                transform=transform_train,
                                                download=False)
     
     train_sampler=torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                     num_replicas=size,
                                                                     rank=rank)
     
     train_loader=torch.utils.data.DataLoader(dataset=train_dataset,
                                                batch_size=batch_size_per_gpu,
                                                shuffle=False,
                                                num_workers=0,
                                                pin_memory=True,
                                                sampler=train_sampler)
 
     # training (timers and display handled by process 0)
     ifrank==0: start=datetime.now()         
     total_step=len(train_loader)
     
     forepochinrange(args.epochs):
         ifrank==0: start_dataload=time()
         
         fori, (images, labels) inenumerate(train_loader):
             
             # distribution of images and labels to all GPUs
             images=images.to(gpu, non_blocking=True)
             labels=labels.to(gpu, non_blocking=True) 
             
             ifrank==0: stop_dataload=time()
 
             ifrank==0: start_training=time()
             
             # forward pass
             outputs=ddp_model(images)
             loss=criterion(outputs, labels)
 
             # backward and optimize
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             ifrank==0: stop_training=time() 
             if (i+1) %10==0andrank==0:
                 print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Time data load: {:.3f}ms, Time training: {:.3f}ms'.format(epoch+1, args.epochs,
                                                                         i+1, total_step, loss.item(), (stop_dataload-start_dataload)*1000,
                                                                         (stop_training-start_training)*1000))
             ifrank==0: start_dataload=time()
                     
         #Save checkpoint at every end of epoch
         ifrank==0:
             torch.save(ddp_model.state_dict(), './checkpoint/{}GPU_{}epoch.checkpoint'.format(size, epoch+1))
 
     ifrank==0:
         print(">>> Training complete in: "+str(datetime.now() -start))
 
 
 if__name__=='__main__':
 
     main()

代码将数据和模型分割到多个gpu上,并以分布式的方式更新模型。下面是代码的一些解释:

train(args, ddp_arg)有两个参数,args和ddp_arg,其中args是传递给脚本的命令行参数,ddp_arg包含分布式训练相关参数。

rank, local_rank, size, MASTER_ADDR, port, NODE_ID = ddp_arg:解包ddp_arg中分布式训练相关参数。

如果rank为0,则打印当前使用的gpu数量和主节点IP地址信息。

dist.init_process_group(backend=‘nccl’, init_method=‘tcp://{}:{}’.format(MASTER_ADDR, port), world_size=size, rank=rank) :使用NCCL后端初始化分布式进程组。

torch.cuda.set_device(local_rank):为这个进程选择指定的GPU。

model = torchvision.models. ResNet50 (pretrained=False).to(gpu):从torchvision模型中加载ResNet50模型,并将其移动到指定的gpu。

ddp_model = DistributedDataParallel(model, device_ids=[local_rank]):将模型包装在DistributedDataParallel模块中,也就是说这样我们就可以进行分布式训练了

加载CIFAR-10数据集并应用数据增强转换。

train_sampler=torch.utils.data.distributed.DistributedSampler(train_dataset,num_replicas=size,rank=rank):创建一个DistributedSampler对象,将数据集分割到多个gpu上。

train_loader =torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size_per_gpu,shuffle=False,num_workers=0,pin_memory=True,sampler=train_sampler):创建一个DataLoader对象,数据将批量加载到模型中,这与我们平常训练的步骤是一致的只不过是增加了一个分布式的数据采样DistributedSampler

为指定的epoch数训练模型,以分布式的方式使用optimizer.step()更新权重。

rank0在每个轮次结束时保存一个检查点。

rank0每10个批次显示损失和训练时间。

结束训练时打印训练模型所花费的总时间也是在rank0上。

代码测试

在使用1个节点1/2/3/4个gpu, 2个节点6/8个gpu,每个节点3/4个gpu上进行了训练Cifar10上的Resnet50的测试如下图所示,每次测试的批处理大小保持不变。完成每项测试所花费的时间以秒为单位记录。随着使用的gpu数量的增加,完成测试所需的时间会减少。当使用8个gpu时,需要320秒才能完成,这是记录中最快的时间。这是肯定的,但是我们可以看到训练的速度并没有像GPU数量增长呈现线性的增长,这可能是因为Resnet50算是一个比较小的模型了,并不需要进行并行化训练。

在多个gpu上使用数据并行可以显著减少在给定数据集上训练深度神经网络(DNN)所需的时间。随着gpu数量的增加,完成训练过程所需的时间减少,这表明DNN可以更有效地并行训练。

这种方法在处理大型数据集或复杂的DNN架构时特别有用。通过利用多个gpu,可以加快训练过程,实现更快的模型迭代和实验。但是需要注意的是,通过Data Parallelism实现的性能提升可能会受到通信开销和GPU内存限制等因素的限制,需要仔细调优才能获得最佳结果。

https://avoid.overfit.cn/post/67095b9014cb40888238b84fea17e872

作者:Joseph El Kettaneh

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/356022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot监控

文章目录一、PrometheusGrafana监控Springboot1、简介2、SpringBoot应用镜像搭建2.1 springboot应用创建2.2 镜像创建3、Prometheus3.1 概述3.2 Prometheus创建4、Grafana可视化监控4.1 可视化4.2 告警设置二、轻量级日志系统Loki1、简介1.1 介绍1.2 与ELK差异2、grafana loki日…

linux宝塔安装和部署node全栈项目

使用服务器:阿里云ECS系列 服务器操作系统: Alibaba Cloud Linux 2.1903 LTS 64位 连接服务器方式: Workbench远程连接 使用公网IP登录 Workbench远程桌面,使用命令安装linux宝塔面板操作服务器: 1.登录linux宝塔面板,使用终端命令安装linux宝塔 yum i…

【操作系统】计算机系统概述

文章目录操作系统的概念、功能和目标熟悉的操作系统计算机系统的层次结构操作系统的概念操作系统的功能和目标作为系统资源的管理者作为用户和计算机之间的接口作为最接近硬件的层次操作系统的四个特征并发共享并发和共享的关系虚拟异步操作系统的发展和分类手工操作阶段单道批…

1207. 大臣的旅费/树的直径【AcWing】

1207. 大臣的旅费 很久以前,T王国空前繁荣。 为了更好地管理国家,王国修建了大量的快速路,用于连接首都和王国内的各大城市。 为节省经费,T国的大臣们经过思考,制定了一套优秀的修建方案,使得任何一个大…

使用Docker-Compose搭建Redis集群

1. 集群配置3主3从由于仅用于测试,故我这里只用1台服务器进行模拟redis列表2.编写redis.conf在server上创建一个目录用于存放redis集群部署文件。这里我放的路径为/root/redis-cluster 在/opt/docker/redis-cluster目录下创建redis-1,redis-2,redis-3,redis-4,redis…

Python 使用 pip 安装 matplotlib 模块(秒解版)

长话短说:本人下载 matplotlib 花了大概三个半小时屡屡碰壁,险些暴走。为了不让新来的小伙伴走我的弯路,特意创作本片文章指明方向。 1.首先需要下载 python 我直接是在电脑自带的软件商店里下载的,图方便,当然在官网下…

操作系统 四(设备管理)

I/O系统功能 隐藏I/O设备的细节;保证设备无关性;提高处理机和I/O设备的利用率;对I/O设备进行控制;确保对设备的正确共享;处理错误。中断、通道、DMA概念 中断:CPU对I/O设备发来的中断信号的一种响应DMA&am…

【配电网优化】基于串行和并行ADMM算法的配电网优化研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

超详细讲解文件函数

超详细讲解文件函数!!!!字符输入/输出函数fgetcfputc文本行输入/输出函数fgetsfputs格式化输入/输出函数fscanffprintf二进制输入/输出函数freadfwrite打开/关闭文件函数fopenfclose字符输入/输出函数 fgetc fgetc函数可以从指定…

个人谈谈对ThreadLocal内存泄露的理解

个人谈谈对ThreadLocal内存泄露的理解ThreadLocal作用ThreadLocalMap内存泄露解释为什么要这样设计ThreadLocalMap的实现思路ThreadLocal作用 平时我们会使用ThreadLocal来存放当前线程的副本数据,让当前线程执行流中各个位置,都可以从ThreadLocal中获取…

Java SPI 机制详解

在面向对象的设计原则中,一般推荐模块之间基于接口编程,通常情况下调用方模块是不会感知到被调用方模块的内部具体实现。一旦代码里面涉及具体实现类,就违反了开闭原则。如果需要替换一种实现,就需要修改代码。 为了实现在模块装…

使用packetbeat对MySQL进行网络抓包

文章目录一、Packetbeat 简介二、packetbeat部署和使用2.1 官方下载解压2.2 修改配置文件2.3 导入索引模板和dashboard2.4 启动packetbeat三、效果展示一、Packetbeat 简介 Packetbeat 是一款轻量型实时网络数据包分析器,能够将主机和容器中的数据发送至 Logstash 或…

uboot编译分析

uboot编译分析 V 1 –> Q ,在一行命令前面加上表示不会在终端输出命令 KCONFIG_CONFIG ? .config.config 默认是没有的,默认是需要使用命令“make xxx_defconofig”先对uboot进行配置,配置完成就会在uboot根目录下生成.config。如果后续自行调整…

多种方法解决谷歌(chrome)、edge、火狐等浏览器F12打不开调试页面或调试模式(面板)的问题。

文章目录1. 文章引言2. 解决问题3. 解决该问题的其他方法1. 文章引言 不论是前端开发者,还是后端开发者,我们在调试web项目时,偶尔弹出相关错误。 此时,我们需要打开浏览器的调试模式,如下图所示: 通过浏…

智能拣配单解决方案

电子货架标签系统(ESLs),是一种放置在货架上、可替代传统纸质价格标签的电子显示装置, 每一个电子货架标签通过有线或者无线网络与商场计算机数据库相连, 并将最新的商品价格通过电子货架标签上的屏显示出来。 电子…

基于微信小程序图书馆管理系统

开发工具:IDEA、微信小程序服务器:Tomcat9.0, jdk1.8项目构建:maven数据库:mysql5.7前端技术:vue、uniapp服务端技术:springbootmybatis-plus本系统分微信小程序和管理后台两部分,项…

量子计算(7)pyqpanda编程2循环与条件判断

目录 一、QWhile 二、QIf 各位读者老爷们,大家好呀,前些时忙着学校的期末考试,小编好久没更新量子计算的文章啦,这段时间也有读者私信小编,问了一些问题。我知道大家都很急,但大家先别急。这不&#xff0…

【数据结构】——队列

文章目录前言一.什么是队列,队列的特点二、队列相关操作队列的相关操作声明队列的创建1.队列的初始化2.对队列进行销毁3.判断队列是否为空队列4.入队操作5.出队操作6.取出队头数据7. 取出队尾数据8.计算队伍的人数总结前言 本文章讲述的是数据结构的特殊线性表——…

Python3 错误和异常实例及演示

作为 Python 初学者,在刚学习 Python 编程时,经常会看到一些报错信息,在前面我们没有提及,这章节我们会专门介绍。 Python 有2种错误很容易辨认:语法错误和异常。 Python assert(断言)用于判断…

通信算法之一百零四:QPSK完整收发仿真链路

1.发射机物理层基带仿真链路 1.1 % Generates the data to be transmitted [transmittedBin, ~] BitGenerator(); 2.2 % Modulates the bits into QPSK symbols modulatedData QPSKModulator(transmittedBin); 2.3 % Square root Raised Cosine Transmit Filter %comm…