昇思25天学习打卡营第17天|linchenfengxue

news2025/1/8 4:50:02

RNN实现情感分类

概述

情感分类是自然语言处理中的经典任务,是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型,实现如下的效果:

输入: This film is terrible
正确标签: Negative
预测标签: Negative

输入: This film is great
正确标签: Positive
预测标签: Positive

数据准备

本节使用情感分类的经典数据集IMDB影评数据集,数据集包含Positive和Negative两类,下面为其样例:

ReviewLabel
"Quitting" may be as much about exiting a pre-ordained identity as about drug withdrawal. As a rural guy coming to Beijing, class and success must have struck this young artist face on as an appeal to separate from his roots and far surpass his peasant parents' acting success. Troubles arise, however, when the new man is too new, when it demands too big a departure from family, history, nature, and personal identity. The ensuing splits, and confusion between the imaginary and the real and the dissonance between the ordinary and the heroic are the stuff of a gut check on the one hand or a complete escape from self on the other.Negative
This movie is amazing because the fact that the real people portray themselves and their real life experience and do such a good job it's like they're almost living the past over again. Jia Hongsheng plays himself an actor who quit everything except music and drugs struggling with depression and searching for the meaning of life while being angry at everyone especially the people who care for him most.Positive

此外,需要使用预训练词向量对自然语言单词进行编码,以获取文本的语义特征,本节选取Glove词向量作为Embedding。

数据下载模块

为了方便数据集和预训练词向量的下载,首先设计数据下载模块,实现可视化下载流程,并保存至指定路径。数据下载模块使用requests库进行http请求,并通过tqdm库对下载百分比进行可视化。此外针对下载安全性,使用IO的方式下载临时文件,而后保存至指定的路径并返回。

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path

# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'

def http_get(url: str, temp_file: IO):
    """使用requests库下载数据,并使用tqdm库进行流程可视化"""
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()

def download(file_name: str, url: str):
    """下载数据并存为指定名称"""
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path
imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

加载IMDB数据集

下载好的IMDB数据集为tar.gz文件,我们使用Python的tarfile库对其进行读取,并将所有数据和标签分别进行存放。

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。 因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。这里我们使用Glove(Global Vectors for Word Representation)这种经典的预训练词向量, 其数据格式如下:

我们直接使用第一列的单词作为词表,使用dataset.text.Vocab将其按顺序加载;同时读取每一行的Vector并转为numpy.array,用于nn.Embedding加载权重使用。

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。

这里我们使用mindspore.dataset中提供的接口进行预处理操作。这里使用到的接口均为MindSpore的高性能数据引擎设计,每个接口对应操作视作数据流水线的一部分,详情请参考MindSpore数据引擎。 首先针对token到index id的查表操作,使用text.Lookup接口,将前文构造的词表加载,并指定unknown_token。其次为文本序列统一长度操作,使用PadEnd接口,此接口定义最大长度和补齐值(pad_value),这里我们取最大长度为500,填充值对应词表中<pad>的index id。

模型构建

完成数据集的处理后,我们设计用于情感分类的模型结构。首先需要将输入文本(即序列化后的index id列表)通过查表转为向量化表示,此时需要使用nn.Embedding层加载Glove词向量;然后使用RNN循环神经网络做特征提取;最后将RNN连接至一个全连接层,即nn.Dense,将特征转化为与分类数量相同的size,用于后续进行模型优化训练。整体模型结构如下:

nn.Embedding -> nn.RNN -> nn.Dense

这里我们使用能够一定程度规避RNN梯度消失问题的变种LSTM(Long short-term memory)做特征提取层。

Embedding

Embedding层又可称为EmbeddingLookup层,其作用是使用index id对权重矩阵对应id的向量进行查找,当输入为一个由index id组成的序列时,则查找并返回一个相同长度的矩阵,例如:

embedding = nn.Embedding(1000, 100) # 词表大小(index的取值范围)为1000,表示向量的size为100
input shape: (1, 16)                # 序列长度为16
output shape: (1, 16, 100)

这里我们使用前文处理好的Glove词向量矩阵,设置nn.Embeddingembedding_table为预训练词向量矩阵。对应的vocab_size为词表大小400002,embedding_size为选用的glove.6B.100d向量大小,即100。

RNN(循环神经网络)

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

本节我们选择LSTM变种而不是经典的RNN做特征提取,来规避梯度消失问题,并获得更好的模型效果。下面来看MindSpore中nn.LSTM对应的公式:

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层。

Dense

在经过LSTM编码获取句子特征后,将其送入一个全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

损失函数与优化器

完成模型主体构建后,首先根据指定的参数实例化网络;然后选择损失函数和优化器。针对本节情感分类问题的特性,即预测Positive或Negative的二分类问题,我们选择nn.BCEWithLogitsLoss(二分类交叉熵损失函数)。

训练逻辑

在完成模型构建,进行训练逻辑的设计。一般训练逻辑分为一下步骤:

  1. 读取一个Batch的数据;
  2. 送入网络,进行正向计算和反向传播,更新权重;
  3. 返回loss。

模型训练与保存

前序完成了模型构建和训练、评估逻辑的设计,下面进行模型训练。这里我们设置训练轮数为5轮。同时维护一个用于保存最优模型的变量best_valid_loss,根据每一轮评估的loss值,取loss值最小的轮次,将模型进行保存。为节省用例运行时长,此处num_epochs设置为2,可根据需要自行修改。

模型加载与测试

模型训练完成后,一般需要对模型进行测试或部署上线,此时需要加载已保存的最优模型(即checkpoint),供后续测试使用。这里我们直接使用MindSpore提供的Checkpoint加载和网络权重加载接口:1.将保存的模型Checkpoint加载到内存中,2.将Checkpoint加载至模型。

load_param_into_net接口会返回模型中没有和Checkpoint匹配的权重名,正确匹配时返回空列表。

自定义输入测试

最后我们设计一个预测函数,实现开头描述的效果,输入一句评价,获得评价的情感分类。具体包含以下步骤:

  1. 将输入句子进行分词;
  2. 使用词表获取对应的index id序列;
  3. index id序列转为Tensor;
  4. 送入模型获得预测结果;
  5. 打印输出预测结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1907346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1.8.0-矩阵乘法的反向传播-简单推导

1相关资料 之前分享过一个博客里面写的&#xff0c;我们大致了解并记住结论的博客&#xff1a;【深度学习】7-矩阵乘法运算的反向传播求梯度_矩阵梯度公式-CSDN博客&#xff1b;这里再分享一下自然语言处理书上关于这部分的推导过程&#xff1a;3-矩阵相乘-梯度反向传播的计算…

开源模型应用落地-FastAPI-助力模型交互-进阶篇(一)

一、前言 FastAPI 的高级用法可以为开发人员带来许多好处。它能帮助实现更复杂的路由逻辑和参数处理&#xff0c;使应用程序能够处理各种不同的请求场景&#xff0c;提高应用程序的灵活性和可扩展性。 在数据验证和转换方面&#xff0c;高级用法提供了更精细和准确的控制&#…

上海慕尼黑电子展开展,启明智显携物联网前沿方案亮相

随着科技创新的浪潮不断涌来&#xff0c;上海慕尼黑电子展在万众瞩目中盛大开幕。本次展会汇聚了全球顶尖的电子产品与技术解决方案&#xff0c;成为业界瞩目的焦点。启明智显作为物联网彩屏显示领域的佼佼者携产品亮相展会&#xff0c;为参展者带来了RTOS、LINUX全系列方案及A…

【见刊通知】MVIPIT 2023机器视觉、图像处理与影像技术国际会议

MVIPIT 2023&#xff1a;https://ieeexplore.ieee.org/xpl/conhome/10578343/proceeding 入库Ei数据库需等20-50天左右 第二届会议征稿启动&#xff08;MVIPIT 2024&#xff09; The 2nd International Conference on Machine Vision, Image Processing & Imaging Techn…

java —— tomcat 部署项目

一、通过 war 包部署 1、将项目导出为 war 包&#xff1b; 2、将 war 包放置在 tomcat 目录下的 webapps 文件夹下&#xff0c;该 war 包稍时便自动解析为项目文件夹&#xff1b; 3、启动 tomcat 的 /bin 目录下的 startup.bat 文件&#xff0c;此时即可从浏览器访问项目首页…

3.Python学习:模块\包\yaml

1.模块与包–互相引用 &#xff08;1&#xff09;一个模块就是一个.py文件 &#xff08;2&#xff09;有模块的目录–文件夹 &#xff08;3&#xff09;包&#xff1a;文件夹包含__init__.py文件 &#xff08;4&#xff09;导入包时&#xff0c;init.py文件里的内容会执行一次 …

聚鼎装饰画:装饰画喊个与现在是什么情况

回眸历史长河&#xff0c;装饰画以其独特的魅力一直为人类生活环境添彩增趣。从古埃及的壁画到文艺复兴时期的油画&#xff0c;再到现代简约的线条画&#xff0c;装饰画如同时代的缩影&#xff0c;映射出不同历史阶段的文化特征与审美趣味。 在现代社会&#xff0c;装饰画的现状…

C语言学习笔记[21]:分支语句if...else

C语言是结构化的程序设计语言 顺序结构选择结构循环结构 分支语句对应的就是选择结构&#xff0c;循环语句对应的就是循环结构 分支语句 if...elseswitch 循环语句 whilefordo...while goto语句 语句 C语言中由分号隔开的就是一条语句&#xff0c;比如&#xff1a; #…

猫咪浮毛多怎么办?一分钟推荐性价比高的养猫空气净化器排名

作为一名猫咖店老板&#xff0c;我发现很多铲屎官来店里咨询&#xff0c;在春夏换季时会频繁打喷嚏、全身过敏红肿。这是因为猫咪在换季时会大量掉毛&#xff0c;家里就像下雪一样&#xff0c;空气中充满了猫毛。这些猫毛上附带的细菌会随浮毛被人吸入&#xff0c;从而引发打喷…

【计算机毕业设计】021基于weixin小程序微信点餐

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

奇安信20240513笔试

题目一 解题思路 n转为字符串&#xff0c;如果位数为偶数&#xff0c;取前一半设为x&#xff0c;后一段为y&#xff0c;从x最低位开始&#xff0c;9&#xff0c;9*10&#xff0c;9*10*10。。。 到最高位&#xff0c;加x&#xff0c;如果x大于或等于y&#xff0c;加1. 位数为奇数…

数据结构——二叉树之c语言实现堆与堆排序

目录 前言&#xff1a; 1.二叉树的概念及结构 1.1 特殊的二叉树 1.2 二叉树的存储结构 1.顺序存储 2.链式存储 2. 二叉树的顺序结构及实现 2.1 堆的概念 ​编辑 2.2 堆的创建 3.堆的实现 3.1 堆的初始化和销毁 初始化&#xff1a; 销毁&#xff1a; 插入&…

React Hooks:上天在提醒你,别再用Class组件了!

React Hooks&#xff1a;上天在提醒你&#xff0c;别再用Class组件了&#xff01; React Hooks 的出现可以说是前端界的一场革命。它不仅让我们告别了繁琐的 Class 组件&#xff0c;还让代码变得更加简洁、易读、易维护。如果你还在固守 Class 组件的阵地&#xff0c;那么这篇…

如何处理 PostgreSQL 中由于表锁定导致的并发访问问题?

文章目录 一、表锁定的类型二、表锁定导致的并发访问问题三、解决方案&#xff08;一&#xff09;使用合适的锁定模式&#xff08;二&#xff09;优化事务处理&#xff08;三&#xff09;避免不必要的锁定&#xff08;四&#xff09;使用索引&#xff08;五&#xff09;监控和分…

Java-链表反转

题目&#xff1a; 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 图示&#xff1a; 输入&#xff1a; head [1,2,3,4,5] 输出&#xff1a; [5,4,3,2,1] 解题思路&#xff1a; 情况一&#xff1a; 只有一个节点或者没有节点 …

小白学C++(第一天)基础入门

温馨提醒&#xff1a;本篇文章&#xff0c;请各位c基础不行的童鞋不要贸然观看 C的第一个程序 第一个关键字namespace namespace 是定义空间的名字的关键字&#xff0c;使用格式格式如下&#xff1a; namespace 空间名 { } 其中{ }内的命名空间的成员&#xff0c;可以定义…

计算机图形学入门26:高级光线传播

1.有偏与无偏 在做光线追踪很多方法都是用蒙特卡洛积分去估计&#xff0c;蒙特卡洛积分有些是无偏的(Unbiased)&#xff0c;所谓无偏估计就是无论使用多少个样品&#xff0c;所估计的期望值都是正确的。那么&#xff0c;所有其他情况都是有偏的(Biased)&#xff0c;就是估计的期…

MySQL存储与优化 一、MySQL架构原理

1.MySQL体系架构 MySQL Server架构自顶向下大致可以分网络连接层、服务层、存储引擎层和系统文件层 (1)网络连接层 客户端连接器&#xff08;Client Connectors&#xff09;&#xff1a;提供与MySQL服务器建立的支持。目前几乎支持所有主流的服务端编程技术&#xff0c;例如常…

《初级C++》(一)

初级C&#xff08;一&#xff09; 1: C参考⽂档2&#xff1a;C创建与实现创建C的第一套程序命名空间的理解空间命名的实现C输⼊&输出缺省参数 1: C参考⽂档 https://legacy.cplusplus.com/reference/ 《非官方》 https://zh.cppreference.com/w/cpp 《官方中文版》 https:/…

前端面试题28(Vue3的Teleport功能在什么场景下特别有用?能给个例子吗?)

Vue 3 的 Teleport 功能在需要将组件的渲染结果放置在 DOM 树中与当前组件位置无关的任意位置时特别有用。这通常涉及到需要将某些UI元素&#xff08;如模态框、弹出菜单、通知、工具提示等&#xff09;从其逻辑上的父级组件中“提取”出来&#xff0c;放置到页面的更高层级或完…