BERT模型解读与简单任务实现(论文复现)

news2024/9/20 15:27:16

BERT模型解读与简单任务实现(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

相关背景

语言模型:语言模型是指对于任意的词序列,它能够计算出这个序列是一句话的概率。
预训练:预训练是一种迁移学习的概念,指的是用海量的数据来训练一个泛化能力很强的模型
微调:微调(Fine-tuning)是指在预训练模型的基础上,针对特定任务或数据领域,对部分或全部模型参数进行进一步的训练和调整
Transformer:
BERT是基于Transformer实现的,BERT中包含很多Transformer模块,其取得成功的一个关键因素是Transformer的强大作用。BERT仅有Encoder部分,因为它并不是生成式模型。Transformer个模块通过自注意力机制实现快速并行,改进了RNN最被人诟病的训练慢的缺点,并且可以增加到非常深的深度,充分发掘DNN模型的特性,提升模型准确率。
Transformer首先对每个句子进行词向量化,进行编码,再添加某个词蕴含的位置信息,生成一个向量。而后通过Attention算法,生成一个新向量,这个新向量不仅包含了词的含义,词中句子中的位置信息,也包含了该词和句子中的每个单词含义之间的关系和价值信息。这种方法突破了时序序列的屏障,使得Transformer得到了广泛的应用。

BERT的优势

  • 作为一种预训练模型,在特定场景使用时不需要用大量的语料来进行训练,节约时间效率高效,泛化能力较强。
  • Bert是一种端到端(end-to-end)的模型,不需要我们调整网络结构,只需要在最后加上特定于下游任务的输出层。
  • 基于Transformer,可以实现快速并行,也可以增加到非常深的深度,充分发掘DNN模型的特性,提升模型准确率。
  • 和ELMO,GPT等其他预训练模型相比,BERT是一种双向的模型,结合上下文来进行训练,具有更好的性能。

BERT和传统nlp相比的特点

1.更好的语义理解能力
传统的自然语言处理工具只能从字面意义上进行文本分析,无法理解句子的含义和上下文。而BERT模型是双向的,可以同时考虑句子左右两侧的上下文信息,从而更好地理解句子的含义和语境。因此,在对话系统、文本分类等领域中BERT模型的表现更加优秀
2.更好的文本预训练能力
BERT是基于预训练的模型,使用了大型无标注语料库进行训练。由于BERT训练时使用了大量的语料库。因此具有更好的泛化能力和适应性,可以适应不同的自然语言处理任务。
3.可拓展性强
BERT采用Transformer结构,使得模型可以轻松地进行拓展。可以通过增加层数、增加训练数据等方式来提高模型的性能。因此,BERT模型在对新领域的应用中具有很大的潜力。
4.更好的效果
针对一些自然语言处理领域的任务,BERT模型的表现要优于其他传统的自然语言处理模型。例如,BERT在文本分类任务中表现出的效果比传统的卷积网络和循环神经网络要好,在当前的文本分类领域中有着广泛的应用

BERT的应用领域

BERT作为一个预训练模型,能够通过适当的数据集进行微调,使得它能够胜任自然语言处理领域的多种任务,比如情感分析、摘要、对话等任务

模型架构

BERT的模型架构是基于多层双向Transformer编码器。具体的,Google提供了一大一小两个BERT模型:
BERT_Small(L=12,H=768,A=12,总参数=110M)
BERT_Large(L=24,H=1024,A=16,总参数=340M)

输入输出表示

为了使BERT能够处理各种下游任务,输入表示能够明确表示单个句子和一对句子(例如,⟨question,answer⟩)在一个标记序列中。
BERT的输入由三部分组成:
Token Embeddings:使用具有30,000个标记词汇表的WordPiece嵌入。每个序列的第一个标记始终是一个特殊的分类标记([CLS])。对应于此标记的最终隐藏状态用作分类任务的聚合序列表示。
Segment Embeddings:用于区分两个句子。通过两种方式区分句子:1.用一个特殊标记([SEP])将它们分开。2.为每个标记添加一个学习的嵌入,指示它属于句子A还是句子B。
Position Embeddings:位置编码,transformer没有捕捉位置信息的能力,所以需要额外的位置编码,这里没有使用transformer论文中的正弦位置编码, 而是采用了learned positional embeddings。
将BERT的输入表示可视化如下:

在这里插入图片描述

BERT预训练任务

使用两个无监督任务来预训练BERT,包括MLM和NSP。

MLM掩码语言模型

直观来看,深度双向语言模型当然比单向的从左到右或者从右到左模型更有效。但不幸的是,标准条件语言模型只能从左到右或从右到左进行训练,因为双向条件将允许每个单词在多层上下文中间接 “看到自己”
为了训练一个深度双向表示,Google学者简单地随机掩盖一定比例的输入标记,然后预测这些被掩盖的标记,这个过程称为“掩码语言模型”(MLM),也就是类似于完形填空任务。
但这种办法存在两个问题:
1.在预训练和微调之间导致了不匹配,因为[MASK]标记在微调期间不会出现。为了缓解这一问题,他们并不总是用实际的[MASK]标记替换“被掩盖”的单词,而是在训练时随机选择15%的标记位置进行预测。如果选择第i个标记,将第i个标记以以下方式替换:(1) 80%的概率 用[MASK]标记替换 (2) 10%的概率用随机标 记替换 (3) 10%的概率保持不变。
2.每个batch只预测15%的tokens,需要更多的轮次才能收敛。

NSP下句预测

许多重要的下游任务,例如问答(QA)和自然语言推断(NLI),都是基于理解两个文本句子之间的关系的。这个没有办法直接由语言模型捕捉,所以增加了一个next sentence prediction(NSP)任务。具体的,对于训练语料中一对句子A和B,B有一半的概率是A的下一句,一半的概率是随机的句子。

预训练过程

预训练过程在很大程度上遵循现有关于语言模型预训练的文献。对于预训练语料库,BERT使用了BooksCorpus(800M字)和英文维基百科(2,500M字)。对于维基百科,仅提取文本段落,忽略列表、表格和标题。使用文档级语料库而不是乱序句子级语料库是至关重要的,以便提取长连续序列

微调过程

在不同任务上微调BERT的示意图如图所示。任务特定模型是通过将BERT与一个额外的输出层结合形成的,因此只需要从头开始学习少量参数。在这些任务中,a和b是序列级任务,而c和d是标记级任务。在图中,E代表输入嵌入,Ti代表标记 i的上下文表示,[CLS]是用于分类输出的特殊符号,[SEP]是用于分隔非连续标记序列的特殊符号。

在这里插入图片描述

核心逻辑

代码主要分为三部分。
1.dataset,主要负责数据的预处理。比如如何对语料做mask,如何加入CLS、SEP符号等等。
2.model,主要包括bert模型架构,两个预训练任务的实现。
3.trainer,主要实现了预训练的逻辑。
核心内容在model部分。
model部分的核心在于BERT、BERTEmbedding和transformer着三部分

BERT

下面的类实现了BERT的模型,模型的输入通过BERT Embedding层和transformer组建,实现模型架构。

class BERT(nn.Module)
    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.feed_forward_hidden = hidden * 4
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        x = self.embedding(x, segment_info)
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        return x

transformer

下面实现的是transformer块,通过多头注意力层、前馈网络层(两层全连接层)、两层残差链接和dropout层来实现

class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

BERT Embedding

下面的类实现了BERT Embedding层,BERT模型的数据输入的三种编码加和,再经过一个dropout层。

class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

演示效果

经过训练后的BERT模型,能够根据输入的句子计算他们的编码值。以及,可以区分两个句子的编码,对指定的词进行self-Attention操作,例如:

输入内容:

bert_model = BertModel.from_pretrained(MODEL_PATH, config = model_config)
print(tokenizer.encode('我不喜欢你'))  
sen_code = tokenizer.encode_plus('我不喜欢这世界','我只喜欢你')
print(sen_code)

输出内容:

[101, 2769, 679, 1599, 3614, 872, 102]
{'input_ids': [101, 2769, 679, 1599, 3614, 6821, 686, 4518, 102, 2769, 1372, 1599, 3614, 872, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

同时可以实现论文中提到的两个任务。

MLM任务:

import numpy as np
import torch
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel
model_name = 'xxxxx'       #指定需下载的预训练模型参数

#任务一:MLM
import numpy as np
import torch

from transformers import logging
logging.set_verbosity_error()

from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel
model_name = 'xxxx'       #指定需下载的预训练模型参数

#任务一:MLM
samples =  ['[CLS] Where is the capital of China? [SEP] Beijing is the capital of [MASK] . [SEP]']
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenized_text = [tokenizer.tokenize(i) for i in samples]
input_ids = [tokenizer.convert_tokens_to_ids(i) for i in tokenized_text]
input_ids = torch.LongTensor(input_ids)


model = BertForMaskedLM.from_pretrained(model_name, cache_dir=model_name)
model.eval()

outputs = model(input_ids)
prediction_scores = outputs[0]
sample = prediction_scores[0].detach().numpy()

pred = np.argmax(sample, axis=1)
print(pred)

输出pred代表每个位置最大概率的字符索引:

[1012 2073 2003 1996 3007 1997 2859 1029 1012 7211 2003 1996 3007 1997
 2859 1012 1012]

NSP任务:

from transformers import BertTokenizer, BertForSequenceClassification
import torch
path='xxxx'
tokenizer = BertTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path)
TEMPERATURE = 1 
MERGE_RATIO = 0.5 

def is_nextsent(sent, next_sent):

        encoding = tokenizer(sent, next_sent, return_tensors="pt",truncation=True, padding=False)

        with torch.no_grad():
            labels = torch.tensor([1]).unsqueeze(0)
            outputs = model(**encoding, labels=labels)
            logits = outputs.logits
            probs = torch.softmax(logits/TEMPERATURE, dim=1)
            next_sentence_prob = probs[:, 0].item()
            print(next_sentence_prob)
        if next_sentence_prob <= MERGE_RATIO:

            return False
        else:
            return True

sen1="今天天气怎么样"
sen2="今天天气很好"
sen3="小明今年多大了"
sen4="小明天天吃饭"
print(is_nextsent(sen1,sen2))
print(is_nextsent(sen3,sen4))

一个输出例子,说明句子1和句子2之间是上下句关系,而句子3和句子4之间不是上下句关系,其中小数表示模型预测的两个句子是上下文关系的概率,这里概率越大,说明两个模型越可能是上下句关系:

0.5294263362884521
True
0.4861791133880615
False

使用方式

下载模型,进入huggingface官网搜索框输入bert-base-chinese,下载需要的文件。在本地新建一个文件夹,把上面文件下载到这个目录下面,注意:不能改变文件名和后缀;如果无法登录huggingface,可以前往附件查看,附件中提供了模型的下载链接。本地加载模型,使用如下代码本地加载模型,即可进行模型加载使用等功能。

path='xxxx'
tokenizer = BertTokenizer.from_pretrained(path)
model = BertForSequenceClassification.from_pretrained(path)

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c++类和对象(6个默认成员函数)第二级中阶

目录 6个默认成员函数介绍 构造函数 构造函数是什么&#xff1f; 构造函数的6种特性 析构函数 析构函数是什么&#xff1f; 析构函数的特性 拷贝构造函数 什么是拷贝构造函数 拷贝函数的特性 四.默认生成的拷贝构造实行的是浅拷贝&#xff08;值拷贝&#xff09;&am…

【2024】前端学习笔记9-内部样式表-外部导入样式表-类选择器

学习笔记 内部样式表外部导入样式表类选择器&#xff1a;class 内部样式表 内部样式表是将 CSS 样式规则写在 HTML 文档内部。通过<style>标签在 HTML 文件的<head>部分定义样式。 简单示例&#xff1a; <!DOCTYPE html><html><head><style…

【linux】基础IO(上)

1. 共识原理 文件 内容 属性文件分为 打开的文件 没打开的文件打开的文件 &#xff1a; 是进程打开的 ----- 本质是要研究文件和进程的关系没打开的文件 &#xff1a; 没打开的文件储存在磁盘上&#xff0c;由于没打开的文件很多&#xff0c;所以需要分门别类的防止好&…

常用函数式接口的使用

FunctionalInterface注解 函数式接口在java中是指:有且仅有一个抽象方法的接口。 虽然知道怎么使用&#xff0c;但是没有搞懂使用场景&#xff0c;暂且记录下使用方法吧&#xff0c;不至于看到源码的时候不知所云。 要我自己写代码&#xff0c;我是想不起来这样用的&#xff0…

YOLOv9改进策略【注意力机制篇】| 2024 SCSA-CBAM 空间和通道的协同注意模块

一、本文介绍 本文记录的是基于SCSA-CBAM注意力模块的YOLOv9目标检测改进方法研究。现有注意力方法在空间-通道协同方面未充分挖掘其潜力&#xff0c;缺乏对多语义信息的充分利用来引导特征和缓解语义差异。SCSA-CBAM注意力模块构建一个空间-通道协同机制&#xff0c;使空间注…

C语言 结构体和共用体——典型实例:洗发牌模拟

目录 如何表示52张扑克牌&#xff1f; 如何保存一副扑克牌&#xff1f; 如何发牌&#xff1f; 如何设计主函数&#xff1f; 如何模拟洗牌&#xff1f; 如何表示52张扑克牌&#xff1f; 如何保存一副扑克牌&#xff1f; 如何发牌&#xff1f; 如何设计主函数&#xff1f; 如…

窗户检测系统源码分享

窗户检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …

十大常用加密软件排行榜|2024年好用的加密软件推荐(企业必备)

在数字化时代&#xff0c;数据安全已经成为企业生存和发展的关键因素之一。随着网络攻击和数据泄露事件的频发&#xff0c;企业对数据加密的需求日益增长。选择一款可靠的加密软件&#xff0c;不仅能保护企业的核心数据&#xff0c;还能确保业务的连续性和合规性。本文将为您介…

Stable Diffusion 使用详解(11)--- 场景ICON制作

目录 背景 controlNet 整体描述 Canny Lineart Depth 实际使用 AI绘制需求 绘制过程 PS打底 场景模型选择 设置提示词及绘制参数 controlnet 设置 canny 边缘 depth 深度 lineart 线稿 效果 背景 这段时间不知道为啥小伙伴似乎喜欢制作很符合自己场景的ICON。…

共享wifi哪家公司正规合法?看这3点就够了!

随着共享wifi项目的热度不断上升&#xff0c;越来越多的公司都开始加入到共享wifi贴码的研发行列之中&#xff0c;让意向入局该项目的创业者拥有更多选择的同时&#xff0c;也让许多想要借此割一波韭菜的不法分子有了可乘之机。在此背景下&#xff0c;共享wifi哪家公司正规合法…

OpenHarmony(鸿蒙南向开发)——小型系统内核(LiteOS-A)【内核启动】

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 子系统开发内核 轻量系统内核&#xff08;LiteOS-M&#xff09; 轻量系统内核&#…

Docker安装rabbitmq并配置延迟队列

下载rabbitmq镜像 docker pull rabbitmq:management 运行rabbitmq镜像 docker run -id --namerabbitmq -p 5671:5671 -p 5672:5672 -p 4369:4369 -p 15671:15671 -p 15672:15672 -p 25672:25672 -e RABBITMQ_DEFAULT_USERtom -e RABBITMQ_DEFAULT_PASStom rabbitmq:management …

回归传统,Domino拷贝式迁移!

大家好&#xff0c;才是真的好。 前面讲太多普及型的概念&#xff0c;今天我们来点实在的内容。 在Notes/Domino的黄金年代&#xff0c;有一件事情大家干得风生水起&#xff0c;那就是Domino服务器迁移。 要么迁移到另一台硬件服务器上&#xff0c;要么迁移到新换的磁盘当中…

展会上想要留住俄罗斯客户,柯桥成人俄语培训

展品 экспонат 模型 макет 证明(书) свидетельство 预算 бюджет 确认订单 подтверждение заказа 缺点,毛病,缺陷 недостаток 退换 возвращать 更换 заменять 调整 урегулир…

[PTA]7-1 谁管谁叫爹

[PTA]7-1 谁管谁叫爹 输入格式&#xff1a; 输入第一行给出一个正整数 N&#xff08;≤100&#xff09;&#xff0c;为游戏的次数。以下 N 行&#xff0c;每行给出一对不超过 9 位数的正整数&#xff0c;对应 A 和 B 给出的原始数字。题目保证两个数字不相等。 输出格式&…

虹科干货 | CAN/CAN FD故障揭秘:快速排查与解决技巧

是否在处理CAN总线问题时感到头疼&#xff1f;是否在寻找简单直接的方法来解决那些看似复杂的连接故障&#xff1f;本文将为您提供实用技巧&#xff0c;让您能够轻松应对这些难题。 CAN总线因其高效、可靠的数据交换能力&#xff0c;在汽车、工业控制、航空航天等多个关键领域得…

【软件方案】智慧社区总体解决方案(PPT原件)

1.智慧社区整体建设方案内容 2.整体功能介绍 软件全套资料部分文档清单&#xff1a; 工作安排任务书&#xff0c;可行性分析报告&#xff0c;立项申请审批表&#xff0c;产品需求规格说明书&#xff0c;需求调研计划&#xff0c;用户需求调查单&#xff0c;用户需求说明书&…

故障模拟测试负载是如何实现的

故障模拟测试负载是在系统或设备上故意引入故障&#xff0c;以测试其应对能力的方法。这种方法可以帮助我们了解系统在面临各种故障时的响应和恢复能力&#xff0c;从而提高系统的可靠性和稳定性。故障模拟测试负载的实现主要依赖于以下几个步骤&#xff1a; 1. 确定故障类型&…

uniapp快速入门教程,内容来源于官方文档,仅仅记录快速入门需要了解到的知识点

uniapp快速入门教程&#xff0c;内容来源于官方文档&#xff0c;仅仅记录快速入门需要了解到的知识点 目录 介绍uniapp 介绍uniapp x 介绍功能框架图创建项目&发布组件/标签的变化js的变化css的变化工程结构和页面管理 pages.jsonmanifest.json 应用配置组件easycom组件规…

【Unity杂谈】iOS 18中文字体显示问题的调查

一、问题现象 最近苹果iOS 18系统正式版推送&#xff0c;周围升级系统的同事越来越多&#xff0c;有些同事发现&#xff0c;iOS 18上很多游戏&#xff08;尤其是海外游戏&#xff09;的中文版&#xff0c;显示的字很奇怪&#xff0c;就像一些字被“吞掉了”&#xff0c;无法显示…