机器学习洞察 | 分布式训练让机器学习更加快速准确

news2025/1/17 3:02:38

机器学习能够基于数据发现一般化规律的优势日益突显,我们看到有越来越多的开发者关注如何训练出更快速、更准确的机器学习模型,而分布式训练 (Distributed Training) 则能够大幅加速这一进程。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

关于在亚马逊云科技上进行分布式训练的话题,在各种场合和论坛我们讨论了很多。随着 PyTorch 这一开源机器学习框架被越来越多的开发者在生产环境中使用,我们也将围绕它展开话题。本文我们将分别探讨在 PyTorch 上的两种分布式训练:数据分布式训练,以及模型分布式训练。

首先我们来看看当今机器学习模型训练的演进趋势中,开发者对模型训练结果的两种迫切需求:

  • 更快速

  • 更准确

更快速的数据分布式训练

对于机器学习模型训练来说,将庞大的训练数据有效拆分,有助于加快训练速度。

常见的数据分布式训练方式有两种:

基于参数服务器的数据分布式训练

异步)参数服务器 (Parameter Server) : 如 TensorFlow Parameter Server Strategy

对于参数服务器 (Parameter Server) 来说,计算节点被分成两种:

  • Workers:保留一部分的训练数据,并且执行计算;

  • Servers:共同维持全局共享的模型参数。

而 Workers 只和 Servers 有通信,Workers 相互之间没有通信。

参数服务器方式的优点开发者都很熟悉就不赘述了,而参数服务器的一个主要问题是它们对可用网络带宽的利用不够理想,Servers 常常成为通信瓶颈

由于梯度在反向传递期间按顺序可用,因此在任何给定的时刻,从不同服务器发送和接收的数据量都存在不平衡。有些服务器正在接收和发送更多的数据,有些很少甚至没有。随着参数服务器数量的增加,这个问题变得更加严重。

基于 Ring ALL-Reduce 的数据分布式训练

(同步)Ring All-Reduce: 如 Horovod 和 PyTorch DDP

Ring All-Reduce 的网络连接是一个环形,这样就不需要单独的 GPU 做 Server。6 个 GPU 独立做计算,用各自的数据计算出各自的随机梯度,然后拿 6 个随机梯度的相加之和来更新模型参数。为了求 6 个随机梯度之和,我们需要做 All-Reduce。在全部的 GPU 都完成计算之后,通过 Ring All-Reduce 转 2 圈(第 1 圈加和,第 2 圈广播这个加和),每个 GPU 就有了 6 个梯度的相加之和。注意算法必须是同步算法,因为 All-Reduce 需要同步(即等待所有的 GPU 计算出它们的梯度)。

Ring All-Reduce 的主要问题是:

  • 通过 Ring All-Reduce 转圈传递信息时,例如:G0 传递给 G1 时,其它 GPU 都在闲置状态;因此,这种步进时间越长,GPU 闲置时间就越长;而 GPU 越多这种通信代价就越大;

  • All-Reduce 的资源会占用宝贵的 GPU 资源,所以会在扩展的时候,面临效率挑战。

实例:Amazon SageMaker 数据并行的分布式方法

那么如何尽可能消除上述弊端?我们通过亚马逊云科技在 Amazon SageMaker 上的数据并行实例来演示如何解决这一问题。

SageMaker 从头开始构建新的 All-Reduce 算法,以充分利用亚马逊云科技网络和实例拓扑,利用 EC2 实例之间的节点到节点通信。

这样做的优势在于:

  • 引入了一种名为平衡融合缓冲区的新技术,以充分利用带宽。GPU 中的缓冲区将梯度保持到阈值大小,然后复制到 CPU 内存,分片成 N 个部分,然后将第 i 个部分发送到第 i 个服务器。平衡服务器发送和接收的数据,有效利用带宽。

  • 可以有效地将 All-Reduce 从 GPU 转移到 CPU。

我们能够重叠向后传递和 All-Reduce,从而缩短步进时间,释放 GPU 资源用于计算。

在这里分享关键的 PyTorch 代码步骤:

  1. 更新训练脚本

与非分布式训练不同的是,在这里我们输入 mdistributed.dataparallel.torch.torch_smdbp 的模型:

# Import SMDataParallel PyTorch Modules
import smdistributed.dataparallel.torch.torch_smddp
  1. 提交训练任务

在这里指定一个开关,打开数据并行即可。这样可以非常方便地调试,而不用在底层配置上花费时间。

# Training using SMDataParallel Distributed Training Framework
      distribution={"smdistributed": 
          {"dataparallel": 
                 {"enabled": True
                 }
          }
       },
      debugger_hook_config=False,

您可以在 GitHub 上查看完整代码示例:

Amazon Sagemaker Examples

  • https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/data_parallel/mnist/code/train_pytorch_smdataparallel_mnist.py?trk=cndc-detail

  • https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb?trk=cndc-detail

更准确的模型分布式训练

众所周知,模型越大,那么预测结果的准确度越高。

那么在面对庞大模型的时候如何进行模型并行?我们推荐开发者使用以下方式:

自动模型拆分

主要的优化策略基于内存使用和计算负载,从而更好地实现大模型的兼容。

  • 平衡内存使用:平衡每台设备上存储的可训练参数和激活次数的总数。

  • 平衡计算负载:平衡每台设备中执行的操作次数。

流水线执行计划

Amazon Sagemaker PyTorch SDK 中可以选择两种方式实现:

  • 简单流水线:需要等前项全部计算完之后才能进行后项的计算。

  • 交错流水线:通过更高效利用 GPU 来实现更好的性能,包括模型并行等方式。

Amazon SageMaker 分布式模型并行库的核心功能是流水线执行 (Pipeline Execution Schedule) ,它决定了模型训练期间跨设备进行计算和数据处理的顺序。流水线是一种通过让 GPU 在不同的数据样本上同时进行计算,从而在模型并行度中实现真正的并行化技术,并克服顺序计算造成的性能损失。

流水线基于将一个小批次拆分为微批次,然后逐个输入到训练管道中,并遵循库运行时定义的执行计划。微批次是给定训练微型批次的较小子集。管道调度决定了在每个时隙由哪个设备执行哪个微批次。例如,根据流水线计划和模型分区,GPU i 可能会在微批处理 b 上执行(向前或向后)计算,而 GPU i+1 对微批处理 b+1 执行计算,从而使两个 GPU 同时处于活动状态。

该库提供了两种不同的流水线计划,简单式和交错式,可以使用 SageMaker Python SDK 中的工作流参数进行配置。在大多数情况下,交错流水线可以通过更高效地利用 GPU 来实现更好的性能。

更多相关信息可参考:

SageMaker 模型并行库的核心功能 - 亚马逊 SageMaker

PyTorch 模型并行的分布式训练的关键步骤如下(以PyTorch SageMaker Distributed Model Parallel GPT2 代码为例):

  1. 更新训练脚本

a. Import 模型并行模块

b. 带入参数的模型并行的初始化 smp.int (smp_config)

详细代码请参考:https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py?trk=cndc-detail

  1. 提交训练任务

a. 消息传递接口 (MPI) 是编程并行计算机程序的基本通信协议,这里可描述每台机器上的 GPU 数量等参数

b. 激活模型分布式训练框架和相关配置等

详细代码请参考:https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py?trk=cndc-detail

实例:Amazon SageMaker 分布式训练

训练医疗计算机视觉 (CV) 模型需要可扩展的计算和存储基础架构。下图案例向您展示如何将医疗语义分割训练工作负载从 90 小时减少到 4 小时。

图片来源:官方博客《使用 Amazon SageMaker 训练大规模医疗计算机视觉模型》

解决方案中使用了 Amazon SageMaker 处理进行分布式数据处理,使用 SageMaker 分布式训练库来加快模型训练。数据 I/O、转换和网络架构是使用 PyTorch 和面向人工智能的医疗开放网络 (MONAI) 库构建的。

在下篇文章中,我们将继续关注无服务器推理,请持续关注 Build On Cloud 微信公众号。

往期推荐

  • 机器学习洞察 | 挖掘多模态数据机器学习的价值

作者黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

 文章来源:https://dev.amazoncloud.cn/column/article/63e32dd06b109935d3b77259?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738987.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

限制远程访问,保障服务器安全,如何指定某台电脑远程本服务器?

好多人都在问,如何限制某台电脑远程访问本服务器是一个必须要解决的问题。下面,我将为大家介绍几种限制远程访问的方法,帮助大家保障服务器的安全性。 1.修改远程桌面端口号 默认情况下,Windows服务器的远程桌面端口号…

时序预测 | Matlab+Python实现基于高斯混合模型聚类结合CNN-BiLSTM-Attention的风电场短期功率预测

时序预测 | MatlabPython实现基于高斯混合模型聚类结合CNN-BiLSTM-Attention的风电场短期功率预测 目录 时序预测 | MatlabPython实现基于高斯混合模型聚类结合CNN-BiLSTM-Attention的风电场短期功率预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 基于高斯混…

1.Git使用技巧-常用命令3

1.Git使用技巧-常用命令3 文章目录 1.Git使用技巧-常用命令3一、版本分支介绍二、版本控制常用命令例子 三、git 仓库如何使用总结 一、版本分支介绍 分支介绍: Master : 稳定压倒一切,禁止尚review和测试过的代码提交到这个分支上&#xff…

1.2 向量基础

什么是向量 向量的定义 ①向量是有大小和方向的有向线段。 ②向量没有位置,只有大小和方向 ③向量的箭头是向量的结束,尾是向量的开始 ④向量魔术的位移能被认为是宇宙平行的唯一序列 (向量的数组不是向量的位置,而是向量在各个维…

C++多线程学习(十七、简单实现线程池)

目录 线程池 设计线程池的关键问题 代码 可能出现的疑问 queue> task; 总结: template auto InsertQueue(T&& t, Args&& ...args)->future;(t(args...))> 总结: ThreadPool(size_t size);构造函数 总结&#xff1…

在SpringBoot中搭建微服务的项目(19版)

1.创建SpringBoot项目 2.删除不需要的,留一个pom文件 3.掉地SpringBoot的版本: <version>2.1.6.RELEASE</version> 4.导入该pom文件 <dependencies> <!-- SpringBoot启动器--><dependency><groupId>org.springframework.boot</g…

关于Redis因OAuth 2.0内存溢出解决方案

一、背景介绍 1.问题简介 本次问题是由OAuth 2.0授权框架&#xff08;用于授权第三方应用程序【客户端】访问受保护的资源。&#xff09;存储在Redis集群中的一个key引起的&#xff1a;client_id_to_access&#xff08;或称为 “client ID to access”&#xff09;通常是指在O…

安全用电管理平台针对电气火灾的解决方案 安科瑞 许敏

摘要&#xff1a; 安全用电管理平台是针对我国当前电气火灾事故频发而设计的一套电气火灾预警和预防管理系统&#xff0c;该系统是基于移动互联网、云计算技术、通过物联网传感终端&#xff08;现场监控模块、传输模块&#xff09;&#xff0c;将供电侧、用电侧电气安全参数实时…

java 打包Spring Boot项目,并运行在windows系统中

前面呢 我们已经把Spring Boot比较基础的东西都弄完了 然后呢 我们来看运维这方面的知识 首先 我们做个打包运行 其实很多人可能会比较熟悉windows系统 而linux服务器 相对没那么了解 那么我们就先来弄windows的 首先 我们要知道 为什么要打包 我们就看我们前面做的MMP项目 当…

git轻量级服务器gogs

确保本真机已启动sshd服务 sudo apt install openssh-server -y sudo systemctl start sshgogs部署 启动 sudo docker stop gogs; sudo docker rm gogs; rm -fr /build/gogs_data/*; sudo docker run --namegogs -p 10022:22 -p 10880:3000 -v /build/gogs_data:/data …

布雷默浪丹 PT 141:189691-06-3,1607799-13-2,Bremelanotide,布美诺肽

Bremelanotide&#xff0c;布雷默浪丹 PT 141&#xff0c;布美诺肽Product structure&#xff1a; Product specifications&#xff1a; 1.CAS No&#xff1a;189691-06-3/1607799-13-2 2.Molecular formula&#xff1a;C50H68N14O10 3.Molecular weight&#xff1a;1025.063 4…

抖音seo矩阵系统源码开发部署-技术开源(三)

场景&#xff1a;抖音seo源码。抖音矩阵源码&#xff0c;短视频seo源码&#xff0c;短视频矩阵源码开发部署&#xff0c;技术分享&#xff0c; 一、 抖音seo源码开发所需服务器环境配置 要开发抖音SEO矩阵系统&#xff0c;需要以下服务器环境&#xff1a; Web服务器&#xff…

Jmeter的常用设置(二)【处理乱码问题】

文章目录 前言一、察看结果树响应结果是乱码_解决方法 方法一&#xff1a;在察看结果树之前添加 后置处理器 中的 “BeanShell PostProcessor” 来动态修改结果处理编码方法二&#xff1a;在配置文件中修改二、使用步骤 1.引入库2.读入数据总结 前言 接口测试中遇到的各种问题…

使用 ViteJs 将 Jest 测试集成到现有的 Vue 3 项目中

根据我最近的经验&#xff0c;我面临着将 Jest 测试框架集成到使用Vite构建的现有Vue3 js项目中的挑战。我在各种博客上找到有用的安装指南时遇到了困难。然而&#xff0c;经过多次尝试和付出很大的努力&#xff0c;我最终找到了解决方案。在这篇博文中&#xff0c;我的目标是提…

2023黑马头条.微服务项目.跟学笔记(五)

2023黑马头条.微服务项目.跟学笔记 五 延迟任务精准发布文章1.文章定时发布2.延迟任务概述2.1 什么是延迟任务2.2 技术对比2.2.1 DelayQueue2.2.2 RabbitMQ实现延迟任务2.2.3 redis实现 3.redis实现延迟任务4.延迟任务服务实现4.1 搭建heima-leadnews-schedule模块4.2 数据库准…

Swagger简介及Springboot集成Swagger详细教程

Swagger简介及Springboot集成Swagger详细教程 学习目标 了解Swagger的作用和概念了解前后端分离在SpringBoot中集成Swagger 1、Swagger简介 前后端分离 VueSpringBoot 后端时代 前端只用管理静态页面&#xff1b;html–>后端。模版引擎JSP–>后端是主力 前后端分离式时…

获取mysql存储过程的异常信息

示例 CREATE DEFINERrootlocalhost PROCEDURE getErrorMsg() BEGIN-- 定义存储变量DECLARE code CHAR(5) DEFAULT ;DECLARE msg TEXT;DECLARE result TEXT;-- 声明异常处理DECLARE CONTINUE HANDLER FOR SQLEXCEPTIONBEGIN-- 获取异常code,异常信息GET DIAGNOSTICS CONDITION …

基于单片机的恒温恒湿温室大棚温湿度控制系统的设计与实现

功能介绍 以51单片机作为主控系统&#xff1b;液晶显示当前温湿度按键设置温湿度报警上限和下限&#xff1b;温度低于下限继电器闭合加热片进行加热&#xff1b;温度超过上限继电器闭合开启风扇进行降温湿度低于下限继电器闭合加湿器进行加湿湿度高于上限继电器闭合开启风扇进行…

干翻Dubbo系列第三篇:Dubbo术语与第一个应用程序

前言 不从恶人的计谋&#xff0c;不站罪人的道路&#xff0c;不坐亵慢人的座位&#xff0c;惟喜爱耶和华的律法&#xff0c;昼夜思想&#xff0c;这人便为有福&#xff01;他要像一棵树栽在溪水旁&#xff0c;按时候结果子&#xff0c;叶子也不枯干。凡他所做的尽都顺利。 如…

小程序页面顶部标题栏、导航栏navigationBar如何隐藏、变透明?

在app.json中的 "window"下面追加一行 "navigationStyle": "custom" 小程序顶部的白色背景条就不见了&#xff0c;直接变透明&#xff0c;只剩下右上角的胶囊按钮 警告&#xff1a; 如果页面有 <web-view src"{{src}}" /> …