【MindSpore学习打卡】应用实践-自然语言处理-基于RNN的情感分类:使用MindSpore实现IMDB影评分类

news2025/2/26 6:11:51

情感分类是自然语言处理(NLP)中的一个经典任务,广泛应用于社交媒体分析、市场调研和客户反馈等领域。本篇博客将带领大家使用MindSpore框架,基于RNN(循环神经网络)实现一个情感分类模型。我们将详细介绍数据准备、模型构建、训练与评估等步骤,并最终实现对自定义输入的情感预测。

数据准备

下载和加载数据集

  • 为什么要使用requeststqdm库? requests库提供了简洁的HTTP请求接口,方便我们从网络上下载数据。tqdm库则用于显示下载进度条,帮助我们实时了解下载进度,提高用户体验。
  • 为什么要使用临时文件和shutil库? 使用临时文件可以确保下载的数据在出现意外中断时不会影响最终保存的文件。shutil库提供了高效的文件操作方法,确保数据能正确地从临时文件复制到最终保存路径。

我们使用IMDB影评数据集,这是一个经典的情感分类数据集,包含积极(Positive)和消极(Negative)两类影评。为了方便数据下载和处理,我们首先设计一个数据下载模块。

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path

# 指定保存路径
cache_dir = Path.home() / '.mindspore_examples'

def http_get(url: str, temp_file: IO):
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()

def download(file_name: str, url: str):
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path

下载IMDB数据集并进行解压和加载:

import re
import six
import string
import tarfile

class IMDBData():
    label_map = {"pos": 1, "neg": 0}
    def __init__(self, path, mode="train"):
        self.mode = mode
        self.path = path
        self.docs, self.labels = [], []
        self._load("pos")
        self._load("neg")

    def _load(self, label):
        pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))
        with tarfile.open(self.path) as tarf:
            tf = tarf.next()
            while tf is not None:
                if bool(pattern.match(tf.name)):
                    self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
                                         .translate(None, six.b(string.punctuation)).lower()).split())
                    self.labels.append([self.label_map[label]])
                tf = tarf.next()

    def __getitem__(self, idx):
        return self.docs[idx], self.labels[idx]

    def __len__(self):
        return len(self.docs)

加载训练数据集进行测试,输出数据集数量:

imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_train = IMDBData(imdb_path, 'train')
print(len(imdb_train))

加载预训练词向量

为什么要使用预训练词向量? 预训练词向量(如Glove)是基于大规模语料库训练得到的,能够捕捉到丰富的语义信息。使用预训练词向量可以提高模型的性能,尤其是在数据量较小的情况下。

  • 为什么要添加<unk><pad>标记符? <unk>标记符用于处理词汇表中未出现的单词,避免模型在遇到新词时无法处理。<pad>标记符用于填充不同长度的文本序列,使其能够被打包为一个batch进行并行计算,提高训练效率。

我们使用Glove预训练词向量对文本进行编码。以下是加载Glove词向量的代码:

import zipfile
import numpy as np

def load_glove(glove_path):
    glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')
    if not os.path.exists(glove_100d_path):
        glove_zip = zipfile.ZipFile(glove_path)
        glove_zip.extractall(cache_dir)

    embeddings = []
    tokens = []
    with open(glove_100d_path, encoding='utf-8') as gf:
        for glove in gf:
            word, embedding = glove.split(maxsplit=1)
            tokens.append(word)
            embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
    embeddings.append(np.random.rand(100))
    embeddings.append(np.zeros((100,), np.float32))

    vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
    embeddings = np.array(embeddings).astype(np.float32)
    return vocab, embeddings

glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
print(len(vocab.vocab()))

数据集预处理

为什么要对文本进行分词和去除标点? 分词和去除标点是文本预处理的基础步骤,有助于模型更好地理解文本的语义。将文本转为小写可以减少词汇表的大小,提高模型的泛化能力。

对加载的数据集进行预处理,包括将文本转为index id序列,并进行序列填充。

import mindspore as ms
import mindspore.dataset as ds

lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)

imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])

imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])

imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
imdb_valid = imdb_valid.batch(64, drop_remainder=True)

模型构建

接下来,我们构建用于情感分类的RNN模型。模型主要包括以下几层:

  1. Embedding层:将单词的index id转为词向量。
  2. RNN层:使用LSTM进行特征提取。
  3. 全连接层:将LSTM的输出特征映射到分类结果。

为什么选择LSTM而不是经典RNN? LSTM通过引入门控机制,有效地缓解了经典RNN中存在的梯度消失问题,能够更好地捕捉长距离依赖信息,提高模型的效果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniform

class RNN(nn.Cell):
    def __init__(self, embeddings, hidden_dim, output_dim, n_layers, bidirectional, pad_idx):
        super().__init__()
        vocab_size, embedding_dim = embeddings.shape
        self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, batch_first=True)
        weight_init = HeUniform(math.sqrt(5))
        bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
        self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)

    def construct(self, inputs):
        embedded = self.embedding(inputs)
        _, (hidden, _) = self.rnn(embedded)
        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.fc(hidden)
        return output

损失函数与优化器

我们使用二分类交叉熵损失函数nn.BCEWithLogitsLoss和Adam优化器。

为什么使用二分类交叉熵损失函数和Adam优化器? 二分类交叉熵损失函数适用于二分类任务,能够衡量模型预测结果与真实标签之间的差异。Adam优化器结合了动量和自适应学习率的优点,具有较快的收敛速度和较好的效果,广泛应用于深度学习模型的训练。

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')

model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

训练逻辑

设计训练一个epoch的函数,用于训练过程和loss的可视化。

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

def train_step(data, label):
    loss, grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_one_epoch(model, train_dataset, epoch=0):
    model.set_train()
    total = train_dataset.get_dataset_size()
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in train_dataset.create_tuple_iterator():
            loss = train_step(*i)
            loss_total += loss.asnumpy()
            step_total += 1
            t.set_postfix(loss=loss_total/step_total)
            t.update(1)

评估指标和逻辑

设计评估逻辑,计算模型在验证集上的准确率。

def binary_accuracy(preds, y):
    rounded_preds = np.around(ops.sigmoid(preds).asnumpy())
    correct = (rounded_preds == y).astype(np.float32)
    acc = correct.sum() / len(correct)
    return acc

def evaluate(model, test_dataset, criterion, epoch=0):
    total = test_dataset.get_dataset_size()
    epoch_loss = 0
    epoch_acc = 0
    step_total = 0
    model.set_train(False)

    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in test_dataset.create_tuple_iterator():
            predictions = model(i[0])
            loss = criterion(predictions, i[1])
            epoch_loss += loss.asnumpy()

            acc = binary_accuracy(predictions, i[1])
            epoch_acc += acc

            step_total += 1
            t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)
            t.update(1)

    return epoch_loss / total

模型训练与保存

进行模型训练,并保存最优模型。

num_epochs = 2
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')

for epoch in range(num_epochs):
    train_one_epoch(model, imdb_train, epoch)
    valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        ms.save_checkpoint(model, ckpt_file_name)

模型加载与测试

加载已保存的最优模型,并在测试集上进行评估。

param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)
imdb_test = imdb_test.batch(64)
evaluate(model, imdb_test, loss_fn)

自定义输入测试

设计一个预测函数,实现对自定义输入的情感预测。

score_map = {
    1: "Positive",
    0: "Negative"
}

def predict_sentiment(model, vocab, sentence):
    model.set_train(False)
    tokenized = sentence.lower().split()
    indexed = vocab.tokens_to_ids(tokenized)
    tensor = ms.Tensor(indexed, ms.int32)
    tensor = tensor.expand_dims(0)
    prediction = model(tensor)

通过本文的学习,我们成功地使用MindSpore框架实现了一个基于RNN的情感分类模型。我们从数据准备开始,详细讲解了如何加载和处理IMDB影评数据集,以及使用预训练的Glove词向量对文本进行编码。然后,我们构建了一个包含Embedding层、LSTM层和全连接层的情感分类模型,并使用二分类交叉熵损失函数和Adam优化器进行训练。最后,我们评估了模型在测试集上的性能,并实现了对自定义输入的情感预测。希望这篇博客能帮助你更好地理解RNN在自然语言处理中的应用,并激发你在NLP领域的更多探索和实践。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901793.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构之“栈”(全方位认识)

&#x1f339;个人主页&#x1f339;&#xff1a;喜欢草莓熊的bear &#x1f339;专栏&#x1f339;&#xff1a;数据结构 前言 栈是一种数据结构&#xff0c;具有" 后进先出 "的特点 或者也可见说是 ” 先进后出 “。大家一起加油吧冲冲冲&#xff01;&#xff01; …

SpringBoot | 大新闻项目后端(redis优化登录)

该项目的前篇内容的使用jwt令牌实现登录认证&#xff0c;使用Md5加密实现注册&#xff0c;在上一篇&#xff1a;http://t.csdnimg.cn/vn3rB 该篇主要内容&#xff1a;redis优化登录和ThreadLocal提供线程局部变量&#xff0c;以及该大新闻项目的主要代码。 redis优化登录 其实…

基于深度学习LightWeight的人体姿态检测跌倒系统源码

一. LightWeight概述 light weight openpose是openpose的简化版本&#xff0c;使用了openpose的大体流程。 Light weight openpose和openpose的区别是&#xff1a; a 前者使用的是Mobilenet V1&#xff08;到conv5_5&#xff09;&#xff0c;后者使用的是Vgg19&#xff08;前10…

Charles拦截发送数据包-cnblog

Charles拦截发送数据包 打开允许断点 右键要打断点的数据包&#xff0c;打断点 重新发请求进入断点模式 修改完毕后发送

uniapp实现一个键盘功能

前言 因为公司需要&#xff0c;所以我.... 演示 代码 键盘组件代码 <template><view class"keyboard_container"><view class"li" v-for"(item, index) in arr" :key"index" click"changArr(item)" :sty…

数据库管理-第216期 Oracle的高可用-01(20240703)

数据库管理216期 2024-07-03 数据库管理-第216期 Oracle的高可用-01&#xff08;20240703&#xff09;1 MAA简介2 MAA等级2.1 BRONZE2.2 SILVER2.3 GOLD2.4 PLATINUM 3 业务延续性总结 数据库管理-第216期 Oracle的高可用-01&#xff08;20240703&#xff09; 作者&#xff1a;…

Google Play上架:恶意软件、移动垃圾软件和行为透明度详细解析和解决办法 (一)

近期整理了许多开发者的拒审邮件和内容,也发现了许多问题,今天来说一下关于恶意软件这类拒审的问题。 目标邮件如下: 首先说一下各位小伙伴留言私信的一个方法,提供你的拒审邮件和时间,尽可能的详细,这样会帮助我们的团队了解你们的问题,去帮助小伙伴么解决问题。由于前…

Java面试八股之MySQL存储货币数据,用什么类型合适

MySQL存储货币数据&#xff0c;用什么类型合适 在MySQL中存储货币数据&#xff0c;最合适的类型是DECIMAL。这是因为货币数据通常需要高精度&#xff0c;尤其是对于财务交易&#xff0c;即使是极小的精度损失也可能导致严重的会计错误。DECIMAL类型可以提供固定的精度&#xf…

ASP.NET Core----基础学习02----中间件的执行顺序 静态文件中间件

文章目录 1.终端中间件&#xff08;Middleware&#xff09;2.中间件的执行顺序&#xff08;1&#xff09;当只有2个中间件的时候&#xff0c;先执行普通中间件&#xff0c;再执行终端中间件&#xff08;2&#xff09;当有多个中间件的时候&#xff0c;中间件的执行顺序 3.添加静…

中英双语介绍百老汇著名歌剧:《猫》(Cats)和《剧院魅影》(The Phantom of the Opera)

中文版 百老汇著名歌剧 百老汇&#xff08;Broadway&#xff09;是世界著名的剧院区&#xff0c;位于美国纽约市曼哈顿。这里汇集了许多著名的音乐剧和歌剧&#xff0c;吸引了全球各地的观众。以下是两部百老汇的经典音乐剧&#xff1a;《猫》和《剧院魅影》的详细介绍。 1.…

Hyper-V克隆虚拟机教程分享!

方法1. 使用导出导入功能克隆Hyper-V虚拟机 导出和导入是Hyper-V服务器备份和克隆的一种比较有效的方法。使用此功能&#xff0c;您可以创建Hyper-V虚拟机模板&#xff0c;其中包括软件、VM CPU、RAM和其他设备的配置&#xff0c;这有助于在Hyper-V中快速部署多个虚拟机。 在…

el-table实现固定列,及解决固定列导致部分滚动条无法拖动的问题

一、el-table实现固定列 当数据量动态变化时&#xff0c;可以为 Table 设置一个最大高度。 通过设置max-height属性为 Table 指定最大高度。此时若表格所需的高度大于最大高度&#xff0c;则会显示一个滚动条。 <div class"zn-filter-table"><!-- 表格--…

数字化精益生产系统--MRP 需求管理系统

MRP&#xff08;Material Requirements Planning&#xff0c;物料需求计划&#xff09;需求管理系统是一种在制造业中广泛应用的计划工具&#xff0c;旨在通过分析和计划企业生产和库存需求&#xff0c;优化资源利用&#xff0c;提高生产效率。以下是对MRP需求管理系统的功能设…

基于Java的网上花店系统

目 录 1 网上花店商品销售网站概述 1.1 课题简介 1.2 设计目的 1.3 系统开发所采用的技术 1.4 系统功能模块 2 数据库设计 2.1 建立的数据库名称 2.2 所使用的表 3 网上花店商品销售网站设计与实现 1. 用户注册模块 2. 用户登录模块 3. 鲜花列表模块 4. 用户购物车…

CTFShow的RE题(三)

数学不及格 strtol 函数 long strtol(char str, char **endptr, int base); 将字符串转换为长整型 就是解这个方程组了 主要就是 v4, v9的关系&#xff0c; 3v9-(v10v11v12)62d10d4673 v4 v12 v11 v10 0x13A31412F8C 得到 3*v9v419D024E75FF(1773860189695) 重点&…

#数据结构 链式栈

1. 概念 链式栈LinkStack 逻辑结构&#xff1a;线性结构物理结构&#xff1a;链式存储栈的特点&#xff1a;后进先出 栈具有后进先出的特点&#xff0c;我们使用链表来实现栈&#xff0c;即链式栈。那么栈顶是入栈和出栈的地方&#xff0c;单向链表有头有尾&#xff0c;那我…

性能测试相关理解(一)

根据学习全栈测试博主的课程做的笔记 一、说明 若未特别说明&#xff0c;涉及术语都是jmeter来说&#xff0c;线程数&#xff0c;就是jmeter线程组中的线程数 二、软件性能是什么 1、用户关注&#xff1a;响应时间 2、业务/产品关注&#xff1a;响应时间、支持多少并发数、…

创新配置,秒级采集,火爆短视频评论抓取

快速采集评论数据的好处 快速采集评论数据是在当今数字信息时代的市场趋势分析和用户反馈分析中至关重要的环节。通过准确获取并分析大量用户评论&#xff0c;您将能够更好地了解消费者的需求、情感和偏好。集蜂云采集平台提供了一种简单配置的方法&#xff0c;使您能够快速采…

NDVI数据集提取植被覆盖度FVC

植被覆盖度FVC 植被覆盖度&#xff08;Foliage Vegetation Cover&#xff0c;FVC&#xff09;是指植被冠层覆盖地表的面积比例&#xff0c;通常用来描述一个区域内植被的茂密程度或生长状况。它是生态学、环境科学以及地理信息系统等领域的重要指标&#xff0c;对于理解地表能…

A股继续3000以下震荡,而国外股市屡创新高,人民币反弹能带动A股吗?

今天的A股&#xff0c;让人愤愤不已&#xff0c;你知道是为什么吗&#xff1f;盘面上出现3个耐人寻味的重要信号&#xff0c;一起来看看&#xff1a; 1、今天两市一度回踩2920点&#xff0c;让股民的心都开始悬起来了。而午后市场行情有了转变&#xff0c;下跌的股票开始明显变…