pytorch 多机多卡训练方法

news2025/1/28 1:09:33

        在深度学习训练中,使用多机多卡(多台机器和多块 GPU)可以显著加速模型训练过程。 PyTorch 提供了多种方法来实现多机多卡训练,以下是一些常用的方法和步骤:

1. 使用 torch.distributed 包

        PyTorch 的 torch.distributed 包提供了分布式训练的支持。以下是使用 torch.distributed 进行多机多卡训练的步骤:

1.1 环境设置

        首先,确保每台机器上都安装了相同版本的 PyTorch 和 CUDA 。然后,设置环境变量:

export MASTER_ADDR="主节点的 IP 地址"
export MASTER_PORT="主节点的端口号"
export WORLD_SIZE="总的进程数(机器数 * 每台机器的 GPU 数)"
export RANK="当前进程的全局排名(从 0开始)"

1.2 初始化分布式环境

在代码中,使用 torch.distributed.init_process_group 初始化分布式环境:

import torch
import torch.distributed as dist

def init_distributed_mode():
    dist.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(local_rank)

1.3 创建模型和优化器

将模型和优化器移动到 GPU 上,并使用 torch.nn.parallel.DistributedDataParallel 包装模型:

model = MyModel().cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

1.4 使用 DistributedSampler

在数据加载时,使用 torch.utils.data.distributed.DistributedSampler 来确保每个进程加载不同的数据:

from torch.utils.data import DataLoader, DistributedSampler

train_dataset = MyDataset()
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler)

1.5 训练循环

在训练循环中,确保每个 epoch 开始时调用 train_sampler.set_epoch(epoch)

for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)
    for batch in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

2. 使用 torchrun 启动分布式训练

PyTorch 提供了 torchrun 工具来简化分布式训练的启动过程。以下是使用 torchrun 的步骤:

2.1 编写训练脚本

编写一个标准的训练脚本,确保包含分布式训练的初始化代码。

2.2 使用 torchrun 启动训练

在命令行中使用 torchrun 启动训练:

torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr="主节点的 IP 地址" --master_port="主节点的端口号" train.py

3. 使用 torch.distributed.launch 启动分布式训练

torch.distributed.launch 是另一种启动分布式训练的工具,但它已经被 torchrun 所取代。以下是使用 torch.distributed.launch 的步骤:

3.1 编写训练脚本

编写一个标准的训练脚本,确保包含分布式训练的初始化代码。

3.2 使用 torch.distributed.launch 启动训练

在命令行中使用 torch.distributed.launch 启动训练:

python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="主节点的 IP 地址" --master_port="主节点的端口号" train.py

4. 使用 torch.nn.DataParallel(单机多卡)

如果只需要在单台机器上使用多块 GPU,可以使用 torch.nn.DataParallel

model = MyModel()
model = torch.nn.DataParallel(model)
model = model.cuda()

for batch in train_loader:
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

5.示例 

'''多机多卡分布式训练
第一台机器
python -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="10.21.73.19" --master_port=12355 分布式训练_多机.py  --batch-size 6 --epochs 10 --lr 1e-6 --eval-steps 1000 --max-val-item-count 1000 --use-lora

第二台机器
python -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="10.21.73.19" --master_port=12355 分布式训练_多机.py  --batch-size 6 --epochs 10 --lr 1e-6 --eval-steps 1000 --max-val-item-count 1000 --use-lora

batchsize是每台机器上的batchsize
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2283650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

rocketmq-product-send方法源码分析

先看有哪些send方法 首先说红圈的 有3个红圈。归类成3种发送方式。假设前提条件,发送的topic,有3个broker,每个broker总共4个write队列,总共有12个队列。 普通发送。负载均衡12个队列。指定超时时间指定MessageQueue,发送&#…

69.在 Vue 3 中使用 OpenLayers 拖拽实现放大区域的效果(DragPan)

引言 在现代 Web 开发中,地图功能已经成为许多应用的重要组成部分。OpenLayers 是一个功能强大的开源地图库,支持多种地图源和交互操作。Vue 3 是一个流行的前端框架,以其响应式数据和组件化开发著称。本文将介绍如何在 Vue 3 中集成 OpenLa…

77,【1】.[CISCN2019 华东南赛区]Web4

有句英文,看看什么意思 好像也可以不看 进入靶场 点击蓝色字体 我勒个豆,百度哇 所以重点应该在url上,属于任意文件读取类型 接下来该判断框架了 常见的web框架如下 一,Python 框架 1.Flask URL 示例 1:http://…

手撕B-树

一、概述 1.历史 B树(B-Tree)结构是一种高效存储和查询数据的方法,它的历史可以追溯到1970年代早期。B树的发明人Rudolf Bayer和Edward M. McCreight分别发表了一篇论文介绍了B树。这篇论文是1972年发表于《ACM Transactions on Database S…

一文简单回顾复习Java基础概念

还是和往常一样,我以提问的方式回顾复习,今天回顾下Java小白入门应该知道的一些基础知识 Java语言有哪些特点呢? Java语言的特点有: 面向对象,主要是封装、继承、多态;平台无关性,“一次编写…

GCC之编译(8)AR打包命令

GCC之(8)AR二进制打包命令 Author: Once Day Date: 2025年1月23日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章请查看专栏: Linux实践记录_Once-Day的博客-C…

2.1.3 第一个工程,点灯!

新建工程 点击菜单栏左上角,新建工程或者选择“文件”-“新建工程”,选择工程类型“标准工程”选择设备类型和编程语言,并指定工程文件名及保存路径,如下图所示: 选择工程类型为“标准工程” 选择主模块机型&#x…

图像处理算法研究的程序框架

目录 1 程序框架简介 2 C#图像读取、显示、保存模块 3 C动态库图像算法模块 4 C#调用C动态库 5 演示Demo 5.1 开发环境 5.2 功能介绍 5.3 下载地址 参考 1 程序框架简介 一个图像处理算法研究的常用程序逻辑框架,如下图所示 在该框架中,将图像处…

计算机工程:解锁未来科技之门!

计算机工程与应用是一个充满无限可能性的领域。随着科技的迅猛发展,计算机技术已经深深渗透到我们生活的方方面面,从医疗、金融到教育,无一不在彰显着计算机工程的巨大魅力和潜力。 在医疗行业,计算机技术的应用尤为突出。比如&a…

Linux初识——基本指令(2)

本文将继续从上篇末尾讲起,讲解我们剩下的基本指令 一、剩余的基本指令 1、mv mv指令是move(移动)的缩写,其功能为:1.剪切文件、目录。2.重命名 先演示下重命名,假设我想把当前目录下的di34改成dir5 那…

单片机-STM32 WIFI模块--ESP8266 (十二)

1.WIFI模块--ESP8266 名字由来: Wi-Fi这个术语被人们普遍误以为是指无线保真(Wireless Fidelity),并且即便是Wi-Fi联盟本身也经常在新闻稿和文件中使用“Wireless Fidelity”这个词,Wi-Fi还出现在ITAA的一个论文中。…

80,【4】BUUCTF WEB [SUCTF 2018]MultiSQL

53,【3】BUUCTF WEB october 2019 Twice SQLinjection-CSDN博客 上面这个链接是我第一次接触二次注入 这道题也涉及了 对二次注入不熟悉的可以看看 BUUCTF出了点问题,打不开,以下面这两篇wp作为学习对象 [SUCTF 2018]MultiSQL-CSDN博客 …

Prometheus部署及linux、mysql、monog、redis、RocketMQ、java_jvm监控配置

Prometheus部署及linux、mysql、monog、redis、RocketMQ、java_jvm监控配置 1.Prometheus部署1.2.Prometheus修改默认端口 2.grafana可视化页面部署3.alertmanager部署4.监控配置4.1.主机监控node-exporter4.2.监控mysql数据库mysqld_exporter4.3.监控mongod数据库mongodb_expo…

问题排查 - TC397 CORE2 50MS/100MS任务不运行

1、问题描述 CORE2 的任务运行次数的计数值OsTask_100ms_Core2 - task_cnt[12]、OsTask_50ms_Core2 - task_cnt[16]不在累加,但是其他任务OsAlarm_1ms_Core2、OsAlarm_5ms_Core2、OsAlarm_10ms_Core2、OsAlarm_20ms_Core2 任务计数值累加正常。 如果是任务栈溢出&a…

Spring FatJar写文件到RCE分析

背景 现在生产环境部署 spring boot 项目一般都是将其打包成一个 FatJar,即把所有依赖的第三方 jar 也打包进自身的 app.jar 中,最后以 java -jar app.jar 形式来运行整个项目。 运行时项目的 classpath 包括 app.jar 中的 BOOT-INF/classes 目录和 BO…

百度APP iOS端磁盘优化实践(上)

01 概览 在APP的开发中,磁盘管理已成为不可忽视的部分。随着功能的复杂化和数据量的快速增长,如何高效管理磁盘空间直接关系到用户体验和APP性能。本文将结合磁盘管理的实践经验,详细介绍iOS沙盒环境下的文件存储规范,探讨业务缓…

蓝桥杯之c++入门(一)【第一个c++程序】

目录 前言一、第⼀个C程序1.1 基础程序1.2 main函数1.3 字符串1.4 头文件1.5 cin 和 cout 初识1.6 名字空间1.7 注释 二、四道简单习题(点击跳转链接)练习1:Hello,World!练习2:打印飞机练习3:第⼆个整数练习4&#xff…

14-6-1C++STL的list

(一)list容器的基本概念 list容器简介: 1.list是一个双向链表容器,可高效地进行插入删除元素 2.list不可以随机存取元素,所以不支持at.(pos)函数与[ ]操作符 (二)list容器头部和尾部的操作 list对象的默…

【AI论文】Sigma:对查询、键和值进行差分缩放,以实现高效语言模型

摘要:我们推出了Sigma,这是一个专为系统领域设计的高效大型语言模型,其独特之处在于采用了包括DiffQKV注意力机制在内的新型架构,并在我们精心收集的系统领域数据上进行了预训练。DiffQKV注意力机制通过根据查询(Q&…

InceptionV1_V2

目录 不同大小的感受野去提取特征 经典 Inception 网络的设计思路与运行流程 背景任务:图像分类(以 CIFAR-10 数据集为例) Inception 网络的设计思路 Inception 网络的运行流程 打个比方 多个损失函数的理解 1. 为什么需要多个损失函数&#…