文本分类-RNN-LSTM

news2025/1/24 14:34:38

1.前言

        本节介绍RNN和LSTM,并采用它们在电影评论数据集上实现文本分类,会涉及以下几个知识点。

        1. 词表构建:包括数据清洗,词频统计,词频截断,词表构建。

        2. 预训练词向量应用:下载并加载Glove的预训练embedding进行训练,主要是如何把词向量放到nn.embedding层中的权重。

        3. RNN及LSTM构建:涉及nn.RNN和nn.LSTM的使用。

2.任务介绍

        本节采用的数据集是斯坦福大学的大型电影评论数据集(large movie review dataset) https://ai.stanford.edu/~amaas/data/sentiment/

        包含25000个训练样本,25000个测试样本,下载解压后得到aclImdb文件夹,aclImdb下有train和test,neg和pos下分别 有txt文件,txt中为电影评论文本。

         来看看一条具体的样本,train/pos/3_10.txt:

        本节任务就是对这样的一条文本进行处理,输出积极/消极的二分类概率向量。

3.数据模块

        文本任务与图像任务不同,输入不再是像素这样的数值,而是字符串,因此需要将字符串转为矩阵运算可接受的向量形 式。

         为此需要在数据处理模块完成以下步骤:

        a.分词:将一长串文本切分为一个个独立语义的词,英文可用空格来切分。

        b. 词嵌入:词嵌入通常分两步。首先将词字符串转为索引序号,然后索引序号根据词嵌入矩阵(embedding层)取对应的向量。其中词与索引之间的映射关系需要提前构建,这就是词表构建的过程。

        因此,代码开发整体流程:

        1. 编写分词功能函数

        2. 构建词表:对训练数据进行分词,统计词频,并构建词表。例如{'UNK': 0, 'PAD': 1, 'the': 2, '.': 3, 'and': 4, 'a': 5, 'of': 6, 'to': 7, ...}

        3. 编写PyTorch的Dataset,实现分词、词转序号、长度填充/截断序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

        序号转词向量的过程由模型的nn.Embedding层实现,因此数据模块只需将词变为索引序号即可,接下来一一解析各环节核心功能代码实现。

4.词表构建

        参考配套代码a_gen_vocabulary.py,首先编写分词功能函数,分词前做一些简单的数据清洗,例如在标点符号前加入空 格、去除掉不是大小写字母及 .!? 符号的数据。

        接着,写一个词表统计类实现词频统计,和词表字典的创建,代码注释非常详细,这里不赘述。 运行代码,即可完成词频统计,词表的构建,并保存到本地npy文件,在训练及推理过程中使用。

        在词表构建过程中有一个截断数量的超参数需要设置,这里设置为20000,即最多有20000个词的表示,不在字典中的词被归为UNK这个词。

         在这个数据集中,原始词表长度为74952,即通过split切分后,有7万多个不一样的字符串,通常可以通过降序排列,取前面一部分即可。

        代码会输出词频统计图,也可以观察出词频下降的速度以及高频词是哪些。

5.Dataset编写

        参考配套代码aclImdb_dataset.py,getitem中主要做两件事,首先获取label,然后获取文本预处理后的列表,列表中元素是词所对应的索引序号。

        在self.word2index.encode中需要注意设置文本最大长度self.max_len,这是由于需要将所有文本处理到相同长度,长度不足的用词填充,长度超出则截断。

6.模型模块——RNN

        模型的构建相对简单,理论知识在这里不介绍,需要了解和温习的推荐看看《动手学》。这里借助动手学的RNN图片讲解代码的实现。

        在构建的模型RNNTextClassifier中,需要三个子module,分别是:

                1. nn.Embedding:将词序号变为词向量,用于后续矩阵运算

                2. nn.RNN:循环神经网络的实现

                3. nn.Linear:最终分类输出层的实现

        在forward时,流程如下:

                1. 获取词向量

                2. 构建初始化隐藏层,默认为全0

                3. rnn推理获得输出层和隐藏层

                4. fc层输出分类概率:fc层的输入是rnn最后一个隐藏层

        更多关于nn.RNN的参数设置,可以参考官方文档:

        torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)

7.模型模块——LSTM

        RNN是神经网络中处理时序任务最为经典的设计,但是其也存在一些缺点,例如梯度消失和梯度爆炸,以及长期依赖问 题。

        当序列很长时,RNN模型很难捕捉到远距离的依赖关系,导致模型预测不准确。

        为此,带门控机制的RNN涌现,包括GRU(Gated Recurrent Unit,门控循环单元)和LSTM(Long Short-Term Memory,长短期记忆网络),其中LSTM应用最广,这里直接跳过GRU。         LSTM模型引入了三个门(input gate、forget gate和output gate),用于控制输入、输出和遗忘的流动,允许模型有选择性地忘记或记住一些信息。

        input gate用于控制输入的流动

        forget gate用于控制遗忘的流动

        output gate用于控制输出的流动

        相较于RNN,除了输出隐藏层向量h,还输出记忆层向量c,不过对于下游使用,不需要关心向量c的存在。 同样地,借助《动手学》中的LSTM示意图来理解代码。

        在这里,借鉴《动手学》的代码,采用的LSTM为双向LSTM,这里简单介绍双向循环神经网络的概念。

         双向循环神经网络(Bidirectional Recurrent Neural Network,Bi-RNN)同时考虑前向和后向的上下文信息,前向层和后向层的输出在每个时间步骤上都被连接起来,形成了一个综合的输出,这样可以更好地捕捉序列中的上下文信息。

        在pytorch代码中,只需要将bidirectional设置为True即可,

        nn.LSTM(embed_size, num_hiddens, num_layers=num_layers, bidirectional=True)。

        当采用双向时,需要注意output矩阵的shape为 [ sequence length , batch size ,2×hidden size]

        更多关于nn.LSTM的参数设置,可以参考官方文档:torch.nn.LSTM(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

        详细参考:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

8.embedding预训练加载

        模型构建好之后,词向量的embedding层是随机初始化的,要从头训练具备一定逻辑关系的词向量表示是费时费力的, 通常可以采用在大规模预料上训练好的词向量矩阵。

        这里可以参考斯坦福大学的GloVe(Global Vectors for Word Representation)预训练词向量。

        GloVe是一种无监督学习算法,用于获取单词的向量表示,GloVe预训练词向量可以有效地捕捉单词之间的语义关系,被广泛应用于自然语言处理领域的各种任务,例如文本分类、命名实体识别和机器翻译等。

        Glove有四大类,根据数据量不同进行区分,相同数据下又根据向量长度分

        a.Wikipedia 2014 + Gigaword 5 (6B tokens, 400K vocab, uncased, 50d, 100d, 200d, & 300d vectors, 822 MB download): glove.6B.zip

        b.Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors, 1.75 GB download): glove.42B.300d.zip

        c.Common Crawl (840B tokens, 2.2M vocab, cased, 300d vectors, 2.03 GB download): glove.840B.300d.zip

        d.Twitter (2B tweets, 27B tokens, 1.2M vocab, uncased, 25d, 50d, 100d, & 200d vectors, 1.42 GB download): glove.twitter.27B.zip

         在这里,采用Wikipedia 2014 + Gigaword 5 中的100d,即词向量长度为100,向量的token数量有6B。

        下载好的GloVe词向量矩阵是一个txt文件,一行是一个词和词向量,中间用空格隔开,因此加载该预训练词向量矩阵可以这样。

        原始GloVe预训练词向量有40万个词,在这里只关心词表中有的词,因此可以在加载字典时加一行过滤,即在词表中的词,才去获取它的词向量。

        在本案例中,词表大小是2万,根据匹配,只有19720个词在GloVe中找到了词向量,其余的词向量就需要随机初始化。

        获取GloVe预训练词向量字典后,需要把词向量放到embedding层中的矩阵,对弈embedding层来说,一行是一个词的词向量,因此通过词表的序号找到对应的行,然后把预训练词向量放进去即可,代码如下:

9.训练及实验记录

        准备好了数据和模型,接下来按照常规模型训练即可。

        这里将会做一些对比实验,包括模型对比:

         a.RNN vs LSTM

        b.有预训练词向量 vs 无预训练词向量

       c. 冻结预训练词向量 vs 放开预训练词向量

        具体指令如下,推荐放到bash文件中,一次性跑

        实验结果如下所示:

        1. RNN整体不work,经过分析发现设置的文本token长度太长,导致RNN梯度消失,以至于无法训练。调整 text_max_len为50后,train acc=0.8+, val=0.62,整体效果较差。

         2. 有了预训练词向量要比没有预训练词向量高出10多个点。

         3. 放开词向量训练,效果会好一些,但是不明显。

        补充实验:将RNN模型的文本最长token数量设置为50,其余保持不变,得到的三种embedding方式的结果如下:

        结论:

        1. LSTM较RNN在长文本处理上效果更好

        2. 预训练词向量在小样本数据集上很关键,有10多个点的提升

        3. 放开与冻结embedding层训练,效果差不多

10.小结

        本小节通过电影影评数据集实现文本分类任务,通过该任务可以了解:

        1. 文本预处理机制:包括清洗、分词、词频统计、词表构建、词表截断、UNK与PAD特殊词设定等。

        2. 预训练词向量使用:包括GloVe的下载及加载、nn.embedding层的设置 。

        3. RNN系列网络模型使用:大致了解循环神经网络的输入/输出是如何构建,如何配合fc层实现文本分类。

         4. RNN可接收的文本长度有限:文本过长,导致梯度消失,文本过短,导致无法捕获更多文本信息,因此推荐采用 LSTM等门控机制的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1869144.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

端到端图像分类算法开发实战:从 Arm 虚拟硬件到 Grove Vision AI Module V2 物理硬件

端到端图像分类算法开发实战:从 Arm 虚拟硬件到 Grove Vision AI Module V2 物理硬件 文章目录 1. 写在前面2. 产品简介2.1 Arm 虚拟硬件镜像产品简介2.2 Grove - Vision AI V2 产品简介 3. 实验前准备4. 实验步骤4.1 模型训练4.2 Arm 虚拟硬件镜像上的部署测试4.2…

【HarmonyOS NEXT】har 包的构建生成过程

Har模块文件结构 构建HAR 打包规则 开源HAR除了默认不需要打包的文件(build、node_modules、oh_modules、.cxx、.previewer、.hvigor、.gitignore、.ohpmignore)和.gitignore/.ohpmignore中配置的文件,cpp工程的CMakeLists.txt,…

【Python机器学习】自动化特征选择——迭代特征选择

在单变量测试中,没有使用模型;在基于模型的选择中,使用单个模型来选择特征。而在迭代特征选择中,将会构造一系列模型,每个模型都使用不同数量的特征。有两种基本方法: 1、开始时没有特征,然后逐…

【MySQL基础篇】概述及SQL指令:DDL及DML

数据库是一个按照数据结构来组织、存储和管理数据的仓库。以下是对数据库概念的详细解释:定义与基本概念: 数据库是长期存储在计算机内的、有组织的、可共享的、统一管理的大量数据的集合。 数据库不仅仅是数据的简单堆积,而是遵循一定的规则…

聚合项目学习

首先建立一个总的工程目录,里边后期会有我们的父工程、基础工程(继承父工程)、业务工程(依赖基础工程)等模块 1、在总工程目录中(open一个空的文件夹),首先建立一个父工程模块(通过spring init…

Unity中模拟抛物线(非Unity物理)

Unity中模拟抛物线非Unity物理 介绍剖析问题以及所需公式重力加速度公式:h 1/2*g*t*t(h 1/2 * g * t ^ 2)速度公式:Vt V初 a * t 主要代码总结 介绍 用Unity物理系统去做的抛物线想要控制速度或者想要细微的控制一些情况是非常困难的。所以想要脱离U…

【Linux系列】Fedora40安装VMware Workstation Pro报错

问题描述 由于Fedora 40使用的Linux内核是6.9,导致安装VMware Workstation Pro 时,安装依赖无法成功,具体报错如下 ..................CC [M] /tmp/modconfig-a8Fcf5/vmnet-only/smac.oCC [M] /tmp/modconfig-a8Fcf5/vmnet-only/vnetEvent.oCC [M] /tmp/modconfig-a8Fcf…

《数据勒索防范手册(1.0版)》

当前,数据勒索攻击已成为全球最严重的数据安全威胁之一攻击方式呈现 APT 化、平台化、多重化、AI驱动化等发展趋势:据统计,近年来针对制造业、公共事业、卫生保健、电力、交通、能源等领域的勒索攻击显著增加。随着云计算、边缘计算等技术的不断发展&…

深入探究小型语言模型 (SLM)

使用 Microsoft Bing Image Creator 创建 大型语言模型 (LLM) 已经流行了一段时间。最近,小型语言模型 (SLM) 增强了我们处理和使用各种自然语言和编程语言的能力。但是,一些用户查询需要比在通用语言上训练的模型所能提供的更高的准确性和领域知识。此外…

大疆车载的第一款油车智驾:上汽大众途观L Pro的智能辅助驾驶系统

引言 在自驾行业中,有一个低调但迅速崭露头角的选手——大疆车载。自2016年成立以来,大疆车载(现已更名为卓御)通过其先进的智能驾驶技术,逐渐在市场上赢得了声誉。此次,上汽大众途观L Pro成为大疆车载首款…

如何科学减肥先从了解自己在到饮食运动

在这个以瘦为美的时代,许多人被肥胖所困扰着, 今天就来教大家如何科学减脂。 一、如何判断自己是否需要减脂? 第一步就是判断自己的体重指数(BMI)是否在正常标准。BMI是国际上衡量人体胖瘦程度及是否健康的一个常用指…

打破生态「孤岛」,Catizen将开启Telegram小游戏2.0时代?

Catizen:引领Telegram x TON生态的顶级猫咪链游 在区块链游戏领域,吸引玩家的首要因素往往是游戏的趣味性。然而,仅靠趣味性无法评估一个项目的长期价值和发展潜力。真正能在区块链游戏市场中取得长久成功的项目,无一例外都依靠扎…

软件自动化测试有哪些流程?可替代手工测试吗?

随着科技的不断发展,软件在我们生活中的地位越来越重要。然而,在软件开发过程中,必然会出现各种各样的问题和bug,为了提高软件的质量和稳定性,保证用户的使用体验,软件自动化测试应运而生。 那么&#xff…

百元蓝牙耳机哪款性价比高?盘点性价比高的百元蓝牙耳机品牌

在如今快节奏的生活中,蓝牙耳机已经成为人们日常生活中不可或缺的配件。然而,市面上百元左右性价比高的蓝牙耳机琳琅满目,消费者往往难以选择到一款质量好、耐用的产品。我们希望可以为广大消费者提供一些参考和建议,接下来&#…

基于强化学习DQN的股票预测【股票交易】

强化学习笔记 第一章 强化学习基本概念 第二章 贝尔曼方程 第三章 贝尔曼最优方程 第四章 值迭代和策略迭代 第五章 强化学习实例分析:GridWorld 第六章 蒙特卡洛方法 第七章 Robbins-Monro算法 第八章 多臂老虎机 第九章 强化学习实例分析:CartPole 第十章 时序差分法 第十一…

商家转账到零钱开通指南

商家转账到零钱功能是微信支付开发的一款商家可以直接向个人微信发放零钱的产品,商家可通过此功能手动或者自动向多个微信用户发起转账。不过因为人工审核门槛的问题,不少商家很难自主通过申请,以下是经过我们上万次开通操作的经验总结&#…

观成科技:证券行业加密业务安全风险监测与防御技术研究

摘要:解决证券⾏业加密流量威胁问题、加密流量中的应⽤⻛险问题,对若⼲证券⾏业的实际流量内容进⾏调研分析, 分析了证券⾏业加密流量⾯临的合规性⻛险和加密协议及证书本⾝存在的⻛险、以及可能存在的外部加密流量威 胁,并提出防…

缓冲区溢出

本文作者:杉木涂鸦智能安全实验室 前置知识点 栈 栈(Stack)是计算机中的一种数据结构,用于存储临时数据。它的特点是后入先出(LIFO),只能在栈顶添加或删除数据。在程序中,栈被用于…

【JavaScript】JS对象和JSON

目录 一、创建JS对象 方式一:new Object() 方式二:{属性名:属性值,...,..., 方法名:function(){ } } 二、JSON格式 JSON格式语法: JSON与Java对象互转: 三、JS常见对象 3.1数组对象API 3.2 其它对象API 一、创建JS对象 方式一:new…

创新前沿:Web3如何颠覆传统计算机模式

随着Web3技术的快速发展,传统的计算机模式正面临着前所未有的挑战和改变。本文将深入探讨Web3技术的定义、原理以及它如何颠覆传统计算机模式,以及对全球科技发展的潜在影响。 1. 引言:Web3技术的兴起与背景 Web3不仅仅是技术创新的一种&…