跟代码执行流程,读Megatron源码(二)训练入口pretrain_gpt.py

news2025/1/15 6:22:51

  Megatron-LM默认支持GPT、T5、BERT等多个常见模型的预训练,当下大模型流行,故以pretrain_gpt.py为例做源码的走读。

一. 启动pretrain_gpt.py

  pretrain_gpt.py为GPT类模型的训练入口,它通过命令行形式被调用,其精确执行路径位于Megatron-LM框架的examples/gpt3目录下。具体而言,启动过程依赖于train_gpt3_175b_distributed.sh这一脚本,该脚本专为部署GPT-3模型在分布式环境下训练而设计(当然,也可以参照编写自定义的启动脚本)。

  在train_gpt3_175b_distributed.sh脚本内部,核心操作是通过torchrun命令实现的,该命令是PyTorch分布式训练的一部分,用于在多个计算节点上高效并行地执行pretrain_gpt.py。此过程确保了模型训练任务能够充分利用集群资源,加速训练过程,代码如下图:

二. torchrun简介

  trochrun是PyTorch官方推荐用于替代torch.distributed.launch的分布式数据并行训练模块。它旨在提供一种更灵活、更健壮的方式来启动和管理分布式训练任务。

  trochrun启动并行训练任务的原理如下:

1. 初始化分布式环境

  trochrun首先负责初始化分布式训练所需的环境。这包括设置通信后端(如NCCL、GLOO等)、分配工作进程的RANK和WORLD_SIZE(即参与训练的总进程数),以及处理其他与分布式训练相关的配置。

2. 分配工作进程

  trochrun会根据指定的参数(如--nnodes、--nproc-per-node等)来分配工作进程。这些进程可以是同一台机器上的多个 GPU,也可以是跨多台机器的GPU。每个进程都会加载相同的训练脚本(如pretrain_gpt.py),但会处理不同的数据子集,以实现并行训练。

3. 同步与通信

  在训练过程中,torchrun 管理下的各个工作进程需要频繁地进行同步和通信。这包括梯度同步(在反向传播后同步各GPU上的梯度)、参数更新(使用同步后的梯度更新模型参数)等。PyTorch提供了丰富的API(如torch.distributed.all_reduce、torch.distributed.barrier等)来支持这些操作。

4. 优雅处理故障

  trochrun相比torch.distributed.launch的一大改进是它能够更优雅地处理工作进程的故障。例如,如果某个工作进程因为某种原因崩溃了,torchrun可以尝试重新启动该进程,以确保训练任务的连续性。此外,torchrun还支持弹性训练(elastic training),即允许在训练过程中动态地增加或减少工作进程的数量。

5. 简化配置与启动

  trochrun通过提供命令行接口和配置文件选项来简化分布式训练的配置和启动过程。用户只需指定少量的参数(如节点数、每节点进程数等),即可启动复杂的分布式训练任务。此外,torchrun还支持从环境变量中读取配置信息,这使得在不同环境中部署训练任务变得更加灵活。

6. 自动化资源分配

  在某些情况下,torchrun还可以与资源管理器(如Kubernetes、Slurm等)集成,以自动化地分配和管理训练所需的计算资源。这包括GPU、CPU、内存和存储等资源。通过集成资源管理器,torchrun可以进一步提高分布式训练的可扩展性和灵活性。

  总之,torchrun通过以上机制共同作用,使得使用PyTorch进行分布式训练变得更加高效、可靠和易于管理。

三. 主要函数

  pretrain_gpt.py脚本封装了多个核心功能组件,具体包括model_provider(),forward_step(),train_valid_test_datasets_provider(),以及pretrain()等主要函数。

  其中,model_provider()负责提供预训练所需的模型实例对象;forward_step()定义了模型前向传播的具体步骤,包括输入处理、模型计算等;train_valid_test_datasets_provider()则负责准备训练、验证及测试所需的数据集,确保数据的有效供给。

  值得注意的是,前三个函数model_provider(),forward_step(),train_valid_test_datasets_provider()更是作为pretrain()函数的入参,共同构成了GPT模型训练入口。

  这种设计确保了预训练过程的模块化、灵活性与可扩展性。下面会从model_provider()开始逐行解析源码。

四. 源码分析

1. model_provider

  def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

    If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.

    Args:
        pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
        post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.


    Returns:
        Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
    """
    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"

    print_rank_0('building GPT model ...')
    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=True,
            pre_process=pre_process,
            post_process=post_process,
        )
    else: # using core models
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if use_te:
                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
            else:
                transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=True,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
        )

    return model

  model_provider函数是用于构建GPT(生成预训练Transformer)模型实例的函数,它会以类似函数指针的形式,作为pretrain()的入参传递到后续的训练代码中,供训练过程调用。

  该函数主要代码流程包含以下几个步骤:

  a. 获取参数和配置:

  通过 get_args() 函数获取命令行参数和配置文件中的参数。根据 args.transformer_impl 的值确定是否使用 Transformer Engine (use_te)。

  b. 配置模型:

  如果指定了 YAML 配置文件 (args.yaml_cfg),则从 YAML 文件中加载模型结构。否则,根据命令行参数 (args) 加载。

  c. 选择模型类型:

  如果args.use_legacy_models为True,则使用megatron.legacy.model.GPTModel构建模型。这通常用于向后兼容或测试旧版本的模型。

  如果不使用旧版模型,则直接运行构建GPTModel,如图红框部分。

  其中,GPTModel的参数包括配置 (config)、词汇表大小 (vocab_size)、最大序列长度 (max_sequence_length)、是否进行前处理和后处理、是否使用 FP16 进行语言模型交叉熵计算、是否并行输出等。这些参数基本上都来源于配置文件,关于配置文件的内容和解析将于下文详述。

  注:此处的model_provider是作为函数指针传递到pretrain()中,函数指针只有在调用时才会真正执行,故,GPTModel的具体实现待到执行时再具体分析。

2. forward_step

def forward_step(data_iterator, model: GPTModel):
    """Forward training step.

    Args:
        data_iterator : Input data iterator
        model (GPTModel): The GPT Model
    """
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator', log_level=2).start()
    global stimer
    with stimer(bdata=True):
        tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
            data_iterator)
    timers('batch-generator').stop()

    with stimer:
        output_tensor = model(tokens, position_ids, attention_mask,
                              labels=labels)

    return output_tensor, partial(loss_func, loss_mask)

  forward_step顾名思义,这个函数是GPT模型训练过程中前向处理函数,负责处理一批输入数据并通过模型进行前向传播。

  该函数的核心实现,仍然是对model(forward_step函数的入参)的forward的调用,只是在调用之前封装了计时器计时逻辑以及批次数据获取的逻辑(这部分逻辑会根据不同的业务场景变化而变化,故不能直接封装到model的forward函数中,而是应该在pretrain脚本中实现),具体代码流程如下:

  a. 获取参数和计时器:

  通过get_args()和get_timers()函数分别获取训练参数和计时器对象,用于控制训练过程和记录时间消耗。

  b. 获取批次数据:

  使用timers对象记录获取批次数据的时间(可选,通过log_level=2控制)。

  调用get_batch函数从data_iterator中获取一批数据,包括tokens(输入文本对应的token IDs)、labels(训练标签,通常用于计算损失,对于语言模型任务,labels通常是tokens的右移一位版本)、loss_mask(损失掩码,用于忽略某些位置的损失计算,如填充位置)、attention_mask(注意力掩码,用于指示哪些位置需要参与注意力计算)和position_ids(位置ID,用于模型中的位置编码)。

  c. 模型前向传播:

  使用stimer(可能是一个自定义的计时器)记录模型前向传播的时间。

  将获取到的数据(tokens, position_ids, attention_mask, labels)传递给模型model进行前向传播。这里labels是可选的,用于计算损失,但在前向传播阶段不一定需要。

  模型输出output_tensor,通常包含模型的预测结果(如logits)。

  d. 返回输出和损失函数:

  返回output_tensor和partial函数,该函数需要loss_mask作为参数来计算损失。这种方式允许延迟损失的计算,直到所有相关的数据都已准备好。

3. train_valid_test_datasets_provider

def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build the train test and validation datasets.

    Args:
        train_val_test_num_samples : A list containing the number of samples in train test and validation.
    """
    args = get_args()

    config = core_gpt_dataset_config_from_args(args)

    if args.mock_data:
        dataset_type = MockGPTDataset
    else:
        dataset_type = GPTDataset

    print_rank_0("> building train, validation, and test datasets for GPT ...")

    train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
        dataset_type,
        train_val_test_num_samples,
        is_dataset_built_on_rank,
        config
    ).build()

    print_rank_0("> finished creating GPT datasets ...")

    return train_ds, valid_ds, test_ds

  该函数接收一个参数train_val_test_num_samples,这是一个列表,包含了训练集、验证集和测试集的样本数量。函数的目的是根据提供的参数和配置,构建GPT模型的训练、验证和测试数据集。主要代码流程如下:

  a. 获取参数和配置:

  使用get_args()函数获取训练过程中的全局参数,并通过core_gpt_dataset_config_from_args根据这些参数生成数据集配置对象config。

  b. 确定数据集类型:

  根据args.mock_data的值决定使用哪种数据集类型。如果mock_data为True,则使用MockGPTDataset,这是一种模拟数据集,可能用于测试或快速原型开发。如果mock_data为False,则使用GPTDataset,这是实际的数据集类型,包含真实的训练数据。

  c. 构建数据集:

  使用BlendedMegatronDatasetBuilder类来构建数据集,传递给BlendedMegatronDatasetBuilder的参数包括数据集类型dataset_type、训练/验证/测试集的样本数量train_val_test_num_samples、is_dataset_built_on_rank(用于检查当前处理单元是否负责构建数据集),以及配置对象config。

  调用build()方法实际构建数据集,该方法返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds。

  其中BlendedMegatronDatasetBuilder来源于包“megatron.core.datasets.blended_megatron_dataset_builder”,由于数据集构建逻辑比较简单,故,在此不做详述,有兴趣的同学可以自行查看。

  d. 返回值

  函数返回三个数据集对象:训练集train_ds、验证集valid_ds和测试集test_ds,这些对象可以用于后续的训练、验证和测试过程。

4. pretrain

  pretrain函数是megatron/pretrain_gpt.py文件中的一个执行入口,通常会将该函数写于文件的末尾。该函数被第一章中的启动脚本调用,进而开启训练流程。

  pretrain函数的入参如下:

  train_valid_test_datasets_provider:这是第3小节分析的函数指针,负责提供训练、验证和测试数据集。

  model_provider:这是第1小节分析的函数指针,负责提供GPT模型的实例。它可能根据传入的配置或参数来初始化模型。

  ModelType.encoder_or_decoder:这个参数指定了模型的类型,这里是编码器或解码器(对于GPT模型,它实际上是一个解码器)。

  forward_step:这是第2小节分析的函数指针,定义了模型训练过程中的一个前向传播步骤,包括数据的前向传递和损失的计算。

  args_defaults:这是一个字典,包含了预训练过程中一些默认参数的键值对。在这个例子中,它指定了默认的tokenizer_type为GPT2BPETokenizer,这意味着在文本预处理时将使用基于BPE(Byte Pair Encoding)的GPT-2分词器。

  至此,pretrain_gpt.py的源码基本解析完毕,下一篇文章将以pretrain函数为入口,跟随代码运行流程,深入其内部实现,详细解析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1940904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

n7.Nginx 第三方模块

Nginx 第三方模块 第三模块是对nginx 的功能扩展,第三方模块需要在编译安装Nginx 的时候使用参数–add-modulePATH指定路径添加,有的模块是由公司的开发人员针对业务需求定制开发的,有的模块是开 源爱好者开发好之后上传到github进行开源的模…

初学Linux之常见指令(上)

初学Linux之常见指令(上) 文章目录 初学Linux之常见指令(上)1. Linux下的小技巧热键man 指令 2. ls 指令3. pwd 指令4. cd 指令5. tree 指令6. touch 指令7. mkdir 指令8. rmdir 和 rm 指令9. cp 指令10. mv 指令 1. Linux下的小技…

微信小程序:vant-weapp 组件库、css 变量

vant-weapp 组件库 前往 vant-weapp 官网 npm 使用限制:不支持依赖于 Node.js 内置库、浏览器内置对象、C 插件 的包。 安装 vant-weapp # 通过 npm 安装 npm i vant/weapp -S --production# 通过 yarn 安装 yarn add vant/weapp --production# 安装 0.x 版本 npm i…

Hadoop3:MR程序处理小文件的优化办法(uber模式)

一、解决方案 1、在数据采集的时候,就将小文件或小批数据合成大文件再上传HDFS(数据源头) 2、Hadoop Archive(存储方向) 是一个高效的将小文件放入HDFS块中的文件存档工具,能够将多个小文件打包成一个HAR…

Git --- Branch Diverged

Git --- Branch Diverged Branch Diverged是如何形成的如何解决RebaseMerge Branch Diverged是如何形成的 尝试提交并将更改推送到 master 分支时,是否看到这条烦人的消息 原因是: 直到更改 B 之前,我的分支和“origin/master”完全相同。从…

通过HTML/CSS 实现各类进度条的功能。

需求:我们在开发中会遇到使用各式各样的进度条,因为当前插件里面进度条各式各样的,为了方便我们定制化的开发和方便修改样式,我们这里使用HTML和CSS样式来进行开发进度条功能。 通过本文学习我们会明白如何使用 HTML/CSS 创建各种…

Docker_一刀流_好用、好玩还方便_(持续更新)

Docker 简介一、镜像和容器的概念镜像(Image)容器(Container)镜像和容器的关系 宿主机二、Dockerfile2.1 什么是Dockerfile2.2 Dockerfile中的常见指令2.3Dockerfile实例2.4 详细扩展 三、镜像3.1 镜像的创建3.2对于镜像的一些常用…

开放原子校源行 | 湖南大学师生一行赴麒麟信安开展专业见习活动

“开放原子校源行”是开放原子开源基金会作为国家级开源公益平台发起的长期性开源教育推广公益项目。项目拟通过资助高校设立开源社团、推广开源课程、设置开源助学金、引导开源实践等方式培育开源人才,加快将开源文化、理念和技术融入校园,引导广大师生…

TCP/IP四层模型、OSI七层模型

OSI定义了网络互连的七层框架(物理层、数据链路层、网络层、传输层、会话层、表示层、应用层)TCP/IP 四层模型是目前被广泛采用的一种模型,由以下 4 层组成:应用层、传输层、网络层、网络接口层 FTP(File Transfer Pro…

Thinkphp开发文档二次整理版

基础部分 安装 环境要求 ​ *php>7.1.0 命令下载 通过Composer进行下载,操作步骤下载软件 phpstudy --->点击软件管理 --->安装Composer --->再点击网站 --->点击管理 --->点击Composer --->复制如下命令代码: ​ 稳定版&…

数据采集卡替代传统仪表的输出功能优势分析

在现代工业控制和科学研究中,数据采集卡和传统仪表在信号输出功能上各有其独特的优势。本期视频将对它们在这一方面的特点进行详细对比分析。 1. 多通道和多种信号输出支持 传统仪表:信号输出较为有限,通常只能提供固定的单一类型信号输出接…

解决 go 引用私有包,安装失败

问题描述 go mod tidy 或者 go run main.go 时,提示失败,例如 no such host(设置GOPRIVATE)或者 x509: certificate signed by unknown authority 之类的报错(设置GOINSECURE) 解决 在各种 insteadof 方…

Android 消息发布订阅框架:EventBus

目录 1.是什么 2.如何使用 3.五种线程模型 4.Eventbus2和Eventbus3的区别 一、是什么 EventBus是一款发布/订阅事件总线的框架,使用它可以进行模块间通信、解耦。它可以使用很少的代码,来实现多组件之间的通信,非常的方便。 为什么使用它…

安装caffe-CPU版本并进行训练

目录 前言 0、安装Ubuntu 18.04 版本 输入ls没有反应 ubuntu换源 换源出现的问题 1、安装caffe出现E:Unable to locate package caffe-cpu问题 2、把 code 文件夹下载到 ubuntu 3、在本地使用caffe-CPU,并部署数据标注工具 ATool 问题1 问题2 问题3 命令行…

C语言迷宫

目录 开头程序程序的流程图程序输入与输出的效果结尾 开头 大家好&#xff0c;我叫这是我58。今天&#xff0c;我要来看我用C语言编译出来的迷宫游戏。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <Windows.h> void printmaze(char s…

计算机网络基础:局域网、广域网及OSI七层模型解析

文章目录 一、局域网和广域网1、局域网&#xff08;LAN - Local Area Network&#xff09;2、广域网&#xff08;WAN - Wide Area Network&#xff09;3、对比局域网和广域网 二、OSI七层模型1、OSI的七层网络结构2、OSI的数据传输方式3、网络与操作系统的关系 一、局域网和广域…

“论系统安全架构设计及其应用”,写作框架,软考高级论文,系统架构设计师论文

论文真题 随着社会信息化进程的加快&#xff0c;计算机及网络已经被各行各业广泛应用&#xff0c;信息安全问题也变得愈来愈重要。它具有机密性、完整性、可用性、可控性和不可抵赖性等特征。信息系统的安全保障是以风险和策略为基础&#xff0c;在信息系统的整个生命周期中提…

【BUG】已解决:error: legacy - install - failure

error: legacy - install - failure 目录 error: legacy - install - failure 【常见模块错误】 【解决方案】 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页&#xff0c;我是博主英杰&#xff0c;211科班出身&#xff0c;就职于医疗科技公司&…

spring-boot 整合 redisson 实现延时队列(文末有彩蛋)

应用场景 通常在一些需要经历一段时间或者到达某个指定时间节点才会执行的功能&#xff0c;比如以下这些场景&#xff1a; 订单超时提醒收货自动确认会议提醒代办事项提醒 为什么使用延时队列 对于数据量小且实时性要求不高的需求来说&#xff0c;最简单的方法就是定时扫描数据…