浅谈PyTorch中的DP和DDP

news2024/10/4 5:31:08

目录

  • 1. 引言
  • 2. PyTorch 数据并行(Data Parallel, DP)
    • 2.1 DP 的优缺点
    • 2.2 DP 实现代码示例
  • 3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)
    • 3.1 DDP 的优缺点
    • 3.2 分布式基本概念
    • 3.3 DDP 的应用流程
    • 3.5 DDP 实现代码示例
  • 4. DP和DDP的对比

1. 引言

在现代深度学习中,随着模型规模的不断增大以及数据量的快速增长,模型训练所需的计算资源也变得愈加庞大。尤其是在大型深度学习模型的训练过程中,单张 GPU 显存往往难以满足需求,因此,如何高效利用多 GPU 进行并行训练,成为了加速模型训练的关键手段。PyTorch 作为目前最受欢迎的深度学习框架之一,提供了多种并行训练的方式,其中最常用的是 数据并行(Data Parallel, DP)分布式数据并行(Distributed Data Parallel, DDP)

⚠️ 无论是DP还是DDP都只支持数据并行。

2. PyTorch 数据并行(Data Parallel, DP)

数据并行(Data Parallel, DP) 是 PyTorch 中一种简单的并行训练方式,它的主要思想是将数据拆分为多个子集,然后将这些子集分别分配给不同的 GPU 进行计算。DP 的工作原理如下:

  1. 在前向传播时,首先将模型的参数复制到每个 GPU 上。
  2. 每个 GPU 独立计算一部分数据的前向传播和损失值,并将计算结果返回到主 GPU。
  3. 主 GPU 汇总每个 GPU 计算的损失,并计算出梯度。
  4. 通过反向传播,将计算得到的梯度更新主 GPU 的模型参数,然后再将更新后的参数广播到其他 GPU 上。

2.1 DP 的优缺点

优点

  • 实现简单,使用 PyTorch 提供的 torch.nn.DataParallel 接口即可轻松实现。
  • 对于小规模的模型和数据集,DP 能够在单机多卡的场景下提供良好的加速效果。

缺点

  • DP 在每个 batch 中需要在 GPU 之间传递模型参数和数据,参数更新时也需要将梯度传递回主 GPU,这会造成大量的通信开销。
  • 由于梯度的计算和模型参数的更新都是在主 GPU 上完成的,主 GPU 的负载会显著增加,导致 GPU 资源无法得到充分利用。

2.2 DP 实现代码示例

使用 torch.nn.DataParallel 实现数据并行非常简单。我们只需要将模型封装到 DataParallel 中,然后传入多个 GPU 即可。下面我们通过代码示例展示如何使用 DP 进行并行训练。

import torch
import torch.nn as nn
import torchvision

BATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸

# 1. 创建模型
net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
net = nn.DataParallel(net)
net = net.cuda()

# 2. 生成随机数据
total_steps = 100  # 假设每个 epoch 有 100 个步骤
inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).cuda()
targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).cuda()

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
)

# 4. 开始训练
net.train()
for ep in range(1, EPOCHS + 1):
    train_loss = correct = total = 0
    for idx in range(total_steps):
        outputs = net(inputs)

        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        total += targets.size(0)
        correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()

        if (idx + 1) % 25 == 0 or (idx + 1) == total_steps:
            print(f"Epoch [{ep}/{EPOCHS}], Step [{idx + 1}/{total_steps}], Loss: {train_loss / (idx + 1):.3f}, Acc: {correct / total:.3%}")

在这个代码示例中,我们使用了随机生成的输入和标签数据,以简化代码并专注于并行训练的实现。通过将模型封装在 DataParallel 中,我们可以在多个 GPU 上进行并行计算。然而,由于 DP 存在较大的通信开销以及主 GPU 的计算瓶颈,因此在更大规模的训练中,我们更推荐使用分布式数据并行(DDP)来加速训练。

3. PyTorch 分布式数据并行(Distributed Data Parallel, DDP)

分布式数据并行(Distributed Data Parallel, DDP) 是 PyTorch 中推荐使用的多 GPU 并行训练方式,特别适合大规模训练任务。与 DP 不同,DDP 是一种多进程并行方式,避免了 Python 全局解释器锁(GIL)的限制,可以在单机或多机多卡环境中实现更高效的并行计算。DDP的工作原理如下:

  1. 在每个 GPU 上运行一个独立的进程,每个进程都有自己的一份模型副本和数据。
  2. 各个进程独立执行前向传播、计算损失和反向传播,得到各自的梯度。
  3. 在反向传播阶段,各个 GPU 的进程通过通信将梯度汇总,平均后更新每个进程中的模型参数。
  4. 每个进程的模型参数在整个训练过程中保持一致,避免了 DP 中由于参数广播导致的通信开销。

3.1 DDP 的优缺点

优点

  • 由于各个 GPU 上的进程独立计算梯度,更新模型参数时只需要同步梯度而非整个模型,通信开销较小,性能大幅提升。
  • DDP 可以在多机多卡环境下使用,支持大规模的分布式训练,适合深度学习模型的高效扩展。

缺点

  • 代码实现相对 DP 较为复杂,需要手动管理进程的初始化和同步。

3.2 分布式基本概念

在使用 DDP 进行分布式训练时,我们需要理解以下几个基本概念:

  1. node(节点):物理节点,一台机器即为一个节点。
  2. nnodes(节点数量):表示参与训练的物理节点数量。
  3. node rank(节点序号):节点的编号,用于区分不同的物理节点。
  4. nproc per node(每节点的进程数量):表示每个物理节点上启动的进程数量,通常等于 GPU 的数量。
  5. world size(全局进程数量):表示全局并行的进程总数,等于 nnodes * nproc_per_node
  6. rank(进程序号):表示每个进程的唯一编号,用于进程间通信,rank=0 的进程为主进程。
  7. local rank(本地进程序号):在某个节点上的进程的序号,local_rank=0 表示该节点的主进程。

3.3 DDP 的应用流程

使用 DDP 进行分布式训练的步骤如下:

  1. 初始化分布式训练环境:通过 torch.distributed.init_process_group 初始化进程组,指定通信后端和相关配置。
  2. 创建分布式模型:将模型封装到 torch.nn.parallel.DistributedDataParallel 中,进行并行训练。
  3. 生成或加载数据:在每个进程中加载数据,并确保数据在不同进程间的分布,如使用 DistributedSampler
  4. 执行训练脚本:在每个节点的每个进程上启动训练脚本,进行模型训练。

3.5 DDP 实现代码示例

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP

BATCH_SIZE = 256
EPOCHS = 5
NUM_CLASSES = 10
INPUT_SHAPE = (3, 224, 224)  # ResNet-18 的输入尺寸

if __name__ == "__main__":

    # 1. 设置分布式变量,初始化进程组
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend="nccl")
    device = torch.device("cuda", local_rank)

    print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")

    # 2. 创建模型
    net = torchvision.models.resnet18(pretrained=False, num_classes=NUM_CLASSES)
    net = net.to(device)
    net = DDP(net, device_ids=[local_rank], output_device=local_rank)

    # 3. 生成随机数据
    total_steps = 100  # 假设每个 epoch 有 100 个步骤
    inputs = torch.randn(BATCH_SIZE, *INPUT_SHAPE).to(device)
    targets = torch.randint(0, NUM_CLASSES, (BATCH_SIZE,)).to(device)

    # 4. 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        net.parameters(), lr=0.02, momentum=0.9, weight_decay=0.0001, nesterov=True
    )

    # 5. 开始训练
    net.train()
    for ep in range(1, EPOCHS + 1):
        train_loss = correct = total = 0
        for idx in range(total_steps):
            outputs = net(inputs)

            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            total += targets.size(0)
            correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()

            if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == total_steps):
                print(
                    "   == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(
                        idx + 1,
                        total_steps,
                        ep,
                        EPOCHS,
                        train_loss / (idx + 1),
                        100.0 * correct / total,
                    )
                )
    if rank == 0:
        print("\n            =======  Training Finished  ======= \n")

在以上代码中,我们使用了随机生成的输入和标签数据,以简化代码并专注于 DDP 的实现细节。通过在每个进程中初始化分布式环境,并将模型封装在 DistributedDataParallel 中,我们可以在多个 GPU 上高效地进行并行训练。

需要注意的是,DDP 的实现需要在每个进程中正确设置设备和初始化过程,这样才能确保模型和数据在对应的 GPU 上进行计算。

4. DP和DDP的对比

DP 是单进程多线程的分布式方法,主要用于单机多卡的场景。它的工作方式是在每个批处理期间,将模型参数分发到所有 GPU,各 GPU 计算各自的梯度后将结果汇总到 GPU0,再由 GPU0 完成参数更新,然后将更新后的模型参数广播回其他 GPU。由于 DP 只广播模型的参数,速度较慢,尤其是在多个 GPU 协同工作时,GPU 利用率低,通常效率不如 DDP。

相比之下,DDP 使用多进程架构,既支持单机多卡,也支持多机多卡,并避免了 GIL(全局解释器锁)带来的性能损失。每个进程独立计算梯度,计算完成后各进程汇总并平均梯度,更新参数时各进程均独立完成。这种方式减少了通信开销,只在初始化时广播一次模型参数,并且在每次更新后只传递梯度。由于各进程独立更新参数,且更新过程中模型参数保持一致,DDP 在效率和速度上大大优于 DP。

数据并行(DP)分布式数据并行(DDP)
实现复杂度使用 nn.DataParallel,实现简单,代码改动较少。需要设置分布式环境,使用 torch.distributed,代码实现相对复杂,需要手动管理进程和同步。
通信开销通信开销较大,参数和梯度需要在主 GPU 和其他 GPU 之间频繁传递。通信开销较小,只在反向传播时同步梯度,各 GPU 之间直接通信,无需通过主 GPU。
扩展性扩展性有限,适用于单机多卡,不支持多机训练。扩展性强,支持单机多卡和多机多卡,适合大规模分布式训练。
性能主 GPU 负载重,可能成为瓶颈,GPU 资源利用率较低。各 GPU 负载均衡,资源利用率高,训练速度更快。
适用场景适合小规模模型和数据集的单机多卡训练。适合大规模模型和数据集的单机或多机多卡训练。
梯度同步方式梯度在主 GPU 上汇总和更新,需要从其他 GPU 收集梯度。梯度在各 GPU 间直接同步,通常使用 All-Reduce 操作,效率更高。
模型参数广播每次前向传播都需要将模型参数从主 GPU 复制到其他 GPU。初始化时各进程各自持有一份模型副本,参数更新后自动同步,无需频繁复制。
对 Python GIL 的影响受限于 Python 全局解释器锁(GIL),因为是单进程多线程,无法充分利用多核 CPU。采用多进程方式,不受 GIL 影响,能够充分利用多核 CPU 和多 GPU 进行并行计算。
容错性主 GPU 故障会导致整个训练中断,容错性较差。各进程相对独立,某个进程出错不会影响其他进程,容错性较好。
调试难度由于是单进程,调试相对容易。多进程调试较为复杂,需要注意进程间的通信和同步问题。
代码修改量只需在模型外层加上 nn.DataParallel 封装,代码改动少。需要在代码中添加进程初始化、模型封装、设备设置等步骤,修改量较大。
数据加载方式使用常规的数据加载方式,无需特殊处理。需要使用 DistributedSampler 等工具,确保各进程加载不同的数据子集,避免数据重复。
资源占用主 GPU 内存和计算资源占用较高,其他 GPU 资源可能未被充分利用。各 GPU 资源均衡占用,能够最大化利用多 GPU 的计算能力。
训练结果一致性由于参数更新在主 GPU 上进行,可能存在精度损失或不一致的情况。各进程的模型参数同步更新,训练结果一致性更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学会使用maven工具看这一篇文章就够了

文章目录 概述一、定义与功能二、核心组件三、主要作用四、仓库管理 settings.xml说明一、文件位置与优先级二、主要配置元素三、配置示例 pom.xml文件说明一、pom.xml的基本结构二、pom.xml的主要元素及其说明三、依赖管理四、常用插件五、其他配置 maven安装配置一、下载Mave…

12.数据结构和算法-栈和队列的定义和特点

栈和队列的定义和特点 栈的应用 队列的常见应用 栈的定义和特点 栈的相关概念 栈的示意图 栈与一般线性表有什么不同 队列的定义和特点 队列的相关概念

创建一个Java Web API项目

创建一个Java Web API涉及多个步骤和技术栈,包括项目设置、依赖管理、数据访问层实现、业务逻辑实现、控制层开发以及测试和部署。在这篇详解中,我将带领你通过一个完整的Java Web API实现流程,采用Spring Boot和MyBatis-Plus作为主要技术工具…

redis高级篇 抢红包案例的设计以及分布式锁

一 抢红包案例 1.1 抢红包 二倍均值算法: M为剩余金额;N为剩余人数,公式如下: 每次抢到金额随机区间(0,(M/N)*2) 这个公式,保证了每次获取的金额平均值…

文心一言 VS 讯飞星火 VS chatgpt (360)-- 算法导论24.3 2题

二、请举出一个包含负权重的有向图,使得 Dijkstra 算法在其上运行时将产生不正确的结果。为什么在有负权重的情况下,定理 24.6 的证明不能成立呢?定理 24.6 的内容是:Dijkstra算法运行在带权重的有向图 G ( V , E ) G(V,E) G(V,E…

高炉计算笔记

一、总体概述 热风炉是一种重要的工业热能设备,通过燃烧燃料将水加热为蒸汽,用于驱动各种设备。在热风炉的运行过程中,烟气量是一个重要的参数,表示热风炉内燃料的利用率及运行效率。烟气量的计算公式如下: Q α Q…

Stream流的终结方法(二)——collect

1.Stream流的终结方法 2. collect方法 collect方法用于收集流中的数据放到集合中去,可以将流中的数据放到List,Set,Map集合中 2.1 将流中的数据收集到List集合中 package com.njau.d10_my_stream;import java.util.*; import java.util.f…

Leetcode—560. 和为 K 的子数组【中等】(unordered_map)

2024每日刷题&#xff08;166&#xff09; Leetcode—560. 和为 K 的子数组 C实现代码 class Solution { public:int subarraySum(vector<int>& nums, int k) {unordered_map<int, int> mp{{0, 1}};int ans 0;int prefix 0;for(int i 0; i < nums.size…

深度学习----------------------------编码器、解码器架构

目录 重新考察CNN重新考察RNN编码器-解码器架构总结编码器解码器架构编码器解码器合并编码器和解码器 重新考察CNN 编码器&#xff1a;将输入编码成中间表达形式&#xff08;特征&#xff09; 解码器&#xff1a;将中间表示解码成输出。 重新考察RNN 编码器&#xff1a;将文…

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续&#xff1a;四、参考文献&#xff1a; 前言 首先给出莱斯衰落信道仿真模型&#xff0c;该模型由直射路径分量和反射路径分量组成&#xff0c;其中反射路径分量由瑞利衰落信道模型构…

水下垃圾识别数据集支持yolov5、yolov6、yolov7、yolov8、yolov9、yolov10总共3131张数据训练集1886张带标注的txt文件

水下垃圾识别数据集 支持yolov5、yolov6、yolov7、yolov8、yolov9、yolov10 总共3131张数据 训练集1886张 带标注的txt文件 水下垃圾识别数据集介绍 数据集名称 水下垃圾识别数据集 (Underwater Trash Detection Dataset) 数据集概述 该数据集专为训练和评估基于YOLO系列目…

【一文理解】conda install pip install 区别

大部分情况下&#xff0c;conda install & pip install 二者安装的package都可以正常work&#xff0c;但是混装多种package后容易版本冲突&#xff0c;出现各种报错。 目录 检查机制 支持语言 库的位置 环境隔离 编译情况 检查机制 conda有严格的检查机制&#xff0c…

python-线程与进程

进程 程序编写完没有运行称之为程序。正在运行的代码&#xff08;程序&#xff09;就是进程。在Python3语言中&#xff0c;对多进程支持的是multiprocessing模块和subprocess模块。multiprocessing模块为在子进程中运行任务、通讯和共享数据&#xff0c;以及执行各种形式的同步…

【Java数据结构】 链表

【本节目标】 1. ArrayList 的缺陷 2. 链表 3. 链表相关 oj题目 一. ArrayList的缺陷 上节课已经熟悉了ArrayList 的使用&#xff0c;并且进行了简单模拟实现。通过源码知道&#xff0c; ArrayList 底层使用数组来存储元素&#xff1a; public class ArrayList<E>…

探索Spring Boot:实现“衣依”服装电商平台

1系统概述 1.1 研究背景 如今互联网高速发展&#xff0c;网络遍布全球&#xff0c;通过互联网发布的消息能快而方便的传播到世界每个角落&#xff0c;并且互联网上能传播的信息也很广&#xff0c;比如文字、图片、声音、视频等。从而&#xff0c;这种种好处使得互联网成了信息传…

深入理解 CSS 浮动(Float):详尽指南

“批判他人总是想的太简单 剖析自己总是想的太困难” 文章目录 前言文章有误敬请斧正 不胜感恩&#xff01;目录1. 什么是 CSS 浮动&#xff1f;2. CSS 浮动的历史背景3. 基本用法float 属性值浮动元素的行为 4. 浮动对文档流的影响5. 清除浮动clear 属性清除浮动的技巧1. 使用…

从零开始讲PCIe(1)——PCI概述

一、前言 在之前的内容中&#xff0c;我们已经知道了PCIe是一种外设总线协议&#xff0c;其前身是PCI和PCI-X&#xff0c;虽然PCIe在硬件上有了很大的进步&#xff0c;但其使用的软件与PCI系统几乎保持不变。这种向后兼容性设计&#xff0c;目的是使从旧设计到新设计的迁移更加…

【QGis】生成规则网格/渔网(Fishnet)

【QGis】生成规则网格/渔网&#xff08;Fishnet&#xff09; QGis操作案例参考 QGIS下载安装及GIS4WRF插件导入可参见另一博客-【QGIS】软件下载安装及GIS4WRF插件使用。 QGis操作案例 1、加载中国省级边界&#xff0c;QGis界面如下&#xff1a; 查看坐标系&#xff1a; 如…

详解JVM类加载机制

❝ 前几篇文章我们分别详细描述了 JVM整体的内存结构 JVM对象内存是如何布局的以及内存分配的详细过程 但是对JVM内存结构各个模块没有深入的分析&#xff0c;为了熟悉JVM底层结构&#xff0c;接下来将把JVM运行时数据区的各个模块逐一分析&#xff0c;体系化的理解JVM的各个模…

【S32K3 RTD LLD篇5】K344 ADC SW+HW trigger

【S32K3 RTD LLD篇5】K344 ADC SWHW trigger 一&#xff0c;文档简介二&#xff0c;ADC SW HW 触发2.1 软硬件平台2.2 SWADC 软件触发2.3 SWBCTUADC 软件BCTU触发2.4 PITTRIGMUXADC 硬件PIT TRIGUMX触发2.5 EMIOSBCTUHWADC硬件EMIOS BCTU触发2.6 EMIOSBCTUHW LISTADC硬件EMIOS …