竞赛 基于GRU的 电影评论情感分析 - python 深度学习 情感分类

news2025/1/23 13:13:48

文章目录

  • 1 前言
    • 1.1 项目介绍
  • 2 情感分类介绍
  • 3 数据集
  • 4 实现
    • 4.1 数据预处理
    • 4.2 构建网络
    • 4.3 训练模型
    • 4.4 模型评估
    • 4.5 模型预测
  • 5 最后

1 前言

🔥 优质竞赛项目系列,今天要分享的是

基于GRU的 电影评论情感分析

该项目较为新颖,适合作为竞赛课题方向,学长非常推荐!

🧿 更多资料, 项目分享:

https://gitee.com/dancheng-senior/postgraduate

1.1 项目介绍

其实,很明显这个项目和微博谣言检测是一样的,也是个二分类的问题,因此,我们可以用到学长之前提到的各种方法,即:

朴素贝叶斯或者逻辑回归以及支持向量机都可以解决这个问题。

另外在深度学习中,我们可以用CNN-Text或者RNN以及LSTM等模型最好。

当然在构建网络中也相对简单,相对而言,LSTM就比较复杂了,为了让不同层次的同学们可以接受,学长就用了相对简单的GRU模型。

如果大家想了解LSTM。以后,学长会给大家详细介绍。

2 情感分类介绍

其实情感分析在自然语言处理中,情感分析一般指判断一段文本所表达的情绪状态,属于文本分类问题。一般而言:情绪类别:正面/负面。当然,这就是为什么本人在前面提到情感分析实际上也是二分类问题的原因。

3 数据集

学长本次使用的是非常典型的IMDB数据集。

该数据集包含来自互联网的50000条严重两极分化的评论,该数据被分为用于训练的25000条评论和用于测试的25000条评论,训练集和测试集都包含50%的正面评价和50%的负面评价。该数据集已经经过预处理:评论(单词序列)已经被转换为整数序列,其中每个整数代表字典中的某个单词。

查看其数据集的文件夹:这是train和test文件夹。

在这里插入图片描述

接下来就是以train文件夹介绍里面的内容
在这里插入图片描述

然后就是以neg文件夹介绍里面的内容,里面会有若干的text文件:
在这里插入图片描述

4 实现

4.1 数据预处理



    #导入必要的包
    import zipfile
    import os
    import io
    import random
    import json
    import matplotlib.pyplot as plt
    import numpy as np
    import paddle
    import paddle.fluid as fluid
    from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Embedding
    from paddle.fluid.dygraph.base import to_variable
    from paddle.fluid.dygraph import GRUUnit
    import paddle.dataset.imdb as imdb


    #加载字典
    def load_vocab():
        vocab = imdb.word_dict()
        return vocab
    #定义数据生成器
    class SentaProcessor(object):
        def __init__(self):
            self.vocab = load_vocab()
    
        def data_generator(self, batch_size, phase='train'):
            if phase == "train":
                return paddle.batch(paddle.reader.shuffle(imdb.train(self.vocab),25000), batch_size, drop_last=True)
            elif phase == "eval":
                return paddle.batch(imdb.test(self.vocab), batch_size,drop_last=True)
            else:
                raise ValueError(
                    "Unknown phase, which should be in ['train', 'eval']")



步骤

  1. 首先导入必要的第三方库

  2. 接下来就是数据预处理,需要注意的是:数据是以数据标签的方式表示一个句子,因此,每个句子都是以一串整数来表示的,每个数字都是对应一个单词。当然,数据集就会有一个数据集字典,这个字典是训练数据中出现单词对应的数字标签。

4.2 构建网络

这次的GRU模型分为以下的几个步骤

  • 定义网络
  • 定义损失函数
  • 定义优化算法

具体实现如下


#定义动态GRU
class DynamicGRU(fluid.dygraph.Layer):
def init(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation=‘sigmoid’,
candidate_activation=‘relu’,
h_0=None,
origin_mode=False,
):
super(DynamicGRU, self).init()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res

class GRU(fluid.dygraph.Layer):
    def __init__(self):
        super(GRU, self).__init__()
        self.dict_dim = train_parameters["vocab_size"]
        self.emb_dim = 128
        self.hid_dim = 128
        self.fc_hid_dim = 96
        self.class_dim = 2
        self.batch_size = train_parameters["batch_size"]
        self.seq_len = train_parameters["padding_size"]
        self.embedding = Embedding(
            size=[self.dict_dim + 1, self.emb_dim],
            dtype='float32',
            param_attr=fluid.ParamAttr(learning_rate=30),
            is_sparse=False)
        h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
        h_0 = to_variable(h_0)
        
        self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.hid_dim*3)
        self._fc2 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="relu")
        self._fc_prediction = Linear(input_dim=self.fc_hid_dim,
                                output_dim=self.class_dim,
                                act="softmax")
        self._gru = DynamicGRU(size=self.hid_dim, h_0=h_0)
        
    def forward(self, inputs, label=None):
        emb = self.embedding(inputs)
        o_np_mask =to_variable(inputs.numpy().reshape(-1,1) != self.dict_dim).astype('float32')
        mask_emb = fluid.layers.expand(
            to_variable(o_np_mask), [1, self.hid_dim])
        emb = emb * mask_emb
        emb = fluid.layers.reshape(emb, shape=[self.batch_size, -1, self.hid_dim])
        fc_1 = self._fc1(emb)
        gru_hidden = self._gru(fc_1)
        gru_hidden = fluid.layers.reduce_max(gru_hidden, dim=1)
        tanh_1 = fluid.layers.tanh(gru_hidden)
        fc_2 = self._fc2(tanh_1)
        prediction = self._fc_prediction(fc_2)
        
        if label is not None:
            acc = fluid.layers.accuracy(prediction, label=label)
            return prediction, acc
        else:
            return prediction

4.3 训练模型


def train():
with fluid.dygraph.guard(place = fluid.CUDAPlace(0)): # # 因为要进行很大规模的训练,因此我们用的是GPU,如果没有安装GPU的可以使用下面一句,把这句代码注释掉即可
# with fluid.dygraph.guard(place = fluid.CPUPlace()):

        processor = SentaProcessor()
        train_data_generator = processor.data_generator(batch_size=train_parameters["batch_size"], phase='train')

        model = GRU()
        sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=train_parameters["lr"],parameter_list=model.parameters())

        steps = 0
        Iters, total_loss, total_acc = [], [], []
        for eop in range(train_parameters["epoch"]):
            for batch_id, data in enumerate(train_data_generator()):

                steps += 1
                doc = to_variable(
                    np.array([
                        np.pad(x[0][0:train_parameters["padding_size"]], 
                              (0, train_parameters["padding_size"] - len(x[0][0:train_parameters["padding_size"]])),
                               'constant',
                              constant_values=(train_parameters["vocab_size"]))
                        for x in data
                    ]).astype('int64').reshape(-1))
                label = to_variable(
                    np.array([x[1] for x in data]).astype('int64').reshape(
                        train_parameters["batch_size"], 1))
        
                model.train()
                prediction, acc = model(doc, label)
                loss = fluid.layers.cross_entropy(prediction, label)
                avg_loss = fluid.layers.mean(loss)
                avg_loss.backward()
                sgd_optimizer.minimize(avg_loss)
                model.clear_gradients()
 
                if steps % train_parameters["skip_steps"] == 0:
                    Iters.append(steps)
                    total_loss.append(avg_loss.numpy()[0])
                    total_acc.append(acc.numpy()[0])
                    print("step: %d, ave loss: %f, ave acc: %f" %
                         (steps,avg_loss.numpy(),acc.numpy()))

                if steps % train_parameters["save_steps"] == 0:
                    save_path = train_parameters["checkpoints"]+"/"+"save_dir_" + str(steps)
                    print('save model to: ' + save_path)
                    fluid.dygraph.save_dygraph(model.state_dict(),
                                                   save_path)
    draw_train_process(Iters, total_loss, total_acc)

在这里插入图片描述
在这里插入图片描述

4.4 模型评估

在这里插入图片描述

结果还可以,这里说明的是,刚开始的模型训练评估不可能这么好,很明显是过拟合的问题,这就需要我们调整我们的epoch、batchsize、激活函数的选择以及优化器、学习率等各种参数,通过不断的调试、训练最好可以得到不错的结果,但是,如果还要更好的模型效果,其实可以将GRU模型换为更为合适的RNN中的LSTM以及bi-
LSTM模型会好很多。

4.5 模型预测


train_parameters[“batch_size”] = 1

with fluid.dygraph.guard(place = fluid.CUDAPlace(0)):

    sentences = 'this is a great movie'
    data = load_data(sentences)
    print(sentences)
    print(data)
    data_np = np.array(data)
    data_np = np.array(np.pad(data_np,(0,150-len(data_np)),"constant",constant_values =train_parameters["vocab_size"])).astype('int64').reshape(-1)
    infer_np_doc = to_variable(data_np)

    model_infer = GRU()
    model, _ = fluid.load_dygraph("data/save_dir_750.pdparams")
    model_infer.load_dict(model)
    model_infer.eval()
    result = model_infer(infer_np_doc)
    print('预测结果为:正面概率为:%0.5f,负面概率为:%0.5f' % (result.numpy()[0][0],result.numpy()[0][1]))

在这里插入图片描述

训练的结果还是挺满意的,到此为止,我们的本次项目实验到此结束。

5 最后

🧿 更多资料, 项目分享:

https://gitee.com/dancheng-senior/postgraduate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1600966.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

注册表让我重回80年代(狗头保命

现在是2024年4月16日23:09:07,今天之所以这么晚才睡,是因为遇到了一个很有意思的事情,以至于解决完之后,强挺困意,将其记录—— 缘由是想只用键盘操纵电脑,上面有写,那用winR就是家常便饭。只不…

贴片滚珠振动开关 / 振动传感器的用法

就是这种小东西: 上面的截图来自:https://item.szlcsc.com/3600130.html 以前写过一篇介绍这种东西内部的结构原理:贴片微型滚珠振动开关的结构原理。就是有个小滚珠会接通开关两边的电极,振动时滚珠会在内部蹦跳,开关…

基于Springboot的影城管理系统

基于SpringbootVue的影城管理系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 首页展示 电影信息 电影资讯 后台登录页 后台首页 用户管理 电影类型管理 放映…

RabbitMQ Stream插件使用详解

2.4版为RabbitMQ流插件引入了对RabbitMQStream插件Java客户端的初始支持。 RabbitStreamTemplateStreamListener容器 将spring rabbit流依赖项添加到项目中&#xff1a; <dependency><groupId>org.springframework.amqp</groupId><artifactId>sprin…

WebKit内核游览器

WebKit内核游览器 基础概念游览器引擎Chromium 浏览器架构Webkit 资源加载这里就不得不提到http超文本传输协议这个概念了&#xff1a; 游览器多线程HTML 解析总结 基础概念 百度百科介绍 WebKit 是一个开源的浏览器引擎&#xff0c;与之相对应的引擎有Gecko&#xff08;Mozil…

C# 字面量null对于引用类型变量和值类型变量

编译器让相同的字符串字面量共享堆中的同一内存位置以节约内存。 在C#中&#xff0c;字面量&#xff08;literal&#xff09;是指直接表示固定值的符号&#xff0c;比如数字、字符串或者布尔值。而关键字&#xff08;keyword&#xff09;则是由编程语言定义的具有特殊含义的标…

mysql 转pg 两者不同的地方

因项目数据库&#xff08;原来是MySQL&#xff09;要改成PostgreSQL。 项目里面的sql要做一些调整。 1&#xff0c;写法上的区别&#xff1a; 1&#xff0c;数据准备&#xff1a; 新建表格&#xff1a; CREATE TABLE property_config ( CODE VARCHAR(50) NULL…

PHP一句话木马

一句话木马 PHP 的一句话木马是一种用于 Web 应用程序漏洞利用的代码片段。它通常是一小段 PHP 代码&#xff0c;能够在目标服务器上执行任意命令。一句话木马的工作原理是利用 Web 应用程序中的安全漏洞&#xff0c;将恶意代码注入到服务器端的 PHP 脚本中。一旦执行&#xf…

免费ssl通配符证书申请教程

在互联网安全日益受到重视的今天&#xff0c;启用HTTPS已经成为网站运营的基本要求。它不仅保障用户数据传输的安全&#xff0c;提升搜索引擎排名&#xff0c;还能增强用户对网站的信任。通配符证书是一种SSL/TLS证书&#xff0c;用于同时保护一个域名及其所有下一级子域名的安…

【Qt】:界面优化(一:基本语法)

界面优化 一.基本语法1.设置指定控件样式2.设置全局控件样式3.从文件加载样式表4.使⽤Qt Designer编辑样式&#xff08;最常用&#xff09; 二.选择器1.概述2.子控件选择器3.伪类型选择器 三.盒模型 在网页前端开发领域中,CSS是一个至关重要的部分.描述了一个网页的"样式&…

基于JavaWeb开发的springboot网约车智能接单规划小程序[附源码]

基于JavaWeb开发的springboot网约车智能接单规划小程序[附源码] &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获取联系方式 承接各种…

Weakly Supervised Audio-Visual Violence Detection 论文阅读

Weakly Supervised Audio-Visual Violence Detection 论文阅读 摘要III. METHODOLOGYA. Multimodal FusionB. Relation Modeling ModuleC. Training and Inference IV. EXPERIMENTSV. CONCLUSION阅读总结 文章信息&#xff1a; 发表于&#xff1a;IEEE TRANSACTIONS ON MULTIME…

上海亚商投顾:沪指低开低走跌 两市逾千股跌超10%

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 一.市场情绪 三大指数昨日低开低走&#xff0c;沪指跌超1.6%&#xff0c;深成指、创业板指尾盘跌逾2%&#xff0c;微盘股指…

【c 语言】结构体指针

&#x1f388;个人主页&#xff1a;豌豆射手^ &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 &#x1f917;收录专栏&#xff1a;C语言 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共同学习、交流进步&…

2024年4月最新版GPT

2024年4月最新版ChatGPT/GPT4, 附上最新的使用教程。 随着人工智能技术的不断发展&#xff0c;ChatGPT和GPT4已经成为了人们日常生活中不可或缺的助手。2024年4月,OpenAI公司推出了最新版本的GPT4,带来了更加强大的功能和更加友好的用户体验。本文将为大家带来最新版GPT4的实用…

外网如何访问内网数据库?

企业和个人常常需要在外部网络环境中访问内部的数据库资源。这也是为了实现更大范围的资源共享和便捷的工作模式。由于网络安全和防火墙的限制&#xff0c;外网访问内网数据库并不是一件容易的事情。 在解决这个问题的过程中&#xff0c;天联组网应运而生。天联组网是一款异地组…

AIGC实战——VQ-GAN(Vector Quantized Generative Adversarial Network)

AIGC实战——VQ-GAN 0. 前言1. VQ-GAN2. ViT VQ-GAN小结系列链接 0. 前言 本节中&#xff0c;我们将介绍 VQ-GAN (Vector Quantized Generative Adversarial Network) 和 ViT VQ-GAN&#xff0c;它们融合了变分自编码器 (Variational Autoencoder, VAE)、Transformer 和生成对…

科技驱动未来,提升AI算力,GPU扩展正当时

要说这两年最火的科技是什么&#xff1f;我想“AI人工智能”肯定是最有资格上榜的&#xff0c;尤其ChatGPT推出后迅速在社交媒体上走红&#xff0c;短短5天&#xff0c;注册用户数就超过100万&#xff0c;2023年一月末&#xff0c;ChatGPT的月活用户更是突破1亿&#xff0c;成为…

内网kift私有网盘如何实现在外网公网访问?快解析映射方案

KIFT是一款面向个人、团队、小型组织的网盘服务器系统&#xff0c;安装运行比较简单&#xff0c;开箱即用&#xff0c;下载解压&#xff0c;双击jar文件即可启动。因为是开源的&#xff0c;不少人选择使用KIFT做开源私有网盘&#xff0c;有能力的大佬还可以对它进行定制开发。 …

python 列表对象函数

对象函数必须通过一个对象调用。 列表名.函数名() append() 将某一个元素对象添加在列表的表尾 如果添加的是其他的序列&#xff0c;该序列也会被看成是一个数据对象 count() 统计列表当中 某一个元素出现的次数 extend() 在当前列表中 将传入的其他序列的元素添加在表尾…