深度学习模型: BERT(Bidirectional Encoder Representations from Transformers)详解

news2025/2/12 2:49:11

一、引言

自然语言处理(NLP)领域在过去几十年取得了显著的进展。从早期基于规则的方法到统计机器学习方法,再到如今基于深度学习的模型,NLP 不断向着更高的准确性和效率迈进。BERT 的出现为 NLP 带来了新的突破,它能够有效地对自然语言进行编码,从而在多个 NLP 任务中取得优异的表现。

二、BERT 的架构

(一)Transformer 基础

Transformer 架构由 Vaswani 等人提出,它摒弃了传统的循环神经网络(RNN)结构,采用了自注意力机制(Self - Attention Mechanism)。

自注意力机制

自注意力机制的核心公式如下:

import numpy as np
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    if mask is not None:
        scores += (mask * -1e9)
    attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis = -1, keepdims=True)
    return np.matmul(attention_weights, V)

这里,Q(Query)、K(Key)和V(Value)是输入的向量表示。该机制通过计算QK的点积并进行缩放来得到注意力权重,然后用这些权重对V进行加权求和,得到输出。

多头注意力

  • 多头注意力是对自注意力机制的扩展:
    def multi_head_attention(Q, K, V, num_heads):
        d_model = Q.shape[-1]
        d_k = d_model // num_heads
        Q_heads = np.array([Q[:, :, i * d_k:(i + 1) * d_k] for i in range(num_heads)])
        K_heads = np.array([K[:, :, i * d_k:(i + 1) * d_k] for i in range(num_heads)])
        V_heads = np.array([V[:, :, i * d_k:(i + 1) * d_k] for i in range(num_heads)])
        attention_heads = [scaled_dot_product_attention(Qh, Kh, Vh) for Qh, Kh, Vh in zip(Q_heads, K_heads, V_heads)]
        concat_attention = np.concatenate(attention_heads, axis=-1)
        return concat_attention

它将输入分成多个头(heads),每个头独立进行自注意力计算,然后将结果拼接起来。

输入表示

BERT 的输入由三部分组成:词嵌入(Token Embeddings)、段嵌入(Segment Embeddings)和位置嵌入(Position Embeddings)。

(二)BERT 的具体架构

BERT 的架构基于 Transformer 的编码器部分。

import torch
class BERTInputEmbedding(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size):
        super(BERTInputEmbedding, self).__init__()
        self.token_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
        self.segment_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size)
        self.position_embeddings = torch.nn.Embedding(max_position_embeddings, hidden_size)
    def forward(self, input_ids, token_type_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype = torch.long, device = input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        token_embeds = self.token_embeddings(input_ids)
        segment_embeds = self.segment_embeddings(token_type_ids)
        position_embeds = self.position_embeddings(position_ids)
        return token_embeds + segment_embeds + position_embeds

 

词嵌入将输入的单词转换为向量表示,段嵌入用于区分不同的句子(例如在句子对任务中),位置嵌入则对单词的位置进行编码。

多层 Transformer 编码器

BERT 由多层 Transformer 编码器堆叠而成。

class BERTEncoder(torch.nn.Module):
    def __init__(self, num_layers, hidden_size, num_heads, intermediate_size, dropout):
        super(BERTEncoder, self).__init__()
        self.layers = torch.nn.ModuleList([BERTLayer(hidden_size, num_heads, intermediate_size, dropout) for _ in range(num_layers)])
    def forward(self, hidden_states):
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        return hidden_states
  • 每一层 Transformer 编码器都包括多头注意力机制和前馈神经网络,并且在每层之间有残差连接和层归一化。

三、BERT 的预训练任务

(一)掩码语言模型(Masked Language Modeling,MLM)

  1. 原理
    • 在训练过程中,随机地将输入中的一些单词替换为特殊的[MASK]标记。模型的任务是根据上下文预测被掩码的单词。
    • 例如,对于句子 “The [MASK] is red”,模型需要预测出被掩码的单词 “apple”。
  2. 代码示例
    def mask_tokens(inputs, tokenizer, mlm_probability = 0.15):
        labels = inputs.clone()
        probability_matrix = torch.full(labels.shape, mlm_probability)
        special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype = torch.bool), value = 0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
        return inputs, labels

    (二)下一句预测(Next Sentence Prediction,NSP)

原理

对于给定的两个句子 A 和 B,模型需要判断 B 是否是 A 的下一句。这有助于模型学习句子之间的语义关系。

代码示例

def create_next_sentence_labels(sentence_pairs):
    next_sentence_labels = []
    for (sentence_a, sentence_b) in sentence_pairs:
        if sentence_b is not None:
            next_sentence_labels.append(1)
        else:
            next_sentence_labels.append(0)
    return torch.tensor(next_sentence_labels, dtype = torch.long)

四、BERT 的微调

(一)文本分类任务

  1. 架构调整
    • 在文本分类任务中,通常在 BERT 的输出上添加一个分类层。
class BERTForTextClassification(torch.nn.Module):
    def __init__(self, bert_model, num_classes):
        super(BERTForTextClassification, self).__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(p = 0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_classes)
    def forward(self, input_ids, token_type_ids, attention_mask):
        outputs = self.bert(input_ids, token_type_ids, attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits
  • 这里利用 BERT 的输出池化结果,经过一个线性分类器进行分类。
  1. 微调过程
    • 在微调时,使用标记好的文本分类数据集,通过反向传播算法来更新 BERT 模型和分类层的参数。
    • 例如,使用交叉熵损失函数:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e - 5)
for epoch in range(num_epochs):
    for batch_input_ids, batch_token_type_ids, batch_attention_mask, batch_labels in data_loader:
        optimizer.zero_grad()
        logits = model(batch_input_ids, batch_token_type_ids, batch_attention_mask)
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()

(二)命名实体识别(NER)任务

  • 架构调整
    • 对于 NER 任务,通常在 BERT 的输出上添加一个 CRF(Conditional Random Field)层来进行序列标注。

CRF 层有助于考虑标签之间的依赖关系,提高命名实体识别的准确性。

微调过程

与文本分类类似,使用标记好的 NER 数据集进行微调。

训练过程中,除了更新 BERT 和分类层的参数,还会更新 CRF 层的参数。

五、BERT 的优势

(一)双向编码

  1. 与传统的单向语言模型(如 GPT)不同,BERT 采用双向编码机制。这使得模型能够同时利用上下文信息来对单词进行编码,从而更准确地理解单词的含义。
  2. 在处理诸如文本填空等任务时,双向编码能够更好地根据前后文信息来选择合适的单词。

(二)预训练的通用性

  1. BERT 的预训练任务(MLM 和 NSP)使得模型能够学习到通用的语言知识。
  2. 这种通用性使得 BERT 在微调用于不同的自然语言处理任务时,能够快速适应并取得较好的效果,无论是文本分类、问答系统还是命名实体识别等任务。

(三)性能表现

  1. 在多个自然语言处理基准测试中,BERT 都取得了领先的成绩。
  2. 例如,在 GLUE(General Language Understanding Evaluation)基准测试中,BERT 的表现远远超过了之前的模型。

六、BERT 的局限性

(一)计算资源需求大

  1. BERT 的训练和微调都需要大量的计算资源,包括 GPU 和大量的内存。
  2. 对于一些小型研究机构或企业来说,可能难以承担如此高的计算成本。

(二)长文本处理问题

  1. 虽然 BERT 在处理一般长度的文本时表现良好,但在处理非常长的文本时,由于其架构的限制,可能会出现性能下降的情况。
  2. 这是因为 Transformer 架构中的自注意力机制计算复杂度随着文本长度的增加而急剧增加。

(三)领域适应性

BERT 是在大规模通用语料上进行预训练的,在某些特定领域的自然语言处理任务中,可能需要进一步的领域适应性调整。

例如,在医学领域的文本处理中,BERT 可能需要在医学语料上进行进一步的预训练或微调才能达到较好的效果。

七、BERT 的改进和扩展

(一)RoBERTa

RoBERTa 是对 BERT 的改进,它在预训练过程中进行了一些优化。

例如,取消了下一句预测(NSP)任务,增加了预训练数据的量和多样性,并且采用了动态掩码(Dynamic Masking)的方法来进行掩码语言模型训练。

这些改进使得 RoBERTa 在一些自然语言处理任务中取得了更好的性能。

(二)ALBERT

ALBERT 在 BERT 的基础上进行了架构上的精简和优化。

它提出了跨层参数共享(Cross - Layer Parameter Sharing)的方法,减少了模型的参数数量,同时采用了句子顺序预测(Sentence Order Prediction,SOP)任务来替代 NSP 任务,进一步提高了模型的性能和训练效率。

八、结论

BERT 作为一种基于 Transformer 的预训练模型,在自然语言处理领域取得了巨大的成功。它的双向编码机制、有效的预训练任务和广泛的适用性使其成为自然语言处理研究和应用中的重要工具。尽管存在一些局限性,但通过不断的改进和扩展,如 RoBERTa 和 ALBERT 等变体的出现,BERT 及其相关模型将继续在自然语言处理领域发挥重要作用,推动该领域向着更高的准确性和效率迈进。

在未来的研究中,一方面可以继续探索如何优化 BERT 的架构和训练方法,以减少计算资源需求和提高长文本处理能力;另一方面,可以深入研究如何更好地将 BERT 应用于特定领域,提高其领域适应性,从而在更多的自然语言处理应用场景中取得更好的效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251866.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

亚马逊开发视频人工智能模型,The Information 报道

根据《The Information》周三的报道,电子商务巨头亚马逊(AMZN)已开发出一种新的生成式人工智能(AI),不仅能处理文本,还能处理图片和视频,从而减少对人工智能初创公司Anthropic的依赖…

LLM学习笔记(13)分词器 tokenizer

由于神经网络模型不能直接处理文本,因此我们需要先将文本转换为数字,这个过程被称为编码 (Encoding),其包含两个步骤: 使用分词器 (tokenizer) 将文本按词、子词、字符切分为 tokens;将所有的 token 映射到对应的 tok…

通过LabVIEW项目判断开发环境是否正版

在接收或分析他人提供的LabVIEW项目时,判断其开发环境是否为正版软件对于保护知识产权和避免使用非法软件至关重要。本文将详细介绍如何通过项目文件、可执行程序及开发环境信息判断LabVIEW是否为正版。 ​ 1. 从项目文件判断 LabVIEW项目的源码(VI 文件…

node.js基础学习-url模块-url地址处理(二)

前言 前面我们创建了一个HTTP服务器,如果只是简单的http://localhost:3000/about这种链接我们是可以处理的,但是实际运用中一般链接都会带参数,这样的话如果我们只是简单的判断链接来分配数据,就会报404找不到链接。为了解决这个问…

思科网络设备常用命令整理

思科网络设备的配置命令非常丰富,广泛应用于路由器、交换机和其他网络设备的管理与配置。以下是一些常见的思科设备配置命令,按照功能分类,以帮助你快速查找和使用。 一、基本命令 查看当前配置和状态 show running-config:查看…

2024年信号处理与神经网络应用(SPNNA 2024)

会议官网:www.spnna.org 会议时间:2024年12月13-15日 会议地点:中国武汉

Leecode经典题3-删除排序数组中的重复项

删除排序数组中的重复项 题目描述: 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 …

无人机数据处理系统:原理与核心系统

一、数据处理系统的运行原理 数据获取:无人机在飞行过程中,通过搭载的传感器(如相机、激光雷达等)采集到各种类型的数据,例如图像、点云等。这些数据是后续处理和分析的基础。 数据传输:采集到的数据会通…

ElasticSearch学习篇19_《检索技术核心20讲》搜推广系统设计思想

目录 主要是包含搜推广系统的基本模块简单介绍,另有一些流程、设计思想的分析。 搜索引擎 基本模块检索流程 查询分析查询纠错 广告引擎 基于标签倒排索引召回基于向量ANN检索召回打分机制:非精确打分精准深度学习模型打分索引精简:必要的…

【尚筹网】五、管理员维护

【尚筹网】五、管理员维护 任务清单分页管理管理员信息目标思路代码引入 PageHelperAdminMapper 中编写 SQL 语句AdminMapper 接口生成方法AdminServiceAdminHandler页面显示主体在页面上使用 Pagination 实现导航条 关键词查询页面上调整表单在翻页时保持关键词查询条件 单条删…

MySQL 启动失败问题分析与解决方案:`mysqld.service failed to run ‘start-pre‘ task`

目录 前言1. 问题背景2. 错误分析2.1 错误信息详解2.2 可能原因 3. 问题排查与解决方案3.1 检查 MySQL 错误日志3.2 验证 MySQL 配置文件3.3 检查文件和目录权限3.4 手动启动 MySQL 服务3.5 修复 systemd 配置文件3.6 验证依赖环境 4. 进一步优化与自动化处理结语 前言 在日常…

Apache storm UI如何更换默认8080端口

在搭建Apache storm环境的时候,遇到Apache storm UI默认端口是8080,但是这个端口会被其他java程序占用,导致Apache storm UI服务无法启动。报错Exception in thread “main” java.lang.RuntimeException: java.io.IOException: Failed to bi…

FPGA实现串口升级及MultiBoot(十)串口升级SPI FLASH实现

本文目录索引 工程架构example9工程设计Vivado设计Vitis设计example9工程验证1、读取FLASH ID2、擦除整个FLASH3、Blank-Check4、烧写Golden区位流5、读取FLASH内容6、烧写MultiBoot区位流(升级位流)7、MultiBoot区位流(升级位流)启动example10工程设计Vivado设计Vitis设计exam…

图解人工智能:从规则到深度学习的全景解析

🌟作者简介:热爱数据分析,学习Python、Stata、SPSS等统计语言的小高同学~🍊个人主页:小高要坚强的博客🍓当前专栏:Python之机器学习🍎本文内容:图解人工智能:…

Binder架构

一、架构 如上图,binder 分为用户层和驱动层两部分,用户层有客户端(Client)、服务端(Server)、服务管理(ServiceManager)。 从用户空间的角度,使用步骤如下(…

基于springboot中小型制造企业质量管理系统源码和论文

信息数据从传统到当代,是一直在变革当中,突如其来的互联网让传统的信息管理看到了革命性的曙光,因为传统信息管理从时效性,还是安全性,还是可操作性等各个方面来讲,遇到了互联网时代才发现能补上自古以来的…

Flutter 权限申请

这篇文章是基于permission_handler 10.2.0版本写的 前言 在App开发过程中我们经常要用到各种权限,我是用的是permission_handler包来实现权限控制的。 pub地址:https://pub.dev/packages/permission_handler permission_handler 权限列表 变量 Androi…

MATLAB期末复习笔记(下)

五、数据和函数的可视化 1.MATLAB的可视化对象 图形对象是 MATLAB用来创建可视化数据的组件。每个对象都有一个名为句柄 的唯一标识符。使用该句柄,您可以通过设置对象 属性 来操作现有图形对象的特征 ROOT: :即电脑屏幕 Figure :图窗…

web安全从0到1:burp-suite3

声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…

深度学习:梯度下降法

损失函数 L:衡量单一训练样例的效果。 成本函数 J:用于衡量 w 和 b 的效果。 如何使用梯度下降法来训练或学习训练集上的参数w和b ? 成本函数J是参数w和b的函数,它被定义为平均值; 损失函数L可以衡量你的算法效果&a…