Transformers 加速的一些常用技巧

news2024/11/29 10:51:18

Transformers 是一个强大的架构,但模型因其采用的自注意力机制,虽然能够有效地处理序列数据并捕获长距离依赖关系,但同时也容易导致在训练过程中出现OOM(Out of Memory,内存不足)或者达到GPU的运行时限制。

主要是因为

  1. 参数数量庞大:Transformer模型通常包含大量的参数,尤其是在模型层面进行扩展时(例如,增加层数或头数)。这些参数需要大量的内存来存储权重和梯度。
  2. 自注意力计算:自注意力机制需要对输入序列的每个元素与其他所有元素计算其相互关系,导致计算复杂度和内存需求随着输入长度的增加而显著增加。对于非常长的序列,这一点尤其突出。
  3. 激活和中间状态存储:在训练过程中,需要存储前向传播中的中间激活状态,以便于反向传播时使用。这增加了额外的内存负担。

为了解决这些问题,我们今天来总结以下一些常用的加速策略

固定长度填充

在处理文本数据时,由于文本序列的长度可能各不相同,但许多机器学习模型(尤其是基于Transformer的模型)需要输入数据具有固定的尺寸,因此需要对文本序列进行固定长度填充(padding)。

在使用Transformer模型时,填充部分不应影响到模型的学习。因此通常需要使用注意力掩码(attention mask)来指示模型在自注意力计算时忽略这些填充位置。通过这种固定长度填充和相应的处理方法,可以使得基于Transformer的模型能够有效地处理不同长度的序列数据。在实际应用中,这种方法是处理文本输入的常见策略。

 def fixed_pad_sequences(sequences, max_length, padding_value=0):
     padded_sequences = []
     for sequence in sequences:
         if len(sequence) >= max_length:
             padded_sequence = sequence[:max_length]  # Trim the sequence if it exceeds max_length
         else:
             padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
             padded_sequence = sequence + padding  # Pad the sequence
         padded_sequences.append(padded_sequence)
     return padded_sequences

这种方式会将所有的序列填充成一个长度,这样虽然长度相同了,但是因为序列的实际大小本来就不同,同一批次很可能出现有很多填充的情况,所以就出现了动态填充策略。

动态填充是在每个批处理中动态填充输入序列到最大长度。与固定长度填充不同,在固定长度填充中,所有序列都被填充以匹配整个数据集中最长序列的长度,动态填充根据该批中最长序列的长度单独填充每个批中的序列。

这样虽然每个批次的长度是不同的,但是批次内部的长度是相同的,可以加快处理速度。

 def pad_sequences_dynamic(sequences, padding_value=0):
     max_length = max(len(seq) for seq in sequences)  # Find the maximum length in the sequences
     padded_sequences = []
     for sequence in sequences:
         padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
         padded_sequence = sequence + padding  # Pad the sequence
         padded_sequences.append(padded_sequence)
     return padded_sequences

等长匹配

等长匹配是在训练或推理过程中将长度相近的序列分组成批处理的过程。等长匹配通过基于序列长度将数据集划分为桶,然后从这些桶中采样批次来实现的。

从上图可以看到,通过等长匹配的策略,减少了填充量,这样也可以加速计算

 def uniform_length_batching(sequences, batch_size, padding_value=0):
     # Sort sequences based on their lengths
     sequences.sort(key=len)
     
     # Divide sequences into buckets based on length
     buckets = [sequences[i:i+batch_size] for i in range(0, len(sequences), batch_size)]
     
     # Pad sequences within each bucket to the length of the longest sequence in the bucket
     padded_batches = []
     for bucket in buckets:
         max_length = len(bucket[-1])  # Get the length of the longest sequence in the bucket
         padded_bucket = []
         for sequence in bucket:
             padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
             padded_sequence = sequence + padding  # Pad the sequence
             padded_bucket.append(padded_sequence)
         padded_batches.append(padded_bucket)
     
     return padded_batches

自动混合精度

自动混合精度(AMP)是一种通过使用单精度(float32)和半精度(float16)算法的组合来加速深度学习模型训练的技术。它利用了现代gpu的功能,与float32相比,使用float16数据类型可以更快地执行计算,同时使用更少的内存。

 import torch
 from torch.cuda.amp import autocast, GradScaler
 
 # Define your model
 model = YourModel()
 
 # Define optimizer and loss function
 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 criterion = torch.nn.CrossEntropyLoss()
 
 # Create a GradScaler object for gradient scaling
 scaler = GradScaler()
 
 # Inside the training loop
 for inputs, targets in dataloader:
     # Clear previous gradients
     optimizer.zero_grad()
     
     # Cast inputs and targets to the appropriate device
     inputs, targets = inputs.to(device), targets.to(device)
     
     # Enable autocasting for forward pass
     with autocast():
         # Forward pass
         outputs = model(inputs)
         loss = criterion(outputs, targets)
     
     # Backward pass
     # Scale the loss value
     scaler.scale(loss).backward()
     
     # Update model parameters
     scaler.step(optimizer)
     
     # Update the scale for next iteration
     scaler.update()

AMP在训练过程中动态调整计算精度,允许模型在大多数计算中使用float16,同时自动将某些计算提升为float32,以防止下流或溢出等数值不稳定问题。

Fp16 vs Fp32

双精度(FP64)消耗64位。符号值为1位,指数值为11位,有效精度为52位。

单精度(FP32)消耗32位。符号值为1位,指数值为8位,有效精度为23位。

半精度(FP16)消耗16位。符号值为1位,指数值为5位,有效精度为10位。

所以Fp16可以提高内存节省,并可以大大提高模型训练的速度。考虑到Fp16的优势和它在模型使用方面的主导区域,它非常适合推理任务。但是fp16会产生数值精度的损失,导致计算或存储的值不准确,考虑到这些值的精度至关重要。

另外就是这种优化师针对于分类任务的,对于回归这种需要精确数值的任务Fp16的表现并不好。

总结

以上这些方法,可以在一定程度上缓解内存不足和计算资源的限制,但是对于大型的模型我们还是需要一个强大的GPU。

https://avoid.overfit.cn/post/7240bee210cd408a90ca04279830040e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1670427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

防爆巡检手持终端在燃气巡检作业中的应用

在燃气巡检作业中,安全始终是首要考虑的因素。面对易燃易爆的燃气环境,传统的巡检方式已经难以满足现代安全管理的需求。随着科技的不断进步,防爆巡检手持终端应运而生,成为燃气巡检作业的得力助手。这些终端不仅具备高度的防爆性…

介绍适用于 Node.js 的 Elastic OpenTelemetry 发行版

作者:来自 Elastic Trent Mick 我们很高兴地宣布推出 Elastic OpenTelemetry Distribution for Node.js 的 alpha 版本。 该发行版是 OpenTelemetry Node.js SDK 的轻量级包装,可以让你更轻松地开始使用 OpenTelemetry 来观察 Node.js 应用程序。 背景 …

RiPro主题美化【支付弹窗底部提示语根据入口不同有不同的提示】ritheme主题美化RiProV2 增加支付提示语,按支付类型不同,入口不同提示语不同的设置

RiPro主题美化【支付弹窗底部提示语根据入口不同有不同的提示】ritheme主题美化RiProV2 增加支付提示语,按支付类型不同,入口不同提示语不同的设置 背景: 接上文:https://www.uu2id.com/827.html 付费组件在以下几个地方会弹出:1)文章隐藏内容付费;2)付费资源下载;3…

HR人才测评:应变能力与岗位胜任力素质测评

什么是应变能力 应变能力在职场中可以说是必备的素质之一,它指的是从业者需要长期活动或者是行为来迎接即将到来的挑战,做提前的思考,以适应未来的挑战,具有随机应变的意思。在外界还未发生变化或者是已经发生变化时,…

沃比得DP28A 对数周期天线 200MHz~8GHz

产品概述 DP28A 对数周期天线,工作频率为 200MHz~8GHz。具有频带宽,性能可靠,增益高等优 点,是理想的无线电频谱管理、EMC 测试、电子对抗等领域的定向接收、发射天线。 应用领域 ● 无线电频谱管理 ● EMC 测试 …

(有奖调查)企业级3D模型资产管理平台,用户需求大调查!

(有奖调查)企业级3D模型文件管理平台用户需求大调查https://www.wjx.cn/vm/PpLKkmn.aspx#

作为一名普通投资者怎么查看现货白银的价格是多少?

做现货白银白银投资的投资者,经常会关注现货白银的价格是多少,因为交易决策是建立在具体的价格之上的。那么有什么方法可以让投资者可以时刻关注到现货白银的价格多少呢? 要时刻监测现货白银的价格,我们主要有2种途径,…

IO的阻塞和非阻塞浅析

在操作系统和网络编程中,IO(输入/输出)操作是一个非常重要的概念。 在处理IO的时候,阻塞和非阻塞都是同步IO。只有使用了特殊的API才是异步IO。 ——陈硕大神 网络IO层面 典型的一次IO的两个阶段是什么? 数据准备 和…

大数据项目中的拉链表(hadoop,hive)

缓慢渐变维 拉链表 拉链表,可实现数据快照,可以将历史和最新数据保存在一起 如何实现: 在原始数据增加两个新字段 起始时间(有效时间:什么时候导入的数据的时间),结束时间(默认的结束时间为99…

CUDA-基于累计直方图和共享内存的中值滤波算法

作者:翟天保Steven 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 实现原理 基于累计直方图的中值滤波算法详情参见: OpenCV-基于累计直方图的中值滤波算法-CSDN博客 为了进一步提升算…

06、SpringBoot 源码分析 - SpringApplication启动流程六

SpringBoot 源码分析 - SpringApplication启动流程六 初始化基本流程SpringApplication的prepareEnvironment准备环境SpringApplication的getOrCreateEnvironment创建环境configureEnvironment配置环境ApplicationConversionService的getSharedInstance配置转换器 SpringApplic…

C# WinForm —— 17 MaskedTextBox 介绍

1. 简介 本质是文本框,但它可以通过掩码来区分输入的正确与否,可以控制输入的格式、长度 主要应用场景是:需要格式化输入信息的情况 2. 常用属性 属性解释(Name)控件ID,在代码里引用的时候会用到,一般以 mtxt 开头AsciiOnly是否…

Ansible常用变量【下】

转载说明:如果您喜欢这篇文章并打算转载它,请私信作者取得授权。感谢您喜爱本文,请文明转载,谢谢。 前言 在上一篇文章《Ansible常用变量【上】》中,学习了Ansible常用变量的前半部分,放了个五一假&#x…

USB转串口芯片CH341、CH372、CH374、CH375等的电路及 PCB 设计的重要注意事项

前言 USB芯片的电路和PCB设计参考及注意事项,含CH34X、CH37X等系列芯片的电路设计说明。涉及工作稳定性和抗干扰以及USB-HOST带电热插拔。基于 USB 芯片的电路及 PCB 设计的重要注意事项 版本:2E 1、摘要 本文主要针对以下因电路及 PCB 设计不佳而引起…

浮点数的由来及运算解析

数学是自然科学的皇后,计算机的设计初衷是科学计算。计算机的最基本功能是需要存储整数、实数,及对整数和实数进行算术四则运算。 但是在计算机从业者的眼中,我们知道的数学相关的基本数据类型通常是整型、浮点型、布尔型。整型又分为int8&a…

点是否在三角形内C++源码实现

原理 思路: 面积和: abc obcaocabo,应该有更简洁的方法,但是这个方法思路更简单 代码实现: 注意二维向量的叉乘后,是垂直于平面的向量,相当于z为0三维向量叉乘,所以只有z维度有值,xy0. flo…

【Nginx <一>⭐️】Nginx 的初步了解以及安装使用

目录 👋前言 👀一、 Nginx 介绍 🌱二、 安装使用 💞️ 三、 总结 📫四、 章末 👋前言 小伙伴们大家好,前段时间主要在学习 Elasticsearch 相关的知识,花了两周的时间吧&#x…

排序-冒泡排序(bubble sort)

冒泡排序(Bubble Sort)是一种简单的排序算法,它重复地遍历待排序的数列,一次比较两个元素,如果它们的顺序错误就把它们交换过来。遍历数列的工作是重复地进行直到没有再需要交换,也就是说该数列已经排序完成…

JavaWeb--13Mybatis(2)

Mybatis(2) 1 Mybatis基础操作1.1 需求和准备工作1.2 删除员工日志输入参数占位符 1.3 新增员工1.4 修改员工信息1.5 查询员工1.5.1 根据ID查询数据封装 1.5.3 条件查询 2 XML配置文件规范3 MyBatis动态SQL3.1 什么是动态SQL3.2 动态SQL-if更新员工 3.3 …

决策树学习记录

对于一个决策树的决策面: 他其实是在任意两个特征基础上对于所有的点进行一个分类,并且展示出不同类别的之间的决策面,进而可以很清楚的看出在这两个特征上各个数据点种类的分布。 对于多输出的问题,在利用人的上半张脸来恢复下半…