Tensorflow2.0笔记 - 循环神经网络RNN做IMDB评价分析

news2024/12/26 23:14:43

        本笔记记录使用SimpleRNNCell做一个IMDB评价系统情感二分类问题的例子。

import os
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
#np.random.seed(22)
tf.__version__


#取常见的10000个单词
total_words = 10000
#句子最长的单词数量设置为80
max_review_len = 80
#embedding设置为100,表示每个单词用100维向量表示
embedding_len = 100
#加载IMDB数据集
(x_train,y_train), (x_test, y_test) = datasets.imdb.load_data(num_words = total_words)
#对训练数据和测试数据的句子进行填充或截断
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)

#构建数据集
batchsize = 128
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsize, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsize, drop_remainder=True)

#x_train包含25000个句子,每个句子包含80个单词,y_train标签为1表示好评,0表示差评
print('x_train: shape - ', x_train.shape, ' y_train: max/min -', tf.reduce_max(y_train).numpy(), '/', tf.reduce_min(y_train).numpy())
print('x_test: shape - ', x_test.shape)


class MyRNN(keras.Model):
    #units:state的维度
    def __init__(self, total_words, embedding_len, max_review_len, units):
        super(MyRNN, self).__init__()
        #初始的序列状态初始化为0(第0时刻的状态)
        self.state0 = [tf.zeros([batchsize, units])]
        self.state1 = [tf.zeros([batchsize, units])]
        #embedding层,将文本转换为embedding表示
        #[b, 80] => [b, 80, 100]
        self.embedding = layers.Embedding(total_words, embedding_len, input_length=max_review_len)
        #[b, 80, 100] , units: 64 - 转换为64维的state [b, 64]
        self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.2)
        self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.2)
        #全连接层 [[b, 64] => [b, 1]
        self.outlayer = layers.Dense(1)
    #inputs: [b, 80] 
    def call(self, inputs, training=None):
        x = inputs
        #做embedding,[b,80] => [b, 80, 100]
        x = self.embedding(x)
        #做RNN cell计算
        #[b, 80, 100] => [b,  64]
        #遍历句子中的每个单词
        # word: [b, 100]
        state0 = self.state0
        state1 = self.state1
        for word in tf.unstack(x, axis=1):
            #h1 = x*w_xh + h0*w_hh
            out0, state0 = self.rnn_cell0(word, state0, training)
            out1, state1 = self.rnn_cell1(out0, state1)
        #循环完毕后,得到的out为[b, 64],表示每个句子最终得到的状态
        x = self.outlayer(out1)
        #计算最终评价结果
        prob = tf.sigmoid(x)
        return prob

def main():
    units = 64
    epochs = 15
    lr = 0.001

    model = MyRNN(total_words, embedding_len, max_review_len, units)
    model.compile(optimizer = optimizers.Adam(lr), loss = tf.losses.BinaryCrossentropy(),
                 metrics=['accuracy'])
    model.fit(db_train, epochs=epochs, validation_data=db_test)

    model.evaluate(db_test)

if __name__ == '__main__':
    main()

运行结果:

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1657341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

模拟实现链表的功能

1.什么是链表? 链表是一种物理存储结构上非连续存储结构,数据元素的逻辑顺序是通过链表中的引用链接次序实现的 。 实际中链表的结构非常多样,以下情况组合起来就有8种链表结构: 单向或者双向 带头或者不带头 …

机器学习:基于线性回归、岭回归、xgboost回归、Lasso回归、随机森林回归预测卡路里消耗

前言 系列专栏:机器学习:高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目,每个项目都处理一组不同的问题,包括监督和无监督学习、分类、回归和聚类,而且涉及创建深度学…

小丑的身份证和复印件 (BFS + Floyd)

本题链接:登录—专业IT笔试面试备考平台_牛客网 题目: 样例: 输入 2 10 (JOKERjoke #####asdr) 输出 12 思路: 根据题意,要求最短时间,实际上也可以理解为最短距离。 所以应该联想到有关最短距离的算法&…

【图文教程】PyCharm安装配置PyQt5+QtDesigner+PyUic+PyRcc

这里写目录标题 PyQt5、Qt Designer、PyUic、PyRcc简介(1)下载安装PyQt5(2)打开designer.exe所在位置(3)在PyCharm中配置QtDesigner(4)验证QtDesigner是否配置成功(5&…

重学java 34.API 5.工具类

有失才有悟,崩塌后的重建只会更牢固 —— 24.5.9 一、System类 1.概述: 系统相关类,是一个工具类 2.特点: a.构造私有,不能利用构造方法new对象 b.方法都是静态的 3.使用: 类名直接调用 4.方法 方法 …

Linux系统入侵排查(二)

前言 为什么要做系统入侵排查 入侵排查1 1.排查历史命令记录 2.可疑端口排查 3.可疑进程排查 4.开机启动项 4.1系统运行级别示意图: 4.2查看运行级别命令 4.3系统默认允许级别 4.4.开机启动配置文件 入侵排查2: 1.启动项文件排查&#xff1…

Python从0到POC编写--实用小脚本

UrlCheck: 假设我们要对一份 url 列表进行访问是不是 200 , 量多的话肯定不能一个一个去点开看, 这个时候我们可以借助脚本去判断, 假如有一份这样的列表, 这份列表呢,奇奇怪怪,有些写错了…

基于Spring Boot的公司OA系统设计与实现

基于Spring Boot的银行OA系统设计与实现 开发语言:Java 框架:springboot JDK版本:JDK1.8 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea 系统部分展示 用户登录界面,在银行OA系统运行后&#x…

刷题第3天(中等题):LeetCode24--两两交换链表中的节点--递归法

LeetCode24: 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:head [1,2,3,4…

FastDFS - 无法获取服务端连接资源:can‘t create connection to/xx.xx.xx.xx:0

问题描述 根据官方文档 安装完FastDFS服务器后, 服务正常启动,但是在 SpringBoot 项目使用 fastdfs-client 客户端报错无法获取服务端连接资源:cant create connection to/xx.xx.xx.xx:0, 一系列排查发现是获取到的 tracker 端口为 0 。 co…

Docx文件误删除如何恢复?别再花冤枉钱了,4个高效恢复软件!

不管是工作还是学习,总是会与各种各样的文件打交道。文件量越多就越容易出现文件丢失、文件误删的情况。遇到这些情况,失去的文件还能找回来吗?只要掌握了一些数据恢复方法,是很有机会恢复回来的,下面我会将这些方法分…

生信分析进阶2 - 利用GC含量的Loess回归矫正reads数量

在NGS数据比对后,需要矫正GC偏好引起的reads数量误差可用loess回归算法,使用R语言对封装的loess算法实现。 在NIPT中,GC矫正对检测结果准确性非常重要,具体研究参考以下文章。 Noninvasive Prenatal Diagnosis of Fetal Trisomy…

static静态成员变量和静态方法

当有new创建一个对象的,里面属性和方法,通过构造函数,能定义多个不同的对象,在我们做面向对象开发的时候,给一个场景,人在一个班级的时候,你的老师可能是固定的。 当我们用构造方法去构造的时候,每次都去传递一个固定的实参去定义个老师。 这样好会显得代码非常的…

DNS 解析在网络传输中有什么意义?

首先我们先说说什么是DNS解析? DNS解析是将域名解析为对应的IP地址的过程。DNS它作为将域名和IP地址相互映射的一个分布式数据库,能够使人更方便地访问互联网。DNS解析的过程就是寻找哪个IP地址对应你所输入的网址,然后将网页内容返回给用户…

常用的文件摆渡系统有哪些 | 好用的文件摆渡系统推荐

一、什么是文件摆渡系统 简单来说,文件摆渡系统是一种高效的、以文件为中心的文件管理系统,它的出现旨在解决企业在文件传输、共享和管理过程中的种种痛点。 更为值得一提的是,文件摆渡系统还具备强大的安全合规性,能够有效防止…

MultiBooth:文本驱动的多概念图像生成技术

在人工智能的领域,将文本描述转换为图像的技术正变得越来越先进。最近,一个由清华大学和Meta Reality Labs的研究人员组成的团队,提出了一种名为MultiBooth的新方法,它能够根据用户的文本提示,生成包含多个定制概念的图…

pytorch加载模型出现错误

大概的错误长下面这样: 问题出现的原因: ​很明显,我就是犯了第一种错误。 网上的修改方法: 我觉得按道理哈,确实,蓝色部分应该是可以把问题解决了的​。​但是我没有解决,因为我犯了另外一个错…

Django关于ORM的增删改查

Django中使用orm进行数据库的管理,主要包括以下步骤 1、创建model, 2、进行迁移 3、在视图函数中使用 以下的内容可以先从查询开始看,这样更容易理解后面删除部分代码 主要包括几下几种: 1、增 1)实例例化model,代…

struct和union大小计算规则

Union 一:联合类型的定义 联合也是一种特殊的自定义类型,这种类型定义的变量也包含一系列的成员,特征是这些成员公用同一块空间(所以联合也叫共用体) 比如:共用了 i 这个较大的空间 二: 联合的…

每日Attention学习4——Spatial Attention Module

模块出处 [link] [code] [MM 21] Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection 模块名称 Spatial Attention Module (SAM) 模块作用 空间注意力 模块结构 模块代码 import torch import torch.nn as nn import torch.nn.functional a…