PyTorch中的多进程并行处理

news2025/1/4 19:09:13

PyTorch是一个流行的深度学习框架,一般情况下使用单个GPU进行计算时是十分方便的。但是当涉及到处理大规模数据和并行处理时,需要利用多个GPU。这时PyTorch就显得不那么方便,所以这篇文章我们将介绍如何利用torch.multiprocessing模块,在PyTorch中实现高效的多进程处理。

多进程是一种允许多个进程并发运行的方法,利用多个CPU内核和GPU进行并行计算。这可以大大提高数据加载、模型训练和推理等任务的性能。PyTorch提供了torch.multiprocessing模块来解决这个问题。

导入库

 import torch
 import torch.multiprocessing as mp
 from torch import nn, optim

对于多进程的问题,我们主要要解决2方面的问题:1、数据的加载;2分布式的训练

数据加载

加载和预处理大型数据集可能是一个瓶颈。使用torch.utils.data.DataLoader和多个worker可以缓解这个问题。

 from torch.utils.data import DataLoader, Dataset
 class CustomDataset(Dataset):
     def __init__(self, data):
         self.data = data
     def __len__(self):
         return len(self.data)
     def __getitem__(self, idx):
         return self.data[idx]
 data = [i for i in range(1000)]
 dataset = CustomDataset(data)
 dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
 for batch in dataloader:
     print(batch)

num_workers=4意味着四个子进程将并行加载数据。这个方法可以在单个GPU时使用,通过增加数据读取进程可以加快数据读取的速度,提高训练效率。

分布式训练

分布式训练包括将训练过程分散到多个设备上。torch.multiprocessing可以用来实现这一点。

我们一般的训练流程是这样的

 class SimpleModel(nn.Module):
     def __init__(self):
         super(SimpleModel, self).__init__()
         self.fc = nn.Linear(10, 1)
 def forward(self, x):
         return self.fc(x)
 def train(rank, model, data, target, optimizer, criterion, epochs):
     for epoch in range(epochs):
         optimizer.zero_grad()
         output = model(data)
         loss = criterion(output, target)
         loss.backward()
         optimizer.step()
         print(f"Process {rank}, Epoch {epoch}, Loss: {loss.item()}")

要修改这个流程,我们首先需要初始和共享模型

 def main():
     num_processes = 4
     data = torch.randn(100, 10)
     target = torch.randn(100, 1)
     model = SimpleModel()
     model.share_memory()  # Share the model parameters among processes
     optimizer = optim.SGD(model.parameters(), lr=0.01)
     criterion = nn.MSELoss()
     processes = []
     for rank in range(num_processes):
         p = mp.Process(target=train, args=(rank, model, data, target, optimizer, criterion, 10))
         p.start()
         processes.append(p)
     for p in processes:
         p.join()
 if __name__ == '__main__':
     main()

上面的例子中四个进程同时运行训练函数,共享模型参数。

多GPU的话则可以使用分布式数据并行(DDP)训练

对于大规模的分布式训练,PyTorch的torch.nn.parallel.DistributedDataParallel(DDP)是非常高效的。DDP可以封装模块并将其分布在多个进程和gpu上,为训练大型模型提供近线性缩放。

 import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP

修改train函数初始化流程组并使用DDP包装模型。

 def train(rank, world_size, data, target, epochs):
     dist.init_process_group("gloo", rank=rank, world_size=world_size)
     
     model = SimpleModel().to(rank)
     ddp_model = DDP(model, device_ids=[rank])
     
     optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
     criterion = nn.MSELoss()
 
     for epoch in range(epochs):
         optimizer.zero_grad()
         output = ddp_model(data.to(rank))
         loss = criterion(output, target.to(rank))
         loss.backward()
         optimizer.step()
         print(f"Process {rank}, Epoch {epoch}, Loss: {loss.item()}")
 
     dist.destroy_process_group()

修改main函数增加world_size参数并调整进程初始化以传递world_size。

 def main():
     num_processes = 4
     world_size = num_processes
     data = torch.randn(100, 10)
     target = torch.randn(100, 1)
     mp.spawn(train, args=(world_size, data, target, 10), nprocs=num_processes, join=True)
 if __name__ == '__main__':
     mp.set_start_method('spawn')
     main()

这样,就可以在多个GPU上进行训练了

常见问题及解决

1、避免死锁

在脚本的开头使用mp.set_start_method(‘spawn’)来避免死锁。

 if __name__ == '__main__':
     mp.set_start_method('spawn')
     main()

因为多线程需要自己管理资源,所以请确保清理资源,防止内存泄漏。

2、异步执行

异步执行允许进程独立并发地运行,通常用于非阻塞操作。

 def async_task(rank):
     print(f"Starting task in process {rank}")
     # Simulate some work with sleep
     torch.sleep(1)
     print(f"Ending task in process {rank}")
 def main_async():
     num_processes = 4
     processes = []
     
     for rank in range(num_processes):
         p = mp.Process(target=async_task, args=(rank,))
         p.start()
         processes.append(p)
     
     for p in processes:
         p.join()
 if __name__ == '__main__':
     main_async()

3、共享内存管理

使用共享内存允许不同的进程在不复制数据的情况下处理相同的数据,从而减少内存开销并提高性能。

 def shared_memory_task(shared_tensor, rank):
     shared_tensor[rank] = shared_tensor[rank] + rank
 def main_shared_memory():
     shared_tensor = torch.zeros(4, 4).share_memory_()
     processes = []
     
     for rank in range(4):
         p = mp.Process(target=shared_memory_task, args=(shared_tensor, rank))
         p.start()
         processes.append(p)
     
     for p in processes:
         p.join()
     print(shared_tensor)
 if __name__ == '__main__':
     main_shared_memory()

共享张量shared_tensor可以被多个进程修改

总结

PyTorch中的多线程处理可以显著提高性能,特别是在数据加载和分布式训练时使用torch.multiprocessing模块,可以有效地利用多个cpu,从而实现更快、更高效的计算。无论您是在处理大型数据集还是训练复杂模型,理解和利用多处理技术对于优化PyTorch中的性能都是必不可少的。使用分布式数据并行(DDP)进一步增强了跨多个gpu扩展训练的能力,使其成为大规模深度学习任务的强大工具。

https://avoid.overfit.cn/post/a68990d2d9d14d26a4641bbaf265671e

作者:Ali ABUSALEH

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1905586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

含并行连结的网络

一、Inception块 1、白色部分通过降低通道数来控制模型复杂度,蓝色做特征提取工作,每条路上的通道数可能不同,大概我们会把更重要的那部分特征分配更多的通道数 2、Inception只改变高宽,不改变通道数 3、在不同的情况下需要选择…

渐开线花键测量学习笔记分享

大家好,继续渐开线花键的相关内容,本期是渐开线花键测量相关的学习笔记分享: 花键检测项目有花键大径和小径检验;内花键齿槽宽和外花键齿厚,以及渐开线终止圆 和起始圆直径检测;齿距累计误差 、齿形误差 、…

MySQL—统计函数和数学函数以及GROUP BY配合HAVING

合计/统计函数 count -- 演示 mysql 的统计函数的使用 -- 统计一个班级共有多少学生? SELECT COUNT(*) FROM student -- 统计数学成绩大于 90 的学生有多少个? SELECT COUNT(*) FROM student WHERE math > 90 -- 统计总分大于 250 的人数有多少&…

Centos新手问题——yum无法下载软件

起因: 最近在学习centos7,在VM上成功安装后,用Secure进行远程登陆。然后准备下载一个C编译器,看网络上的教程,都是用yum来下载,于是我也输入了命令: yum -y install gcc* 本以为会自动下载&a…

数据统计与数据分组18-25题(30 天 Pandas 挑战)

数据统计与数据分组 1. 知识点1.18 分箱与统计个数1.19 分组与求和统计1.20 分组获取最小值1.21 分组获取值个数1.22 分组与条件查询1.23 分组与条件查询及获取最大值1.24 分组及自定义函数1.25 分组lambda函数统计 2. 题目2.18 按分类统计薪水(数据统计&#xff09…

《python程序语言设计》2018版第5章第52题利用turtle绘制sin函数

这道题是送分题。因为循环方式已经写到很清楚,大家照抄就可以了。 但是如果说光照抄可是会有问题。比如我们来演示一下。 import turtleturtle.penup() turtle.goto(-175, 50 * math.sin((-175 / 100 * 2 * math.pi))) turtle.pendown() for x in range(-175, 176…

5款屏幕监控软件精选|电脑屏幕监控软件分享

屏幕监控软件在现代工作环境中扮演着越来越重要的角色,无论是为了提高员工的工作效率,还是为了保障企业数据的安全,它们都成为了不可或缺的工具。 下面,让我们以一种新颖且易于理解的方式,来介绍五款备受好评的屏幕监…

前端JS特效第21集:HTML5响应式多种切换效果轮播大图切换js特效代码

HTML5响应式多种切换效果轮播大图切换js特效代码&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-t…

灵活视图变换器:为扩散模型设计的革新图像生成架构

在自然界中&#xff0c;图像的分辨率是无限的&#xff0c;而现有的图像生成模型在跨任意分辨率泛化方面存在困难。虽然扩散变换器&#xff08;DiT&#xff09;在特定分辨率范围内表现出色&#xff0c;但在处理不同分辨率的图像时却力不从心。为了克服这一限制&#xff0c;来自上…

MySQL篇三:数据类型

文章目录 前言1. 数值类型1.1 tinyint类型1.2 bit类型1.3 小数类型1.3.1 float1.3.2 decimal 2. 字符串类型2.1 char2.2 varchar2.3 char和varchar比较 3. 日期类型4. enum和set 前言 数据类型分类&#xff1a; 1. 数值类型 1.1 tinyint类型 在MySQL中&#xff0c;整型可以指…

MPS---MPQ86960芯片layout设计总结

MPQ86960 是一款内置功率 MOSFET 和栅极驱动的单片半桥。它可以在宽输入电压 (VIN) 范围内实现高达 50A 的连续输出电流 (IOUT)&#xff0c;通过集成MOSFET 和驱动可优化死区时间 (DT) 并降低寄生电感&#xff0c;从而实现高效率。 MPQ86960 兼容三态输出控制器&#xff0c;另…

[附源码]基于Flask的演唱会购票系统

摘要 随着互联网技术的普及和发展&#xff0c;传统购票方式因其效率低下、流程繁琐等问题已难以满足现代社会的需求。本文设计并实现了一个基于Flask框架的演唱会购票系统&#xff0c;该系统集成了用户管理、演唱会信息管理、票务管理以及数据统计与分析等功能模块&#xff0c…

如何让代码兼容 Python 2 和 Python 3?Future 库助你一臂之力

目录 01Future 是什么? 为什么选择 Future? 安装与配置 02Future 的基本用法 1、兼容 print 函数 2、兼容整数除法 3、兼容 Unicode 字符串 03Future 的高级功能 1. 处理字符串与字节 2. 统一异常处理…

nullptr和NULL

nullptr 既不是整型类型&#xff0c;也不是指针类型&#xff0c;nullptr 的类型是 std::nullptr_t&#xff08;空指针类型&#xff09;&#xff0c;能转换成任意的指针类型。 NULL是被定义为0的常量&#xff0c;当遇到函数重载时&#xff0c;就会出现问题。避免歧义 函数重载…

Django QuerySet对象,filter()方法

filter()方法 用于实现数据过滤功能&#xff0c;相当于sql语句中的where子句。 filter(字段名__exact10) 或 filter(字段名10)类似sql 中的 10 filter(字段名__gt10) 类似SQL中的 >10 filter(price__lt29.99) 类似sql中的 <29.99 filter(字段名__gte10, 字段名__lte20…

ELFK简介

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

初识java—jdk17的一些新增特性

文章目录 前言一 &#xff1a; yield关键字二 &#xff1a;var关键字三 &#xff1a;密封类四 &#xff1a;空指针异常&#xff1a;五&#xff1a;接口中的私有方法&#xff1a;六&#xff1a;instanceof关键字 前言 这里介绍jdk17相对于jdk1.8的部分新增特性。 一 &#xff…

python集成Bartender实现二维码打印

本文摘录于&#xff1a;https://blog.csdn.net/mynameisJW/article/details/105500773只是做学习备份之用&#xff0c;绝无抄袭之意&#xff0c;有疑惑请联系本人&#xff01; 这里上传我优化了一下的代码:https://download.csdn.net/download/chengdong1314/89522026 我这里弄…

GuLi商城-商品服务-API-品牌管理-OSS整合测试

各语言SDK参考文档_对象存储(OSS)-阿里云帮助中心 安装SDK&#xff1a; <dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.17.4</version> </dependency> 测试上传文件流&…

【leetcode周赛记录——405】

405周赛记录 #1.leetcode100339_找出加密后的字符串2.leetcode100328_生成不含相邻零的二进制字符串3.leetcode100359_统计X和Y频数相等的子矩阵数量4.leetcode100350_最小代价构造字符串 刷了一段时间算法了&#xff0c;打打周赛看看什么水平了 #1.leetcode100339_找出加密后的…