目录
- model 内存
- gradients 内存
- activates 内存
经典图打底:
训练深度模型的内存消耗主要有以下几个部分:
- 存储模型可训练参数
- 存储梯度
- 存储反向传播中间变量,例如:
L = ( Y − Y ^ ) 2 Y ^ = X T W ∂ L ∂ W = − 2 ( Y − Y ^ ) ∂ Y ^ ∂ W = − 2 ( Y − X T W ) X \begin{aligned} L &= (Y - \hat Y)^2\\ \hat Y &= X^T W\\ \frac{\partial L}{\partial W}&= -2(Y-\hat Y) \frac{\partial \hat Y }{\partial W} = -2(Y- X^T W) X \end{aligned} LY^∂W∂L=(Y−Y^)2=XTW=−2(Y−Y^)∂W∂Y^=−2(Y−XTW)X
这里面 X X X 就需要保存下来供反向传播时使用
下面具体的分析中需要用到每一层的具体运算张量,具体可以参考 Transfomer矩阵维度分析及MultiHead详解
model 内存
"""
计算储存Transformer模型可训练参数所需的内存
参数:
- vocab_in_size: vocab_in大小
- vocab_out_size: vocab_out大小
- encoder_layers_num: 编码器层数
- decoder_layers_num: 解码器层数
- d_model: 编码器和解码器的隐藏层大小
- num_head: 头的数量
- embedding_size: 词嵌入大小
- filter_size: 前馈子层的隐藏层大小
- batch_size: 批大小
- seq_len: 输入序列长度
- bias: 是否加偏置项
- include_pos_embedding: 位置编码是否单独包含可优化参数
- dropout_rate: 例如: 0.1
- dtype_size: 默认为4 (FP32),若是FP16,改为2
返回:
- 所需内存,以字节为单位。
"""
bias = bias * 1
# 计算encoder embedding的参数内存消耗
encoder_embedding_params = vocab_in_size * embedding_size
# 计算 Encoder 的参数内存消耗
# Multi-head Attention parameters: 3 * (d_model * d_model) + (d_model * d_model)
# Layer normalization: d_model + d_model * bias
# Feed-forward network parameters: d_model * filter_size + filter_size * d_model
attention_params = 4 * d_model * d_model
layer_norm_params = d_model + d_model * bias
ffn_params_params = 2 * d_model * filter_size
encoder_params = (attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * encoder_layers_num
# 计算decoder embedding的参数内存消耗
decoder_embedding_params = vocab_out_size * embedding_size
# 计算 Decoder 的参数内存消耗
# Masked Multi-head Attention parameters: 4 * (d_model * d_model)
# Multi-head Attention parameters: 4 * (d_model * d_model)
decoder_params = (attention_params + layer_norm_params + attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * decoder_layers_num
# 计算最后 output 层的参数内存消耗
output_params = d_model * vocab_out_size
# 计算储存模型可训练参数所需内存,考虑 dropout_rate(近似估算)
model_memory = (encoder_embedding_params + encoder_params + decoder_embedding_params + decoder_params + output_params) * (1 + dropout_rate) * dtype_size
if include_pos_embedding:
model_memory += seq_len * d_model * 2 # encoder 和 decoder 各有一个 pos embedding
gradients 内存
这里除了 gradients 内存,还考虑了一些小项,例如 mask,优化器 等消耗的内存
def get_inputs_mem(batch_size, seq_len, dtype_size=8):
"""
计算Transformer模型输入数据的内存占用
参数:
- batch_size: 批大小
- seq_len: 输入序列长度
- dtype_size: 默认为8 (int64)
返回:
- 所需内存,以字节为单位。
"""
return batch_size * seq_len * dtype_size * 2 # 同时计算输入和输出
# 计算attention中的mask的内存消耗
# Mask: seq_len * seq_len for each attention block
mask_memory = seq_len * seq_len * (encoder_layers_num + decoder_layers_num*2) * dtype_size
# 计算gradients消耗的内存, 训练过程中的梯度与模型参数的形状相同,因此梯度的内存大小也是 model_memory
grads_memory = model_memory
# 计算优化器消耗的内存,此处以adam为例,对每一个可训练参数,需要储存一个一阶动量和一个二阶动量
# 若使用的其他优化器,此处按需修改
optimizer_memory = 2 * model_memory
# 数据存储消耗的内存
inputs_memory = get_inputs_mem(batch_size,seq_len)
activates 内存
"""
计算中间结果(activates)的内存消耗,反向传播需要用到这些中间结果
参数:
- vocab_out_size: vocab_out大小
- encoder_layers_num: 编码器层数
- decoder_layers_num: 解码器层数
- d_model: 编码器和解码器的隐藏层大小
- num_head: 头的数量
- filter_size: 前馈子层的隐藏层大小
- batch_size: 批大小
- seq_len: 输入序列长度
- dtype_size: 默认为4 (FP32),若是FP16,改为2
返回:
- 所需内存,以字节为单位。
"""
# 由于各个layer的输入和输出size都是 batch_size * seq_len * d_model, 先计算出来后续使用
N = batch_size * seq_len * d_model * dtype_size
# 计算每层 attention 部分的中间结果内存消耗
# 1.linear transformation: X*W_q = Q, X*W_k = K, X*W_v = V, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 X (只需储存一个,因为是同一个X)
# 2.由于 Attention(Q,K,V) = softmax(QK^T/sqrt(d))V, 其中 QK^T 的张量为 [batch_size * num_head * seq_len * d_model/num_head] * [batch_size * num_head * d_model/num_head * seq_len] = [batch_size * num_head * seq_len * seq_len]
# V 张量为 [batch_size * num_head * seq_len * d_model/num_head], 需要存储 Q, K, V, softmax(QK^T/sqrt(d))
# 3.output linear transformation: Y = Attention(Q,K,V)*W_2, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 Attention(Q,K,V)
linear_memory = N
softmax_memory = 3 * N + batch_size * num_head * seq_len * seq_len * dtype_size
output_memory = N
attention_memory = linear_memory + softmax_memory + output_memory
# 计算每层的 Layer normalization 的中间结果内存消耗, Layer normalization 输出张量为 batch_size * seq_len * d_model
layer_norm_memory = N
# 计算每层的 FFN 部分的中间结果内存消耗
# 1.第一层 linear transformation: X*W_1 = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * filter_size] = [batch_size * seq_len * filter_size], 需储存 X
# 2.中间 Relu 连接: Y' = Relu(Y), 需储存 Y'
# 3.第二层 linear transformation: Y'*W_2 = Z, 张量为 [batch_size * seq_len * filter_size] * [batch_size * filter_size * d_model] = [batch_size * seq_len * d_model], 需储存 Y'
ffn_memory = N + 2 * batch_size * seq_len * filter_size * dtype_size
encoder_memory = (attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * encoder_layers_num
decoder_memory = (attention_memory + layer_norm_memory + attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * decoder_layers_num
# 计算 output 层的中间结果内存消耗
# 1.output linear transformation: X*W = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * vocab_out_size] = [batch_size * seq_len * vocab_out_size], 需储存 X
# 2.softmax(Y): 需储存 softmax(Y)
output_memory = N + batch_size * seq_len * vocab_out_size * dtype_size
total_activates_memory = encoder_memory + decoder_memory + output_memory
将上述三个部分加总,就是训练 Transfomer 模型大概需要的内存消耗。
NOTE:
- 这里没有考虑混合精度训练,如果考虑混合精度训练,还需要在不同的部分,使用不同的 dtype_size
- 如果是GPT这种 decoder-only 或者 encoder-only 的模型,只需要 decoder_layers_num = 0,即可 (decoder-only 也是这样做的,因为decoder-only 中的 Masked Multi-head Attention 没有了,实际的参数情况和 encoder-only 是一样的)
Reference:
Transformer Memory Arithmetic: Understanding all the Bytes in nanoGPT
Formula to compute approximate memory requirements of transformer models
Transformer Math 101