windows下使用pytorch进行单机多卡分布式训练

news2024/11/17 17:27:20

首先,pytorch的版本必须是大于1.7,这里使用的环境是:

 

pytorch==1.12+cu11.6
四张4090显卡
python==3.7.6

使用nn.DataParallel进行分布式训练

这一种方式较为简单:
首先我们要定义好使用的GPU的编号,GPU按顺序依次为0,1,2,3。gpu_ids可以通过命令行的形式传入:

gpu_ids = args.gpu_ids.split(',')
gpu_ids = [int(i) for i in gpu_ids]
torch.cuda.set_device('cuda:{}'.format(gpu_ids[0]))

创建模型后用nn.DataParallel进行处理,

 model.cuda()
 r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])

对,没错,只需要这么两步就行了。需要注意的是保存模型后进行加载时,需要先用nn.DataParallel进行处理,再加载权重,不然参数名没对齐会报错。

checkpoint = torch.load(checkpoint_path)
model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
r_model.load_state_dict(checkpoint['state_dict'])

如果不使用分布式加载模型,你需要对权重进行映射:

new_start_dict = {}
for k, v in checkpoint['state_dict'].items():
    new_start_dict["module." + k] = v
model.load_state_dict(new_start_dict)

使用Distributed进行分布式训练

首先了解一下概念:
node:主机数,单机多卡就一个主机,也就是1。
rank:当前进程的序号,用于进程之间的通讯,rank=0的主机为master节点。
local_rank:当前进程对应的GPU编号。
world_size:总的进程数。
在windows中,我们需要在py文件里面使用:

import os
os.environ["CUDA_VISIBLE_DEVICES]='0,1,3'

来指定使用的显卡。
假设现在我们使用上面的三张显卡,运行时显卡会重新按照0-N进行编号,有:

[38664] rank = 1, world_size = 3, n = 1, device_ids = [1]
[76032] rank = 0, world_size = 3, n = 1, device_ids = [0]
[23208] rank = 2, world_size = 3, n = 1, device_ids = [2]

也就是进程0使用第1张显卡,进行1使用第2张显卡,进程2使用第三张显卡。
有了上述的基本知识,再看看具体的实现。

使用torch.distributed.launch启动

使用torch.distributed.launch启动时,我们必须要在args里面添加一个local_rank参数,也就是:
parser.add_argument("--local_rank", type=int, default=0)
1、初始化:

import torch.distributed as dist

env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
current_work_dir = os.getcwd()
    init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]),
                                world_size=int(env_dict["WORLD_SIZE"]))

这里需要重点注意,这种启动方式在环境变量中会存在RANK和WORLD_SIZE,我们可以拿来用。backend必须指定为gloo,init_method必须是file:///,而且每次运行完一次,下一次再运行前都必须删除生成的ddp_example,不然会一直卡住。
2、构建模型并封装
local_rank会自己绑定值,不再是我们--local_rank指定的。

 model.cuda(args.local_rank)
 r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)

3、构建数据集加载器并封装

  train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
                              collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)

4、计算损失函数
我们把每一个GPU上的loss进行汇聚后计算。

def loss_reduce(self, loss):
        rt = loss.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.args.local_world_size
        return rt

loss = self.criterion(outputs, labels)
torch.distributed.barrier()
loss = self.loss_reduce(loss)

注意打印相关信息和保存模型的时候我们通常只需要在local_rank=0时打印。同时,在需要将张量转换到GPU上时,我们需要指定使用的GPU,通过local_rank指定就行,即data.cuda(args.local_rank),保证数据在对应的GPU上进行处理。
5、启动
在windows下需要把换行符去掉,且只变为一行。

python -m torch.distributed.launch \
--nnode=1 \
--node_rank=0 \
--nproc_per_node=3 \
main_distributed.py \
--local_world_size=3 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

nproc_per_node、local_world_size和GPU的数目保持一致。

使用torch.multiprocessing启动

使用torch.multiprocessing启动和使用torch.distributed.launch启动大体上是差不多的,有一些地方需要注意。

mp.spawn(main_worker, nprocs=args.nprocs, args=(args,))

main_worker是我们的主运行函数,dist.init_process_group要放在这里面,而且第一个参数必须为local_rank。即main_worker(local_rank, args)
nprocs是进程数,也就是使用的GPU数目。
args按顺序传入main_worker真正使用的参数。
其余的就差不多。
启动指令:

python main_mp_distributed.py \
--local_world_size=4 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

最后需要说明的,假设我们设置的batch_size=64,那么实际上的batch_size = int(batch_size / GPU数目)。
附上完整的基于bert的中文文本分类单机多卡训练代码:GitHub - taishan1994/pytorch_bert_chinese_text_classification: 基于pytorch+bert的中文文本分类GitHub - taishan1994/pytorch_bert_chinese_text_classification: 基于pytorch+bert的中文文本分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/803363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信小程序nodejs+vue房屋租赁交易租房平台

本系统功分为用户,经纪人,管理员三个角色,其中用户可以进行注册登陆系统,用户可以查看中介门店,查看经纪人,在线预定房源,在线租赁房源,可以发布求购房源,求租房源&#…

Mac/win开发快捷键、vs插件、库源码、开发中的专业名词

目录 触控板手势(2/3指) 鼠标右键 快捷键 鼠标选择后shift⬅️→改变选择 mac command⬅️:删除←边的全部内容 commadtab显示下栏 commandshiftz向后撤回 commandc/v复制粘贴 command ⬅️→回到行首/末 commandshift3/4截图 飞…

element dialog弹出框层级错乱问题

需要加modal-append-to-body 默认为true,遮罩层是否插入至 body 元素上,若为 false,则遮罩层会插入至 Dialog 的父元素上。 为false时的HTML结构 为true时的HTML结构 出现弹框层级错乱问题时可以modal-append-to-body是否设置为false了。

【LS科技芯团队成立】基础研究是科学之本、技术之源、创新之魂

目录 LS科技芯团队简介 团队创建人 成立本团队的核心目的 来自各个省份的大佬专家们 加入LS科技芯团队吧! LS科技芯团队简介 “LS科技芯”团队于2023年7月25日下午成立。汇聚了来自各个省份的技术博主,涵盖了电子技术,程序设计,…

软件兼容性测试中需注意的关键问题

在进行软件兼容性测试时,有一些关键问题需要特别注意,以确保测试的准确性和全面性。本文将介绍一些在软件兼容性测试中需注意的关键问题,帮助测试人员更好地进行兼容性测试工作。 首先,测试范围,测试人员需要明确测试的…

sql查询语句大全-详细讲解(格式、示例)

目录 范围查询 BETWEEN...AND in 为空 模糊查询 去重查询 AND OR 排序查询 聚合函数 1.count:计算个数 2.max:计算最大值 3.min:计算最小值 4.sum:计算和 5.avg:计算平均数 分组查询 group by 分组后…

vite / nuxt3 项目使用define配置/自定义,可以使用process.env.xxx获取的环境变量

每日鸡汤:每个你想要学习的瞬间,都是未来的你向自己求救 首先可以看一下我的这篇文章了解一下关于 process.env 的环境变量。 对于vite项目,在我们初始化项目之后,在浏览器中打印 process.env,只有 NODE_ENV这个变量&…

【HMS Core】统一扫描连续扫码、闪光灯关闭问题

【问题描述1】 使用Default View Mode进行扫码,如何实现连续扫码 【解决方案】 在默认扫码模式Default View中,功能是集成在SDK内部的,无法设置连续扫码模式等信息。 可以使用Customized View Mode这种模式,它提供了相关的api可…

【一天三道算法题】代码随想录——Day14

一. 有效的括号 题目链接:力扣 思路:无非三种情况: 1. 左侧括号多,右侧少 2. 左右侧一样多,该字符串属于有小括号字符串 3. 右侧括号多,左侧少 那么说白了就是要比较左右括号的数量,谁多&…

KNN背景分割算法

以下代码用OpenCV实现了视频中背景消除和提取的建模,涉及到KNN(K近邻算法),真题效果比较好,可以用来进行状态分析。 原理如下: 背景建模:在背景分割的开始阶段,建立背景模型。 前景…

容器部署jenkins定时构建于本地时间不一致

1. Dockerfile FROM jenkins/jenkins:2.411-jdk11 USER root #以下生成密钥方式为旧格式,因为新格式暂不能被"Publish over SSH--->Jenkins SSH Key"功能识别 RUN ssh-keygen -q -m PEM -t rsa -b 2048 -N -f /root/.ssh/id_rsa ADD ./apache-maven…

区分jdbcTemplate操作数据库和mybatis操作数据库

JdbcTemplate和MyBatis是Java中常用的两种数据库操作方式。它们在实现上有一些区别,下面我将为你介绍它们的主要特点和区别: JdbcTemplate: JdbcTemplate是Spring框架中提供的一个类,用于简化JDBC操作。使用JdbcTemplate时&#x…

【设计模式】观察者设计模式解析

目录 一、观察者模式定义 二、观察者模式角色 三、观察者模式类图 四、观察者模式实例 五、观察者模式优缺点 5.1、优点 5.2、缺点 六、观察者模式应用 6.1、Spring 中观察者模式的四个角色 6.2、coding~~~~~~ 一、观察者模式定义 观察者模式(Observer Pattern)&#…

【Unity细节】关于NotImplementedException: The method or operation is not implemented

👨‍💻个人主页:元宇宙-秩沅 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 秩沅 原创 收录于专栏:unity细节和bug ⭐关于NotImplementedException: The method or operation is not implemented.⭐…

【MATLAB第61期】基于MATLAB的GMM高斯混合模型回归数据预测

【MATLAB第61期】基于MATLAB的GMM高斯混合模型回归数据预测 高斯混合模型GMM广泛应用于数据挖掘、模式识别、机器学习和统计分析。其中,它们的参数通常由最大似然和EM算法确定。关键思想是使用高斯混合模型对数据(包括输入和输出)的联合概率…

MFC自定义控件使用

用VS2005新建一个MFC项目,添加一个Custom Control控件在窗体 我们需要为自定义控件添加一个类。项目,添加类,MFC类 设置类名字,基类为CWnd,你也可以选择CDialog作为基类 类创建完成后,在它的构造函数中注册一个新的自定义窗体,取名为"MyWindowClass" WNDCL…

3ds max 烘培世界坐标到贴图/顶点色

设置Diffuse 为ObjectNormal Normalize(objectNormal) * 0.5 0.5 把Diffuse烘培到顶点色 烘培Diffuse到贴图 模型按UV展开 右键复制 , 到mesh上粘贴 烘培到贴图 UE使用 贴图导入为BC7 float3 n ObjectNormal*2-1; return float3(n.x,n.z,n.y); // x ,z ,y

【深度学习Week3】ResNet+ResNeXt

ResNetResNeXt 一、ResNetⅠ.视频学习Ⅱ.论文阅读 二、ResNeXtⅠ.视频学习Ⅱ.论文阅读 三、猫狗大战Lenet网络Resnet网络 四、思考题 一、ResNet Ⅰ.视频学习 ResNet在2015年由微软实验室提出,该网络的亮点: 1.超深的网络结构(突破1000层&…

叶工好容5-日志与监控

目录 前言 平台维度 docker运行状态 cAdvisor-日志采集者 Heapster-日志收集 metrics-server-出生决定成败 kube-state-metrics-不完美中的完美 应用维度 日志 部署方式 输出方式 工具选择 日志接入 监控 serviceMonitor Annotation Prometheus扩展性 Thanos …

StackOverFlow刚刚宣布推出自己的AI产品!

StackOverFlow刚刚宣布要推出自己的AI产品! OverflowAI是StackOverFlow即将推出自己AI产品的名字,据称也是以VSCode插件的形式,计划在8月发布。我们来看看都有些什么功能,通过目前的信息看,OverflowAI的主要功能就是&…