训练 Transfomer 模型的内存消耗计算

news2025/1/11 4:15:24

目录

  • model 内存
  • gradients 内存
  • activates 内存

经典图打底:

训练深度模型的内存消耗主要有以下几个部分:

  1. 存储模型可训练参数
  2. 存储梯度
  3. 存储反向传播中间变量,例如:

L = ( Y − Y ^ ) 2 Y ^ = X T W ∂ L ∂ W = − 2 ( Y − Y ^ ) ∂ Y ^ ∂ W = − 2 ( Y − X T W ) X \begin{aligned} L &= (Y - \hat Y)^2\\ \hat Y &= X^T W\\ \frac{\partial L}{\partial W}&= -2(Y-\hat Y) \frac{\partial \hat Y }{\partial W} = -2(Y- X^T W) X \end{aligned} LY^WL=(YY^)2=XTW=2(YY^)WY^=2(YXTW)X
这里面 X X X 就需要保存下来供反向传播时使用

下面具体的分析中需要用到每一层的具体运算张量,具体可以参考 Transfomer矩阵维度分析及MultiHead详解


model 内存

    """
    计算储存Transformer模型可训练参数所需的内存

    参数:
    - vocab_in_size: vocab_in大小
    - vocab_out_size: vocab_out大小
    - encoder_layers_num: 编码器层数
    - decoder_layers_num: 解码器层数
    - d_model: 编码器和解码器的隐藏层大小
    - num_head: 头的数量
    - embedding_size: 词嵌入大小
    - filter_size: 前馈子层的隐藏层大小
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - bias: 是否加偏置项
    - include_pos_embedding: 位置编码是否单独包含可优化参数
    - dropout_rate: 例如: 0.1
    - dtype_size: 默认为4 (FP32),若是FP16,改为2

    返回:
    - 所需内存,以字节为单位。
    """

    bias = bias * 1

    # 计算encoder embedding的参数内存消耗
    encoder_embedding_params = vocab_in_size * embedding_size

    # 计算 Encoder 的参数内存消耗
    # Multi-head Attention parameters: 3 * (d_model * d_model) + (d_model * d_model)
    # Layer normalization: d_model + d_model * bias
    # Feed-forward network parameters: d_model * filter_size + filter_size * d_model
    attention_params = 4 * d_model * d_model
    layer_norm_params = d_model + d_model * bias
    ffn_params_params = 2 * d_model * filter_size
    encoder_params = (attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * encoder_layers_num

    # 计算decoder embedding的参数内存消耗
    decoder_embedding_params = vocab_out_size * embedding_size
    # 计算 Decoder 的参数内存消耗
    # Masked Multi-head Attention parameters: 4 * (d_model * d_model)
    # Multi-head Attention parameters: 4 * (d_model * d_model)
    decoder_params = (attention_params + layer_norm_params + attention_params + layer_norm_params + ffn_params_params + layer_norm_params) * decoder_layers_num

    # 计算最后 output 层的参数内存消耗
    output_params = d_model * vocab_out_size

    # 计算储存模型可训练参数所需内存,考虑 dropout_rate(近似估算)
    model_memory = (encoder_embedding_params + encoder_params + decoder_embedding_params + decoder_params + output_params) * (1 + dropout_rate) * dtype_size
    if include_pos_embedding:
        model_memory += seq_len * d_model * 2 # encoder 和 decoder 各有一个 pos embedding
     

gradients 内存

这里除了 gradients 内存,还考虑了一些小项,例如 mask,优化器 等消耗的内存

def get_inputs_mem(batch_size, seq_len, dtype_size=8):
    """
    计算Transformer模型输入数据的内存占用

    参数:
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - dtype_size: 默认为8 (int64)

    返回:
    - 所需内存,以字节为单位。
    """
    return batch_size * seq_len * dtype_size * 2  # 同时计算输入和输出

    # 计算attention中的mask的内存消耗
    # Mask: seq_len * seq_len for each attention block
    mask_memory = seq_len * seq_len * (encoder_layers_num + decoder_layers_num*2) * dtype_size

    # 计算gradients消耗的内存, 训练过程中的梯度与模型参数的形状相同,因此梯度的内存大小也是 model_memory
    grads_memory = model_memory

    # 计算优化器消耗的内存,此处以adam为例,对每一个可训练参数,需要储存一个一阶动量和一个二阶动量
    # 若使用的其他优化器,此处按需修改
    optimizer_memory = 2 * model_memory

    # 数据存储消耗的内存
    inputs_memory = get_inputs_mem(batch_size,seq_len)


activates 内存

    """
    计算中间结果(activates)的内存消耗,反向传播需要用到这些中间结果
    
     参数:
    - vocab_out_size: vocab_out大小
    - encoder_layers_num: 编码器层数
    - decoder_layers_num: 解码器层数
    - d_model: 编码器和解码器的隐藏层大小
    - num_head: 头的数量
    - filter_size: 前馈子层的隐藏层大小
    - batch_size: 批大小
    - seq_len: 输入序列长度
    - dtype_size: 默认为4 (FP32),若是FP16,改为2

    返回:
    - 所需内存,以字节为单位。
    """

    # 由于各个layer的输入和输出size都是 batch_size * seq_len * d_model, 先计算出来后续使用
    N = batch_size * seq_len * d_model * dtype_size

    # 计算每层 attention 部分的中间结果内存消耗
    # 1.linear transformation: X*W_q = Q, X*W_k = K, X*W_v = V, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 X (只需储存一个,因为是同一个X)
    # 2.由于 Attention(Q,K,V) = softmax(QK^T/sqrt(d))V, 其中 QK^T 的张量为 [batch_size * num_head * seq_len * d_model/num_head] * [batch_size * num_head * d_model/num_head * seq_len] = [batch_size * num_head * seq_len * seq_len]
    # V 张量为 [batch_size * num_head * seq_len * d_model/num_head], 需要存储 Q, K, V, softmax(QK^T/sqrt(d))
    # 3.output linear transformation: Y = Attention(Q,K,V)*W_2, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * d_model] = [batch_size * seq_len * d_model], 需储存 Attention(Q,K,V)
    linear_memory = N
    softmax_memory = 3 * N + batch_size * num_head * seq_len * seq_len * dtype_size
    output_memory = N
    attention_memory = linear_memory + softmax_memory + output_memory

    # 计算每层的 Layer normalization 的中间结果内存消耗, Layer normalization 输出张量为 batch_size * seq_len * d_model
    layer_norm_memory = N

    # 计算每层的 FFN 部分的中间结果内存消耗
    # 1.第一层 linear transformation: X*W_1 = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * filter_size] = [batch_size * seq_len * filter_size], 需储存 X
    # 2.中间 Relu 连接: Y' = Relu(Y), 需储存 Y'
    # 3.第二层 linear transformation: Y'*W_2 = Z, 张量为 [batch_size * seq_len * filter_size] * [batch_size * filter_size * d_model] = [batch_size * seq_len * d_model], 需储存 Y'
    ffn_memory = N + 2 * batch_size * seq_len * filter_size * dtype_size

    encoder_memory = (attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * encoder_layers_num
    decoder_memory = (attention_memory + layer_norm_memory + attention_memory + layer_norm_memory + ffn_memory + layer_norm_memory) * decoder_layers_num

    # 计算 output 层的中间结果内存消耗
    # 1.output linear transformation: X*W = Y, 张量为 [batch_size * seq_len * d_model] * [batch_size * d_model * vocab_out_size] = [batch_size * seq_len * vocab_out_size], 需储存 X
    # 2.softmax(Y): 需储存 softmax(Y)
    output_memory = N + batch_size * seq_len * vocab_out_size * dtype_size

    total_activates_memory = encoder_memory + decoder_memory + output_memory

将上述三个部分加总,就是训练 Transfomer 模型大概需要的内存消耗。

NOTE:

  1. 这里没有考虑混合精度训练,如果考虑混合精度训练,还需要在不同的部分,使用不同的 dtype_size
  2. 如果是GPT这种 decoder-only 或者 encoder-only 的模型,只需要 decoder_layers_num = 0,即可 (decoder-only 也是这样做的,因为decoder-only 中的 Masked Multi-head Attention 没有了,实际的参数情况和 encoder-only 是一样的)

Reference:
Transformer Memory Arithmetic: Understanding all the Bytes in nanoGPT
Formula to compute approximate memory requirements of transformer models
Transformer Math 101

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2043217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Transformer架构;Encoder-Decoder;Padding Mask;Sequence Mask;

目录 Transformer架构 Transformer架构的主要组成部分: 简单举例说明输入和输出: Encoder-Decoder 编码器/解码器组成 6、位置前馈网络(Position-wise Feed-Forward Networks) 7、残差连接和层归一化 10、掩码Mask 10.1 Padding Mask 10.2 Sequence Mask 为什么…

Gradio 复杂布局的实现

Gradio Interface 和 ChatInterface 布局都相对固定,只能通过参数添加组件,如果想要自定义页面布局,就需要更高级的布局方式 Block 。Gradio 中可以通过行和列进行布局,可以互相嵌套。我们先看一官方的例子: import g…

Vue Mixins 深度解析含面试常问题

Vue Mixins 深度解析含面试常问题 文章目录 Vue Mixins 深度解析含面试常问题一、Mixin 是什么二、Vue中如何使用1. 创建Mixin2. 使用Mixin3. 合并策略4. 全局Mixin5. 使用场景 三、包含哪些属性或方法API四、扩展与高级技巧1. 命名冲突2. 全局 vs 局部3. 合并策略深入4. 使用高…

商品期权会爆仓吗?

商品期权交易中存在爆仓的情况。一个期权的价格与其基础资产的波动性密切相关。在波动性高的情况下,尽管收益可能更高,但投资者也需要面对更大的价格波动风险,商品期权有买方和卖方,买方无爆仓风险,卖方是保证金交易有…

Hadoop大数据集群搭建

一、虚拟机配置网络 1、配置文件 进入“/etc/sysconfig/network-scripts”目录,查看当前目录下的“ifcfg-ens33”文件 对“ens33”文件进行配置 2、重启网络 systemctl restart network 3、测试网络 Ping www.baidu.com 4、设置虚拟机主机名称 5、绑定主机名和…

【android 9】【input】【11.发送普通motion事件1——touch设备的加载——MultiTouchInputMapper】

系列文章目录 可跳转到下面链接查看下表所有内容https://blog.csdn.net/handsomethefirst/article/details/138226266?spm1001.2014.3001.5501文章浏览阅读2次。系列文章大全https://blog.csdn.net/handsomethefirst/article/details/138226266?spm1001.2014.3001.5501 目录 …

传知代码-CENet及多模态情感计算实战(论文复现)

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 一、概述 本文对 “Cross-Modal Enhancement Network for Multimodal Sentiment Analysis” 论文进行讲解和手把手复现教学,解决当下热门的多模态情感计算问题,并展示在MOSI和MOSEI两个数…

labview经验分享1-任意16进制字符类型匹配

系列文章目录 1、任意16进制字符类型匹配 文章目录 系列文章目录问题导入实现任意16进制字符类型匹配在这里插入图片描述 总结 问题导入 labveiw的字符串匹配,使用的是正则表达式,可以让我们很方便的对字符串进行字符处理操作。 但是某些情况下&#…

WEB渗透Bypass篇-常规操作

绕过lsa-protection https://github.com/RedCursorSecurityConsulting/PPLKillerLinux绕过disable_function LD_PRELOAD linux环境 putenv()、mail()可用 https://github.com/yangyangwithgnu/bypass_disablefunc_via_LD_PRELOAD http://192.168.0.107/bypass_disablefunc.p…

一篇文章教你搭建一个高深莫测的SQL优化器

❓在数据库操作中,SQL优化一直是一个让人头疼的问题。今天,我将教你一种无需编写任何代码,只需要两个组件,便能轻松搭建一个高深莫测的SQL优化器的方法。通过这个方法,它可以将巨慢无比的SQL,把速度优化到极…

重启人生计划-浮舟沧海

🥳🥳🥳 茫茫人海千千万万,感谢这一刻你看到了我的文章,感谢观赏,大家好呀,我是最爱吃鱼罐头,大家可以叫鱼罐头呦~🥳🥳🥳 如果你觉得这个【重启人生…

VIM复合命令

VIM提供了很多 复合命令,可以把两个动作合并为一次按键。极大提高了编辑效率。以下是一些具体的例子: 复合命令等效的长命令说明Cc$删除光标到行尾scl删除光标位置的字符S^C删除整行I^i光标移动到行首A$a光标移动到行尾oA 回车光标下方开启一行Oko光标…

一文掌握SOP搭建步骤方法

如果你正在阅读这篇文章,那么你很可能在寻找如何为你的企业编写标准操作程序(SOP)的指导,以确保更好的流程被传达给你的团队并且得到遵循。 为什么SOPs很重要 SOPs必须清晰地传达你的业务流程,以标准化操作并确保盈利性…

Vue2 消息订阅与发布

1.pubsub-js 第三方库实现 实现任何框架的消息订阅发布 npm i pubsub-js <template><div class"student"><h2>展示学生的名称:{{ name }}</h2><h2>展示学生的性别:{{ sex }}</h2></div> </template><script>…

浏览器插件利器--allWebPluginV2.0.0.16-Stable版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品&#xff0c;致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX控件直接嵌入浏览器&#xff0c;实现插件加载、界面显示、接口调用、事件回调等。支持Chrome、Firefo…

ollama使用llama3.1案例

ollama安装和运行llama3.1 8b conda create -n ollama python3.11 -y conda activate ollama curl -fsSL https://ollama.com/install.sh | sh ollama run songfy/llama3.1:8b 就这么简单就能运行起来了. 我们可以在命令行中与他交互. 当然我们也可以用接口访问: curl http:…

在IDEA中用自带的数据库 连接 redis 失败(JedisAccessControlException)

文章目录 1、问题出现的背景2、分析问题出现的原因3、解决办法不用输入用户名直接输入密码即可 1、问题出现的背景 redis.clients.jedis.exceptions.JedisAccessControlException: WRONGPASS invalid username-password pair or user is disabled.2、分析问题出现的原因 查看…

智慧水务项目(六)PyScada学习一,初步建立项目并测试

一、说明 Pyscada是scada的python实现&#xff0c;需要学习一下&#xff0c;以备不时之需&#xff0c;目前我的想法是用他来模拟opc数据&#xff0c;毕竟我准备做的项目需要系统与scada通过opc进行通信&#xff0c;正好做一个简单的scada系统 是一个开源的SCADA&#xff08;S…

记录|C#主界面设计【Web风格】

目录 前言一、页面效果二、布局设计2.1 左边菜单栏搭建框架Step1. panelMenu &#xff1a;Step2. panelLogoStep3. button模板Step4. 复制buttonStep5. 微调Button 2.2 界面颜色变换Step1. ThemeColor类Step2. From1.csStep3. 更换按钮点击颜色效果 2.3 按钮点击事件2.4 顶部ti…

十、Linux二进制安装ClickHouse集群(含rpm安装)

目录 十、Linux二进制安装ClickHouse集群(含rpm安装&#xff0c;单机版使用rpm&#xff0c;集群使用tar包安装方式)1 部署前服务器配置&#xff08;集群的话三台都要配置&#xff09;1.2 配置hosts文件1.3 打开文件数限制1.4 取消 SELINUX1.5 禁用透明大页 2 下载所需文件2.1 t…