Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建

news2025/1/3 4:56:26

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

tips:安装依赖库

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install tqdm requests

一、RNN模型构建

数据集准备完成了输入文本通过查字典(序列化)的向量化。并使用nn.Embedding层加载了Glove词向量。下一步将使用RNN循环神经网络做特征提取,最后将RNN连接至全连接网络nn.Dednse,将特征转化为分类。

nn.Embedding -> nn.RNN -> nn.Dense

本项目,采用规避RNN梯度消的变种LSTM(Long short-term memory)代替RNN做特征提取层。

1.1 关于RNN

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

RNN-0

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN

1.2 关于LSTM(Long short-term memory)

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

LSTM

本项目选择LSTM变种而不是经典的RNN做特征提取,可规避梯度消失问题,并获得更好的模型效果。
在MindSpore中nn.LSTM对应的公式:

h 0 : t , ( h t , c t ) = LSTM ( x 0 : t , ( h 0 , c 0 ) ) h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0)) h0:t,(ht,ct)=LSTM(x0:t,(h0,c0))

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state`)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层

Time step:在循环神经网络计算的每一次循环,成为一个Time step。在送入文本序列时,一个Time step对应一个单词。因此在本例中,LSTM的输出 h 0 : t h_{0:t} h0:t对应每个单词的隐状态集合, h t h_t ht c t c_t ct对应最后一个单词对应的隐状态。

下一层:全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

1.3 特征提取网络构建

RNN循环神经网络: nn.LSTM()
初始化参数:

 embeddings:输入向量,
  hidden_dim:隐藏层特征的维度, 
  output_dim:输出维数, 
  n_layers:RNN 层的数量,
  bidirectional:是否为双向 RNN, 
   pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

tips:使用nn.embeddings()创建嵌入层时,可以通过padding_idx参数指定一个特定的索引,用于表示填充值。
embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0),将padding_idx设置为0,表示使用索引为0的词汇作为填充值。在文本序列中,我们将使用0来填充较短的序列。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniform

class RNN(nn.Cell):
    def __init__(self, embeddings, hidden_dim, output_dim, n_layers,
                 bidirectional, pad_idx):
        super().__init__()
        vocab_size, embedding_dim = embeddings.shape
        self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)
        self.rnn = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           batch_first=True)
        weight_init = HeUniform(math.sqrt(5))
        bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
        self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)

    def construct(self, inputs):
        embedded = self.embedding(inputs)
        _, (hidden, _) = self.rnn(embedded)
        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.fc(hidden)
        return output

实例化模型,打印输出

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')

model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
print(model)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1947491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

松下UV电源MID SONIC600 ANUP8304NAIS电源设备更新换下的

松下UV电源MID SONIC600 ANUP8304NAIS电源设备更新换下的

JL 跳转指令的理解

一般情况下&#xff0c;JU 和 JC 是最常见的跳转指令&#xff1b;但有时会用到JL 指令&#xff0c;JL 说起来更像是一组指令&#xff0c;类似C,C# 语言中的 switch case 语句&#xff0c;但是有个明显的不同&#xff0c;前者的判断条件可以是任意合理数字&#xff0c;后者范围…

洗地机什么品牌质量好耐用?口碑最好的洗地机排名分享

在追求高效、便捷的现代家居环境中&#xff0c;洗地机作为清洁工具的关键角色&#xff0c;其品牌与品质的选择成为了消费者关注的焦点。面对琳琅满目的洗地机市场&#xff0c;洗地机什么品牌质量好耐用&#xff1f;如何挑选出一款既高效又智能&#xff0c;且能带来卓越清洁体验…

算力共享:环形结构的算力分配策略

目录 算力共享:环形结构的算力分配策略 方法签名 方法实现 注意事项 nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True) end = round(start + (node[1].memory / total_memory), 5) 算力共享:环形结构的算力分配策略 这段代码定义了一个名为RingMemoryWeig…

基于 HTML+ECharts 实现智慧运维数据可视化大屏(含源码)

智慧运维数据可视化大屏&#xff1a;基于 HTML 和 ECharts 的实现 在现代企业中&#xff0c;运维管理是确保系统稳定运行的关键环节。随着数据量的激增&#xff0c;如何高效地监控和分析运维数据成为了一个重要课题。本文将介绍如何利用 HTML 和 ECharts 实现一个智慧运维数据可…

菜鸟从0学微服务——MyBatis-Plus

关于“菜鸟从0学微服务” 针对有编程基础&#xff0c;开始学习微服务的同学&#xff0c;我们陆续推出从0学微服务的笔记分享。力求从各个中间件的使用来反思这些中间件的作用和优势。 会分享的比较快&#xff0c;会记录demo演算和中间件的使用过程&#xff0c;至于细节的理论…

Python的人脸识别程序

1.录入人脸&#xff0c;输入ID号 haarcascade_frontalface_default.xml # 导入模块 import os import numpy as np import cv2 as cv import cv2face_detector cv2.CascadeClassifier(rD:\Automation_All_Files\OCR\haarcascade_frontalface_default.xml) # 待更改# 为即将…

【VTKExamples::Movie】制作并保存动画

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK中创建动画,并保存动画的方法,样例及样例源码,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ…

vue-快速入门

Vue 前端体系、前后端分离 1、概述 1.1、简介 Vue (发音为 /vjuː/&#xff0c;类似 view) 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS 和 JavaScript 构建&#xff0c;并提供了一套声明式的、组件化的编程模型&#xff0c;可以高效地开发用户界面。…

vue3实现在新标签中打开指定的网址

有一个文件列表&#xff0c;如下图&#xff1a; 我希望点击查看按钮的时候&#xff0c;能够在新的标签页面打开这个文件的地址进行预览&#xff0c;该如何实现呢&#xff1f; 比如&#xff1a; 实际上要实现这个并不难&#xff0c;参考demo如下&#xff1a; 首先&#x…

【Go系列】Go的UI框架GIO

其实主要我是要花一个折线图&#xff0c;但是使用Fyne貌似画不出来&#xff0c;使用plot也没法动态生成&#xff0c;听说Gio可以&#xff0c;那就先介绍一下什么是Gio把。 GIO&#xff08;gioui.org&#xff09;是一个用于Go语言的跨平台GUI库&#xff0c;旨在为开发人员提供构…

【ROS2】高级:安全-设置安全性

目标&#xff1a;使用 sros2 设置安全性。 教程级别&#xff1a;高级 时间&#xff1a;15 分钟 内容 背景 安装 从源代码安装选择替代中间件 运行演示 1. 为安全文件创建一个文件夹2. 生成密钥库3. 生成密钥和证书4. 配置环境变量5. 运行 talker/listener 演示 参加测验&#x…

一起搭WPF界面之MVVM架构的简单搭建

一起搭WPF界面之MVVM架构的简单搭建 1 前言2 创建项目2.1新建项目2.2WPF2.3创建完成 3 MVVM划分3.1 划分逻辑3.2文件夹创建3.3文件创建3.3.1 Views——可在主界面的基础上&#xff0c;划分多个用户控件模块3.3.2 ViewModels——创建数据结构存放的cs文件3.3.3 Models——创建处…

在 VM 虚拟机中安装 openEuler + 桌面

在 VM 虚拟机中安装 openEuler 1 介绍2 步骤语言Root 账户安装位置网络和主机名自动检索到【推荐】手动配置网络 软件选择安装完成登录测试网络curl ip / ping ipip link show / ip a如网络不通&#xff0c;可检查网卡状态和dns配置 安装命令设置以图形界面的方式启动【dde】第…

sql-libs通关详解

1-4关 1.第一关 我们输入?id1 看回显&#xff0c;通过回显来判断是否存在注入&#xff0c;以及用什么方式进行注入&#xff0c;直接上图 可以根据结果指定是字符型且存在sql注入漏洞。因为该页面存在回显&#xff0c;所以我们可以使用联合查询。联合查询原理简单说一下&…

PyTorch之ResNet101模型与示例

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客 ResNet101模型 ResNet101是一种深度残差网络&#xff0c;它是ResNet系列中的一种&#xff0c;下面详解ResNet101网络结构。 ResNet101网络结构中有101层&#xff0c;其中第一层是77的卷积层&#xff0c;然后是4个…

Nacos 配置中心配置加载源码分析

前言&#xff1a;上一篇我们分析 Nacos 配置中心服务端源码的时候&#xff0c;多次看到有去读取本地配置文件&#xff0c;那本地配置文件是何时加载的&#xff1f;本篇我们来进行详细分析。 Nacos 系列文章传送门&#xff1a; Nacos 初步认识和 Nacos 部署细节 Nacos 配置管…

https改造-python https 改造

文章目录 前言https改造-python https 改造1.1. https 配置信任库2. 客户端带证书https发送,、服务端关闭主机、ip验证 前言 如果您觉得有用的话&#xff0c;记得给博主点个赞&#xff0c;评论&#xff0c;收藏一键三连啊&#xff0c;写作不易啊^ _ ^。   而且听说点赞的人每…

遗传算法与深度学习实战——进化深度学习

遗传算法与深度学习实战——进化深度学习 0. 前言1. 进化深度学习1.1 进化深度学习简介1.2 进化计算简介 2. 进化深度学习应用场景3. 深度学习优化3.1 优化网络体系结构 4. 通过自动机器学习进行优化4.1 自动机器学习简介4.2 AutoML 工具 5. 进化深度学习应用5.1 模型选择&…

Java给定一些元素随机从中选择一个

文章目录 代码实现java.util.Random类实现随机取数(推荐)java.util.Collections实现(推荐)Java 8 Stream流实现(不推荐) 完整代码参考&#xff08;含测试数据&#xff09; 在Java中&#xff0c;要从给定的数据集合中随机选择一个元素&#xff0c;我们很容易想到可以使用 java.…