本节将实现一个基于Transformer架构的BERT模型。
1. MultiHeadAttention 类
这个类实现了多头自注意力机制(Multi-Head Self-Attention),是Transformer架构的核心部分。
在前几篇文章中均有讲解,直接上代码
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(1)
atten_scores = atten_scores.masked_fill(mask == 0, -1e9)
atten_scores = torch.softmax(atten_scores, dim=-1)
out = atten_scores @ V
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.dropout(self.o_proj(out))
2. FeedForward 类
这个类实现了Transformer中的前馈网络(Feed-Forward Network, FFN)。
在前几篇文章中均有讲解,直接上代码
class FeedForward(nn.Module):
def __init__(self, d_model, dff, dropout):
super().__init__()
self.W1 = nn.Linear(d_model, dff)
self.act = nn.GELU()
self.W2 = nn.Linear(dff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.W2(self.dropout(self.act(self.W1(x))))
3. TransformerEncoderBlock 类
这个类实现了Transformer架构中的一个编码器块(Encoder Block)。
在前几篇文章中有Decoder的讲解(与Encoder原理基本相似),直接上代码
class TransformerEncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout, dff):
super().__init__()
self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn_block = FeedForward(d_model, dff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))
res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))
return res2
4. BertModel 类
这个类实现了BERT模型的整体架构。
class BertModel(nn.Module):
def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.seg_emb = nn.Embedding(3, d_model)
self.pos_emb = nn.Embedding(seq_len, d_model)
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, num_heads, dropout, dff)
for _ in range(N_blocks)
])
self.norm = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, seg_ids, mask):
pos = torch.arange(x.shape[1])
tok_emb = self.tok_emb(x)
seg_emb = self.seg_emb(seg_ids)
pos_emb = self.pos_emb(pos)
x = tok_emb + seg_emb + pos_emb
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
return x
-
词嵌入、段嵌入和位置嵌入:
-
tok_emb
:将输入的词索引映射到词嵌入空间。 -
seg_emb
:用于区分不同的句子(例如在BERT中,用于区分句子A和句子B)。 -
pos_emb
:将位置信息编码到嵌入空间,使模型能够捕捉到序列中的位置信息。
-
-
Transformer编码器层:通过
nn.ModuleList
堆叠了N_blocks
个TransformerEncoderBlock
,每个块都负责对输入序列进行进一步的特征提取。 -
层归一化和Dropout:在所有编码器层处理完毕后,对输出进行层归一化和Dropout处理,进一步稳定模型的输出。
Bert完整代码(标红部分为本节所提到部分)
import re
import math
import torch
import random
import torch.nn as nn
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
# nn.TransformerEncoderLayer
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(1)
atten_scores = atten_scores.masked_fill(mask == 0, -1e9)
atten_scores = torch.softmax(atten_scores, dim=-1)
out = atten_scores @ V
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.dropout(self.o_proj(out))
class FeedForward(nn.Module):
def __init__(self, d_model, dff, dropout):
super().__init__()
self.W1 = nn.Linear(d_model, dff)
self.act = nn.GELU()
self.W2 = nn.Linear(dff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.W2(self.dropout(self.act(self.W1(x))))
class TransformerEncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout, dff):
super().__init__()
self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn_block = FeedForward(d_model, dff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))
res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))
return res2
class BertModel(nn.Module):
def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.seg_emb = nn.Embedding(3, d_model)
self.pos_emb = nn.Embedding(seq_len, d_model)
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, num_heads, dropout, dff)
for _ in range(N_blocks)
])
self.norm = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, seg_ids, mask):
pos = torch.arange(x.shape[1])
tok_emb = self.tok_emb(x)
seg_emb = self.seg_emb(seg_ids)
pos_emb = self.pos_emb(pos)
x = tok_emb + seg_emb + pos_emb
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
return x
class BERT(nn.Module):
def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):
super().__init__()
self.bert = BertModel(vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff)
self.mlm_head = nn.Linear(d_model, vocab_size)
self.nsp_head = nn.Linear(d_model, 2)
def forward(self, mlm_tok_ids, seg_ids, mask):
bert_out = self.bert(mlm_tok_ids, seg_ids, mask)
cls_token = bert_out[:, 0, :]
mlm_logits = self.mlm_head(bert_out)
nsp_logits = self.nsp_head(cls_token)
return mlm_logits, nsp_logits
def read_data(file):
with open(file, "r", encoding="utf-8") as f:
data = f.read().strip().replace("\n", "")
corpus = re.split(r'[。,“”:;!、]', data)
corpus = [sentence for sentence in corpus if sentence.strip()]
return corpus
def create_nsp_dataset(corpus):
nsp_dataset = []
for i in range(len(corpus)-1):
next_sentence = corpus[i+1]
rand_id = random.randint(0, len(corpus) - 1)
while abs(rand_id - i) <= 1:
rand_id = random.randint(0, len(corpus) - 1)
negt_sentence = corpus[rand_id]
nsp_dataset.append((corpus[i], next_sentence, 1)) # 正样本
nsp_dataset.append((corpus[i], negt_sentence, 0)) # 负样本
return nsp_dataset
class BERTDataset(Dataset):
def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):
self.nsp_dataset = nsp_dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.cls_id = tokenizer.cls_token_id
self.sep_id = tokenizer.sep_token_id
self.pad_id = tokenizer.pad_token_id
self.mask_id = tokenizer.mask_token_id
def __len__(self):
return len(self.nsp_dataset)
def __getitem__(self, idx):
sent1, sent2, nsp_label = self.nsp_dataset[idx]
sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)
sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)
tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]
seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)
mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)
mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)
seg_ids = self.pad_to_seq_len(seg_ids, 2)
mlm_labels = self.pad_to_seq_len(mlm_labels, -100)
mask = (mlm_tok_ids != 0)
return {
"mlm_tok_ids": mlm_tok_ids,
"seg_ids": seg_ids,
"mask": mask,
"mlm_labels": mlm_labels,
"nsp_labels": torch.tensor(nsp_label)
}
def pad_to_seq_len(self, seq, pad_value):
seq = seq[:self.max_length]
pad_num = self.max_length - len(seq)
return torch.tensor(seq + pad_num * [pad_value])
def build_mlm_dataset(self, tok_ids):
mlm_tok_ids = tok_ids.copy()
mlm_labels = [-100] * len(tok_ids)
for i in range(len(tok_ids)):
if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:
if random.random() < 0.15:
mlm_labels[i] = tok_ids[i]
if random.random() < 0.8:
mlm_tok_ids[i] = self.mask_id
elif random.random() < 0.9:
mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)
return mlm_tok_ids, mlm_labels
if __name__ == "__main__":
data_file = "4.10-BERT/背影.txt"
model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(model_path)
corpus = read_data(data_file)
max_length = 25 # len(max(corpus, key=len))
print("Max length of dataset: {}".format(max_length))
nsp_dataset = create_nsp_dataset(corpus)
trainset = BERTDataset(nsp_dataset, tokenizer, max_length)
batch_size = 16
trainloader = DataLoader(trainset, batch_size, shuffle=True)
vocab_size = tokenizer.vocab_size
d_model = 768
N_blocks = 2
num_heads = 12
dropout = 0.1
dff = 4*d_model
model = BERT(vocab_size, d_model, max_length, N_blocks, num_heads, dropout, dff)
lr = 1e-3
optim = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
epochs = 20
for epoch in range(epochs):
for batch in trainloader:
batch_mlm_tok_ids = batch["mlm_tok_ids"]
batch_seg_ids = batch["seg_ids"]
batch_mask = batch["mask"]
batch_mlm_labels = batch["mlm_labels"]
batch_nsp_labels = batch["nsp_labels"]
mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)
loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))
loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)
loss = loss_mlm + loss_nsp
loss.backward()
optim.step()
optim.zero_grad()
print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))
pass
pass