Pytorch实战笔记(3)——BERT实现情感分析

news2024/11/20 0:26:20

本文展示的是使用 Pytorch 构建一个 BERT 来实现情感分析。本文的架构是第一章详细介绍 BERT,其中包括 Self-attention,Transformer 的 Encoder,BERT 的输入与输出,以及 BERT 的预训练和微调方式;第二章是核心代码部分。

目录

  • 1 BERT
    • 1.1 self-attention
    • 1.2 multi-head self-attention
    • 1.3 Encoder
    • 1.4 BERT 的输入与输出
      • 1.4.1 BERT 的输入
      • 1.4.2 BERT 的输出
    • 1.5 BERT 预训练
    • 1.6 BERT 微调
  • 2 BERT 实现情感分析
  • 参考

1 BERT

1.1 self-attention

Self-attention 接受一个序列输入,并输出等长的序列。其运行流程如下。
self-attention
上图是 self-attention 的部分实例,因为仅展示了 b 1 b_1 b1 的计算过程。其计算过程如下所述(这里仅说明 b 1 b_1 b1 的计算过程, b 2 b_2 b2 b 4 b_4 b4 的计算方式与 b 1 b_1 b1 一样):

  1. 对于输入序列 { a 1 , a 2 , a 3 , a 4 } \{a_1, a_2, a_3, a_4\} {a1,a2,a3,a4},当我们计算 a 1 a_1 a1 对该输入序列的注意力向量时, a 1 a_1 a1 会经过三次不同的线性变换,得到 q 1 q_1 q1 k 1 k_1 k1 v 1 v_1 v1 向量,公式如下。这里的 q (query)、k (key)、v (value) 可以用数据库来理解,q 对应的就是 SQL 语句,来查询某个键,最后返回这个键的值,就比如 q 是 ‘select age from girlfriend’,这里 query 就是这个 sql 语句,key 就是 age,value 就是 18。
    q 1 = W q a 1 , k 1 = W k a 1 , v 1 = W v a 1 . q_1=W^qa_1, \\ k_1=W^ka_1,\\ v_1=W^va_1. q1=Wqa1,k1=Wka1,v1=Wva1.
  2. 而对于 { a 2 , a 3 , a 4 } \{a_2, a_3, a_4\} {a2,a3,a4} 而言,它们是被查询注意力的对象,所以只生成 k 和 v(这里需要注意的是,self-attention 是会计算自己对自己的注意力的,所以会有 k1 和 v1)。
  3. 接着 q 1 q_1 q1 会与 { k 1 , k 2 , k 3 , k 4 } \{k_1, k_2, k_3, k_4\} {k1,k2,k3,k4} 分别做一次点积操作,得到注意力权重 { α 1 , 1 , α 1 , 2 , α 1 , 3 , α 1 , 4 } \{\alpha_{1, 1}, \alpha_{1, 2}, \alpha_{1, 3}, \alpha_{1, 4}\} {α1,1,α1,2,α1,3,α1,4},公式如下。这里需要注意的是,由于 k k k 会经过一次转置,所以注意力权重 α \alpha α标量。同时,由于点积操作可以看做是一次相似度的计算(因为余弦相似度的计算公式是 c o s θ = a ⋅ b ∣ a ∣ ∣ b ∣ {\rm cos}\theta=\frac{a \cdot b}{|a||b|} cosθ=a∣∣bab,即 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s θ a \cdot b = |a||b|{\rm cos}\theta ab=a∣∣bcosθ,所以内积可以看做是计算两个向量的相似度),所以这里内积就可以理解为计算 q 1 q_1 q1 { k 1 , k 2 , k 3 , k 4 } \{k_1, k_2, k_3, k_4\} {k1,k2,k3,k4} 的一次相似度权重计算(因为 α \alpha α 是标量,所以是相似度权重)。
    { a 1 , 1 , α 1 , 2 , α 1 , 3 , α 1 , 4 } = q 1 { k 1 , k 2 , k 3 , k 4 } T . \{a_{1,1}, \alpha_{1, 2}, \alpha_{1, 3}, \alpha_{1, 4}\} = q_1 \{k_1, k_2, k_3, k_4\}^{\rm T}. {a1,1,α1,2,α1,3,α1,4}=q1{k1,k2,k3,k4}T.
  4. 最后,相似度权重 { α 1 , 1 , α 1 , 2 , α 1 , 3 , α 1 , 4 } \{\alpha_{1, 1}, \alpha_{1, 2}, \alpha_{1, 3}, \alpha_{1, 4}\} {α1,1,α1,2,α1,3,α1,4} { v 1 , v 2 , v 3 , v 4 } \{v_1, v_2, v_3, v_4\} {v1,v2,v3,v4} 相乘,分别得到 a 1 a_1 a1 a 1 a_1 a1 的注意力向量、 a 1 a_1 a1 a 2 a_2 a2 的注意力向量、 a 1 a_1 a1 a 3 a_3 a3 的注意力向量、和 a 1 a_1 a1 a 4 a_4 a4 的注意力向量。接着将这些向量拼起来,就得到了 b 1 b_1 b1 b 1 b_1 b1 里面就包含了 a 1 a_1 a1 对整个输入序列的所有注意力向量。

1.2 multi-head self-attention

多头自注意力机制实际上就是计算多次 self-attention。如下图所示。
multi-head self-attention
multi-head self-attention 就是输入的向量会经过 h h h 个不同的线性变换,得到 h h h 个 q、k、v。比如 h = 2 h=2 h=2 的时候, a 1 a_1 a1 会通过以下公式得到 q 1 1 q_1^1 q11 k 1 1 k_1^1 k11 v 1 1 v_1^1 v11 q 1 2 q_1^2 q12 k 1 2 k_1^2 k12 v 1 2 v_1^2 v12
q 1 1 = W 1 q a 1 , k 1 1 = W 1 k a 1 , v 1 1 = W 1 v a 1 , q 1 2 = W 2 q a 1 , k 1 2 = W 2 k a 1 , v 1 2 = W 2 v a 1 . q_1^1=W^q_1a_1, \\ k_1^1=W^k_1a_1,\\ v_1^1=W^v_1a_1,\\ q_1^2=W^q_2a_1, \\ k_1^2=W^k_2a_1,\\ v_1^2=W^v_2a_1. q11=W1qa1,k11=W1ka1,v11=W1va1,q12=W2qa1,k12=W2ka1,v12=W2va1.
接着,每个 self-attention 后的输出,会拼在一起,再通过一个线性转换,得到 multi-head self-attention 的输出。设第一个头的输出为 h e a d 1 head_1 head1,第二个头的输出为 h e a d 2 head_2 head2,最后的输出为 O O O,则其计算公式为:
O = c o n c a t ( h e a d 1 , h e a d 2 ) W o . O = {\rm concat}(head_1, head_2)W^o. O=concat(head1,head2)Wo.

1.3 Encoder

这里的 Encoder 特指的是 Transformer[1] 中的 Encoder(左边是 Transformer 的 Encoder,右边是 Decoder),其模型结构如下:
Transformer
Encoder 中一共有以下几个部分:

  • Multi-head self-attention:在前面已介绍过了。
  • 残差连接 (Residual connection)[2]:对应的是图中的 Add。残差连接如下图所示。简单来说,残差连接就是将一个模块的输入与其输出相加,通常使用在层次较深的结构当中。那么为什么残差连接在深层次模型中有效?具体而言,如果不采用残差连接,那么前向传播为 F ( x ) F(x) F(x),反向传播的时候,求梯度就为 ∂ ( F ( x ) ) ∂ x \frac{\partial (F(x))}{\partial x} x(F(x)),当梯度消失的时候, ∂ ( F ( x ) ) ∂ x \frac{\partial (F(x))}{\partial x} x(F(x)) 就为0,就无法回传梯度。而当采用了残差连接后,前向传播变为 F ( x ) + x F(x) + x F(x)+x。从直觉上来讲,这样能够让模型更关注于经过了这个模块后变化的部分;而从数学上来将,在反向传播的时候会变成 ∂ ( F ( x ) + x ) ∂ x = ∂ ( F ( x ) ) ∂ x + 1 \frac{\partial (F(x)+x)}{\partial x}=\frac{\partial (F(x))}{\partial x}+1 x(F(x)+x)=x(F(x))+1。当梯度消失后,那么 ∂ ( F ( x ) ) ∂ x \frac{\partial (F(x))}{\partial x} x(F(x)) 趋近于0,所以 ∂ ( F ( x ) + x ) ∂ x \frac{\partial (F(x)+x)}{\partial x} x(F(x)+x) 趋近于1,使得梯度无法消失,始终能够回传。
    残差连接
  • 层归一化 (layer norm)[3]:对应的是图中的 Norm。层归一化的公式如下所示。其中, m m m 是向量 x i x_i xi 的均值, σ \sigma σ 是向量 x i x_i xi 的标准差。层归一化的示例图如下所示。具体而言,如果数据不做归一化,有可能在某些方向上梯度下降很快(从左下到右上),这样会导致越过最优点;有的方向上下降很慢(从右下到左上),这样会导致半天收敛不到最优点。而通过层归一化后,能够使得数据在各个方向上都能够下降的一样快,使得能够更快收敛。
    x i ′ = x i − m σ x_i'=\frac{x_i-m}{\sigma} xi=σxim
    layer norm
  • 位置嵌入 (Positional Encoding):对应的是图中最下面的 Positional Encoding。为什么要位置嵌入?由于 self-attention 可以看做是下图这样,两个位置之间间隔为1。如果不能理解为什么是 1,可以再回过头看看上面那个 gif。那么这样会导致一个问题,对于自然语言处理的任务而言,词语的先后顺序肯定是很重要的,就比如我现在这里写到了 positional encoding,那么和第一小节写的 self-attention 关联就很弱了,所以需要通过位置嵌入来控制词语与词语之间的位置。
    self-attention
  • 全连接层:对应图中的 Feed forward,没什么好说的,唯一要注意的是,这里是两层全连接层,公式如下:
    F F N ( x ) = W 2 ( R e L U ( W 1 x + b 1 ) ) + b 2 FFN(x)=W_2({\rm ReLU}(W_1x+b_1))+b_2 FFN(x)=W2(ReLU(W1x+b1))+b2

1.4 BERT 的输入与输出

1.4.1 BERT 的输入

BERT 的输入与传统的语言模型输入不同,传统的语言模型的输入就只是整个句子,而 BERT 在输入中还加入了几个特殊的字符。其中包括:

  • [CLS]:[CLS] 一定出现在句首,这个特殊字符通过 BERT 后得到的隐藏状态代表了该句子的句向量。 [CLS] 是一定会有的。
  • [SEP]:[SEP] 一定出现在句子的结尾。由于 BERT 支持单句和两句话输入,所以用 [SEP] 来区分哪句话是哪句话。[SEP] 是一定会有的。
  • [MASK]:[MASK] 会出现在 [CLS] 与 [SEP] 中的任意位置,该特殊字符是让 BERT 去预测这个位置是什么词语。[MASK] 不一定会有

以以下两句话为例 练习时长两年半唱跳 rap 打篮球,那么输入进 BERT 后会变成以下这样:[CLS] 练习时长两年半 [SEP] 唱跳 rap 打篮球 [SEP];如果只有前一句话输入,并且掩盖掉 的话,那么是如下这样:[CLS] 练习时长两年[MASK] [SEP]

1.4.2 BERT 的输出

与传统的序列模型一样,BERT 的输出有两部分:

  • 句向量:通过模型后,[CLS] 的隐藏状态即句向量。如果是一句话输入,那么就是这句话的句向量;如果是两句话输入,那么就是这两句话的句向量。
  • 每个词语的隐藏状态:和 LSTM 一样,BERT 也会输出每个词语的隐藏状态。这里特别需要注意的是,所谓的 BERT 的词嵌入,实际上指的就是这个通过 BERT 后的隐藏状态,而非 BERT 的嵌入层。 这是因为 BERT 是基于上下文的词嵌入(contextualized word embedding),你得有上下文信息,才能叫词嵌入。

1.5 BERT 预训练

BERT 预训练有两个部分,第一个部分是 masked language model (MLM),第二部分是 next sentence prediction (NSP)

  • MLM 简单来说就是随机将输入文本中 15% 的词语给提取出来,然后进行以下处理:1. 80% 的可能,将词语替换为 [MASK],这是让模型通过上下文来预测这 [MASK] 是什么词语;10% 的可能,将词语随机替换为另外一个词语;10% 的可能,保持词语不变。MLM 如下图所示。MLM 是个 V V V 分类任务,其中 V V V 是词表大小。
    MLM
  • NSP 简单来说就是输入两句话到模型中,让模型判断后一句话是否与前一句话有关联。NSP 如下图所示。NSP 是个二分类任务,其中 1 代表上下两句话有关联,0 代表没有关联。
    NSP

1.6 BERT 微调

微调阶段,首先 BERT 会先加载预训练好的参数,并额外添加上部分随机初始化的参数。如下图所示。图中,用橙色标出来的全连接层,就是在微调阶段随机初始化的参数。所以微调阶段,训练的为两部分内容:

  • 模型本身:这部分是可以不参与训练的,因为模型已经在预训练阶段训练好了,不是必须训练的。
  • 随机初始化的参数:这部分是必须训练的参数。
    BERT 微调

2 BERT 实现情感分析

  • 全部代码在 github 上,网址为:https://github.com/Balding-Lee/Pytorch4NLP
  • 我采用的是 IMDb 数据集,由于数据集没有验证集,而且读取起来很麻烦,所以我将数据给读取出来,放到了一个文件中,并且将训练集中的10%划分为了验证集,数据集链接如下: https://pan.baidu.com/s/128EYenTiEirEn0StR9slqw,提取码:xtu3 。
  • 采用的词嵌入是谷歌的词嵌入,词嵌入的链接如下:链接:https://pan.baidu.com/s/1SPf8hmJCHF-kdV6vWLEbrQ,提取码:r5vx
    在本博客中仅介绍模型部分,详细代码见 github。

具体的模型代码如下:

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertConfig, BertForSequenceClassification


class Config:
    def __init__(self):
        # 训练配置
        self.seed = 22
        self.batch_size = 64
        self.lr = 1e-5
        self.weight_decay = 1e-4
        self.num_epochs = 100
        self.early_stop = 512
        self.max_seq_length = 128
        self.save_path = '../model_parameters/BERT_SA.bin'

        # 模型配置
        self.bert_hidden_size = 768
        self.model_path = 'bert-base-uncased'
        self.num_outputs = 2


class Model(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.config = config
        self.device = device
        tokenizer_class, bert_class, model_path = BertTokenizer, BertForSequenceClassification, config.model_path
        bert_config = BertConfig.from_pretrained(model_path, num_labels=config.num_outputs)
        self.tokenizer = tokenizer_class.from_pretrained(model_path)
        self.bert = bert_class.from_pretrained(model_path, config=bert_config).to(device)

    def forward(self, inputs):
        tokens = self.tokenizer.batch_encode_plus(inputs,
                                                  add_special_tokens=True,
                                                  max_length=self.config.max_seq_length,
                                                  padding='max_length',
                                                  truncation='longest_first')

        input_ids = torch.tensor(tokens['input_ids']).to(self.device)
        att_mask = torch.tensor(tokens['attention_mask']).to(self.device)

        logits = self.bert(input_ids, attention_mask=att_mask).logits

        return logits

实验结果如下:

test loss 0.281900 | test accuracy 0.878846 | test precision 0.853424 | test recall 0.915280 | test F1 0.883270

参考

[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, et al. Attention is all you need [EB/OL]. https://arxiv.org/abs/1706.03762, 2017.
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, et al. Deep Residual Learning for Image Recognition [EB/OL]. https://arxiv.org/abs/1512.03385, 2015.
[3] Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton. Layer Normalization [EB/OL]. https://arxiv.org/abs/1607.06450, 2016.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/191447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器视觉_HALCON_HDevelop用户指南_4.HDevelop开发程序

文章目录四、HDevelop编程4.1. 新建一个新程序4.2. 输入一个算子4.3. 指定参数4.4. 获取帮助4.5. 添加其他程序4.6. 理解图像显示4.7. 检查变量4.8. 利用灰度直方图改进阈值4.9. 编辑代码行4.10. 重新执行程序4.11. 保存程序4.12. 选择特征区域4.13. 打开图形窗口4.14. 循环遍历…

Swig工具在win10上使用

SWIG 是一种软件开发工具,它将 C 和 C 编写的程序与各种高级编程语言连接起来。这里我们用它来将 C/C 转换成 Java。 一、Swig安装 1、下载 官网:SWIG官网下载 源码链接 GitHub:https://github.com/swig/swig.git 这两个地址可能会出现无…

STM32单片机智能蓝牙APP加油站火灾预警安防防控报警监控系统MQ2DHT11

实践制作DIY- GC0122-智能蓝牙APP加油站火灾预警 一、功能说明: 基于STM32单片机设计-智能蓝牙APP加油站火灾预警 功能介绍: 基于STM32F103C系列最小系统,MQ-2烟雾传感器,火焰传感器(不能直视阳光会受到阳光干扰&…

Cesium 渐变长方体实现-Shader

position获取: 1.1 在cesium中,可通过vec4 p = czm_computePosition();获取 模型坐标中相对于眼睛的位置矩阵 1.2 vec4 eyePosition = czm_modelViewRelativeToEye * p; // position in eye coordinates 获取eyePosition 1.3 v_positionEC = czm_inverseModelView * eyePo…

Python流程控制详解

和其它编程语言一样,Python流程控制可分为 3 大结构:顺序结构、选择(分支)结构和循环结构。 Python对缩进的要求(重点) Python 是一门非常独特的编程语言,它通过缩进来识别代码块,…

ConditionalOnBean详解及ConditionalOn××总结

ConditionalOnBean详解 为什么学习ConditionalOnBean 在学习 Springboot 自动装配的时候遇到 Bean 装配和 Bean 配置需要条件判断的场景时,查阅了相关内容了解到 Conditional 和 ConditionalOnBean 注解,深入学习之后受益匪浅。 ConditionalOnBean测试…

后量子 KEM 方案:Newhope

参考文献: Lyubashevsky V, Peikert C, Regev O. On ideal lattices and learning with errors over rings[J]. Journal of the ACM (JACM), 2013, 60(6): 1-35.Lyubashevsky V, Peikert C, Regev O. A toolkit for ring-LWE cryptography[C]//Advances in Cryptol…

Linux常见指令大全(一)

🌹作者:云小逸 📝个人主页:云小逸的主页 📝Github:云小逸的Github 🤟motto:要敢于一个人默默的面对自己,强大自己才是核心。不要等到什么都没有了,才下定决心去做。种一颗树,最好的时间是十年前…

POE交换机全方位解读(上)

POE交换机在安防行业的应用,给视频监控系统带来了质的改变,POE交换机。可通过网线为无线AP、网路摄像头等PoE终端设备供电,传送距离可达100m,安装简单,即插即用。非常适合无线城市、安防监控等行业使用。 POE供电方案及…

「融云政企数智办公解决方案」入选「大信创产品目录」

1月31日,CIO 时代、新基建创新研究院联合公布“大信创产品目录”,“融云政企数智办公解决方案”成功通过审核,被正式纳入“大信创产品目录”。 据悉,CIO 时代、新基建创新研究院从去年底开始组织开展“大信创产品目录”征集工作&a…

【C语言 数据结构】数组与对称矩阵的压缩存储

文章目录数组的定义数组的顺序表示和实现顺序表中查找和修改数组元素矩阵的压缩存储特殊矩阵稀疏矩阵数组的定义 提到数组,大家首先会想到的是:很多编程语言中都提供有数组这种数据类型,比如 C/C、Java、Go、C# 等。但本节我要讲解的不是作为…

frp构建多级网络代理

简介frp 是一个专注于内网穿透的高性能的反向代理应用,支持 TCP、UDP、HTTP、HTTPS 等多种协议,采用 Golang 编写,支持跨平台,仅需下载对应平台的二进制文件即可执行,没有额外依赖。frp可以将内网服务以安全、便捷的方…

Idea 中【Maven】的环境配置

目录 一 maven 项目管理工具软件二.首先要安装Jdk1.7/8 和IDEA三.在IDEA中配置maven四.在MavenDemo01下 创建多个模块项目四.Jar包依赖 插件五.运用一 maven 项目管理工具软件 1 . Maven项目对象模型(POM),可以通过一小段描述信息来管理项目的构建,报告和文档的项目管理工具…

MQ如何保证消息不丢失

如何保证消息不丢失 哪些环节会造成消息丢失 其实主要就是跨网络的环境中需要考虑消息的丢失,主要是有以下几个方面 生产者往MQ发送消息MQ的Broker是集群有主从的,主节点把消息同步到从节点时也需要考虑消息丢失问题消息从内存持久化到硬盘时&#xf…

软考高级系统架构师背诵要点---软件架构设计

软件架构设计 软件架构的概念: 软件架构为软件系统提供了一个结构、行为和属性的高级抽象,由构成系统的元素的描述、这些元素的相互作用、指导元素集成的模式及这些模式的约束组成 软件架构41视图: 逻辑视图:主要是整个系统的抽…

Java基础:面向对象

一、设计对象并使用 二、封装 对象代表什么,就得封装对应的数据,并提供数据对应的行为。 1.private关键字:priviate修饰的成员变量只能在本类中访问。 2.this关键字:能够直接对应成员变量(当局部变量名相同时)。 3. 构造方法…

【Linux】十分钟快速了解Linux常用指令(建议收藏)

目录💖一. 关机指令01. shutdown02. halt03. reboot💖二. 常用指令04. ls05. pwd06. cd07. touch08. mkdir09. rm10. man11. cp(复制)12. mv指令13. nano14. cat15. less16. head17. tail18. find19. grep20. zip/unzip21. tar💖三、 日期指令…

JS 中 reduce()方法及使用详解

reduce()方法可以搞定的东西特别多,就是循环遍历能做的,reduce都可以做,比如数组求和、数组求积、统计数组中元素出现的次数、数组去重等等。 reduce() 方法对数组中的每个元素执行一个由您提供的reduce函数(依次执行),将其结果汇…

Python字符串分割方法【心得总结】

Python中字符串分割的常用方法 是直接调用字符串的str.split方法, 但是其只能指定一种分隔符, 如果想指定多个分隔符拆分字符串需要用到re.split方法 (正则表达式的split方法) 源码资料电子书:点击此处跳转文末名片获取 str.spli…

OAuth2简单介绍

目录 一、什么是OAuth2 二、OAuth2中的角色 1、资源所有者 2、资源服务器 3、客户 4、授权服务器 三、认证流程 四、生活中的OAuth2思维 五、令牌的特点 六、OAuth2授权方式 1、授权码 2、隐藏式 3、密码式 4、凭证式 一、什么是OAuth2 OAuth2.0是目前使用非常广…