使用FP8加速PyTorch训练的两种方法总结

news2024/10/6 6:40:58

在PyTorch中,FP8(8-bit 浮点数)是一个较新的数据类型,用于实现高效的神经网络训练和推理。它主要被设计来降低模型运行时的内存占用,并加快计算速度,同时尽量保持训练和推理的准确性。虽然PyTorch官方在标准发布中尚未全面支持FP8,但是在2.2版本中PyTorch已经包含了对FP8的“有限支持”并且出现了2个新的变量类型,

torch.float8_e4m3fn

torch.float8_e5m2

,而H100也支持这种类型,所以这篇文章我们就来介绍如何使用FP8来提高训练效率

模型架构

我们定义了一个Vision Transformer (ViT)支持的分类模型(使用流行的timm Python包版本0.9.10)以及一个随机生成的数据集。我们选择了ViT-Huge的有6.32亿个参数的最大的模型,这样可以演示FP8的效果。

 import torch, time
 import torch.optim
 import torch.utils.data
 import torch.distributed as dist
 from torch.nn.parallel.distributed import DistributedDataParallel as DDP
 import torch.multiprocessing as mp
 
 # modify batch size according to GPU memory
 batch_size = 64
 
 from timm.models.vision_transformer import VisionTransformer
 
 from torch.utils.data import Dataset
 
 
 # use random data
 class FakeDataset(Dataset):
     def __len__(self):
         return 1000000
 
     def __getitem__(self, index):
         rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
         label = torch.tensor(data=[index % 1000], dtype=torch.int64)
         return rand_image, label
 
 
 def mp_fn(local_rank, *args):
     # configure process
     dist.init_process_group("nccl",
                             rank=local_rank,
                             world_size=torch.cuda.device_count())
     torch.cuda.set_device(local_rank)
     device = torch.cuda.current_device()
     
     # create dataset and dataloader
     train_set = FakeDataset()
     train_loader = torch.utils.data.DataLoader(
         train_set, batch_size=batch_size,
         num_workers=12, pin_memory=True)
 
     # define ViT-Huge model
     model = VisionTransformer(
             embed_dim=1280,
             depth=32,
             num_heads=16,
         ).cuda(device)
     model = DDP(model, device_ids=[local_rank])
 
     # define loss and optimizer
     criterion = torch.nn.CrossEntropyLoss()
     optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
     model.train()
 
     t0 = time.perf_counter()
     summ = 0
     count = 0
 
     for step, data in enumerate(train_loader):
         # copy data to GPU
         inputs = data[0].to(device=device, non_blocking=True)
         label = data[1].squeeze(-1).to(device=device, non_blocking=True)
   
         # use mixed precision to take advantage of bfloat16 support
         with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
             outputs = model(inputs)
             loss = criterion(outputs, label)
         optimizer.zero_grad(set_to_none=True)
         loss.backward()
         optimizer.step()
         
         # capture step time
         batch_time = time.perf_counter() - t0
         if step > 10:  # skip first steps
             summ += batch_time
             count += 1
         t0 = time.perf_counter()
         if step > 50:
             break
     print(f'average step time: {summ/count}')
 
 
 if __name__ == '__main__':
     mp.spawn(mp_fn,
              args=(),
              nprocs=torch.cuda.device_count(),
              join=True)

Transformer Engine

PyTorch(版本2.1)不包括FP8的数据类型。所以我们需要通过第三方的库Transformer Engine (TE),这是一个用于在NVIDIA gpu上加速Transformer模型的专用库。

使用FP8要比16float16和bfloat16复杂得多。这里我们不用关心细节,因为TE都已经帮我们实现了,我们只要拿来用就可以了。

但是需要对我们上面的模型进行一些简单的修改,需要将transformer变为TE的专用transformer层

 import transformer_engine.pytorch as te
 from transformer_engine.common import recipe
 
 
 class TE_Block(te.transformer.TransformerLayer):
     def __init__(
             self,
             dim,
             num_heads,
             mlp_ratio=4.,
             qkv_bias=False,
             qk_norm=False,
             proj_drop=0.,
             attn_drop=0.,
             init_values=None,
             drop_path=0.,
             act_layer=None,
             norm_layer=None,
             mlp_layer=None
     ):
         super().__init__(
             hidden_size=dim,
             ffn_hidden_size=int(dim * mlp_ratio),
             num_attention_heads=num_heads,
             hidden_dropout=proj_drop,
             attention_dropout=attn_drop
             )

然后修改VisionTransformer初始化使用自定义层:

   model = VisionTransformer(
       embed_dim=1280,
       depth=32,
       num_heads=16,
       block_fn=TE_Block
       ).cuda(device)

最后一个修改是用te包裹模型前向传递。Fp8_autocast上下文管理器。此更改需要支持FP8的GPU:

 with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
     with te.fp8_autocast(enabled=True):
         outputs = model(inputs)
     loss = criterion(outputs, label)

下面我们就可以测试结果:

可以看到,使用TE块提高了p4d(19%)和p5(32%)的性价比。使用FP8可将p5上的性能额外提高约20%。在TE和FP8优化之后,基于h100的p5.48large的性价比优于基于a100的p4d.24large 。并且训练速度提高了3倍。

Pytorch的原生FP8

在2.2版本后,pytorch原生FP8支持已经是“有限支持”了,所以我们可以先学习一下如何使用了。

 import torch
 from tabulate import tabulate
 
 f32_type = torch.float32
 bf16_type = torch.bfloat16
 e4m3_type = torch.float8_e4m3fn
 e5m2_type = torch.float8_e5m2
 
 # collect finfo for each type
 table = []
 for dtype in [f32_type, bf16_type, e4m3_type, e5m2_type]:
     numbits = 32 if dtype == f32_type else 16 if dtype == bf16_type else 8
     info = torch.finfo(dtype)
     table.append([info.dtype, numbits, info.max, 
                   info.min, info.smallest_normal, info.eps])
 
 headers = ['data type', 'bits', 'max', 'min', 'smallest normal', 'eps']
 print(tabulate(table, headers=headers))
 
 '''
 Output:
 
 data type      bits          max           min  smallest normal          eps
 -------------  ----  -----------  ------------  ---------------  -----------
 float32          32  3.40282e+38  -3.40282e+38      1.17549e-38  1.19209e-07
 bfloat16         16  3.38953e+38  -3.38953e+38      1.17549e-38    0.0078125
 float8_e4m3fn     8          448          -448         0.015625        0.125
 float8_e5m2       8        57344        -57344      6.10352e-05         0.25
 '''

我们可以通过在张量初始化函数中指定dtype来创建FP8张量,如下所示:

 device="cuda"
 e4m3 = torch.tensor(1., device=device, dtype=e4m3_type)
 e5m2 = torch.tensor(1., device=device, dtype=e5m2_type)

也可以强制转换为FP8。在下面的代码中,我们生成一个随机的浮点张量,并比较将它们转换为四种不同的浮点类型的结果:

 x = torch.randn(2, 2, device=device, dtype=f32_type)
 x_bf16 = x.to(bf16_type)
 x_e4m3 = x.to(e4m3_type)
 x_e5m2 = x.to(e5m2_type)
 
 print(tabulate([[‘float32’, *x.cpu().flatten().tolist()],
                 [‘bfloat16’, *x_bf16.cpu().flatten().tolist()],
                 [‘float8_e4m3fn’, *x_e4m3.cpu().flatten().tolist()],
                 [‘float8_e5m2’, *x_e5m2.cpu().flatten().tolist()]],
                headers=[‘data type’, ‘x1’, ‘x2’, ‘x3’, ‘x4’]))
 
 '''
 The sample output demonstrates the dynamic range of the different types:
 
 data type                  x1              x2              x3              x4
 -------------  --------------  --------------  --------------  --------------
 float32        2.073093891143  -0.78251332044  -0.47084918620  -1.32557279110
 bfloat16       2.078125        -0.78125        -0.4707031      -1.328125
 float8_e4m3fn  2.0             -0.8125         -0.46875        -1.375
 float8_e5m2    2.0             -0.75           -0.5            -1.25
 -------------  --------------  --------------  --------------  --------------
 '''

虽然创建FP8张量很容易,但FP8张量上执行一些基本的算术运算是不支持的。并且需要特定的函数,比如torch._scaled_mm来进行矩阵乘法。

 output, output_amax = torch._scaled_mm(
         torch.randn(16,16, device=device).to(e4m3_type),
         torch.randn(16,16, device=device).to(e4m3_type).t(),
         bias=torch.randn(16, device=device).to(bf16_type),
         out_dtype=e4m3_type,
         scale_a=torch.tensor(1.0, device=device),
         scale_b=torch.tensor(1.0, device=device)
     )

那么如何进行模型的训练呢,我们来做一个演示

 import torch
 from timm.models.vision_transformer import VisionTransformer
 from torch.utils.data import Dataset, DataLoader
 import os
 import time
 
 #float8 imports
 from float8_experimental import config
 from float8_experimental.float8_linear import Float8Linear
 from float8_experimental.float8_linear_utils import (
     swap_linear_with_float8_linear,
     sync_float8_amax_and_scale_history
 )
 
 #float8 configuration (see documentation)
 config.enable_amax_init = False
 config.enable_pre_and_post_forward = False
 
 # model configuration controls:
 fp8_type = True # toggle to change floating-point precision
 compile_model = True # toggle to enable model compilation
 batch_size = 32 if fp8_type else 16 # control batch size
 
 device = torch.device('cuda')
 
 # use random data
 class FakeDataset(Dataset):
     def __len__(self):
         return 1000000
     def __getitem__(self, index):
         rand_image = torch.randn([3, 256, 256], dtype=torch.float32)
         label = torch.tensor(data=[index % 1024], dtype=torch.int64)
         return rand_image, label
 
 # get data loader
 def get_data(batch_size):
     ds = FakeDataset()
     return DataLoader(
            ds,
            batch_size=batch_size, 
            num_workers=os.cpu_count(),
            pin_memory=True
          )
 
 # define the timm model
 def get_model():
     model = VisionTransformer(
         class_token=False,
         global_pool="avg",
         img_size=256,
         embed_dim=1280,
         num_classes=1024,
         depth=32,
         num_heads=16
     )
     if fp8_type:
         swap_linear_with_float8_linear(model, Float8Linear)
     return model
 
 # define the training step
 def train_step(inputs, label, model, optimizer, criterion):
     with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
         outputs = model(inputs)
         loss = criterion(outputs, label)
     optimizer.zero_grad(set_to_none=True)
     loss.backward()
     if fp8_type:
         sync_float8_amax_and_scale_history(model)
     optimizer.step()
 
 
 model = get_model()
 optimizer = torch.optim.Adam(model.parameters())
 criterion = torch.nn.CrossEntropyLoss()
 train_loader = get_data(batch_size)
 
 # copy the model to the GPU
 model = model.to(device)
 if compile_model:
     # compile model
     model = torch.compile(model)
 model.train()
 
 t0 = time.perf_counter()
 summ = 0
 count = 0
 
 for step, data in enumerate(train_loader):
     # copy data to GPU
     inputs = data[0].to(device=device, non_blocking=True)
     label = data[1].squeeze(-1).to(device=device, non_blocking=True)
 
     # train step
     train_step(inputs, label, model, optimizer, criterion)
 
     # capture step time
     batch_time = time.perf_counter() - t0
     if step > 10:  # skip first steps
         summ += batch_time
         count += 1
     t0 = time.perf_counter()
     if step > 50:
         break
 
 print(f'average step time: {summ / count}')

这里需要特定的转换函数,将一些操作转换为支持FP8的版本,需要说明的是,因为还在试验阶段所以可能不稳定

FP8线性层的使用使我们的模型的性能比我们的基线实验提高了47%(!!)

对比TE

未编译的TE FP8模型的性能明显优于我们以前的FP8模型,但编译后的PyTorch FP8模型提供了最好的结果。因为TE FP8模块不支持模型编译。所以使用torch.compile会导致“部分编译”,即它在每次使用FP8时将计算分拆为多个图。

总结

在这篇文章中,我们演示了如何编写PyTorch训练脚本来使用8位浮点类型。TE是一个非常好的库,因为它可以让我们的代码修改量最小,而PyTorch原生FP8支持虽然需要修改代码,并且还是在试验阶段(最新的2.3还是在试验阶段),可能会产生问题,但是这会让训练速度更快。

不过总的来说FP8的确可以加快我们的训练速度,提高GPU的使用效率。这里要提一句TE是由NVIDIA开发的,并对其gpu进行了大量定制,所以如果是N卡的话可以直接用TE
https://avoid.overfit.cn/post/0dd1fba546674b48b932260fa8742971

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695602.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【电路笔记】-二阶滤波器

二阶滤波器 二阶(或双极)滤波器由两个连接在一起的 RC 滤波器部分组成,可提供 -40dB/十倍频程滚降率。 1、概述 二阶滤波器也称为 VCVS 滤波器,因为运算放大器用作压控电压源放大器,是有源滤波器设计的另一种重要类型,因为与我们之前研究过的有源一阶 RC 滤波器一起,…

Git 的安装和使用

一、Git 的下载和安装 目录 一、Git 的下载和安装 1. git 的下载 2. 安装 二、Git 的基本使用-操作本地仓库 1 初始化仓库 1)创建一个空目录 2)git init 2 把文件添加到版本库 1)创建文件 2)git add . 3)g…

迅睿 CMS 中开启【ionCube 扩展】的方法

有时候我们想要某种功能时会到迅睿 CMS 插件市场中找现有的插件,但会有些担心插件是否适合自己的需求。于是迅睿 CMS 考虑到这一层推出了【申请试用】,可以让用户申请试用 30 天,不过试用是有条件的,条件如下: php 版…

MyBatis复习笔记

3.Mybatis复习 3.1 xml配置 properties&#xff1a;加载配置文件 settings&#xff1a;设置驼峰映射 <settings><setting name"mapUnderscoreToCamelCase" value"true"/> </settings>typeAliases&#xff1a;类型别名设置 #这样在映射…

28. 正定矩阵和最小值

文章目录 1. 概述2. 正定矩阵判定条件3. 举例 1. 概述 正定矩阵这节可以将主元&#xff0c;行列式&#xff0c;特征值&#xff0c;还有不稳定性结合起来。以前我们学的是解决方程 A x b Axb Axb 的问题&#xff0c;现在升级&#xff0c;变成 x T A x b x^TAxb xTAxb &…

html 字体设置 (web端字体设置)

windows自带的字体是有版权的&#xff0c;包括微软雅黑&#xff08;方正&#xff09;、宋体&#xff08;中易&#xff09;、黑体&#xff08;中易&#xff09;等 版权算是个大坑&#xff0c;所谓为了避免版权问题&#xff0c;全部使用开源字体即可 我这里选择的是思源宋体&…

Java进阶学习笔记10——子类构造器

子类构造器的特点&#xff1a; 子类的全部构造器&#xff0c;都会先调用父类的构造器&#xff0c;再执行自己。 子类会继承父类的数据&#xff0c;可能还会使用父类的数据。所以&#xff0c;子类初始化之前&#xff0c;一定先要完成父类数据的初始化&#xff0c;原因在于&…

【pyspark速成专家】7_SparkSQL编程1

目录 一&#xff0c;RDD&#xff0c;DataFrame和DataSet对比 二&#xff0c;创建DataFrame 本节将介绍SparkSQL编程基本概念和基本用法。 不同于RDD编程的命令式编程范式&#xff0c;SparkSQL编程是一种声明式编程范式&#xff0c;我们可以通过SQL语句或者调用DataFrame的相…

2024Spring> HNU-计算机系统-实验4-Buflab-导引+验收

前言 称不上导引了&#xff0c;因为验收已经结束了。主要是最近比较忙&#xff0c;在准备期末考试。周五晚上才开始看实验&#xff0c;自己跟着做了一遍实验&#xff0c;感觉难度还是比bomblab要低的&#xff0c;但是如果用心做的话对于栈帧的理解确实能上几个档次。 实验参考…

ClickHouse 24.4 版本发布说明

本文字数&#xff1a;13148&#xff1b;估计阅读时间&#xff1a;33 分钟 审校&#xff1a;庄晓东&#xff08;魏庄&#xff09; 本文在公众号【ClickHouseInc】首发 新的一个月意味着新版本的发布&#xff01; 发布概要 本次ClickHouse 24.4版本包含了13个新功能&#x1f381;…

List Control控件绑定变量

创建基于对话框的mfc项目 添加 List Control控件 右击控件&#xff0c;选择“添加变量” 在初始化对话框代码中增加一些代码 BOOL CMFCApplication3Dlg::OnInitDialog() { //...// TODO: 在此添加额外的初始化代码DWORD dwStyle m_programLangList.GetExtendedStyle(); …

程序员的那些经典段子

哈喽&#xff0c;大家好&#xff0c;我是明智&#xff5e; 本周咱们已经解决了在面试中经常碰到的OOM问题&#xff1a; 《美团一面&#xff0c;发生OOM了&#xff0c;程序还能继续运行吗&#xff1f;》 《美团一面&#xff1a;碰到过OOM吗&#xff1f;你是怎么处理的&#xff1…

【Linux】Linux的基本指令_1

文章目录 二、基本指令1. whoami 和 who2. pwd3. ls4. clear5. mkdir 和 cd6. touch7. rmdir 和 rm 未完待续 二、基本指令 直接在命令行的末尾&#xff08;# 后面&#xff09;输入指令即可。在学习Linux指令的过程中&#xff0c;还会穿插一些关于Linux的知识点。 1. whoami …

AI助力垃圾分类开启智慧环保新时代,基于卷积神经网络模型开发实践垃圾分类识别系统

在快节奏的现代生活中&#xff0c;垃圾分类已经成为一项重要的环保举措。然而&#xff0c;面对日益复杂的垃圾种类和繁多的分类标准&#xff0c;许多人感到困惑和无奈。幸运的是&#xff0c;随着人工智能技术的飞速发展&#xff0c;AI深度学习模型为垃圾分类带来了革命性的变化…

人工智能 框架 paddlepaddle 飞桨 使用指南 使用例子 线性回归模型demo 详解

安装过程&使用指南&线性回归模型 使用例子 本来预想 是安装 到 conda 版本的 11.7的 但是电脑没有gpu 所以 安装过程稍有变动,下面简单讲下 conda create -n paddle_env117 python=3.9 由于想安装11.7版本 py 是3.9 所以虚拟环境名称也是 paddle_env117 activa…

嵌入式全栈开发学习笔记---C语言笔试复习大全21(编程题25~30)

目录 25、实现字符串的排序。&#xff08;输入hello world good&#xff0c;输出good hello world&#xff0c;其中字符串个数任意&#xff09; 26、输入两个有序的字符串&#xff08;从小到大&#xff09;&#xff0c;合并成一个有序的字符串。&#xff08;输入cdhxyz fjln …

利用EAS自动生成数据模型和sql脚本

EAS适用于敏捷开发中小系统,这节主要讲解EAS对应的模型和数据库脚本输出应用。 在这个应用程序中,用户可自定义实体模型和枚举模型,只要选择相应的实体或者枚举进行右击添加即可。 解决方案参数设定,在解决方案的设定中可设置项目名称、通用语言,命名空间和输出位置。 连…

Python+Flask+Pandas怎样实现任意时间范围的对比数据报表

话不多说,有图有源码: 1.上图 2.因为是低代码的,只能发重要有用的代码片段了 实现思路:1)获取指定时间范围内的数据:2)df合并 #----------年份替换----------------for syear in range(int(byear),int(eyear)1):start_datestr(syear)strbdate[4:]end_datestr(syear)stredate…

2024-05-22 VS2022使用modules

点击 <C 语言编程核心突破> 快速C语言入门 VS2022使用modules 前言一、准备二、使用其一, 用VS installer 安装模块:第二个选项就是, 与你的代码一同编译std模块, 这个非常简单, 但是也有坑. 总结 前言 要解决问题: 使用VS2022开启modules. 想到的思路: 跟着官方文档整…

Linux更改系统中的root密码

Linux里面的root密码忘记了怎么办&#xff1f; 1 更改系统中的 root 密码 &#xff08;1&#xff09;键盘 CtrlAltT 快捷键打开终端。 &#xff08;2&#xff09;在终端窗口中输入以下代码&#xff1a; sudo passwd root &#xff08;3&#xff09;输入锁屏密码 &#xf…