pytorch ddp 范例

news2025/1/8 5:41:36

pytorch ddp 范例:

################
## main.py文件
import argparse
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
# 新增:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

### 1. 基础模块 ### 
# 假设我们的模型是这个,与DDP无关
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv2_bn = nn.BatchNorm2d(16, eps=1e-4, momentum=0.01)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        #x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv2_bn(self.conv2(x))))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 假设我们的数据是这个
def get_dataset():
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    my_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
        download=True, transform=transform)
    # DDP:使用DistributedSampler,DDP帮我们把细节都封装起来了。
    #      用,就完事儿!sampler的原理,第二篇中有介绍。
    train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
    # DDP:需要注意的是,这里的batch_size指的是每个进程下的batch_size。
    #      也就是说,总batch_size是这里的batch_size再乘以并行数(world_size)。
    trainloader = torch.utils.data.DataLoader(my_trainset, 
        batch_size=16, num_workers=2, sampler=train_sampler)
    return trainloader
    
### 2. 初始化我们的模型、数据、各种配置  ####
# DDP:从外部得到local_rank参数
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1, type=int)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank

# DDP:DDP backend初始化
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')  # nccl是GPU设备上最快、最推荐的后端

# 准备数据,要在DDP初始化之后进行
trainloader = get_dataset()

# 构造模型
model = ToyModel().to(local_rank)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)

# DDP: Load模型要在构造DDP模型之前,且只需要在master上加载就行了。
ckpt_path = None
if dist.get_rank() == 0 and ckpt_path is not None:
    model.load_state_dict(torch.load(ckpt_path))
# DDP: 构造DDP model
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

# DDP: 要在构造DDP model之后,才能用model初始化optimizer。
#因为optimizer和DDP是没有关系的,所以optimizer初始状态的同一性是不被DDP保证的!
#大多数官方optimizer,其实现能保证从同样状态的model初始化时,其初始状态是相同的。
#所以这边我们只要保证在DDP模型创建后才初始化optimizer,就不用做额外的操作。
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# 假设我们的loss是这个
loss_func = nn.CrossEntropyLoss().to(local_rank)

### 3. 网络训练  ###
model.train()
iterator = tqdm(range(100))
for epoch in iterator:
    # DDP:设置sampler的epoch,
    # DistributedSampler需要这个来指定shuffle方式,
    # 通过维持各个进程之间的相同随机数种子使不同进程能获得同样的shuffle效果。
    trainloader.sampler.set_epoch(epoch)
    # 后面这部分,则与原来完全一致了。
    for data, label in trainloader:
        data, label = data.to(local_rank), label.to(local_rank)
        optimizer.zero_grad()
        prediction = model(data)
        loss = loss_func(prediction, label)
        loss.backward()
        iterator.desc = "loss = %0.3f" % loss
        optimizer.step()
    # DDP:
    # 1. save模型的时候,和DP模式一样,有一个需要注意的点:保存的是model.module而不是model。
    #    因为model其实是DDP model,参数是被`model=DDP(model)`包起来的。
    # 2. 只需要在进程0上保存一次就行了,避免多次保存重复的东西。
    if dist.get_rank() == 0:
        torch.save(model.module.state_dict(), "%d.ckpt" % epoch)


################
## Bash运行
# DDP: 使用torch.distributed.launch启动DDP模式
# 使用CUDA_VISIBLE_DEVICES,来决定使用哪些GPU
# CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node 2 main.py

pytorch ddp原理

  1. 更加深入的了解了下ddp模式下index的分配机制。
    比如总共10个数据, 在程序开始的时候会随机打乱总的indices。
    由于每张卡上打乱的随机种子是相同的,因此可以保证每个进程上的数据集是不重复的,并且能取到所有的数据集。
    A机器数据量10,B机器数据量10,batchsize都是2

master机器分配的:
indices=[4, 7, 3, 0, 6]
Slave机器分配的:
indices=[1, 5, 9, 8, 2]

通过代码加的打印信息如下:

  1. 通过实验得知,DDP模式下都是根据当前机器上面的数据集来确定数据量大小的,只是在划分数据index的时候根据卡数来平分,
indices = indices[self.rank:self.total_size:self.num_replicas]

并且每轮迭代都会重新打乱总的indices。
附pytorch相关源码:
anaconda3/envs/pytorch1.7.0_general/lib/python3.7/site-packages/torch/utils/data/distributed.py

DistributedSampler的__iter__函数
def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            """
                由于shuffle=True,因此这一步必定是会执行的
                根据self.epoch + self.seed来确定每一个进程的都是一样的
            """
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        # 子采样,指定步长为显卡的数量,根据每张卡的不同次序,指定起点
        # 由于每张卡上打乱的进程是相同的,因此可以保证每个进程上的数据集是不重复的,并且能取到所有的数据集
        indices = indices[self.rank:self.total_size:self.num_replicas]
        # self.num_samples在初始化的时候就已经是所有样本除以进程数量以后的
        # 这里确保取得的索引是和样本数量长度相等的,由于是assert断言,因此必然是相等的
        assert len(indices) == self.num_samples

        return iter(indices)

https://blog.csdn.net/yang332233/article/details/129020200?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168630008716800213026543%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=168630008716800213026543&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_ecpm_v1~rank_v31_ecpm-2-129020200-null-null.268v1koosearch&utm_term=ddp&spm=1018.2226.3001.4450

https://blog.csdn.net/yang332233/article/details/129053867?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168630012016800215051330%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=168630012016800215051330&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_ecpm_v1~rank_v31_ecpm-1-129053867-null-null.268v1koosearch&utm_term=ddp&spm=1018.2226.3001.4450

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/630318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始手搓一个STM32与机智云的小项目——GPIO的输入输出

文章目录 前言GPIO简介GPIO的命名与数量GPIO的功能STM32F1 GPIO的寄存器 库函数开发搭建库函数的工程查看原理图WACK_UP输入按键继电器输出138控制流水灯 代码编写库函数简介GPIO输出模式控制继电器通过138控制ledGPIO实现按键输入的操作编写逻辑代码 实物效果 总结 前言 上一…

Redis学习总结(二)

AOF 为什么是在执行完命令之后记录日志&#xff1f; 关系型数据库&#xff08;如 MySQL&#xff09;通常都是执行命令之前记录日志&#xff08;方便故障恢复&#xff09;&#xff0c;而 Redis AOF 持久化机制是在执行完命令之后再记录日志。AOF 记录日志过程为什么是在执行完命…

如何让GPT不再胡说八道

相信我们大部分人在使用GPT的时候&#xff0c;会发现GPT经常在胡言乱语、回复错误的答案等情况&#xff0c;甚至有的内容牛头不对马嘴&#xff0c;直接开始编造&#xff0c;例如下面案例&#xff1a; 我&#xff1a; 周树人是谁 GPT&#xff1a;周树人 (1897年-1975年) &…

独立开发变现周刊(第90期):自学开发了一个36万美元/年的ChatGPT应用

分享独立开发、产品变现相关内容&#xff0c;每周五发布。 目录 1、ChatGPT-Midjourney: 开源 ChatGPTMidjourney 网页应用2、PLExtension: 一个图床上传浏览器扩展3、EasySpider: 一个可视化爬虫软件4、BibiGPT: 音视频 AI 一键总结 & 对话5、自学的程序员开发了一个36万美…

【i阿极送书——第四期】《ChatGPT时代:ChatGPT全能应用一本通》

系列文章目录 作者&#xff1a;i阿极 作者简介&#xff1a;数据分析领域优质创作者、多项比赛获奖者&#xff1a;博主个人首页 &#x1f60a;&#x1f60a;&#x1f60a;如果觉得文章不错或能帮助到你学习&#xff0c;可以点赞&#x1f44d;收藏&#x1f4c1;评论&#x1f4d2;…

MongoDB集群和安全

目录 副本集-Replica Sets简介副本集的三个角色副本集架构目标副本集的创建主节点副本节点仲裁节点初始化配置副本集和主节点查看副本集的配置内容查看副本集状态添加副本从节点添加仲裁从节点副本集的数据读写操作 主节点的选举原则完整的连接字符串 分片集群-Sharded Cluster…

spring杂记

1、springboot是如何解析yml配置文件中的 tomcat配置&#xff0c;并将其赋值给 tomcat的 重要类 ServerProperties。该类为解析yml文件中的server配置 下面我们主要看看是怎样将 端口号 port 赋值给tomcat的 找到port属性&#xff0c;点击getter方法 发现调用该方法的地方为 …

在弹出框内三个元素做水平显示

最终效果图要求是这样&#xff1a; js代码&#xff1a; // 显示弹出窗口 function showPopup(node) {var popup document.createElement(div);popup.className popup;var inputContainer1 document.createElement(div);/* inputContainer1.className input-container1; */…

Upscayl:开源AI图像放大增强工具 | AIGC实践

连续写了两篇比较理论的文章——一篇行业思考&#xff0c;一篇技术讨论——可能劝退了很多不明真相的人民群众&#xff0c;一看后台数据&#xff0c;好么…… 马上周末了&#xff0c;今天分享一篇轻松小文&#xff0c;介绍一款开源免费、成熟度高、操作简单、效果显著的开源AI图…

记录--开始使用Vue 3时应避免的10个错误

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 Vue 3 稳定已经有一段时间了。许多代码库正在生产中使用它&#xff0c;其他人最终也必须进行迁移。我有机会与它一起工作&#xff0c;并记录了我的错误&#xff0c;这可能是你想避免的。 1.使用响应式…

.Net8罕见的技术:MSIL的机器码简析

前言 一般的只有最终的汇编代码才有机器码表示&#xff0c;然一个偶然的机会发现&#xff0c;MSIL(Microsoft intermediate language)作为一个中间语言表示&#xff0c;居然也有机器码&#xff0c;其实这也难怪&#xff0c;计算机里面万物都是二进制&#xff0c;本篇来看下,以下…

【GitHub探索】用python写web前端之reactpy探索

你有想象过用python来写web前端这种操作么&#xff1f;近期在github-trending上就有这样的一个项目reactpy&#xff0c;可以满足你在python上写web前端的欲望。为此&#xff0c;笔者也决定踩踩坑&#xff0c;看看这个项目的形式到底如何&#xff0c;能不能很方便地实际投产。 …

对比 document.URL 和 location.href

对比 document.URL 和 location.href document.URL 和 location.href 的不同点 document.URL只读 , location.href读写 给 document.URL 赋值, document.URL的值不会改变 给 location.href 赋值, location.href 的值改变了, 并且页面也改变了, 效果和 location.assign()一样…

解数独--难的一批

1题目 编写一个程序&#xff0c;通过填充空格来解决数独问题。 数独的解法需 遵循如下规则&#xff1a; 数字 1-9 在每一行只能出现一次。数字 1-9 在每一列只能出现一次。数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。&#xff08;请参考示例图&#xff09; 数…

【MySQL】数据库SQL语句之DML

目录 前言&#xff1a; 一.DML添加数据 1.1给指定字段添加数据 1.2给全部字段添加数据 1.3批量添加数据 二.DML修改数据 三.DML删除数据 四.结尾 前言&#xff1a; 时隔一周&#xff0c;啊苏今天来更新啦&#xff0c;简单说说这周在做些什么吧&#xff0c;上课、看书、…

ubuntu编译安装pcl

环境配置&#xff1a; ubuntu18.04pcl1.11.0 下载源码并解压 tar -zxvf pcl-pcl-1.11.0.tar.gz 进入解压后的文件夹、建立bulid文件夹并进入该文件夹 安装依赖 sudo apt-get update 使用apt-get包管理器安装CMake&#xff1a; sudo apt-get install cmake 使用apt-get包管理…

创新案例 | 新锐品牌Usmile如何借助社媒运营打造爆品成为国产电动牙刷TOP1?

Usmile 是广州星际悦动股份有限公司旗下全面口腔护理品牌。2016 年至今&#xff0c;Usmile共荣获了 16 项国内外设计大奖&#xff0c;2020 年“双十一”期间&#xff0c;入选 2020 年度天猫十大新品牌&#xff0c;销售额超 1 亿&#xff0c;成为国内首个破亿的电动牙刷品牌&…

【立体视觉(一)】之成像原理与相机畸变

【立体视觉&#xff08;一&#xff09;】之成像原理与相机畸变 一、成像原理一&#xff09;针孔模型二&#xff09;坐标系转换1. 世界坐标系到相机坐标系2. 相机坐标系到图像坐标系3. 图像坐标系到像素坐标系4. 相机坐标系到像素坐标系5. 世界坐标系到像素坐标系 二、相机畸变一…

618数码节该如何挑选,推荐几款618值得入手的数码好物

又到了一年一度的618剁手季&#xff0c;各大电商平台都纷纷推出了超级大促活动&#xff0c;激发了无数值友的狂热购物欲望。你是否也已经开始摩拳擦掌&#xff0c;准备掏钱包买买买呢&#xff1f;那么赶快听听小编的建议吧&#xff01;经过自己使用的亲身体验&#xff0c;小编给…

Superset | 地图无法显示的问题

知识目录 一、写在前面二、Superset地图显示不了三、Superset无法加载已更新的MySQL数据库数据 一、写在前面 大家好&#xff01;我是初心&#xff0c;一直在寻找并尝试着适合自己的方向&#xff01; Apache Superset是一款由Python语言为主开发的开源时髦数据探索分析以及可…