Training - PyTorch Lightning 的 Horovod 策略实践 (all_gather)

news2025/1/13 7:57:03

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/137686312

Horovod

在 PyTorch Lightning 中使用 Horovod 策略,可以在多个 GPU 上并行训练模型。Horovod 是分布式训练框架,通过优化数据传输来提高多 GPU/CPU 训练的效率。要在 PyTorch Lightning 中使用 Horovod,需要在训练命令中指定 Horovod 作为策略。

  • PyTorch Lightning 源码:GitHub - pytorch-lightning
  • Horovod 策略的具体源码:pytorch_lightning.strategies.horovod

1. 构建 Docker 环境

首先,需要构建支持 MPI 运行的 Docker,安装 PyTorch Lightning 与 Horovod 的安装包,目前而言,PyTorch Lightning 的 2.+ 版本,以上,已经移除 Horovod 策略,需要降级至 1.8.6 版本,才支持 Horovod 策略,即:

pip install pytorch-lightning==1.8.6
pip install cmake==3.24.2 
pip install horovod==0.27.0

注意:horovod 安装之前,需要满足 cmake 版本,需要预先安装 cmake 包,否则报错:

File "/tmp/pip-install-qcugcd1u/horovod_a39ef0ac7a9e4940bc6b5969457a47f4/setup.py", line 88, in get_cmake_bin
	raise RuntimeError("Failed to install temporary CMake. "
RuntimeError: Failed to install temporary CMake. Please update your CMake to 3.13+ or set HOROVOD_CMAKE appropriately.

参考:StackOverflow - How to reinstall the latest cmake version?

验证 PyTorch 与 Horovod 是否安装成功:

python

import torch
print(torch.__version__)  # 1.13.1
print(torch.cuda.is_available())  # True

from horovod.torch import mpi_lib_v2 as mpi_lib
# pass

也可以使用 Horovod 策略补充工程,支持 PyTorch Lightning 的 2.+ 版本,参考 GitHub - lightning-Horovod

Horovod

启动 Docker:

nvidia-docker run -it --name [your name] -v /pfs_beijing:/pfs_beijing -v /nfs_beijing:/nfs_beijing -v /nfs_beijing_ai:/nfs_beijing_ai [your image]:[version]

上传 Docker 至服务器:

# 提交 Tag
docker ps -l
docker commit 20df5ad955bb [your image]:[version]

# 准备远程 Tag
docker tag [your image]:[version] [remote image]:[version]
docker images | grep [your image]

# 推送至远程
docker push [remote image]:[version]

2. 配置 Horovod 策略

固定随机种子,确保分布式的表现一致:

# 设置 seed 参数
if args.seed is not None:
    seed_everything(args.seed)
    logger.info(f"[CL] Using seed: {args.seed}")

配置 Horovod 环境变量 与 策略,即:

from pytorch_lightning.strategies import HorovodStrategy

os.environ["HOROVOD_FUSION_THRESHOLD"] = "0"
os.environ["HOROVOD_CACHE_CAPACITY"] = "0"
os.environ["OMPI_MCA_btl_vader_single_copy_mechanism"] = "none"
import horovod.torch as hvd
hvd.init()
torch.cuda.set_device(hvd.local_rank())
strategy = HorovodStrategy()

# Horovod 不需要设置,使用默认值
args.num_nodes = 1
args.gpus = None

logger.info(f"[CL] Using HorovodStrategy")

注意:Horovod 策略,在 pl.Trainer 中,不需要设置 num_nodesgpus,使用默认值,即 1 和 None。

具体的 pl.Trainer 配置 Horovod 策略,如下:

trainer = pl.Trainer(
    accelerator="gpu",
    # ...
    strategy=strategy,  # 多机多卡配置
    num_nodes=args.num_nodes,  # 节点数
    devices=args.gpus,  # 每个节点 GPU 卡数
)

3. 配置 Horovod 的 all_gather 实例

在 PyTorch Lightning 中,不推荐直接使用 torch.distributed.all_gather_object() 进行分布式数据汇集,建议在 pl.LightningModule 类中,直接调用 self.all_gather() 方法。

  • torch.distributed.all_gather_object() 的源码,参考 Doc - PyTorch
  • LightningModule.all_gather() 的源码,参考 Doc - Lighting 1.8.6
  • horovod.torch.allgather() 的源码,参考 Doc - Horovod

LightningModule 的 all_gather() 调用 Horovod 的 allgather() 函数,源码如下:

def all_gather(self, result: Tensor, group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False) -> Tensor:
        if group is not None and group != dist_group.WORLD:
            raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")

        if len(result.shape) == 0:
            # Convert scalars to single dimension tensors
            result = result.reshape(1)

        # sync and gather all
        self.join()
        return hvd.allgather(result)

其中,torch.distributed.all_gather_object() 方法,报错如下:

horovod all_gather_object "Default process group has not been initialized, please make sure to call init_process_group.""

原因是,在 LightningModule 中,不推荐直接使用 torch.distributed 的方法,建议直接调用 LightningModule 的内部方法。

其中 all_gather 的源码修改示例,如下:

class ModelWrapper(pl.LightningModule):
  	
    def gather_log(self, log, world_size):
        if world_size == 1:
            return log

        # 异常代码,不建议直接调用 torch.distributed
        # log_list = [None] * world_size
        # torch.distributed.all_gather_object(log_list, log)
        # log = {key: sum([l[key] for l in log_list], []) for key in log}

        log_gather_map = self.all_gather(log)
        # logger.info(f"[CL] log: {log}")
        # logger.info(f"[CL] log_list_map: {log_gather_map}")

        log_parse_map = dict()
        for key in log_gather_map.keys():
            # [sample,num_node],例如 样本 3 个,Node 2个,[[1,2],[3,4],[5,6]]
            tmp_list = log_gather_map[key]
            for item in tmp_list:
                if isinstance(item, torch.Tensor):
                    item_cpu = item.detach().cpu()
                    item_x = item_cpu.numpy().tolist()
                    if key not in log_parse_map.keys():
                        log_parse_map[key] = []
                    # sum([[1,2],[3,4]], []) -> [1, 2, 3, 4]
                    log_parse_map[key] += item_x
                elif isinstance(item, str):
                    # val_name = ['7skh_B', '7vqk_A', '7vrf_A'],all_gather 问题
                    continue
        # logger.info(f"[CL] log_parse_map: {log_parse_map}")
        return log_parse_map
      
	# ...

日志输出,包括2个卡,每个卡的数据,all_gather之后,获得全部数据,如下:

# Worker 0, all_gather 之前:
[worker-0:163] [INFO] [CL] log: 
{
  'val_first_ref_rmsd': [30.974, 21.57, 18.238],
  # ...
}

# Worker 1, all_gather 之前:
[worker-1:163] [INFO] [CL] log: 
{
	'val_first_ref_rmsd': [27.358, 19.888, 32.003],
  # ...
}

# Worker 0, all_gather 之后:
[worker-0:163] [INFO] [CL] log_gather_map:
{
  'val_first_ref_rmsd': [
    tensor([30.9740, 27.4560], device='cuda:0'),
    tensor([21.5700, 19.6400], device='cuda:0'),
    tensor([18.2380, 31.5020], device='cuda:0')
  ],
  # ...
}

# 获得全部的6个样本数据:
[worker-1:163] [INFO] [CL] log_parse_map: 
{
	'val_first_ref_rmsd': [30.9740, 27.4560, 21.5700, 19.6400, 18.2380, 31.5020],
	# ...
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1590616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手写ArrrayList

需求 自定义的MyArrayList import java.util.Arrays; import java.util.Objects;public class MyArrayList<E> {private Object[] elementData ; // 存储元素的数组private int size; // 记录 的元素个数private static final int DEFAULT_CAPACITY 10; // 默认容量// …

前端重置表单的多个Demo

目录 前言1. 纯重置2. reset重置3. resetFields重置4. 彩蛋 前言 由于从Java转全栈&#xff0c;对于前端的相关知识目前 以点科普面&#xff0c;此处的总结 重置前端表单内容&#xff0c;防止影响后续操作 其基本知识只需要通过点击按钮触发重置表单 1. 纯重置 可以通过按钮…

Golang | Leetcode Golang题解之第24题两两交换链表中的节点

题目&#xff1a; 题解&#xff1a; func swapPairs(head *ListNode) *ListNode {dummyHead : &ListNode{0, head}temp : dummyHeadfor temp.Next ! nil && temp.Next.Next ! nil {node1 : temp.Nextnode2 : temp.Next.Nexttemp.Next node2node1.Next node2.Nex…

一次http访问超时服务器端调试

问题&#xff1a;http访问服务器时没有返回&#xff0c;没有超时&#xff0c;一直在阻塞 处理过程&#xff1a;telnet端口能连上&#xff0c;服务端程序也不存在处理时间过长的情况。 说明tcp连接没问题。推测是客户端连接后再发起请求&#xff0c;服务端阻塞了。因为很多客户…

项目实训2024.04.12日志:Self-QA生成问答对

1. Self-QA技术 1.1. 为什么要用Self-QA技术 关于为什么要搜集问答对&#xff0c;我在创新实训2024.04.07日志&#xff1a;提取QA对这篇文章中提到过&#xff1a;训练大模型需要从业务侧积累的问题、资料、文档中提取出一些指令-问答对作为输入的语料。 之前我们对于问答对的…

Django中间件路由映射自动加/斜杠问题原因及分析

输入 http://127.0.0.1:8000/main/index/ 输入 http://127.0.0.1:8000/main/index 路由定义情况 urlpatterns [path("index/", views.index) ]可以发现我在输入URL的index路由时&#xff0c;如果没有和Django定义的路由匹配规则一样的话&#xff0c;浏览器自…

Python——详细解析目标检测xml格式标注转换为txt格式

本文简述了目标检测xml格式标注的内容&#xff0c;以及yolo系列模型所需的txt格式标注的内容。并提供了一个简单的&#xff0c;可以将xml格式标注文件转换为txt格式标注文件的python脚本。 1. xml格式文件内容 <size>标签下为图片信息&#xff0c;包括 <width> …

【SVN】clean up报错:Cleanup failed to process the following paths 解决方法

报错来源&#xff1a;代码更新有一个文件既不能接受自己的也不能接受别人的&#xff0c;只能取消&#xff0c;再提交提醒clean up&#xff0c;随后报标题错误。 解决方法&#xff1a;参考https://www.cnblogs.com/pinpin/p/11395438.html 注&#xff1a;如果clean up的时候有…

代码随想录算法训练营DAY24|C++回溯算法Part.1|回溯算法理论基础、77.组合、组合问题的剪枝操作

文章目录 回溯算法如何理解回溯算法回溯法模版回溯算法模版框架 77.组合树形结构回溯三部曲伪代码CPP代码实现 组合问题的剪枝操作 回溯算法 如何理解回溯算法 回溯法解决的问题都可以抽象为树形结构。 因为回溯法解决的都是在集合中递归查找子集&#xff0c;集合的大小就构成…

Spring Boot集成Graphql快速入门Demo

1.Graphql介绍 GraphQL 是一个用于 API 的查询语言&#xff0c;是一个使用基于类型系统来执行查询的服务端运行时&#xff08;类型系统由你的数据定义&#xff09;。GraphQL 并没有和任何特定数据库或者存储引擎绑定&#xff0c;而是依靠你现有的代码和数据支撑。 优势 GraphQL…

Stable Diffusion 本地部署教程:详细步骤与常见问题解析

作为一位热衷于探索前沿AI技术的博主&#xff0c;近期我深度研究了Stable Diffusion模型的本地部署过程。在这篇教程中&#xff0c;我将详述从环境准备到模型运行的每个步骤&#xff0c;并针对常见的部署问题给出解决方案&#xff0c;帮助你顺利在本地开启Stable Diffusion的创…

pyplot+pandas实现操作excel及画图

1、安装jupyter lab pip install jupyterlab # 启动 建议在指定的项目文件夹下 开启cmd窗口并执行 jupyter lab 启动后会自动打开浏览器访问 2、安装依赖 pip install matplotlib pip install xlrd pip install pandas 3、读取excel import pandas as pddf pd.read_excel(hi…

C# Solidworks二次开发:几何公差IGot相关操作API详解

大家好&#xff0c;今天要介绍的是关于几何公差IGot相关操作的API。 几何公差之前没有讲过&#xff0c;具体API如下面所示&#xff1a; &#xff08;1&#xff09;第一个为GetText&#xff0c;这个API的含义为获取此几何公差的指定文本部分&#xff0c;下面是官方的具体解释&…

基于springboot实现医疗病历互换系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现医疗病历交互系统演示 摘要 进入21世纪&#xff0c;计算机技术迅速向着网络化的、集成化方向发展。传统的单机版应用软件正在逐渐退出舞台&#xff0c;取而代之的是支持网络、支持多种数据信息的新一代网络版应用软件&#xff0c;形成了信息化的社会。信息…

FluentUI系列 - 1 - 介绍第一个窗口

介绍一个QML的UI库&#xff0c;国人编写&#xff0c;作者也耍知乎。这个UI库确实好用&#xff0c;但是教程基本等于无&#xff0c;个人在使用中顺便记录一下学习内容。这玩意儿也有Pyside6的版本&#xff0c;有需要的可以查看PySide6-FluentUI-QML。 FluentUI库地址​github.c…

开关灯---一维数组

直接看题&#xff1a; 开关灯 此题用模拟的复杂度是O(n&#xff09; &#xff0c;其实有更优解就是用完全平方数。但是我不想在C中遇到数学。。。所以用模拟解。 把数组的类型设为bool类型即可&#xff01; AC代码&#xff1a; #include<bits/stdc.h> using namespace …

Unity TMP Inputfield 输入框 框选 富文本 获取真实定位

一、带富文本标签的框选是什么 UGUI的InputField提供了selectionAnchorPosition和selectionFocusPosition&#xff0c;开始选择时的光标下标和当前光标下标 对于未添加富文本标签时&#xff0c;直接通过以上两个值&#xff0c;判断一下框选方向&#xff08;前向后/后向前&…

前端 接口返回来的照片太大 加载慢如何解决

现象 解决 1. 添加图片懒加载 背景图懒加载 对背景图懒加载做的解释 和图片懒加载不同&#xff0c;背景图懒加载需要使用 v-lazy:background-image&#xff0c;值设置为背景图片的地址&#xff0c;需要注意的是必须声明容器高度。 <div v-for"img in imageList&quo…

麒麟 V10 离线 安装 k8s 和kuboard

目录 安装文件准备 主机准备 主机配置 修改主机名&#xff08;三个节点分别执行&#xff09; 配置hosts&#xff08;所有节点&#xff09; 关闭防火墙、selinux、swap、dnsmasq(所有节点) 安装依赖包&#xff08;所有节点&#xff09; 系统参数设置(所有节点) 时间同步…

html 引入vue Element ui 的方式

第一种&#xff1a;使用CDN的方式引入 <!--引入 element-ui 的样式&#xff0c;--> <link rel"stylesheet" href"https://unpkg.com/element-ui/lib/theme-chalk/index.css"> <!-- 必须先引入vue&#xff0c; 后使用element-ui --> <…