Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)

news2024/11/14 6:05:02

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

一、模型资源下载

  1. RNN升级版LSTM模型:本项目训练好的情感分类模型-下载训练好的IMDB分类模型。

二、模型加载与推理

class RNN(nn.Cell):
    def __init__(self, embeddings, hidden_dim, output_dim, n_layers,
                 bidirectional, pad_idx):
        super().__init__()
        vocab_size, embedding_dim = embeddings.shape
        self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings),
                                      padding_idx=pad_idx)
        self.rnn = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           batch_first=True)
        weight_init = HeUniform(math.sqrt(5))
        bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
        self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)

    def construct(self, inputs):
        embedded = self.embedding(inputs)
        _, (hidden, _) = self.rnn(embedded)
        hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)
        output = self.fc(hidden)
        return output

编写预测接口:test_interface

def predict_sentiment(model, vocab, sentence):
    score_map = {
        1: "Positive",
        0: "Negative"
    }
    model.set_train(False)
    tokenized = sentence.lower().split()
    indexed = vocab.tokens_to_ids(tokenized)
    tensor = ms.Tensor(indexed, ms.int32)
    tensor = tensor.expand_dims(0)
    prediction = model(tensor)
    return score_map[int(np.round(ops.sigmoid(prediction).asnumpy()))]

def test_interface():
    # train()
    score_map = {
        1: "Positive",
        0: "Negative"
    }

    ckpt_file_name = './IMDB/IMDB/sentiment-analysis.ckpt'
    # 预训练词向量表
    glove_path = r"./IMDB/IMDB/glove.6B.zip"
    vocab, embeddings = load_glove(glove_path)  # 预定义词向量表
    hidden_size = 256
    output_size = 1
    num_layers = 2
    bidirectional = True
    pad_idx = vocab.tokens_to_ids('<pad>')
    model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
    param_dict = ms.load_checkpoint(ckpt_file_name)
    ms.load_param_into_net(model, param_dict)

    # 预测
    while True:
        try:
            print("go on!")
            sentence = input("请输入:")
            res = predict_sentiment(model, vocab, sentence)
            print("用户输入的内容为:", sentence, "评价结果是:", res)
        except:
            break

def load_glove(glove_path):
    glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')  # 保存数据词典
    if not os.path.exists(glove_100d_path):
        glove_zip = zipfile.ZipFile(glove_path)
        glove_zip.extractall(cache_dir)

    embeddings = []
    tokens = []
    with open(glove_100d_path, encoding='utf-8') as gf:
        for glove in gf:
            word, embedding = glove.split(maxsplit=1)
            tokens.append(word)
            embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))
    # 添加 <unk>, <pad> 两个特殊占位符对应的embedding
    embeddings.append(np.random.rand(100))
    embeddings.append(np.zeros((100,), np.float32))

    vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)
    embeddings = np.array(embeddings).astype(np.float32)
    return vocab, embeddings

预测推理

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import os
import zipfile
import numpy as np

test_interface()

预测结果。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1946029.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vue实战教程】之 Vue Router 路由详解

Vue Router路由 1 路由基础 1.1 什么是路由 用Vue.js创建的项目是单页面应用&#xff0c;如果想要在项目中模拟出来类似于页面跳转的效果&#xff0c;就要使用路由。其实&#xff0c;我们不能只从字面的意思来理解路由&#xff0c;从字面上来看&#xff0c;很容易把路由联想…

CV13_混淆矩阵、F1分数和ROC曲线

1.1 混淆矩阵Confusion Matrix 混淆矩阵&#xff08;Confusion Matrix&#xff09;是机器学习和统计学中用于描述监督学习算法性能的特定表格布局。它是一种特定类型的误差矩阵&#xff0c;可以非常直观地表示分类模型在测试数据集上的预测结果与实际结果之间的对比。 混淆矩…

谣言检测文献阅读十二—A Convolutional Approach for Misinformation Identification

系列文章目录 谣言检测文献阅读一—A Review on Rumour Prediction and Veracity Assessment in Online Social Network谣言检测文献阅读二—Earlier detection of rumors in online social networks using certainty‑factor‑based convolutional neural networks谣言检测文…

FreeModbus学习——读输入寄存器eMBFuncReadInputRegister

FreeModbus版本&#xff1a;1.6 当功能码为04时&#xff0c;也就是读输入寄存器MB_FUNC_READ_INPUT_REGISTER 看一下它是怎么调用读输入寄存器处理函数的 当功能码为04时&#xff0c;调用读输入寄存器处理函数 这个函数在数组xFuncHandlers中&#xff0c;也就是eMBFuncRead…

Mysql数据库第四次作业

mysql> create table student(sno int primary key auto_increment,sname varchar(30) not null unique,Ssex varchar(2) check (Ssex男 or Ssex女) not null,Sage int not null,Sdept varchar(10) default计算机 not null); mysql> create table Course(Con int primar…

【通信协议-RTCM】MSM语句(2) - RINEXMSM7语句总结(重要!自动化开发计算卫星状态常用)

注释&#xff1a; 在工作中主要负责的是RTCM-MSM7语句相关开发工作&#xff0c;所以主要介绍的就是MSM7语句相关内容 1. 相位校准参考信号 2. MSM1、MSM2、MSM3、MSM4、MSM5、MSM6和MSM7的消息头内容 DATA FIELDDF NUMBERDATA TYPENO. OF BITSNOTES Message Number - 消息编…

1. Docker的介绍和安装 (二)

5 Docker的原理 5.1 Namespace Namespace&#xff08;命名空间&#xff09;提供了一个独立的工作环境&#xff0c;Docker使用Namespace来隔离容器&#xff0c;使得每个容器都有自己独立的系统资源&#xff08;如进程ID、主机名、网络等&#xff09;。 PID Namespace&#xf…

SBTI科学碳目标认证是什么?SBTI科学碳目标的重要性

SBTI科学碳目标认证&#xff0c;作为企业在应对气候变化和追求可持续发展道路上的重要里程碑&#xff0c;其认证过程严谨而系统。以下是获得SBTI科学碳目标认证的详细步骤&#xff1a; 首先&#xff0c;企业需要在线注册并提交承诺书&#xff0c;郑重承诺在未来24个月内提交科学…

Linux网络-配置IP

作者介绍&#xff1a;简历上没有一个精通的运维工程师。希望大家多多关注作者&#xff0c;下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 本来IP配置应该放在Linux安装完成的就要配置的&#xff0c;但是由于那个时候对Linux不怎么熟悉&#xff0c;所以单独列了一个…

每日一题 LeetCode03 无重复字符的最长字串

1.题目描述 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的最长字串的长度。 2 思路 可以用两个指针, 滑动窗口的思想来做这道题,即定义两个指针.一个left和一个right 并且用一个set容器,一个length , 一个maxlength来记录, 让right往右走,并且用一个set容器来…

扫雷-C语言

一、前言&#xff1a; 众所周知&#xff0c;扫雷是一款大众类的益智小游戏&#xff0c;它的游戏目标是在最短的时间内根据点击格子出现的数字找出所有非雷格子&#xff0c;同时避免踩雷&#xff0c;踩到一个雷即全盘皆输。 今天&#xff0c;我们的目的就是通过C语言来实现一个简…

Open函数使用 Json与pickle Os模块

一. 文件操作与 open() 函数 Open函数是Python中用于打开文件的内置函数&#xff0c;其基本语法如下&#xff1a; open(file, moder, buffering-1, encodingNone, errorsNone, newlineNone, closefdTrue, openerNone) 各参数说明&#xff1a; file&#xff1a; 要打开的文件…

34_YOLOv5网络详解

1.1 简介 YOLOV5是YOLO&#xff08;You Only Look Once&#xff09;系列目标检测模型的一个重要版本&#xff0c;由 Ultralytics 公司的Glenn Jocher开发并维护。YOLO系列以其快速、准确的目标检测能力而闻名&#xff0c;尤其适合实时应用。YOLOV5在保持高效的同时&#xff0c…

El-Table 表格的表头字段切换

最近写了一个小功能&#xff0c;比较有意思&#xff0c;特此博客记录。 提出需求&#xff1a;需要表头字段变化&#xff0c;但是我在官网上的表格相关上查找&#xff0c;没有发现便捷方法。 于是我有两个想法&#xff1a;1.做三个不同的表格。2.做一个表格使用不同的表头字段。…

2024.7.24 远程连接到另一设备(win)上的vrep时无响应(防火墙!)

Windows防火墙禁止了软件的端口的通信&#xff0c;打开即可 如何设置Windows 7 防火墙端口规则

字典集合案例

1.统计字符 统计字符串中每个字符出现的次数 s l like summer very much #去掉空格 s s.replace(" ","") d dict() for i in s:if i in d:d[i] 1else:d[i] 1 for i in d:print(i,d[i]) 2.求不重复的随机数 #导入随机数 import random a int(input(&q…

VMware 上安装 CentOS 7 教程 (包含网络设置)

**建议先看一些我安装VMware的教程&#xff0c;有些网络配置需要做一下 1.打开VMware&#xff0c;创建虚拟机 2.勾选自定义&#xff0c;点击下一步 3.点击下一步 4.勾选“稍后安装操作系统”&#xff0c;点击下一步 5.勾选linux&#xff0c;勾选centos7&#xff0c;点击下一步…

每日Attention学习12——Exterior Contextual-Relation Module

模块出处 [ISBI 22] [link] [code] Duplex Contextual Relation Network for Polyp Segmentation 模块名称 Exterior Contextual-Relation Module (ECRM) 模块作用 内存型特征增强模块 模块结构 模块思想 原文表述&#xff1a;在临床环境中&#xff0c;不同样本之间存在息肉…

Python算法基础:解锁冒泡排序与选择排序的奥秘

在数据处理和算法设计中&#xff0c;排序是一项基础且重要的操作。本文将介绍两种经典的排序算法&#xff1a;冒泡排序&#xff08;Bubble Sort&#xff09;和选择排序&#xff08;Selection Sort&#xff09;。我们将通过示例代码来演示这两种算法如何对列表进行升序排列。 一…

BGP选路之Local Preference

原理概述 当一台BGP路由器中存在多条去往同一目标网络的BGP路由时&#xff0c;BGP协议会对这些BGP路由的属性进行比较&#xff0c;以确定去往该目标网络的最优BGP路由。BGP首先比较的是路由信息的首选值&#xff08;PrefVal)&#xff0c;如果 PrefVal相同&#xff0c;就会比较本…