[NLP] BERT模型参数量

news2024/7/3 17:35:24

一 BERT_Base 110M参数拆解

BERT_base模型的110M的参数具体是如何组成的呢,我们一起来计算一下:

刚好也能更深入地了解一下Transformer Encoder模型的架构细节。

借助transformers模块查看一下模型的架构:

import torch
from transformers import BertTokenizer, BertModel

bertModel = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
for name,param in bertModel.named_parameters():

print(name, param.shape)

得到的模型参数为:

embeddings.word_embeddings.weight torch.Size([30522, 768])
embeddings.position_embeddings.weight torch.Size([512, 768])
embeddings.token_type_embeddings.weight torch.Size([2, 768])
embeddings.LayerNorm.weight torch.Size([768])
embeddings.LayerNorm.bias torch.Size([768])

encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias torch.Size([768])
encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias torch.Size([768])
encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias torch.Size([768])

encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])

encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.0.intermediate.dense.bias torch.Size([3072])
encoder.layer.0.output.dense.weight torch.Size([768, 3072])
encoder.layer.0.output.dense.bias torch.Size([768])
encoder.layer.0.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.output.LayerNorm.bias torch.Size([768])

encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
encoder.layer.11.attention.self.query.bias torch.Size([768])
encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
encoder.layer.11.attention.self.key.bias torch.Size([768])
encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
encoder.layer.11.attention.self.value.bias torch.Size([768])
encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.11.attention.output.dense.bias torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.11.intermediate.dense.bias torch.Size([3072])
encoder.layer.11.output.dense.weight torch.Size([768, 3072])
encoder.layer.11.output.dense.bias torch.Size([768])
encoder.layer.11.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.output.LayerNorm.bias torch.Size([768])

pooler.dense.weight torch.Size([768, 768])
pooler.dense.bias torch.Size([768])

其中,BERT模型的参数主要由三部分组成:

Embedding层参数

Transformer Encoder层参数

LayerNorm层参数

二 Embedding层参数

由于词向量是由Token embedding,Position embedding,Segment embedding三部分构成的,因此embedding层的参数也包括以上三部分的参数。

BERT_base英文词表大小为:30522, 隐藏层hidden_size=768,文本最大长度seq_len = 512

Token embedding参数量为:30522 * 768;

Position embedding参数量为:512 * 768;

Segment embedding参数量为:2 * 768。

因此总的参数量为:(30522 + 512 +2)* 768 = 23,835,648

 

LN层在Embedding层

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

三 Transformer Encoder层参数

可以将该部分拆解成两部分:Self-attention层参数、Feed-Forward Network层参数

1.Self-attention层参数

改层主要是由Q、K、V三个矩阵运算组成,BERT模型中是Multi-head多头的Self-attention(记为SA)机制。先通过Q和K矩阵运算并通过softmax变换得到对应的权重矩阵,然后将权重矩阵与 V矩阵相乘,最后将12个头得到的结果进行concat,得到最终的SA层输出。

1. multi-head因为分成12份, 单个head的参数是 768 * (768/12) * 3,  紧接着将多个head进行concat再进行变换,此时W的大小是768 * 768

    12个head就是  768 * (768/12) * 3 * 12  + 768 * 768 = 1,769,472 + 589,824 = 2359296

3. LN层在Self-attention层

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

2.Feed-Forward Network层参数

由FFN(x)=max(0, xW1+b1)W2+b2可知,前馈网络FFN主要由两个全连接层组成,且W1和W2的形状分别是(768,3072),(3072,768),因此该层的参数量为:

feed forward的参数主要由两个全连接层组成,intermediate_size为3072(原文中4H长度) ,那么参数为12*(768*3072+3072*768)= 56623104

LN层在FFN

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

layer normalization

layer normalization有两个参数,分别是gamma和beta。有三个地方用到了layer normalization,分别是embedding层后、multi-head attention后、feed forward后,这三部分的参数为768*2+12*(768*2+768*2)=38400

四 总结

综上,BERT模型的参数总量为:

23835648 + 12*2359296(28311552)   + 56623104 +  38400  = 108808704  ≈103.7M

Embedding层约占参数总量的20%,Transformer层约占参数总量的80%。

注:本文介绍的参数仅是BERT模型的Transformer Encoder部分的参数,涉及的bias由于参数很少,本文也未计入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/903397.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Leetcode.118 杨辉三角

题目链接 Leetcode.118 杨辉三角 easy 题目描述 给定一个非负整数 n u m R o w s numRows numRows,生成「杨辉三角」的前 n u m R o w s numRows numRows 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出:…

STM32--DMA

文章目录 DMA简介DMA特性 DMA框图DMA基本结构DMA请求数据宽度对齐DMA数据转运工程DMAADC多通道 DMA简介 直接存储器存取(DMA)用来提供在外设和存储器之间或者存储器和存储器之间的高速数据传输。无须CPU干预,数据可以通过DMA快速地移动,这就节省了CPU的…

Goland 注释时自动在注释符号后添加空格

不得不说 JetBrains 旗下的 IDE 都好用,而且对于注释这块,使用 Ctrl / 进行注释的时候,大多会在每个注释符号后统一添加一个空格,比如 PyCharm 和 RubeMine 等。 # PyCharm # print("hello world") # RubyMine # req…

基于web网上订餐系统的设计与实现(论文+源码)_kaic

目录 1绪论 1.1课题研究背景 1.2研究现状 1.3主要内容 1.4本文结构 2网上订餐系统需求分析 2.1系统业务流程分析 2.2消费者用户业务流程分析 2.3商户业务流程分析 2.4管理员用户流程分析消费者用户用例分析 2.5系统用例分析 3网上订餐系统设计 3.1功能概述 3.2订单管理模块概要…

MySQL安装、配置和启动关闭

1. 概述 本文主要内容: MySQL下载;MySQL的安装;配置环境变量;登录MySQL服务器;查询系统数据库;启动和关闭服务; 2. 安装、配置、启动与关闭服务 2.1. MySQL下载 在MySQL官网就可以下载。 …

aardio简单日历实例

import console; io.open()//aardio简单日历实例getMonthDays function(year,month){var startDate year"/"month"/""1"; //当月1日var endDate time(startDate).addmonth(1).addday(-1); //当月末return endDate.diffday(time(startDate))1…

数据结构 - 算法设计的基本要求

1、算法的描述: 自然语言:英语、中文流程图:传统流程图、NS流程图伪代码:类语言 - 类C语言程序代码:Java、C语言 2、算法的特性: 一个算法必须具备以下五个特性: 3、算法设计的要求 正确性可…

msvcp140.dll文件丢失的解决方法是什么?

在日常使用电脑的时候,有时候会遇到一些使用问题。比如,有一次遇到了这样一个问题。那就是,因为“msvcp140.dll”这个文件丢失,有些软件安装不了。今天把我在网上找了很久的解决方法分享给大家,希望也能帮到大家。 丢…

页面滑动到可视区域加载更多内容思维流程

页面滑动到可视区域加载更多内容思维流程

Slingshot | 细胞分化轨迹的这样做比较简单哦!~(二)

1写在前面 今天又值班了,你没有听错!! 🥲 又值班了!!!!😅 最近自己的确不太在状态,做事情有极强的拖延症,要振奋起来啦,man&#xff0…

编写Dockerfile制作自己的镜像并推送到私有仓库

说明:我将用到的私有仓库是Harbor,安装教程参考我的这一篇文章: 安装搭建私有仓库Harbor_Word_Smith_的博客-CSDN博客 一、案例1 1、要求 编写Dockerfile制作Web应用系统nginx镜像,生成镜像nginx:v1.1,并推送其到私…

Linux学习(一)虚拟机安装

1、简介 最近准备开始进行linux的学习,本文从头开始记录学习过程以及遇到困难处理办法,便于以后复习、指令复制等。 2、虚拟机安装 2.1 VMware虚拟机安装 安装包链接:ubuntu20.04 https://www.aliyundrive.com/s/ZN8kZFKvBRu 点击链接保存…

十一、Linux用户及用户组的权限信息如何查看?如何修改?什么是权限的数字序号?

目录: 1、认知权限信息 2、rwx? (1)总括: (2)r权限: (3)w权限: (4)x权限: 3、修改权限 (1&a…

电脑提示找不到d3dcompiler_47.dll的解决方案

在电脑上玩游戏或许工作中是我日常生活中的一大乐趣。然而,最近我遇到了一个问题,让我对我的游戏还有我的工作软件体验感到非常沮丧。这个问题就是d3dcompiler_47.dll文件的丢失。当我尝试启动一个新的游戏时,一个错误提示窗口出现在我的屏幕…

操作系统经典互斥问题哲学家就餐问题

问题描述 由Dijkstra提出并解决的哲学家就餐问题是典型的同步问题。该问题描述的是五个哲学家共用一张圆桌,分别坐在周围的五张椅子上,在圆桌上有五个碗和五只筷子,他们的生活方式是交替的进行思考和进餐。平时,一个哲学家进行思考…

LeetCode 42题:接雨水

题目 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3,2,1,…

c#扩展方法的使用

扩展方法可以向现有类型“添加”方法,无需创建新的派生类型、重新编译或以其他方式修改原始类型,用起来很方便,下面是我写的例子,为string这个常用的类型添加一个showmes方法,以下是扩展方法的代码: public…

React+Typescript 状态管理

好 本文 我们来说说状态管理 也就是我们的 state 我们直接顺便写一个组件 参考代码如下 import * as React from "react";interface IProps {title: string,age: number }interface IState {count:number }export default class hello extends React.Component<I…

【LeetCode-中等题】49. 字母异位词分组

题目 题解一:排序哈希表 思路:由于互为字母异位词的两个字符串包含的字母相同&#xff0c;因此对两个字符串分别进行排序之后得到的字符串一定是相同的&#xff0c;故可以将排序之后的字符串作为哈希表的键。 核心api: //将字符串转换为字符数组char[] ch str.toCharArray();…

【力扣】84. 柱状图中最大的矩形 <模拟、双指针、单调栈>

目录 【力扣】84. 柱状图中最大的矩形题解暴力求解双指针单调栈 【力扣】84. 柱状图中最大的矩形 给定 n 个非负整数&#xff0c;用来表示柱状图中各个柱子的高度。每个柱子彼此相邻&#xff0c;且宽度为 1 。求在该柱状图中&#xff0c;能够勾勒出来的矩形的最大面积。 示例…