深度学习:Pytorch分布式训练

news2024/11/17 3:37:57

深度学习:Pytorch分布式训练

  • 简介
  • 模型并行
  • 数据并行
  • 参考文献

简介

在深度学习领域,模型越来越庞大、数据量不断增加,训练这些大型模型越来越耗时。通过在多个GPU或多个节点上并行地训练模型,我们可以显著减少训练时间。此外,某些模型因为巨大的参数量,单个设备可能无法容纳其整个模型和数据。在这种情况下,分布式训练不仅能提高训练速度,更是必要的手段来训练大模型。为此,PyTorch 分布式训练提供了两种基本的并行方法:

  • 模型并行(Model Parallel):模型并行是指将模型的不同部分放到不同的设备上。这种方式通常用于当一个单独的模型太大而无法放到单个GPU上时。

  • 数据并行(Data Parallel):数据并行是将训练数据分割并在多个设备上同时训练的方法。PyTorch提供了 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel 用于在多个GPU上并行化模型训练。

模型并行

在这里插入图片描述

模型并行主要利用to(device)函数将模型和数据(Tensor张量)放置在适当设备上,其余代码基本无需额外改动。
以下是一个简单的模型并行的代码示例:

import torch
import torch.nn as nn
import torch.optim as optim


class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.net1 = torch.nn.Linear(10, 10).to('cuda:0')
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to('cuda:1')

    def forward(self, x):
        x = self.relu(self.net1(x.to('cuda:0')))
        return self.net2(x.to('cuda:1'))

model = DemoModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

optimizer.zero_grad()
outputs = model(torch.randn(20, 10))
labels = torch.randn(20, 5).to('cuda:1')
loss_fn(outputs, labels).backward()
optimizer.step()

注意调用损失函数时,您只需要确保标签与输出位于同一设备上。不难看出,此模型并行的方法效率相对较低,因为在任何时间点,两个 GPU 中只有一个在工作,而另一个则处于闲置状态。而且中间过程变量从cuda:0复制到cuda:1,又会需要额外的开销。因此可以引入流水线并行来进行加速。

在以下代码示例中,采取将输入数据批次划分为 20 组。由于 PyTorch 异步启动 CUDA 操作,因此可以不需要生成多个线程来实现并发。值得注意的是,使用较小的结果split_size会导致许多微小的 CUDA 内核启动,而使用较大的split_size会导致在第一次和最后一次数据划分期间存在相对较长的空闲时间。因此split_size对于特定实验可能有一个最佳配置,可以多次尝试最佳的超参数。

class PipelineParallelResNet50(ModelParallelResNet50):
    def __init__(self, split_size=20, *args, **kwargs):
        super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
        self.split_size = split_size

    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        s_prev = self.seq1(s_next).to('cuda:1')
        ret = []

        for s_next in splits:
            # A. ``s_prev`` runs on ``cuda:1``
            s_prev = self.seq2(s_prev)
            ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

            # B. ``s_next`` runs on ``cuda:0``, which can run concurrently with A
            s_prev = self.seq1(s_next).to('cuda:1')

        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

        return torch.cat(ret)

数据并行

在这里插入图片描述

DataParallel是单进程、多线程,仅适用于单机,而是DistributedDataParallel多进程,适用于单机和多机训练。由于跨线程的 GIL 争用、每次迭代复制模型以及分散输入和收集输出带来的额外开销,DataParallel通常比DistributedDataParallel在单台机器上更慢。

一般地,数据并行的流程为:

  1. 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时初始化 distributed 包。
  2. 如果需要进行组内集体通信,用 new_group 创建子分组
  3. 创建分布式并行模型 DDP(model, device_ids=device_ids)
  4. 为数据集创建 Sampler
  5. 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练
  6. 使用 destory_process_group() 销毁进程组

以下是一个简单的数据并行的代码示例:

# demo_ddp.py
# 在init_process_group()时,一般可设置为Gloo、NCCL或mpi后端,Gloo目前在GPU上运行速度比 NCCL慢。所以经验法则是:
# 分布式GPU训练使用 NCCL 后端
# 分布式CPU训练使用 Gloo 后端

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    device_id = rank % torch.cuda.device_count()
    model = DemoModel().to(device_id)
    ddp_model = DDP(model, device_ids=[device_id])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    dist.destroy_process_group()

if __name__ == "__main__":
    demo_basic()

然后使用torchrun命令进行启动,其中,nnodes表示总节点数,nproc_per_node表示每个节点运行的进程数,rdzv_id表示用户定义的ID,唯一标识作业的工作组, rdzv_backend表示集合点的后端,rdzv_endpoint表示rendezvous后端运行的地址

# 需要应用 slurm 等集群管理工具来实际在 2 个节点上运行此命令。
export MASTER_ADDR=$(scontrol show hostname ${SLURM_NODELIST} | head -n 1)
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 demo_ddp.py

此命令表示在两台服务器上运行 DDP 脚本,每台服务器运行 8 个进程,即在 16 个 GPU 上运行。

参考文献

  1. https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
  2. https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  3. https://medium.com/deelvin-machine-learning/model-parallelism-vs-data-parallelism-in-unet-speedup-1341bc74ff9e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1611049.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python长时间序列遥感数据处理及在全球变化、物候提取、植被变绿与固碳分析、生物量估算与趋势分析等领域中的应用

植被是陆地生态系统中最重要的组分之一,也是对气候变化最敏感的组分,其在全球变化过程中起着重要作用,能够指示自然环境中的大气、水、土壤等成分的变化,其年际和季节性变化可以作为地球气候变化的重要指标。此外,由于…

安装Hexo上传插件时,使用hexo d 报错“Spawn Failed”

问题如下: 解决方案: 找到deploy的配置文件 将配置文件增加 [user]email xxxx你git的邮箱地址name 你git的用户名 然后执行hexo clean清除一下后,执行hexo d,会弹出让你登录git的账号和密码,登录后就上传成功了。 …

Java 判断两个Date类型的时间是否大于6天

判断两个Date类型的时间是否大于6天 new Date().toInstant() 是获取当前时间,并转换成Instant对象 cardDeviceTrajectoryInfo.getGpstime().toInstant() 是表中最后一条数据的时间,并转换成Instant对象 // 计算两个时间的间隔 long daysBetween ChronoU…

【基础IO】谈谈动静态库(怒肝7000字)

文章目录 前言实验代码样例静态库生成一个静态库归档工具ar静态库的链接 动态库创建动态库加载动态库 动静态链接静态链接动态链接动静态链接的优缺点 前言 在软件开发中,库(Library)是一种方式,可以将代码打包成可重用的格式&…

隧道代理的优势与劣势分析

“随着互联网的快速发展,网络安全已经成为一个重要的议题。为了保护个人和组织的数据,隧道代理技术逐渐成为网络安全的重要工具。隧道代理通过在客户端和服务器之间建立安全通道,加密和保护数据的传输,有效地防止黑客入侵和信息泄…

docker安装并跑通QQ机器人实践(4)-bs-cqhttp搭建

go-cqhttp,基于 Mirai 以及 MiraiGo 的 OneBot Golang 原生实现,只需简单的配置, 就可以基于 go-cqhttp 使用框架开发,具有轻量, 原生, 高并发, 低占用, 跨平台等特点。 1 go-cqhttp 官网及可执行文件下载链接 go-cqhttp 官网:ht…

如何通过MSTSC连接Ubuntu的远程桌面?

正文共:666 字 12 图,预估阅读时间:1 分钟 前面我们介绍了如何通过VNC连接Ubuntu 18.04的远程桌面(Ubuntu 18.04开启远程桌面连接),非常简单。但是有小伙伴咨询如何使用微软的远程桌面连接MSTSC&#xff08…

二维码门楼牌管理应用平台建设:取保候审的智能化监管

文章目录 前言一、取保候审的传统监管困境二、二维码门楼牌管理应用平台的优势三、取保候审备案信息的智能化处理四、保障居民合法权益五、展望未来 前言 随着信息技术的飞速发展,二维码门楼牌管理应用平台已成为现代社区治理的重要工具。本文重点探讨如何借助该平…

保姆级教程!QRCNN-BiLSTM一键实现多变量回归区间预测!区间预测全家桶再更新!

​ 声明:文章是从本人公众号中复制而来,因此,想最新最快了解各类智能优化算法及其改进的朋友,可关注我的公众号:强盛机器学习,不定期会有很多免费代码分享~ 今天对我们之前推出的区间预测全家桶进行…

最邻近插值和线性插值

最邻近插值 在图像分割任务中:原图的缩放一般采用双线性插值,用于上采样或下采样;而标注图像的缩放有特定的规则,需使用最临近插值,不用于上采样或下采样。 自定义函数 这个是通过输入原始图像和一个缩放因子来对图像…

JVM类加载基本流程及双亲委派模型

1.JVM内存区域划分 一个运行起来的Java进程就是一个JVM虚拟机,这就需要从操作系统中申请一片内存区域。JVM申请到内存之后,会把这个内存划分为几个区域,每个区域都有各自的作用。 一般会把内存划分为四个区域:方法区(也称 "…

力扣287. 寻找重复数

Problem: 287. 寻找重复数 文章目录 题目描述思路解题方法复杂度Code 题目描述 思路 利用二分查找搜索1 ~ n中重复的元素,我们每次取出当前二分查找的区间的中间元素mid并在元始的数组nums中统计小于mid的元素的个数count: 若count > mid则说明重复的…

大数据学习的第三天

文章目录 学习大数据命令的方式查看文件拷贝文件的方式添加数据的方式 出现了问题移动文件 hadoop工作流程和工作机制的方式namenodedatanodesecondarynamenode(主节点) 学习大数据命令的方式 查看文件 hadoop fs -cat /test/2.txt下载文件 hadoop fs -get -f /test/2.txt-f …

通俗说字解词:什么是道理?常说讲道理,李秘书讲写作这节就给你讲“道理”!

通俗说字解词:什么是道理?常说讲道理,李秘书讲写作这节就给你讲“道理”! 说到“道理”,这可真是个有意思的词。它由“道”和“理”两个部分组成,就像一碗好吃的面,有汤有料,缺一不可…

xilinx cpri ip 开发记录

CPRI是无线通信里的一个标准协议,连接REC和RE的通信。 Xilinx有提供CPRI IP核。 区别于其它通信协议,如以太网等,CPRI是一个同步系统。 这就意味着两端的Master和Slave应当是同源时钟的,两边不存在频差,并且内部延时…

ikigai极简3p模型:想、能、有

ikigai模型简化为3p模型: - passion 想要、想做 - professional 能要、能做 - profit 有益、有利 根据三角形不可能定律,三者满足两个就很不容易了。又想做又能做的未必有钱,又能做又有钱的未必想做,又想做又有钱的未必能做。 要实…

(C语言)sscanf 与 sprintf详解

目录 1.sprintf函数详解 2. sscanf函数详解 1.sprintf函数详解 头文件&#xff1a;stdio.h 作用&#xff1a;将格式化的数据写入字符串里&#xff0c;也就是将格式化的数据转变为字符串。 演示&#xff1a; #include <stdio.h> struct S {char name[10];int height;…

LeetCode---128双周赛

题目列表 3110. 字符串的分数 3111. 覆盖所有点的最少矩形数目 3112. 访问消失节点的最少时间 3113. 边界元素是最大值的子数组数目 一、字符串的分数 按照题目要求&#xff0c;直接模拟遍历即可&#xff0c;代码如下 class Solution { public:int scoreOfString(string …

如何通过通过钉钉发送信息????????

1、通过钉钉群添加一个机器人 2、代码实现 /*** 发钉钉审核.** param*/private void sendDingDing(String tableName) {String url "https://oapi.dingtalk.com/robot/send?access_token229c627d05a3157f79a5ef1942d29c4dfb4515bf5c0ad65e3c69423bc016f97c";JSONOb…

达梦数据库的AWR报告

达梦数据库的AWR报告 数据库快照是一个只读的静态的数据库。 DM 快照功能是基于数据库实现的&#xff0c;每个快照是基于数据库的只读镜像。通过检索快照&#xff0c;可以获取源数据库在快照创建时间点的相关数据信息。 为了方便管理自动工作集负载信息库 AWR&#xff08;Auto…