Pytorch实战笔记(1)——BiLSTM 实现情感分析

news2025/1/16 0:50:52

本文展示的是使用 Pytorch 构建一个 BiLSTM 来实现情感分析。本文的架构是第一章详细介绍 BiLSTM,第二章粗略介绍 BiLSTM(就是说如果你想快速上手可以跳过第一章),第三章是核心代码部分。

目录

  • 1. BiLSTM的详细介绍
  • 2. BiLSTM 的简单介绍
  • 3. BiLSTM 实现情感分析
  • 参考

1. BiLSTM的详细介绍

坦白的说,其实我也不懂 LSTM,但是我这里还是尽我最大的可能解释这个模型。这里我就盗个图 [1](懒得自己画了,而且感觉好像他也是盗的李宏毅老师课件的图)。
LSTM
简单来说,LSTM 在每个时刻的输入都是由该时刻输入的序列信息 X t X^t Xt 与上一时刻的隐藏状态 h t − 1 h^{t-1} ht1 通过四种不同的非线性变化映射而成,分别为:

  1. 遗忘门控信号:遗忘门控信号 z f z^f zf 的计算公式如下:
    z f = s i g m o i d ( W f [ X t ; h t − 1 ] ) , z^f = {\rm sigmoid}(W^f\left[ X^t; h^{t-1} \right]), zf=sigmoid(Wf[Xt;ht1]),
    其中, [ X t ; h t − 1 ] [X^t;h^{t-1}] [Xt;ht1] 是将 X t X^t Xt h t − 1 h^{t-1} ht1 拼接起来; W f W^f Wf 是权重; S i g m o i d ( ⋅ ) {\rm Sigmoid}(\cdot) Sigmoid() 是 Sigmoid 激活函数,用于将数据映射到 (0, 1) 的区间范围内。
  2. 记忆门控信号:记忆门控信号 z i z^i zi 的计算公式如下:
    z i = s i g m o i d ( W i [ X t ; h t − 1 ] ) . z^i={\rm sigmoid}(W^i\left[ X^t; h^{t-1} \right]). zi=sigmoid(Wi[Xt;ht1]).
  3. 输出门控信号:输出门控信号 z o z^o zo 的计算公式如下:
    z o = s i g m o i d ( W o [ X t ; h t − 1 ] ) . z^o = {\rm sigmoid}(W^o\left[ X^t; h^{t-1} \right]). zo=sigmoid(Wo[Xt;ht1]).
  4. 当前时刻的信息:当前时刻的信息 z z z 的计算公式如下:
    z = t a n h ( W [ X t ; h t − 1 ] ) , z = {\rm tanh}(W\left[ X^t; h^{t-1} \right]), z=tanh(W[Xt;ht1]),
    其中, t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh() 是将数据放缩到 (-1, 1) 的区间内。

通过以上的公式,我们可以发现, z f , z i , z o z^f, z^i, z^o zf,zi,zo 都是 (0, 1) 区间的值,而 z z z(-1, 1) 区间的值。

接着就是 LSTM 的内部计算公式,即图上所示的那几个,分别为:

  1. 当前时刻的细胞状态 c t c^t ct 的计算公式如下:
    c t = z f ⊙ c t − 1 + z i ⊙ z , c^t = z^f \odot c^{t-1} + z^i \odot z, ct=zfct1+ziz,
    其中, ⊙ \odot 是哈达玛积,即矩阵元素对位相乘,但是需要注意的是,哈达玛积数学上不可解释,但是跑出来效果好
  2. 当前时刻的隐藏状态 h t h^t ht 的计算公式如下:
    h t = z o ⊙ t a n h ( c t ) . h^t = z^o \odot {\rm tanh} (c^t). ht=zotanh(ct).
  3. 当前时刻的输出 y t y^t yt 的计算公式如下:
    y t = σ ( W ′ h t ) . y^t = \sigma (W'h^t). yt=σ(Wht).

公式列举完后,这里说一下我对这些公式的理解(不一定是对的哈)。

  • 首先是 c t c^t ct 的计算。我们看到 c t c^t ct 的计算分为了两部分。一部分是 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1,这一部分是 LSTM 的遗忘过程,由于刚刚提到, z f z^f zf 是 (0, 1) 区间范围内的值,同时,sigmoid 函数是一个无限趋近于 0 或者 1 的函数,也就是说, c t − 1 c^{t-1} ct1 无论怎样,都会有些数据被遗弃,始终不会完全保留下来,这也就模拟了一个遗忘的过程。同理,对于记忆部分 z i ⊙ z z^i \odot z ziz,这一步也是只会保留部分 z z z 的信息,也就模拟了人的记忆是由些许失真的过程。同时,两者相加后,那么就代表了当前细胞状态 c t c^t ct 中保留的是没有被遗忘掉的过去的信息和当前时刻被记忆下来的信息
  • 接着是 h t h^t ht 的计算。首先是为什么要先对 c t c^t ct 做一次 t a n h ( ⋅ ) {\rm tanh}(\cdot) tanh(),这是因为由于 c t c^t ct 的区间范围不是 (-1, 1),因为 z i ⊙ z z^i \odot z ziz 的区间范围是 (-1, 1),再与 z f ⊙ c t − 1 z^f \odot c^{t-1} zfct1 相加,那么 c t c^t ct 的范围就有可能超出 (-1, 1),所以先用一个 tanh 将数值给放缩到 (-1, 1) 内。接着再与 z o z^o zo 做一次哈达玛积后,得到的隐藏状态就是 (-1, 1) 的数据,那么该数据放到后续模块中,就可以代表当前时刻的输入是正的还是负的,同时有多大。
  • 最后就是 y t y^t yt 的计算,实际上这就是个全连接层,将隐藏状态进行一次映射,再通过一个非线性变化的激活函数。

2. BiLSTM 的简单介绍

当然,其实你没看懂上面的部分也不重要,从使用的角度上来讲,会用就行了,就像你用手机,你不会去搞懂里面每个元器件是怎么做出来的,每个 APP 是怎么写出来的;就像你去打篮球,也不用梳个中分,穿个背带裤。

那么对于 BiLSTM,你需要了解的是什么?

  • 首先,这是一个序列模型,它接受一个序列的输入,并且输出这个序列的信息。对于序列中每个位置的输出,它会包含该位置的信息以及之前的信息。就是说 LSTM 能够捕获到位置 t t t 及其之前位置的信息。而对于 BiLSTM 的话,则能捕获到 t t t 的双向信息。
  • 如果是 BiLSTM,它的每个位置的输出,是前向 L S T M → \overrightarrow{LSTM} LSTM 的输出 y → \overrightarrow{y} y 与反向 L S T M ← \overleftarrow{LSTM} LSTM 的输出 y ← \overleftarrow{y} y 拼接在一起的, [ y → ; y ← ] [\overrightarrow{y}; \overleftarrow{y}] [y ;y ]。所以假设你设置 LSTM 的隐藏层维度为 128,那么单向 LSTM 的输出维度是 128,但是双向就是 256 (128*2).
  • 但是虽然说 LSTM 好像大概可能也许 maybe possibly 能够捕获长距离依赖信息哈,毕竟 LSTM 的全称都是 Long Short-Term Memory,但是实际上这是 LSTM 的骗局,LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力!LSTM 并没有捕获长距离依赖信息的能力! 从数学上说,你经过这么多次 sigmoid,还能保留个啥?当然,在《An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling》这篇论文[2]中,作者用了大量的实验来说明了,LSTM 不仅并行计算能力差(因为要上一个时间步的信息才能计算下一个时间步,所以 LSTM 不是个并行系统),同时在它最吹嘘的长距离信息捕获能力上,都不如 CNN,所以以后在跑实验的时候,可以尝试使用 TextCNN 来试试,说不定效果比 BiLSTM 好(反正我做过的实验中 TextCNN 性能一般比 BiLSTM 高8-10个点)。

3. BiLSTM 实现情感分析

  • 全部代码在 github 上,网址为:https://github.com/Balding-Lee/Pytorch4NLP
  • 我采用的是 IMDb 数据集,由于数据集没有验证集,而且读取起来很麻烦,所以我将数据给读取出来,放到了一个文件中,并且将训练集中的10%划分为了验证集,数据集链接如下: https://pan.baidu.com/s/128EYenTiEirEn0StR9slqw ,提取码:xtu3 。
  • 采用的词嵌入是谷歌的词嵌入,词嵌入的链接如下:链接:https://pan.baidu.com/s/1SPf8hmJCHF-kdV6vWLEbrQ ,提取码:r5vx

在本博客中仅介绍模型部分,详细代码见 github。

模型图如图所示:
模型图
具体而言,就是输入序列输入到一个双向 LSTM 中,并将双向 LSTM 的最后一个隐藏状态(即句向量)输入到一个全连接层(也可以说是分类器)中,输出最后的分类结果,具体模型的代码如下:

import torch.nn as nn

class BiLSTM_SA(nn.Module):

    def __init__(self, embed, config):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embed, freeze=False)
        self.LSTM = nn.LSTM(config.embed_size, config.lstm_hidden_size,
                            num_layers=config.num_layers, batch_first=True,
                            bidirectional=True)
        # 因为是双向 LSTM, 所以要乘2
        self.ffn = nn.Linear(config.lstm_hidden_size * 2,
                             config.dense_hidden_size)
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(config.dense_hidden_size,
                                    config.num_outputs)

    def forward(self, inputs):
        # shape: (batch_size, max_seq_length, embed_size)
        embed = self.embedding(inputs)
        # shape: (batch_size, max_seq_length, lstm_hidden_size * 2)
        lstm_hidden_states, _ = self.LSTM(embed)
        # LSTM 的最后一个时刻的隐藏状态, 即句向量
        # shape: (batch, lstm_hidden_size * 2)
        lstm_hidden_states = lstm_hidden_states[:, -1, :]
        # shape: (batch, dense_hidden_size)
        ffn_outputs = self.relu(self.ffn(lstm_hidden_states))
        # shape: (batch, num_outputs)
        logits = self.classifier(ffn_outputs)

        return logits

全连接层我采用了两个全连接层,一个将维度从 256 压缩到 128,另外一个是分类器。

这里有个小细节要注意一下,通常在论文的公式里面,我们都会看到别人写的分类器的公式如下: y ^ = S o f t m a x ( W h + b ) \hat{y} = {\rm Softmax}(Wh+b) y^=Softmax(Wh+b),有个 softmax 的激活函数,但是在 pytorch 中实际不需要,就比如我代码里面是写的:

logits = self.classifier(ffn_outputs)

而不是:

y_hat = self.softmax(self.classifier(ffn_outputs))

这是因为如果你后面选用交叉熵作为损失函数,而且调用的是torch中的 nn.CrossEntropyLoss(),那么就没必要在输出的时候用 softmax,这是因为 nn.CrossEntropyLoss() 中自带有 softmax 操作,虽然这样对你的分类结果不会产生任何影响,但是你得损失会变得很大。

最后的测试集的实验结果为:

test loss 0.419664 | test accuracy 0.813760 | test precision 0.804267 | test recall 0.829360 | test F1 0.816621

参考

[1] 陈诚. 人人都能看懂的LSTM[EB/OL]. https://zhuanlan.zhihu.com/p/32085405, 2018
[2] Shaojie Bai, J. Zico Kolter, Vladlen Koltun. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling [EB/OL]. https://arxiv.org/abs/1803.01271, 2018

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/166133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【三年面试五年模拟】算法工程师的独孤九剑秘籍(第十二式)

Rocky Ding公众号:WeThinkIn写在前面 【三年面试五年模拟】栏目专注于分享AI行业中实习/校招/社招维度的必备面积知识点与面试方法,并向着更实战,更真实,更从容的方向不断优化迭代。也欢迎大家提出宝贵的意见或优化ideas&#xff…

【算法】二叉树

❤️ Author: 老九 ☕️ 个人博客:老九的CSDN博客 🙏 个人名言:不可控之事 乐观面对 😍 系列专栏: 文章目录二叉树数组转化为二叉树二叉树转化为二叉链表二叉树的遍历排序二叉树BST(二叉搜索树&…

新思路dp

参考文章思路:点我 题:C. Count Binary Strings 前言:嗯,今天做这个题的时候,想了一堆乱七八糟的解法,想记录一下hhhhhh。 题意:输入以类似于邻接表的形式给出字符串(只由000和111组成&#…

深入理解MySQL中的bin log、redo log、undo log

bin log(二进制日志) 什么是bin log? 记录数据库执行的写入性操作信息,以二进制的形式保存在磁盘中。 由服务层产生,所有储存引擎都支持。 bin log属于逻辑日志。 bin log日志有三种格式:STATMENT、ROW、M…

详解数据库的锁机制及原理

详解数据库的锁机制及原理1.数据库锁的分类2.行锁共享锁排他锁更新锁3.意向锁4.锁机制解释数据库隔离级别1.数据库锁的分类 本图源自CSDN博主:Stephen.W 数据库锁一般可以分为两类,一个是悲观锁,一个是乐观锁 乐观锁一般是指用户自己实现的…

Java 中是如何获取 IP 属地的

细心的小伙伴可能会发现,抖音新上线了 IP 属地的功能,小伙伴在发表动态、发表评论以及聊天的时候,都会显示自己的 IP 属地信息下面,我就来讲讲,Java 中是如何获取 IP 属地的,主要分为以下几步 通过 HttpSer…

HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务

资料 Hugging Face 官方文档:https://huggingface.co/ Hugging Face 代码链接:https://github.com/huggingface/transformers 1. 环境准备 创建 conda 环境激活 conda 环境下载 transformers 依赖下载 transformers 中需要处理数据集的依赖下载 pytor…

win10录屏文件在哪?如何找到录制后的文件

在工作和学习中,我们会遇到需要使用录屏工具录制电脑屏幕的情况,很多小伙伴在录制完win10电脑屏幕之后,找不到录制的视频文件。win10录屏文件在哪?今天小编教大家如何找到电脑录屏文件和录制win10电脑屏幕的方法,如果您…

带你认识QOwnNotes

导读QOwnNotes 是一款自由而开源的笔记记录和待办事项的应用,可以运行在 Linux、Windows 和 mac 上。这款程序将你的笔记保存为纯文本文件,它支持 Markdown 支持,并与 ownCloud 云服务紧密集成。 QOwnNotes 的亮点就是它集成了 ownCloud 云服…

数据量大也不卡的bi软件有哪些?

用过数据分析软件的都知道,很多的软件在数据量不算特别大的时候还好,分析效率、响应速度都不慢,但一旦使用的数据量超过一定范围,系统就会明显变慢,甚至崩溃。随着企业业务的发展扩张,数据分析的精细化&…

Linksys WRT路由器刷入OpenWrt与原厂固件双固件及切换

Linksys路由器OpenWrt与原厂固件双固件刷入及切换双固件机制使用原厂固件刷其他固件使用原厂固件切换启动分区使用OpenWrt刷入Sysupgrade使用OpenWrt刷入Img使用OpenWrt切换分区通用的硬切换分区(三次重启)双固件机制 新机器默认有一个原厂固件&#xf…

详解分布式系统核心概念——CAP、CP和AP

最近研究Sykwalking,当调研 oap如何进行集群部署时发现:skywalking oap 之间本身不能搭建集群,需要一个集群管理器来组建集群,它支持nacos、zookeeper、Kubernetes、Consul、Etcd 五种集群管理器。我重点比较了nacos和zookeeper&a…

python中的闭包和装饰器

目录 一.闭包 1.闭包的用途和用法 简单闭包 2.nonlocal关键字的作用 ATM闭包实现 注意事项 小结 二.装饰器 装饰器的一般写法(闭包写法) 装饰器的语法糖写法 一.闭包 1.闭包的用途和用法 先看如下代码: 通过全局变量account_amount来…

【Python学习】条件和循环

前言 往期文章 【Python学习】列表和元组 【Python学习】字典和集合 条件控制 简单来说:当判断的条件为真时,执行某种代码逻辑,这就是条件控制。 那么在讲条件控制之前,可以给大家讲一个程序员当中流传的比较真实的一个例子…

CUDA规约算法(加和)

1.block内相邻元素规约(线程不连续) 上图为1个block内的16个线程的操作示意: 第0个线程会和第1,2,4,8发生关系 第2个线程会和第3个线程发生关系 第4个线程会和第5,6个线程发生关系 ... 以上…

这7个网络设备配置接口基本参数要牢记,从此接口相关配置不用怕!

本文给大家介绍网络设备配置接口基本参数,包括接口描述信息、接口流量统计时间间隔功能以及开启或关闭接口。 进入接口视图 背景信息 对接口进行基本配置前,需要进入接口视图。 操作步骤 执行命令system-view,进入系统视图。执行命令inte…

Widget小组件

目录 技能点 Widget背调 a. 设计定位 b. Widget小组件限制 c. Widget小组件 开发须知 d. 什么是 SwiftUI App Group 数据共享 a. 配置 App Groups 1、开发者账号配置,并更新pp证书 2、Xcode配置 b. 缓存数据共享-代码实现 1、文件存储 2. 沙盒存储&…

【MySQL】运算符及相关函数详解

序号系列文章3【MySQL】MySQL基本数据类型4【MySQL】MySQL表的七大约束5【MySQL】字符集与校对集详解6【MySQL】MySQL单表操作详解文章目录前言MySQL运算符1,算术运算符1.1,算术运算符的基本使用1.2,常用数学函数的基本使用2,比较…

vulnhub DC系列 DC-7

总结:社工尝试 目录 下载地址 漏洞分析 信息收集 ssh webshell 命令执行 提权 下载地址 DC-7.zip (Size: 939 MB)Download: http://www.five86.com/downloads/DC-7.zipDownload (Mirror): https://download.vulnhub.com/dc/DC-7.zip漏洞分析 信息收集 这里还…

代码随想录算法训练营第13天 239.滑动窗口最大值、347. 前 K 个高频元素

代码随想录算法训练营第13天 239.滑动窗口最大值、347. 前 K 个高频元素 滑动窗口最大值 力扣题目链接(opens new window) 给定一个数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只…