跟代码执行流程,读Megatron源码(四)megatron训练脚本initialize.py之initialize_megatron()分布式环境初始化

news2025/1/10 23:53:31

  在前文中,我们讲述了pretrain函数的执行流程,其首要步骤是megatron分组的初始化与环境的配置。本文将深入initialize_megatron函数源码,剖析其初始化分布式训练环境的内部机制。

  注:在此假设读者具备3D并行相关知识

一. initialize_megatron函数的上下文调用关系(initialize.py)

  在Megatron-LM中,initialize.py文件中的initialize_megatron函数是分布式训练环境的初始化核心。该函数由trainning.py的pretrain函数调用,是整个pretrain流程中的第一个核心步骤,负责配置3D并行的环境和分组信息。

1. initialize_megatron函数源码

  initialize_megatron函数的核心代码段如下:

  需要注意的是,尽管initialize_megatron函数还涵盖了设置全局参数、分词器构建、自动恢复配置、TensorBoard日志记录、计时器设置以及依赖编译等辅助功能,但这些功能在初始化流程中虽具重要性,却非其核心职责所在。其核心功能聚焦于上述代码段所描述的分布式与模型并行初始化流程,而该流程主要是通过finish_mpu_init中调用的_initialize_distributed函数实现的,如下图。

2. _initialize_distributed

  _initialize_distributed函数主要有两个作用:

  a. 通过调用torch.distributed.init_process_group()初始化分布式环境,该函数设置了分布式训练所需的基本通信环境,包括进程间的通信后端、worldsize(参与训练的进程总数)、每个进程的rank号等。默认情况下,它会创建一个全局的进程组,这个进程组定义了哪些进程可以相互通信,也可以根据需要创建更多的进程组以支持更复杂的通信模式。

  在使用 init_process_group() 初始化分布式环境之后,我们可以使用PyTorch提供的分布式通信和同步 API 来实现跨进程的通信和数据同步。这包括使用dist.all_reduce() 来聚合梯度、dist.barrier() 同步所有进程的执行点等。

  请注意,在调用init_process_group()之前,需要确保已经正确设置了所有相关的环境变量(如MASTER_ADDR、MASTER_PORT),并且这些环境变量对于每张卡都是唯一的。

  b. 通过调用mpu.initialize_model_parallel()来初始化分布式训练环境中的数据并行(DP)、张量并行(TP)、和流水线并行(PP)的分组,如下图。

  mpu.initialize_model_parallel()的入参解释如下:

  • tensor_model_parallel_size:张量并行的大小。

  • pipeline_model_parallel_size:流水线并行的大小。

  • virtual_pipeline_model_parallel_size:虚拟流水线并行的大小,这是一个更高级的特性,允许在流水线阶段内部进一步分割模型。

  • pipeline_model_parallel_split_rank:该参数指定了流水线分割的起始rank,即决定了哪个rank的设备将开始处理流水线的第一个阶段,然后接下来的阶段按顺序分配给rank号递增的设备。

  • context_parallel_size、expert_model_parallel_size等参数用于特定的模型架构,如带有上下文并行或专家并行的Transformer模型。

  • distributed_timeout_minutes:分布式操作的超时时间(以分钟为单位)。

  • nccl_communicator_config_path:NCCL通信器的配置文件路径,NCCL是用于NVIDIA GPU的高效通信库。

  • order:指定并行策略的顺序,例如'tp-cp-ep-dp-pp'表示张量并行(Tensor Parallelism)、上下文并行(Context Parallelism)、专家并行(Expert Parallelism)、数据并行(Data Parallelism)和流水线并行(Pipeline Parallelism)的顺序。

  • encoder_pipeline_model_parallel_size:专门用于编码器的流水线并行大小。

  • get_embedding_ranks和get_position_embedding_ranks:用于获取特定用于embedding或position embedding的GPU rank号,以便为这些组件配置特定的并行策略。

二. mpu.initialize_model_parallel函数的调用关系(_init_.py)

  mpu.initialize_model_parallel()函数的调用关系在import中表明,如下图:

  其中mpu定义在megatron/core/_init_.py中:

  如上图,mpu指向parallel_state,因此,对于mpu.initialize_model_parallel()的调用既是对parallel_state.initialize_model_parallel()的调用,该函数实现在parallel_state.py中。

三. 分组逻辑的具体实现(parallel_state.py)

1. parallel_state.initialize_model_parallel的调用关系

  parallel_state.initialize_model_parallel函数在分布式训练架构中扮演着关键但非终结性的角色,其核心功能是启动模型并行所需的基础设置。该函数专注于预配置模型并行相关的进程群组(process groups)与全局状态变量,确保并行执行环境的基础通信框架得以确立。

  下面以dp为例,展示该函数的代码执行逻辑。

  首先,该函数创建RankGenerator对象实体,该对象的作用是根据用户输入的tp/dp/pp大小,以及总卡数(world_size),确定最终每张卡的分组,如下图。

  其次,它首先调用RankGenerator组件动态生成高效的rank分配方案,这一步骤是优化资源利用与通信效率的关键。随后,通过调用后端(backend)接口,根据RankGenerator产出的分组策略,构建起实际用于数据交换的通信群组(communication groups)。这些通信群组覆盖了数据并行(dp)、张量并行(tp)、以及流水线并行(pp)等维度(这里只以dp为例),确保模型训练过程中的数据流通与参数同步能够高效且有序地进行,如下图。

  随着所有必需通信群组的成功建立,以及全局分组变量的初始化,模型并行的核心通信网络得以全面搭建完成。这一网络的构建不仅标志着模型并行化训练环境的初步就绪,更为后续的高性能计算任务奠定了坚实的基础,确保了分布式训练过程中数据一致性与效率的最优化实现。

2. RankGenerator.get_ranks

  RankGenerator,顾名思义,其主要职责为生成rank分组,它管理着用户的tp/pp/dp/ep/cp数值,以及全局卡数(world_size)等分组相关的配置项,并在get_ranks函数中调用generate_masked_orthogonal_rank_groups函数获取用户需要的最终分组信息,如下图:

3. generate_masked_orthogonal_rank_groups

  generate_masked_orthogonal_rank_groups函数是rank分组的最终实现,其代码逻辑如下:

  a. 筛选并行性尺寸:

  masked_shape:从parallel_size和mask中筛选出被掩码(即True)的并行尺寸,这些尺寸将用于生成组内的rank。

  unmasked_shape:同样从parallel_size和mask中筛选出未被掩码的并行尺寸,这些尺寸将用于在更广泛的并行环境中(跨组)定位每个组。

  b. 计算步长:

  global_stride:通过prefix_product(parallel_size)计算得到全局步长。

  masked_stride和unmasked_stride:分别根据mask从global_stride中筛选出被掩码和未被掩码的步长。这些步长用于计算全局rank。

  c. 确定组大小和组数:

  group_size:通过prefix_product(masked_shape)[-1]计算得到每个组的大小。

  num_of_group:通过world_size // group_size计算得到组的数量,即全局大小除以每个组的大小。

  d. 生成rank:

  遍历每个组(group_index从0到num_of_group-1)。

  使用decompose(group_index, unmasked_shape)根据未被掩码的并行尺寸分解组索引,得到该组在全局并行环境中的位置(decomposed_group_idx)。

  在每个组内,遍历每个rank(rank_in_group从0到group_size-1)。

  使用decompose(rank_in_group, masked_shape)根据被掩码的并行尺寸分解组内rank,得到该rank在组内的位置(decomposed_rank_idx)。

  计算每个rank的全局索引,通过将被掩码和未被掩码的索引分别与其对应的步长进行内积(inner_product),然后将两个内积相加得到。

  最后,将计算得到的每个组内的rank添加到ranks列表中。

  e. 返回rank列表:

  函数最终返回ranks列表,其中包含了每个组内的所有rank,这些rank在全局并行环境中是唯一的。

  至此分布式训练分组的全部逻辑均介绍完毕,后续文章将继续解析分组完成后的训练逻辑。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1941756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux基于CentOS7【yum】【vim】的基础学习,【普通用户提权】

目录 yum生态 什么是yum yum是如何得知目标服务器的地址和下载链接 vim vim模式 命名模式 光标移动 插入模式 i键插 a键插 o键插 底行模式 批量化注释 批量化去注释 创建vim配置文件 例子 高亮功能: 缩进功能: 符号位自动补齐功能…

软件缺陷(Bug)、禅道

目录 软件缺陷的判定标准 软件缺陷的核心内容 构成缺陷的基本要素 缺陷报告 缺陷管理 缺陷的跟踪流程 项目管理工具--禅道 软件在使用过程中存在的任何问题(如:错误、异常等),都叫软件的缺陷,简称bug。 软件缺…

175道Docker面试题(上)

目录 1、什么是docker? 2、Docker与普通虚拟机的对比: 3、Docker常用命令: 4、Docker镜像是什么? 5、Docker容器是什么? 6、Docker容器有几种状态? 7、Dockerfile中最常见的指令是什么? …

v-for 进行列表的 增删改查

通过对象下标替换属性值 但是通过实践此方法是错误的&#xff0c;Vue监听的是students这个对象&#xff0c;而不是这个对象里面的数组信息&#xff0c;也就是说&#xff0c;改变里面的值&#xff0c;并不能在页面上实现更新的功能 <!DOCTYPE html> <html lang"en…

windows系统conda 使用注意事项

在windows系统中&#xff0c;可以直接在Anaconda 的终端中使用conda功能。 但是&#xff0c;在Anaconda的终端中&#xff0c; 无法使用完整的命令行工具。 如设置环境变量的set命令就不起作用。 这时我们就想要在windows的cmd命令中直接使用conda。 但直接使用会有问题&#…

HarmonyOS NEXT零基础入门到实战-第四部分

自定义组件: 概念: 由框架直接提供的称为 系统组件&#xff0c; 由开发者定义的称为 自定义组件。 源代码&#xff1a; Component struct MyCom { build() { Column() { Text(我是一个自定义组件) } } } Component struct MyHeader { build() { Row(…

Maven的常用命令(面试篇之Maven)

我在写项目时,使用Maven的插件的命令来进行打包等,却发现报错误了,虽然解决了, 但借此机会来总结一下Maven的常用命令: 这些插件都有着自己的命令,虽然,我们可以简化的使用一些idea中的方便的按键: 但 , 一个程序员的功力深浅就在这些细节末尾处: 在Maven中&#xff0c;插件是…

高职国培丨数据分析与数据挖掘课程实施能力提升培训班正式开班

7月15日&#xff0c;由广东机电职业技术学院牵头&#xff0c;广东泰迪智能科技股份有限公司作为合作单位的“高职教师数据分析与数据挖掘课程实施能力提升培训班&#xff08;高职国培&#xff09;”正式开班。来自广东省各地36位高校教师参与本次线下师资国培班。 广东机电职业…

el-menu根据多层树形结构递归遍历展示菜单栏

文章目录 前提条件假设菜单等级只有两个等级结果如下所示 但是如果菜单等级超过两个等级或者多个等级的话App.vueMenuItems.vue结果如下所示 关于遍历时图标前的展示后续完善关于点击路由跳转参考element plus的官网即可 前提条件 package.json如下所示&#xff0c;这是一个Vi…

广州邀请媒体宣传(附媒体名单)

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 广州地区 媒体邀约&#xff1a; 记者现场采访&#xff0c;电视台到场报道&#xff0c;展览展会宣传&#xff0c;广交会企业宣传&#xff0c;工厂探班&#xff0c;媒体专访等。 适合广州…

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转

vue3前端开发-小兔鲜项目-二级页面面包屑导航和跳转&#xff01;这一次&#xff0c;做两件事。第一件事是把二级分类页面的跳转&#xff08;也就是路由&#xff09;设计一下。第二件事是把二级页面的面包屑导航设计一下。 第一件事&#xff0c;二级页面的跳转路由设计一下。 如…

【RabbitMQ】Windows下RabbitMQ的安装和部署

Windows下RabbitMQ的安装和部署 一、引言二、环境搭建三、安装ERLANG四、安装RabbitMQ五、安装RabbitMQ-Plugins六、验证 一、引言 RabbitMQ——Rabbit Message Queue的简写&#xff0c;但不能仅仅理解其为消息队列&#xff0c;消息代理更合适。RabbitMQ 是一个由 Erlang 语言…

智慧隧道可视化:安全与效率的智能保障

运用图扑可视化技术&#xff0c;实时监测隧道内的环境和交通状况&#xff0c;提升维保效率和应急响应能力&#xff0c;确保隧道运营的安全和畅通。

Build a Large Language Model (From Scratch)GPT-4o翻译和代码每行中文注释Ch5

目录 Pretraining on Unlabeled DataThis chapter covers5.1 Evaluating generative text models5.1.1 Using GPT to generate text5.1.2 Calculating the text generation loss5.1.3 Calculating the training and validation set losses 5.2 Training an LLM5.3 Decoding str…

免费分享一套微信小程序图书馆座位预约管理系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】,帅呆了~~

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序图书馆座位预约管理系统(SpringBoot后端Vue管理端)&#xff0c;分享下哈。 项目介绍 随着移动互联网技术的飞速发展和智能设备的普及&#xff0c;图书馆服务模式正在经历深刻的变革。本论文旨在…

【Linux】从零开始认识多线程 --- 线程ID

在这个浮躁的时代 只有自律的人才能脱颖而出 -- 《觉醒年代》 1 前言 上一篇文章中讲解了线程控制的基本接口&#xff1a; 线程创建pthread_create(pthread_t *thread, const pthread_attr_t *attr, void *(*start_routine) (void *), void *arg);: pthread_t *thread :输出…

艺术成分很高的完全自定义的UITabBar(很简单)

引言 在iOS应用开发中&#xff0c;UITabBar是一个非常场景且重要的UI组件。系统为我们提供的UITabBar虽然功能强大&#xff0c;但是在某些情况下&#xff0c;它的标准样式并不能满足我们特定的设计需求&#xff0c;它的灵活性也有一些局限。为了打造更具个性化好的用户友好的交…

Ai绘画变现的14种途径 学习Stablediffusion midjourney用途

AIGC&#xff0c;一个在当代社会中不可忽视的词汇&#xff0c;指的是利用人工智能技术生成创作内容。近年来&#xff0c;全球范围内涌现出50个热门的AI工具&#xff0c;其中&#xff0c;以140亿次访问量雄踞榜首的“GBT”&#xff0c;无疑是AI领域的领头羊。在这些工具中&#…

node.js中nodemon : 无法加载和使用问题,这是由于windows安全策略影起的按如下操作即可

1、用管理员权限打开vscode 2、文件终端中打开&#xff0c;输入 Set-ExecutionPolicy -Scope CurrentUser 3、再输入RemoteSigned 4、使用get-ExecutionPolicy查看权限&#xff0c;可以看到变为了RemoteSigned 重启问题解决

一个注解解决重复提交问题

一、前言 ​ 在应用系统中提交是一个极为常见的功能&#xff0c;倘若不加管控&#xff0c;极易由于用户的误操作或网络延迟致使同一请求被发送多次&#xff0c;从而生成重复的数据记录。针对用户的误操作&#xff0c;前端通常会实现按钮的 loading 状态&#xff0c;以阻…