Bert Encoder和Transformer Encoder有什么不同

news2025/1/12 22:50:40

前言:本篇文章主要从代码实现角度研究 Bert Encoder和Transformer Encoder 有什么不同?应该可以帮助你:

  • 深入了解Bert Encoder 的结构实现
  • 深入了解Transformer Encoder的结构实现

本篇文章不涉及对注意力机制实现的代码研究。

注:本篇文章所得出的结论和其它文章略有不同,有可能是本人代码理解上存在问题,但是又没有找到更多的文章加以验证,并且代码也检查过多遍。

观点不太一致的文章:bert-pytorch版源码详细解读_bert pytorch源码-CSDN博客 这篇文章中,存在 “这个和我之前看的transformers的残差连接层差别还挺大的,所以并不完全和transformers的encoder部分结构一致。” 但是我的分析是:代码实现上不太一样,但是本质上没啥不同,只是Bert Encoder在Attention之后多了一层Linear。具体分析过程和结论可以阅读如下文章。

如有错误或问题,请在评论区回复。

1、研究目标

这里主要的观察对象是BertModel中Bert Encoder是如何构造的?从Bert Tensorflow源码,以及transformers库中源码去看。

然后再看TransformerEncoder是如何构造的?从pytorch内置的transformer模块去看。

最后再对比不同。

2、tensorflow中BertModel主要代码如下

class BertModel(object):
    def __init__(...):
        ...得到了self.embedding_output以及attention_mask
        
        # transformer_model就代表了Bert Encoder层的所有操作
        self.all_encoder_layers = transformer_model(input_tensor=self.embedding_output, attention_mask=attention_mask,...)
        
        # 这里all_encoder_layers[-1]是取最后一层encoder的输出
        self.sequence_output = self.all_encoder_layers[-1]
        
        ...pooler层,对 sequence_output中的first_token_tensor,即CLS对应的表示向量,进行dense+tanh操作
        with tf.variable_scope("pooler"):
          first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
          self.pooled_output = tf.layers.dense(
              first_token_tensor,
              config.hidden_size,
              activation=tf.tanh,
              kernel_initializer=create_initializer(config.initializer_range))
        
def transformer_model(input_tensor, attention_mask=None,...):
    ...
    for layer_idx in range(num_hidden_layers):
        # 如下(1)(2)(3)就是每一层Bert Encoder包含的结构和操作
        with tf.variable_scope("layer_%d" % layer_idx):
            # (1)attention层:主要包含两个操作,获取attention_output,对attention_output进行dense + dropout + layer_norm
            with tf.variable_scope("attention"):
                # (1.1)通过attention_layer获得 attention_output
                attention_output
                
                # (1.2)output层:attention_output需要经过dense + dropout + layer_norm操作
                with tf.variable_scope("output"):
                    attention_output = tf.layers.dense(attention_output,hidden_size,...)
                    attention_output = dropout(attention_output, hidden_dropout_prob)
                    # “attention_output + layer_input” 表示 残差连接操作
                    attention_output = layer_norm(attention_output + layer_input)
        
            # (2)intermediate中间层:对attention_output进行dense+激活(GELU)
            with tf.variable_scope("intermediate"):
              intermediate_output = tf.layers.dense(
                  attention_output,
                  intermediate_size,
                  activation=intermediate_act_fn,)
            
            # (3)output层:对intermediater_out进行dense + dropout + layer_norm
            with tf.variable_scope("output"):
              layer_output = tf.layers.dense(
                  intermediate_output,
                  hidden_size,
                  kernel_initializer=create_initializer(initializer_range))
              layer_output = dropout(layer_output, hidden_dropout_prob)
              # "layer_output + attention_output"是残差连接操作
              layer_output = layer_norm(layer_output + attention_output)
              
              all_layer_outputs.append(layer_output)

3、pytorch的transformers库中的BertModel主要代码;

  • 其中BertEncoder对应要研究的目标
class BertModel(BertPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config) if add_pooling_layer else None
        
    def forward(...):
        # 这是嵌入层操作
        embedding_output = self.embeddings(input_ids=input_ids,position_ids=position_ids,token_type_ids=token_type_ids,...)
        
        # 这是BertEncoder层的操作
        encoder_outputs = self.encoder(embedding_output,attention_mask=extended_attention_mask,...)
        
        # 这里encoder_outputs是一个对象,encoder_outputs[0]是指最后一层Encoder(BertLayer)输出
        sequence_output = encoder_outputs[0]
        # self.pooler操作是BertPooler层操作,是先取first_token_tensor(即CLS对应的表示向量),然后进行dense+tanh操作
        # 通常pooled_output用于做下游分类任务
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
        
class BertEncoder(nn.Module):
    def __init__(self, config):
        ...
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        ...

    def forward(...):
        for i, layer_module in enumerate(self.layer):
            
            # 元组的append做法,将每一层的hidden_states保存到all_hidden_states;
            # 第一个hidden_states是BertEncoder的输入,后面的都是每一个BertLayer的输出
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            
            ...
            # 执行BertLayer的forward方法,包含BertAttention层 + BertIntermediate中间层 + BertOutput层
            layer_outputs = layer_module(...)
            
            # 当前BertLayer的输出
            hidden_states = layer_outputs[0]
            
            # 添加到all_hidden_states元组中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)


class BertLayer(nn.Module):
    def __init__(self, config):
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(...):
        # (1)Attention是指BertAttention
        # BertAttention包含:BertSelfAttention + BertSelfOutput
        # BertSelfAttention包括计算Attention+Dropout
        # BertSelfOutput包含:dense+dropout+LayerNorm,LayerNorm之前会进行残差连接
        self_attention_outputs = self.attention(...)
        # self_attention_outputs是一个元组,取[0]获取当前BertLayer中的Attention层的输出
        attention_output = self_attention_outputs[0]
        
        # (2)BertIntermediate中间层包含:dense+gelu激活
        # (3)BertOutput层包含:dense+dropout+LayerNorm,LayerNorm之前会进行残差连接
        # feed_forward_chunk的操作是:BertIntermediate(attention_output) + BertOutput(intermediate_output, attention_output)
        # BertIntermediate(attention_output)是:dense+gelu激活
        # BertOutput(intermediate_output, attention_output)是:dense+dropout+LayerNorm;
        # 其中LayerNorm(intermediate_output + attention_output)中的“intermediate_output + attention_output”是残差连接操作
        layer_output = apply_chunking_to_forward(self.feed_forward_chunk, ..., attention_output)
        

4、pytorch中内置的transformer的TransformerEncoderLayer主要代码

  • torch.nn.modules.transformer.TransformerEncoderLayer
class TransformerEncoderLayer(Module):
    '''
    Args:
    d_model: the number of expected features in the input (required).
    nhead: the number of heads in the multiheadattention models (required).
    dim_feedforward: the dimension of the feedforward network model (default=2048).
    dropout: the dropout value (default=0.1).
    activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    '''
    
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
    super(TransformerEncoderLayer, self).__init__()
    self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
    # Implementation of Feedforward model
    self.linear1 = Linear(d_model, dim_feedforward)
    self.dropout = Dropout(dropout)
    self.linear2 = Linear(dim_feedforward, d_model)

    self.norm1 = LayerNorm(d_model)
    self.norm2 = LayerNorm(d_model)
    self.dropout1 = Dropout(dropout)
    self.dropout2 = Dropout(dropout)

    self.activation = _get_activation_fn(activation)
    
    def forward(...):
        # 过程:
        # (1)MultiheadAttention操作:src2 = self.self_attn
        # (2)Dropout操作:self.dropout1(src2)
        
        # (3)残差连接:src = src + self.dropout1(src2)
        # (4)LayerNorm操作:src = self.norm1(src)
        
        # 如下是FeedForword:做两次线性变换,为了更深入的提取特征
        # (5)Linear操作:src = self.linear1(src)
        # (6)RELU激活(默认RELU)操作:self.activation(self.linear1(src))
        # (7)Dropout操作:self.dropout(self.activation(self.linear1(src)))
        # (8)Linear操作:src2 = self.linear2(...)
        # (9)Dropout操作:self.dropout2(src2)
        
        # (10)残差连接:src = src + self.dropout2(src2)
        # (11)LayerNorm操作:src = self.norm2(src)
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                      key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src
    

5、区别总结

        Transformer Encoder的结构如上图所示,代码也基本和上图描述的一致,不过代码中在Multi-Head Attention和Feed Forward之后都存在一个Dropout操作。(可以认为每层网络之后都会接一个Dropout层,是作为网络模块的一部分)

可以将Transformer Encoder过程表述为:

(1)MultiheadAttention + Dropout + 残差连接 + LayerNorm

(2)FeedForword(Linear + RELU + Dropout + Linear + Dropout) + 残差连接 + LayerNorm;Transformer默认的隐含层激活函数是RELU;

可以将 Bert Encoder过程表述为:

(1)BertSelfAttention: MultiheadAttention + Dropout

(2)BertSelfOutput:Linear+ Dropout + 残差连接 + LayerNorm; 注意:这里的残差连接是作用在BertSelfAttention的输入上,不是Linear的输入。

(3)BertIntermediate:Linear + GELU激活

(4)BertOutput:Linear + Dropout + 残差连接 + LayerNorm;注意:这里的残差连接是作用在BertIntermediate的输入上,不是Linear的输入;

进一步,把(1)(2)合并,(3)(4)合并:

(1)MultiheadAttention + Dropout + Linear + Dropout + 残差连接 + LayerNorm

(2)FeedForword(Linear + GELU激活 + Linear + Dropout) + 残差连接 + LayerNorm;Bert默认的隐含层激活函数是GELU;

所以,Bert Encoder和Transformer Encoder最大的区别是,Bert Encoder在做完Attention计算后,还会用一个线性层去提取特征,然后才进行残差连接。其次,是FeedForword中的默认激活函数不同。Bert Encoder图结构如下:

Bert 为什么要这么做?或许是多一个线性层,特征提取能力更强,模型表征能力更好。

GELU和RELU:GELU是RELU的改进版,效果更好。

Reference

  • GeLU、ReLU函数学习_gelu和relu-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1497410.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

快来查看!你的简历亮点在哪里?还有精美模板等你来下载!

一、个人简历写作指南 编写个人简历是展示自己专业技能、工作经历和教育背景的重要方式。以下是一些个人简历写作的指南,希望对你有所帮助: 1. 简明扼要 简洁清晰:简历内容应该简明扼要,突出重点信息,避免冗长。易读性…

J8 - Inception v1算法

目录 理论知识Inception卷积计算 模型结构模型实现inception 结构GoogLeNet模型打印模型结构 模型效果总结与心得体会 理论知识 GoogLeNet首次出现就在2014年的ILSVRC比赛中获得冠军,最初的版本为InceptionV1。共有22层深,参数量5M。 可以达到同时期VGG…

FreeROTS day2

总结DMA空闲中断接收数据的使用方法 首先要要选择串口然后配置串口的参数,配置MDA通道选择接受数据,配置空闲中断,定义一个数据接收的容器,启动MDA传输当串口收到数据时MDA将数据传输到容器中,MDA会一直检测是否有数据当有数据并…

Node 旧淘宝源 HTTPS 过期处理

今天拉取老项目更新依赖,出现 urlshttps%3A%2F%2Fregistry.npm.taobao.org%2Fegg-logger%2Fdownload%2Fegg-logger-2.6.1.tgz: certificate has expired 类似报错。即使删除 node_modules 重新安装,问题依然无法解决。 一、问题演示 二、原因分析 1、淘…

泰克P6139B TektronixP6139B无源探头

特征: 500 MHz 探头带宽 探头尖端的大输入阻抗 10 MOhm,8 pF 补偿范围:8 pF 至 18 pF 电缆长度:1.3M 10X 衰减系数 300 V CAT II 输入电压 用于探测小几何电路元件的紧凑型探头 用于增强被测设备可见性的小型探头主体 可更换的探…

leetcode 3.6

Leetcode hot 100 一.矩阵1.旋转图像 二.链表1. 相交链表2.反转链表3.回文链表4.环形链表5.环形链表 II 一.矩阵 1.旋转图像 旋转图像 观察规律可得: matrix[i][j] 最终会被交换到 matrix [j][n−i−1]位置,最初思路是直接上三角交换,但是会…

CTP-API开发系列之五:SimNow环境介绍

CTP-API开发系列之五:SimNow环境介绍 CTP-API开发系列之五:SimNow环境介绍SimNow模拟测试环境第一套第二套登录关键字段可视化终端常见问题 CTP-API开发系列之五:SimNow环境介绍 如果你要研发一套国内期货程序化交易系统,从模拟测…

AI嵌入式CanMV-K230项目(1)-简介

文章目录 前言一、嘉楠的产品体系二、开发板介绍三、应用领域总结 前言 前一些列文章我们介绍了K210的使用方法,近期嘉楠科技发布了最新一版的K230芯片,下面我们来了解下这款芯片,后续我们将介绍该款芯片开发板的使用方法。 一、嘉楠的产品体…

ant-desgin charts双轴图DualAxes,柱状图无法立即显示,并且只有在调整页面大小(放大或缩小)后才开始显示

摘要 双轴图表中,柱状图无法立即显示,并且只有在调整页面大小(放大或缩小)后才开始显示 官方示例代码 在直接复制,替换为个人数据时,出现柱状图无法显示问题 const config {data: [data, data],xFiel…

Cobalt Strike 4.9.1(已更新,文章图片没换)

Cobalt Strike 4.9.1 1. 工具介绍1.1. 工具添加1.2. 工具获取 2. 工具使用2.1. 添加权限并运行2.2. 连接服务端2.3. 连接成功 3. 安全性自查 1. 工具介绍 CS 是Cobalt Strike的简称,是一款渗透测试神器,常被业界人称为CS神器。Cobalt Strike已经不再使用…

类和对象周边知识

再谈构造函数 前几期我们把六个默认成员函数一一说明后,构造函数还有一些周边知识。 初始化列表 我们在没有了解初始化列表的时候一般都是使用构造函数初始化或者在声明哪里给予缺省值,那么为什么好药存在初始化列表呢?是因为①.有些值必须…

GAN 网络的损失函数介绍代码

文章目录 GAN的损失函数介绍1.L1 losses2.mse loss3.smooth L14.charbonnier_loss5.perceptual loss (content and style losses)6.Gan损失7.WeightedTVLoss8.完整代码方便使用,含训练epoch代码。 GAN的损失函数介绍 1.L1 losses pixel_opt: type: L1Loss loss_weight: 1.0 r…

Linux Ubuntu部署SVN服务端结合内网穿透实现客户端公网访问

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

2024.3.7

大端存储:高存低,低存高; 小端存储:高存高,低存低; sizeof 用于获取数据类型或变量的大小,strlen 用于获取字符串的长度。 不能改变常量字符串, char *arr"hello"; *ar…

外汇天眼:伦敦金属交易所宣布新的高级领导任命

伦敦金属交易所(LME)今日宣布了多项高级领导职务任命和组织设计变更。 LME的任命将于2024年4月1日生效。 苏珊斯莫尔被任命为总法律顾问,负责监督LME及LME Clear的法律职能。斯莫尔女士将于6月加入,并将向LME及LME Clear的首席执…

JEDEC标准介绍及JESD22全套下载

JEDEC标准 作为半导体相关的行业的从业者,或多或少会接触到JEDEC标准。标准对硬件系统的设计、应用、验证,调试等有着至关重要的作用。 JEDEC(全称为 Joint Electron Device Engineering Council)是一个电子组件工程标准制定组织…

2024【问题解决】Github 2024无法克隆git clone自从签了2F2安全协议之后

项目场景:ping通Github但没法clone–502 问题描述 提示:ping通Github但没法clone--502: 例如:git clone https://gitclone.com/l.git/*** $ git clone https://github.com/darrenpig/Yocto Cloning into Yocto_tutorial... fatal: unable to access https://gitclone.co…

计划任务和日志

一、计划任务 计划任务概念解析 在Linux操作系统中,除了用户即时执行的命令操作以外,还可以配置在指定的时间、指定的日期执行预先计划好的系统管理任务(如定期备份、定期采集监测数据)。RHEL6系统中默认已安装了at、crontab软件…

Freecad Assembly4装配模型设计入门

一、基本信息 本文内容:学习Assembly4装配模型设计功能。 2024年3月7日 最新版Freecad 0.21.2 最新版 Assembly4 0.50.8 下载地址:stoneold/FreeCAD_Assembly4 最新版 Assembly4 示例教程 下载地址:FreeCAD_Examples: Freecad Assmbly4 …

【JDBC】Java连接数据库

目录 JDBC的工作原理JDBC API:JDBC开发步骤加载并注册JDBC驱动:建立数据库连接:创建Statement对象:执行SQL语句:处理结果:Connection接口的常用方法Statement接口的常用方法ResultSet接口的常用方法 SQL注入…