BertModel源码解析
- 1. BertModel 介绍
- 2. BertModel 源码逐行注释
1. BertModel 介绍
BertModel 是 transformers 库中的核心模型之一,它实现了 BERT(Bidirectional Encoder Representations from Transformers)模型的架构。BERT 是基于 Transformer 编码器的堆叠模块来构建的。以下是 BertModel 内部包含的主要模块和组件的详细介绍:
- BertEmbeddings (BertEmbeddings 源码解析)
将词嵌入(Token Embeddings) 、位置嵌入(Position Embeddings) 和 标记类型嵌入(Segment Embeddings) 组合起来,为每个输入token生成最终的嵌入表示
- BertEncoder
BERT 模型的核心部分,包含了多个堆叠的 Transformer 编码器层(Layer)。每一层都是一个自注意力机制与前馈神经网络的组合。即:
--------- Self-Attention Heads (BertSelfAttention)
--------- Feedforward Neural Network (BertIntermediate & BertOutput)
- BertPooler
负责将编码器的输出转化为单一的全局表示。
通常使用第一个 token([CLS])的表示作为整个序列的表示,并通过一个线性层加上 tanh 激活函数生成最终的句子向量。
这个句子向量可以用于分类或其他需要整体序列表示的任务。
BertModel 是由多个模块组合而成的复杂架构,这些模块协同工作,共同实现了强大的文本表示能力。通过这些模块,BertModel 能够捕捉句子中深层次的语义信息,并应用于广泛的 NLP 任务。
2. BertModel 源码逐行注释
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
# -*- coding: utf-8 -*-
# @time: 2024/7/11 10:43
"""PyTorch BERT model."""
import torch
from typing import List, Optional, Tuple, Union
from transformers import BertPreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings
_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class BertModel(BertPreTrainedModel):
"""
模型可以作为编码器(仅使用自注意力)或解码器。在作为解码器时,自注意力层之间会添加一层交叉注意力层,这遵循 Ashish Vaswani、Noam Shazeer、Niki Parmar、Jakob Uszkoreit、Llion Jones、Aidan N. Gomez、Lukasz Kaiser 和 Illia Polosukhin 所描述的 [Attention is all you need](https://arxiv.org/abs/1706.03762) 架构。
要使模型作为解码器,需要在初始化时将配置中的 `is_decoder` 参数设置为 `True`。要在 Seq2Seq 模型中使用,需要将 `is_decoder` 和 `add_cross_attention` 参数都设置为 `True`;此时在前向传播中需要提供 `encoder_hidden_states` 作为输入。
"""
def __init(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config # 保存传入的配置对象
self.embeddings = BertEmbeddings(config) # 初始化 BERT 的嵌入层
self.encoder = BertEncoder(config) # 初始化 BERT 的编码器层
self.pooler = BertPooler(config) if add_pooling_layer else None # 如果 add_pooling_layer 为 True,则初始化池化层
self.attn_implementation = config._attn_implementation # 保存注意力机制的实现细节
self.position_embedding_type = config.position_embedding_type # 保存位置嵌入的类型
self.post_init() # 执行一些初始化后的操作
def get_input_embeddingss(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
修剪模型的注意力头。heads_to_prune: 是一个字典,包含 {layer_num: 该层中要修剪的头的列表}。详见基类 PreTrainedModel。
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# 为模型的前向传递函数添加文档字符串和代码示例
"""
add_start_docstrings_to_model_forward装饰器会将 BERT_INPUTS_DOCSTRING 中的文档字符串添加到模型前向传递函数的开头部分。
BERT_INPUTS_DOCSTRING 是一个格式化字符串,其中包含有关输入张量形状的信息。在这里,它被格式化为 "batch_size, sequence_length",描述了输入的批量大小和序列长度。
add_code_sample_docstrings装饰器会为模型的前向传递函数添加代码示例文档字符串。
checkpoint 参数指定了用于文档的检查点名称。
output_type 参数指定了模型前向传递输出的类型,这里是 BaseModelOutputWithPoolingAndCrossAttentions。
config_class 参数指定了模型配置的类,这里是 _CONFIG_FOR_DOC。
"""
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`,*可选*):
编码器最后一层输出的隐藏状态序列。如果模型配置为解码器,则在交叉注意力中使用。
encoder_attention_mask (`torch.FloatTensor`,形状为 `(batch_size, sequence_length)`,*可选*):
用于避免对编码器输入中的填充标记索引执行注意力操作的掩码。如果模型配置为解码器,则在交叉注意力中使用。掩码值选择 `[0, 1]`:
- 1 表示**未被掩码**的标记,
- 0 表示**被掩码**的标记。
past_key_values (`tuple(tuple(torch.FloatTensor))`,长度为 `config.n_layers`,每个元组有4个形状为 `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)` 的张量):
包含注意力块的预计算键和值的隐藏状态。可用于加速解码。
如果使用 `past_key_values`,用户可以选择仅输入形状为 `(batch_size, 1)` 的最后一个 `decoder_input_ids`(那些没有其过去键值状态的输入)而不是形状为 `(batch_size, sequence_length)` 的所有 `decoder_input_ids`。
use_cache (`bool`,*可选*):
如果设置为 `True`,则返回 `past_key_values` 键值状态,并可用于加速解码(参见 `past_key_values`)。
"""
# ------------------------------1. 关于参数的配置---------------------------
"""
最后得到的参数有:
output_attentions(是否返回所有注意力层的注意力张量),
output_hidden_states(是否返回所有层的隐藏状态),
return_dict(是否返回ModelOutput而不是普通元组),
use_cache(如果设置为True, past_key_values则返回键值状态并可用于加快解码速度),
batch_size,
seq_length,
device,
past_key_values_length(包含注意力块的预计算键和值隐藏状态, 可用于加速解码),
token_type_ids,
"""
# 如果 output_attentions 不为 None,则使用其值,否则使用配置中的默认值 self.config.output_attentions
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 如果 output_hidden_states 不为 None,则使用其值,否则使用配置中的默认值 self.config.output_hidden_states
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
# 如果 return_dict 不为 None,则使用其值,否则使用配置中的默认值 self.config.use_return_dict
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 如果模型配置为解码器,设置 use_cache 参数
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False # 如果不是解码器,强制将 use_cache 设置为 False
# 检查 input_ids 和 inputs_embeds,不能同时指定
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
# 如果指定了 input_ids,检查填充和注意力掩码
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size() # 获取 input_ids 的形状
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] # 获取 inputs_embeds 的形状(除去最后一维)
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# 从 input_shape 中获取 batch_size 和 seq_length
batch_size, seq_length = input_shape
# 如果 input_ids 不为 None,则设备为 input_ids 的设备. 否则,设备为 inputs_embeds 的设备
device = input_ids.device if input_ids is not None else inputs_embeds.device
# 如果 past_key_values 不为 None,则获取 past_key_values 中第一个元素的形状的第三维长度作为 past_key_values_length. 否则,将 past_key_values_length 设置为 0
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if token_type_ids is None:
# 如果模型的嵌入层有 token_type_ids 属性
if hasattr(self.embeddings, "token_type_ids"):
# 从嵌入层的 token_type_ids 中获取前 seq_length 的部分
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
# 扩展 token_type_ids 以匹配 batch_size 和 seq_length
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
# 将扩展后的 token_type_ids 赋值给 token_type_ids
token_type_ids = buffered_token_type_ids_expanded
else:
# 如果嵌入层没有 token_type_ids 属性,则创建一个全零的 token_type_ids
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# ------------------------------ *2. 输入嵌入层(Input Embeddings)---------------------------
# 计算嵌入层的输出
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# ------------------------------3. 注意力掩码的配置------------------------------------------
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
# 判断是否使用 SDPA 注意力掩码的条件
use_sdpa_attention_masks = (
self.attn_implementation == "sdpa" # 判断注意力实现是否为 SDPA
and self.position_embedding_type == "absolute" # 判断位置嵌入类型是否为绝对位置
and head_mask is None # 判断是否没有指定 head_mask
and not output_attentions # 判断是否不需要输出注意力
)
# 根据条件 use_sdpa_attention_masks 进行扩展注意力掩码
if use_sdpa_attention_masks:
# 为 SDPA 扩展注意力掩码
# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]
if self.config.is_decoder:
# 如果是解码器,准备 4D 因果注意力掩码: 这种掩码确保解码器在生成下一个词时只能看到当前词及其之前的词,而不能看到未来的词。
"""
attention_mask:输入的注意力掩码,通常是一个二维张量,表示每个词的位置是否应该被注意力机制关注。
input_shape:输入的形状,通常是 (batch_size, seq_length)。
embedding_output:嵌入层的输出,包含输入词的嵌入表示。
past_key_values_length:过去键值的长度,用于支持缓存机制(如在解码器中)。
"""
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
# 如果不是解码器,准备 4D 注意力掩码
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# 提供一个维度为 [batch_size, from_seq_length, to_seq_length] 的自注意力掩码
# 只需使其可广播到所有注意力头
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# ================================ 获得了extended_attention_mask: [batch_size, 1, seq_length, seq_length]=======================================
# 如果为交叉注意力提供了 2D 或 3D 注意力掩码
# 需要使其可广播到 [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None: # 解码器的配置
# 获取编码器隐藏状态的形状
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
# 如果没有提供编码器注意力掩码,创建一个全为 1 的编码器注意力掩码
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if use_sdpa_attention_masks:
# 为 SDPA 扩展注意力掩码
# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# 否则,反转编码器注意力掩码
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
# 如果不需要编码器扩展注意力掩码,设置为 None
encoder_extended_attention_mask = None
# ================================ 获得了encoder_extended_attention_mask: [batch_size, 1, seq_length, seq_length]或 None====================================
# 准备注意力头掩码(如果需要)
# 1.0 表示保留该注意力头
# attention_probs 的形状为 bsz x n_heads x N x N
# 输入的 head_mask 的形状为 [num_heads] 或 [num_hidden_layers x num_heads]
# head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# ================================ 获得了head_mask, 一般情况下是None ====================================
# ------------------------------ *4. 编码器层(Encoder Layers)-----------------------------
# 传递输入到编码器,并获取编码器的输出
encoder_outputs = self.encoder(
embedding_output, # 嵌入层的输出
attention_mask=extended_attention_mask, # 扩展的注意力掩码
head_mask=head_mask, # 注意力头掩码,一般None
encoder_hidden_states=encoder_hidden_states, # 编码器的隐藏状态,一般None
encoder_attention_mask=encoder_extended_attention_mask, # 编码器的注意力掩码
past_key_values=past_key_values, # 过去的键值对
use_cache=use_cache, # 是否使用缓存
output_attentions=output_attentions, # 是否输出注意力
output_hidden_states=output_hidden_states, # 是否输出隐藏状态
return_dict=return_dict, # 是否返回字典
)
# 从编码器的输出中获取序列输出,这里的encoder_outputs[0]值其实就是last_hidden_state
sequence_output = encoder_outputs[0]
# ------------------------------ *5. 池化层(Pooling Layer)--------------------------------
# 如果存在池化层,则对序列输出进行池化,池化就是加了一层线性变换和tanh激活函数
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# ------------------------------6. 返回输出结果--------------------------------------------
# 如果 return_dict 为 False,返回一个元组,其中包含序列输出、池化输出和编码器输出中的其他部分
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
# 如果 return_dict 为 True,返回一个 BaseModelOutputWithPoolingAndCrossAttentions 对象
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, # 最后隐藏状态
pooler_output=pooled_output, # 池化输出
past_key_values=encoder_outputs.past_key_values, # 过去的键值对
hidden_states=encoder_outputs.hidden_states, # 隐藏状态
attentions=encoder_outputs.attentions, # 注意力
cross_attentions=encoder_outputs.cross_attentions, # 交叉注意力
)