94. BERT以及BERT代码实现

news2025/1/10 10:59:17

1. NLP里的迁移学习

  • 使用预训练好的模型来抽取词、句子的特征
    • 例如word2vec 或语言模型
  • 不更新预训练好的模型
  • 需要构建新的网络来抓取新任务需要的信息
    • Word2vec忽略了时序信息,语言模型只看了一个方向
    • Word2vec只是抽取底层的信息,作为embedding层,之后的网络还是得自己设计,所以新的任务需要构建新的网络

2. BERT的动机

在这里插入图片描述

3. BERT架构

在这里插入图片描述

4. 输入的修改

在这里插入图片描述

5. 预训练任务1:带掩码的语言模型

在这里插入图片描述

6. 预训练任务2:下一句子预测

在这里插入图片描述

7. 总结

  • BERT是针对微调设计的
  • 基于Transformer的编码器做了如下修改
    • 模型更大,训练数据更多
    • 输入句子对,片段嵌入,可学习的位置编码
    • 训练时使用两个任务:
      • 带掩码的语言模型
      • 下一个句子预测

8. 代码实现

import torch
from torch import nn
from d2l import torch as d2l

8.1 输入表示

下面的get_tokens_and_segments将一个句子或两个句子作为输入,然后返回BERT输入序列的标记及其相应的片段索引。

def get_tokens_and_segments(tokens_a, tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0和1分别标记片段A和B
    # segements 是一个长为len(tokens_a) + 2的向量,每个元素都是0
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

下面的BERTEncoder类类似于transformer中实现的TransformerEncoder类。与TransformerEncoder不同,BERTEncoder使用片段嵌入可学习的位置嵌入

class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        # segment_embedding中第一个参数是2,因为输入是0和1
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers): # 加入多少个EncoderBlock
            self.blks.add_module(f"{i}", d2l.EncoderBlock(
                key_size, query_size, value_size, num_hiddens, norm_shape,
                ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
        # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
        # batch_size=1,随机初始化pos_embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        # 以上两行代码是对输入X进行处理
        for blk in self.blks: # 接着让X输入到之前设置好的blk中,自行计算
            X = blk(X, valid_lens)
        return X

假设词表大小为10000,为了演示BERTEncoder的前向推断,让我们创建一个实例并初始化它的参数。

vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
                      ffn_num_hiddens, num_heads, num_layers, dropout)

我们将tokens定义为长度为82个输入序列,其中每个词元是词表的索引。使用输入tokens的BERTEncoder的前向推断返回编码结果,其中每个词元由向量表示,其长度由超参数num_hiddens定义。此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)

tokens = torch.randint(0, vocab_size, (2, 8))
segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
encoded_X = encoder(tokens, segments, None)
encoded_X.shape

运行结果:

在这里插入图片描述

8.2 预训练任务

1. 掩蔽语言模型(Masked Language Modeling)

我们实现了下面的MaskLM类来预测BERT预训练的掩蔽语言模型任务中的掩蔽标记。预测使用单隐藏层的多层感知机(self.mlp)。在前向推断中,它需要两个输入:BERTEncoder编码结果用于预测的词元位置。输出是这些位置的预测结果。

class MaskLM(nn.Module):
    """BERT的掩蔽语言模型任务"""
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))

    # X是上面上一块代码中BERT encoder的输出,
    # pred_positions是要预测的词元的位置
    def forward(self, X, pred_positions):
        # 每个输入序列要预测的位置个数
        # 假设pred_positions是一个二维数组:[[1, 5, 2], [6, 1, 5]]
        # 则shape[1]表示列数,也就是有几个位置需要预测
        num_pred_positions = pred_positions.shape[1]
        # reshape(-1) 表示把pred_positions弄成一行,也就是:
        # [1, 5, 2, 6, 1, 5]
        pred_positions = pred_positions.reshape(-1)
        # X是encoded_X,形状是(2, 8, 768),
        # 表示批量大小为2,长度为8,每个词元用长为768的向量表示
        batch_size = X.shape[0]
        # batch_idx:(0,1)
        batch_idx = torch.arange(0, batch_size)
        # 假设batch_size=2,num_pred_positions=3
        # 那么batch_idx是np.array([0,0,0,1,1,1])
        # repeat_interleave((0,1), 3),得到batch_idx为:tensor([0, 0, 0, 1, 1, 1])
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        # X[batch_idx, pred_positions]也就是X[([0,0,0,1,1,1]),([1, 5, 2, 6, 1, 5])]
        # 分别拿到(0,1),(0,5),(0,2),(1,6),(1,1),(1,5)这一些需要mask的位置
        # (0,1)解释:拿到第一个序列的第1行,第一行也就是我要用做预测的词元
        # 所以,masked_X得到的是所有要mask的词元的原本的向量表示(经过BERT之后的向量表示)
        masked_X = X[batch_idx, pred_positions] # 形状是(6,768)

        # 对masked_X 进行reshape(2,3,768)
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        # 再把对masked_X放入mlp中,这也对应了mlp定义的第一层的num_inputs大小为768
        # 放入mlp中就会对这些mask词元做预测,由长为768的向量表示,变成vocab_size大小的向量表示
        mlm_Y_hat = self.mlp(masked_X)
        # 因此返回的mlm_Y_hat的形状是(2,3,10000),也就是最后一维发生了改变
        return mlm_Y_hat

为了演示MaskLM的前向推断,我们创建了其实例mlm并对其进行了初始化。回想一下,来自BERTEncoder的正向推断encoded_X表示2个BERT输入序列。我们将mlm_positions定义为在encoded_X的任一输入序列中预测的3个指示。mlm的前向推断返回encoded_X的所有掩蔽位置mlm_positions处的预测结果mlm_Y_hat。对于每个预测,结果的大小等于词表的大小。

mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(encoded_X, mlm_positions)
mlm_Y_hat.shape

运行结果:

在这里插入图片描述

通过掩码下的预测词元mlm_Y的真实标签和预测结果mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务交叉熵损失

mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
# mlm_Y_hat.reshape((-1, vocab_size))后得到的形状是(6,10000)
#  mlm_Y.reshape(-1)的形状是一个长为6的向量
# 这里应该是随意举的例子,为了表示一下要算交叉熵损失函数
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l.shape

运行结果:

在这里插入图片描述

2. 下一句预测(Next Sentence Prediction)

尽管掩蔽语言建模能够编码双向上下文来表示单词,但它不能显式地建模文本对之间的逻辑关系。为了帮助理解两个文本序列之间的关系,BERT在预训练中考虑了一个二元分类任务——下一句预测。在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。

下面的NextSentencePred类使用单隐藏层的多层感知机来预测第二个句子是否是BERT输入序列中第一个句子的下一个句子。由于Transformer编码器中的自注意力,特殊词元“< cls>”的BERT表示已经对输入的两个句子进行了编码。因此,多层感知机分类器的输出层(self.output)以X作为输入,其中X是多层感知机隐藏层的输出,而MLP隐藏层的输入是编码后的“< cls>”词元。

class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状:(batchsize,num_hiddens)
        return self.output(X)

我们可以看到,NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测。

# 一开始的encoded_X的形状是(2,8,768)
# 从第1维开始展平得到的形状是(2,8*768)=(2,6144)
encoded_X = torch.flatten(encoded_X, start_dim=1)
# encoded_X.shape[-1]:6144
# 这个6144是来初始化NextSentencePred实例的,作为num_inputs
nsp = NextSentencePred(encoded_X.shape[-1])
# NSP的输入形状:(batchsize,num_hiddens)
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape

运行结果:

在这里插入图片描述

还可以计算两个二元分类的交叉熵损失。

nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape

运行结果:

在这里插入图片描述

值得注意的是,上述两个预训练任务中的所有标签都可以从预训练语料库中获得,而无需人工标注。原始的BERT已经在图书语料库 :cite:Zhu.Kiros.Zemel.ea.2015和英文维基百科的连接上进行了预训练。这两个文本语料库非常庞大:它们分别有8亿个单词和25亿个单词。

3. 整合代码

在预训练BERT时,最终的损失函数是掩蔽语言模型损失函数和下一句预测损失函数的线性组合。现在我们可以通过实例化三个类BERTEncoderMaskLMNextSentencePred来定义BERTModel类。前向推断返回编码后的BERT表示encoded_X、掩蔽语言模型预测mlm_Y_hat和下一句预测nsp_Y_hat

class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        # Bert的encoder
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        # 隐藏层,是下一个句子预测任务的隐藏层
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        # 任务1
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        # 任务2
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        # 获得输入X进行了各种编码,位置编码+segement embedding + token embedding
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None: # 需要预测位置的话,就需要掩码语言模型
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None

        # encoded_X的第0个维度是batch_szie,第二个维度是输入的句子长度(一个句子对),
        # 那么0就是这个句子对的第一个元素“<cls>”的索引
        # 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/186346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据太多?3款免费数据分析软件,分分钟解决

本文分享下我在做数据分析时用过的几个简单易上手的数据可视化软件。 先放上目录&#xff1a; 数据统计收集类——简道云数据图表美化类——图表秀数据开发类——Echart 01 简道云 https://www.jiandaoyun.com/ 适用于&#xff1a;想要“简单易上手”适合业务人员&#xff0…

TF-A源码移植的过程

1.解压标准 tf-a 源码包&#xff1a; tar xfz tf-a-stm32mp-2.2.r2-r0.tar.gz 2.将 ST 官方补丁文件打到 tf-a 源码中&#xff1a; 3.配置交叉编译工具链&#xff1a; 进入~/FSMP1A/tf-a-stm32mp-2.2.r2-r0$ 目录下&#xff0c;打开Makefile.sdk 将如下内容进行更改 4.复制设…

【前端设计】监控顺序返回型总线超时的计时器模块设计

前言 总线超时检查机制是系统中必要的模块设计&#xff0c;用于在总线无法返回response时能够及时上报中断。从理论上分析&#xff0c;如果总线发生了诸如挂死或者物理损坏等超时行为&#xff0c;无论计时器上报timeout的时间偏大还是偏小&#xff0c;都是一定可以上报中断的。…

Xilinx MicroBlaze系列教程(适用于ISE和Vivado开发环境)

本文是Xilinx MicroBlaze系列教程的第0篇文章。这个系列文章是我个人最近两年使用Xilinx MicroBlaze软核的经验和笔记,以Xilinx ISE 14.7和Spartan-6,以及Vivado 2018.3和Artix-7为例,介绍MicroBlaze软核、AXI系列IP核的软硬件使用,希望能帮助到更多的人。 MicroBlaze是Xil…

什么是有限元分析?能用来干什么

您是否想过工程师和制造商如何测试他们设计的耐用性、强度和安全性&#xff1f;如果您看过汽车广告&#xff0c;您可能会相信工程师和设计师不断地破坏他们的产品以测试其强度。您可能会得出结论&#xff0c;制造商会重复此过程&#xff0c;直到设计能够承受巨大的损坏并达到可…

夜游经济:夜景“亮化”,形象“美化”,经济“活化”

复杂的国际形势之下&#xff0c;扩大国内消费需求&#xff0c;激发消费市场潜力&#xff0c;堪称疫后经济复苏振兴的“金钥匙”。这一背景下&#xff0c;大力发展夜游经济&#xff0c;成为提振国内消费需求、促进城乡居民就业、拉动经济复苏增长的重要突破口。去年以来&#xf…

无法超越的100米_百兆以太网传输距离_网线有哪几种?

对网络比较了解的朋友&#xff0c;都知道双绞线有一个“无法逾越”的“100米”传输距离。无论是10M传输速率的三类双绞线&#xff0c;还是100M传输速率的五类双绞线&#xff0c;甚至1000M传输速率的六类双绞线&#xff0c;最远有效传输距离为100米。在综合布线规范中&#xff0…

Qt下实现欧姆龙PLC 串口发送HOSTLINK(FINS)模式

文章目录前言一、HOSTLINK协议说明二、校验码&#xff08;FCS&#xff09;计算三、示例完整代码四、下载链接总结前言 本文讲述了Qt下模拟串口调试工具发送HOSTLINK&#xff08;FINS&#xff09;模式&#xff0c;主要进行了HR保持区的字和位的读写&#xff0c;对HOSTLINK协议中…

记一次CPU飚高以及排查过程

一.cpu突然飚高 收到系统频发的cpu超过90%的告警.虽然是在非线上环境出现.接到告警后第一反应还是去重启了机器,重启后cpu如期的下降了下来.以为能高枕无忧,不过一会儿还是收到了告警. 二.排查 2.1 top 指令查看物理机进程id 申请了堡垒机权限登上机器 top指令后.如下确实发…

微服务,Docker, k8s,Cloud native 云原生的简易发展史

微服务发展史 2005年&#xff1a;Dr. PeterRodgers在Web ServicesEdge大会上提出了“Micro-Web-Services”的概念。2011年&#xff1a;一个软件架构工作组使用了“microservice”一词来描述一种架构模式。2012年&#xff1a;同样是这个架构工作组&#xff0c;正式确定用“micr…

万字详解 C 语言文件操作

目录 一、什么是文件&#xff1f; 1.1 - 文件和流的基本概念 1.2 - 文件的分类 1.3 - 文件名 二、缓冲文件系统和非缓冲文件系统 三、文件指针类型 四、文件的打开和关闭 4.1 - fopen 4.2 - fclose 五、文件的顺序读写 5.1 - 字符输出函数 fputc 5.2 - 字符输入函数…

【Kubernetes 企业项目实战】06、基于 Jenkins+K8s 构建 DevOps 自动化运维管理平台(上)

目录 一、k8s 助力 DevOps 在企业落地实践 1.1 传统方式部署项目为什么发布慢&#xff0c;效率低&#xff1f; 1.2 上线一个功能&#xff0c;有多少时间被浪费了&#xff1f; 1.3 如何解决发布慢&#xff0c;效率低的问题呢&#xff1f; 1.4 什么是 DevOps&#xff1f; …

【JavaScript】原型链

文章目录构造函数原型对象访问机制内置构造函数一切皆对象原型链构造函数 - 本质还是一个函数- 和 new 关键字连用- 特点1. 自动创建一个对象2. 自动返回一个对象3. 让函数的this指向这个对象 书写构造函数的时候1. 属性写在函数内2. 方法写在原型上构造函数的不合理 把方法写在…

Android studio 护眼模式配置、字体大小设置、内存大小设置等各类疑难杂症

Android studio 4.1 1、左边目录栏颜色配置&#xff1a; 2、代码编辑区域背景色设置 3、控制台背景色设置 4、菜单栏、工具栏、左边栏字体大小设置 5、编码区字体大小设置 6、修改内存大小、显示内存 例如&#xff1a;修改android-studio/bin/studio.vmoptions studio64.vmop…

NANK南卡护眼台灯Pro评测,护眼台灯天花板般存在!

现代大环境下&#xff0c;生活节奏和压力的不断加快加重&#xff0c;如今的手机、电脑屏幕以及长时间的工作学习都会出现用眼过度的问题&#xff0c;久而久之&#xff0c;各种眼睛刺痛、干涩、肿胀等近视问题接踵而至。为了保障自己的健康&#xff0c;近些年&#xff0c;人们纷…

回归预测 | MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出

回归预测 | MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出 目录回归预测 | MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出预测效果基本介绍模型描述程序设计参考文献预测效果 基本介绍 MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出。 1 .data为…

基础课程11:调试工具

目标 有时事情不会按照预期进行&#xff0c;从总线(如果有的话)检索到的错误消息不能提供足够的信息。幸运的是&#xff0c;GStreamer附带了大量调试信息&#xff0c;这些信息通常会提示问题可能是什么。本教程展示了: 如何从GStreamer获取更多调试信息。 如何将自己的调试信…

电力电子器件简介

文章目录1、二极管2、BJT3、晶闸管&#xff08;SCR&#xff09;4、TRIAC5、GTO&#xff08;全控器件&#xff09;6、功率MOSFET&#xff08;开关速度快、电压驱动更容易&#xff09;7、IGBT8、总结![在这里插入图片描述](https://img-blog.csdnimg.cn/1d309b3d449040788c6437f8…

【胖虎的逆向之路】04——脱壳(一代壳)原理脱壳相关概念详解

【胖虎的逆向之路】04——脱壳&#xff08;一代壳&#xff09;原理&脱壳相关概念详解 【胖虎的逆向之路】01——动态加载和类加载机制详解 【胖虎的逆向之路】02——Android整体加壳原理详解&实现 【胖虎的逆向之路】03——Android一代壳脱壳办法&实操 文章目录【…

高速路如何避免ETC车辆漏计问题,ETC通道出入车辆校准看板

人群密集场所事故预防措施和应急管理方案的制定&#xff0c;对每一个交通枢纽和大型社会活动场所都显得尤为重要。对于交通管理部门来说&#xff0c;获取准确、可靠的交通数据已经变得越来越重要。 所以呢&#xff0c;ETC出入车辆校准看板是必要的。ETC出入车辆校准看板&#x…