昇思学习打卡-25-自然语言处理/RNN实现情感分类

news2025/1/16 5:49:05

文章目录

  • 数据下载
  • 加载预训练词向量
  • 数据集预处理
  • 模型构建
  • 损失函数与优化器
  • 训练逻辑
  • 评估指标和逻辑
  • 模型训练与保存
  • 模型加载与测试
  • 自定义输入测试
  • 测试

情感分类是自然语言处理中的经典任务,是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型,感觉跟这个任务很像链接: 昇思学习打卡-18-LLM原理与实践/MindNLP ChatGLM-6B StreamChat,本节使用情感分类的经典数据集IMDB影评数据集,数据集包含Positive和Negative两类。

数据下载

首先设计数据下载模块,实现可视化下载流程,并保存至指定路径。数据下载模块使用requests库进行http请求,并通过tqdm库对下载百分比进行可视化。此外针对下载安全性,使用IO的方式下载临时文件,而后保存至指定的路径并返回。

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path

# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'

def http_get(url: str, temp_file: IO):
    """使用requests库下载数据,并使用tqdm库进行流程可视化"""
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()

def download(file_name: str, url: str):
    """下载数据并存为指定名称"""
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path

数据集已分割为train和test两部分,且每部分包含neg和pos两个分类的文件夹,因此需分别train和test进行读取并处理数据和标签。

import re
import six
import string
import tarfile

class IMDBData():
    """IMDB数据集加载器

    加载IMDB数据集并处理为一个Python迭代对象。

    """
    label_map = {
        "pos": 1,
        "neg": 0
    }
    def __init__(self, path, mode="train"):
        self.mode = mode
        self.path = path
        self.docs, self.labels = [], []

        self._load("pos")
        self._load("neg")

    def _load(self, label):
        pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))
        # 将数据加载至内存
        with tarfile.open(self.path) as tarf:
            tf = tarf.next()
            while tf is not None:
                if bool(pattern.match(tf.name)):
                    # 对文本进行分词、去除标点和特殊字符、小写处理
                    self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
                                         .translate(None, six.b(string.punctuation)).lower()).split())
                    self.labels.append([self.label_map[label]])
                tf = tarf.next()

    def __getitem__(self, idx):
        return self.docs[idx], self.labels[idx]

    def __len__(self):
        return len(self.docs)

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。 因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。这里我们使用Glove(Global Vectors for Word Representation)这种经典的预训练词向量 直接使用第一列的单词作为词表,使用dataset.text.Vocab将其按顺序加载;同时读取每一行的Vector并转为numpy.array,用于nn.Embedding加载权重使用。具体实现如下:

import zipfile
import numpy as np

def load_glove(glove_path):
    glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')
    if not os.path.exists(glove_100d_path):
        glove_zip = zipfile.ZipFile(glove_path)
        glove_zip.extractall(cache_dir)

    embeddings = []
    tokens = []
    with open(glove_100d_path, encoding='utf-8') as gf:
        for glove in gf:
            word, embedding = glove.split(maxsplit=1)
            tokens.append(word)
            embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
    # 添加 <unk>, <pad> 两个特殊占位符对应的embedding
    embeddings.append(np.random.rand(100))
    embeddings.append(np.zeros((100,), np.float32))

    vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
    embeddings = np.array(embeddings).astype(np.float32)
    return vocab, embeddings

由于数据集中可能存在词表没有覆盖的单词,因此需要加入标记符;同时由于输入长度的不一致,在打包为一个batch时需要将短的文本进行填充,因此需要加入标记符。完成后的词表长度为原词表长度+2。

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用补齐,超出的进行截断。

这里我们使用mindspore.dataset中提供的接口进行预处理操作。这里使用到的接口均为MindSpore的高性能数据引擎设计,每个接口对应操作视作数据流水线的一部分,详情请参考MindSpore数据引擎。 首先针对token到index id的查表操作,使用text.Lookup接口,将前文构造的词表加载,并指定unknown_token。其次为文本序列统一长度操作,使用PadEnd接口,此接口定义最大长度和补齐值(pad_value),这里我们取最大长度为500,填充值对应词表中的index id。

  除了对数据集中text进行预处理外,由于后续模型训练的需要,要将label数据转为float32格式。
import mindspore as ms

lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
type_cast_op = ds.transforms.TypeCast(ms.float32)

完成预处理操作后,需将其加入到数据集处理流水线中,使用map接口对指定的column添加操作。

imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])

imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])

手动将IMDB数据集分割为训练和验证两部分,比例取0.7, 0.3。

imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])

模型构建

完成数据集的处理后,我们设计用于情感分类的模型结构。首先需要将输入文本(即序列化后的index id列表)通过查表转为向量化表示,此时需要使用nn.Embedding层加载Glove词向量;然后使用RNN循环神经网络做特征提取;最后将RNN连接至一个全连接层,即nn.Dense,将特征转化为与分类数量相同的size,用于后续进行模型优化训练。整体模型结构如下:nn.Embedding -> nn.RNN -> nn.Dense,这里我们使用能够一定程度规避RNN梯度消失问题的变种LSTM(Long short-term memory)做特征提取层。

损失函数与优化器

完成模型主体构建后,首先根据指定的参数实例化网络;然后选择损失函数和优化器。针对本节情感分类问题的特性,即预测Positive或Negative的二分类问题,我们选nn.BCEWithLogitsLoss(二分类交叉熵损失函数)。

训练逻辑

在完成模型构建,进行训练逻辑的设计。一般训练逻辑分为一下步骤:

  • 读取一个Batch的数据;
  • 送入网络,进行正向计算和反向传播,更新权重;
  • 返回loss。

下面按照此逻辑,使用tqdm库,设计训练一个epoch的函数,用于训练过程和loss的可视化。

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss
​
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):
    loss, grads = grad_fn(data, label)
    optimizer(grads)
    return loss
​
def train_one_epoch(model, train_dataset, epoch=0):
    model.set_train()
    total = train_dataset.get_dataset_size()
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in train_dataset.create_tuple_iterator():
            loss = train_step(*i)
            loss_total += loss.asnumpy()
            step_total += 1
            t.set_postfix(loss=loss_total/step_total)
            t.update(1)

评估指标和逻辑

训练逻辑完成后,需要对模型进行评估。即使用模型的预测结果和测试集的正确标签进行对比,求出预测的准确率。由于IMDB的情感分类为二分类问题,对预测值直接进行四舍五入即可获得分类标签(0或1),然后判断是否与正确标签相等即可。下面为二分类准确率计算函数实现:

def binary_accuracy(preds, y):
    """
    计算每个batch的准确率
    """

    # 对预测值进行四舍五入
    rounded_preds = np.around(ops.sigmoid(preds).asnumpy())
    correct = (rounded_preds == y).astype(np.float32)
    acc = correct.sum() / len(correct)
    return acc

有了准确率计算函数后,类似于训练逻辑,对评估逻辑进行设计, 分别为以下步骤:

  • 读取一个Batch的数据;
  • 送入网络,进行正向计算,获得预测结果;
  • 计算准确率。

同训练逻辑一样,使用tqdm进行loss和过程的可视化。此外返回评估loss至供保存模型时作为模型优劣的判断依据。

在进行evaluate时,使用的模型是不包含损失函数和优化器的网络主体; 在进行evaluate前,需要通过model.set_train(False)将模型置为评估状态,此时Dropout不生效。
def evaluate(model, test_dataset, criterion, epoch=0):
    total = test_dataset.get_dataset_size()
    epoch_loss = 0
    epoch_acc = 0
    step_total = 0
    model.set_train(False)

    with tqdm(total=total) as t:
        t.set_description('Epoch %i' % epoch)
        for i in test_dataset.create_tuple_iterator():
            predictions = model(i[0])
            loss = criterion(predictions, i[1])
            epoch_loss += loss.asnumpy()

            acc = binary_accuracy(predictions, i[1])
            epoch_acc += acc

            step_total += 1
            t.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)
            t.update(1)

    return epoch_loss / total

模型训练与保存

前序完成了模型构建和训练、评估逻辑的设计,下面进行模型训练。这里我们设置训练轮数为5轮。同时维护一个用于保存最优模型的变量best_valid_loss,根据每一轮评估的loss值,取loss值最小的轮次,将模型进行保存。为节省用例运行时长,此处num_epochs设置为2,可根据需要自行修改。

num_epochs = 2
best_valid_loss = float('inf')
ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')

for epoch in range(num_epochs):
    train_one_epoch(model, imdb_train, epoch)
    valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        ms.save_checkpoint(model, ckpt_file_name)

模型加载与测试

模型训练完成后,一般需要对模型进行测试或部署上线,此时需要加载已保存的最优模型(即checkpoint),供后续测试使用。这里我们直接使用MindSpore提供的Checkpoint加载和网络权重加载接口:1.将保存的模型Checkpoint加载到内存中,2.将Checkpoint加载至模型。

load_param_into_net接口会返回模型中没有和Checkpoint匹配的权重名,正确匹配时返回空列表。
param_dict = ms.load_checkpoint(ckpt_file_name)
ms.load_param_into_net(model, param_dict)

对测试集打batch,然后使用evaluate方法进行评估,得到模型在测试集上的效果。

imdb_test = imdb_test.batch(64)
evaluate(model, imdb_test, loss_fn)

自定义输入测试

最后我们设计一个预测函数,实现开头描述的效果,输入一句评价,获得评价的情感分类。具体包含以下步骤:

  • 将输入句子进行分词;
  • 使用词表获取对应的index id序列;
  • index id序列转为Tensor;
  • 送入模型获得预测结果;
  • 打印输出预测结果。
    具体实现如下:
score_map = {
    1: "Positive",
    0: "Negative"
}

def predict_sentiment(model, vocab, sentence):
    model.set_train(False)
    tokenized = sentence.lower().split()
    indexed = vocab.tokens_to_ids(tokenized)
    tensor = ms.Tensor(indexed, ms.int32)
    tensor = tensor.expand_dims(0)
    prediction = model(tensor)
    return score_map[int(np.round(ops.sigmoid(prediction).asnumpy()))]

测试

predict_sentiment(model, vocab, "This film is terrible")

在这里插入图片描述

此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1943784.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

android audio不同音频流,(六)settings内音频流音量调整

&#xff08;1&#xff09;settings内&#xff0c;可设置音频流音量&#xff0c;如下图&#xff1a; &#xff08;2&#xff09;settings调整音量条进度&#xff0c;会触发SeekBarVolumizer对象&#xff1a; SeekBarVolumizer文件路径&#xff1a; frameworks/base/core/java/…

设计模式13-单件模式

设计模式13-单件模式 写在前面对象性能模式典型模式1. 单例模式&#xff08;Singleton Pattern&#xff09;2. 享元模式&#xff08;Flyweight Pattern&#xff09;3. 原型模式&#xff08;Prototype Pattern&#xff09;4. 对象池模式&#xff08;Object Pool Pattern&#xf…

WebRTC QoS方法十三.2(Jitter延时的计算)

一、背景介绍 一些报文在网络传输中&#xff0c;会存在丢包重传和延时的情况。渲染时需要进行适当缓存&#xff0c;等待丢失被重传的报文或者正在路上传输的报文。 jitter延时计算是确认需要缓存的时间 另外&#xff0c;在检测到帧有重传情况时&#xff0c;也可适当在渲染时…

无人机图像目标检测技术详解

当前研究领域的热点之一。无人机搭载的高清摄像头能够实时捕获大量图像数据&#xff0c;对这些数据进行有效的目标检测对于军事侦察、环境监测、灾害救援等领域具有重要意义。本文将对无人机图像目标检测技术进行详解&#xff0c;包括图像处理技术、目标检测算法、关键技术应用…

LoFTR关键点特征匹配算法环境构建与图像匹配测试Demo

0&#xff0c;LoFTR CVPR 2021论文《LoFTR: Detector-Free Local Feature Matching with Transformers》开源代码 1&#xff0c;项目主页 LoFTR: Detector-Free Local Feature Matching with Transformers 2&#xff0c;GItHub主页 GitHub - zju3dv/LoFTR: Code for "…

Gen AI核心技术发展趋势分析

Gen AI核心技术解析及发展趋势 判别式模型&#xff0c;适用于处理回归与分类任务&#xff0c;其核心在于精准区分各类数据。与生成模型的生成新数据不同&#xff0c;判别模型专注于揭示输入特征与输出标签之间的紧密联系&#xff0c;从而实现准确分类或预测。在众多应用领域&am…

VS Code打开新文件会覆盖之前的窗口,解决办法

当我在VS Code中打开文件夹时&#xff0c;文件夹中只有一个文件能展示在窗口中&#xff0c;如果点击打开另外一个文件&#xff0c;之前打开的文件又会被覆盖。这样是无法进行文件之间的关联查找的。 要保证窗口可以打开多个文件&#xff0c;有不同的tab显示&#xff0c;设置如下…

好用的电脑屏幕监控软件推荐,什么软件能够监控电脑?

在当今信息化时代&#xff0c;电脑屏幕监控软件成为了企业管理、家长监管以及教育培训等领域的必备工具。通过实时监控电脑屏幕&#xff0c;这类软件可以有效提高工作效率&#xff0c;防止信息泄露&#xff0c;保障网络安全。本文将详细盘点几款主流的电脑屏幕监控软件&#xf…

PHP基础语法(四)

一、字符串类型 1、字符串定义语法 1&#xff09;单引号字符串&#xff1a;在单引号内部&#xff0c;所有的字符都会按照字面意义解释&#xff0c;不会进行变量替换或转义处理&#xff0c;除了 \ 表示单引号本身。 $str1 Hello, World!;2&#xff09;双引号字符串&#xff…

【机器学习算法基础】(基础机器学习课程)-08-决策树和随机森林-笔记

一、决策树之信息论基础 决策树是一种用来做决策的工具&#xff0c;就像我们生活中的选择树。例如&#xff0c;你在选择今天穿什么衣服时&#xff0c;会根据天气情况、出行活动等进行判断。决策树的构建过程涉及一些信息论的概念&#xff0c;用来衡量和选择最好的“分叉点”来进…

Unity打包设置

1.Resolution and Presentation (分辨率和显示) Fullscreen Window (全屏窗口): 应用程序将以全屏窗口模式运行&#xff0c;但不会独占屏幕。适用于想要全屏显示但仍需访问其他窗口的情况。 Resizable Window (可调整大小的窗口): 允许用户调整应用程序窗口的大小。适用于窗口…

Action通信 实践案例

Action通信 Server端实现 创建服务端功能包&#xff08;注意目录层级&#xff09;&#xff1a; ros2 pkg create my_exer05_action_server --build-type ament_cmake --node-name nav_server --dependencies rclcpp rclcpp_action my_exer_interfaces nav_msgs /*需求&#x…

【vim】ubuntu20-server 安装配置 vim 最新最详细

在编程的艺术世界里,代码和灵感需要寻找到最佳的交融点,才能打造出令人为之惊叹的作品。而在这座秋知叶i博客的殿堂里,我们将共同追寻这种完美结合,为未来的世界留下属于我们的独特印记。【vim】ubuntu20-server 安装配置 vim 最新最详细 开发环境一、vim github二、安装必…

基于SpringBoot实现验证码功能

目录 一 实现思路 二 代码实现 三 代码汇总 现在的登录都需要输入验证码用来检测是否是真人登录&#xff0c;所以验证码功能在现在是非常普遍的&#xff0c;那么接下来我们就基于springboot来实现验证码功能。 一 实现思路 今天我们介绍的是两种主流的验证码&#xff0c;一…

IP地址专用SSL/https证书——10分钟签发

一般常用的SSL证书多为域名型SSL证书&#xff0c;即需要提供准确的域名。如果不能提供域名&#xff0c;只能提供IP地址&#xff0c;则需要一种特殊的SSL证书——IP地址证书。下面是IP地址证书的申请教程 IP地址专用SSL证书获取链接https://www.joyssl.com/certificate/select/…

SQL中的LEFT JOIN、RIGHT JOIN和INNER JOIN

在SQL中&#xff0c;JOIN操作是连接两个或多个数据库表&#xff0c;并根据两个表之间的共同列&#xff08;通常是主键和外键&#xff09;返回数据的重要方法。其中&#xff0c;LEFT JOIN&#xff08;左连接&#xff09;、RIGHT JOIN&#xff08;右连接&#xff09;和INNER JOIN…

Open3D 将点云投影到球面

目录 一、概述 二、代码实现 2.1关键函数 2.2完整代码 三、实现效果 3.1原始点云 3.2投影后点云 前期试读&#xff0c;后续会将博客加入下列链接的专栏&#xff0c;欢迎订阅 Open3D点云算法与点云深度学习案例汇总&#xff08;长期更新&#xff09;-CSDN博客 一、概述…

【表单组件】地址组件新增精简模式

07/17 主要更新模块概览 快速筛选 精简模式 触发条件 自定义域名 01 表单管理 1.1 【表单组件】-数据关联组件新增快速筛选功能 说明&#xff1a; 数据关联组件新增快速筛选功能&#xff0c;用户在数据关联组件选择数据时&#xff0c;可以通过快速筛选功能&#xff0…

黑马头条Day07-app端文章搜索

一、今日内容介绍 1. App端搜索效果图 2. 今日内容 &#xff08;1&#xff09;文章搜索 Elasticsearch环境搭建索引库创建文章搜索多条件复合查询索引数据同步 &#xff08;2&#xff09;搜索历史记录 MongoDB环境搭建异步保存搜索历史查看搜索历史列表删除搜索历史 &…

Linux源码安装的Redis如何配置systemd管理并设置开机启动

文章目录 实验前提实验 实验前提 已完成源码安装并能正常启动redis /usr/local/bin/redis-server能正常启动redis 实验 vim /etc/systemd/system/redis.service内容如下&#xff1a; [unit] Descriptionredis-server Afternetwork.target[Service] Typeforking ExecStart/…