PyTorch内存优化的10种策略总结:在有限资源环境下高效训练模型

news2025/4/21 15:20:03

在大规模深度学习模型训练过程中,GPU内存容量往往成为制约因素,尤其是在训练大型语言模型(LLM)和视觉Transformer等现代架构时。由于大多数研究者和开发者无法使用配备海量GPU内存的高端计算集群,因此掌握有效的内存优化技术变得尤为关键。本文将系统性地介绍多种内存优化策略,这些技术组合应用可使模型训练的内存消耗降低近20倍,同时不会损害模型性能和预测准确率。以下大部分技术可以相互结合,以获得更显著的内存效率提升。

1、自动混合精度训练

混合精度训练是降低内存占用的基础且高效的方法,它充分利用16位(FP16)和32位(FP32)浮点格式的优势。

混合精度训练的核心思想是在大部分计算中使用较低精度执行数学运算,从而减少内存带宽和存储需求,同时在计算的关键环节保持必要的精度。通过对激活值和梯度采用FP16格式,这些张量的内存占用可减少约50%。然而某些特定的层或操作仍需要FP32格式以避免数值不稳定。PyTorch对自动混合精度(AMP)的原生支持大大简化了实现过程。

混合精度训练低精度训练 有本质区别

关于混合精度训练是否会影响模型准确率的问题,答案是。混合精度训练通过精心设计的计算流程保持了计算精度。

混合精度训练原理

混合精度训练通过结合16位(

FP16

)和32位(

FP32

)浮点格式来保持计算准确性。使用16位精度计算梯度可显著加快计算速度并减少内存消耗,同时维持与32位分辨率相当的结果质量。这种方法在计算资源受限的环境中尤为有效。

"混合精度"一词更准确地描述了这一过程,因为并非所有参数和操作都转换为16位格式。实际训练过程中,32位和16位操作交替执行,形成混合精度计算流程。

如上图所示,该过程首先将权重转换为低精度(

FP16

)以加速计算,然后计算梯度,接着将梯度转回高精度(

FP32

)以确保数值稳定性,最后使用这些适当缩放的梯度更新原始权重。通过这种方式,混合精度训练可提高训练效率的同时保持网络的整体精度和稳定性。

使用

torch.cuda.amp.autocast()

可轻松实现混合精度训练,示例代码如下:

 import torch
from torch.cuda.amp import autocast, GradScaler

# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for data, target in data_loader:
    optimizer.zero_grad()
    # Enable mixed precision
    with autocast():
        output = model(data)
        loss = loss_fn(output, target)

    # Scale the loss and backpropagate
    scaler.scale(loss).backward()
    scaler.step(optimizer)
     scaler.update()

2、低精度训练

除了混合精度训练,我们还可以尝试使用完整的16位低精度格式进行训练。由于16位浮点数的表示范围限制,这种方法可能导致

NaN

值出现。为解决这一问题,研究人员开发了多种专用浮点格式。其中,Brain Floating Point(

BF16

)是Google为此专门开发的一种广受欢迎的格式。与标准

FP16

相比,

BF16

提供了更大的动态范围,能够表示极大和极小的数值,使其更适合于深度学习应用中可能遇到的多样化数值情况。尽管较低精度可能在某些计算中影响精确度或导致舍入误差,但在大多数深度学习应用场景中,这种影响对模型性能的影响极小。

虽然

BF16

最初是为TPU设计的,但现在大多数现代GPU(Nvidia Ampere架构及更高版本)也支持这种格式。可以通过以下方法检查GPU是否支持

BF16

格式:

 import torch
 print(torch.cuda.is_bf16_supported())  # should print True

3、梯度检查点

即便采用混合精度和低精度训练,大型模型在前向传播过程中产生的大量中间张量仍会消耗大量内存。**梯度检查点(Gradient Checkpointing)**技术通过在前向传播过程中选择性地仅存储部分中间结果来解决这一问题。在反向传播过程中,系统会重新计算缺失的中间值,这虽然增加了计算成本,但可以显著降低内存需求。

通过战略性地选择需要设置检查点的层,可以通过动态重新计算激活值而非存储它们来减少内存使用。对于具有深层架构的模型,中间激活值通常占据了内存消耗的主要部分,此时这种权衡尤为有效。以下是梯度检查点的实现示例:

 import torch
from torch.utils.checkpoint import checkpoint

def checkpointed_segment(input_tensor):
    # This function represents a portion of your model
    # which will be recomputed during the backward pass.
    # You can create a custom forward pass for this segment.
    return model_segment(input_tensor)

# Instead of a conventional forward pass, wrap the segment with checkpoint.
 output = checkpoint(checkpointed_segment, input_tensor)

采用此方法,在多数情况下可将激活值所需的内存减少40-50%。尽管反向传播现在包含额外的计算开销,但当GPU内存成为限制因素时,这种权衡通常是合理的。

4、使用梯度累积降低批量大小

在尝试上述方法后,一个自然的问题是:

为何不直接减小批量大小?

虽然这确实是最直接的方法,但通常使用较小批量大小会导致预测性能下降。简单减小批量大小虽然能显著降低内存消耗,但往往会对模型准确率产生不良影响。

如何在这两者之间取得平衡?

**梯度累积(Gradient Accumulation)**正是为解决这一问题而设计的技术。它允许在训练过程中虚拟增加批量大小,其核心原理是为较小的批量计算梯度,并在多次迭代中累积这些梯度(通常通过求和或平均),而不是在每个批次后立即更新模型权重。一旦累积的梯度达到目标"虚拟"批量大小,才使用这些累积的梯度更新模型参数。

然而需要注意,这种技术的主要缺点是显著增加了训练时间。

5、张量分片和分布式训练

对于即使应用上述优化后仍无法在单个GPU上容纳的超大模型,**完全分片数据并行(Fully Sharded Data Parallel, FSDP)**技术提供了解决方案。FSDP将模型参数、梯度和优化器状态分片到多个GPU上,这不仅使得训练超大模型成为可能,还能通过更合理地分配通信开销提高训练效率。

FSDP不是在每个GPU上维护完整的模型副本,而是将模型参数分配到多个可用设备上。在执行前向或反向传播时,系统仅将相关分片加载到内存中。这种分片机制显著降低了单个设备的内存需求,与前述技术结合使用,在某些情况下可实现高达10倍的内存降低效果。

FSDP可通过以下方式实现:

 import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()

# Wrap the model in FSDP for sharded training across GPUs.
 fsdp_model = FSDP(model)

6、高效的数据加载

内存优化中常被忽视的一个方面是数据加载效率。虽然大部分优化关注点集中在模型内部结构和计算过程,但低效的数据处理同样可能造成不必要的瓶颈,影响内存利用和计算速度。作为经验法则,当处理数据加载器时,应始终启用

Pinned Memory

和配置适当的

Multiple Workers

,如下所示:

 from torch.utils.data import DataLoader

# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,      # Adjust based on your CPU capabilities
    pin_memory=True     # Enables faster host-to-device transfers
 )

7、使用原地操作

在处理张量时,如果不谨慎管理,每个操作都可能创建新的张量对象。**原地操作(In-place Operations)**通过直接修改现有张量而非分配新张量,有助于减少内存碎片和总体内存占用。这种方式减少了临时内存分配,在迭代训练循环中尤为重要。示例如下:

 import torch
 
 x = torch.randn(100, 100, device='cuda')
 y = torch.randn(100, 100, device='cuda')
 
 # Using in-place addition
 x.add_(y)  # Here x is modified directly instead of creating a new tensor

8、激活和参数卸载

对于极大规模模型,即使应用了所有上述技术,由于大量中间激活值的存在,仍可能达到GPU内存限制。**激活和参数卸载(Activation and Parameter Offloading)**技术通过将部分中间数据移动到CPU内存,为GPU内存提供额外的缓解。

这种方法通过战略性地将部分激活值和/或参数临时卸载到主机内存(CPU),仅在GPU内存中保留关键计算所需的数据。虽然DeepSpeed、Fabric等专用框架可自动管理这种数据移动,但也可以按如下方式实现自定义卸载逻辑:

 def offload_activation(tensor):
    # Move tensor to CPU to save GPU memory
    return tensor.cpu()

def process_batch(data):
    # Offload some activations explicitly
    intermediate = model.layer1(data)
    intermediate = offload_activation(intermediate)
    intermediate = intermediate.cuda()  # Move back when needed
    output = model.layer2(intermediate)
     return output

9、使用更精简的优化器

各种优化器在内存消耗方面存在显著差异。例如,广泛使用的Adam优化器为每个模型参数维护两个额外状态参数(动量和方差),这意味着更多的内存消耗。将Adam替换为无状态优化器(如SGD)可将参数数量减少近2/3,这在处理LLM等大型模型时尤为重要。

标准SGD的缺点是收敛特性较差。为弥补这一点,可引入余弦退火学习率调度器以实现更好的收敛效果。实现示例:

 # instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
             optimizer, T_max=num_steps)

这种优化可在保持模型准确率达到约97%(取决于具体应用)的同时,显著改善峰值内存消耗。

10、进阶优化技术

除上述基础技术外,以下高级策略可进一步优化GPU内存使用,充分发挥硬件潜能:

内存分析和缓存管理

精确测量是有效优化的前提。PyTorch提供了多种实用工具用于监控GPU内存使用情况:

 import torch
 
 # print a detailed report of current GPU memory usage and fragmentation
 print(torch.cuda.memory_summary(device=None, abbreviated=False))
 
 # free up cached memory that's no longer needed by PyTorch
 torch.cuda.empty_cache()

使用TorchScript进行JIT编译

PyTorch的即时编译器(JIT)能够将Python模型转换为经过优化的、可序列化的TorchScript程序。这种转换通过优化内核启动和减少运行时开销,可带来内存和性能的双重提升:

 import torch
 
 # Suppose `model` is an instance of your PyTorch network.
 scripted_model = torch.jit.script(model)
 
 # Now, you can run the scripted model just like before.
 output = scripted_model(input_tensor)

这种编译方式可显著优化模型运行效率。

自定义内核融合

编译的另一项重要优势是能够将多个操作融合到单个计算内核中。内核融合有助于减少内存读写操作,提高总体计算吞吐量:

使用torch.compile()进行动态内存分配

进一步利用编译技术,JIT编译器可通过编译时优化改进动态内存分配效率。结合跟踪和计算图优化技术,这种方法可在大型模型和Transformer架构中实现更显著的内存和性能优化。

总结

在GPU和云计算资源成本高昂的环境下,最大化利用现有计算资源至关重要。对于希望在有限计算资源条件下训练或微调大型模型(如LLM或视觉Transformer)的研究者和开发者而言,掌握上述优化技术尤为重要。本文介绍的这些策略代表了研究人员和专业人士在资源受限条件下进行高效模型训练的常用方法。

https://avoid.overfit.cn/post/dc61dc9f03fc45f48dba26c21a276bce

作者:Sahib Dhanjal

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2308537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【湖北省计算机信息系统集成协会主办,多高校支持 | ACM出版,EI检索,往届已见刊检索】第二届边缘计算与并行、分布式计算国际学术会议(ECPDC 2025)

第二届边缘计算与并行、分布式计算国际学术会议(ECPDC 2025)将于2025年4月11日至13日在中国武汉盛大召开。本次会议旨在为边缘计算、并行计算及分布式计算领域的研究人员、学者和行业专家提供一个高水平的学术交流平台。 随着物联网、云计算和大数据技术…

硬件工程师入门教程

1.欧姆定律 测电压并联使用万用表测电流串联使用万用表,红入黑出 2.电阻的阻值识别 直插电阻 贴片电阻 3.电阻的功率 4.电阻的限流作用 限流电阻阻值的计算 单位换算关系 5.电阻的分流功能 6.电阻的分压功能 7.电容 电容简单来说是两块不连通的导体加上中间的绝…

性能测试监控工具jmeter+grafana

1、什么是性能测试监控体系? 为什么要有监控体系? 原因: 1、项目-日益复杂(内部除了代码外,还有中间件,数据库) 2、一个系统,背后可能有多个软/硬件组合支撑,影响性能的因…

DeepSeek如何快速开发PDF转Word软件

一、引言 如今,在线工具的普及让PDF转Word成为了一个常见需求,常见的PDF转Word工具有收费的WPS,免费的有PDFGear(详见:PDFGear:一款免费的PDF编辑、格式转化软件-CSDN博客),以及在线工具SmallP…

目标检测——数据处理

1. Mosaic 数据增强 Mosaic 数据增强步骤: (1). 选择四个图像: 从数据集中随机选择四张图像。这四张图像是用来组合成一个新图像的基础。 (2) 确定拼接位置: 设计一个新的画布(输入size的2倍),在指定范围内找出一个随机点(如…

基于springboot+vue的拖恒ERP-物资管理

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

spring结合mybatis多租户实现单库分表

实现单库分表-水平拆分 思路:student表数据量大,所以将其进行分表处理。一共有三个分表,分别是student0,student1,student2,在新增数据的时候,根据请求头中的meta-tenant参数决定数据存在哪张表…

YoloV8改进策略:Block改进|CBlock,Transformer式的卷积结构|即插即用

摘要 论文标题: SparseViT: Nonsemantics-Centered, Parameter-Efficient Image Manipulation Localization through Spare-Coding Transformer 论文链接: https://arxiv.org/pdf/2412.14598 官方GitHub: https://github.com/scu-zjz/SparseViT 这段代码出自SparseViT ,代码如…

微服务架构实践:SpringCloud与Docker容器化部署

## 微服务架构实践:SpringCloud与Docker容器化部署 随着互联网应用的复杂性不断增加,传统的单体应用架构面临着诸多挑战,如难以部署、维护困难、开发效率低下等问题凸显出来。为了解决这些问题,微服务架构应运而生,它通…

[原创]openwebui解决searxng通过接口请求不成功问题

openwebui 对接 searxng 时 无法查询到联网信息,使用bing搜索,每次返回json是正常的 神秘代码: http://172.30.254.200:8080/search?q北京市天气&formatjson&languagezh&time_range&safesearch0&languagezh&locale…

8 SpringBootWeb(下):登录效验、异步任务和多线程、SpringBoot中的事务管理@Transactional

文章目录 案例-登录认证1. 登录功能1.1 需求1.2 接口文档1.3 思路分析1.4 功能开发1.5 测试2. 登录校验2.1 问题分析2.2 会话技术2.2.1 会话技术介绍2.2.2 会话跟踪方案2.2.2.1 方案一 - Cookie2.2.2.2 方案二 - Session2.2.2.3 方案三 - 令牌技术2.2.3 JWT令牌(Token)2.2.3.…

2025年山东省职业院校技能大赛(高职组)“云计算应用”赛项赛卷1

“云计算应用”赛项赛卷1 2025年山东省职业院校技能大赛(高职组)“云计算应用”赛项赛卷1模块一 私有云(30分)任务1 私有云服务搭建(5分)1.1.1 基础环境配置1.1.2 yum源配置1.1.3 配置无秘钥ssh1.1.4 基础安…

MySQL数据库基本概念

目录 什么是数据库 从软件角度出发 从网络角度出发 MySQL数据库的client端和sever端进程 mysql的client端进程连接sever端进程 mysql配置文件 MySql存储引擎 MySQL的sql语句的分类 数据库 库的操作 创建数据库 不同校验规则对查询的数据的影响 不区分大小写 区…

塔能科技:工厂智慧照明,从底层科技实现照明系统的智能化控制

在全球节能减碳和智慧生活需求激增的背景下,基于“用软件定义硬件,让物联运维更简捷更节能”的产品理念,塔能科技的智慧照明一体化方案如新星般崛起,引领照明行业新方向。现在,我们来深入探究其背后的创新技术。该方案…

P3398 仓鼠找 sugar【题解】

这是LCA的一个应用,关于LCA P3398 仓鼠找 sugar 题目描述 小仓鼠的和他的基(mei)友(zi)sugar 住在地下洞穴中,每个节点的编号为 1 ∼ n 1\sim n 1∼n。地下洞穴是一个树形结构。这一天小仓鼠打算从从他…

Android Trace埋点beginSection打tag标签,Kotlin

Android Trace埋点beginSection打tag标签,Kotlin import android.os.Bundle import android.os.Trace import android.util.Log import androidx.appcompat.app.AppCompatActivityclass ImageActivity : AppCompatActivity() {companion object {const val TRACE_TA…

Lua的table(表)

Lua表的基本概念 Lua中的表(table)是一种多功能数据结构,可以用作数组、字典、集合等。表是Lua中唯一的数据结构机制,其他数据结构如数组、列表、队列等都可以通过表来实现。 表的实现 Lua的表由两部分组成: 数组部分…

51页精品PPT | 农产品区块链溯源信息化平台整体解决方案

PPT展示了一个基于区块链技术的农产品溯源信息化平台的整体解决方案。它从建设背景和需求分析出发,强调了农产品质量安全溯源的重要性以及国际国内的相关政策要求,指出了食品安全问题在流通环节中的根源。方案提出了全面感知、责任到人、定期考核和追溯反…

Jenkins 自动打包项目镜像部署到服务器 ---(前端项目)

Jenkins 新增前端项目Job 指定运行的节点 选择部署运行的节点标签,dev标签对应开发环境 节点的远程命令执行配置 jenkins完整流程 配置源码 拉取 Credentials添加 触发远程构建 配置后可以支持远程触发jenkins构建(比如自建的CICD自动化发布平台&…

使用AoT让.NetFramework4.7.2程序调用.Net8编写的库

1、创建.Net8的库&#xff0c;双击解决方案中的项目&#xff0c;修改如下&#xff0c;启用AoT&#xff1a; <Project Sdk"Microsoft.NET.Sdk"><PropertyGroup><OutputType>Library</OutputType><PublishAot>true</PublishAot>&…