Bert详细学习及代码实现详解

news2024/10/7 6:50:08

BERT概述

BERT的全称是Bidirectional Encoder Representation from Transformers,即双向Transformer的Encoder,因为decoder是不能获要预测的信息的。在大型语料库(Wikipedia + BookCorpus)上训练一个大型模型(12 层到 24 层 Transformer)很长时间(1M 更新步骤),这就是 BERT。

  • 模型的主要创新点都在pre-train方法上,即用了Masked LMNext Sentence Prediction两种方法分别捕捉词语句子级别的representation。

    • Masked LM --> word
    • Next Sentence Prediction --> sentence

在这里插入图片描述

Mask掩码

在原始预处理代码中,我们随机选择 WordPiece 标记进行掩码。
例如:
Input Text: the man jumped up , put his basket on phil ##am ##mon ' s head
Original Masked Input: [MASK] man [MASK] up , put his [MASK] on phil [MASK] ##mon ' s head

全字掩码改进:
Whole Word Masked Input: the man [MASK] up , put his basket on [MASK] [MASK] [MASK] ' s head

改进思想:
训练是相同的——我们仍然独立预测每个屏蔽的 WordPiece 标记。改进来自于这样的事实:对于已拆分为多个 WordPieces 的单词,原始预测任务过于“简单”。

  • 一次预测一个mask太简单了,把原来的mask周围的词全部都mask掉,提高难度。

Enmbedding

三种Embedding求和构成的:
在这里插入图片描述

  • Token Embeddings是词向量,第一个单词是CLS标志,可以用于之后的分类任务
  • Segment Embeddings用来区别两种句子,因为预训练不光做LM还要做以两个句子为输入的分类任务
  • Position Embeddings和之前文章中的Transformer不一样,不是三角函数而是学习出来的

Pre-training Task 1: Masked Language Model

为什么要bidirection?

意思就是如果使用预训练模型处理其他任务,那人们想要的肯定不止某个词左边的信息,而是左右两边的信息。

  • 在训练过程中作者随机mask 15%的token,而不是把像cbow一样把每个词都预测一遍。最终的损失函数只计算被mask掉那个token。
Input: the man went to the [MASK1] . he bought a [MASK2] of milk.
Labels: [MASK1] = store; [MASK2] = gallon

mask的技巧:

Mask如何做也是有技巧的,如果一直用标记[MASK]代替(在实际预测时是碰不到这个标记的)会影响模型,所以随机mask的时候10%的单词会被替代成其他单词,10%的单词不替换,剩下80%才被替换为[MASK]。

  • 要注意的是Masked LM预训练阶段模型是不知道真正被mask的是哪个词,所以模型每个词都要关注。

sequence_length:

  • 因为序列长度太大(512)会影响训练速度,所以90%的steps都用seq_len=128训练,余下的10%步数训练512长度的输入。

Pre-training Task 2: Next Sentence Prediction

因为涉及到QA和NLI之类的任务,增加了第二个预训练任务

  • 目的是让模型理解两个句子之间的联系。训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,模型预测B是不是A的下一句。预训练的时候可以达到97-98%的准确度。

注意:作者特意说了语料的选取很关键,要选用document-level的而不是sentence-level的,这样可以具备抽象连续长序列特征的能力。

Sentence A: the man went to the store .
Sentence B: he bought a gallon of milk .
Label: IsNextSentence
Sentence A: the man went to the store .
Sentence B: penguins are flightless .
Label: NotNextSentence

fine-tuning

code:run_classifier.py / run_squad.py(tpu)

Sentence (and sentence-pair) classification tasks

在运行此示例之前,您必须通过运行此脚本下载 GLUE 数据并将其解压到某个目录 $GLUE_DIR 。接下来,下载 BERT-Base 检查点并将其解压缩到某个目录 $BERT_BASE_DIR 。

此示例代码在 Microsoft Research Paraphrase Corpus (MRPC) 语料库上微调 BERT-Base ,该语料库仅包含 3,600 个示例,并且可以在大多数 GPU 上在几分钟内进行微调。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue

python run_classifier.py \
  --task_name=MRPC \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/
***** Eval results *****
  eval_accuracy = 0.845588
  eval_loss = 0.505248
  global_step = 343
  loss = 0.505248

训练完分类器后,您可以使用 --do_predict=true 命令在推理模式下使用它。输入文件夹中需要有一个名为 test.tsv 的文件。输出将在输出文件夹中名为 test_results.tsv 的文件中创建。每行将包含每个样本的输出,列是类概率。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier

python run_classifier.py \
  --task_name=MRPC \
  --do_predict=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=/tmp/mrpc_output/

影响内存使用的因素有:

  1. max_seq_length :发布的模型使用高达 512 的序列长度进行训练,但您可以使用更短的最大序列长度进行微调以节省大量内存。这是由示例代码中的 max_seq_length 标志控制的。

  2. train_batch_size :内存使用量也与批量大小成正比。

  3. 模型类型, BERT-Base 与 BERT-Large : BERT-Large 模型比 BERT-Base 需要更多的内存。

  4. 优化器:BERT的默认优化器是Adam,它需要大量额外的内存来存储 m 和 v 向量。切换到内存效率更高的优化器可以减少内存使用量,但也会影响结果。我们还没有尝试过其他优化器进行微调。


Using BERT to extract fixed feature vectors

在某些情况下,与其对整个预训练模型进行端到端的微调,不如获得预训练的上下文嵌入,这些嵌入是从预训练的隐藏层生成的每个输入标记的固定上下文表示。 -训练有素的模型。这也应该可以缓解大部分内存不足问题。

# Sentence A and Sentence B are separated by the ||| delimiter for sentence
# pair tasks like question answering and entailment.
# For single sentence inputs, put one sentence per line and DON'T use the
# delimiter.
echo 'Who was Jim Henson ? ||| Jim Henson was a puppeteer' > /tmp/input.txt

python extract_features.py \
  --input_file=/tmp/input.txt \
  --output_file=/tmp/output.jsonl \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --layers=-1,-2,-3,-4 \
  --max_seq_length=128 \
  --batch_size=8

If you need to maintain alignment between the original and tokenized words (for projecting training labels), see the Tokenization section below.

注意:您可能会看到类似 Could not find trained model in model_dir: /tmp/tmpuB5g5c, running initialization to predict. 的消息 此消息是预期的,它仅意味着我们正在使用 init_from_checkpoint() API 而不是保存的模型 API。如果您不指定检查点或指定无效的检查点,该脚本将会抱怨。

tokenalization

  1. 实例化 tokenizer = tokenization.FullTokenizer 的实例

  2. 使用 tokens = tokenizer.tokenize(raw_text) 对原始文本进行标记。

  3. 截断至最大序列长度。 (您最多可以使用 512 个,但出于内存和速度原因,您可能希望使用更短的长度。)

  4. 在正确的位置添加 [CLS] 和 [SEP] 标记。

在我们描述处理单词级任务的一般方法之前,了解我们的分词器到底在做什么非常重要。它有三个主要步骤:

  • (1) 文本规范化:将所有空白字符转换为空格,并(对于 Uncased 模型)将输入小写并去掉重音标记。例如, John Johanson’s, → john johanson’s, 。

  • (2) 标点符号分割:分割两侧的所有标点符号(即在所有标点符号周围添加空格)。标点符号定义为 (a) 任何具有 P* Unicode 类的字符,(b) 任何非字母/数字/空格 ASCII 字符(例如,像 $ 这样的字符在技术上不是标点)。例如, john johanson’s, → john johanson ’ s ,

  • (3) WordPiece 标记化:将空格标记化应用于上述过程的输出,并对每个标记单独应用 WordPiece 标记化。 (我们的实现直接基于 tensor2tensor 中的实现,该实现是链接的)。例如, john johanson ’ s , → john johan ##son ’ s ,

### Input
orig_tokens = ["John", "Johanson", "'s",  "house"]
labels      = ["NNP",  "NNP",      "POS", "NN"]

### Output
bert_tokens = []

# Token map will be an int -> int mapping between the `orig_tokens` index and
# the `bert_tokens` index.
orig_to_tok_map = []

tokenizer = tokenization.FullTokenizer(
    vocab_file=vocab_file, do_lower_case=True)

bert_tokens.append("[CLS]")
for orig_token in orig_tokens:
  orig_to_tok_map.append(len(bert_tokens))
  bert_tokens.extend(tokenizer.tokenize(orig_token))
bert_tokens.append("[SEP]")

# bert_tokens == ["[CLS]", "john", "johan", "##son", "'", "s", "house", "[SEP]"]
# orig_to_tok_map == [1, 2, 4, 6]

分类任务

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

预训练模型

每个 .zip 文件包含三项:

  1. 包含预训练权重(实际上是 3 个文件)的 TensorFlow 检查点 ( bert_model.ckpt )。

  2. 用于将 WordPiece 映射到单词 id 的词汇文件 ( vocab.txt )。

  3. 指定模型超参数的配置文件 ( bert_config.json )。


代码详解

https://github.com/google-research/bert/blob/master/run_classifier.py

输入组成:

  • guid: Unique id for the example.
    text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified.
    text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks.
    label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/842301.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows为nginx添加定时任务(开机延迟启动)

windows开机启动任务 调用定时任务管理器选中windows创建基本任务设置名称和描述设置触发器 并且添加个延迟触发设置操作设置条件配置设置 调用定时任务管理器 winr 输入 taskschd.msc回车 选中windows创建基本任务 设置名称和描述 设置触发器 并且添加个延迟触发 设置操作 …

深入学习 Redis - 事务、实现原理、指令使用及场景

目录 一、Redis 事务 vs MySQL事务 二、Redis 事务的执行原理 2.1、执行原理 2.2、Redis 事务设计这么简单,为什么不涉及成 MySQL 那样强大呢? 三、Redis 事务的使用 3.1、使用场景 3.2、具体演示 开启/执行/放弃事务 watch 监控 watch 实现原理…

Visual ChatGPT:Microsoft ChatGPT 和 VFM 相结合

推荐:使用 NSDT场景编辑器助你快速搭建可二次编辑的3D应用场景 什么是Visual ChatGPT? Visual ChatGPT 是一个包含 Visual Foundation 模型 (VFM) 的系统,可帮助 ChatGPT 更好地理解、生成和编辑视觉信息。VFM 能够指…

UML箭头汇总

参考:http://www.cnblogs.com/damsoft/archive/2016/10/24/5993602.html 1.UML简介 Unified Modeling Language (UML)又称统一建模语言或标准建模语言。 简单说就是以图形方式表现模型,根据不同模型进行分类,在UML 2.0中有13种图&#xff…

Hi,运维,你懂Java吗--No.9:线程池

作为运维,你不一定要会写Java代码,但是一定要懂Java在生产跑起来之后的各种机制。 本文为《Hi,运维,你懂Java吗》系列文章 第九篇,敬请关注后续系列文章 欢迎关注 龙叔运维(公众号) 持续分享运维…

8个最高效的Python爬虫框架,你用过几个?

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 小编收集了一些较为高效的Python爬虫框架。分享给大家。 1.Scrapy Scrapy是一个为了爬取网站数据,提取结构性数据而编写的应用框架。 可以应用在包括数据挖掘,信息处理或存储历史数据等一系列的程…

Springboot @Validated注解详细说明

在Spring Boot中&#xff0c;Validated注解用于验证请求参数。它可以应用在Controller类或方法上 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-validation</artifactId> </depen…

SpringBoot源码分析(7)--prepareContext/准备应用上下文

文章目录 一、前言二、prepareContext2.1、context.setEnvironment2.2、postProcessApplicationContext(context);2.3、applyInitializers(context)2.4、发布ApplicationContextInitializedEvent事件2.5、打印启动和profile日志2.6、注册单例Bean2.6.1、手工注册单例Bean流程 2…

尚品汇总结七:商品详情模块(面试专用)

一、业务介绍 订单业务在整个电商平台中处于核心位置&#xff0c;也是比较复杂的一块业务。是把“物”变为“钱”的一个中转站。 整个订单模块一共分四部分组成&#xff1a; 结算页面 在购物车列表页面中,有一个结算的按钮,用户一点击这个按钮时,跳转到结算页,结算页展示了用…

单细胞测序基础知识

构建文库 上机测序 根据不同的荧光检测不同的碱基 质量控制&#xff08;质控QC&#xff09; 去除低质量的序列 表达定量 统计reads数&#xff0c;进而得到表达矩阵 标准化 让所有样本处在同一起跑线上 主成分分析PCA 图中每个点都代表一个样本&#xff0c;不同颜色…

【Linux】网络套接字知识点补足

目录 1 地址转换函数 1.1 字符串转in_addr的函数: 1.2 in_addr转字符串的函数: 1.3 关于inet_ntoa 2 TCP协议通讯流程 1 地址转换函数 本节只介绍基于IPv4的socket网络编程,sockaddr_in中的成员struct in_addr sin_addr表示32位 的IP 地址但是我们通常用点分十进制的字符串…

[BabysqliV3.0]phar反序列化

文章目录 [BabysqliV3.0]phar反序列化 [BabysqliV3.0]phar反序列化 开始以为是sql注入 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ST1jvadM-1691302941344)(https://raw.githubusercontent.com/leekosss/photoBed/master/202308032140269.png)…

CentOS 7中,配置了Oracle jdk,但是使用java -version验证时,出现的版本是OpenJDK,如何解决?

1.首先&#xff0c;检查已安装的jdk版本 sudo yum list installed | grep java2.移除、卸载圈红的系统自带的openjdk sudo yum remove java-1.7.0-openjdk.x86_64 sudo yum remove java-1.7.0-openjdk-headless.x86_64 sudo yum remove java-1.8.0-openjdk.x86_64 sudo yum r…

java.util.NoSuchElementException: No value present-报错(已解决)

阿丹&#xff1a; 今天在spring-boot整合MongoDB的过程中出现了下面的错误&#xff0c;是因为追求新技术、更优雅产生的。 记录一下。 错误截图如下&#xff1a; 错误位置代码如下&#xff1a; 主要问题&#xff08;问题原因&#xff09;&#xff1a; 因为之前升级了我的jdk的…

Java基础——注解

1 概述 注解用于对Java中类、方法、成员变量做标记&#xff0c;然后进行特殊处理&#xff0c;至于到底做何种处理由业务需求来决定。 例如&#xff0c;JUnit框架中&#xff0c;标记了注解Test的方法就可以被当做测试方法进程执行 2 自定义注解 public interface 注解名称 {p…

GLTF在线场景编辑工具

推荐&#xff1a;使用 NSDT场景编辑器助你快速搭建可二次编辑的3D应用场景 以下是Babylon.js Sandbox的主要功能和特点&#xff1a; 1、创建和编辑场景&#xff1a;Babylon.js Sandbox允许用户在一个交互式的3D环境中创建和编辑glTF场景。您可以添加不同类型的物体、调整其位置…

重型并串式液压机械臂建模与simscape仿真

一、建模 Hydraulic manipulator Figure 1 shows different constituting parts of the manipulator considered, with every part labeled using numbers from 1 to 10. For each part, a CAD model is provided. Each file is named in accordance with the corresponding la…

基于YOLOv7的密集场景行人检测识别分析系统

密集场景下YOLO系列模型的精度如何&#xff1f;本文的主要目的就是想要基于密集场景基于YOLOv7模型开发构建人流计数系统&#xff0c;简单看下效果图&#xff1a; 这里实验部分使用到的数据集为VSCrowd数据集。 实例数据如下所示&#xff1a; 下载到本地解压缩后如下所示&…

K8s operator从0到1实战

Operator基础知识 Kubernetes Operator是一种用于管理和扩展Kubernetes应用程序的模式和工具。它们是一种自定义的Kubernetes控制器&#xff0c;可以根据特定的应用程序需求和业务逻辑扩展Kubernetes功能。 Kubernetes Operator基于Kubernetes的控制器模式&#xff0c;通过自…

cocos creator 的input.on 不生效

序&#xff1a; 1、执行input.on的时候发现不生效 2、一直按控制台也打印不出来console.log 3、先收藏这篇&#xff0c;因为到时候cocos要开发serveApi的时候&#xff0c;你得选一款趁手的后端开发并且&#xff0c;对习惯用ts写脚本的你来说&#xff0c;node是入门最快&#xf…