transformer系列3---transformer结构参数量统计

news2024/11/25 20:37:07

Transformer参数量统计

  • 1 Embedding
  • 2 Positional Encoding
  • 3 Transformer Encoder
    • 3.1 单层EncoderLayer
      • 3.1.1 MHA
      • 3.1.2 layer normalization
      • 3.1.3 MLP
      • 3.1.4 layer normalization
    • 3.2 N层Encoderlayer总参数量
  • 4 Transformer Decoder
    • 4.1 单层Decoderlayer
      • 4.1.1 mask MHA
      • 4.1.2 layer normalization
      • 4.1.3 交叉多头注意力
      • 4.1.4 layer normalization
      • 4.1.5 MLP
      • 4.1.6 layer normalization
    • 4.2 N层Decoderlayer总参数量
    • 5 Transformer输出

1 Embedding

NLP算法会使用不同的分词方法表示所有单词,确定分词方法之后,首先建立一个词表,词表的维度是词总数vocab_size ×表示每个词向量维度d_model(论文中dmodel默认值512),这是一个非常稀疏的矩阵。这样,对于Transformer的encoder输入的句子sentence,先用相应的分词方法转换成新的序列src_vocab,然后用每个词的id去前面的稀疏矩阵查表,通过查表(nn.Embedding)将该序列转换到新的向量空间,就是词嵌入的结果。Transformer的decoder输入(输出)同理,先用相应的分词方法转换成新的序列tgt_vocab,然后将该序列经过embedding转换到新的向量空间。
因此统计参数量时,应为词表的维度=vocab_size × d_model

2 Positional Encoding

位置编码同理,首先建立一个位置矩阵,维度是输入向量的最大长度src_max_len × dmodel,实际使用时候根据实际词长度n取前n个位置编码
因此,位置编码的参数量=src_max_len × dmodel

3 Transformer Encoder

3.1 单层EncoderLayer

3.1.1 MHA

在这里插入图片描述

MHA包含 W Q , W K , W V W^{Q},W^{K},W^{V} WQWKWV和输出的权重矩阵 W O W^{O} WO以及偏置, W Q , W K 维度是 d m o d e l × d k , W V 维度是 d m o d e l × d v W^{Q},W^{K}维度是dmodel × dk,W^{V}维度是dmodel × dv WQWK维度是dmodel×dkWV维度是dmodel×dv,输出权重矩阵 W O 的维度是 ( h × d v ) × d m o d e l W^{O}的维度是(h×dv)×dmodel WO的维度是(h×dv)×dmodel,论文中dk = dv = dmodel/h = 64,头数h = 8。

  1. 一个头中3个矩阵的参数量是dmodel × dk + dmodel × dk + dmodel × dv= 3dk × dmodel
  2. h个头的参数量=h × 3 × dk × dmodel = 3 d m o d e l 2 3dmodel^{2} 3dmodel2
  3. 加上输出矩阵后的矩阵总参数量=(h × dv) × dmodel + 3dmodel × dmodel = 4 d m o d e l 2 4dmodel^{2} 4dmodel2
  4. 每个矩阵偏置维度是dmodel,4个矩阵的偏置=4dmodel
  5. MHA总参数量= 4 d m o d e l 2 + 4 d m o d e l 4dmodel^{2}+4dmodel 4dmodel2+4dmodel

3.1.2 layer normalization

layer normalization层的参数包含weight和bias,代码中nn.LayerNorm(dmodel),因此weight和bias的维度都是dmodel,参数量之和=2dmodel

3.1.3 MLP

由两个线性层组成,W1维度是(dmodel,4×dmodel),b1维度是4×dmodel,W2维度是(4×dmodel,dmodel),b2维度是dmodel,参数量为 dmodel×4×dmodel+4×dmodel+4×dmodel×dmodel+dmodel = 8 d m o d e l 2 + 5 d m o d e l 8dmodel^{2}+5dmodel 8dmodel2+5dmodel

3.1.4 layer normalization

同3.1.2,参数量之和=2dmodel

3.2 N层Encoderlayer总参数量

  1. 综上计算,Transformer中1层Encoderlayer的总参数量是 4 d m o d e l 2 + 4 d m o d e l + 2 d m o d e l + 8 d m o d e l 2 + 5 d m o d e l + 2 d m o d e l = 12 d m o d e l 2 + 13 d m o d e l 4dmodel^{2}+4dmodel+2dmodel+8dmodel^{2}+5dmodel+2dmodel=12dmodel^{2}+13dmodel 4dmodel2+4dmodel+2dmodel+8dmodel2+5dmodel+2dmodel=12dmodel2+13dmodel

  2. 论文中Encoderlayer的层数是N = 6 ,因此N层的Encoderlayer的参数量统计为 12 N d m o d e l 2 + 13 N d m o d e l 12Ndmodel^{2}+13Ndmodel 12Ndmodel2+13Ndmodel,实际中常常省略一次项,参数量统计= 12 N d m o d e l 2 12Ndmodel^{2} 12Ndmodel2

4 Transformer Decoder

Decoder比Encoder多一层交叉多头注意力,以及一个layer normalization,但计算方式与Encoder相同,直接采用上面的结论

4.1 单层Decoderlayer

4.1.1 mask MHA

mask MHA总参数量= 4 d m o d e l 2 + 4 d m o d e l 4dmodel^{2}+4dmodel 4dmodel2+4dmodel

4.1.2 layer normalization

layer normalization参数量=2dmodel

4.1.3 交叉多头注意力

总参数量= 4 d m o d e l 2 + 4 d m o d e l 4dmodel^{2}+4dmodel 4dmodel2+4dmodel

4.1.4 layer normalization

layer normalization参数量=2dmodel

4.1.5 MLP

MLP参数量= 8 d m o d e l 2 + 5 d m o d e l 8dmodel^{2}+5dmodel 8dmodel2+5dmodel

4.1.6 layer normalization

layer normalization参数量=2dmodel

4.2 N层Decoderlayer总参数量

  1. 1层Decoderlayer参数量为上述计算之和, 4 d m o d e l 2 + 4 d m o d e l + 2 d m o d e l + 4 d m o d e l 2 + 4 d m o d e l + 2 d m o d e l + 8 d m o d e l 2 + 5 d m o d e l + 2 d m o d e l = 16 d m o d e l 2 + 19 d m o d e l 4dmodel^{2}+4dmodel+2dmodel+4dmodel^{2}+4dmodel+2dmodel+8dmodel^{2}+5dmodel+2dmodel=16dmodel^{2}+19dmodel 4dmodel2+4dmodel+2dmodel+4dmodel2+4dmodel+2dmodel+8dmodel2+5dmodel+2dmodel=16dmodel2+19dmodel
  2. N层Decoderlayer参数量= 16 N d m o d e l 2 + 19 N d m o d e l 16Ndmodel^{2}+19Ndmodel 16Ndmodel2+19Ndmodel

5 Transformer输出

Decoder输入输出向量的最大长度tgt_max_len,最后一层参数量=dmodel×tgt_max_len

Transformer的总参数量为上述5个部分的参数量之和。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1042918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AUTOSAR中的Crypto Stack(二)--CSM数据类型解析

在上一节,简单梳理了加密栈的基本要求。其中最关键最核心的还是用户如何使用HSM这个黑盒子,这就必须要对Crypto Service Manager要有很清晰的认识。 那么首先我们还是围绕概述里提到的job类型进行分析。 1. Crypto_JobType 上图, 在AUTOSAR的架构里,所有的密码操作…

笔记本电脑查询连接wifi密码

笔记本电脑查询连接wifi密码 1、背景2、环境3、实操3.1、已连接wifi查看密码3.2、之前连接过的wifi密码查看 1、背景 在日常使用过程中遇到两个使用场景。网络管理员跳过一下步骤,针对wifi使用人员。 1、刚到一个新环境中需要连接wifi的场景 2、在一个场所连接过一…

【LeetCode热题100】--160.相交链表

160.相交链表 使用双指针: /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) {* val x;* next null;* }* }*/ public class Solution {public ListNode getInter…

基于Vue+ELement搭建动态树与数据表格实现分页

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的专栏《ELement》。🎯🎯 &#x1…

高等数学应试考点速览(下)

函数项级数 【收敛域】上,收敛于:【和函数】; 幂级数:绝对收敛区间 ( − R , R ) (-R,R) (−R,R),(端点是否属于收敛域,需要再探讨) R lim ⁡ n → ∞ ∣ a n a n 1 ∣ R\lim_{n…

LLM(二)| LIMA:在1k高质量数据上微调LLaMA1-65B,性能超越ChatGPT

本文将介绍在Lit-GPT上使用LoRA微调LLaMA模型,并介绍如何自定义数据集进行微调其他开源LLM 监督指令微调(Supervised Instruction Finetuning) 什么是监督指令微调?为什么关注它? 目前大部分LLM都是decoder-only&…

右键菜单添加 Open Git Bash

前言 在使用 TortoiseGit 作为Git的可视化工具,但是会经常用到命令行操作,一般来说,安装了TortoiseGit后,右键会出现 open git-bash here... 的命令。但是,可能由于某些原因,这个右键菜单选项不见了。下面…

springcloud:三、ribbon负载均衡原理+调整策略+饥饿加载

Ribbon负载均衡原理 调整Ribbon负载均衡策略 第一种会对order-service里所有的服务消费者都采用该新规则 第二种会针对order-service里某个具体的服务消费者采用该新规则 饥饿加载

【LeetCode】力扣364.周赛题解

Halo,这里是Ppeua。平时主要更新C,数据结构算法,Linux与ROS…感兴趣就关注我bua! 1.最大二进制奇数 🍉题目: 🍉例子: 🍉 题解: 首先看题目,最大二进制奇数,在一个二…

二十六、MySQL并发事务问题:脏读/不可重复读/幻读

1、事务的隔离级别 (1)隔离级别 Read uncommitted # 读,未提交 Read committed # 读,已提交 Repeatable Read(默认) # 可重复读 Serializable # 串读 (2)基础语法 set transaction isolation level 事…

高等数学应试考点速览(上)

极限 上界存在,则上确界存在数列极限 定义性质:唯一、有界(保序、夹逼、不等式性质)、保号、四则运算判定: 单侧:单调有界双侧:闭区间套增量:柯西审敛 归并和收敛子列聚点有限覆盖原…

想要精通算法和SQL的成长之路 - 最长递增子序列 II(线段树的运用)

想要精通算法和SQL的成长之路 - 最长递增子序列 II(线段树的运用) 前言一. 最长递增子序列 II1.1 向下递推1.2 向上递推1.3 更新操作1.4 查询操作1.5 完整代码: 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 最长递增子序列 II 原题链接…

idea 通过tomcat 配置 https方式访问

步骤1:管理员模式打开cmd命令进行生成密匙 D:\software\apache-tomcat-8.5.93\bin\tomcat.keystore 是生成密匙存放的路径,修改成自己tomcat的路径即可 keytool -genkeypair -alias "tomcat" -keyalg "RSA" -keystore "D:\s…

Spring Boot 集成 MinIO 实现文件上传、下载和删除

MinIO 是一种开源的对象存储服务,它基于云原生架构构建,并提供了高性能、易于扩展和安全的存储解决方案。 一.安装和配置 MinIO 服务器 为了演示方便,本文采用Windows安装 1.在官方网站下载MinIO 安装文件,地址:ht…

MATLAB中norm函数用法

目录 语法 说明 示例 向量模 向量的 1-范数 两个点之间的欧几里德距离 矩阵的 2-范数 N 维数组的 Frobenius 范数 常规向量范数 norm函数的功能是计算向量范数和矩阵范数。 语法 n norm(v) n norm(v,p) n norm(X) n norm(X,p) n norm(X,"fro") 说明…

Android 面试经历复盘整理~

此次面试一共4面4小时,中间只有几分钟间隔。对持续的面试状态考验还是蛮大的。 关于面试的心态,保持悲观的乐观主义心态比较好。面前做面试准备时保持悲观,尽可能的做足准备。面后积极做复盘,乐观的接受最终结果。 切忌急于下结论…

从裸机开始安装操作系统

目录 一、预置知识 电脑裸机 win10版本 官方镜像 V.S. 正版系统 二、下载微软官方原版系统镜像 三、使用微PE系统维护U盘 四、安装操作系统 五、总结 一、预置知识 电脑裸机 ●只有硬件部分,还未安装任何软件系统的电脑叫做裸机。 ●主板、硬盘、显卡等必…

2005-2018年上市公司高管前三名薪酬比例数据

2005-2018年上市公司高管前三名薪酬比例数据 1、时间:2005-2018年 2、指标:证券代码、year、高管薪酬总额、高管前三名薪酬总额、高管前三名薪酬比例、市场类型、行业代码、交易状态 3、范围:上市公司 4、指标解释: 薪酬是员…

深入理解C语言(1):数据在内存中的存储

文章主题:数据在内存中的存储🌏所属专栏:深入理解C语言📔作者简介:更新有关深入理解C语言知识的博主一枚,记录分享自己对C语言的深入解读。😆个人主页:[₽]的个人主页🏄&…

【红队攻防】从零开始的木马免杀到上线

0、环境配置说明 应该全部使用云服务器完整演示比较好,奈何太穷了买不起服务器,只能用本地环境演示。所需环境如下: 系统环境: CentOS 7 ,Windows 10 软件环境 Cobalt Strike 4.7 , ShellQMaker, 360杀…