Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练

news2025/2/23 7:19:01

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

模型训练与推理

1. 模型加载

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')
# 模型加载
model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)

其中:vocab, embeddings = load_glove(glove_path)
模型构建和实例化参数:

  embeddings:输入向量,是数据集经过glove模型统一处理的词向量数值特征,
  hidden_dim:隐藏层特征的维度, 
  output_dim:输出维数, 
  n_layers:RNN 层的数量,
  bidirectional:是否为双向 RNN, 
  pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

2.模型训练

def train():
    # 音频数据集
    imdb_path = r'./IMDB/aclImdb_v1.tar.gz'

    # 训练集和测试集生成
    imdb_train, imdb_test = load_imdb(imdb_path)  # review评论-标签,数据集

    # 预训练词向量表
    glove_path = r"./IMDB/glove.6B.zip"
    vocab, embeddings = load_glove(glove_path)  # 预定义词向量表

    # 语句标签-数据集。将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。每条评论500字。
    lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500],
                                  pad_value=vocab.tokens_to_ids('<pad>'))  # 使用PadEnd接口,定义最大长度和补齐值(pad_value),取最大长度为500
    type_cast_op = ds.transforms.TypeCast(ms.float32)  # 将label数据转为float32格式
    # 预处理操作流水线
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])
    imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])

    # 由于IMDB数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    # 调用数据集的map、split、batch为数据集处理流水线增加对应操作,返回值为新的Dataset类型。现在仅定义流水线操作,在执行时开始执行数据处理流水线,获取最终处理好的数据并送入模型进行训练。
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    # 定义训练参数
    hidden_size = 256
    output_size = 1
    num_layers = 2
    bidirectional = True
    lr = 0.001
    pad_idx = vocab.tokens_to_ids('<pad>')

    model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)

    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss

    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

    def train_step(data, label):
        loss, grads = grad_fn(data, label)
        optimizer(grads)
        return loss

    def train_one_epoch(model, train_dataset, epoch=0):
        model.set_train()
        total = train_dataset.get_dataset_size()
        loss_total = 0
        step_total = 0
        with tqdm(total=total) as t:
            t.set_description('Epoch %i' % epoch)
            for i in train_dataset.create_tuple_iterator():
                loss = train_step(*i)
                loss_total += loss.asnumpy()
                step_total += 1
                t.set_postfix(loss=loss_total / step_total)
                t.update(1)

    num_epochs = 50
    best_valid_loss = float('inf')
    ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')

    for epoch in range(num_epochs):
        train_one_epoch(model, imdb_train, epoch)
        valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            ms.save_checkpoint(model, ckpt_file_name)


if __name__ == "__main__":
    train()

训练完成:
在这里插入图片描述

3.小结

本节实现了情感分类模型的训练,精度达到99.9%,损失值降到0.0027。下一节将进行模型部署应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1954497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu上编译多个版本的frida

准备工作 Ubuntu20(WSL) 略 安装依赖 sudo apt update sudo apt-get install build-essential git lib32stdc-9-dev libc6-dev-i386 -y nodejs 去官网[1]下载nodejs&#xff0c;版本的话我就选的20.15.1&#xff1a; tar -xf node-v20.15.1-linux-x64.tar.xz 下载源码 …

科研论文之Word论文编辑

这篇文章介绍在word中怎么编辑论文&#xff0c;包括论文的模板、论文的字体设置、论文的插图、论文的参考文献等等。 为便利知识传播&#xff0c;我的所有文章都不会设置收费专栏。但文章写作不易&#xff0c;如有可能麻烦打赏一下&#xff0c;金额随意。收款码见下图&#xff…

Ubuntu下手动部署Java项目

1.1 打包项目上传至Ubuntu 1.2 java -jar 项目压缩包 1.3 确认防火墙打开 1.4 令进程在后台运行 nohup java -jar boot工程.jar &> hello.log & 1.5 停止项目运行 查看进程号&#xff0c;杀掉进程

基于微信小程序+SpringBoot+Vue的刷题系统(带1w+文档)

基于微信小程序SpringBootVue的刷题系统(带1w文档) 基于微信小程序SpringBootVue的刷题系统(带1w文档) 本系统是将网络技术和现代的管理理念相结合&#xff0c;根据试题信息的特点进行重新分配、整合形成动态的、分类明确的信息资源&#xff0c;实现了刷题的自动化&#xff0c;…

axure制作切换栏--动态面板的应用

先看下效果&#xff1a;点击上面的切换栏 切换到西游记栏目&#xff1a; 切换到水浒传栏目&#xff1a; 上述两个图片比对可以发现&#xff0c;在点击切换栏的时候&#xff0c;里面的内容以及切换栏的下面蓝色横线也会发生对应的变化。这里涉及到两个地方的变化&#xff0c;就…

VirtualBox虚拟机安装,Ubuntu iso 镜像下载

利用VirtualBox&#xff0c;在Windows主机上装Ubuntu的虚拟机 视频教程在这&#xff1a; Virtualbox虚拟机安装&#xff0c;Ubuntu iso镜像下载_哔哩哔哩_bilibili 一、Ubuntu iso 镜像下载 我们是要在Windows主机上装Ubuntu的虚拟机&#xff0c;下载下Ubuntu iso 镜像。下…

react中如何避免父子组件同时渲染(memo的使用)

1.需求说明 react的渲染机制是父子组件同时渲染&#xff0c;不管子组件是否有变化只要父组件重新渲染了子组件就跟着重新渲染。为了避免不必要的消耗&#xff0c;我们可以使用memo钩子函数 2.使用memo前展示 import { memo,useState } from "react"function Son()…

【03】Java虚拟机是如何加载Java类的

从class文件到内存中的类&#xff0c;按先后顺序需要经过加载、链接以及初始化三个步骤 一、加载 加载就是查找字节流&#xff0c;并且据此创建类的过程。 除了启动类加载器&#xff08;所有类加载器的祖师爷&#xff0c;由C实现&#xff0c;没有对应的Java对象&#xff09;之外…

uniapp实现局域网(内网)中APP自动检测版本,弹窗提醒升级

uniapp实现局域网&#xff08;内网&#xff09;中APP自动检测版本&#xff0c;弹窗提醒升级 在开发MES系统的过程中&#xff0c;涉及到了平板端APP的开发&#xff0c;既然是移动端的应用&#xff0c;那么肯定需要APP版本的自动更新功能。 查阅相关资料后&#xff0c;在uniapp的…

安全哈希算法:SHA算法

&#x1f3af; 主题简介 SHA&#xff08;Secure Hash Algorithm&#xff09;是比MD5更安全的哈希算法。通过案例形式了解SHA算法的原理、实现方法及注意细节。无论你是Python爱好者还是JavaScript高手&#xff0c;这篇内容都将为你提供一个深入了解SHA算法的机会。 &#x1f…

基于Libero的工程创建

基于Libero的工程创建 第一步&#xff1a;双击进入到工程界面&#xff0c;编写项目详细信息。 Project Name&#xff1a;标识您的项目名称。不要使用空格或保留的Verilog或VHDL关键字。 Project Location&#xff1a;在磁盘上标识您的项目位置。 Description&#xff1a;关于…

图论:1615. 最大网络秩(贪心,非完全图一定存在两个点之间没有边)

文章目录 1.计算出度排序哈希2.枚举3.贪心4.思考 1615. 最大网络秩 在不考虑两座道路直接相连时&#xff0c;我们求出入度&#xff08;或出度&#xff09;最大的两个点即可。 若相连&#xff0c;则存在一条边&#xff0c;所以我们将边存入一个集合中&#xff0c;快速查找是否存…

[每周一更]-(第107期):经典面试题-从输入URL到页面加载发生了什么

文章目录 过程概述简化版&#xff1a;详细版&#xff1a;1. 用户输入URL2. 浏览器解析URL3. DNS解析4. TCP连接5. SSL/TLS握手&#xff08;如果使用HTTPS&#xff09;6. HTTP请求和响应7. 浏览器渲染页面8. 处理后续请求 一般前后端都可以考察问题&#xff0c;让参与者了解网页…

WordPress设置固定连接后提示404

WordPress设置固定链接后出现404错误通常是因为服务器的伪静态规则没有正确设置。以下是几种常见的服务器环境下的解决方案&#xff1a; 宝塔面板&#xff1a;如果服务器安装了宝塔面板&#xff0c;可以在宝塔面板中选择对应的WordPress伪静态规则并保存设置 。 Apache服务器&a…

星间链路的卫星节点网络接口IP地址规划问题 based on 卫星互联网Walker星座

★★★第p个轨道面上的第n个卫星节点[ XL_p_n ]的IPv4子网和网络接口地址规划★★★ IPv4子网问题&#xff1a;中间2个点分十进制分别表示[P:轨道面索引]和[N:当前轨道面上的卫星索引]。考虑Exata设置IPv4子网默认为 190.0.0.0 &#xff0c;不妨&#xff1a; 将某个轨道高度的W…

【通信模块】简单玩转WiFi模块(ESP32、ESP8266)

笔者学习太极创客的学习笔记&#xff0c;链接如下&#xff1a;www.taichimaker.com 前期准备 电脑端口 固件烧录 WIFI到网页 对应七层网络协议 WIFI工作模式&#xff08;链路层&#xff09; 接入点模式、无线中断模式、混合模式 IP协议&#xff08;网络层&#xff09; 子网…

【python】Python考研分数 线性回归模型预测(源码+论文)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

前端缓存问题(浏览器缓存和http缓存)- 解决办法

问题描述&#xff1a;前端代码更新&#xff0c;但因浏览器缓存问题&#xff0c;导致页面源代码并未更新 查看页面源代码的方法&#xff1a;鼠标右键&#xff0c;点击查看页面源代码 如图&#xff1a; 解决方法&#xff1a; 注&#xff1a;每执行一步&#xff0c;就检查一下浏览…

c生万物系列(加减乘除模篇)

为了提高c语言的运行效率&#xff0c;我们需要采用更高效的运算&#xff0c;那么切入点就是随处可见的基本运算符合&#xff0c;从底层架构考虑&#xff0c;加减乘除的效率比位运算低很多&#xff0c;为了能够更好迎合CPU的二进制&#xff0c;有必要取代基本的加减乘除以及求余…

Java----队列(Queue)

目录 1.队列&#xff08;Queue&#xff09; 1.1概念 1.2队列的使用 1.3队列的模拟实现 1.4循环队列 1.4.1循环队列下标偏移 1.4.2如何区分队列是空还是满 1.5双端队列 (Deque) 1.队列&#xff08;Queue&#xff09; 1.1概念 队列&#xff1a;只允许在一端进行插入数据…