分布式训练:(Pytorch)

news2024/11/13 10:11:32

分布式训练是将机器学习模型的训练过程分散到多个计算节点或设备上,以提高训练速度和效率,尤其是在处理大规模数据和模型时。分布式训练主要分为数据并行模型并行两种主要策略:

1. 数据并行 (Data Parallelism)

数据并行是最常见的分布式训练方式。在这种方法中,模型副本会被复制到多个计算设备上,每个设备处理不同的批次(batch)数据。

工作流程:
  • 每个设备上都有一个完整的模型副本。
  • 数据集被分割成多个部分(mini-batches),每个设备处理其中一部分。
  • 每个设备独立计算模型的前向传播和反向传播,计算出梯度。
  • 通过某种方式(如梯度聚合),将所有设备的梯度平均化,并更新全局模型参数。
  • 同步方式可分为同步训练和异步训练:
    • 同步训练:所有设备都在同一个时刻更新模型参数。
    • 异步训练:各设备独立更新参数,可能导致一些参数不一致。
# Replicate module to devices in device_ids
replicas = nn.parallel.replicate(module, device_ids)
# Distribute input to devices in device_ids
inputs = nn.parallel.scatter(input, device_ids)
# Apply the models to corresponding inputs
outputs = nn.parallel.parallel_apply(replicas, inputs)
# Gather result from all devices to output_device
result = nn.parallel.gather(outputs, output_device)
优点:
  • 易于实现,特别是在GPU集群或云端平台中。
  • 可以在大规模数据集上显著加快训练过程。
缺点:
  • 通信开销较大,特别是在梯度同步阶段,可能会成为训练速度的瓶颈。
  • 对大模型的扩展性有限,因为每个设备都需要存储完整的模型。

2. 模型并行 (Model Parallelism)

模型并行将一个大型模型拆分到多个设备上,以便更好地利用计算资源,尤其适用于内存消耗较大的模型。

工作流程:
  • 模型被拆分成多个部分,每个设备负责模型的一个子集。
  • 输入数据在各设备间传递,完成前向传播和反向传播。
  • 各设备独立计算梯度并更新自己负责的模型参数。
优点:
  • 适合超大规模模型,尤其是单个设备无法存储整个模型的情况。
  • 内存使用效率较高。
缺点:
  • 由于模型的不同部分在不同设备上进行计算,存在大量的通信开销,尤其是在前向传播和反向传播时需要设备间频繁交互。
  • 难以实现模型的负载均衡,部分设备可能成为性能瓶颈。

常用的分布式训练框架

  • TensorFlow:支持多设备、多机器的分布式训练,通过 tf.distribute.Strategy 轻松实现。
  • PyTorch:通过 torch.distributed 提供原生支持,还支持基于 Horovod 等第三方工具的分布式训练。
  • Horovod:Uber 开源的分布式深度学习库,支持 TensorFlow、Keras、PyTorch 等。

关键挑战

  • 同步和通信开销:在数据并行训练中,梯度的同步可能成为瓶颈。
  • 负载均衡:在模型并行训练中,确保各设备之间的负载均衡非常重要,以避免性能瓶颈。
  • 容错性:分布式训练中节点故障可能导致训练过程中断,需要具备一定的容错机制。

常用的 API 有两个:

  • torch.nn.DataParallel(DP)
  • torch.nn.DistributedDataParallel(DDP)

torch.nn.DataParallel(简称 DP)是 PyTorch 提供的一个简单的并行化工具,主要用于在多个 GPU 上进行数据并行训练。DataParallel 通过将输入数据批次(batch)切分成多个小批次,并将其分发到多个 GPU 上,进行并行处理。它会自动处理梯度的同步和模型参数的更新。

torch.nn.DataParallel 的工作机制

  1. 模型复制DataParallel 会将模型复制到多个 GPU 上,每个 GPU 上有一个模型副本。
  2. 数据分割:输入数据会被划分成多个小批次(mini-batches),并分别分发给各个 GPU。
  3. 并行执行:每个 GPU 独立进行前向传播和反向传播,计算梯度。
  4. 梯度汇总:主设备(默认是 cuda:0)会收集所有 GPU 计算出的梯度,并将它们平均化,更新模型的全局参数。

使用 torch.nn.DataParallel

使用 DataParallel 非常简单,通常只需要将模型用 DataParallel 包裹,然后像普通模型一样使用即可。

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# 初始化模型和数据
model = SimpleModel()

# 将模型并行化
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)

model = model.cuda()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 模拟输入数据
inputs = torch.randn(32, 10).cuda()  # 一个 32 样本的 batch,每个样本 10 个特征
targets = torch.randn(32, 5).cuda()  # 对应的目标输出

# 前向传播
outputs = model(inputs)

# 计算损失
loss = criterion(outputs, targets)

# 反向传播
optimizer.zero_grad()
loss.backward()

# 更新模型参数
optimizer.step()

DistributedDataParallel (简称 DDP) 是 PyTorch 用于分布式训练的高级并行化工具,它的效率和灵活性比 DataParallel 更高,特别适合在多个 GPU 甚至跨多个节点(机器)上进行分布式训练。与 DataParallel 不同,DDP 在每个设备(GPU)上独立处理模型的前向传播和反向传播,并且避免了主设备的瓶颈问题。

DistributedDataParallel 的工作原理

  1. 模型的分发:与 DataParallel 类似,DDP 会在每个 GPU 上保留一份模型副本。但与 DataParallel 不同的是,DDP 不需要将数据集中在主设备上,而是让每个 GPU 独立完成自己的工作。
  2. 前向和反向传播:每个 GPU 上的模型执行前向传播和反向传播,并计算梯度。
  3. 梯度同步:每个设备上计算的梯度通过 all-reduce 操作在所有设备之间同步,确保所有模型副本的梯度相同。这个过程是并行进行的,不会像 DataParallel 那样集中在主设备上,因此通信效率更高。
  4. 参数更新:每个设备独立地应用梯度更新全局模型参数。

DistributedDataParallel 的优点

  • 高效的通信和同步:梯度的同步是在所有设备之间并行进行的,避免了主设备成为通信瓶颈的问题,因此在多 GPU 或跨节点时表现更加优异。
  • 可扩展性强DDP 支持跨多台机器的训练,适合超大规模模型或需要跨节点的分布式训练。
  • 无锁设计DDP 实现了无锁的梯度同步,不会因锁机制造成性能损失。

DistributedDataParallel 的使用

DataParallel 类似,DDP 也需要对模型进行包装,但它需要更多的设置,特别是在多机环境下,还需要配置通信后端。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# 初始化分布式环境
def setup(rank, world_size):
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

# 销毁分布式环境
def cleanup():
    dist.destroy_process_group()

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# 初始化模型、优化器和数据
def main(rank, world_size):
    setup(rank, world_size)

    model = SimpleModel().cuda(rank)
    ddp_model = DDP(model, device_ids=[rank])

    criterion = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

    # 模拟输入数据
    inputs = torch.randn(32, 10).cuda(rank)
    targets = torch.randn(32, 5).cuda(rank)

    # 前向传播
    outputs = ddp_model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()

    # 更新模型参数
    optimizer.step()

    cleanup()

# 假设有两个GPU,可以这样启动分布式训练
if __name__ == "__main__":
    world_size = 2  # GPU数
    torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size, join=True)
特性DataParallel (DP) DistributedDataParallel (DDP)
通信模式主设备负责梯度同步所有设备间并行同步梯度
性能通信开销大,主设备瓶颈通信开销小,性能更高
可扩展性适用于单机多 GPU适用于单机或多机多 GPU
使用场景小规模并行大规模或跨节点分布式训练

2. 并行数据加载

在深度学习任务中,数据加载通常是训练过程中的一个瓶颈,特别是当数据量很大时。使用多个进程来并行加载数据,并将数据从可分页内存(虚拟内存)转移到固定内存(GPU 内存)可以显著提高训练效率。

工作流程

  1. 数据加载

    • 使用多个进程并行从磁盘读取数据。每个进程负责加载不同的数据批次,减少了磁盘 I/O 操作的等待时间。
  2. 生产者-消费者模式

    • 数据加载进程(生产者)将读取的数据批次放入队列中,而主线程(消费者)从队列中取出数据批次进行训练。这样可以在数据加载和模型训练过程中实现并行化,减少数据加载对训练速度的影响。
  3. 固定内存的使用

    • 将数据从主机的可分页内存转移到固定内存。数据被加载到固定内存中后,转移到 GPU 的速度会更快,因为固定内存中的数据可以快速传输。

参数解释

  1. num_workers

    • 这个参数指定了数据加载的进程数量。将 num_workers 设置为大于 0 的值可以让 DataLoader 使用多个子进程来并行加载数据。
    • 例如,num_workers=4 表示使用 4 个进程来加载数据。这可以显著提高数据加载速度,因为多个进程可以同时从磁盘读取不同的数据批次。
  2. pin_memory

    • 这个参数用于将数据从主机内存(CPU 内存)固定到页面锁定内存(pinned memory)。固定内存可以让数据传输到 GPU 更加高效。
    • pin_memory=True 时,DataLoader 会将数据从可分页的内存(虚拟内存)传输到固定内存中,这样在将数据转移到 GPU 时,数据传输速度会更快,因为固定内存可以避免页面交换的开销。

总结

  • 数据加载:使用多个进程来并行加载和预处理数据,通过流水线处理减少数据加载的延迟。
  • 数据传输:利用 CUDA 流优化从固定内存到 GPU 的数据传输。
  • 数据并行性:使用数据并行和 NCCL 等通信库实现高效的梯度同步和模型参数更新,优化训练过程。

这种方法结合了数据加载、数据传输和数据并行处理的优化,能够显著提升深度学习模型的训练效率和速度。

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, size):
        self.data = np.random.rand(size, 3, 224, 224).astype(np.float32)
        self.labels = np.random.randint(0, 2, size).astype(np.int64)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx]), torch.tensor(self.labels[idx])

dataset = CustomDataset(size=10000)
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,      # 使用 4 个子进程加载数据
    pin_memory=True     # 将数据转移到固定内存
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for inputs, labels in dataloader:
    inputs, labels = inputs.to(device), labels.to(device)
    # 模型训练代码
    # ...

 参考文章:

Pytorch 分布式训练(DP/DDP)_pytorch分布式训练-CSDN博客icon-default.png?t=O83Ahttps://blog.csdn.net/ytusdc/article/details/122091284?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522CC589E02-BBE1-4F15-BDC0-CA76EBF6C160%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=CC589E02-BBE1-4F15-BDC0-CA76EBF6C160&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-122091284-null-null.142^v100^control&utm_term=%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83&spm=1018.2226.3001.4187

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2140604.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构之树的常用术语

二叉树的常用术语 前言 由于数组在插入、删除上的缺点和链表在查询上的缺点,出现了树的数据结构,可以在增删改查中弥补数组和链表的缺陷。 常用数据 节点:每个节点根节点:最上层的节点,Root节点父节点:相…

基于SSM的宿舍管理系统的设计与实现 (含源码+sql+视频导入教程)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM的宿舍管理系统拥有两种角色,分别为管理员和宿管,具体功能如下: 管理员:学生管理、班级管理、宿舍管理、卫生管理、访客管理、用户…

SOT23封装1A电流LDO具有使能功能的 1A、低 IQ、高精度、低压降稳压器系列TLV757P

前言 SOT23-5封装的外形和丝印 该LDO适合PCB空间较小的场合使用,多数SOT23封装的 LDO输出电流不超过0.5A。建议使用时输入串联二极管1N4001,PCB布局需要考虑散热,参考文末PCB布局。 1 特性 • 采用 SOT-23 (DYD) 封装,具有 60.3C/W RθJA •…

双指针算法专题(2)

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏: 优选算法专题 想要了解双指针算法的介绍,可以去看下面的博客:双指针算法的介绍 目录 611.有效三角形的个数 LCR 1…

【天池比赛】【零基础入门金融风控 Task2赛题理解】实战进行中……20240915更新至2.3.4.3 查看训练集测试集中特征属性只有一值的特征

2.3 代码示例 2.3.1 导入数据分析及可视化过程需要的库 import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import datetime import warnings warnings.filterwarnings(ignore) 2.3.2 读取文件 #读取数据时相对路径载入报错时…

【Redis】之Geo

概述 Geo就是Geolocation的简写形式,代表地理坐标。在Redis中,构造了能够存储地址坐标信息的一种数据结构,帮助我们根据经纬度来检索数据。 命令行操作方法 GEOADD 可以用来添加一个或者多个地理坐标。 GEODIST 返回一个key中两个成员之…

AgentTuning:提升大型语言模型的通用Agent能力

人工智能咨询培训老师叶梓 转载标明出处 大模型被用作现实中复杂任务的Agent时,它们的表现往往不如商业模型,如ChatGPT和GPT-4。这些任务要求LLMs作为中央控制器,负责规划、记忆和工具利用,这就需要精巧的提示方法和鲁棒性强的LL…

华为的仓颉和ArkTS这两门语言有什么区别

先贴下官网: ArkTs官网 仓颉官网 ArkTS的官网介绍说,ArkTS是TypeScript的进一步强化版本,简单来说就是包含了TS的风格,但是做了一些改进。 了解TypeScript的朋友都应该知道,其实TypeScript就是JavaScript的改进版本&…

基于springboot 自习室预订系统 前后端分离

基于springboot 自习室预订系统 前后端分离 目 录 摘 要 I Abstract II 第1章 前 言 2 1.1 研究背景 3 1.2 研究现状 3 1.3 系统开发目标 3 第2章 系统开发环境 5 2.1 java技术 5 2.2 Mysql数据库 6 2.3 B/S结构 7 2.4 springboot框架 7 2.5 ECLIPSE 开发环境 7 …

Redis的配置与优化

目录 一、关系数据库与非关系型数据库 1.1、关系型数据库 1.2、非关系型数据库 1.3、关系型数据库和非关系型数据库区别 数据存储方式不同 扩展方式不同 对事务性的支持不同 1.4、非关系型数据库产生背景 二、Redis简介 2.1、Redis优点 2.2、Redis为什么这么快 三、…

CefSharp_Vue交互(Element UI)_WinFormWeb应用---设置应用透明度(含示例代码)

一、界面预览 1.1 设置透明(整个页面透明80%示例) 限制输入值:10-100(数字太小会不好看见) 1.2 vue标题栏 //注册类与js调用 (async function(

速通汇编(五)认识段地址与偏移地址,CS、IP寄存器和jmp指令,DS寄存器

一,地址的概念 通常所说的地址指的是某内存单元在整个机器内存中的物理地址,把整个机器内存比作一个酒店,内存单元就是这个酒店的各个房间,给这些房间编的门牌号,类比回来就是内存单元的物理地址 在第一篇介绍debug的…

Scratch植物大战僵尸【机器人vs外星人版本】

小虎鲸Scratch资源站-免费少儿编程Scratch作品源码,素材,教程分享网站! 简介 在这个教学案例中,我们将制作一个类似《植物大战僵尸》的Scratch游戏,主题为“机器人对抗外星人”。这个版本将采用创新的角色设计,机器人将保护地球免受外星人入…

SQL题目分析:打折日期交叉问题--计算品牌总优惠天数

在电商平台的数据分析中,处理品牌促销活动的日期交叉问题是一个挑战。本文将介绍几种高级SQL技巧,用于准确计算每个品牌的总优惠天数,即使在存在日期交叉的情况下。 问题背景 我们有一个促销活动表 shop_discount,记录了不同品牌…

算法:76.最小覆盖子串

题目 链接:leetcode链接 思路分析(滑动窗口) 还是老样子,连续问题,滑动窗口哈希表 令t用的hash表为hash1,s用的hash表为hash2 利用hash表统计窗口内的个字符出现的个数,与hash1进行比较 选…

SpringBoot 消息队列RabbitMQ在代码中声明 交换机 与 队列使用注解创建

创建Fanout交换机 Configuration public class FanoutConfig {Beanpublic FanoutExchange fanoutExchange(){return new FanoutExchange("csdn.fanout");//交换机名称} }创建队列 Beanpublic Queue fanoutQueue3(){return new Queue("csdn.queue");}绑定…

Nature Climate Change | 全球土壤微生物群落调控微生物呼吸对变暖的敏感性(Q10)

本文首发于“生态学者”微信公众号! 全球变暖将加速有机物分解,从而增加土壤中二氧化碳的释放,触发正的碳-气候反馈。这种反馈的大小在很大程度上取决于有机质分解的温度敏感性(Q10)。Q10仍然是围绕土壤碳排放到大气的预测的主要不确定性来源…

软考架构-层次架构风格

一、两层C/S架构 客户端和服务器都有处理功能。处理在表示层(客户端)和数据层(服务器)进行 二、三层C/S架构 将处理功能独立出来。表示层在客户机上,功能层在应用服务器上,数据层在数据库服务器上。 三…

玄机科技浪漫绘情缘:海神缘下,一吻定情

在史莱克学院那片璀璨星空的见证下,《斗罗大陆II绝世唐门》第65集“海神缘相亲大会”的浪漫序幕,温柔地触动了每一位观众的心弦。 本集中,霍雨浩与王冬之间那段跨越重重障碍、终得相守的浪漫告白,在玄机科技独特的审美视角、精细…

强化学习Reinforcement Learning|Q-Learning|SARSA|DQN以及改进算法

一、强化学习RL 强化学习是机器学习的一个重要的分支,是一种有效的工具,在文献中被广泛用于解决MDP问题。在一个强化学习过程中,一个智能体只能通过和它所处的环境互动学习最优策略。特别地,智能体首先观察自己当前的状态&#xf…