Pytorch单机多卡分布式训练

news2024/11/24 4:26:30

Pytorch单机多卡分布式训练

数据并行:

DP和DDP

这两个都是pytorch下实现多GPU训练的库,DP是pytorch以前实现的库,现在官方更推荐使用DDP,即使是单机训练也比DP快。

  1. DataParallel(DP)

    • 只支持单进程多线程,单一机器上进行训练。
    • 模型训练开始的时候,先把模型复制到四个GPU上面,然后把数据分配给四个GPU进行前向传播,前向传播之后再汇总到卡0上面,然后在卡0上进行反向传播,参数更新,再将更新好的模型复制到其他几张卡上。

    在这里插入图片描述

  2. DistributedDataParallel(DDP)

    • 支持多线程多进程,单一或者多个机器上进行训练。通常DDP比DP要快。

    • 先把模型载入到四张卡上,每个GPU上都分配一些小批量的数据,再进行前向传播,反向传播,计算完梯度之后再把所有卡上的梯度汇聚到卡0上面,卡0算完梯度的平均值之后广播给所有的卡,所有的卡更新自己的模型,这样传输的数据量会少很多。

      在这里插入图片描述

DDP代码写法

  1. 初始化

    import torch.distributed as dist
    import torch.utils.data.distributed
    
    # 进行初始化,backend表示通信方式,可选择的有nccl(英伟达的GPU2GPU的通信库,适用于具有英伟达GPU的分布式训练)、gloo(基于tcp/ip的后端,可在不同机器之间进行通信,通常适用于不具备英伟达GPU的环境)、mpi(适用于支持mpi集群的环境)
    # init_method: 告知每个进程如何发现彼此,默认使用env://
    dist.init_process_group(backend='nccl', init_method="env://")
    
  2. 设置device

    device = torch.device(f'cuda:{args.local_rank}')	# 设置device,local_rank表示当前机器的进程号,该方式为每个显卡一个进程
    torch.cuda.set_device(device)	# 设定device
    
  3. 创建dataloader之前要加一个sampler

    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)	# 加一个sampler
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)
    
  4. torch.nn.parallel.DistributedDataParallel包裹模型(先to(device)再包裹模型)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])	# 包裹模型
    
  5. 真正训练之前要set_epoch(),否则将不会shuffer数据

    for epoch in range(10):
        train_sampler.set_epoch(epoch)		# set_epoch
        for step, data in enumerate(data_loader_train):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            opt.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            if step % 10 == 0:
                print("loss: {}".format(loss.item()))
    
  6. 模型保存

    if args.local_rank == 0:		# local_rank为0表示master进程
    	torch.save(net, "my_net.pth")
    
  7. 运行

    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        # local_rank参数是必须的,运行的时候不必自己指定,DDP会自行提供
        parser.add_argument("--local_rank", type=int, default=0)
        args = parser.parse_args()
        main(args)
    
  8. 运行命令

    python -m torch.distributed.launch --nproc_per_node=2 多卡训练.py	# --nproc_per_node=2表示当前机器上有两个GPU可以使用
    

完整代码

import os
import argparse
import torch
import torchvision
import torch.distributed as dist
import torch.utils.data.distributed

from torchvision import transforms
from torch.multiprocessing import Process

def main(args):
    # nccl: 后端基于NVIDIA的GPU-to-GPU通信库,适用于具有NVIDIA GPU的分布式训练
    # gloo: 后端是一个基于TCP/IP的后端,可在不同机器之间进行通信,通常适用于不具备NVIDIA GPU的环境。
    # mpi: 后端使用MPI实现,适用于具备MPI支持的集群环境。
    # init_method: 告知每个进程如何发现彼此,如何使用通信后端初始化和验证进程组。 默认情况下,如果未指定 init_method,PyTorch 将使用环境变量初始化方法 (env://)。
    dist.init_process_group(backend='nccl', init_method="env://") # nccl比较推荐
    device = torch.device(f'cuda:{args.local_rank}')
    torch.cuda.set_device(device)
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST("./", train=True, transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set, batch_size=256, sampler=train_sampler)

    net = torchvision.models.resnet101(num_classes=10)
    net.conv1 = torch.nn.Conv2d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], output_device=[device])
    criterion = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
    for epoch in range(10):
        train_sampler.set_epoch(epoch)
        for step, data in enumerate(data_loader_train):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            opt.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            if step % 10 == 0:
                print("loss: {}".format(loss.item()))
    if args.local_rank == 0:
        torch.save(net, "my_net.pth")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    main(args)

参考:

https://zhuanlan.zhihu.com/p/594046884
https://zhuanlan.zhihu.com/p/358974461

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1045580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Aqara秋季品鉴会众新品亮相 空间智能全面进化

9月26日,全球IoT独角兽Aqara以“空间智能,全面进化”为主题在北京三里屯CHAO酒店举办了秋季品鉴会。会上正式发布了方舟智慧中枢M3、卡农智能墙壁开关Z1 Pro、银河系列高端灯具V1以及具备全新功能升级的场景传感器FP2,方舟技术升级至2.0版本。…

浅谈Deep Learning 与 Machine Learning 与Artificial Intelligence

文章目录 三者的联系与区别 三者的联系与区别 “Deep Learning is a kind of Machine Learning, and Machine Learning is a kind of Artificial Intelligence.” 人工智能(AI),机器学习(Machine Learning,简称ML&am…

玩转 CODING 自动化助手,助力高效研发!

点击链接了解详情 在日常工作中,您是否会遇到下面的情况: 作为研发人员,从需求拆分出来的开发子任务完成时,还要手动修改需求为完成状态,不仅耗时还容易遗漏; 作为产品经理,每天都要关注需求/任…

游戏开发过程中需要注意哪些问题呢?

游戏开发是一个复杂的过程,需要注意多个方面的问题。以下是一些需要特别关注的关键问题: 游戏设计: 确定游戏的核心玩法和目标受众。 制定详细的游戏设计文档,包括角色、关卡设计、游戏机制和故事情节。 技术选择:…

项目04-基于Docker的Prometheus+Grafana+AlertManager的飞书监控报警平台

文章目录 一.项目介绍1.流程图2.拓扑图3.详细介绍 二.前期准备1.项目环境2.IP划分 三. 项目步骤1.ansible部署软件环境1.1 安装ansible环境1.2 建立免密通道1.3 批量部署docker 2 部署nginx、MySQL以及cadvisor、exporter节点2.1 在nginx节点服务器上面配置nginx、node_exporte…

高效批量剪辑的秘诀与技巧,虚化背景技巧在视频剪辑中的应用与创意

你是否曾经为了制作一个高质量的视频而感到烦恼?视频剪辑是一项繁琐的工作,但是使用批量剪辑工具可以让这个过程变得更加高效。今天,我们将向您介绍一款强大的批量剪辑工具——视频工厂,帮助您轻松制作高质量视频。 首先&#xff…

linux中mysql启动失败以及数据迁移

背影:服务启动失败:报错数据库连接太多导致mysql挂了 解决过程: 在任何部署信息都不知道的前提下(因为是被临时拉来解决的): 1、通过【find / -name mysql】或者【whereis mysql】查找(ps&am…

两表查询常用SQL

1、两个表:table_a和table_b,求两表的交集,关键字:INNER JOIN SELECT a.*,b.* FROM table_a AS a INNER JOIN table_b AS b ON a.idb.id; 2、两个表:table_a和table_b,table_a为主表&#xff0…

新旅程、新经营丨神策 2023 数据驱动大会 10 月 27-28 日北京见

以数据驱动为手段 以客户旅程为抓手 实现更好的数字化客户经营 「新旅程、新经营,决胜数字化」 神策 2023 数据驱动大会 报名通道正式开启 识别二维码立即报名 历经八载,神策数据驱动大会已成为国内数字化转型和营销科技领域的年度盛会!本届大…

这个国庆场景下的创意数据应用,体现了数字经济时代的商业价值

在生成式AI爆火的2023年,数据协作和数据交换的商业价值越来越明显。大模型的训练正需要海量跨领域数据的“投喂”,才能真正创造商业价值涌现的奇迹。而如何在保护数据安全的前提下,有效发挥数据资产的商业价值,成为企业数字化亟需…

[异构图-论文阅读]Heterogeneous Graph Transformer

这篇论文介绍了一种用于建模Web规模异构图的异构图变换器(HGT)架构。以下是主要的要点: 摘要和引言 (第1页) 异构图被用来抽象和建模复杂系统,其中不同类型的对象以各种方式相互作用。许多现有的图神经网络(GNNs)主要针对同构图设计,无法有效表示异构结构。HGT通过设计…

【力扣2656】K个元素的最大和

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述二、题目分析 一、题目描述 题目链接:K个元素的最大和 给你一个下标从 0 开始的整数数组 nums 和…

【kkFileView】源码编译打包构建镜像部署

目录 官网使用源码构建镜像k8s启动yaml参考使用介绍 官网 官网: http://kkfileview.keking.cn/zh-cn/index.html在线文档: http://kkfileview.keking.cn/zh-cn/docs/home.html源码地址: https://gitee.com/kekingcn/file-online-preview发行版下载页面: https://gitee.com/kek…

数据分析技能点-正态分布和其他变量分布

在数据驱动的世界里,了解和解释数据分布是至关重要的。不同类型的数据分布,如正态分布、二项分布和泊松分布,具有不同的特性和应用场景。这些分布不仅在统计学和数据科学中有广泛应用,而且在日常生活和商业决策中也起着关键作用。 文章目录 正态分布正态分布和偏差其他常见…

使用adb命令通过数据线操控Android手机设备屏幕

目录 第一步:下载并安装Android SDK Platform-Tools 第二步:启动adb并测试连接 第三步:操控手机 第一步:下载并安装Android SDK Platform-Tools 进入Android开发者网站上找到ADB工具包(包含在Android SDK Platform…

最新AI智能写作系统ChatGPT源码/支持GPT4.0+GPT联网提问/支持ai绘画Midjourney+Prompt+MJ以图生图+思维导图生成

一、AI创作系统 SparkAi系统是基于很火的GPT提问进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作ChatGPT系统?小编这里写一个详细图文教程吧&#x…

没有一技之长,该如何找工作?

很负责任的告诉你,跟你一样有这个困惑的人真的太多了! 而且你也会发现,你身边的大多数人也都很迷茫。 家庭、学历一般,没啥特长爱好,更没有拿的出手的技能。 想要告诉你的是,你觉得你自己一无所长&#…

基于Matlab求解2023华为杯研究生数学建模竞赛E题——出血性脑卒中临床智能诊疗建模实现步骤(附上源码+数据)

文章目录,源码见文末下载 背景介绍准备工作:处理数据第一题:血肿扩张风险相关因素探索建模a)问题b)问题 第二题: 血肿周围水肿的发生及进展建模,并探索治疗干预和水肿进展的关联关系a&#xff0…

图像的读写与保存

图像是由众多的像素值构成的,我们如何去操作图像呢? 答案就是将图像转化为数组。 OpenCV提供了这样的方法。 我们使用cv2.imread()方法读取图片,返回数组格式。 对于cv2.imread(filename, flags)函数参数如下: 参数filename&a…

Adaptive AUTOSAR CM模块介绍(二)

在Adaptive AUTOSAR CM模块介绍(一)中介绍了 AP CM模块的功能和定位,这一篇主要是讲解AP CM模块的ara::com API的内容: 为什么AUTOSAR发明了另一种通信中间件API/技术?在当时中间件技术有很多啊?在当时特别有名的中间件有&#xf…