使用FP8加速PyTorch训练

news2024/11/24 6:58:58

现代的人工智能硬件架构(例如,Nvidia Hopper, Nvidia Ada Lovelace和Habana Gaudi2)中,FP8张量内核能够显著提高每秒浮点运算(FLOPS),以及为人工智能训练和推理工作负载提供内存优化和节能的机会。

在这篇文章中,我们将介绍如何修改PyTorch训练脚本,利用Nvidia H100 GPU的FP8数据类型的内置支持。这里主要介绍由Transformer Engine库公开的fp8特定的PyTorch API,并展示如何将它们集成到一个简单的训练脚本中。(我们这里只介绍如何使用FP8,不会介绍FP8具体的理论知识)

随着人工智能模型变得越来越复杂,训练它们所需的机器也越来越复杂。Nvidia H100 GPU据称支持“前所未有的性能和可扩展性”。

在AWS中,H100 gpu是作为AWS EC2 p5实例的一个组件提供的。这些实例声称“与上一代基于gpu的EC2实例相比,可将解决方案的时间加快4倍,并将训练ML模型的成本降低高达40%”。

当涉及到机器学习训练实例时,并不总是越大越好。p5实例族尤其如此。p5可能会比其他实例要快很多,因为H100是无可争议的性能野兽。但是一旦考虑到p5的成本(8-GPU p5.48xlarge实例的成本为每小时98.32美元),你可能会发现其他实例类型更适合。

下面我们将在p5.48xlarge上训练一个相对较大的计算机视觉模型,并将其性能与p4d进行比较。p4d.24xlarge包含8个Nvidia A100 gpu。

模型

我们定义了一个Vision Transformer (ViT)支持的分类模型(使用流行的timm Python包版本0.9.10)以及一个随机生成的数据集。ViT主干有多种形状和大小。我们选择了通常被称为ViT-Huge的配置-具有6.32亿个参数-这样能够更好地利用H100对大型模型的容量。

 import torch, time
 import torch.optim
 import torch.utils.data
 import torch.distributed as dist
 from torch.nn.parallel.distributed import DistributedDataParallel as DDP
 import torch.multiprocessing as mp
 
 # modify batch size according to GPU memory
 batch_size = 64
 
 from timm.models.vision_transformer import VisionTransformer
 
 from torch.utils.data import Dataset
 
 
 # use random data
 class FakeDataset(Dataset):
     def __len__(self):
         return 1000000
 
     def __getitem__(self, index):
         rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
         label = torch.tensor(data=[index % 1000], dtype=torch.int64)
         return rand_image, label
 
 
 def mp_fn(local_rank, *args):
     # configure process
     dist.init_process_group("nccl",
                             rank=local_rank,
                             world_size=torch.cuda.device_count())
     torch.cuda.set_device(local_rank)
     device = torch.cuda.current_device()
     
     # create dataset and dataloader
     train_set = FakeDataset()
     train_loader = torch.utils.data.DataLoader(
         train_set, batch_size=batch_size,
         num_workers=12, pin_memory=True)
 
     # define ViT-Huge model
     model = VisionTransformer(
             embed_dim=1280,
             depth=32,
             num_heads=16,
         ).cuda(device)
     model = DDP(model, device_ids=[local_rank])
 
     # define loss and optimizer
     criterion = torch.nn.CrossEntropyLoss()
     optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
     model.train()
 
     t0 = time.perf_counter()
     summ = 0
     count = 0
 
     for step, data in enumerate(train_loader):
         # copy data to GPU
         inputs = data[0].to(device=device, non_blocking=True)
         label = data[1].squeeze(-1).to(device=device, non_blocking=True)
   
         # use mixed precision to take advantage of bfloat16 support
         with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
             outputs = model(inputs)
             loss = criterion(outputs, label)
         optimizer.zero_grad(set_to_none=True)
         loss.backward()
         optimizer.step()
         
         # capture step time
         batch_time = time.perf_counter() - t0
         if step > 10:  # skip first steps
             summ += batch_time
             count += 1
         t0 = time.perf_counter()
         if step > 50:
             break
     print(f'average step time: {summ/count}')
 
 
 if __name__ == '__main__':
     mp.spawn(mp_fn,
              args=(),
              nprocs=torch.cuda.device_count(),
              join=True)

我们使用专用PyTorch 2.1 AWS深度学习容器(763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-ec2)在p5.48xlarge和p4d上都训练了这个模型。

p5的性能远远超过了p4d的性能——每步0.199秒比0.41秒——快了两倍多!!这意味着训练大型机器学习模型的时间将减少一半。但是当你考虑到成本的差异(p4d每小时32.77美元,p5每小时98.32美元),p5的性价比比p4d差30% !!

在这一点上,可能会得出两个可能的结论之一。第一种可能性是,尽管有这么多宣传,但p5根本不适合您。第二个是p5仍然是可行的,但是需要对模型进行调整,充分利用它的潜力。

FP8与Transformer Engine的集成

PyTorch(版本2.1)不包括FP8数据类型。为了将我们的脚本编程为使用FP8,我们将使用Transformer Engine (TE),这是一个用于在NVIDIA gpu上加速Transformer模型的专用库。TE(版本0.12)预装在AWS PyTorch 2.1 DL容器中。

使用FP8的机制比16位(float16和bfloat16)要复杂得多。TE库实现向用户隐藏了所有杂乱的细节。有关如何使用TE api的说明(请参阅官方文档)。

为了修改我们的模型以使用TE,我们将TE的专用Transformer层,所以需要我们自己写一个包装器:

 import transformer_engine.pytorch as te
 from transformer_engine.common import recipe
 
 
 class TE_Block(te.transformer.TransformerLayer):
     def __init__(
             self,
             dim,
             num_heads,
             mlp_ratio=4.,
             qkv_bias=False,
             qk_norm=False,
             proj_drop=0.,
             attn_drop=0.,
             init_values=None,
             drop_path=0.,
             act_layer=None,
             norm_layer=None,
             mlp_layer=None
     ):
         super().__init__(
             hidden_size=dim,
             ffn_hidden_size=int(dim * mlp_ratio),
             num_attention_heads=num_heads,
             hidden_dropout=proj_drop,
             attention_dropout=attn_drop
             )

然后修改VisionTransformer初始化自定义块:

   model = VisionTransformer(
       embed_dim=1280,
       depth=32,
       num_heads=16,
       block_fn=TE_Block
       ).cuda(device)

到目前为止,还没有做任何针对h100特定的更改-相同的代码可以在我们的a100的p4d实例类型上运行。最后一个修改是用te包裹模型前向传递。Fp8_autocast上下文管理器。此更改需要支持FP8的GPU:

 with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
     with te.fp8_autocast(enabled=True):
         outputs = model(inputs)
     loss = criterion(outputs, label)

关于使用FP8的一些注意事项

使用8位浮点表示(相对于16位或32位表示)意味着较低的精度和较低的动态范围。这些可以对模型收敛的可达性和/或速度产生有意义的影响,但不能保证这将适用于所有的模型。所以可能需要调整底层FP8机制(例如,使用TEapi),调整一些超参数,和/或将FP8的应用限制在模型的子模型(一部分)。最坏的可能是尽管进行了所有尝试,模型还是无法与FP8兼容。

结果

在下表中总结了在两个p4d上的实验结果。24xlarge和p5.48xlarge EC2实例类型,使用和不使用TE库。对于p5.48xlarge实验,我们将批处理大小加倍,这样提高80 GB GPU内存的利用率。使用FP8可以减少GPU内存消耗,从而进一步增加批处理大小。

可以看到,使用TE提高了p4d(19%)和p5(32%)的性价比。使用FP8可将p5上的性能额外提高约20%。在TE和FP8优化之后,基于h100的p5.48large的性价比优于基于a100的p4d.xlarge——虽然差距不大(2%)。考虑到训练速度提高了3倍,我们可以有把握地得出结论,p5将是训练优化模型的更好的实例类型。

但是我们也看到了,这是相对较小的性价比提升(远低于p5公告中提到的40%),所以可能还有更多的优化方案,我们需要继续研究。

总结

在这篇文章中,我们演示了如何编写PyTorch训练脚本来使用8位浮点类型。展示了FP8的使用是如何从Nvidia H100中获得最佳性能的关键因素。FP8的可行性及其对训练性能的影响可以根据模型的细节而变化很大。

https://avoid.overfit.cn/post/541a04c656db474d91ee5eb1fa5bc5f8

作者:Chaim Rand

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1220208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DBeaver 23.2.4发布

导读DBeaver 23.2.4发布,修复和添加了一些新功能。 SQL 编辑器 为更新语句添加了代码自动补全功能修复了智能补全和替换带引号表达式的问题删除了日志中首次打开 SQL 编辑器时出现的多余错误 数据库导航器 添加了过滤对象可视化功能修复了脚本文件夹打开问题 数据传输 正确…

消除“数据烟囱”,瓴羊港如何打破壁垒将多数据融通成大数据?

作为数字经济时代的“新石油”,数据已成为重要的生产要素。阿里巴巴副总裁、瓴羊CEO朋新宇认为,目前正处在数据流通变革的时代,其中最核心的问题是如何破解数实融合发展的堵点。数据流通中最重要的原则是,不流通无价值&#xff0c…

echarts 实现同一组legend控制两个饼图示例

实现同一组legend控制两个饼图示例: 该示例有如下几个特点: ①饼图不同值实现分割 ②实现tooltip自定义样式(echarts 实现tooltip提示框样式自定义-CSDN博客) ③自定义label内容 ④不同值颜色渐变 代码如下: this.o…

html实现图片裁剪处理(附源码)

文章目录 1.设计来源1.1 主界面1.2 裁剪界面 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者:xcLeigh 文章地址:https://blog.csdn.net/weixin_43151418/article/details/134455169 html实现图片裁剪处理(附源码),支持图片放大缩小&#…

Linux下安装部署redis(离线模式)

一、准备工作 1.下载redis的安装包 下载地址:Index of /releases/ 大家可以自行选择redis的版本,笔者选择的是最新的 2.上传到服务器 前提是我先在服务器上创建了一个目录redis7.2.3,我直接上传到这个目录下 二、安装redis 1.解压redis t…

阿里云服务器ECS安装宝塔面板

前言 如今各种云服务器租借平台,例如腾讯云、阿里云之类的,很轻松的就能租借得到一台Linux的服务器。但是Linux的管理和使用存在一定的门槛。宝塔面板作为一款流行的服务器管理软件,提供了简单易用的图形化界面和丰富的管理功能,降…

Android SmartTable根据int状态格式化文字及颜色

private void initData() {List<UserInfo> list new ArrayList<>();list.add(new UserInfo("一年级", "李同学", 6, 1, 120, 1100, 450, 0));list.add(new UserInfo("一年级", "张同学", 6, 2, 120, 1100, 450, 1));list…

electron使用better-sqlite3打包失败(electron打包有进程没有界面)

remove *\chrome_100_percent.pak: Access is denied. 解决&#xff1a; 管理员权限执行&#xff1a;taskkill /IM 你的进程名.exe /F&#xff0c;再次执行build electron使用better-sqlite3打包后有进程没有界面 原因是代码及依赖包安装有误&#xff0c;模块丢失。主要分享的…

Flat Ads将在杭州举办社交出海沙龙,探寻海外巨大增量空间

深圳站落幕后&#xff0c;Flat Ads社交沙龙活动迎来杭州站&#xff01;11月29日&#xff0c;Flat Ads联动Alibaba Cloud、TopOn、融云&#xff0c;开展《泛娱乐社交APP出海新风口-杭州站》&#xff0c;分享如何捕捉出海新赛道的风向标&#xff0c;并迅速实现获客增长&#xff0…

36 mysql 主键冲突 和 唯一索引冲突

前言 我们这里 来看一下 我们经常碰到的 "duplicate key xxx" 测试表结构如下 CREATE TABLE tz_test (id int(11) unsigned NOT NULL AUTO_INCREMENT,field1 varchar(128) DEFAULT NULL,PRIMARY KEY (id) USING BTREE,KEY field1 (field1) USING BTREE ) ENGINEI…

超声功率放大器使用范围有哪些

超声功率放大器是一种特殊的设备&#xff0c;用于放大超声波信号的功率级别。它在各种领域都有广泛的应用范围&#xff0c;下面将详细介绍超声功率放大器的使用范围。 医学影像领域&#xff1a; 在医学影像领域&#xff0c;超声功率放大器被广泛用于超声诊断设备。它们能够放大…

亲测一款超实用的在线制作产品册工具,一看就会

最近&#xff0c;我一直在寻找一款简单易用的在线制作产品册工具&#xff0c;终于让我找到了一个超实用的神器&#xff01;这款工具不仅功能强大&#xff0c;而且操作简单&#xff0c;一看就会。 首先&#xff0c;这款工具提供了丰富的模板和素材&#xff0c;用户可以根据自己的…

【测试功能篇 01】Jmeter 压测接口最大并发量、吞吐量、TPS

压力测试&#xff0c;我们针对比较关键的接口&#xff0c;可以进行相应的压力测试&#xff0c;主要还是测试看看接口能抗住多少的请求数&#xff0c;TPS稳定在多少&#xff0c;也就是吞吐量多少 安装 Jmeter的安装很简单&#xff0c;官网下载地址 http://jmeter.apache.org/ &…

万字长文:从 C# 入门学会 RabbitMQ 消息队列编程

RabbitMQ 简介 RabbitMQ 是一个实现了 AMQP 协议的消息队列&#xff0c;AMQP 被定义为作为消息传递中间件的开放标准的应用层协议。它代表高级消息队列协议&#xff0c;具有消息定位、路由、队列、安全性和可靠性等特点。 目前社区上比较流行的消息队列有 kafka、ActiveMQ、Pul…

mac中安装Homebrew

1、Homebrew是什么&#xff1f; 软件安装管理工具 2、先检查电脑中是否已经安装了Homebrew 打开终端输入&#xff1a;brew 提示命令没有找到&#xff0c;说明电脑没有安装Homebrew 如果提示上述图片说明Homebrew已经安装成功 3、安装Homebrew 进入https://brew.sh/ 复制的命…

3.ubuntu20.04环境的ros搭建

ros搭建比较简单&#xff0c;主要步骤如下&#xff1a; 1.配置ros软件源&#xff1a; sudo sh -c echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list 2.配置密钥 sudo apt-key adv --keyser…

01线性回归

目录 常规求解&#xff1a; 矩阵求解 sklean算法求解 # 二元一次方程 # x y 14 # 2x - y 10 常规求解&#xff1a; x np.array([[1,1],[2,-1]])print(x) # [[ 1 1] # [ 2 -1]]y np.array([14, 10])w np.linalg.solve(x, y)print(正常求救&#xff1a;)print(w) …

在Centos7.9_2207安装CDH6.3.2

在Centos7.9_2207安装CDH6.3.2 背景 笔者做大数据开发&#xff0c;实时部分一般要用到HBase、Kudu、Redis等组件来保证幂等性&#xff0c;为了方便&#xff0c;还是选用老古董CDH6.3.2【最后的免费版】做一个单节点机器&#xff0c;方便随时挂起。多节点虚拟机由之前的双路E5…

实践小记——C#科学计数法格式化输出

文章速览 示例默认输出&#xff0c;不设置小数精度设置尾数部分的小数精度 总结参考文章 坚持记录实属不易&#xff0c;希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区&#xff01; 谢谢~ 示例 默认输出&#xff0c;不设置小数精度 private void Fo…