NLP之LSTM与BiLSTM

news2024/10/7 18:20:06

文章目录

  • 代码展示
  • 代码解读
  • 双向LSTM介绍(BiLSTM)

代码展示

import pandas as pd
import tensorflow as tf
tf.random.set_seed(1)
df = pd.read_csv("../data/Clothing Reviews.csv")
print(df.info())

df['Review Text'] = df['Review Text'].astype(str)
x_train = df['Review Text']
y_train = df['Rating']
print(y_train.unique())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 23486 entries, 0 to 23485
Data columns (total 11 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   Unnamed: 0               23486 non-null  int64 
 1   Clothing ID              23486 non-null  int64 
 2   Age                      23486 non-null  int64 
 3   Title                    19676 non-null  object
 4   Review Text              22641 non-null  object
 5   Rating                   23486 non-null  int64 
 6   Recommended IND          23486 non-null  int64 
 7   Positive Feedback Count  23486 non-null  int64 
 8   Division Name            23472 non-null  object
 9   Department Name          23472 non-null  object
 10  Class Name               23472 non-null  object
[4 5 3 2 1]
from tensorflow.keras.preprocessing.text import Tokenizer

dict_size = 14848
tokenizer = Tokenizer(num_words=dict_size)

tokenizer.fit_on_texts(x_train)
print(len(tokenizer.word_index),tokenizer.index_word)

x_train_tokenized = tokenizer.texts_to_sequences(x_train)
from tensorflow.keras.preprocessing.sequence import pad_sequences
max_comment_length = 120
x_train = pad_sequences(x_train_tokenized,maxlen=max_comment_length)

for v in x_train[:10]:
    print(v,len(v))
# 构建RNN神经网络
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tf

rnn = Sequential()
# 对于rnn来说首先进行词向量的操作
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))
# RNN:simple_rnn (SimpleRNN)  (None, 100)   16100
# LSTM:simple_rnn (SimpleRNN)  (None, 100)  64400
rnn.add(Bidirectional(LSTM(units=100)))  # 第二层构建了100个RNN神经元
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))  # 输出分类的结果
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])
print(rnn.summary())
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)
print(result)
print(result.history)

代码解读

首先,我们来总结这段代码的流程:

  1. 导入了必要的TensorFlow Keras模块。
  2. 初始化了一个Sequential模型,这表示我们的模型会按顺序堆叠各层。
  3. 添加了一个Embedding层,用于将整数索引(对应词汇)转换为密集向量。
  4. 添加了一个双向LSTM层,其中包含100个神经元。
  5. 添加了两个Dense全连接层,分别包含10个和6个神经元。
  6. 使用sparse_categorical_crossentropy损失函数编译了模型。
  7. 打印了模型的摘要。
  8. 使用给定的训练数据和验证数据对模型进行了训练。
  9. 打印了训练的结果。

现在,让我们逐行解读代码:

  1. 导入依赖:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,SimpleRNN,Embedding,LSTM,Bidirectional
import tensorflow as tf

你导入了创建和训练RNN模型所需的TensorFlow Keras库。

  1. 初始化模型:
rnn = Sequential()

你选择了一个顺序模型,这意味着你可以简单地按顺序添加层。

  1. 添加Embedding层:
rnn.add(Embedding(input_dim=dict_size,output_dim=60,input_length=max_comment_length))

此层将整数索引转换为固定大小的向量。dict_size是词汇表的大小,max_comment_length是输入评论的最大长度。

  1. 添加LSTM层:
rnn.add(Bidirectional(LSTM(units=100)))

你选择了双向LSTM,这意味着它会考虑过去和未来的信息。它有100个神经元。

  1. 添加全连接层:
rnn.add(Dense(units=10,activation=tf.nn.relu))
rnn.add(Dense(units=6,activation=tf.nn.softmax))

这两个Dense层用于模型的输出,最后一层使用softmax激活函数进行6类的分类。

  1. 编译模型:
rnn.compile(loss='sparse_categorical_crossentropy',optimizer="adam",metrics=['accuracy'])

你选择了一个适合分类问题的损失函数,并选择了adam优化器。

  1. 显示模型摘要:
print(rnn.summary())

这将展示模型的结构和参数数量。

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 embedding (Embedding)       (None, 120, 60)           890880    
                                                                 
 bidirectional (Bidirectiona  (None, 200)              128800    
 l)                                                              
                                                                 
 dense (Dense)               (None, 10)                2010      
                                                                 
 dense_1 (Dense)             (None, 6)                 66        
                                                                 
=================================================================
Total params: 1,021,756
Trainable params: 1,021,756
Non-trainable params: 0
_________________________________________________________________
None
  1. 训练模型:
result = rnn.fit(x_train,y_train,batch_size=64,validation_split=0.3,epochs=10)

你用训练数据集训练了模型,其中30%的数据用作验证,训练了10个周期。

Epoch 1/10
257/257 [==============================] - 74s 258ms/step - loss: 1.2142 - accuracy: 0.5470 - val_loss: 1.0998 - val_accuracy: 0.5521
Epoch 2/10
257/257 [==============================] - 57s 221ms/step - loss: 0.9335 - accuracy: 0.6293 - val_loss: 0.9554 - val_accuracy: 0.6094
Epoch 3/10
257/257 [==============================] - 59s 229ms/step - loss: 0.8363 - accuracy: 0.6616 - val_loss: 0.9321 - val_accuracy: 0.6168
Epoch 4/10
257/257 [==============================] - 61s 236ms/step - loss: 0.7795 - accuracy: 0.6833 - val_loss: 0.9812 - val_accuracy: 0.6089
Epoch 5/10
257/257 [==============================] - 56s 217ms/step - loss: 0.7281 - accuracy: 0.7010 - val_loss: 0.9559 - val_accuracy: 0.6043
Epoch 6/10
257/257 [==============================] - 56s 219ms/step - loss: 0.6934 - accuracy: 0.7156 - val_loss: 1.0197 - val_accuracy: 0.5999
Epoch 7/10
257/257 [==============================] - 57s 220ms/step - loss: 0.6514 - accuracy: 0.7364 - val_loss: 1.1192 - val_accuracy: 0.6080
Epoch 8/10
257/257 [==============================] - 57s 222ms/step - loss: 0.6258 - accuracy: 0.7486 - val_loss: 1.1350 - val_accuracy: 0.6100
Epoch 9/10
257/257 [==============================] - 57s 220ms/step - loss: 0.5839 - accuracy: 0.7749 - val_loss: 1.1537 - val_accuracy: 0.6019
Epoch 10/10
257/257 [==============================] - 57s 222ms/step - loss: 0.5424 - accuracy: 0.7945 - val_loss: 1.1715 - val_accuracy: 0.5744
<keras.callbacks.History object at 0x00000244DCE06D90>
  1. 显示训练结果:
print(result)
<keras.callbacks.History object at 0x0000013AEAAE1A30>
print(result.history)
{'loss': [1.2142471075057983, 0.9334620833396912, 0.8363043069839478, 0.7795010805130005, 0.7280740141868591, 0.693393349647522, 0.6514003872871399, 0.6257606744766235, 0.5839114189147949, 0.5423741340637207], 
'accuracy': [0.5469586253166199, 0.6292579174041748, 0.6616179943084717, 0.6833333373069763, 0.7010340690612793, 0.7156326174736023, 0.7363746762275696, 0.748600959777832, 0.7748783230781555, 0.7944647073745728], 
'val_loss': [1.0997602939605713, 0.9553984999656677, 0.932131290435791, 0.9812102317810059, 0.9558586478233337, 1.019730806350708, 1.11918044090271, 1.1349923610687256, 1.1536787748336792, 1.1715185642242432], 
'val_accuracy': [0.5520862936973572, 0.609423816204071, 0.6168038845062256, 0.6088560819625854, 0.6043145060539246, 0.5999148488044739, 0.6080045700073242, 0.6099914908409119, 0.6019017696380615, 0.574368417263031]
}

这将展示训练过程中的损失和准确性等信息。

双向LSTM介绍(BiLSTM)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
例子:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1155534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux进程控制】进程控制专篇

【Linux进程控制】进程控制专篇 目录 【Linux进程控制】进程控制专篇进程创建fork函数写实拷贝fork常规用法fork调用失败的原因 进程终止进程退出场景进程常见退出方法_exit函数return退出 进程等待进程等待必要性进程等待的方法获取子进程status 具体代码实现进程程序替换替换…

3.5 队列的表示和操作的实现

思维导图&#xff1a; 3.5.1 队列类型 3.5.1 队列的类型定义 1. 简介 队列是一种特殊的线性表&#xff0c;它的特性是只能在表的一端进行插入操作&#xff0c;而在另一端进行删除操作。通常将允许插入操作的一端称为队尾&#xff0c;允许删除操作的一端称为队头。 2. 抽象…

ArcGIS计算土地现状容积率

本文讲解在ArcGIS中,基于建筑数据和地籍边界数据,计算土地容积率。 一、容积率介绍 容积率(Plot Ratio/Floor Area Ratio/Volume Fraction)是指一个小区的地上建筑总面积与净用地面积的比率。又称建筑面积毛密度。 二、数据分析 (1)建筑数据(dwg) (2)地籍边界数据…

VsCode 只有一个标签页 编辑区只能打开一个文件

产生如图所示的问题&#xff1a; 可能是不小心取消了勾选 勾选&#xff0c;Show Tabs

软件安利——火绒安全

近年来&#xff0c;以优化、驱动、管理为目标所打造的软件屡见不鲜&#xff0c;大同小异的电脑管家相继走入了公众的视野。然而&#xff0c;在这日益急功近利的社会氛围驱动之下&#xff0c;真正坚持初心、优先考虑用户体验的电脑管家逐渐湮没在了浪潮之中。无论是鲁大师&#…

Enfocus PitStop Pro 2022(Acrobat dc增强)

Enfocus PitStop Pro 2022是一款Acrobat dc PDF编辑和校对工具&#xff0c;为Mac用户提供了强大的功能和精确的控制&#xff0c;以确保PDF文件的质量和准确性。该软件具有全面的PDF编辑功能&#xff0c;包括添加、删除或重新排列页面&#xff0c;合并和分割PDF文件&#xff0c;…

工程中Http的请求、各种回调函数的使用

文章目录 1、登录回调以及各种函数的使用1、SdoLoginClient工程中的SdoBase_Initialize3接口2、LoginClient中的Initialize接口3、ProcessResponse调用ProcessLoginResponse传递参数给回调函数使用4、ProcessLoginResponse登录响应接口的使用5、ProcessResponse调用然后根据req…

四、[mysql]索引优化-1

目录 前言一、场景举例1.联合索引第一个字段用范围查询不走索引(分情况&#xff09;2.强制走指定索引3.覆盖索引优化4.in和or在表数据量比较大的情况会走索引&#xff0c;在表记录不多的情况下会选择全表扫描5.like 后% 一般情况都会走索引(索引下推) 二、Mysql如何选择合适的索…

中移链浏览器简介

&#xff08;1&#xff09;简介 生活中&#xff0c;常用的互联网浏览器&#xff0c;是用来检索、展示以及传递Web信息资源的应用程序。用浏览器进行搜索&#xff0c;可以快速查找到目标信息。而对于区块链而言&#xff0c;也有区块链浏览器。 区块链浏览器&#xff0c;是指为用…

【案例实战】NodeJS+Vue3+MySQL实现列表查询功能

这篇文章&#xff0c;给大家带来一个列表查询的功能&#xff0c;从前端到后端的一个综合案例实战。 采用vue3作为前端开发&#xff0c;nodejs作为后端开发。 首先我们先来看一下完成的页面效果。点击分页&#xff0c;可以切换到上一页、下一页。搜索框可以进行模糊查询。 后端…

17.基干模型Swin-Transformer解读

文章目录 SWin-Transformer解读1.基础介绍关于Shifted Window based Self-Attention相对位置偏置网络整体结构和层级特征欢迎访问个人网络日志🌹🌹知行空间🌹🌹 SWin-Transformer解读 1.基础介绍 Swin-Transformer是2021年03月微软亚洲研究院提交的论文中提出的,比V…

【Linux】常见指令以及具体其使用场景

君兮_的个人主页 即使走的再远&#xff0c;也勿忘启程时的初心 C/C 游戏开发 Hello,米娜桑们&#xff0c;这里是君兮_&#xff0c;随着博主的学习&#xff0c;博主掌握的技能也越来越多&#xff0c;今天又根据最近的学习开设一个新的专栏——Linux&#xff0c;相信Linux操作系…

毛发渲染方案实现

一、毛发材质概述 以前毛发只能用离线来做 现在实时毛发逐渐可能。长毛渲染和短毛渲染采用的是不同的方案。 二、长毛类制作分析 各向异性 kajiya算法 # 三、短毛类制作分析 四、制作心得及技巧

Ansible中的playbook

目录 一、playbook简介 二、playbook的语法 三、playbook的核心组件 四、playbook的执行命令 五、vim 设定技巧 六、基本示例 一、playbook简介 1、playbook与ad-hoc相比&#xff0c;是一种完全不同的运用。 2、playbook是一种简单的配置管理系统与多机器部署系统的基础…

阿里云Apsara云栖大会2023

文章目录 2023/10/312023/11/012023/11/02彩蛋1&#xff1a;神州十六号彩蛋2&#xff1a;emm… 计算&#xff0c;为了无法计算的价值。 2023/10/31 合规性评审 2023/11/01 暂未开始 2023/11/02 暂未开始 彩蛋1&#xff1a;神州十六号 彩蛋2&#xff1a;emm…

Linux系统jdkTomcatMySQL安装以及J2EE后端接口部署

目录 一、jdk&Tomcat安装 1.1 上传安装包到服务器 1.2 解压对应工具包 1.3 配置环境变量并测试jdk安装 1.4 启动tomcat 1.5 防火墙设置 1.5.1 开启/关闭防火墙以及防火墙状态查看 1.5.2 开放防火墙端口 二、MySQL安装 2.1 卸载mariadb 2.2 在线下载MySQL安装包(…

全方位 Linux 性能调优经验总结

Part1Linux性能优化 1性能优化 性能指标 高并发和响应快对应着性能优化的两个核心指标&#xff1a;吞吐和延时 图片来自: www.ctq6.cn 应用负载角度&#xff1a;直接影响了产品终端的用户体验系统资源角度&#xff1a;资源使用率、饱和度等 性能问题的本质就是系统资源已经…

AI:46-基于深度学习的垃圾邮件识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌本专栏包含以下学习方向: 机器学习、深度学…

libpcap获取数据包

一、用户空间 以Linux以及TPACKET_V3为例。 调用pcap_dispatch获取数据包&#xff0c;然后回调用户传递的数据包处理函数。 read_op实际调用的是pcap_read_linux_mmap_v3 // pcap.c int pcap_dispatch(pcap_t *p, int cnt, pcap_handler callback, u_char *user) {return (p-…

2023年【R1快开门式压力容器操作】考试题及R1快开门式压力容器操作模拟考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 R1快开门式压力容器操作考试题是安全生产模拟考试一点通生成的&#xff0c;R1快开门式压力容器操作证模拟考试题库是根据R1快开门式压力容器操作最新版教材汇编出R1快开门式压力容器操作仿真模拟考试。2023年【R1快开…