torch.utils.checkpoint
是 PyTorch 中用于实现 梯度检查点(gradient checkpointing)的模块。它通过在反向传播中 重新计算 前向传播的某些部分,以显著减少激活值的显存占用。
梯度检查点的核心原理
- 在前向传播中,不是保存每一层的激活值,而是保存输入和部分中间结果。
- 在反向传播时,重新计算需要的前向激活值。
- 优势:
- 显存占用减少:适合超大模型的训练。
- 劣势:
- 计算量增加:反向传播时需要额外的前向计算。
核心API
1. torch.utils.checkpoint.checkpoint
torch.utils.checkpoint.checkpoint
是 PyTorch 提供的一种 内存优化工具,通过 计算图重新计算 的方式来节省显存。它特别适用于深度学习中 大模型或长序列 的训练场景,能够在不降低模型性能的情况下减少显存使用。
工作原理
-
标准前向传播:
- 默认情况下,PyTorch 在前向传播过程中,会存储中间激活值以供反向传播使用。
- 如果模型层数很多或者中间激活值占用大量显存,会导致显存不足。
-
检查点机制:
- 在前向传播时,
torch.utils.checkpoint.checkpoint
会丢弃某些中间激活值(未存储在显存中)。 - 在反向传播时,丢弃的中间激活值会通过 重新计算前向传播 来生成。
- 通过这种方式,显存的占用降低,但会增加一些前向计算的开销。
- 在前向传播时,
函数签名:
torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True)
参数
function
:- 前向传播的函数,必须是纯函数(只依赖输入,不依赖外部状态)。
*args
:- 传递给
function
的参数。
- 传递给
use_reentrant
(默认值为True
):- 如果设置为
True
,使用旧的递归检查点实现;如果为False
,启用非递归实现,推荐设置为False
来避免潜在问题。
- 如果设置为
优缺点
优点
节省显存:
- 丢弃中间激活值后,显存占用显著降低,适合训练大模型。
适配性强:
- 不需要修改模型结构,只需在关键的计算图中插入检查点即可。
返回值
output
:
- 前向传播的结果。
使用场景
大模型的训练:
- 模型层数较多,激活值占用大量显存时&#