基于tensorflow2.x的多GPU并行训练

news2025/1/11 1:51:24
由于最近训练transformer,在单卡上显存不够,另外一块卡上也无法加载,故尝试使用双卡并行的策略。将基本的流程、遇见的难题汇总在这里。

双卡满载

分布策略解释

使用官方给出的tf.distribute.MirroredStrategy作为分布策略。这个策略通过如下的方式运行:
1)所有变量和模型计算图都会在副本之间复制。
2)输入都均匀分布在副本中。
3)每个副本在收到输入后计算输入的损失和梯度。
4)通过求和,每一个副本上的梯度都能同步。
5)同步后,每个副本上的复制的变量都可以同样更新。

正文

初始化分布策略

可以使用如下的命令,查看当前设备有几块GPU可以供使用。

strategy = tf.distribute.MirroredStrategy()
print(strategy.num_replicas_in_sync)

一、数据加载

使用分布式训练,会将总的batch分散到多块GPU上。我这里有两块GPU,使用的batch是32,那么在每个上面就是16。这里,在数据加载的时候就需要做处理,具体处理过程如下:

1)创建一个总的batchsize
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
2) 加载数据集
train_ds = DataLoader().make_batch(PARA['train_'], GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])
valid_ds = DataLoader().make_batch(PARA['vaild_'], GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])
test_ds = DataLoader().make_batch(PARA['testt_'],  GLOBAL_BATCH_SIZE, PARA['max_len_sequence'])

3)对数据做分发

train_ds = strategy.experimental_distribute_dataset(train_ds)
valid_ds = strategy.experimental_distribute_dataset(valid_ds)
test_ds = strategy.experimental_distribute_dataset(test_ds)

经过上面这些操作,数据已经处理好了,接下来处理训练策略。

二、 定义损失函数

:这里有几个地方需要特别注意,tf.losses/tf.keras.losses 中的损失函数通常会返回输入最后一个维度的平均值。损失类封装这些函数。在创建损失类的实例时传递 reduction=Reduction.NONE,表示“无额外缩减”。对于样本输入形状为 [batch, W, H, n_classes] 的类别损失,会缩减 n_classes 维度。对于类似 losses.mean_squared_errorlosses.binary_crossentropy 的逐点损失,应包含一个虚拟轴,使 [batch, W, H, 1] 缩减为 [batch, W, H]。如果没有虚拟轴,则 [batch, W, H] 将被错误地缩减为 [batch, W]
增加虚拟轴的方式也很简单,labels = labels[:, tf.newaxis]如果没有这个,回归模型是跑不起来的!!!

1)使用 tf.distribute.Strategy 时应如何计算损失?

例如,假设有 2 个 GPU,批次大小为 64。一个批次的输入会分布在各个副本(2 个 GPU)上,每个副本获得一个大小为 32 的输入。

每个副本上的模型都会使用其各自的输入进行前向传递,并计算损失。现在,不将损失除以其相应输入中的样本数 (BATCH_SIZE_PER_REPLICA = 32),而应将损失除以 GLOBAL_BATCH_SIZE (64)

之所以需要这样做,是因为在每个副本上计算完梯度后,会通过对梯度求和在副本之间同步梯度。

2)计算方法

如果使用自定义训练循环,则应将每个样本的损失相加,然后将总和除以 GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE),或者使用 tf.nn.compute_average_loss,它会将每个样本的损失、可选样本权重和 GLOBAL_BATCH_SIZE 作为参数,并返回经过缩放的损失。比较而言,选择tf.nn.compute_average_loss这个会好一些。

由于我这里使用的是 tf.keras.losses 类,则需要将损失归约显式指定NONE 或 SUM。与 tf.distribute.Strategy 一起使用时,不允许使用 AUTO 和 SUM_OVER_BATCH_SIZE不允许使用 AUTO,因为用户应明确考虑他们想要的归约量,以确保在分布式情况下归约量正确。不允许使用 SUM_OVER_BATCH_SIZE,因为当前它只能按副本批次大小进行划分,而将按副本数量划分留给用户,这可能很容易遗漏。因此,您需要自己显式执行归约操作。

我做的是回归任务,具体的代码如下,可以看到,loss损失里面使用了reduction=tf.keras.losses.Reduction.NONE,返回损失值的时候使用了tf.nn.compute_average_loss

GLOBAL_BATCH_SIZE = PARA['batch_size']*strategy.num_replicas_in_sync
with strategy.scope():
    # Set reduction to `NONE` so you can do the reduction afterwards and divide by
    # global batch size.
    loss_object = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)
    def compute_loss(labels, predictions):
        # 这里有个坑,见最开始的注
        # 使用Reduction.NONE之后,回归损失会减少一个维度,故要在后面添加一列
        # https://tensorflow.google.cn/tutorials/distribute/custom_training?hl=zh-cn
        labels = labels[:,tf.newaxis]
        predictions = predictions[:, tf.newaxis]
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

三、定义评价指标

评价指标根据自己的实际情况来,我这里使用了loss跟rmse

with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    train_rmse = tf.keras.metrics.RootMeanSquaredError(name='train_rmse')
    test_rmse = tf.keras.metrics.RootMeanSquaredError(name='test_rmse')

四、初始化模型

模型、优化器和checkpoint务必要放在strategy.scope

with strategy.scope():
    model = Transformer(PARA['num_layers'], PARA['input_vocab_size'], PARA['target_vocab_size'],PARA['target_class'],PARA['max_len_sequence'],PARA['d_model'],PARA['num_heads'],PARA['dff'],rate=PARA['dropout_rate'])
    # 加载优化器:
    learning_rate = CustomizedSchedule(PARA['d_model'])
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
    # 记录模型
    check = tf.train.Checkpoint(model=model, optimizer=optimizer)
    check_manager = tf.train.CheckpointManager(check, PARA['model_save'], max_to_keep=5)
    if check_manager.latest_checkpoint:
        check.restore(check_manager.latest_checkpoint)

五、构建训练策略

1) 先构建并行的策略,再构建train_step

with strategy.scope():
    # `run` replicates the provided computation and runs it
    # with the distributed input.
    @tf.function
    def distributed_train_step(dataset_inputs):
        per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

    @tf.function
    def distributed_test_step(dataset_inputs):
        return strategy.run(test_step, args=(dataset_inputs,))

2) 构建train_step

def train_step(inputs):
    train_rmse.reset_states()
    sequence, tm, label = inputs
    with tf.GradientTape() as tape:
        predictions = model(sequence, training=True)
        loss = compute_loss(tm, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_rmse.update_state(tm, predictions)
    return loss

def test_step(inputs):
    sequence, tm, label = inputs
    predictions = model(sequence, training=False)
    t_loss = loss_object(tm, predictions)
    test_loss.update_state(t_loss)
    test_rmse.update_state(tm, predictions)

六、自定义训练过程

def fit(train_ds, valid_ds, test_ds):
        steps = 0
        start = time.time()
        for epoch in range(PARA['EPOCH']):
            # TRAIN LOOP
            total_loss, num_batches, batch = 0.0, 0, 0
            for (batch, x) in enumerate(train_ds):
            	# 这里返回每一个批次的损失值
                per_loss= distributed_train_step(x)
                total_loss += per_loss
                steps += 1
                # 这是自定义的记录函数,可以直接print当前值
                save_smurry('train','-', epoch, batch, steps, [per_loss, train_rmse.result()])
                
                if batch % (PARA['REPORT_STEP']*2) == 0 and batch:
                # 每次处理完之后,需要对test_loss及test_rmse做重置
                    for (batch, x) in enumerate(valid_ds):
                        distributed_test_step(x)
                    # 这里需要得到的是在整个验证集上的结果
                    save_smurry('vaild','-', epoch, batch, steps, [test_loss.result(), test_rmse.result()])
                    test_loss.reset_states()
                    test_rmse.reset_states()

                # 每50次做一次benchmark验证
                if batch % (PARA['REPORT_STEP']*5) == 0 and batch:
                    for x in test_ds:
                        distributed_test_step(x)
                    save_smurry('test','-', epoch, batch, steps, [test_loss.result(), test_rmse.result()])
                    test_loss.reset_states()
                    test_rmse.reset_states()
        time_used = 'Time take for 1 epoch:{} secs\n'.format(time.time()-start)
        fout(time_used)

至此,分布程序构建完成。欢迎一起讨论

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/453307.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Echarts渲染行政区划,实现聚焦高亮交互

首先需要准备行政区划的JSON数据&#xff0c;可以在DataV获取省市区的JSON数据。 最终效果图 渲染地图 建立一个地图容器&#xff0c;注意要给宽高 <!-- 地图容器 --> <div id"map"></div>请求JSON数据&#xff0c;渲染地图 $(function() {var …

Ubuntu 20版本将动态ip修改为静态ip时,ping 不通网络

问题描述&#xff1a; 在对Ubuntu 20版本将动态ip修改为静态ip时&#xff0c;ping www.baidu.com ping不通了 火狐浏览器没有了网路&#xff0c;下载不了东西 一直卡在这里不动 问题出在哪里还是配置ip dns 网关的问题 如果我们在当初安装ubuntu 时&#xff0c;将网络设置成…

24年专转本想要成功我们一个怎样做

23年转本报名人数创造高峰&#xff0c; 24年转本的同学们 如何脱颖而出&#xff0c;成功转本呢&#xff1f;一、明确转本目的 转本是一场重要的考试&#xff0c;有人把转本比喻为第二次高考。面对这唯一的进入本科院校学习的机会&#xff0c;考还是不考&#xff1f; 很多同…

小六壬学习笔记

小六壬学习笔记 简介前置知识:十二地支和十二时辰适用范围起课&#xff1a;月令日时卦象 疑问&#xff1a;遇到闰月怎么办&#xff1f;禁忌数字起课法手机计算器取余数 简单解卦 简介 马前课&#xff0c;又名&#xff1a;小六壬。 小六壬历史渊源&#xff1a;https://m.sohu.c…

RXJava2的基本概念与常见操作符使用实例解析

RXJava2是什么&#xff1f;可以简单介绍一下其特点和应用场景吗&#xff1f; RXJava2是基于观察者模式和链式编程思想的异步编程库&#xff0c;它可以用来优雅地处理异步操作&#xff0c;比如网络请求、数据库查询、文件I/O等操作&#xff0c;减少了回调嵌套&#xff0c;提高了…

【LeetCode】剑指 Offer 68. 二叉树中两个节点的最低公共祖先 p326 -- Java Version

1. 题目介绍&#xff08;68. 二叉树中两个节点的最低公共祖先&#xff09; 面试题68&#xff1a;二叉树中两个节点的最低公共祖先&#xff0c; 一共分为两小题&#xff1a; 题目一&#xff1a;二叉搜索树的最近公共祖先题目二&#xff1a;二叉树的最近公共祖先 2. 题目1&#x…

目标检测 pytorch复现SSD目标检测项目

目标检测 pytorch复现SSD目标检测项目 0、简介1、模型整体框架&#xff08;以VGG16为特征提取网络&#xff09;3、默认框&#xff08;default box&#xff09;的生成--相当于Faster-RCNN中生成的anchor4、预测层的实现原理&#xff1a;5、正负样本的选取6、损失的计算原理6、以…

SpringCloud-9、Sleuth+Zipkin

先吐槽下csdn&#xff0c;编辑器不知道怎么回事&#xff0c;快捷键一下就没有&#xff0c;现在用起来糟心 --- - 这些都用不了&#xff0c;求帮助。 基本介绍 Sleuth:分布式服务跟踪组件 /ZipKin Sleuth/ZipKin-搭建链路监控实例 官网&#xff1a;GitHub - spring-cloud/s…

【移动端网页布局】移动端网页布局基础概念 ⑦ ( 在 PhotoShop 中使用 Cutterman 切二倍图 | 使用二倍图作为背景图像 )

文章目录 一、在 PhotoShop 中使用 Cutterman 切二倍图二、使用二倍图作为背景图像 一、在 PhotoShop 中使用 Cutterman 切二倍图 参考 【CSS】PhotoShop 切图 ③ ( PhotoShop 切图插件 - Cutterman | 下载、安装、启动、注册、登录 Cutterman - 切图神奇 插件 | 使用插件进行切…

selenium应用之抓取b站黑马视频目录建立学习计划Excel

需求故事&#xff1a; 最近时间一下子多了起来&#xff0c;用来学习Java是最合适不过了&#xff0c;但是去b站看视频难免会没有自制力&#xff0c;于是决定用selenium来抓取b站黑马Java视频的目录创建一个学习计划的Excel&#xff0c;便于进行学习进度的管理。 注&#xff1a;纯…

【无模型自适应】基于紧格式动态线性化的无模型自适应控制matlab代码

例题来源&#xff1a;侯忠生教授的《无模型自适应控制&#xff1a;理论与应用》&#xff08;2013年科学出版社&#xff09;。 对应书本 4.2 单输入单输出系统(SISO)紧格式动态线性化(CFDL)的无模型自适应控制(MFAC) 例题4.1 题目要求 matlab代码 clc; clear all;%% 期望轨迹…

【opencv】图像数字化——矩阵的运算( 5 乘法运算)

5 乘法运算 5.1使用“*”运算符 对于Mat对象的乘法&#xff0c;两个Mat只能同时是float或者double类型&#xff0c;对于其它数据类型的矩阵乘法会报错src1的列数等于src2的行数mn * npmp #include <opencv2/core/core.hpp> #include<iostream> using namesp…

Android程序员向音视频进阶,有前景吗

随着移动互联网的普及和发展&#xff0c;Android开发成为了很多人的就业选择&#xff0c;希望在这个行业能获得自己的一席之地。然而&#xff0c;随着时间的推移&#xff0c;越来越多的人进入到了Android开发行业&#xff0c;就导致目前Android开发的工作越来越难找&#xff0c…

【博学谷学习记录】超强总结,用心分享 | 架构师 MinIO学习总结

文章目录 MinIO对象存储的概念计算机数据存储系统-架构模式对象存储的优势常见的对象存储系统/服务&#xff08;Object Storage Service&#xff0c;OSS&#xff09; MinIO简介特点高级特性小结 MinIO部署基于 linux Binary 部署 MinIO ServerMinIO数据组织结构MinIO Client**基…

【论文精读】Emergent Abilities of Large Language Models

1. Emergence 涌现&#xff08;emergence&#xff09;或称创发、突现、呈展、演生&#xff0c;是一种现象&#xff0c;为许多小实体相互作用后产生了大实体&#xff0c;而这个大实体展现了组成它的小实体所不具有的特性。 水分子聚集后组成了雪花是一个物理上的创发现象 扩大&…

C++ 类和对象(上)

类 面向对象的三大特性&#xff1a;封装&#xff0c;继承&#xff0c;多态 C语言结构体中只能定义变量&#xff0c;在C中&#xff0c;结构体内不仅可以定义变量&#xff0c;也可以定义函数。比如&#xff1a; 之前在数据结构初阶中&#xff0c;用C语言方式实现的栈&#xff0c;…

springboot入门和yaml数据格式和读取yaml型数据和多环境配置和命令行启动参数设置

springboot入门 搞掉了手动的spring&#xff0c;mybatis&#xff0c;springmvc配置类&#xff0c;只需要创建一个控制类即可 控制类&#xff1a; package com.itjh.controller;import org.springframework.web.bind.annotation.*;RestController RequestMapping("/book…

KDYZ-YM压敏电阻测试仪

一、概述 晶闸管的伏安特性是晶闸管的基本特性&#xff0c;这项特性的好坏&#xff0c;直接影响到器件在整机上的正常使用。因此&#xff0c;检测晶闸管的伏安特性在晶闸管器件的生产、经销及使用过程中都是十分重要的。 该测试仪的测试方法符合国标JB/T7624-94《整流二极管测试…

AI:人工智能领域AI工具产品集合分门别类(文本类、图片类、编程类、办公类、视频类、音频类、多模态类)的简介、使用方法(持续更新)之详细攻略

AI&#xff1a;人工智能领域AI工具产品集合分门别类(文本类、图片类、编程类、办公类、视频类、音频类、多模态类)的简介、使用方法(持续更新)之详细攻略 导读&#xff1a;由于ChatGPT、GPT-4近期火爆整个互联网&#xff0c;掀起了人工智能相关的二次开发应用的热潮&#xff0c…

MySQL 的 Replace into 与 Insert into on duplicate key update 真正的不同之处

相同点&#xff1a; &#xff08;1&#xff09;没有key的时候&#xff0c;replace与insert .. on deplicate udpate相同。 &#xff08;2&#xff09;有key的时候&#xff0c;都保留主键值&#xff0c;并且auto_increment自动1。 不同点 有key的时候&#xff0c;replace是dele…