大模型显存占用由以下几部分组成:
1. 模型本身参数,假设是1个单位
2.模型的梯度,同样也是一个单位
3.优化器参数(占大头):以Adam参数为例,还需要在显卡中额外存储m和v两个参数,因此为2个单位参数
4.模型的中间计算结果,因为反向传播求导时会用到,需要存储每一层的输入x(下图以Transformer中的全连接层为例,每一个全连接层的输入参数维度为[batch, 句子长度, 每个token维度])
以11B大小模型为例,其模型参数占据显存大小就为40GB,再加上其余三个部分后显存花销更大