Pytorch多GPU分布式训练代码编写

news2025/1/22 21:03:46

Pytorch多GPU分布式训练代码编写

一、数据并行

1.单机单卡

  • 模型拷贝

    • model.cuda() 原地操作
  • 数据拷贝(每步)

    • data=data.cuda() 非原地操作
  • 基于torch.cuda.is_available()来判断是否可用

  • 模型保存与加载

    • torch.save 来保存模型、优化器、其他变量
    • torch.load(file.pt,map_location=torch.device(“cuda”/“cuda:0”/“cpu”))

实际演示

  1. 环境检测

    • 代码中操作

      if __name__ =="__main__":
          if torch.cuda.is_available():
              logging.warning("Cuda is available!")
              os.environ["CUDA_VISIBLE_DEVICES"]="0" #指定使用第0号gpu
          else:
              logging.warning("Cuda is not available! Exit!")
              return
      
    • 命令行操作

      CUDA_VISIBLE_DEVICE="0" python xxx.py
      
  2. 模型拷贝

    image-20240906215319656

  3. 数据拷贝

    image-20240906215835882

  4. 模型保存与加载

    image-20240906220229669

​ 以上便是单机单卡的模型的加载和训练

2.单机多卡

  • 检测GPU数目
    • torch.cuda.device_count()
    • 可用通过命令行CUDA_VISIBLE_DEVICES=”“ 来限制GPU卡的使用
  • torch.nn.DataParallel API(已经淘汰)
    • 简单一行代码,包裹model即可
      • model=DataParallel(model.cuda(),device_ids=[0,1,2,3])
      • data=data.cuda()
    • 模型保存与加载
      • torch.save 注意模型需要调用model.module.state_dict()
      • torch.load 需要注意map_location的使用
    • 缺点
      • 单进程,效率慢
      • 不支持多机情况
      • 不支持模型并行
    • 注意实现
      • 此处的【dataloder中】batch_size应该是每个GPU的batch_size的总和

​ 演示,我们将之前的单卡的程序改造成多卡:

  1. 检测GPU数量

image-20240906223326186

  1. 用DataParallel来包裹住模型

    image-20240907074343558

  2. save的地方需要修改

    image-20240907074525268

  3. load时的注意事项

    image-20240907075048175

  4. batch_size的设置

    image-20240907075303926

  • torch.nn.parallel.DistributedDataParallel(推荐)
    • 多进程执行多卡训练,效率高
    • 代码编写流程
      • torch.distributed.init_process_group(“nccl”,world_size=n_gpus,rank=args.local_rank)
      • torch.cuda.set_device(args.local_rank) 该语句作用相当于CUDA_VISIBLE_DEVICES环境变量
      • model=DistributedDataParallel(model.cuda(args.local_rank),device_ids=[args.local_rank])
      • train_sampler=DistributedSampler(train_dataset)源码位于torch/utils/data/distributed.py
      • train_dataloader=DataLoader(…,sampler=train_sampler)
      • data=data.cuda(args.local_rank)
    • 执行命令
      • python -m torch.distributed.launch - - nproc_per_node=n_gpus train.py
    • 模型保存与加载
      • torch.save 在local_rank=0 的位置进行保存,同样注意调用model.module.state_dict()
      • torch.load 注意 map_location
    • 注意事项
      • train.py 中要有接受local_rank的参数选项,launch会传入这个参数
      • 每个进程的batch_size应该是一个GPU所需要的batch_size大小
      • 在每个周期开始处,调用train_sampler.set_epech(epoch)可以使得数据充分打乱
      • 有了sampler,就不要在DataLoader中设置shuffle=True了

演示:

  1. 指定GPU数量

    image-20240907085058885

  2. 设置使用哪几张卡

    image-20240907085224555

  3. 模型拷贝

    image-20240907085440920

  4. 数据拷贝

    ​ 我们本来传入train函数中的是data_loader,但是train_sampler=DistributedSampler(train_dataset)需要的是dataset,所以我们在向train函数中传递参数的时候需要稍加改动:
    image-20240907092159614

    image-20240907092649161

    image-20240907092955004

  5. save时注意事项

    ​ 我们只在编号0的GPU上进行保存

    image-20240907093143875

  6. 每个epoch打乱数据顺序

    image-20240907093515875

  7. 命令行运行

    image-20240907093754264

3.多机多卡

  • torch.nn.parallel.DistributedDataParallel
    • 代码编写流程
      • 跟单机多卡一致
    • 执行命令(以两节点为例,每个节点处有n_gpus个GPU)
      • python -m torch.distributed.launch - -nproc_per_node=n_gpus - -nnodes=2 - -node_rank=0 - -master_addr=“主节点IP” - -master_port=主节点端口 train.py
      • python -m torch.distributed.launch - -nproc_per_node=n_gpus - -nnodes=2 - -node_rank=1 - -master_addr=“主节点IP” - -master_port=主节点端口 train.py
    • 模型保存与加载
      • 同单机多卡基本一致

二、模型并行

1.背景

  • 模型参数太大,单个GPU无法容纳,需要将模型的不同层拆分到多个GPU上

2.示例

  • 参考:http://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html

3.模型保存与加载

  • 分多个module进行分别保存与加载(略)

image-20240907095750272

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2112424.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring security 中的授权使用

一、认证 身份认证,就是判断一个用户是否为合法用户的处理过程。Spring Security 中支持多种不同方式的认证,但是无论开发者使用那种方式认证,都不会影响授权功能使用。因为 SpringSecurity 很好做到了认证和授权解耦。 二、授权 授权&#x…

红黑树的旋转

红黑树的基本性质 红黑树与普通的二叉搜索树不同,它在每个节点上附加了一个额外的属性——颜色,该颜色可以是红色或黑色。通过引入这些颜色,红黑树能够维持以下 5 个基本性质,以确保树的平衡性: 每个节点是红色或黑色…

C++入门10——stack与queue的使用

目录 1.什么是stack? stack的使用 2.什么是queue? queue的使用 3.priority_queue 3.1 什么是priority_queue? 3.2 priority_queue的使用 1.什么是stack? 在官网中,对stack有这样的介绍: Stacks are a type o…

一台电脑对应一个IP地址吗?‌探讨两台电脑共用IP的可能性

在当今数字化时代,‌IP地址作为网络世界中的“门牌号”,‌扮演着至关重要的角色。‌它负责在网络上唯一标识每一台设备,‌使得数据能够在庞大的互联网中准确无误地传输。‌然而,‌对于IP地址与电脑之间的对应关系,‌许…

uni-appH5项目实现导航区域与内容区域联动效果

一、需求描述 将导航区域与内容区域实现联动,即点击导航区域,内容区滚动到对应位置,内容区滚动过程中根据内容定位到相对应的导航栏。 效果如下: 侧边导航与内容联动效果 二、功能实现思路分析汇总: 三、具体代码 1…

Matplotlib通过axis()配置坐标轴数据详解

坐标轴范围设置 axis()可以直接传入列表[xmin,xmax,ymin,ymax]进行范围设置, 分别可以使用plt.axis()或者画布对象.axis()进行配置 import numpy as np import matplotlib.pyplot as pltx np.linspace(0, 20, 100) y x*2 plt.plot(x, y, r) plt.axis([0,30,0,100]) plt.sa…

【精选】文件摆渡系统:跨网文件传输的安全与效率之选

文件摆渡系统可以解决哪些问题? 文件摆渡系统(File Shuttle System)主要是应用于不同网络、网段、区域之间的文件数据传输流转场景, 用于解决以下几类问题: 文件传输问题: 大文件传输:系统可…

云服务器内网穿透连接云手机配置ALAS

文章目录 服务器安装TailscaleNAT网络(无独立IP)云服务器安装Tailscale有固定IP的云服务器安装Tailscale 云手机安装Tailscale开启无线网络调试安装Tailscale ALAS连接云手机 上次写到服务器连接云手机时只说了有独立IP的,但有独立IP的云手机…

IDM 工具下载 地图高程数据

巧用IDM工具 快捷下载ASTER GDEM v3高程数据 ASTER GDEM v3是NASA推出的30米高清DEM,覆盖了几乎全部的地球陆地。那么,在NASA官网怎么下载ASTER GDEM v3的地形高程数据呢? 首先,你需要注册一个nasa的账号 注册网址: https://urs.earthdata.nasa.gov/users/new 注册方式和国…

彩虹数字屏保时钟 芝麻时钟开启个性化的时代 屏保怎么能少它

彩虹数字屏保时钟 芝麻时钟开启个性化的时代 屏保怎么能少它?电脑屏保多样化,让大家有了更多的选择,让更多人有机会把自己的电脑打扮得漂漂亮亮,今天小编给大家推荐:芝麻时钟(官网下载地址:http…

vulhub GhostScript 沙箱绕过(CVE-2018-16509)

1.执行以下命令启动靶场环境并在浏览器访问 cd vulhub/ghostscript/CVE-2018-16509 #进入漏洞环境所在目录 docker-compose up -d #启动靶场 docker ps #查看容器信息 2.访问网页 3.下载包含payload的png文件 vulhub/ghostscript/CVE-2018-16509/poc.png at master vulh…

TYPE-C USB设计

目录 摘要 TYPE-C电路 握手过程 USB电路 摘要 TYPE-C,是USB的一种接口,USB的第一种接口为常见的USB接口,U盘即为这种接口;第二种接口的形状类似一个凸字,常应用在打印机中,第三种接口即为TYPE-C,支持正…

JdbcRowSetImpl利用链分析

文章目录 JdbcRowSetImpl利用链前言JdbcRowSetImpl利用链分析 JdbcRowSetImpl利用链 前言 首先说明一下:利用链都有自己的使用场景,要根据场景进行选择不同的利用链。 JdbcRowSetImpl利用链用于fastjson反序列化漏洞中。 为什么? 因为fa…

暑期档总结:哪部国漫年番表现更优?

“暑期档”可能是所有档期中绵延时间最长的,作为该时间段主力的学生人群,在学业压力较小的假期中,需要更多娱乐方式来填充生活。除了电影之外,动画番剧越来越成为这一群体的不二选择,各个动画制作公司也会选择把精彩剧…

html记账本改写:数据重新布局,更好用了,没有localStorage保存版本

<!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><title>htm记账本</title><style>table {user-select: none;/* width: 100%; */border-collapse: collapse;}table,th,td {border: 1px solid …

RISC-V架构下 DSA - AI算力的更多可能性:Banana Pi BPI-F3 进迭时空

AI已经从技术走向应用&#xff0c;改变了我们的生活和工作方式。近些年&#xff0c;AI算力芯片领域群雄逐鹿&#xff0c;通过对芯片、算力与AI三者发展迭代过程的理解&#xff0c;我们发现高能效比的算力、通用的软件栈以及高度优化的编译器&#xff0c;是我们的AI算力产品迈向…

稚晖君同款 clion嵌入式开发环境搭建

前言 前段时间看到稚晖君的单片机开发环境&#xff0c;感觉挺酷的&#xff0c;自己也想尝试下&#xff0c;这里记录下安装过程。 安装文件准备 stm32cubemx安装 stm32cubemx stm32cubemx下载地址 当前时间是2024年9月4日&#xff0c;下载的版本是6.12.0版本&#xff0c;下…

一、关系模型和关系代数,《数据库系统概念》,原书第7版

文章目录 [toc]一、引言1.1 什么是数据库1.2 数据完整性1.3 数据库的操作1.4 数据库的持久性1.5 数据库管理系统1.6 数据模型1.7 早期DBMS 二、关系模型2.1 什么是关系模型2.2 关系数据库的结构2.3 键2.4 约束2.5 数据操纵语言(DML)2.6 关系代数2.6.1 选择运算2.6.2 投影运算2.…

【南方科技大学】CS315 Computer Security 【Lab1 Packet Sniffing and Wireshark】

目录 IntroductionBackgroundTCP/IP Network StackApplication LayerTransport LayerInternet LayerLink LayerPacket Sniffer Getting WiresharkStarting WiresharkCapturing PacketsTest Run Questions for the Lab Introduction 实验的第一部分介绍数据包嗅探器 Wireshark。…

2024高教社杯全国大学生数学建模竞赛B题原创python代码

以下均为python代码。先给大家看看之前文章的部分思路&#xff1a; 接下来我们将按照题目总体分析-背景分析-各小问分析的形式来 1 总体分析 题目提供了一个电子产品生产的案例&#xff0c;要求参赛者建立数学模型解决企业在生产过程中的一系列决策问题。以下是对题目的总体…