Pytorch实现多GPU并行训练(DDP)

news2024/11/19 3:36:11

Pytorch实现并行训练通常有两个接口:DP(DataParallel)DDP(DistributedDataParallel)。目前DP(DataParallel)已经被Pytorch官方deprecate掉了,原因有二:1,DP(DataParallel)只支持单机多卡,无法支持多机多卡;2,DP(DataParallel)即便在单机多卡模式下效率也不及DDP(DistributedDataParallel)。我们可以看一下官方文档里的描述:

DistributedDataParallel is proven to be significantly faster than torch.nn.DataParallel for single-node multi-GPU data parallel training.

基于这两个缺点,DP(DataParallel)本文就不介绍了,即使DP(DataParallel)更好上手。本文重点讲解如何使用DDP(DistributedDataParallel)来完成并行训练。

官方文档:https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
官方视频教程:https://pytorch.org/tutorials/beginner/ddp_series_intro.html


DDP

DDP是基于多进程来实现并行训练,每个GPU依靠独立的进程来驱动,进程之间有特殊的通信机制。先介绍几个变量:

变量名意义
rank0在节点中的编号,比如当前是第0张GPU
world_size4整个节点数量,比如一共4张GPU

我们先看一个单GPU代码,然后将其转换成DDP并行训练代码,观察变化来进行学习
单GPU版本:(实验代码来自https://github.com/pytorch/examples/tree/main/distributed/ddp-tutorial-series)

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        return self.data[index]


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True
    )


def main(device, total_epochs, save_every, batch_size):
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, device, save_every)
    trainer.train(total_epochs)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    device = 0  # shorthand for cuda:0
    main(device, args.total_epochs, args.save_every, args.batch_size)

直接复制上面代码,保存为single_gpu.py,然后用以下命令就可以让代码运行:

python single_gpu.py 10000 1000

这一步如果运行失败,请留言告知。如果运行成功,我们观察GPU,会发现该程序只会调用第一张GPU。我们将上述代码稍微改进,就可以得到多GPU版本了:

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        return self.data[index]

def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

将上述代码保存为mp_train.py,然后直接运行:

python mp_train.py 10000 1000

我们再观察GPU,会发现所有GPU均被调用。到此,实验部分已经结束。下面进入分析阶段。


从单卡变为多卡,需要进行哪些修改?

我们可以简单比对以下差别:

vimdiff single_gpu.py mp_train.py

vimdiff可以比较两个文档的差别,这是一个小trick,分享给大家。

  • 第一步:初始化
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

def ddp_setup(rank, world_size):
   """
   Args:
       rank: Unique identifier of each process
       world_size: Total number of processes
   """
   os.environ["MASTER_ADDR"] = "localhost"
   os.environ["MASTER_PORT"] = "12355"
   init_process_group(backend="nccl", rank=rank, world_size=world_size)
   torch.cuda.set_device(rank)

init_process_group就是初始化群组,相当于给GPU们打了个招呼,我们要进行并行训练啦。这里的"os.environ"可以不用修改,因为只是单机多卡,不涉及多机多卡。"nccl"表示Nvidia Collective Communications Library,是一个跨GPU的通信后端类型,一般也不用修改。这里需要注意的是rankworld_size,比如咱们是4卡机,rank的取值范围是[0,1,2,3],而world_size=4。

  • 第二步:改造模型和数据加载器
    vimdiff
    vimdiff2
    左边是单卡模式,右边是多卡模式。
  • 步骤三:用多进程启动
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

如有不解,欢迎留言讨论~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅聊Android性能优化~

作者:一只修仙的猿 前言 关于性能优化,可能我们的第一反应是这是高手做的事情,一直以来我也是这样认为的。但在最近一段时间,在公司项目上做了一些框架的性能优化,让我初步掀开了性能优化的面纱,也对他有了…

VirtualBox 部署 KVM 虚拟化

什么是KVM技术? KVM(Kernel-based Virtual Machine)是一种开源的虚拟化技术,它是Linux内核的一部分。KVM通过将Linux内核转换为Hypervisor,允许在同一物理主机上运行多个虚拟机实例,每个实例可以独享一部分…

JAVA学习之String学习

1.底层是用什么实现的? JDK8用的char数组,JDK9开始使用byte数组,而且都是final型,所以不同字符串(值)的地址必然不同。 char和byte的区别:char是2个字节表示,而byte是一个字节。 JDK17中&…

Vue3解决:[Vue warn]: Failed to resolve component: el-table(或el-button) 的三种解决方案

1、问题描述&#xff1a; 其一、报错为&#xff1a; [Vue warn]: Failed to resolve component: el-table If this is a native custom element, make sure to exclude it from component resolution via compilerOptions.isCustomElement. at <App> 或者&#xff1a; …

网络安全(黑客)自学

建议一&#xff1a;黑客七个等级 黑客&#xff0c;对很多人来说充满诱惑力。很多人可以发现这门领域如同任何一门领域&#xff0c;越深入越敬畏&#xff0c;知识如海洋&#xff0c;黑客也存在一些等级&#xff0c;参考知道创宇 CEO ic&#xff08;世界顶级黑客团队 0x557 成员&…

kubectl-ai:K8S资源清单的GPT助手

琦彦&#xff0c;在 **云原生百宝箱 **公众号等你&#xff0c;与你一起探讨应用迁移&#xff0c;GitOps&#xff0c;二次开发&#xff0c;解决方案&#xff0c;CNCF生态&#xff0c;及生活况味。 kubectl-ai 项目是一个kubectl使用 OpenAI GPT 生成和应用 Kubernetes 清单的插件…

【APP自动化测试必知必会】Appium之微信小程序自动化测试

本节大纲 H5 与小程序介绍 混合 App 元素定位环境部署 混合 App 元素操作 Airtest 测试 App 01.H5与小程序介绍 H5概述 H5 是指第 5 代 HTML &#xff0c;也指用 H5 语言制作的一切数字产品。 所谓 HTML 是“超文本标记语言”的英文缩写。我们上网所看到网页&#xf…

Oculus创始人谈Vision Pro:苹果在硬件设计、营销都做对了选择

早在Vision Pro正式发布之前&#xff0c;Oculus创始人Palmer Luckey就已经体验过早期版本&#xff0c;并给出了极高的评价。Luckey指出&#xff0c;苹果在XR头显上的策略是明智的&#xff0c;先打造出每个人预期中的头显&#xff0c;然后再去考虑如何让大家买得起。 Vision Pro…

远程控制电脑软件VNC安装使用教程:Windows系统

什么是VNC&#xff1f; VNC (Virtual Network Console)&#xff0c;即虚拟网络控制台&#xff0c;它是一款基于 UNIX 和 Linux 操作系统的优秀远程控制工具软件&#xff0c;由著名的 AT&T 的欧洲研究实验室开发&#xff0c;远程控制能力强大&#xff0c;高效实用&#xff…

【python】python编程基础

基础工具包 python 原生数据结构元组 Tuple列表 list集合 set字典 dictionary NumPy 数据结构数组 Ndarray矩阵 Matrix Pandas 数据结构序列 Series &#xff08;一维&#xff09;数据框 DataFrame &#xff08;二维&#xff09; Matplotlib 数据可视化绘制饼图绘制折线图绘制直…

《Linux操作系统编程》 第六章 Linux中的进程监控: fork函数的使用,以及父子进程间的关系,掌握exec系列函数

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

一文读懂高分文章必备分析-GSEA

Gene Set Enrichment Analysis 或称 GSEA&#xff0c;是一种常用于转录组基因表达分析的数据挖掘技术&#xff0c;已经在《nature》、《Cell》、《ISME》、《Molecular Cell》、《Bioactive Materials》等高分杂志中发表多篇文章&#xff0c;涉及转录组及多组学内容。 凌恩生物…

yxcms存储型XSS至getshell 漏洞复现

为方便您的阅读&#xff0c;可点击下方蓝色字体&#xff0c;进行跳转↓↓↓ 01 环境部署02 漏洞配置03 利用方式04 修复方案 01 环境部署 &#xff08;1&#xff09;yxcms yxcms 基于 PHPMySQL 开发&#xff0c;这是一个采用轻量级 MVC 设计模式的网站管理系统。轻量级 MVC 设…

13-Cookie、Session、Token

目录 1.前置知识——HTTP协议 1.1.HTTP 的主要特点有以下 5 个&#xff1a; 1.2.HTTP 组成 1.3.为什么会有Cookie、Session、Token&#xff1f; 2.Cookie 3.Session PS&#xff1a;Cookie 和 Session 的联系与区别 4.Token 4.1.token的组成 4.2.token是如何生成的&am…

【广州华锐互动】VR航天航空体验展厅提供沉浸式的展示效果

VR航天航空体验展厅是一种基于虚拟现实技术的在线展览形式&#xff0c;它通过模拟真实的太空环境&#xff0c;为用户提供了一种身临其境的参观体验。与传统的线上展览相比&#xff0c;VR航天航空体验展厅具有以下几个特色&#xff1a; 1.沉浸式体验&#xff1a;VR航天航空体验…

史上最详细的webrtc-streamer访问摄像机视频流教程

目录 前言 一、webrtc-streamer的API 二、webrtc-streamer的启动命令介绍 1.原文 2.译文 三、webrtc-streamer的安装部署 1.下载地址 https://github.com/mpromonet/webrtc-streamer/releases 2.windows版本部署 3.Linux版本部署 四、springboot整合webrtc-streamer …

技术不断变革,亚马逊云科技中国峰会引领企业重塑业务

过去十年&#xff0c;数字化转型的浪潮携带着机遇和挑战席卷而来&#xff0c;几乎每个企业都在做数字化转型&#xff0c;开始向大数据、人工智能等新技术寻求生产力的突破。但随着数字化转型深入&#xff0c;很多企业开始感受到数字化投入的成本压力&#xff0c;加之新技术正带…

Alibi:Attention With Linear Biases Enables Input Length Extrapolation

Alibi:Attention With Linear Biases Enables Input Length Extrapolation IntroductionMethodResult参考 Introduction 假设一个模型在512token上做训练&#xff0c;在推理的时候&#xff0c;模型在更长的序列上表现叫做模型的外推性。作者表明以前的位置编码如Sin、Rotary、…

JS 数据变化监听函数封装

文章目录 监听函数使用用例重复添加函数&#xff0c;只有最后一个监听函数有效 监听函数 /*** 监听函数* param {对象} vm * param {键值} key * param {触发函数} action */ function WatchValueChange(vm, key, action) {var val vm[key]Object.defineProperty(vm, key, {e…

阿里内部流传出来的《1000 道互联网大厂 Java 工程师面试题》来袭,面试必刷,跳槽大厂神器

眼看着"金九银十"也快到来了&#xff0c;很多小伙伴都蠢蠢欲动想要刚给自己涨一波薪资&#xff1b;面试作为涨薪最直接最有效的方式&#xff0c;我们需要花费巨大的精力和时间来准备。除了自身的技术积累之外&#xff0c;掌握一定的面试技巧和熟悉最常见的面试题&…