基于Python的自然语言处理系列(10):使用双向LSTM进行文本分类

news2024/12/26 8:08:46

        在前一篇文章中,我们介绍了如何使用RNN进行文本分类。在这篇文章中,我们将进一步优化模型,使用双向多层LSTM来替代RNN,从而提高模型在序列数据上的表现。LSTM通过引入一个额外的记忆单元(cell state)来解决标准RNN中的梯度消失问题。此外,双向LSTM能够同时考虑句子前后的信息,进一步提高模型的性能。

1. LSTM与RNN的区别

        标准RNN容易在处理长序列时出现梯度消失或爆炸的现象,导致模型难以学习长期依赖。LSTM通过引入一个额外的cell state来存储和控制长期信息的流动,避免了梯度消失的问题。具体来说,LSTM使用了三个门来控制信息的流动:输入门、遗忘门和输出门。

        LSTM的计算公式如下:

        我们将在本文中实现一个双向多层LSTM,即同时使用正向和反向的LSTM来处理文本序列。

2. 数据预处理与FastText词嵌入

        首先,我们加载数据集,并使用与前面文章类似的预处理方法,包括使用spacy进行标记化、创建词汇表,并引入预训练的FastText词嵌入。

from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import FastText

# 加载数据集
train, test = AG_NEWS()

# 使用spacy进行标记化
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

# 构建词汇表
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab["<unk>"])

# 引入FastText词嵌入
fast_vectors = FastText(language='simple')
fast_embedding = fast_vectors.get_vecs_by_tokens(vocab.get_itos()).to(device)

3. LSTM模型设计

        在这部分中,我们设计了一个双向多层LSTM模型。我们使用nn.LSTM代替nn.RNN,并通过设置bidirectional=True来启用双向LSTM。此外,我们还将使用多层LSTM,通过设置num_layers=2来增加模型的复杂度。

import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout):
        super().__init__()
        # 嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=vocab['<pad>'])
        # 双向多层LSTM
        self.lstm = nn.LSTM(emb_dim, 
                           hid_dim, 
                           num_layers=num_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout,
                           batch_first=True)
        # 全连接层,接收双向LSTM的输出,因此乘以2
        self.fc = nn.Linear(hid_dim * 2, output_dim)
        
    def forward(self, text, text_lengths):
        # 嵌入层
        embedded = self.embedding(text)
        
        # 打包序列
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False, batch_first=True)
        
        # 通过LSTM
        packed_output, (hn, cn) = self.lstm(packed_embedded)
        
        # 解包序列
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        
        # 拼接正向和反向LSTM的输出
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim=1)
        
        return self.fc(hn)

4. 训练与评估

        我们将使用Adam优化器,并在训练过程中计算模型的损失和准确率。以下是完整的训练与评估代码:

import torch.optim as optim

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 计算准确率
def accuracy(preds, y):
    predicted = torch.max(preds.data, 1)[1]
    batch_corr = (predicted == y).sum()
    acc = batch_corr / len(y)
    return acc

# 训练函数
def train(model, loader, optimizer, criterion, loader_length):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    
    for i, (label, text, text_length) in enumerate(loader): 
        label = label.to(device)
        text = text.to(device)
                
        # 前向传播
        predictions = model(text, text_length).squeeze(1)
        
        # 计算损失和准确率
        loss = criterion(predictions, label)
        acc  = accuracy(predictions, label)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
                        
    return epoch_loss / loader_length, epoch_acc / loader_length

# 评估函数
def evaluate(model, loader, criterion, loader_length):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    
    with torch.no_grad():
        for i, (label, text, text_length) in enumerate(loader): 
            label = label.to(device)
            text = text.to(device)

            predictions = model(text, text_length).squeeze(1)
            
            loss = criterion(predictions, label)
            acc  = accuracy(predictions, label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / loader_length, epoch_acc / loader_length

        我们通过5个epoch训练模型,并保存最佳模型的状态。

num_epochs = 5
best_valid_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, len(train_loader))
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion, len(valid_loader))
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-model.pt')
    
    print(f'Epoch {epoch+1} | Train Loss: {train_loss:.3f}, Train Acc: {train_acc*100:.2f}%')
    print(f'Valid Loss: {valid_loss:.3f}, Valid Acc: {valid_acc*100:.2f}%')

5. 测试与预测

        训练完成后,我们可以使用模型对新文本进行预测。以下是如何使用训练好的模型预测随机新闻文本的类别:

def predict(text, text_length):
    with torch.no_grad():
        output = model(text, text_length).squeeze(1)
        predicted = torch.max(output.data, 1)[1]
        return predicted

test_str = "Google is now facing challenges in its business strategy."
text = torch.tensor(text_pipeline(test_str)).unsqueeze(0).to(device)
text_length = torch.tensor([text.size(1)]).to(device)

prediction = predict(text, text_length)
print(f'预测结果: {prediction.item()}')

结语

        在这篇文章中,我们通过引入双向LSTM改进了文本分类模型的性能。LSTM通过其独特的记忆单元门控机制,有效解决了传统RNN中存在的梯度消失问题,从而能够更好地捕捉长序列中的依赖关系。此外,双向LSTM的加入使模型不仅能够关注序列的前向信息,还能同时捕捉序列中的反向信息,这在处理自然语言中尤为重要。毕竟,在许多语言表达中,句子前后的词语和短语之间存在密切关联,双向LSTM的设计帮助我们更全面地理解文本中的语义。

        通过实验,我们观察到,双向多层LSTM能够显著提升文本分类任务的准确性。相较于传统RNN,LSTM不仅能够捕捉更长时间步的依赖,还通过多层结构让模型具有更深的语义理解能力。使用双向LSTM,模型在多个方向上进行信息处理,进一步提升了模型的学习能力。

        尽管LSTM在序列建模中展现了其优势,但它依然存在一些局限性。例如,当处理极长的序列时,LSTM的效率可能会受到影响。此外,虽然双向LSTM能够提供更好的上下文信息,但它的计算量也相应增加,尤其是当模型层数增加时,训练时间可能会大幅增长。因此,在实际应用中,我们还需要根据具体的任务场景平衡模型的性能和计算成本。

        在未来的研究和实践中,我们可以继续探索更为先进的模型,如Transformer,它在并行计算和长序列建模方面展现了强大的能力。此外,我们也可以尝试将LSTM与其他模型(如卷积神经网络CNN)结合,进一步提高模型的表达能力。

        总的来说,LSTM为处理自然语言中的序列数据提供了强大的工具,尤其是在文本分类、机器翻译、序列标注等任务中具有广泛的应用前景。通过掌握LSTM及其变种模型,开发者可以在更多复杂的自然语言处理任务中获得显著的性能提升。

        在下一篇文章中,我们将探索如何使用**卷积神经网络(CNN)**进行文本分类,CNN以其在图像处理中的成功经验,也能为文本分类任务提供一种有效的建模方式。我们将讨论如何将CNN应用于自然语言处理任务中,并通过实验验证其效果。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2141235.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux:vim编辑技巧

命令模式 光标跳转 输入18&#xff0c;再输入G&#xff0c;可以跳转到18行。 复制、粘贴、删除 P是往上一行粘贴 小写u可以撤销 查找/撤销/保存 大写U可能失效&#xff0c;用CTRLr 末行模式 保存/退出/文件操作 字符串替换 开关参数的控制

基于python+django+vue的在线学习资源推送系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于协同过滤pythondjangovue…

近乎实时的物联网数据管道架构

这篇论文的标题是《Near Real-Time IoT Data Pipeline Architectures》&#xff0c;作者是 Markus Multamki&#xff0c;完成于 2024 年&#xff0c;属于计算机科学与工程硕士学位论文。论文主要研究了物联网&#xff08;IoT&#xff09;数据分析的可扩展数据管道架构&#xff…

FloodFill算法【下】

417. 太平洋大西洋水流问题 题目链接&#xff1a;417. 太平洋大西洋水流问题 题目解析 题目给我们一个矩阵&#xff0c;这个矩阵相当于陆地&#xff0c;被两个洋包围&#xff0c;左和上代表太平洋&#xff0c;右和下代表大西洋。 矩阵里面的数字代表海拔&#xff0c;水可以…

STM32之FMC—扩展外部 SDRAM

文章目录 一、FMC外设介绍二、SDRAM 控制原理1、SDRAM关键参数a、容量、分区b、引脚SDRAM 使用 2、SDRAM芯片IS42S16400J3、SDRAM 控制引脚说明控制逻辑地址控制SDRAM 的存储阵列SDRAM 的命令预充电刷新 W9825G6KH&#xff1a;W9825G6KH引脚 三、STM32F429 FMC四、其他文章打开…

医学数据分析实训 项目四回归分析--预测帕金森病病情的严重程度

文章目录 项目四&#xff1a;回归分析实践目的实践平台实践内容 预测帕金森病病情的严重程度作业&#xff08;一&#xff09;数据读入及理解&#xff08;二&#xff09;数据准备&#xff08;三&#xff09;模型建立&#xff08;四&#xff09;模型预测&#xff08;五&#xff0…

神经网络通俗理解学习笔记(4) 深度生成模型VAE、GAN

深度生成模型 什么是生成式模型蒙特卡洛方法变分推断Variational Inference变分自编码器VAE生成对抗网络Generative Adversarial NetworkDiffusion 扩散模型VAE和GAN 代码实现 什么是生成式模型 判别式和生成式模型 判别式:CNN/RNN/transformer;生成式:AE/VAE/GAN 判别式模型学…

Linux:RPM软件包管理以及Yum软件包仓库

挂载光驱设备 RPM软件包管理 RPM软件包简介 区分软件名和软件包名 软件名&#xff1a;firefox 软件包名&#xff1a;firefox-52.7.0-1.el7.centos.x86_64.rpm 查询软件信息 查询软件&#xff08;参数为软件名&#xff09; ]# rpm -qa #当前系统中所有已安装的软件包 ]# r…

Unity实战案例全解析 :PVZ 植物脚本分析

植物都继承了Pants脚本&#xff0c;但是我因为没注意听讲&#xff0c;把Pants也挂在植物上了&#xff0c;所以子类的PlantEnableUpdate和PlantDisableUpdate抢不过父类&#xff0c;无法正确触发动画&#xff0c;我还找不到哪里出了问题&#xff0c;所以就使用了携程加while强行…

Navicat使用 笔记04

Navicat调用数据库 1.创建一个自己的链接&#xff08;文件-->新建连接-->MySQL&#xff09; 进入到这个界面中&#xff1a; 【注意&#xff1a;密码是下载登录软件时设定过的】 创建一个连接完成&#xff08;通过双击激活&#xff09;。 2.在创建好的连接中创建数据库…

神经网络通俗理解学习笔记(5) 自然语言处理

自然语言处理 词嵌入和word2vec词义搜索和句意表示预训练模型Hugging Face库介绍经典NLP数据集代码案例-电影评论情感分析 词嵌入和word2vec 词嵌入是一种 将高维的数据表示映射到低维空间的方法 word embedding 是将语言中的词编码成向量便于后续的分析和处理 词嵌入和词向量…

感知器神经网络

1、原理 感知器是一种前馈人工神经网络&#xff0c;是人工神经网络中的一种典型结构。感知器具有分层结构&#xff0c;信息从输入层进入网络&#xff0c;逐层向前传递至输出层。根据感知器神经元变换函数、隐层数以及权值调整规则的不同&#xff0c;可以形成具有各种功能特点的…

宿舍管理系统的设计与实现 (含源码+sql+视频导入教程)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 宿舍管理系统拥有三个角色&#xff0c;分别为系统管理员、宿舍管理员以及学生。其功能如下&#xff1a; 管理员&#xff1a;宿舍管理员管理、学生管理、宿舍楼管理、缺勤记录管理、个人密…

django学习入门系列之第十点《A 案例: 员工管理系统8》

文章目录 10.6 重写样式10.7 判断数据是否合法10.8 保存内容至数据库10.9 修改入职时间10.10 错误提示10.11 重写错误信息往期回顾 10.6 重写样式 注意&#xff1a;因为他框架都已经给你写好了&#xff0c;所以如果要使用样式的话可能要自己重新定义框架来进行修改 他有两种方…

衣食住行的投资与消费

机器人工程课程与科研采取敏捷开发的弊端和反思_工业机器人适合敏捷开发吗-CSDN博客 →学历消费者←自我救赎↑2024↓(*Φ皿Φ*)-CSDN博客 大部分衣食住行相关的产品都是消费品&#xff0c;只有极少部分是能保值的资产。 物以稀为贵&#xff0c;量产供应的一般而言都是消费品…

第二百三十五节 JPA教程 - JPA Lob列示例

JPA教程 - JPA Lob列示例 以下代码显示了如何使用Lob注释将字节数组保存到数据库。 LOB在数据库中有两种类型&#xff1a;字符大对象&#xff08;称为CLOB&#xff09;和二进制大对象&#xff08;或BLOB&#xff09;。 CLOB列保存大字符序列&#xff0c;BLOB列可存储大字节序…

JDK的选择安装和下载

搭建Java开发环境 要使用Java首先必须搭建Java的开发环境&#xff1b;Java的产品叫JDK&#xff08;Java Development Kit&#xff1a;Java开发工具包&#xff09;&#xff0c;必须安装JDK才能使用Java。 JDK发展史 那么这么多JDK&#xff0c;应该使用哪个版本&#xff0c;此处…

C# 比较对象新思路,利用反射技术打造更灵活的比较工具

前言 嘿&#xff0c;大家好&#xff01;如果你之前看过我分享的文章《C# 7个方法比较两个对象是否相等》&#xff0c;你可能会意识到对象比较在实际业务中经常出现的场景。今天&#xff0c;我想继续与大家分享一个在实际项目中遇到的问题。 有一次&#xff0c;我接手了一个别…

LLVM PASS-PWN-前置

文章目录 参考环境搭建基础知识![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/dced705dcbb045ceb8df2237c9b0fd71.png)LLVM IR实例1. **.ll 格式&#xff08;人类可读的文本格式&#xff09;**2. **.bc 格式&#xff08;二进制格式&#xff09;**3. **内存表示** …

无心剑英译张九龄《望月怀远》

望月怀远 Watching the Moon and Missing You Far Away 张九龄 By Zhang Jiuling 海上生明月&#xff0c;天涯共此时 情人怨遥夜&#xff0c;竟夕起相思 灭烛怜光满&#xff0c;披衣觉露滋 不堪盈手赠&#xff0c;还寝梦佳期 The bright moon rises from the sea, So far apart…