TensorFlow入门(十二、分布式训练)

news2024/10/2 10:42:24

1、按照并行方式来分

        ①模型并行

                假设我们有n张GPU,不同的GPU被输入相同的数据,运行同一个模型的不同部分。

                在实际训练过程中,如果遇到模型非常庞大,一张GPU不够存储的情况,可以使用模型并行的分布式训练,把模型的不同部分交给不同的GPU负责。这种方式存在一定的弊端:①这种方式需要不同的GPU之间通信,从而产生较大的通信成本。②由于每个GPU上运行的模型部分之间存在一定的依赖,导致规模伸缩性差。

        ②数据并行

                假设我们有n张GPU,不同的GPU被输入不同的数据,运行相同的完整的模型。

                如果遇到一张GPU就能够存下一个模型的情况,可以采用数据并行的方式,这种方式的各部分独立,伸缩性好。

2、按照更新方式来分

        采用数据并行方式时,由于每个GPU负责一部分数据,涉及到如何更新参数的问题,因此分为同步更新和异步更新两种方式。

        ①同步更新

                所有GPU计算完每一个batch(也就是每批次数据)后,再统一计算新权值,等所有GPU同步新值后,再开始进行下一轮计算。

                同步更新的好处是loss的下降比较稳定,但是这个的坏处也很明显,这种方式有等待,处理的速度取决于最慢的那个GPU计算的时间。

        ②异步更新

                每个GPU计算完梯度后,无需等待其他GPU更新,立即更新整体权值并同步。

                异步更新的好处是计算速度快,计算资源能得到充分利用,但是缺点是loss的下降不稳定,抖动大。

3、按照算法来分

        ①Parameter Sever算法

                原理:假设我们有n张GPU,GPU0将数据分成n份分到各张GPU上,每张GPU负责自己那一批次数据的训练,得到梯度后,返回给GPU0上做累计,得到更新的权重参数后,再分发给各张GPU。

        ②Ring AllReduce算法

                原理:假设我们有n张GPU,它们以环形相连,每张GPU都有一个左邻和一个右邻,每张GPU向各自的右邻发送数据,并从它的左邻接近数据。循环n-1次完成梯度积累,再循环n-1次做参数同步。整个算法过程分两个步骤进行:首先是scatter_reduce,然后是allgather。在scatter-reduce,然后是allgather。在scatter-reduce步骤中,GPU将交换数据,使每个GPU可得到最终结果的一个块。在allgather步骤中,gpu将交换这些块,以便所有gpu得到完整的最终结果。

tf.distribute API:

        它是TensorFlow在多GPU、多机器上进行分布式训练用的API。使用这个API,可以在尽可能少改动代码的同时,分布式训练模型。

        它的核心API是tf.distribute.Strategy,只需简单几行代码就可以实现单机多GPU,多机多GPU等情况的分布式训练。

        它的主要优点:

                ①简单易用,开箱即用,高性能

                ②便于各种分布式Strategy切换

                ③支持Custom Training Loop、Estimator、Keras

                ④支持eager excution

tf.distribute.Strategy目前主要有四个Strategy:

        ①MirroredStrategy,即镜像策略

                MirroredStrategy用于单机多GPU、数据并行、同步更新的情况,它会在每个GPU上保存一份模型副本,模型中的每个变量都镜像在所有副本中。这些变量一起形成一个名为MirroredVariable的概念变量。通过apply相同的更新,这些变量保持彼此同步。

                创建一个镜像策略的方法如下:

                        mirrored_strategy = tf.distribute.MirroredStrategy()

                也可以自定义用哪些devices,如:

                        mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0","/gpu:1"])

                训练过程中,镜像策略用了高效的All-reduce算法来实现设备之间变量的传递更新。默认情况下它使用NVIDA NCCL (tf.distribute.NcclAllReduce)作为all-reduce算法的实现。通过apply相同的更新,这些变量保持彼此同步。

                官方也提供了其他的一些all-reduce实现方法,可供选择,如:

                        tf.distribute.CrossDeviceOps

                        tf.distribute.HierarchicalCopyAllReduce

                        tf.distribute.ReductionToOneDevice

        ②CentralStorageStrategy,即中心存储策略

                使用该策略时,参数被统一存在CPU里,然后复制到所有GPU上,它的优点是通过这种方式,GPU是负载均衡的,但一般情况下CPU和GPU通信代价比较大。

                创建一个中心存储策略的方法如下:

                             central_storage_strategy = tf.distribute.experimental.CentralStorageStratygy()

        ③MultiWorkerMirroredStrategy,即多端镜像策略

                该API和MirroredStrategy类似,它是其多机多GPU分布式训练的版本。

                创建一个多端镜像策略的方法如下:

                             multiworker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

        ④ParameterServerStrategy,即参数服务策略

                简称PS策略,由于计算速度慢和负载不均衡,很少使用这种策略。

                创建一个参数服务策略的方法如下:

                              ps_strategy = tf.distribute.experimental.ParameterServerStrategy()

示例代码如下:

import tensorflow as tf

#设置总训练轮数
num_epochs = 5
#设置每轮训练的批大小
batch_size_per_replica = 64
#设置学习率,指定了梯度下降算法中用于更新权重的步长大小
learning_rate = 0.001

#创建镜像策略
strategy = tf.distribute.MirroredStrategy()
#通过同步更新时副本的数量计算出本机的GPU设备数量
print("Number of devices: %d"% strategy.num_replicas_in_sync)
#通过副本数量乘以每轮训练的批大小,得出训练总数据量的大小
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

#函数将输入的图片调整为224x224大小,再将像素值除以255进行归一化,同时返回标签信息
def resize(image,label):
    image = tf.image.resize(image,[224,224])/255.0
    return image,label

#载入数据集并预处理
dataset,_ = tf.keras.datasets.cifar10.load_data()
images,labels = dataset
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)

#在strategy.scope下创建模型和优化器
with strategy.scope():
    #载入了MobileNetV2模型,该模型在ImageNet上预先训练好了,并可以在分类问题上进行微调
    model = tf.keras.applications.MobileNetV2()
    #设置训练时用的优化器、损失函数和准确率评测标准
    model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate),
        loss = tf.keras.losses.sparse_categorical_crossentropy,
        metrics = [tf.keras.metrics.sparse_categorical_accuracy]
        )
    
#执行训练过程
model.fit(dataset,epochs = num_epochs)

对于CIFAR-10数据集下载过慢的问题,可以手动去官网下载

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzicon-default.png?t=N7T8https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz下载完成后将其放在如下图的路径下,并将数据集文件改名为cifar-10-batches-py.tar.gz并解压

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1072185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

COM组件IDispatch操作

IDispatch 组件接口,继承IUnkown,实现了反射机制,可以通过invoke调用dll函数 一般执行过程需要GetIDsOfNames、InvokeHelper函数执行,queryinterface查询获取对象 检查GetIDsOfNames返回的dispid是否正确 COleDispatchDriver 单…

Windows搭建FTP服务器

以win10为例: 安装FTP服务器支持和IIS管理平台。 设置->应用->程序与功能->启用或关闭windows功能-> Internet Infomation Services->勾选【FTP服务器】和Web管理工具的【IIS管理控制台】-> 点击确定等待安装完成。 打开IIS管理器 此电脑-右键…

微信小程序使用CryptoJS加密PassWord(MD5)

微信小程序使用CryptoJS加密PassWord(MD5) 背景及环境: 微信小程序登录页面,需要加密登录密码发送给后端,使用 MD5 来加密密码 开发工具:微信开发者工具 npm安装CryptoJS 查看有哪些crypto的包 npm search crypto 找到自己需要的包…

【AI】将图片制作成可编辑的图标

需要工具:QQ截图、AI描摹 图标网站:有用的网站收藏夹_h5动漫引擎-CSDN博客 第一:随便找个图标网站,选中自己需要的图标【付费的也可以】【肯定是图标越大越清晰】【此处做示范没用最大的图标】,用QQ截图(…

微信小程序:实现列表单选

效果 代码 wxml <view class"all"><view class"item_all" wx:for"{{info}}" wx:key"index"><view classposition {{item.checked?"checked_parameter":""}} data-id"{{item.employee_num}}…

(五)Python字符串常用方法详解

在了解字符串的基本使用之后&#xff0c;本章将介绍 Python 字符串类型常用的几个方法。 在 Python 开发过程中&#xff0c;经常需要对字符串进行一些特殊处理&#xff0c;比如拼接字符串、截取字符串、格式化字符串等&#xff0c;这些操作无需开发者自己设计实现&#xff0c;…

[天翼杯 2021]esay_eval - RCE(disabled_function绕过||AS_Redis绕过)+反序列化(大小写wakeup绕过)

[天翼杯 2021]esay_eval 1 解题流程1.1 分析1.2 解题1.2.1 一阶段1.2.2 二阶段 二、思考总结 题目代码&#xff1a; <?php class A{public $code "";function __call($method,$args){eval($this->code);}function __wakeup(){$this->code "";…

学习记忆——数学篇——算术——无理数

谐音记忆法 2 \sqrt{2} 2 ​≈1.41421&#xff1a;意思意思而已&#xff1b;意思意思&#xff1b; 3 \sqrt{3} 3 ​≈1.7320&#xff1a;—起生鹅蛋&#xff1b;一起生儿&#xff1b; 5 \sqrt{5} 5 ​≈2.2360679&#xff1a;两鹅生六蛋(送)六妻舅&#xff1b;儿儿生&#xf…

Kafka的分布式架构与高可用性

导语 一开始我们就说过Kafka是一款开源的高吞吐、分布式的消息队列系统&#xff0c;那么今天我们就来说下它的分布式架构和高可用性以及双/多中心部署。 Kafka 体系架构简介 以下是 Kafka 的软件架构&#xff0c;整个 Kafka 体系结构由 Producer、Consumer、Broker、ZooKeepe…

小程序等轻应用技术是不是对企业有价值?

技术的持续迭代发展和用户使用习惯的养成&#xff0c;影响了企业业务载体和创新方式的改变。回看企业与用户交互技术载体的变革&#xff0c;发现曾经是PC软件&#xff0c;然后是网页&#xff0c;再后来是App&#xff0c;之后是小程序形态的轻应用。 移动互联网风起云涌的数十年…

【C语言】Linux平台下解析pcap文件

开发环境是readhat、ubuntu、kali 在wireshark上抓包需要使用 Wireshark/tcpdump/ 且 文件后缀名为.pcap 方式保存 效果如下&#xff1a; 引入俩文件如下。 my_pcap.h #pragma once #include <netinet/in.h>#define PCAP_MAGIC 0xa1b2c3d4typedef struct pcap_file_he…

TCP 和UDP通信流程

TCP 通信流程 根据上图可以看到&#xff0c;TCP 服务器和客户端通信分为 TCP 服务端和客户端&#xff0c;需要先建立服务 端然后再建立客户端与之连接进行数据交互。 服务端编程步骤&#xff1a; 1.使用 socket 创建流式套接字 2.使用 bind 绑定将服务器绑定到 IP 3.listen…

测试小白必掌握软件测试十大原则

软件测试是确保软件质量的重要手段之一&#xff0c;它可以检测软件中的各种缺陷和问题&#xff0c;从而提高软件的可靠性、可用性和安全性。软件测试也是一项极富创造性、极具挑战性的工作。为了尽可能发现软件中的错误&#xff0c;提高软件产品的质量&#xff0c;在软件测试的…

不用休眠的 Kotlin 并发:深入对比 delay() 和 sleep()

本文翻译自&#xff1a; https://blog.shreyaspatil.dev/sleepless-concurrency-delay-vs-threadsleep 毫无疑问&#xff0c;Kotlin 语言中的协程 Coroutine 极大地帮助了开发者更加容易地处理异步编程。该特性中封装的诸多高效 API&#xff0c;可以确保开发者花费更小的精力去…

2023年中国隆鼻行业发展历程及趋势分析:隆鼻手术市场将实现进一步增长[图]

隆鼻术就是以各种植入材料置入为主要方法&#xff0c;隆起或抬高鼻部形态为主要目的的鼻整形术式。隆鼻术可能是开展最多的整形美容手术之一。隆鼻术也是一种很成熟的美容手术&#xff0c;操作较为简单、安全、风险较小&#xff0c;也易于接受。 隆鼻行业分类 资料来源&#x…

【2023研电赛】安谋科技企业命题特别奖:面向独居老人的智能居家监护系统

本文为2023年第十八届中国研究生电子设计竞赛安谋科技企业命题特别奖分享&#xff0c;参加极术社区的【有奖活动】分享2023研电赛作品扩大影响力&#xff0c;更有丰富电子礼品等你来领&#xff01;&#xff0c;分享2023研电赛作品扩大影响力&#xff0c;更有丰富电子礼品等你来…

滚雪球学Java(43):探究 Java 中的 Class 类:透视类的本质和实现原理

&#x1f3c6;本文收录于「滚雪球学Java」专栏&#xff0c;专业攻坚指数级提升&#xff0c;助你一臂之力&#xff0c;带你早日登顶&#x1f680;&#xff0c;欢迎大家关注&&收藏&#xff01;持续更新中&#xff0c;up&#xff01;up&#xff01;up&#xff01;&#xf…

科普丨语音芯片选型应遵守的原则

在选择语音芯片时&#xff0c;设计者应该首先详细了解设计要求&#xff0c;并从要求中整理出电路功能模块和性能指标要求。根据功能和性能要求&#xff0c;制定总体设计方案。一般来说&#xff0c;选择语音芯片有以下要求&#xff1a; 1、 性价比&#xff1a;选择物美价廉的语…

16.(开发工具篇mysql)mysql不同库同步数据的异常记录

1:mysql导入时出现“ERROR at line : Unknown command ‘\‘‘.“的解决办法 default-character-set=utf82:ERROR 2006 (HY000) at line 71: MySQL server has gone away (1) 连接超时 查看各项连接时间: show global variables like %timeout;这些值是相对是MySQL的默认…

Redis AOF重写原原理

重写aof之前 appendonly.aof.1.base.aof appendonly.aof.1.incr.aof appendonly.aof.manifest 重写aof 一次 appendonly.aof.2.base.aof 大小变化 appendonly.aof.2.incr.aof 大小o appendonly.aof.manifest 大小不变 AOF文件重写并不是对原文件进行重新整理&#xff0c;而是直…