Transformer 相关模型的参数量计算

news2024/9/24 7:19:45

如何计算Transformer 相关模型的参数量呢?
先回忆一下Transformer模型论文《Attention is all your need》中的两个图。
在这里插入图片描述
在这里插入图片描述

设Transformer模型的层数为N,每个Transformer层主要由self-attention 和 Feed Forward组成。设self-attention模块的head个数为 n h e a d n_{head} nhead,每一个head对应的维度为 d h e a d d_{head} dhead,self-attention输出维度为 d m o d e l = n heads ⋅ d head d_{model}= n_\text{heads}\cdot d_\text{head} dmodel=nheadsdhead。我们可以得到一个Transformer层的参数量为 12 d m o d e l 2 + 13 d m o d e l 12 d_{model}^2 + 13 d_{model} 12dmodel2+13dmodel,具体如下:

  • self-attention块的模型参数有Q、K、V的权重矩阵 W Q 、 W K 、 W V W_Q、W_K 、W_V WQWKWV和偏置,输出矩阵 W O W_O WO及其偏置。这4个权重矩阵的大小为 [ d m o d e l , d m o d e l ] [d_{model}, d_{model}] [dmodel,dmodel],4个偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel],所以self-attention块的参数量为 4 d m o d e l 2 + 4 d m o d e l 4 d_{model}^2 + 4 d_{model} 4dmodel2+4dmodel

  • Feed Forward块一般由2个线性层组成,第一个线性层将维度从 d m o d e l d_{model} dmodel 映射成 4 d m o d e l 4d_{model} 4dmodel, 其权重矩阵 W 1 W_1 W1的大小为 [ d m o d e l , 4 d m o d e l ] [d_{model}, 4d_{model}] [dmodel,4dmodel] ,其偏置的大小为 [ 4 d m o d e l ] [4d_{model}] [4dmodel]。 第二个线性层将维度从 4 d m o d e l 4d_{model} 4dmodel 映射成 d m o d e l d_{model} dmodel,其权重矩阵 W 2 W_2 W2的大小为 [ 4 d m o d e l , d m o d e l ] [4d_{model}, d_{model}] [4dmodel,dmodel] ,其偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel]。所以Feed Forward的参数量为 8 d m o d e l 2 + 5 d m o d e l 8 d_{model}^2 + 5 d_{model} 8dmodel2+5dmodel

  • self-attention 和 Feed Forward都跟随着layer normalization,它有两个可训练模型参数,形状都是 [ d m o d e l ] [d_{model}] [dmodel]。所以2个layer normalization的参数量为 4 d m o d e l 4 d_{model} 4dmodel

除了Transformer层之外的参数有:

  • 词embedding矩阵的参数量,embedding的维度通常等于 d m o d e l d_{model} dmodel,设词表的大小为V,则词embedding的参数量为 V d m o d e l Vd_{model} Vdmodel
  • 位置向量相关,有些位置向量表示方式需要学习参数。

所以N层Transformer模型的可训练模型参数量为 N ( 12 d m o d e l 2 + 13 d m o d e l ) + V d m o d e l N(12 d_{model}^2 + 13 d_{model}) + Vd_{model} N(12dmodel2+13dmodel)+Vdmodel。当 d m o d e l d_{model} dmodel较大时,可以忽略一次项,模型参数量近似为 12 N d m o d e l 2 12 N d_{model}^2 12Ndmodel2

最后试验一下模型参数估计量与论文是否对的上,下表是GPT3和LLaMA的计算对比,可以发现数量级是可以对的上的,因为我们忽略了一次项,所以具体数据与论文不一致。

模型名实际参数量 n l a y e r n_{layer} nlayer d m o d e l d_{model} dmodel n h e a d n_{head} nhead d h e a d d_{head} dhead估计参数量
GPT-3175B961228896128173946175488
LLaMA 6.7B6.7B324096321286442450944
LLaMA 13.0B13.0B4051204012812582912000
LLaMA 32.5B32.5B6066565212831897681920
LLaMA 65.2B65.2B8081926412864424509440

参考资料

  1. Transformer 论文(模型图来自论文)、GPT3的论文等

  2. 整理过程中参考的blog: 1. 知乎用户回旋托马斯x 的文章,除了计算量外,还算了计算量、中间激活等 , 2 transformer 参数量计算, 3 flops 计算, 4 transformers 参数量计算公式

  3. transfomers 库如何得到参数量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/900606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

回归预测 | MATLAB实现CSO-SVM布谷鸟优化算法优化支持向量机多输入单输出回归预测(多指标,多图)

回归预测 | MATLAB实现CSO-SVM布谷鸟优化算法优化支持向量机多输入单输出回归预测(多指标,多图) 目录 回归预测 | MATLAB实现CSO-SVM布谷鸟优化算法优化支持向量机多输入单输出回归预测(多指标,多图)效果一…

移植PeerTalk开源库IOS的USB通信监听服务到QT生成的FFmpeg工程

1.添加生成的PeerTalk库 下图选中部分为FFmpeg依赖库 将USB通信服务的m与h文件添加到工程 因为OC文件使用了弱指针,所以要启用弱指针支持 因为FFmpeg拉流动用到本地网络,所以要在plist文件中启动本地网络使用 设置PeerTalk为嵌入模式 设置Runpath Search Paths为@executable_p…

FPGA:uart原理+tx发送模块+rx接收模块

文章目录 一、串口通信二、UART通信三、tx发送模块四、rx模块接收 一、串口通信 处理器与外部设备通信的两种方式: 串行通信: 指数据的各个位使用多条数据线同时进行传输。 并行通信: 将数据分成一位一位的形式在一条数据线上逐个传输。 串…

1.flink快速入门

前言 下图表示的是一个简单的flink-job的计算图,这种图被称为DAG(有向无环图),表示的这个任务的计算逻辑,无论是spark、hive、还是flink都会把用户的计算逻辑转换为这样的DAG,数据的计算按照DAG触发,理论上只要构建出…

spring源码分析bean的生命周期(下)

doGetBean()执行过程 createBean()执行过程 一、DependsOn注解 spring创建对象之前会判断类上是否加了DependsOn注解,加了会遍历然后会添加到一个map中,spring会先创建DependsOn注解指定的类 二、spring类加载器 在合并BeanDefinition,确定…

centos7.9和redhat6.9 离线升级OpenSSH和openssl (2023年的版本)

升级注意事项! 1、多开几个连接窗口(xshell),避免升级openssh失败无法再次连接终端,否则要跑机房了。 2、可开启telnet服务、vnc服务、打快照。多几个“保命”的路数。一、centos7.9的信息 [rootnode2 ~]# openssl v…

1391. 检查网格中是否存在有效路径;2502. 设计内存分配器;1638. 统计只差一个字符的子串数目

核心思想:并查集。枚举网格中的块,把能连通的连通在一起,最后看(0,0)和(m-1,n-1)是否连通,然后网格中的每个点坐标是二维的,然后通过x*ny转换为一维&#xff…

大数据课程K2——Spark的RDD弹性分布式数据集

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 了解Spark的RDD结构; ⚪ 掌握Spark的RDD操作方法; ⚪ 掌握Spark的RDD常用变换方法、常用执行方法; 一、Spark最核心的数据结构——RDD弹性分布式数据集 1. 概述 初学Spark时,把RDD看…

超实用的批量管理工具 pssh 和 window 文件传输工具 pscp

文章目录 一、概述1)pssh2)pscp 二、pssh 工具安装三、pssh 命令的基本语法四、pscp 工具安装1)Windows 上安装2)Linux 系统上安装 五、 pscp 命令的基本语法1)从 windows 向 linux 传文件2)从 linux 传文件…

算法:滑动窗口解决连续区间子数组问题

文章目录 实现原理实现思路典型例题长度最小的子数组无重复字符的最小字串最大连续1的个数III将x减到0的最小操作水果成篮找到字符串中所有字母异位词(哈希表比较优化)对哈希表内元素比较的优化 总结 本篇积累的是滑动窗口的问题,滑动窗口在算法实现中有重要作用&am…

Python可视化在量化交易中的应用(16)_Seaborn热力图

Seaborn中热力图的绘制方法 seaborn中绘制热力图使用的是sns.heatmap()函数: sns.heatmap(data,vmin,vmax,cmap,center,robust,annot,fmt‘.2g’,annot_kws,linewidths0,linecolor‘white’,cbar,cbar_kws,cbar_ax,square,xticklabels‘auto’,yticklabels‘auto’…

systemd:初学者如何理解其中的争议

导读对于什么是 systemd,以及为什么它经常成为 Linux 世界争议的焦点,你可能仍然感到困惑。我将尝试用简单的语言来回答。 在 Linux 世界中,很少有争议能像传统的 System V 初始化 系统(通常称为 SysVinit)和较新的 s…

QT设置widget背景图片

首先说方法,在给widget或者frame或者其他任何类型的控件添加背景图时,在样式表中加入如下代码,指定某个控件,设置其背景。 类名 # 控件名 { 填充方式:图片路径 } 例如: QWidget#Widget {border-image: url…

1. 微信小程序开发环境搭建

下载 微信的小程序开发需要使用到微信开发者工具,通过https://developers.weixin.qq.com/miniprogram/dev/devtools/stable.html可以下载 下载完成后 安装

Linux 系统编程拾遗

Linux 系统编程拾遗 进程的创建 进程的创建 fork()、exit()、wait()以及execve()的简介 创建新进程:fork()

人工智能原理(6)

目录 一、机器学习概述 1、学习和机器学习 2、学习系统 3、机器学习发展简史 4、机器学习分类 二、归纳学习 1、归纳学习的基本概念 2、变型空间学习 3、归纳偏置 三、决策树 1、决策树组成 2、决策树的构造算法CLS 3、ID3 4、决策树的偏置 四、基于实例的学习…

嵌入式系统总线-片内总线

1.总线概述 总线是CPU与存储器和设备通信的机制,是计算机各部件之间传送数据、地址和控制信息的公共通道。 2.总线参数 总线宽度:又称总线位宽,指的是总线能同时传送数据的位数。如16位总线就是具有16位数据传送能力。 总线频率&#xff…

apex安装出错:TypeError unsupported operand type(s) for +: “NoneType“ and “str“

Windows 10 环境下安装apex报错:TypeError unsupported operand type(s) for : “NoneType“ and “str“ 1、首先apex不能直接pip install apex安装。 2、具体安装步骤:【python】【深度学习】apex的安装_apex python_愿东大没有食堂的博客-CSDN博客 …

深入竞品:解读竞品分析的艺术与策略

引言:为何竞品分析至关重要? 在当今的产品环境中,市场变得越来越拥挤。每个角落都有新的创业公司试图创造下一个行业的颠覆者,同时也有成熟的巨头在不断地迭代和优化他们的产品。在这样的环境中,不了解您的竞争对手是…

『C语言初阶』第八章 -结构体

前言 今天小羊又来给铁汁们分享关于C语言的结构体,在C语言中,结构体类型属于一种构造类型(其他的构造类型还有:数组类型,联合类型),今天我们主要简单了解一下结构体。 一、结构体是什么&#x…