PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

news2025/4/3 4:02:12

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

在这里插入图片描述

二、完整 DDP 训练 Demo

  • 官方DDP Dem参考

1. 基础训练脚本 (ddp_demo.py)

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    """简单的CNN模型"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(9216, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return self.fc(x)

def prepare_dataloader(rank, world_size, batch_size=32):
    """准备分布式数据加载器"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

def train(rank, world_size, epochs=2):
    """训练函数"""
    setup(rank, world_size)
    
    # 设置当前设备
    torch.cuda.set_device(rank)
    
    # 初始化模型、优化器等
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(ddp_model.parameters())
    scaler = GradScaler()  # 混合精度训练
    criterion = nn.CrossEntropyLoss()
    train_loader = prepare_dataloader(rank, world_size)
    
    for epoch in range(epochs):
        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shuffle
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            
            optimizer.zero_grad()
            
            # 混合精度训练
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = ddp_model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
    
    cleanup()

if __name__ == "__main__":
    # 单机多卡启动时,torchrun会自动设置这些环境变量
    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    # 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    if rank == 0:  # 只在主进程保存
        torch.save(model.state_dict(), "model.pth")
    
  3. 多机训练

    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
    # 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    for i, (data, target) in enumerate(train_loader):
        if i % 2 == 0:
            optimizer.zero_grad()
        # 前向和反向...
        if i % 2 == 1:
            optimizer.step()
    
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2326942.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Unity】记录TMPro使用过程踩的一些坑

1、打包到webgl无法输入中文,编辑器模式可以,但是webgl不行,试过网上的脚本,还是不行 解决方法:暂时没找到 2、针对字体asset是中文时,overflow的效果模式处理奇怪,它会出现除了overflow模式以…

计算机视觉初步(环境搭建)

1.anaconda 建议安装在D盘,官网正常安装即可,一般可以安装windows版本 安装成功后,可以在电脑应用里找到: 2.创建虚拟环境 打开anaconda prompt, 可以用conda env list 查看现有的环境,一般打开默认bas…

基于聚类与引力斥力优化的选址算法

在众多实际场景中,诸如消防设施选址、基站布局规划以及充电桩站点部署等,都面临着如何利用最少的资源,实现对所有目标对象全面覆盖的难题。为有效解决这类问题,本文提出一种全新的组合算法模型 —— 基于聚类与引力斥力优化的选址…

Mac 电脑移动硬盘无法识别的解决方法

在使用 Mac 电脑的过程中,不少用户都遇到过移动硬盘没有正常推出,导致无法识别的问题。这不仅影响了数据的传输,还可能让人担心硬盘内数据的安全。今天,我们就来详细探讨一下针对这一问题的解决方法。 当发现移动硬盘无法识别时&…

LeetCode Hot100 刷题笔记(4)—— 二叉树、图论

目录 一、二叉树 1. 二叉树的深度遍历(DFS:前序、中序、后序遍历) 2. 二叉树的最大深度 3. 翻转二叉树 4. 对称二叉树 5. 二叉树的直径 6. 二叉树的层序遍历 7. 将有序数组转换为二叉搜索树 8. 验证二叉搜索树 9. 二叉搜索树中第 K 小的元素 …

【计算机视觉】YOLO语义分割

一、语义分割简介 1. 定义 语义分割(Semantic Segmentation)是计算机视觉中的一项任务,其目标是对图像中的每一个像素赋予一个类别标签。与目标检测只给出目标的边界框不同,语义分割能够在像素级别上区分不同类别,从…

【SpringBoot + MyBatis + MySQL + Thymeleaf 的使用】

目录: 一:创建项目二:修改目录三:添加配置四:创建数据表五:创建实体类六:创建数据接口七:编写xml文件八:单元测试九:编写服务层十:编写控制层十一…

在ensp进行OSPF+RIP+静态网络架构配置

一、实验目的 1.Ospf与RIP的双向引入路由消息 2.Ospf引入静态路由信息 二、实验要求 需求: 路由器可以互相ping通 实验设备: 路由器router7台 使用ensp搭建实验坏境,结构如图所示 三、实验内容 1.配置R1、R2、R3路由器使用Ospf动态路由…

Redis安全与配置问题——AOF文件损坏问题及解决方案

Java 中的 Redis AOF 文件损坏问题全面解析 一、AOF 文件损坏的本质与危害 1.1 AOF 持久化原理 Redis 的 AOF(Append-Only File) 通过记录所有写操作命令实现持久化。文件格式如下: *2\r\n$6\r\nSELECT\r\n$1\r\n0\r\n *3\r\n$3\r\nSET\r\…

3.第二阶段x64游戏实战-分析人物移动实现人物加速

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 上一个内容:2.第二阶段x64游戏实战-x64dbg的使用 想找人物的速度,就需要使用Ch…

Scala(三)

本节课学习了函数式编程,了解到它与Java、C函数式编程的区别;学习了函数的基础,了解到它的基本语法、函数和方法的定义、函数高级。。。学习到函数至简原则,高阶函数,匿名函数等。 函数的定义 函数基本语法 例子&…

什么是 Java 泛型

一、什么是 Java 泛型? 泛型(Generics) 是 Java 中一种强大的编程机制,允许在定义类、接口和方法时使用类型参数。通过泛型,可以将数据类型作为参数传递,从而实现代码的通用性和类型安全。 简单来说&…

Unity中根据文字数量自适应长宽的对话气泡框UI 会自动换行

使用Ugui制作一个可以根据文本数量自动调整宽度,并可以自动换行的文字UI 或者不要独立的Bg,那么一定要把bg的img设置成切片

【小也的Java之旅系列】02 分布式集群详解

文章目录 前言为什么叫小也 本系列适合什么样的人阅读正文单体优点缺点 CAP为什么CAP不可能全部满足?CAP 三选二 分布式事务分布式方案——SeataXA模式(强一致)AT模式(自动补偿,默认模式)TCC模式&#xff0…

Ubuntu里安装Jenkins

【方式1】:下载war包,直接运行,需提前搭建Java环境,要求11或17,不推荐,war包下载地址,将war包上传到服务器,直接使用命令启动 java -jar /data/jenkins/jenkins.war【方式2】&#…

C++包管理工具vcpkg的安装使用教程

前言 使用vcpkg可以更方便地安装各种库,省去配置的时间和配置失败的风险,类似python中的anaconda,懒人必备 参考 安装参考:https://bqcode.blog.csdn.net/article/details/135831901?fromshareblogdetail&sharetypeblogde…

微服务面试题:配置中心

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编…

Qt msvc2017程序无法用enigma vitrual box打包,用winrar打包

我们通常打包Qt程序用Enigma virtual box。这样我们的程序就可以在别的电脑上也能运行,但是有时候,我们发现Enigma virtual box在打包的时候,对于msvc2017需要编译的程序中引用webengineview模块,打包时候发现不能运行。 我们如何…

微服务集成测试 -华为OD机试真题(A卷、JavaScript)

题目描述 现在有n个容器服务,服务的启动可能有一定的依赖性(有些服务启动没有依赖),其次,服务自身启动加载会消耗一些时间。 给你一个n n 的二维矩阵useTime,其中useTime[i][i]10表示服务i自身启动加载需…

Mac: 运行python读取CSV出现 permissionError

在MAC机器里,之前一直运行程序在某个指定的目录下读取excel和csv文件,没有出现错误,有一天突然出现错误:permissionError:[Errno 1] Operation not permitted, 具体错误信息如下: 经过调查得知&#xff0c…