ELECTRA模型简单介绍

news2024/11/25 15:23:00

目录

一、整体概要

二、生成器

三、判别器

四、模型训练

五、其它改进


一、整体概要

ELECTRA(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)采用了一种“生成器——判别器”结构,其与生成式对抗网络(Generative Adversarial Net,GAN)的结构非常相似。ELECTRA的整体模型结构如下图所示。

图中可以看到ELECTRA是由生成器(Generator)和判别器(Discriminator)串联起来的一个模型。这两个部分的作用如下。
 
(1)生成器。一个小的掩码语言模型(MLM),即在 [MASK]的位置预测原来的词;
 
(2)判别器。判断输入句子中的每个词是否被替换,即使用替换词检测(Replaced Token Detection,RTD)预训练任务,取代了BERT模型原始的MLM。需要注意的是这里并没有使用下一个句子预测(NSP)任务。
接下来,我们将结合图中的例子,详细介绍生成器和判别器的建模方法。

二、生成器

对于生成器来说,其目的是将带有掩码的输入文本x= x 1, ···,xn ,通过多层Transformer模型学习到上下文语义表示 h = h 1, ···, h n,并还原掩码位置的文本,即BERT中的MLM任务。需要注意的是,这里只预测经过掩码的词,即对于某个掩码位置t,生成器输出对应原文本 xt 的概率P_{}^{G} \in \mathbb{R}_{}^{|V|} (|V|是词表大小):

式中, w^{_{e}^{}}\in \mathbb{R}_{}^{|V|\times d}表示词向量矩阵; h_{t}^{G}表示原文本xt 对应的隐含层表示。
还是以上图为例,原始句子 x = x 1 x 2 x 3 x 4 x 如下:
 
the chef cooked the meal
经过随机掩码后的句子如下,记 M = {1, 3} 为所有经过掩码的单词位置的下标,记 x^{^{m}}=m_{1}x_{2}m_{3}x_{4}x_{5}   为经过掩码后的输入句子,如下所示:
 
[MASK] chef [MASK] the meal
 
那么生成器的目标是将m 1 还原为x 1 (即the),将m 3 还原为x 3 (即cooked)。在理想情况下,即当生成器的准确率为100%时,掩码标记 [MASK] 能够准确还原为原始句子中的对应单词。然而,在实际情况下,MLM的准确率并没有那么高。如果直接将掩码后的句子 x^{m}  输入生成器中,将产生采样后的句子 x^{s}
 
the chef ate the meal
从上面的例子可以看到,m 1 通过生成器成功地还原出单词the,而m3 采样(或预测)出的单词是ate,而不是原始句子中的cooked。
 
生成器生成的句子将会作为判别器的输入。由于通过生成器改写后的句子中不包含任何人为预先设置的符号(如 [MASK]),ELECTRA通过这种方法解决了预训练和下游任务输入不一致的问题。

三、判别器

受MLM准确率的影响,通过生成器采样后的句子 x^{s} 与原始句子有一定的差别。接下来,判别器的目标是从采样后的句子中识别出哪些单词是和原始句子 x 对应位置的单词一样的,即 替换词检测 任务。上述任务 可以通过二分类方法实现。
对于给定的采样句子 x^{s},通过Transformers模型得到对应的隐含层表示 h^{D} = h_{1}^{D}\cdots h_{n}^{D} 。随后,通过一个全连接层对每个时刻的隐含层表示映射成概率。

式中, w\in \mathbb{R}^{d} 表示全连接层的权重(d表示隐含层维度);M表示所有经过掩码的单词位置下标;σ表示Sigmoid激活函数。 假设1代表被替换过,0代表没有被替换过,则生成器采样生成的句子“the chef ate the meal”对应的预测标签如下,可以记为 y = y1···yn,即:
 
00100

四、模型训练

生成器和判别器分别使用以下损失函数训练:

最终,模型通过最小化以下损失学习模型参数:

式中,X 表示整个大规模语料库;\Theta ^{G}  和 \Theta ^{D} 分别表示生成器和判别器的参数。
注意:由于生成器和判别器衔接的部分涉及采样环节,判别器的损失并不会直接回传到生成器,因为采样操作是不可导的。另外,当预训练结束后,只需要使用判别器进行下游任务精调,而不再使用生成器。

五、其它改进

(1)更小的生成器。通过前面的介绍可以发现,生成器和判别器的主体结构均由BERT组成,因此两者完全可以使用同等大小的参数规模。但这样会导致预训练的时间大约为单个模型的两倍。为了提高预训练的效率,在ELECTRA中生成器的参数量要小于判别器。具体实现时会减小生成器中Transformer的隐含层维度、全连接层维度和注意力头的数目。对于不同模型规模的判别器,其缩放比例也不同,通常在1/4~1/2之间。以ELECTRA-base模型为例,缩放比例是1/3。下表展示了
ELECTRA-base 模型的生成器和判别器的各项参数大小对比。
 

为什么是减小生成器的大小,而不是判别器的大小?因为上文讲到生成器只会在预训练阶段使用,而在下游任务精调阶段是不使用的,因此减小生成器的大小是合理的。
 
(2)参数共享 为了实现更灵活的建模目的,ELECTRA首先引入了词向量因式分解方法,通过全连接层将词向量维度映射到隐含层维度。由于上面讲到,ELECTRA使用了一个更小的生成器,因此生成器和判别器之间无法直接进行参数共享。在ELECTRA中,参数共享只限于输入层权重,其中包括词向量和位置向量矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/587751.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

报表服务器Stimulsoft Server v2023.2亮点:支持PostgreSQL、选项卡

Stimulsoft Server(原Stimulsoft Reports.Server)是一款完善的商业智能解决方案,采用C/S架构,提供报告和分析功能。前端用来展现数据,后台用来处理和存储数据。Stimulsoft Server(原Stimulsoft Reports.Ser…

IDEA性能优化设置(解决卡顿问题)

在我们日常使用IDEA进行开发时,可能会遇到许多卡顿的瞬间,明明我们的机器配置也不低啊?为什么就会一直卡顿呢? 原来这是因为IDEA软件在我们安装的时候就设置了默认的内存使用上限(通常很小),这就…

成都远石无人机航测服务的具体内容

成都远石拥有多年西南地区无人机航测作业经验,现具备DEM、DSM、DOM、三维实景模型及机载激光雷达全套数据成果的生产能力,致力于为各个行业提供准确的数据支持。现在,小编就来给大家介绍成都远石无人机航测服务的具体内容。 1、数字高程模型…

SocketTools Library Edition C++ Crack

SocketTools 库版 SocketTools Library Edition 是一套 Windows 库,可简化 Internet 应用程序开发。它提供您入门所需的一切,包括文档和示例,以及免费的技术支持来回答您遇到的任何开发问题。SocketTools Library Edition 包括本机 Windows 库…

Java集成influxDB 默认UTC时区问题

dd 我用的influxDB 1.8版本高版本我不清楚,因为1.x版本便于写sq语法。 influxDB时序库默认使用 UTC时区,并且无法通过配置来修改这个时区,很多文档说在查询数据的时候加上 tz(Asia/Shanghai)。 而这个在Windows环境下的influxdb会报错…

【youcans动手学模型】AlexNet模型CIFAR10图像分类

欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 【youcans动手学模型】AlexNet模型CIFAR10图像分类 1. AlexNet 卷积神经网络模型1.1 论文简介1.2 AlexNet 的主要贡献1.3 AlexNet 网络1.4 模型的运行结果 2. 在 PyTorch 中定义 AlexNet 模型类2.1 按…

Qt编写视频监控系统77-Onvif组件支持非正常时间的设备

一、前言 在经历了大量的现场设备测试,至少几十种厂家、几百种设备,遇见过奇奇怪怪的问题,一个个想方设法解决,发现有个问题是在下发鉴权的时候,需要带上设备的时间,而不是发送端的时间,如果带…

LeetCode 1110. 删点成林

【LetMeFly】1110.删点成林 力扣题目链接:https://leetcode.cn/problems/delete-nodes-and-return-forest/ 给出二叉树的根节点 root,树上每个节点都有一个不同的值。 如果节点值在 to_delete 中出现,我们就把该节点从树上删去&#xff0c…

MySQL 系统信息函数

文章目录 系统信息函数1. 查看当前 MySQL 数据库版本号2. 查看当前使用的数据库3. 查看当前服务器连接次数 系统信息函数 当我们需要知道当前 MySQL 数据库的一些基本信息和使用情况的时候,可以使用系统信息函数来获取相关信息,以随时掌握数据库的使用情…

【C++系列P2】引用——背刺指针的神秘刺客(精讲一篇过!)

前言 大家好吖,欢迎来到 YY 滴 C系列 ,热烈欢迎!如标题所示,本章主要内容主要来侃侃“引用”这个刺客!如下就是大纲啦~ 一.引用 1.含义与特点 引用,即取别名。它的最大特点是编译器不会为引用变量而开辟空间…

Segment Anything——图像分割的基础模型介绍

人工智能中的基础模型变得越来越重要。这个术语开始在 NLP 领域加快步伐,现在,随着 Segment Anything Model 的出现,他们也慢慢进入了计算机视觉领域。Segment Anything是 Meta 的一个项目,旨在为图像分割的基础模型构建起点。在本文中,我们将了解 Segment Anything 项目最…

Python:Python编程:从入门到实践__超清版:线程

Python线程与安全 实现线程安全有多重方式,常见的包括:锁,条件变量,原子操作,线程本地存储等。 💚 1. 锁2. 条件变量3. 通过 join 阻塞当前线程4. 采用 sleep 来休眠一段时间5. 原子操作5.1 使用 threading…

HTTP请求中token、cookie、session有什么区别

cookie HTTP无状态的,每次请求都要携带cookie,以帮助识别身份服务端也可以向客户端set-cookie,cookie大小4kb默认有跨域限制:不可跨域共享,不可跨域传递cookie(可通过设置withCredential跨域传递cookie) cookie本地存…

【EXata】5.4 连接到互联网

目录 5.4 连接到互联网 5.4.1 Windows 互联网网关配置 5.4.3 验证互联网网关配置 5.4 连接到互联网 EXata 允许在操作主机上运行的基于 Internet 的应用程序通过模拟网络连接到 Internet。这使得即时通讯、流媒体视频、VoIP 等应用程序可以像在现实世界中一样在 EXata 上运行。…

理解Java关键字volatile

原文链接 理解Java关键字volatile 在Java中,关键字volatile是除同步锁以外,另一个同步机制,它使用起来比锁要简单方便,但是却很容易被忽略,或者被误用。这篇文章就来详细讲解一下volatile它的作用,它的原理…

【图像水印 2022 ACM】PIMoG

【图像水印 2022 ACM】PIMoG 论文题目:PIMoG: An Effective Screen-shooting Noise-Layer Simulation for Deep-Learning-Based Watermarking Network 中文题目:PIMoG:深度学习水印网络中一种有效的截屏噪声层仿真 论文链接:https://dl.acm.o…

Redis-- 缓存预热+缓存雪崩+缓存击穿+缓存穿透

Redis-- 缓存预热缓存雪崩缓存击穿缓存穿透**加粗样式** 一 面试题引入二 缓存预热三 缓存雪崩3.1 问题现象3.2 预防解决 四 缓存穿透4.1 定义4.2 解决方案4.2.1 空对象缓存或者缺省值4.2.2 Google布隆过滤器Guava解决缓存穿透 五 缓存击穿5.1 定义5.2 危害5.3 解决 六 总结 一…

Excel·VBA统计多部门多商品销售量前10%的商品

如图:根据表中唯一的货品ID,有m个事业部中分别有n种货品,统计各事业部销量前10%的货品名称,生成统计表(以下为2种统计方式) 目录 仅统计货品ID方法1:字典嵌套字典结果 方法2:自定义函…

【LED子系统】十、详细实现流程(番外篇)

个人主页:董哥聊技术 我是董哥,高级嵌入式软件开发工程师,从事嵌入式Linux驱动开发和系统开发,曾就职于世界500强公司! 创作理念:专注分享高质量嵌入式文章,让大家读有所得! 文章目录…

Hive ---- 文件格式和压缩

Hive ---- 文件格式和压缩 1. Hadoop压缩概述2. Hive文件格式1. Text File2. ORC3. Parquet3. 压缩1. Hive表数据进行压缩2. 计算过程中使用压缩 1. Hadoop压缩概述 为了支持多种压缩/解压缩算法,Hadoop引入了编码/解码器,如下表所示: Hadoo…