BERT(Transformer Encoder)详解和TensorFlow实现(附源码)

news2024/11/27 17:39:57

文章目录

  • 一、BERT简介
    • 1. 模型
    • 2. 训练
      • 2.1 Masked Language Model
      • 2.2 Next Sentence Prediction
      • 2.3 BERT的输出
    • 3. 微调
  • 二、源码
    • 1. 加载BERT模型
    • 2. 加载预处理模型
    • 3. 加载BERT
    • 4. 构建BERT微调模型
    • 5. 训练
    • 6. 推理

一、BERT简介

1. 模型

  BERT的全称为Bidirectional Encoder Representation from Transformers,从名字中可以看出,BERT来源于Transformer的Encoder,见如下Transformer网络结构图,其中红框部分即BERT:
在这里插入图片描述
  图中Encoder(BERT)和Decoder(GPT)结构相似,核心的区别在于其Attention Model的不同。BERT采用了双向注意力模型,而GPT采用的是单向注意力模型(即某个token只与该句子中位于其前方的token计算Attention),如下图:
在这里插入图片描述
  单向语言模型会限制模型的表征能力,使其只能获取单方向的上下文信息,而BERT利用双向注意力来构建整个神经网络,因此最终生成能融合左右上下文信息的深层双向语言表征,即真正意义上的Bi-Directional Context信息。这使得该模型在“完型填空”、“问答”、“情感分析”、“目标式搜索”、和“辅助性导航”等需要完整理解左右上下文信息的场景上,有着非常优秀的表现。

2. 训练

  BERT使用两种预训练方法:Masked Language Model和Next Sentence Prediction,分别捕捉词语和句子级别的特征。

2.1 Masked Language Model

  Masked Language Model称作遮蔽语言模型(简称 MLM)。MLM可以理解为完形填空,我们会随机屏蔽(mask)每一个句子中15%的词,用其上下文来预测原本的词语。
  例如:my dog is hairy → my dog is [MASK],此处将hairy进行了mask处理,然后预测mask位置的词是什么。
  训练过程中,我们要做如下处理:
  80%的情况下采用[MASK],my dog is hairy → my dog is [MASK]
  10%的情况下随机取一个词来代替[MASK]的词,my dog is hairy -> my dog is apple
  10%的情况下保持不变,my dog is hairy -> my dog is hairy
  这么做的主要原因是:在后续微调任务中语句中并不会出现[MASK]标记,且预测一个词汇时,模型并不知道输入对应位置的词汇是否为正确的词汇(10%概率),这就迫使模型更多地依赖于上下文信息去预测词汇,并且赋予了模型一定的纠错能力,训练过程如下图:
在这里插入图片描述

2.2 Next Sentence Prediction

  Next Sentence Prediction称作下一句预测(简称 NSP),即给定一篇文章中的两句话,判断第二句话在文本中是否紧跟在第一句话之后,如下图所示:
在这里插入图片描述
  在实际的训练中,通常另训练集中的50%符合IsNext关系,另外50%的第二句话随机从语料中提取,它们的关系是NotNext,并将这个关系保存在[CLS]中。

2.3 BERT的输出

  众所周知,BERT(以及GPT、Transformer)的模型特点是有多少个输入就有多少个对应的输出,如下图:
在这里插入图片描述
  故,理解其输出具有重要的意义。图中“C”为分类token “[CLS]”对应的输出,而“T_i”
则代表其他token “Tok_i”对应的输出。
  对于一些token级别的任务(如完型填空或问答任务),就把“T_i” 输入到额外的输出层中进行预测;对于一些句子级别的任务(如情感分类任务),就把“C”输入到额外的输出层中。
  注:图中的“[CLS]”代表了输入句子的分类;“[SEP]”代表输入的分隔符,其前后代表两个不同的句子。

3. 微调

  在海量语料训练完BERT之后,便可以将其应用到NLP的各个任务中。
  对应于NSP任务来说,其条件概率表示为:
在这里插入图片描述
  其中C是BERT输出中的[CLS]符号,W是可学习的权值矩阵。
  对于其他任务来说,我们也可以根据BERT的输出信息做出相应的预测。
  下图展示了BERT在11种各不同任务中的模型,它们只需要在BERT的基础上再添加一个输出层便可以完成对特定任务的微调。
在这里插入图片描述
  记录一篇很好的文章:https://blog.csdn.net/KK_1657654189/article/details/122204640

二、源码

  下面以语义分析为例,展示BERT的模型构建、预训练参数的加载、和微调的过程。

1. 加载BERT模型

  注:BERT的模型有很多个版本,其中Small BERT的参数量较小,微调耗时较小,适合新手学习使用;如果想要更高的准确率,且模型规模较小,可以选择ALBERT;如果要更高的准确率,而不关心模型规模,则可以选择classic BERT等版本。
  此处以Small BERT为例:

bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8' 

map_name_to_handle = {
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
}

map_model_to_preprocess = {
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
}

tfhub_handle_encoder = map_name_to_handle[bert_model_name]
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]

2. 加载预处理模型

  用于对输入的句子进行预处理,转换成BERT能识别的格式。

bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess) # 加载为Keras Model

text_test = ['this is such an amazing movie!']
text_preprocessed = bert_preprocess_model(text_test)

3. 加载BERT

bert_model = hub.KerasLayer(tfhub_handle_encoder)

4. 构建BERT微调模型

  由于本例是做语义分析,如前边内容所述,需要将BERT输出内容中的[CLS]输入到Dense层中做分类处理,如下:

def build_classifier_model():
  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
  preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
  encoder_inputs = preprocessing_layer(text_input)
  encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
  outputs = encoder(encoder_inputs)
  net = outputs['pooled_output']
  net = tf.keras.layers.Dropout(0.1)(net)
  net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net) # 此处unit个数是1,因此是二分类
  return tf.keras.Model(text_input, net)

# 构建微调模型
classifier_model = build_classifier_model()
bert_raw_result = classifier_model(tf.constant(text_test))
print(tf.sigmoid(bert_raw_result))

  构建后的微调模型如下:
在这里插入图片描述

5. 训练

# 定义损失函数和度量函数
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
metrics = tf.metrics.BinaryAccuracy()

# 定义optimizer
epochs = 5
steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)

init_lr = 3e-5
optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')

# 训练
classifier_model.compile(optimizer=optimizer,
                         loss=loss,
                         metrics=metrics)
print(f'Training model with {tfhub_handle_encoder}')
history = classifier_model.fit(x=train_ds,
                               validation_data=val_ds,
                               epochs=epochs)

6. 推理

loss, accuracy = classifier_model.evaluate(test_ds)

  至此BERT训练微调的主要过程已完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/620742.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java.time 时区详解

from: https://blog.zhjh.top/archives/MFTOJ-jorm4ISK9KXEYFE LocalDateTime 类是不包含时区信息的,可以通过 atZone 方法来设置 ZoneId,返回 ZonedDateTime 类实例,通过 atOffset 方法来设置 ZoneOffset,返回 OffsetDateTime 类…

攻防世界-web-supersqli

1. 题目描述: 2. 思路分析 这里尝试按照基本思路进行验证,先确定注入点,然后通过union查询依次确认数据库名,表名,字段名,最终获取到我们想要的字段信息。 这里只有一个输入框,所以注入点肯定…

【犀牛书】JavaScript 类型、值、变量章节读书笔记

本文为对《JavaScript权威指南》第三章:类型、值、变量精读的读书笔记,对重点进行了记录以及在一些地方添加了自己的理解。 JavaScript类型可以分为两类:原始类型和对象类型。Javascript的原始类型包括数值、文本字符串(也称字符串…

驱动操作控制LED灯

控制LED灯: 驱动如何操作寄存器 rgb_led灯的寄存器是物理地址,在linux内核启动之后, 在使用地址的时候,操作的全是虚拟地址。需要将物理地址 转化为虚拟地址。在驱动代码中操作的虚拟地址就相当于 操作实际的物理地址。 物理地址&…

2023年5月榜单丨飞瓜数据B站UP主排行榜(哔哩哔哩)发布!

飞瓜轻数发布2023年5月飞瓜数据UP主排行榜(B站平台),通过充电数、涨粉数、成长指数三个维度来体现UP主账号成长的情况,为用户提供B站号综合价值的数据参考,根据UP主成长情况用户能够快速找到运营能力强的B站UP主。 飞…

Git—版本管理工具

作用:分布式版本控制 一句话:在开发的过程中用于管理对文件、目录或工程等内容的修改历史,方便查看历史记录,备份以便恢复以前的版本的软件工程技术 官网下载安装:https://git-scm.com/ 命令大全:https://g…

OceanBase 4.1 全面测评及部署流程,看这篇就够了【建议收藏】

背景 测试 OceanBase 对比 MySQL,TiDB 的性能表现,数据存储压缩,探索多点内部项目一个数据库场景落地 Oceanbase(MySQL->OceanBase)。 单机测试 准备 OBD 方式部署单机 文件准备 wget https://obbusiness-pri…

Bilinear CNN:细粒度图像分类网络,对Bilinear CNN中矩阵外积的解释。

文章目录 一、Bilinear CNN 的网络结构二、矩阵外积(outer product)2.1 外积的计算方式2.2 外积的作用 三、PyTorch 网络代码实现 细粒度图像分类(fine-grained image recognition)的目的是区分类别的子类,如判别一只狗…

【web自动化测试】Web网页测试针对性的流程解析

前言 测试行业现在70%是以手工测试为主,那么只有20%是自动化测试,剩下的10%是性能测试。 有人可能会说,我现在做手工,我为什么要学自动化呢?我去学性能更好性能的人更少? 其实,性能的要求比自动…

蓝桥杯2022年第十三届决赛真题-齿轮

题目描述 这天,小明在组装齿轮。 他一共有 n 个齿轮,第 i 个齿轮的半径为 ri,他需要把这 n 个齿轮按一定顺序从左到右组装起来,这样最左边的齿轮转起来之后,可以传递到最右边的齿轮,并且这些齿轮能够起到提…

小程序容器与PWA是一回事吗?

PWA代表“渐进式网络应用”(Progressive Web Application)。它是一种结合了网页和移动应用程序功能的技术概念。PWA旨在提供类似于原生应用程序的用户体验,包括离线访问、推送通知、后台同步等功能,同时又具有网页的优势&#xff…

软件验收测试该怎么进行?权威的软件检测机构应该具备哪些资质?

软件测试是软件开发周期中非常重要的一个环节。软件测试的目的是发现软件在不同环境下的各种问题,保证软件在发布前能够达到用户的要求。软件验收测试是软件测试的最后一个环节,该环节主要验证软件是否满足用户需求。那么对于软件验收测试,该…

分布式事务二 Seata使用及其原理剖析

一 Seata 是什么 Seata 介绍 Seata 是一款开源的分布式事务解决方案,致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式,为用户打造一站式的分布式解决方案。AT模式是阿里首推的模式,阿里云上有商用版本…

【Spring源码】Spring源码导入Idea

1.基础环境准备 相关软件、依赖的版本号 Spring源码版本 5.3.x软件 ideaIU-2021.1.2.exeGradle gradle-7.2-bin.zip https://services.gradle.org/distributions/gradle-7.2-bin.zip - 网上说要单独下载gradle并配置环境变量,亲测当前5.3.X版本通过gradlew的方式进…

虚函数详解及应用场景

目录 概述1. 虚函数概述2. 虚函数的声明与重写3. 析构函数与虚函数的关系4. 虚函数的应用场景4.1. 多态性4.2. 接口定义与实现分离4.3. 运行时类型识别4.4. 多级继承与虚函数覆盖 结论 概述 虚函数是C中一种实现多态性的重要机制,它允许在基类中声明一个函数为虚函…

PDCCH monitoring capability

欢迎关注同名微信公众号“modem协议笔记”。 前段时间看search space set group (SSSG) switching相关内容时,注意到R17和R16的描述由于PDCCH monitoring capability的变化,内容有些不一样。于是就顺带看了下R16 R17PDCCH monitoring capability的内容。…

Domino 14.0早期测试版本

大家好,才是真的好。 本篇是超级图片篇,图片多,内容丰富,流量党请勿手残。 前天我们说到Engageug2023正在如火如荼进行,主题是“The Future is Now”。 因为时差的关系,实际上在写这篇公众号时&#xff…

设计模式(七):结构型之适配器模式

设计模式系列文章 设计模式(一):创建型之单例模式 设计模式(二、三):创建型之工厂方法和抽象工厂模式 设计模式(四):创建型之原型模式 设计模式(五):创建型之建造者模式 设计模式(六):结构型之代理模式 设计模式…

Java --- springboot3之web内容协商原理

一、内容协商原理 HttpMessageConverter 定制 HttpMessageConverter 来实现多端内容协商 编写WebMvcConfigurer提供的configureMessageConverters底层,修改底层的MessageConverter ResponseBody由HttpMessageConverter处理 标注了ResponseBody的返回值 将会由支持它…

蹭个高考热度,中国人民大学与加拿大女王大学金融硕士项目给你更多的选择

今日各大平台热搜都被“高考”霸屏,朋友圈里到处都是高考的祝福。期待莘莘学子都将交上满意的答卷,考出理想的未来。针对职场上的我们而言高考已是过去时,但知识的力量却是无穷的,在职的我们依然可以向上生长,中国人民…